You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/train_forward.py

95 lines
3.6 KiB
Python

from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import model_dir_for_tag, normalize_tag, processed_path_for_tag
from src.training.train_forward import TrainConfig, train_forward
def main() -> None:
parser = argparse.ArgumentParser(description="Train forward surrogate model")
parser.add_argument(
"--processed",
type=str,
default=None,
help="Processed dataset path",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Optional model output directory",
)
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--epochs", type=int, default=220)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--weight-decay", type=float, default=0.0005)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--w-pressure", type=float, default=1.0)
parser.add_argument("--w-derivative", type=float, default=2.0)
parser.add_argument("--w-slope", type=float, default=0.0)
parser.add_argument("--w-bias-pressure", type=float, default=0.15)
parser.add_argument("--w-bias-derivative", type=float, default=0.05)
parser.add_argument("--w-derivative-shape", type=float, default=0.10)
parser.add_argument("--w-autofit-pressure", type=float, default=0.0)
parser.add_argument("--w-autofit-derivative", type=float, default=0.0)
parser.add_argument("--huber-beta", type=float, default=0.05)
parser.add_argument("--use-sample-reweight", action="store_true", default=True)
parser.add_argument("--sample-reweight-alpha", type=float, default=0.4)
parser.add_argument("--sample-weight-min", type=float, default=1.0)
parser.add_argument("--sample-weight-max", type=float, default=2.5)
parser.add_argument(
"--no-schedule",
action="store_true",
help="Disable the schedule branch and train a parameter-only forward surrogate",
)
args = parser.parse_args()
tag = normalize_tag(args.tag)
use_schedule = not args.no_schedule
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
output_dir = Path(args.output_dir) if args.output_dir is not None else model_dir_for_tag(tag, use_schedule)
cfg = TrainConfig(
processed_path=processed_path,
output_dir=output_dir,
seed=args.seed,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
hidden_dim=args.hidden_dim,
dropout=args.dropout,
w_pressure=args.w_pressure,
w_derivative=args.w_derivative,
w_slope=args.w_slope,
w_bias_pressure=args.w_bias_pressure,
w_bias_derivative=args.w_bias_derivative,
w_derivative_shape=args.w_derivative_shape,
w_autofit_pressure=args.w_autofit_pressure,
w_autofit_derivative=args.w_autofit_derivative,
huber_beta=args.huber_beta,
use_sample_reweight=args.use_sample_reweight,
sample_reweight_alpha=args.sample_reweight_alpha,
sample_weight_min=args.sample_weight_min,
sample_weight_max=args.sample_weight_max,
use_schedule=use_schedule,
)
train_forward(cfg)
if __name__ == "__main__":
main()