from __future__ import annotations import argparse import json import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import normalize_tag, processed_path_for_tag from src.training.train_forward import TrainConfig, train_forward def parse_seed_list(seed_text: str) -> list[int]: seeds = [] for item in str(seed_text).split(","): item = item.strip() if not item: continue seeds.append(int(item)) if not seeds: raise ValueError("至少需要一个 seed") return seeds def default_output_root(tag: str | None, use_schedule: bool) -> Path: suffix = "" if use_schedule else "_no_schedule" if tag: return Path("models") / f"forward_surrogate_{tag}_ensemble{suffix}" return Path("models") / f"forward_surrogate_ensemble{suffix}" def main() -> None: parser = argparse.ArgumentParser(description="Train a deep-ensemble forward surrogate for UQ") parser.add_argument("--processed", type=str, default=None, help="Processed dataset path") parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming") parser.add_argument("--output-root", type=str, default=None, help="Optional ensemble root directory") parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list") parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--epochs", type=int, default=220) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--weight-decay", type=float, default=0.0005) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--w-pressure", type=float, default=1.0) parser.add_argument("--w-derivative", type=float, default=2.0) parser.add_argument("--w-slope", type=float, default=0.0) parser.add_argument("--w-bias-pressure", type=float, default=0.15) parser.add_argument("--w-bias-derivative", type=float, default=0.05) parser.add_argument("--w-derivative-shape", type=float, default=0.10) parser.add_argument("--w-autofit-pressure", type=float, default=0.0) parser.add_argument("--w-autofit-derivative", type=float, default=0.0) parser.add_argument("--huber-beta", type=float, default=0.05) parser.add_argument("--use-sample-reweight", action="store_true", default=True) parser.add_argument("--sample-reweight-alpha", type=float, default=0.4) parser.add_argument("--sample-weight-min", type=float, default=1.0) parser.add_argument("--sample-weight-max", type=float, default=2.5) parser.add_argument("--no-schedule", action="store_true") args = parser.parse_args() tag = normalize_tag(args.tag) use_schedule = not args.no_schedule seeds = parse_seed_list(args.seeds) processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) output_root = Path(args.output_root) if args.output_root is not None else default_output_root(tag, use_schedule) output_root.mkdir(parents=True, exist_ok=True) manifest = { "tag": tag, "processed_path": str(processed_path), "use_schedule": use_schedule, "seeds": seeds, "members": [], } for seed in seeds: member_dir = output_root / f"seed_{seed}" print(f"\n=== Training ensemble member seed={seed} -> {member_dir} ===") cfg = TrainConfig( processed_path=processed_path, output_dir=member_dir, seed=seed, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay, hidden_dim=args.hidden_dim, dropout=args.dropout, w_pressure=args.w_pressure, w_derivative=args.w_derivative, w_slope=args.w_slope, w_bias_pressure=args.w_bias_pressure, w_bias_derivative=args.w_bias_derivative, w_derivative_shape=args.w_derivative_shape, w_autofit_pressure=args.w_autofit_pressure, w_autofit_derivative=args.w_autofit_derivative, huber_beta=args.huber_beta, use_sample_reweight=args.use_sample_reweight, sample_reweight_alpha=args.sample_reweight_alpha, sample_weight_min=args.sample_weight_min, sample_weight_max=args.sample_weight_max, use_schedule=use_schedule, ) train_forward(cfg) manifest["members"].append( { "seed": seed, "dir": str(member_dir), "checkpoint": str(member_dir / "forward_surrogate_best.pt"), "metrics": str(member_dir / "metrics.json"), } ) with open(output_root / "ensemble_manifest.json", "w", encoding="utf-8") as f: json.dump(manifest, f, ensure_ascii=False, indent=2) print("\nEnsemble training complete.") print(f"Manifest saved: {output_root / 'ensemble_manifest.json'}") if __name__ == "__main__": main()