from __future__ import annotations import json import random from dataclasses import dataclass from pathlib import Path import joblib import numpy as np import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from src.data.param_features import inverse_transform_param_features from src.models.time_conditioned_surrogate import TimeConditionedSurrogate from src.training.train_forward import get_part_slices, infer_curve_layout class PointCurveDataset(Dataset): def __init__( self, params_x: np.ndarray, schedule_x: np.ndarray, time_x: np.ndarray, curve_y: np.ndarray, layout: dict, sample_weight: np.ndarray | None = None, ): self.params_x = torch.tensor(params_x, dtype=torch.float32) self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32) self.time_x = torch.tensor(time_x, dtype=torch.float32) slices = get_part_slices(layout) p = curve_y[:, slices["log_pressure"]] d = curve_y[:, slices["log_derivative"]] self.y = torch.tensor(np.stack([p, d], axis=-1), dtype=torch.float32) self.n_samples = int(self.params_x.shape[0]) self.n_time = int(self.time_x.shape[1]) if sample_weight is None: sample_weight = np.ones((self.n_samples,), dtype=np.float32) sample_weight = np.asarray(sample_weight, dtype=np.float32).reshape(-1) if sample_weight.shape[0] != self.n_samples: raise ValueError(f"sample_weight length mismatch: {sample_weight.shape[0]} != {self.n_samples}") self.sample_weight = torch.tensor(sample_weight, dtype=torch.float32) def __len__(self) -> int: return self.n_samples * self.n_time def __getitem__(self, idx: int): sample_idx = idx // self.n_time time_idx = idx % self.n_time return ( self.params_x[sample_idx], self.schedule_x[sample_idx], self.time_x[sample_idx, time_idx], self.y[sample_idx, time_idx], self.sample_weight[sample_idx], ) @dataclass class TimeConditionedTrainConfig: processed_path: Path output_dir: Path seed: int = 42 batch_size: int = 4096 epochs: int = 120 lr: float = 1.0e-3 weight_decay: float = 1.0e-4 hidden_dim: int = 256 n_blocks: int = 4 dropout: float = 0.05 w_pressure: float = 1.0 w_derivative: float = 2.0 huber_beta: float = 0.05 use_schedule: bool = True sample_weight_mode: str = "none" risk_weight: float = 2.5 skin_lt_minus8_weight: float = 3.5 sample_weight_min: float = 1.0 sample_weight_max: float = 4.0 device: str = "cuda" if torch.cuda.is_available() else "cpu" def set_global_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def _smooth_l1_vector(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor: return F.smooth_l1_loss(pred, target, beta=float(beta), reduction="none") def _loss( pred: torch.Tensor, target: torch.Tensor, cfg: TimeConditionedTrainConfig, sample_weight: torch.Tensor | None = None, ) -> torch.Tensor: loss_p = _smooth_l1_vector(pred[:, 0], target[:, 0], beta=float(cfg.huber_beta)) loss_d = _smooth_l1_vector(pred[:, 1], target[:, 1], beta=float(cfg.huber_beta)) loss_vec = float(cfg.w_pressure) * loss_p + float(cfg.w_derivative) * loss_d if sample_weight is None: return loss_vec.mean() w = sample_weight.to(loss_vec.device).reshape(-1).clamp_min(0.0) return (loss_vec * w).sum() / torch.clamp(w.sum(), min=1.0e-12) def _evaluate(model: TimeConditionedSurrogate, loader: DataLoader, cfg: TimeConditionedTrainConfig) -> float: model.eval() total = 0.0 total_n = 0 with torch.no_grad(): for params_x, schedule_x, time_x, y, _sample_weight in loader: params_x = params_x.to(cfg.device) schedule_x = schedule_x.to(cfg.device) time_x = time_x.to(cfg.device) y = y.to(cfg.device) pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None) loss = _loss(pred, y, cfg) bs = int(y.shape[0]) total += float(loss.detach().cpu()) * bs total_n += bs return total / max(total_n, 1) def _raw_params_from_processed_split(data: dict, split: str) -> dict[str, np.ndarray]: key = f"X_params_{split}" features = data["scaler_params"].inverse_transform(data[key]) raw = inverse_transform_param_features(features, data.get("meta", {}).get("param_feature_transform")) names = list(data.get("meta", {}).get("param_names") or ["k", "skin", "wellboreC", "phi", "h", "Cf"]) return {name: raw[:, idx].astype(np.float64) for idx, name in enumerate(names[: raw.shape[1]])} def _build_sample_weight(data: dict, cfg: TimeConditionedTrainConfig, split: str = "train") -> np.ndarray: mode = str(cfg.sample_weight_mode or "none").lower() n = int(data[f"X_params_{split}"].shape[0]) if mode in {"none", "off", "false"}: return np.ones((n,), dtype=np.float32) if mode != "risk_region": raise ValueError(f"Unknown sample_weight_mode={cfg.sample_weight_mode!r}") params = _raw_params_from_processed_split(data, split) weight = np.ones((n,), dtype=np.float32) risk = (params["skin"] < -5.0) & (params["wellboreC"] > 0.1) skin_extreme = params["skin"] < -8.0 weight[risk] = np.maximum(weight[risk], float(cfg.risk_weight)) weight[skin_extreme] = np.maximum(weight[skin_extreme], float(cfg.skin_lt_minus8_weight)) weight = np.clip(weight, float(cfg.sample_weight_min), float(cfg.sample_weight_max)) return weight.astype(np.float32) def _summarize_sample_weight(sample_weight: np.ndarray) -> dict: w = np.asarray(sample_weight, dtype=np.float32).reshape(-1) return { "min": float(np.min(w)), "mean": float(np.mean(w)), "median": float(np.median(w)), "max": float(np.max(w)), "n_weight_gt_1": int(np.sum(w > 1.0)), "n_weight_lt_1": int(np.sum(w < 1.0)), } def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None: cfg.output_dir.mkdir(parents=True, exist_ok=True) set_global_seed(int(cfg.seed)) data = joblib.load(cfg.processed_path) required = ["X_time_train", "X_time_val", "X_time_test"] missing = [key for key in required if key not in data] if missing: raise KeyError(f"processed dataset is missing time-conditioned fields: {missing}") curve_layout = infer_curve_layout(data) train_weight = _build_sample_weight(data, cfg, split="train") train_weight_summary = _summarize_sample_weight(train_weight) train_ds = PointCurveDataset( data["X_params_train"], data["X_schedule_train"], data["X_time_train"], data["Y_curve_train"], curve_layout, sample_weight=train_weight, ) val_ds = PointCurveDataset(data["X_params_val"], data["X_schedule_val"], data["X_time_val"], data["Y_curve_val"], curve_layout) test_ds = PointCurveDataset(data["X_params_test"], data["X_schedule_test"], data["X_time_test"], data["Y_curve_test"], curve_layout) generator = torch.Generator() generator.manual_seed(int(cfg.seed)) train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, generator=generator) val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False) test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False) model = TimeConditionedSurrogate( param_dim=int(data["X_params_train"].shape[1]), schedule_dim=int(data["X_schedule_train"].shape[1]), time_dim=int(data["X_time_train"].shape[-1]), hidden_dim=int(cfg.hidden_dim), n_blocks=int(cfg.n_blocks), dropout=float(cfg.dropout), use_schedule=bool(cfg.use_schedule), ).to(cfg.device) optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay)) best_val = float("inf") best_path = cfg.output_dir / "time_conditioned_surrogate_best.pt" history: list[dict] = [] print("Time-conditioned training config:") print(f" processed={cfg.processed_path}") print(f" output_dir={cfg.output_dir}") print(f" device={cfg.device}, batch_size={cfg.batch_size}, epochs={cfg.epochs}") print( f" dims: param={data['X_params_train'].shape[1]}, " f"schedule={data['X_schedule_train'].shape[1]}, time={data['X_time_train'].shape[-1]}" ) print(f" curve_time_source={data.get('meta', {}).get('curve_time_source', 'unknown')}") print(f" sample_weight_mode={cfg.sample_weight_mode}, sample_weight={train_weight_summary}") for epoch in range(1, int(cfg.epochs) + 1): model.train() total = 0.0 total_n = 0 for params_x, schedule_x, time_x, y, sample_weight in train_loader: params_x = params_x.to(cfg.device) schedule_x = schedule_x.to(cfg.device) time_x = time_x.to(cfg.device) y = y.to(cfg.device) sample_weight = sample_weight.to(cfg.device) optimizer.zero_grad() pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None) loss = _loss(pred, y, cfg, sample_weight=sample_weight) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() bs = int(y.shape[0]) total += float(loss.detach().cpu()) * bs total_n += bs train_loss = total / max(total_n, 1) val_loss = _evaluate(model, val_loader, cfg) history.append({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss}) print(f"[Epoch {epoch:03d}] train={train_loss:.6f} val={val_loss:.6f}") if val_loss < best_val: best_val = val_loss torch.save( { "model_state_dict": model.state_dict(), "param_dim": int(data["X_params_train"].shape[1]), "schedule_dim": int(data["X_schedule_train"].shape[1]), "time_dim": int(data["X_time_train"].shape[-1]), "hidden_dim": int(cfg.hidden_dim), "n_blocks": int(cfg.n_blocks), "dropout": float(cfg.dropout), "use_schedule": bool(cfg.use_schedule), "curve_layout": curve_layout, "processed_path": str(cfg.processed_path), "seed": int(cfg.seed), "sample_weight_mode": str(cfg.sample_weight_mode), "sample_weight_summary": train_weight_summary, }, best_path, ) print(f" -> best model saved to: {best_path}") checkpoint = torch.load(best_path, map_location=cfg.device) model.load_state_dict(checkpoint["model_state_dict"]) test_loss = _evaluate(model, test_loader, cfg) (cfg.output_dir / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False), encoding="utf-8") (cfg.output_dir / "metrics.json").write_text( json.dumps( { "best_val_loss": best_val, "test_loss": test_loss, "sample_weight_mode": str(cfg.sample_weight_mode), "sample_weight_summary": train_weight_summary, }, indent=2, ensure_ascii=False, ), encoding="utf-8", ) print(f"[Final] test={test_loss:.6f}")