You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/train_forward_ensemble.py

135 lines
5.8 KiB
Python

"""训练多个随机种子的正演代理模型集成。
脚本基于同一份 processed 数据和训练超参数循环启动多次 `train_forward`每个成员使用
独立 seed 和输出目录所得模型集合用于后续不确定性估计误差风险分析和 fallback 筛选
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
from src.training.train_forward import TrainConfig, train_forward
def parse_seed_list(seed_text: str) -> list[int]:
"""解析逗号分隔的随机种子列表,用于训练或评估模型集成。"""
seeds = []
for item in str(seed_text).split(","):
item = item.strip()
if not item:
continue
seeds.append(int(item))
if not seeds:
raise ValueError("至少需要一个 seed")
return seeds
def default_output_root(tag: str | None, use_schedule: bool) -> Path:
"""根据实验标签生成集成训练的默认模型输出根目录。"""
suffix = "" if use_schedule else "_no_schedule"
if tag:
return Path("models") / f"forward_surrogate_{tag}_ensemble{suffix}"
return Path("models") / f"forward_surrogate_ensemble{suffix}"
def main() -> None:
"""按多个随机种子重复训练正演代理模型,并写出集成清单。"""
parser = argparse.ArgumentParser(description="Train a deep-ensemble forward surrogate for UQ")
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
parser.add_argument("--output-root", type=str, default=None, help="Optional ensemble root directory")
parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list")
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--epochs", type=int, default=220)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--weight-decay", type=float, default=0.0005)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--w-pressure", type=float, default=1.0)
parser.add_argument("--w-derivative", type=float, default=2.0)
parser.add_argument("--w-slope", type=float, default=0.0)
parser.add_argument("--w-bias-pressure", type=float, default=0.15)
parser.add_argument("--w-bias-derivative", type=float, default=0.05)
parser.add_argument("--w-derivative-shape", type=float, default=0.10)
parser.add_argument("--w-autofit-pressure", type=float, default=0.0)
parser.add_argument("--w-autofit-derivative", type=float, default=0.0)
parser.add_argument("--huber-beta", type=float, default=0.05)
parser.add_argument("--use-sample-reweight", action="store_true", default=True)
parser.add_argument("--sample-reweight-alpha", type=float, default=0.4)
parser.add_argument("--sample-weight-min", type=float, default=1.0)
parser.add_argument("--sample-weight-max", type=float, default=2.5)
parser.add_argument("--no-schedule", action="store_true")
args = parser.parse_args()
tag = normalize_tag(args.tag)
use_schedule = not args.no_schedule
seeds = parse_seed_list(args.seeds)
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
output_root = Path(args.output_root) if args.output_root is not None else default_output_root(tag, use_schedule)
output_root.mkdir(parents=True, exist_ok=True)
# manifest 记录每个成员模型的 seed、checkpoint 和 metrics供不确定性评估脚本批量加载。
manifest = {
"tag": tag,
"processed_path": str(processed_path),
"use_schedule": use_schedule,
"seeds": seeds,
"members": [],
}
for seed in seeds:
member_dir = output_root / f"seed_{seed}"
print(f"\n=== Training ensemble member seed={seed} -> {member_dir} ===")
# 除随机种子和输出目录外,各成员使用相同数据和损失权重,形成深度集成。
cfg = TrainConfig(
processed_path=processed_path,
output_dir=member_dir,
seed=seed,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
hidden_dim=args.hidden_dim,
dropout=args.dropout,
w_pressure=args.w_pressure,
w_derivative=args.w_derivative,
w_slope=args.w_slope,
w_bias_pressure=args.w_bias_pressure,
w_bias_derivative=args.w_bias_derivative,
w_derivative_shape=args.w_derivative_shape,
w_autofit_pressure=args.w_autofit_pressure,
w_autofit_derivative=args.w_autofit_derivative,
huber_beta=args.huber_beta,
use_sample_reweight=args.use_sample_reweight,
sample_reweight_alpha=args.sample_reweight_alpha,
sample_weight_min=args.sample_weight_min,
sample_weight_max=args.sample_weight_max,
use_schedule=use_schedule,
)
train_forward(cfg)
manifest["members"].append(
{
"seed": seed,
"dir": str(member_dir),
"checkpoint": str(member_dir / "forward_surrogate_best.pt"),
"metrics": str(member_dir / "metrics.json"),
}
)
with open(output_root / "ensemble_manifest.json", "w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=2)
print("\nEnsemble training complete.")
print(f"Manifest saved: {output_root / 'ensemble_manifest.json'}")
if __name__ == "__main__":
main()