You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/src/training/train_time_conditioned.py

294 lines
12 KiB
Python

from __future__ import annotations
import json
import random
from dataclasses import dataclass
from pathlib import Path
import joblib
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from src.data.param_features import inverse_transform_param_features
from src.models.time_conditioned_surrogate import TimeConditionedSurrogate
from src.training.train_forward import get_part_slices, infer_curve_layout
class PointCurveDataset(Dataset):
def __init__(
self,
params_x: np.ndarray,
schedule_x: np.ndarray,
time_x: np.ndarray,
curve_y: np.ndarray,
layout: dict,
sample_weight: np.ndarray | None = None,
):
self.params_x = torch.tensor(params_x, dtype=torch.float32)
self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32)
self.time_x = torch.tensor(time_x, dtype=torch.float32)
slices = get_part_slices(layout)
p = curve_y[:, slices["log_pressure"]]
d = curve_y[:, slices["log_derivative"]]
self.y = torch.tensor(np.stack([p, d], axis=-1), dtype=torch.float32)
self.n_samples = int(self.params_x.shape[0])
self.n_time = int(self.time_x.shape[1])
if sample_weight is None:
sample_weight = np.ones((self.n_samples,), dtype=np.float32)
sample_weight = np.asarray(sample_weight, dtype=np.float32).reshape(-1)
if sample_weight.shape[0] != self.n_samples:
raise ValueError(f"sample_weight length mismatch: {sample_weight.shape[0]} != {self.n_samples}")
self.sample_weight = torch.tensor(sample_weight, dtype=torch.float32)
def __len__(self) -> int:
return self.n_samples * self.n_time
def __getitem__(self, idx: int):
sample_idx = idx // self.n_time
time_idx = idx % self.n_time
return (
self.params_x[sample_idx],
self.schedule_x[sample_idx],
self.time_x[sample_idx, time_idx],
self.y[sample_idx, time_idx],
self.sample_weight[sample_idx],
)
@dataclass
class TimeConditionedTrainConfig:
processed_path: Path
output_dir: Path
seed: int = 42
batch_size: int = 4096
epochs: int = 120
lr: float = 1.0e-3
weight_decay: float = 1.0e-4
hidden_dim: int = 256
n_blocks: int = 4
dropout: float = 0.05
w_pressure: float = 1.0
w_derivative: float = 2.0
huber_beta: float = 0.05
use_schedule: bool = True
sample_weight_mode: str = "none"
risk_weight: float = 2.5
skin_lt_minus8_weight: float = 3.5
sample_weight_min: float = 1.0
sample_weight_max: float = 4.0
device: str = "cuda" if torch.cuda.is_available() else "cpu"
def set_global_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def _smooth_l1_vector(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor:
return F.smooth_l1_loss(pred, target, beta=float(beta), reduction="none")
def _loss(
pred: torch.Tensor,
target: torch.Tensor,
cfg: TimeConditionedTrainConfig,
sample_weight: torch.Tensor | None = None,
) -> torch.Tensor:
loss_p = _smooth_l1_vector(pred[:, 0], target[:, 0], beta=float(cfg.huber_beta))
loss_d = _smooth_l1_vector(pred[:, 1], target[:, 1], beta=float(cfg.huber_beta))
loss_vec = float(cfg.w_pressure) * loss_p + float(cfg.w_derivative) * loss_d
if sample_weight is None:
return loss_vec.mean()
w = sample_weight.to(loss_vec.device).reshape(-1).clamp_min(0.0)
return (loss_vec * w).sum() / torch.clamp(w.sum(), min=1.0e-12)
def _evaluate(model: TimeConditionedSurrogate, loader: DataLoader, cfg: TimeConditionedTrainConfig) -> float:
model.eval()
total = 0.0
total_n = 0
with torch.no_grad():
for params_x, schedule_x, time_x, y, _sample_weight in loader:
params_x = params_x.to(cfg.device)
schedule_x = schedule_x.to(cfg.device)
time_x = time_x.to(cfg.device)
y = y.to(cfg.device)
pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None)
loss = _loss(pred, y, cfg)
bs = int(y.shape[0])
total += float(loss.detach().cpu()) * bs
total_n += bs
return total / max(total_n, 1)
def _raw_params_from_processed_split(data: dict, split: str) -> dict[str, np.ndarray]:
key = f"X_params_{split}"
features = data["scaler_params"].inverse_transform(data[key])
raw = inverse_transform_param_features(features, data.get("meta", {}).get("param_feature_transform"))
names = list(data.get("meta", {}).get("param_names") or ["k", "skin", "wellboreC", "phi", "h", "Cf"])
return {name: raw[:, idx].astype(np.float64) for idx, name in enumerate(names[: raw.shape[1]])}
def _build_sample_weight(data: dict, cfg: TimeConditionedTrainConfig, split: str = "train") -> np.ndarray:
mode = str(cfg.sample_weight_mode or "none").lower()
n = int(data[f"X_params_{split}"].shape[0])
if mode in {"none", "off", "false"}:
return np.ones((n,), dtype=np.float32)
if mode != "risk_region":
raise ValueError(f"Unknown sample_weight_mode={cfg.sample_weight_mode!r}")
params = _raw_params_from_processed_split(data, split)
weight = np.ones((n,), dtype=np.float32)
risk = (params["skin"] < -5.0) & (params["wellboreC"] > 0.1)
skin_extreme = params["skin"] < -8.0
weight[risk] = np.maximum(weight[risk], float(cfg.risk_weight))
weight[skin_extreme] = np.maximum(weight[skin_extreme], float(cfg.skin_lt_minus8_weight))
weight = np.clip(weight, float(cfg.sample_weight_min), float(cfg.sample_weight_max))
return weight.astype(np.float32)
def _summarize_sample_weight(sample_weight: np.ndarray) -> dict:
w = np.asarray(sample_weight, dtype=np.float32).reshape(-1)
return {
"min": float(np.min(w)),
"mean": float(np.mean(w)),
"median": float(np.median(w)),
"max": float(np.max(w)),
"n_weight_gt_1": int(np.sum(w > 1.0)),
"n_weight_lt_1": int(np.sum(w < 1.0)),
}
def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None:
cfg.output_dir.mkdir(parents=True, exist_ok=True)
set_global_seed(int(cfg.seed))
data = joblib.load(cfg.processed_path)
required = ["X_time_train", "X_time_val", "X_time_test"]
missing = [key for key in required if key not in data]
if missing:
raise KeyError(f"processed dataset is missing time-conditioned fields: {missing}")
curve_layout = infer_curve_layout(data)
train_weight = _build_sample_weight(data, cfg, split="train")
train_weight_summary = _summarize_sample_weight(train_weight)
train_ds = PointCurveDataset(
data["X_params_train"],
data["X_schedule_train"],
data["X_time_train"],
data["Y_curve_train"],
curve_layout,
sample_weight=train_weight,
)
val_ds = PointCurveDataset(data["X_params_val"], data["X_schedule_val"], data["X_time_val"], data["Y_curve_val"], curve_layout)
test_ds = PointCurveDataset(data["X_params_test"], data["X_schedule_test"], data["X_time_test"], data["Y_curve_test"], curve_layout)
generator = torch.Generator()
generator.manual_seed(int(cfg.seed))
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, generator=generator)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False)
model = TimeConditionedSurrogate(
param_dim=int(data["X_params_train"].shape[1]),
schedule_dim=int(data["X_schedule_train"].shape[1]),
time_dim=int(data["X_time_train"].shape[-1]),
hidden_dim=int(cfg.hidden_dim),
n_blocks=int(cfg.n_blocks),
dropout=float(cfg.dropout),
use_schedule=bool(cfg.use_schedule),
).to(cfg.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay))
best_val = float("inf")
best_path = cfg.output_dir / "time_conditioned_surrogate_best.pt"
history: list[dict] = []
print("Time-conditioned training config:")
print(f" processed={cfg.processed_path}")
print(f" output_dir={cfg.output_dir}")
print(f" device={cfg.device}, batch_size={cfg.batch_size}, epochs={cfg.epochs}")
print(
f" dims: param={data['X_params_train'].shape[1]}, "
f"schedule={data['X_schedule_train'].shape[1]}, time={data['X_time_train'].shape[-1]}"
)
print(f" curve_time_source={data.get('meta', {}).get('curve_time_source', 'unknown')}")
print(f" sample_weight_mode={cfg.sample_weight_mode}, sample_weight={train_weight_summary}")
for epoch in range(1, int(cfg.epochs) + 1):
model.train()
total = 0.0
total_n = 0
for params_x, schedule_x, time_x, y, sample_weight in train_loader:
params_x = params_x.to(cfg.device)
schedule_x = schedule_x.to(cfg.device)
time_x = time_x.to(cfg.device)
y = y.to(cfg.device)
sample_weight = sample_weight.to(cfg.device)
optimizer.zero_grad()
pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None)
loss = _loss(pred, y, cfg, sample_weight=sample_weight)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
bs = int(y.shape[0])
total += float(loss.detach().cpu()) * bs
total_n += bs
train_loss = total / max(total_n, 1)
val_loss = _evaluate(model, val_loader, cfg)
history.append({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss})
print(f"[Epoch {epoch:03d}] train={train_loss:.6f} val={val_loss:.6f}")
if val_loss < best_val:
best_val = val_loss
torch.save(
{
"model_state_dict": model.state_dict(),
"param_dim": int(data["X_params_train"].shape[1]),
"schedule_dim": int(data["X_schedule_train"].shape[1]),
"time_dim": int(data["X_time_train"].shape[-1]),
"hidden_dim": int(cfg.hidden_dim),
"n_blocks": int(cfg.n_blocks),
"dropout": float(cfg.dropout),
"use_schedule": bool(cfg.use_schedule),
"curve_layout": curve_layout,
"processed_path": str(cfg.processed_path),
"seed": int(cfg.seed),
"sample_weight_mode": str(cfg.sample_weight_mode),
"sample_weight_summary": train_weight_summary,
},
best_path,
)
print(f" -> best model saved to: {best_path}")
checkpoint = torch.load(best_path, map_location=cfg.device)
model.load_state_dict(checkpoint["model_state_dict"])
test_loss = _evaluate(model, test_loader, cfg)
(cfg.output_dir / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False), encoding="utf-8")
(cfg.output_dir / "metrics.json").write_text(
json.dumps(
{
"best_val_loss": best_val,
"test_loss": test_loss,
"sample_weight_mode": str(cfg.sample_weight_mode),
"sample_weight_summary": train_weight_summary,
},
indent=2,
ensure_ascii=False,
),
encoding="utf-8",
)
print(f"[Final] test={test_loss:.6f}")