You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/train_forward_ensemble.py

135 lines
5.8 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""训练多个随机种子的正演代理模型集成。
脚本基于同一份 processed 数据和训练超参数循环启动多次 `train_forward`,每个成员使用
独立 seed 和输出目录。所得模型集合用于后续不确定性估计、误差风险分析和 fallback 筛选。
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
from src.training.train_forward import TrainConfig, train_forward
def parse_seed_list(seed_text: str) -> list[int]:
"""解析逗号分隔的随机种子列表,用于训练或评估模型集成。"""
seeds = []
for item in str(seed_text).split(","):
item = item.strip()
if not item:
continue
seeds.append(int(item))
if not seeds:
raise ValueError("至少需要一个 seed")
return seeds
def default_output_root(tag: str | None, use_schedule: bool) -> Path:
"""根据实验标签生成集成训练的默认模型输出根目录。"""
suffix = "" if use_schedule else "_no_schedule"
if tag:
return Path("models") / f"forward_surrogate_{tag}_ensemble{suffix}"
return Path("models") / f"forward_surrogate_ensemble{suffix}"
def main() -> None:
"""按多个随机种子重复训练正演代理模型,并写出集成清单。"""
parser = argparse.ArgumentParser(description="Train a deep-ensemble forward surrogate for UQ")
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
parser.add_argument("--output-root", type=str, default=None, help="Optional ensemble root directory")
parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list")
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--epochs", type=int, default=220)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--weight-decay", type=float, default=0.0005)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--w-pressure", type=float, default=1.0)
parser.add_argument("--w-derivative", type=float, default=2.0)
parser.add_argument("--w-slope", type=float, default=0.0)
parser.add_argument("--w-bias-pressure", type=float, default=0.15)
parser.add_argument("--w-bias-derivative", type=float, default=0.05)
parser.add_argument("--w-derivative-shape", type=float, default=0.10)
parser.add_argument("--w-autofit-pressure", type=float, default=0.0)
parser.add_argument("--w-autofit-derivative", type=float, default=0.0)
parser.add_argument("--huber-beta", type=float, default=0.05)
parser.add_argument("--use-sample-reweight", action="store_true", default=True)
parser.add_argument("--sample-reweight-alpha", type=float, default=0.4)
parser.add_argument("--sample-weight-min", type=float, default=1.0)
parser.add_argument("--sample-weight-max", type=float, default=2.5)
parser.add_argument("--no-schedule", action="store_true")
args = parser.parse_args()
tag = normalize_tag(args.tag)
use_schedule = not args.no_schedule
seeds = parse_seed_list(args.seeds)
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
output_root = Path(args.output_root) if args.output_root is not None else default_output_root(tag, use_schedule)
output_root.mkdir(parents=True, exist_ok=True)
# manifest 记录每个成员模型的 seed、checkpoint 和 metrics供不确定性评估脚本批量加载。
manifest = {
"tag": tag,
"processed_path": str(processed_path),
"use_schedule": use_schedule,
"seeds": seeds,
"members": [],
}
for seed in seeds:
member_dir = output_root / f"seed_{seed}"
print(f"\n=== Training ensemble member seed={seed} -> {member_dir} ===")
# 除随机种子和输出目录外,各成员使用相同数据和损失权重,形成深度集成。
cfg = TrainConfig(
processed_path=processed_path,
output_dir=member_dir,
seed=seed,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
hidden_dim=args.hidden_dim,
dropout=args.dropout,
w_pressure=args.w_pressure,
w_derivative=args.w_derivative,
w_slope=args.w_slope,
w_bias_pressure=args.w_bias_pressure,
w_bias_derivative=args.w_bias_derivative,
w_derivative_shape=args.w_derivative_shape,
w_autofit_pressure=args.w_autofit_pressure,
w_autofit_derivative=args.w_autofit_derivative,
huber_beta=args.huber_beta,
use_sample_reweight=args.use_sample_reweight,
sample_reweight_alpha=args.sample_reweight_alpha,
sample_weight_min=args.sample_weight_min,
sample_weight_max=args.sample_weight_max,
use_schedule=use_schedule,
)
train_forward(cfg)
manifest["members"].append(
{
"seed": seed,
"dir": str(member_dir),
"checkpoint": str(member_dir / "forward_surrogate_best.pt"),
"metrics": str(member_dir / "metrics.json"),
}
)
with open(output_root / "ensemble_manifest.json", "w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=2)
print("\nEnsemble training complete.")
print(f"Manifest saved: {output_root / 'ensemble_manifest.json'}")
if __name__ == "__main__":
main()