from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import normalize_tag, processed_path_for_tag from src.training.train_time_conditioned import TimeConditionedTrainConfig, train_time_conditioned def main() -> None: parser = argparse.ArgumentParser(description="Train a time-conditioned point-wise forward surrogate") parser.add_argument("--processed", type=str, default=None) parser.add_argument("--tag", type=str, default=None) parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--batch-size", type=int, default=4096) parser.add_argument("--epochs", type=int, default=120) parser.add_argument("--lr", type=float, default=1.0e-3) parser.add_argument("--weight-decay", type=float, default=1.0e-4) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--n-blocks", type=int, default=4) parser.add_argument("--dropout", type=float, default=0.05) parser.add_argument("--w-pressure", type=float, default=1.0) parser.add_argument("--w-derivative", type=float, default=2.0) parser.add_argument("--huber-beta", type=float, default=0.05) parser.add_argument("--no-schedule", action="store_true") parser.add_argument( "--sample-weight-mode", choices=["none", "pso_domain", "pso_domain_risk"], default="none", help="Optional sample weighting; default keeps the original unweighted training behavior", ) parser.add_argument("--pso-outside-weight", type=float, default=0.5) parser.add_argument("--pso-inside-weight", type=float, default=1.0) parser.add_argument("--risk-weight", type=float, default=2.5) parser.add_argument("--skin-lt-minus8-weight", type=float, default=3.5) parser.add_argument("--sample-weight-min", type=float, default=0.25) parser.add_argument("--sample-weight-max", type=float, default=4.0) args = parser.parse_args() tag = normalize_tag(args.tag) processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) if args.output_dir is not None: output_dir = Path(args.output_dir) elif tag: output_dir = Path("models") / f"time_conditioned_surrogate_{tag}" else: output_dir = Path("models") / "time_conditioned_surrogate" cfg = TimeConditionedTrainConfig( processed_path=processed_path, output_dir=output_dir, seed=int(args.seed), batch_size=int(args.batch_size), epochs=int(args.epochs), lr=float(args.lr), weight_decay=float(args.weight_decay), hidden_dim=int(args.hidden_dim), n_blocks=int(args.n_blocks), dropout=float(args.dropout), w_pressure=float(args.w_pressure), w_derivative=float(args.w_derivative), huber_beta=float(args.huber_beta), use_schedule=not bool(args.no_schedule), sample_weight_mode=str(args.sample_weight_mode), pso_outside_weight=float(args.pso_outside_weight), pso_inside_weight=float(args.pso_inside_weight), risk_weight=float(args.risk_weight), skin_lt_minus8_weight=float(args.skin_lt_minus8_weight), sample_weight_min=float(args.sample_weight_min), sample_weight_max=float(args.sample_weight_max), ) train_time_conditioned(cfg) if __name__ == "__main__": main()