from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import normalize_tag, processed_path_for_tag from src.training.train_time_conditioned import TimeConditionedTrainConfig, train_time_conditioned def main() -> None: parser = argparse.ArgumentParser(description="Train a time-conditioned point-wise forward surrogate") parser.add_argument("--processed", type=str, default=None) parser.add_argument("--tag", type=str, default=None) parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--batch-size", type=int, default=4096) parser.add_argument("--epochs", type=int, default=120) parser.add_argument("--lr", type=float, default=1.0e-3) parser.add_argument("--weight-decay", type=float, default=1.0e-4) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--n-blocks", type=int, default=4) parser.add_argument("--dropout", type=float, default=0.05) parser.add_argument("--w-pressure", type=float, default=1.0) parser.add_argument("--w-derivative", type=float, default=2.0) parser.add_argument("--huber-beta", type=float, default=0.05) parser.add_argument("--no-schedule", action="store_true") args = parser.parse_args() tag = normalize_tag(args.tag) processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) if args.output_dir is not None: output_dir = Path(args.output_dir) elif tag: output_dir = Path("models") / f"time_conditioned_surrogate_{tag}" else: output_dir = Path("models") / "time_conditioned_surrogate" cfg = TimeConditionedTrainConfig( processed_path=processed_path, output_dir=output_dir, seed=int(args.seed), batch_size=int(args.batch_size), epochs=int(args.epochs), lr=float(args.lr), weight_decay=float(args.weight_decay), hidden_dim=int(args.hidden_dim), n_blocks=int(args.n_blocks), dropout=float(args.dropout), w_pressure=float(args.w_pressure), w_derivative=float(args.w_derivative), huber_beta=float(args.huber_beta), use_schedule=not bool(args.no_schedule), ) train_time_conditioned(cfg) if __name__ == "__main__": main()