from __future__ import annotations import json import random from dataclasses import dataclass from pathlib import Path import joblib import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from src.models.time_conditioned_surrogate import TimeConditionedSurrogate from src.training.train_forward import get_part_slices, infer_curve_layout class PointCurveDataset(Dataset): def __init__(self, params_x: np.ndarray, schedule_x: np.ndarray, time_x: np.ndarray, curve_y: np.ndarray, layout: dict): self.params_x = torch.tensor(params_x, dtype=torch.float32) self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32) self.time_x = torch.tensor(time_x, dtype=torch.float32) slices = get_part_slices(layout) p = curve_y[:, slices["log_pressure"]] d = curve_y[:, slices["log_derivative"]] self.y = torch.tensor(np.stack([p, d], axis=-1), dtype=torch.float32) self.n_samples = int(self.params_x.shape[0]) self.n_time = int(self.time_x.shape[1]) def __len__(self) -> int: return self.n_samples * self.n_time def __getitem__(self, idx: int): sample_idx = idx // self.n_time time_idx = idx % self.n_time return ( self.params_x[sample_idx], self.schedule_x[sample_idx], self.time_x[sample_idx, time_idx], self.y[sample_idx, time_idx], ) @dataclass class TimeConditionedTrainConfig: processed_path: Path output_dir: Path seed: int = 42 batch_size: int = 4096 epochs: int = 120 lr: float = 1.0e-3 weight_decay: float = 1.0e-4 hidden_dim: int = 256 n_blocks: int = 4 dropout: float = 0.05 w_pressure: float = 1.0 w_derivative: float = 2.0 huber_beta: float = 0.05 use_schedule: bool = True device: str = "cuda" if torch.cuda.is_available() else "cpu" def set_global_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def _loss(pred: torch.Tensor, target: torch.Tensor, cfg: TimeConditionedTrainConfig) -> torch.Tensor: loss_p = F.smooth_l1_loss(pred[:, 0], target[:, 0], beta=float(cfg.huber_beta), reduction="mean") loss_d = F.smooth_l1_loss(pred[:, 1], target[:, 1], beta=float(cfg.huber_beta), reduction="mean") return float(cfg.w_pressure) * loss_p + float(cfg.w_derivative) * loss_d def _evaluate(model: TimeConditionedSurrogate, loader: DataLoader, cfg: TimeConditionedTrainConfig) -> float: model.eval() total = 0.0 total_n = 0 with torch.no_grad(): for params_x, schedule_x, time_x, y in loader: params_x = params_x.to(cfg.device) schedule_x = schedule_x.to(cfg.device) time_x = time_x.to(cfg.device) y = y.to(cfg.device) pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None) loss = _loss(pred, y, cfg) bs = int(y.shape[0]) total += float(loss.detach().cpu()) * bs total_n += bs return total / max(total_n, 1) def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None: cfg.output_dir.mkdir(parents=True, exist_ok=True) set_global_seed(int(cfg.seed)) data = joblib.load(cfg.processed_path) required = ["X_time_train", "X_time_val", "X_time_test"] missing = [key for key in required if key not in data] if missing: raise KeyError(f"processed dataset is missing time-conditioned fields: {missing}") curve_layout = infer_curve_layout(data) train_ds = PointCurveDataset(data["X_params_train"], data["X_schedule_train"], data["X_time_train"], data["Y_curve_train"], curve_layout) val_ds = PointCurveDataset(data["X_params_val"], data["X_schedule_val"], data["X_time_val"], data["Y_curve_val"], curve_layout) test_ds = PointCurveDataset(data["X_params_test"], data["X_schedule_test"], data["X_time_test"], data["Y_curve_test"], curve_layout) generator = torch.Generator() generator.manual_seed(int(cfg.seed)) train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, generator=generator) val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False) test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False) model = TimeConditionedSurrogate( param_dim=int(data["X_params_train"].shape[1]), schedule_dim=int(data["X_schedule_train"].shape[1]), time_dim=int(data["X_time_train"].shape[-1]), hidden_dim=int(cfg.hidden_dim), n_blocks=int(cfg.n_blocks), dropout=float(cfg.dropout), use_schedule=bool(cfg.use_schedule), ).to(cfg.device) optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay)) best_val = float("inf") best_path = cfg.output_dir / "time_conditioned_surrogate_best.pt" history: list[dict] = [] print("Time-conditioned training config:") print(f" processed={cfg.processed_path}") print(f" output_dir={cfg.output_dir}") print(f" device={cfg.device}, batch_size={cfg.batch_size}, epochs={cfg.epochs}") print( f" dims: param={data['X_params_train'].shape[1]}, " f"schedule={data['X_schedule_train'].shape[1]}, time={data['X_time_train'].shape[-1]}" ) print(f" curve_time_source={data.get('meta', {}).get('curve_time_source', 'unknown')}") for epoch in range(1, int(cfg.epochs) + 1): model.train() total = 0.0 total_n = 0 for params_x, schedule_x, time_x, y in train_loader: params_x = params_x.to(cfg.device) schedule_x = schedule_x.to(cfg.device) time_x = time_x.to(cfg.device) y = y.to(cfg.device) optimizer.zero_grad() pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None) loss = _loss(pred, y, cfg) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() bs = int(y.shape[0]) total += float(loss.detach().cpu()) * bs total_n += bs train_loss = total / max(total_n, 1) val_loss = _evaluate(model, val_loader, cfg) history.append({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss}) print(f"[Epoch {epoch:03d}] train={train_loss:.6f} val={val_loss:.6f}") if val_loss < best_val: best_val = val_loss torch.save( { "model_state_dict": model.state_dict(), "param_dim": int(data["X_params_train"].shape[1]), "schedule_dim": int(data["X_schedule_train"].shape[1]), "time_dim": int(data["X_time_train"].shape[-1]), "hidden_dim": int(cfg.hidden_dim), "n_blocks": int(cfg.n_blocks), "dropout": float(cfg.dropout), "use_schedule": bool(cfg.use_schedule), "curve_layout": curve_layout, "processed_path": str(cfg.processed_path), "seed": int(cfg.seed), }, best_path, ) print(f" -> best model saved to: {best_path}") checkpoint = torch.load(best_path, map_location=cfg.device) model.load_state_dict(checkpoint["model_state_dict"]) test_loss = _evaluate(model, test_loader, cfg) (cfg.output_dir / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False), encoding="utf-8") (cfg.output_dir / "metrics.json").write_text( json.dumps({"best_val_loss": best_val, "test_loss": test_loss}, indent=2, ensure_ascii=False), encoding="utf-8", ) print(f"[Final] test={test_loss:.6f}")