You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
78 lines
3.2 KiB
Python
78 lines
3.2 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.append(str(ROOT))
|
|
|
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
|
|
from src.training.train_time_conditioned import TimeConditionedTrainConfig, train_time_conditioned
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Train a time-conditioned point-wise forward surrogate")
|
|
parser.add_argument("--processed", type=str, default=None)
|
|
parser.add_argument("--tag", type=str, default=None)
|
|
parser.add_argument("--output-dir", type=str, default=None)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--batch-size", type=int, default=4096)
|
|
parser.add_argument("--epochs", type=int, default=120)
|
|
parser.add_argument("--lr", type=float, default=1.0e-3)
|
|
parser.add_argument("--weight-decay", type=float, default=1.0e-4)
|
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
|
parser.add_argument("--n-blocks", type=int, default=4)
|
|
parser.add_argument("--dropout", type=float, default=0.05)
|
|
parser.add_argument("--w-pressure", type=float, default=1.0)
|
|
parser.add_argument("--w-derivative", type=float, default=2.0)
|
|
parser.add_argument("--huber-beta", type=float, default=0.05)
|
|
parser.add_argument("--no-schedule", action="store_true")
|
|
parser.add_argument(
|
|
"--sample-weight-mode",
|
|
choices=["none", "risk_region"],
|
|
default="none",
|
|
help="Optional risk-region sample weighting; default keeps the original unweighted training behavior",
|
|
)
|
|
parser.add_argument("--risk-weight", type=float, default=2.5)
|
|
parser.add_argument("--skin-lt-minus8-weight", type=float, default=3.5)
|
|
parser.add_argument("--sample-weight-min", type=float, default=1.0)
|
|
parser.add_argument("--sample-weight-max", type=float, default=4.0)
|
|
args = parser.parse_args()
|
|
|
|
tag = normalize_tag(args.tag)
|
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
|
if args.output_dir is not None:
|
|
output_dir = Path(args.output_dir)
|
|
elif tag:
|
|
output_dir = Path("models") / f"time_conditioned_surrogate_{tag}"
|
|
else:
|
|
output_dir = Path("models") / "time_conditioned_surrogate"
|
|
|
|
cfg = TimeConditionedTrainConfig(
|
|
processed_path=processed_path,
|
|
output_dir=output_dir,
|
|
seed=int(args.seed),
|
|
batch_size=int(args.batch_size),
|
|
epochs=int(args.epochs),
|
|
lr=float(args.lr),
|
|
weight_decay=float(args.weight_decay),
|
|
hidden_dim=int(args.hidden_dim),
|
|
n_blocks=int(args.n_blocks),
|
|
dropout=float(args.dropout),
|
|
w_pressure=float(args.w_pressure),
|
|
w_derivative=float(args.w_derivative),
|
|
huber_beta=float(args.huber_beta),
|
|
use_schedule=not bool(args.no_schedule),
|
|
sample_weight_mode=str(args.sample_weight_mode),
|
|
risk_weight=float(args.risk_weight),
|
|
skin_lt_minus8_weight=float(args.skin_lt_minus8_weight),
|
|
sample_weight_min=float(args.sample_weight_min),
|
|
sample_weight_max=float(args.sample_weight_max),
|
|
)
|
|
train_time_conditioned(cfg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|