diff --git a/ML/nmWTAI-ML/scripts/train_time_conditioned.py b/ML/nmWTAI-ML/scripts/train_time_conditioned.py index 399ad1e..2455d17 100644 --- a/ML/nmWTAI-ML/scripts/train_time_conditioned.py +++ b/ML/nmWTAI-ML/scripts/train_time_conditioned.py @@ -28,6 +28,18 @@ def main() -> None: parser.add_argument("--w-derivative", type=float, default=2.0) parser.add_argument("--huber-beta", type=float, default=0.05) parser.add_argument("--no-schedule", action="store_true") + parser.add_argument( + "--sample-weight-mode", + choices=["none", "pso_domain", "pso_domain_risk"], + default="none", + help="Optional sample weighting; default keeps the original unweighted training behavior", + ) + parser.add_argument("--pso-outside-weight", type=float, default=0.5) + parser.add_argument("--pso-inside-weight", type=float, default=1.0) + parser.add_argument("--risk-weight", type=float, default=2.5) + parser.add_argument("--skin-lt-minus8-weight", type=float, default=3.5) + parser.add_argument("--sample-weight-min", type=float, default=0.25) + parser.add_argument("--sample-weight-max", type=float, default=4.0) args = parser.parse_args() tag = normalize_tag(args.tag) @@ -54,6 +66,13 @@ def main() -> None: w_derivative=float(args.w_derivative), huber_beta=float(args.huber_beta), use_schedule=not bool(args.no_schedule), + sample_weight_mode=str(args.sample_weight_mode), + pso_outside_weight=float(args.pso_outside_weight), + pso_inside_weight=float(args.pso_inside_weight), + risk_weight=float(args.risk_weight), + skin_lt_minus8_weight=float(args.skin_lt_minus8_weight), + sample_weight_min=float(args.sample_weight_min), + sample_weight_max=float(args.sample_weight_max), ) train_time_conditioned(cfg) diff --git a/ML/nmWTAI-ML/src/training/train_time_conditioned.py b/ML/nmWTAI-ML/src/training/train_time_conditioned.py index 3ce66df..2ee87bf 100644 --- a/ML/nmWTAI-ML/src/training/train_time_conditioned.py +++ b/ML/nmWTAI-ML/src/training/train_time_conditioned.py @@ -11,12 +11,21 @@ import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset +from src.data.param_features import inverse_transform_param_features from src.models.time_conditioned_surrogate import TimeConditionedSurrogate from src.training.train_forward import get_part_slices, infer_curve_layout class PointCurveDataset(Dataset): - def __init__(self, params_x: np.ndarray, schedule_x: np.ndarray, time_x: np.ndarray, curve_y: np.ndarray, layout: dict): + def __init__( + self, + params_x: np.ndarray, + schedule_x: np.ndarray, + time_x: np.ndarray, + curve_y: np.ndarray, + layout: dict, + sample_weight: np.ndarray | None = None, + ): self.params_x = torch.tensor(params_x, dtype=torch.float32) self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32) self.time_x = torch.tensor(time_x, dtype=torch.float32) @@ -28,6 +37,12 @@ class PointCurveDataset(Dataset): self.n_samples = int(self.params_x.shape[0]) self.n_time = int(self.time_x.shape[1]) + if sample_weight is None: + sample_weight = np.ones((self.n_samples,), dtype=np.float32) + sample_weight = np.asarray(sample_weight, dtype=np.float32).reshape(-1) + if sample_weight.shape[0] != self.n_samples: + raise ValueError(f"sample_weight length mismatch: {sample_weight.shape[0]} != {self.n_samples}") + self.sample_weight = torch.tensor(sample_weight, dtype=torch.float32) def __len__(self) -> int: return self.n_samples * self.n_time @@ -40,6 +55,7 @@ class PointCurveDataset(Dataset): self.schedule_x[sample_idx], self.time_x[sample_idx, time_idx], self.y[sample_idx, time_idx], + self.sample_weight[sample_idx], ) @@ -59,6 +75,13 @@ class TimeConditionedTrainConfig: w_derivative: float = 2.0 huber_beta: float = 0.05 use_schedule: bool = True + sample_weight_mode: str = "none" + pso_outside_weight: float = 0.5 + pso_inside_weight: float = 1.0 + risk_weight: float = 2.5 + skin_lt_minus8_weight: float = 3.5 + sample_weight_min: float = 0.25 + sample_weight_max: float = 4.0 device: str = "cuda" if torch.cuda.is_available() else "cpu" @@ -70,10 +93,23 @@ def set_global_seed(seed: int) -> None: torch.cuda.manual_seed_all(seed) -def _loss(pred: torch.Tensor, target: torch.Tensor, cfg: TimeConditionedTrainConfig) -> torch.Tensor: - loss_p = F.smooth_l1_loss(pred[:, 0], target[:, 0], beta=float(cfg.huber_beta), reduction="mean") - loss_d = F.smooth_l1_loss(pred[:, 1], target[:, 1], beta=float(cfg.huber_beta), reduction="mean") - return float(cfg.w_pressure) * loss_p + float(cfg.w_derivative) * loss_d +def _smooth_l1_vector(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor: + return F.smooth_l1_loss(pred, target, beta=float(beta), reduction="none") + + +def _loss( + pred: torch.Tensor, + target: torch.Tensor, + cfg: TimeConditionedTrainConfig, + sample_weight: torch.Tensor | None = None, +) -> torch.Tensor: + loss_p = _smooth_l1_vector(pred[:, 0], target[:, 0], beta=float(cfg.huber_beta)) + loss_d = _smooth_l1_vector(pred[:, 1], target[:, 1], beta=float(cfg.huber_beta)) + loss_vec = float(cfg.w_pressure) * loss_p + float(cfg.w_derivative) * loss_d + if sample_weight is None: + return loss_vec.mean() + w = sample_weight.to(loss_vec.device).reshape(-1).clamp_min(0.0) + return (loss_vec * w).sum() / torch.clamp(w.sum(), min=1.0e-12) def _evaluate(model: TimeConditionedSurrogate, loader: DataLoader, cfg: TimeConditionedTrainConfig) -> float: @@ -81,7 +117,7 @@ def _evaluate(model: TimeConditionedSurrogate, loader: DataLoader, cfg: TimeCond total = 0.0 total_n = 0 with torch.no_grad(): - for params_x, schedule_x, time_x, y in loader: + for params_x, schedule_x, time_x, y, _sample_weight in loader: params_x = params_x.to(cfg.device) schedule_x = schedule_x.to(cfg.device) time_x = time_x.to(cfg.device) @@ -94,6 +130,59 @@ def _evaluate(model: TimeConditionedSurrogate, loader: DataLoader, cfg: TimeCond return total / max(total_n, 1) +def _raw_params_from_processed_split(data: dict, split: str) -> dict[str, np.ndarray]: + key = f"X_params_{split}" + features = data["scaler_params"].inverse_transform(data[key]) + raw = inverse_transform_param_features(features, data.get("meta", {}).get("param_feature_transform")) + names = list(data.get("meta", {}).get("param_names") or ["k", "skin", "wellboreC", "phi", "h", "Cf"]) + return {name: raw[:, idx].astype(np.float64) for idx, name in enumerate(names[: raw.shape[1]])} + + +def _build_sample_weight(data: dict, cfg: TimeConditionedTrainConfig, split: str = "train") -> np.ndarray: + mode = str(cfg.sample_weight_mode or "none").lower() + n = int(data[f"X_params_{split}"].shape[0]) + if mode in {"none", "off", "false"}: + return np.ones((n,), dtype=np.float32) + if mode not in {"pso_domain", "pso_domain_risk"}: + raise ValueError(f"Unknown sample_weight_mode={cfg.sample_weight_mode!r}") + + params = _raw_params_from_processed_split(data, split) + pso_mask = ( + (params["k"] >= 0.001) + & (params["k"] <= 10.0) + & (params["skin"] >= -10.0) + & (params["skin"] <= 10.0) + & (params["wellboreC"] >= 1.0e-4) + & (params["wellboreC"] <= 2.0) + & (params["phi"] >= 0.01) + & (params["phi"] <= 0.5) + & (params["h"] >= 2.0) + & (params["h"] <= 50.0) + ) + + weight = np.where(pso_mask, float(cfg.pso_inside_weight), float(cfg.pso_outside_weight)).astype(np.float32) + if mode == "pso_domain_risk": + risk = pso_mask & (params["skin"] < -5.0) & (params["wellboreC"] > 0.1) + skin_extreme = pso_mask & (params["skin"] < -8.0) + weight[risk] = np.maximum(weight[risk], float(cfg.risk_weight)) + weight[skin_extreme] = np.maximum(weight[skin_extreme], float(cfg.skin_lt_minus8_weight)) + + weight = np.clip(weight, float(cfg.sample_weight_min), float(cfg.sample_weight_max)) + return weight.astype(np.float32) + + +def _summarize_sample_weight(sample_weight: np.ndarray) -> dict: + w = np.asarray(sample_weight, dtype=np.float32).reshape(-1) + return { + "min": float(np.min(w)), + "mean": float(np.mean(w)), + "median": float(np.median(w)), + "max": float(np.max(w)), + "n_weight_gt_1": int(np.sum(w > 1.0)), + "n_weight_lt_1": int(np.sum(w < 1.0)), + } + + def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None: cfg.output_dir.mkdir(parents=True, exist_ok=True) set_global_seed(int(cfg.seed)) @@ -105,7 +194,16 @@ def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None: raise KeyError(f"processed dataset is missing time-conditioned fields: {missing}") curve_layout = infer_curve_layout(data) - train_ds = PointCurveDataset(data["X_params_train"], data["X_schedule_train"], data["X_time_train"], data["Y_curve_train"], curve_layout) + train_weight = _build_sample_weight(data, cfg, split="train") + train_weight_summary = _summarize_sample_weight(train_weight) + train_ds = PointCurveDataset( + data["X_params_train"], + data["X_schedule_train"], + data["X_time_train"], + data["Y_curve_train"], + curve_layout, + sample_weight=train_weight, + ) val_ds = PointCurveDataset(data["X_params_val"], data["X_schedule_val"], data["X_time_val"], data["Y_curve_val"], curve_layout) test_ds = PointCurveDataset(data["X_params_test"], data["X_schedule_test"], data["X_time_test"], data["Y_curve_test"], curve_layout) @@ -139,20 +237,22 @@ def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None: f"schedule={data['X_schedule_train'].shape[1]}, time={data['X_time_train'].shape[-1]}" ) print(f" curve_time_source={data.get('meta', {}).get('curve_time_source', 'unknown')}") + print(f" sample_weight_mode={cfg.sample_weight_mode}, sample_weight={train_weight_summary}") for epoch in range(1, int(cfg.epochs) + 1): model.train() total = 0.0 total_n = 0 - for params_x, schedule_x, time_x, y in train_loader: + for params_x, schedule_x, time_x, y, sample_weight in train_loader: params_x = params_x.to(cfg.device) schedule_x = schedule_x.to(cfg.device) time_x = time_x.to(cfg.device) y = y.to(cfg.device) + sample_weight = sample_weight.to(cfg.device) optimizer.zero_grad() pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None) - loss = _loss(pred, y, cfg) + loss = _loss(pred, y, cfg, sample_weight=sample_weight) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() @@ -181,6 +281,8 @@ def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None: "curve_layout": curve_layout, "processed_path": str(cfg.processed_path), "seed": int(cfg.seed), + "sample_weight_mode": str(cfg.sample_weight_mode), + "sample_weight_summary": train_weight_summary, }, best_path, ) @@ -192,7 +294,16 @@ def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None: (cfg.output_dir / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False), encoding="utf-8") (cfg.output_dir / "metrics.json").write_text( - json.dumps({"best_val_loss": best_val, "test_loss": test_loss}, indent=2, ensure_ascii=False), + json.dumps( + { + "best_val_loss": best_val, + "test_loss": test_loss, + "sample_weight_mode": str(cfg.sample_weight_mode), + "sample_weight_summary": train_weight_summary, + }, + indent=2, + ensure_ascii=False, + ), encoding="utf-8", ) print(f"[Final] test={test_loss:.6f}")