You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
199 lines
7.8 KiB
Python
199 lines
7.8 KiB
Python
|
3 weeks ago
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import json
|
||
|
|
import random
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import joblib
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
import torch.nn.functional as F
|
||
|
|
from torch.utils.data import DataLoader, Dataset
|
||
|
|
|
||
|
|
from src.models.time_conditioned_surrogate import TimeConditionedSurrogate
|
||
|
|
from src.training.train_forward import get_part_slices, infer_curve_layout
|
||
|
|
|
||
|
|
|
||
|
|
class PointCurveDataset(Dataset):
|
||
|
|
def __init__(self, params_x: np.ndarray, schedule_x: np.ndarray, time_x: np.ndarray, curve_y: np.ndarray, layout: dict):
|
||
|
|
self.params_x = torch.tensor(params_x, dtype=torch.float32)
|
||
|
|
self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32)
|
||
|
|
self.time_x = torch.tensor(time_x, dtype=torch.float32)
|
||
|
|
|
||
|
|
slices = get_part_slices(layout)
|
||
|
|
p = curve_y[:, slices["log_pressure"]]
|
||
|
|
d = curve_y[:, slices["log_derivative"]]
|
||
|
|
self.y = torch.tensor(np.stack([p, d], axis=-1), dtype=torch.float32)
|
||
|
|
|
||
|
|
self.n_samples = int(self.params_x.shape[0])
|
||
|
|
self.n_time = int(self.time_x.shape[1])
|
||
|
|
|
||
|
|
def __len__(self) -> int:
|
||
|
|
return self.n_samples * self.n_time
|
||
|
|
|
||
|
|
def __getitem__(self, idx: int):
|
||
|
|
sample_idx = idx // self.n_time
|
||
|
|
time_idx = idx % self.n_time
|
||
|
|
return (
|
||
|
|
self.params_x[sample_idx],
|
||
|
|
self.schedule_x[sample_idx],
|
||
|
|
self.time_x[sample_idx, time_idx],
|
||
|
|
self.y[sample_idx, time_idx],
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class TimeConditionedTrainConfig:
|
||
|
|
processed_path: Path
|
||
|
|
output_dir: Path
|
||
|
|
seed: int = 42
|
||
|
|
batch_size: int = 4096
|
||
|
|
epochs: int = 120
|
||
|
|
lr: float = 1.0e-3
|
||
|
|
weight_decay: float = 1.0e-4
|
||
|
|
hidden_dim: int = 256
|
||
|
|
n_blocks: int = 4
|
||
|
|
dropout: float = 0.05
|
||
|
|
w_pressure: float = 1.0
|
||
|
|
w_derivative: float = 2.0
|
||
|
|
huber_beta: float = 0.05
|
||
|
|
use_schedule: bool = True
|
||
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
||
|
|
|
||
|
|
|
||
|
|
def set_global_seed(seed: int) -> None:
|
||
|
|
random.seed(seed)
|
||
|
|
np.random.seed(seed)
|
||
|
|
torch.manual_seed(seed)
|
||
|
|
if torch.cuda.is_available():
|
||
|
|
torch.cuda.manual_seed_all(seed)
|
||
|
|
|
||
|
|
|
||
|
|
def _loss(pred: torch.Tensor, target: torch.Tensor, cfg: TimeConditionedTrainConfig) -> torch.Tensor:
|
||
|
|
loss_p = F.smooth_l1_loss(pred[:, 0], target[:, 0], beta=float(cfg.huber_beta), reduction="mean")
|
||
|
|
loss_d = F.smooth_l1_loss(pred[:, 1], target[:, 1], beta=float(cfg.huber_beta), reduction="mean")
|
||
|
|
return float(cfg.w_pressure) * loss_p + float(cfg.w_derivative) * loss_d
|
||
|
|
|
||
|
|
|
||
|
|
def _evaluate(model: TimeConditionedSurrogate, loader: DataLoader, cfg: TimeConditionedTrainConfig) -> float:
|
||
|
|
model.eval()
|
||
|
|
total = 0.0
|
||
|
|
total_n = 0
|
||
|
|
with torch.no_grad():
|
||
|
|
for params_x, schedule_x, time_x, y in loader:
|
||
|
|
params_x = params_x.to(cfg.device)
|
||
|
|
schedule_x = schedule_x.to(cfg.device)
|
||
|
|
time_x = time_x.to(cfg.device)
|
||
|
|
y = y.to(cfg.device)
|
||
|
|
pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None)
|
||
|
|
loss = _loss(pred, y, cfg)
|
||
|
|
bs = int(y.shape[0])
|
||
|
|
total += float(loss.detach().cpu()) * bs
|
||
|
|
total_n += bs
|
||
|
|
return total / max(total_n, 1)
|
||
|
|
|
||
|
|
|
||
|
|
def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None:
|
||
|
|
cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
set_global_seed(int(cfg.seed))
|
||
|
|
|
||
|
|
data = joblib.load(cfg.processed_path)
|
||
|
|
required = ["X_time_train", "X_time_val", "X_time_test"]
|
||
|
|
missing = [key for key in required if key not in data]
|
||
|
|
if missing:
|
||
|
|
raise KeyError(f"processed dataset is missing time-conditioned fields: {missing}")
|
||
|
|
|
||
|
|
curve_layout = infer_curve_layout(data)
|
||
|
|
train_ds = PointCurveDataset(data["X_params_train"], data["X_schedule_train"], data["X_time_train"], data["Y_curve_train"], curve_layout)
|
||
|
|
val_ds = PointCurveDataset(data["X_params_val"], data["X_schedule_val"], data["X_time_val"], data["Y_curve_val"], curve_layout)
|
||
|
|
test_ds = PointCurveDataset(data["X_params_test"], data["X_schedule_test"], data["X_time_test"], data["Y_curve_test"], curve_layout)
|
||
|
|
|
||
|
|
generator = torch.Generator()
|
||
|
|
generator.manual_seed(int(cfg.seed))
|
||
|
|
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, generator=generator)
|
||
|
|
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False)
|
||
|
|
test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False)
|
||
|
|
|
||
|
|
model = TimeConditionedSurrogate(
|
||
|
|
param_dim=int(data["X_params_train"].shape[1]),
|
||
|
|
schedule_dim=int(data["X_schedule_train"].shape[1]),
|
||
|
|
time_dim=int(data["X_time_train"].shape[-1]),
|
||
|
|
hidden_dim=int(cfg.hidden_dim),
|
||
|
|
n_blocks=int(cfg.n_blocks),
|
||
|
|
dropout=float(cfg.dropout),
|
||
|
|
use_schedule=bool(cfg.use_schedule),
|
||
|
|
).to(cfg.device)
|
||
|
|
|
||
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay))
|
||
|
|
best_val = float("inf")
|
||
|
|
best_path = cfg.output_dir / "time_conditioned_surrogate_best.pt"
|
||
|
|
history: list[dict] = []
|
||
|
|
|
||
|
|
print("Time-conditioned training config:")
|
||
|
|
print(f" processed={cfg.processed_path}")
|
||
|
|
print(f" output_dir={cfg.output_dir}")
|
||
|
|
print(f" device={cfg.device}, batch_size={cfg.batch_size}, epochs={cfg.epochs}")
|
||
|
|
print(
|
||
|
|
f" dims: param={data['X_params_train'].shape[1]}, "
|
||
|
|
f"schedule={data['X_schedule_train'].shape[1]}, time={data['X_time_train'].shape[-1]}"
|
||
|
|
)
|
||
|
|
print(f" curve_time_source={data.get('meta', {}).get('curve_time_source', 'unknown')}")
|
||
|
|
|
||
|
|
for epoch in range(1, int(cfg.epochs) + 1):
|
||
|
|
model.train()
|
||
|
|
total = 0.0
|
||
|
|
total_n = 0
|
||
|
|
for params_x, schedule_x, time_x, y in train_loader:
|
||
|
|
params_x = params_x.to(cfg.device)
|
||
|
|
schedule_x = schedule_x.to(cfg.device)
|
||
|
|
time_x = time_x.to(cfg.device)
|
||
|
|
y = y.to(cfg.device)
|
||
|
|
|
||
|
|
optimizer.zero_grad()
|
||
|
|
pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None)
|
||
|
|
loss = _loss(pred, y, cfg)
|
||
|
|
loss.backward()
|
||
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
||
|
|
optimizer.step()
|
||
|
|
|
||
|
|
bs = int(y.shape[0])
|
||
|
|
total += float(loss.detach().cpu()) * bs
|
||
|
|
total_n += bs
|
||
|
|
|
||
|
|
train_loss = total / max(total_n, 1)
|
||
|
|
val_loss = _evaluate(model, val_loader, cfg)
|
||
|
|
history.append({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss})
|
||
|
|
print(f"[Epoch {epoch:03d}] train={train_loss:.6f} val={val_loss:.6f}")
|
||
|
|
|
||
|
|
if val_loss < best_val:
|
||
|
|
best_val = val_loss
|
||
|
|
torch.save(
|
||
|
|
{
|
||
|
|
"model_state_dict": model.state_dict(),
|
||
|
|
"param_dim": int(data["X_params_train"].shape[1]),
|
||
|
|
"schedule_dim": int(data["X_schedule_train"].shape[1]),
|
||
|
|
"time_dim": int(data["X_time_train"].shape[-1]),
|
||
|
|
"hidden_dim": int(cfg.hidden_dim),
|
||
|
|
"n_blocks": int(cfg.n_blocks),
|
||
|
|
"dropout": float(cfg.dropout),
|
||
|
|
"use_schedule": bool(cfg.use_schedule),
|
||
|
|
"curve_layout": curve_layout,
|
||
|
|
"processed_path": str(cfg.processed_path),
|
||
|
|
"seed": int(cfg.seed),
|
||
|
|
},
|
||
|
|
best_path,
|
||
|
|
)
|
||
|
|
print(f" -> best model saved to: {best_path}")
|
||
|
|
|
||
|
|
checkpoint = torch.load(best_path, map_location=cfg.device)
|
||
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||
|
|
test_loss = _evaluate(model, test_loader, cfg)
|
||
|
|
|
||
|
|
(cfg.output_dir / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False), encoding="utf-8")
|
||
|
|
(cfg.output_dir / "metrics.json").write_text(
|
||
|
|
json.dumps({"best_val_loss": best_val, "test_loss": test_loss}, indent=2, ensure_ascii=False),
|
||
|
|
encoding="utf-8",
|
||
|
|
)
|
||
|
|
print(f"[Final] test={test_loss:.6f}")
|