You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/train_time_conditioned.py

78 lines
3.2 KiB
Python

from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
from src.training.train_time_conditioned import TimeConditionedTrainConfig, train_time_conditioned
def main() -> None:
parser = argparse.ArgumentParser(description="Train a time-conditioned point-wise forward surrogate")
parser.add_argument("--processed", type=str, default=None)
parser.add_argument("--tag", type=str, default=None)
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch-size", type=int, default=4096)
parser.add_argument("--epochs", type=int, default=120)
parser.add_argument("--lr", type=float, default=1.0e-3)
parser.add_argument("--weight-decay", type=float, default=1.0e-4)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--n-blocks", type=int, default=4)
parser.add_argument("--dropout", type=float, default=0.05)
parser.add_argument("--w-pressure", type=float, default=1.0)
parser.add_argument("--w-derivative", type=float, default=2.0)
parser.add_argument("--huber-beta", type=float, default=0.05)
parser.add_argument("--no-schedule", action="store_true")
parser.add_argument(
"--sample-weight-mode",
choices=["none", "risk_region"],
default="none",
help="Optional risk-region sample weighting; default keeps the original unweighted training behavior",
)
parser.add_argument("--risk-weight", type=float, default=2.5)
parser.add_argument("--skin-lt-minus8-weight", type=float, default=3.5)
parser.add_argument("--sample-weight-min", type=float, default=1.0)
parser.add_argument("--sample-weight-max", type=float, default=4.0)
args = parser.parse_args()
tag = normalize_tag(args.tag)
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
if args.output_dir is not None:
output_dir = Path(args.output_dir)
elif tag:
output_dir = Path("models") / f"time_conditioned_surrogate_{tag}"
else:
output_dir = Path("models") / "time_conditioned_surrogate"
cfg = TimeConditionedTrainConfig(
processed_path=processed_path,
output_dir=output_dir,
seed=int(args.seed),
batch_size=int(args.batch_size),
epochs=int(args.epochs),
lr=float(args.lr),
weight_decay=float(args.weight_decay),
hidden_dim=int(args.hidden_dim),
n_blocks=int(args.n_blocks),
dropout=float(args.dropout),
w_pressure=float(args.w_pressure),
w_derivative=float(args.w_derivative),
huber_beta=float(args.huber_beta),
use_schedule=not bool(args.no_schedule),
sample_weight_mode=str(args.sample_weight_mode),
risk_weight=float(args.risk_weight),
skin_lt_minus8_weight=float(args.skin_lt_minus8_weight),
sample_weight_min=float(args.sample_weight_min),
sample_weight_max=float(args.sample_weight_max),
)
train_time_conditioned(cfg)
if __name__ == "__main__":
main()