You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
63 lines
2.4 KiB
Python
63 lines
2.4 KiB
Python
|
3 weeks ago
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
ROOT = Path(__file__).resolve().parents[1]
|
||
|
|
sys.path.append(str(ROOT))
|
||
|
|
|
||
|
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
|
||
|
|
from src.training.train_time_conditioned import TimeConditionedTrainConfig, train_time_conditioned
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
parser = argparse.ArgumentParser(description="Train a time-conditioned point-wise forward surrogate")
|
||
|
|
parser.add_argument("--processed", type=str, default=None)
|
||
|
|
parser.add_argument("--tag", type=str, default=None)
|
||
|
|
parser.add_argument("--output-dir", type=str, default=None)
|
||
|
|
parser.add_argument("--seed", type=int, default=42)
|
||
|
|
parser.add_argument("--batch-size", type=int, default=4096)
|
||
|
|
parser.add_argument("--epochs", type=int, default=120)
|
||
|
|
parser.add_argument("--lr", type=float, default=1.0e-3)
|
||
|
|
parser.add_argument("--weight-decay", type=float, default=1.0e-4)
|
||
|
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
||
|
|
parser.add_argument("--n-blocks", type=int, default=4)
|
||
|
|
parser.add_argument("--dropout", type=float, default=0.05)
|
||
|
|
parser.add_argument("--w-pressure", type=float, default=1.0)
|
||
|
|
parser.add_argument("--w-derivative", type=float, default=2.0)
|
||
|
|
parser.add_argument("--huber-beta", type=float, default=0.05)
|
||
|
|
parser.add_argument("--no-schedule", action="store_true")
|
||
|
|
args = parser.parse_args()
|
||
|
|
|
||
|
|
tag = normalize_tag(args.tag)
|
||
|
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
||
|
|
if args.output_dir is not None:
|
||
|
|
output_dir = Path(args.output_dir)
|
||
|
|
elif tag:
|
||
|
|
output_dir = Path("models") / f"time_conditioned_surrogate_{tag}"
|
||
|
|
else:
|
||
|
|
output_dir = Path("models") / "time_conditioned_surrogate"
|
||
|
|
|
||
|
|
cfg = TimeConditionedTrainConfig(
|
||
|
|
processed_path=processed_path,
|
||
|
|
output_dir=output_dir,
|
||
|
|
seed=int(args.seed),
|
||
|
|
batch_size=int(args.batch_size),
|
||
|
|
epochs=int(args.epochs),
|
||
|
|
lr=float(args.lr),
|
||
|
|
weight_decay=float(args.weight_decay),
|
||
|
|
hidden_dim=int(args.hidden_dim),
|
||
|
|
n_blocks=int(args.n_blocks),
|
||
|
|
dropout=float(args.dropout),
|
||
|
|
w_pressure=float(args.w_pressure),
|
||
|
|
w_derivative=float(args.w_derivative),
|
||
|
|
huber_beta=float(args.huber_beta),
|
||
|
|
use_schedule=not bool(args.no_schedule),
|
||
|
|
)
|
||
|
|
train_time_conditioned(cfg)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|