You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/src/training/train_forward.py

634 lines
23 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from __future__ import annotations
import json
import random
from dataclasses import dataclass
from pathlib import Path
import joblib
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from src.models.forward_surrogate import ForwardSurrogate
"""正演代理模型训练循环。
训练目标是拼接后的曲线向量log_pressure、log_derivative 和辅助 slope。
默认损失主要约束压力与导数曲线slope 通常只记录监控,权重为 0。
"""
class ForwardDataset(Dataset):
"""将预处理后的 numpy 数组封装为 PyTorch Dataset。"""
def __init__(self, params_x: np.ndarray, schedule_x: np.ndarray, curve_y: np.ndarray):
self.params_x = torch.tensor(params_x, dtype=torch.float32)
self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32)
self.curve_y = torch.tensor(curve_y, dtype=torch.float32)
def __len__(self) -> int:
return len(self.params_x)
def __getitem__(self, idx: int):
return self.params_x[idx], self.schedule_x[idx], self.curve_y[idx]
@dataclass
class TrainConfig:
"""scripts/train_forward.py 使用的训练运行参数。"""
processed_path: Path
output_dir: Path
seed: int = 42
batch_size: int = 128
epochs: int = 100
lr: float = 1e-3
weight_decay: float = 1e-5
hidden_dim: int = 128
dropout: float = 0.0
w_pressure: float = 1.0
w_derivative: float = 2.0
w_slope: float = 0.0
w_bias_pressure: float = 0.15
w_bias_derivative: float = 0.05
w_derivative_shape: float = 0.10
w_autofit_pressure: float = 0.0
w_autofit_derivative: float = 0.0
use_huber: bool = True
huber_beta: float = 0.05
use_sample_reweight: bool = True
sample_reweight_alpha: float = 0.4
sample_weight_min: float = 1.0
sample_weight_max: float = 2.5
use_schedule: bool = True
device: str = "cuda" if torch.cuda.is_available() else "cpu"
def set_global_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def load_processed_dataset(path: Path) -> dict:
"""加载预处理后的 joblib 数据,并检查必需的数据划分字段。"""
data = joblib.load(path)
required_keys = [
"X_params_train", "X_schedule_train", "Y_curve_train",
"X_params_val", "X_schedule_val", "Y_curve_val",
"X_params_test", "X_schedule_test", "Y_curve_test",
]
for k in required_keys:
if k not in data:
raise KeyError(f"processed dataset 缺少字段: {k}")
return data
def infer_curve_layout(data: dict) -> dict:
"""从元数据读取曲线布局;没有元数据时回退为旧的三段布局。"""
meta = data.get("meta", {})
curve_dim = int(meta.get("curve_dim", data["Y_curve_train"].shape[1]))
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
if curve_dim % 3 != 0:
raise ValueError(f"curve_dim={curve_dim} 不能按 3 段均分,且 meta 中没有 curve_layout")
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
],
}
def get_part_slices(curve_layout: dict) -> dict[str, slice]:
"""将 curve_layout 元数据转换为 pressure/derivative/slope 的切片。"""
out: dict[str, slice] = {}
for part in curve_layout["parts"]:
name = str(part["name"])
start = int(part["start"])
end = int(part["end"])
out[name] = slice(start, end)
return out
def smooth_l1_per_sample(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor:
diff = torch.abs(pred - target)
loss = torch.where(
diff < beta,
0.5 * diff * diff / beta,
diff - 0.5 * beta,
)
return loss.mean(dim=1)
def mse_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return ((pred - target) ** 2).mean(dim=1)
def l1_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
return torch.abs(pred - target).mean(dim=1)
def first_diff(x: torch.Tensor) -> torch.Tensor:
return x[:, 1:] - x[:, :-1]
def affine_restore(x_scaled: torch.Tensor, mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""在 torch 中还原 StandardScaler 变换,用于计算自动拟合风格损失。"""
return x_scaled * scale + mean
def autofit_curve_objective_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""混合相对误差/绝对误差曲线目标函数的 torch 版本。"""
weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0)
weight = 1.0 / (1.0 + weight_factor)
scale = torch.maximum(torch.maximum(torch.abs(target), torch.abs(pred)), torch.full_like(target, 1e-12))
relative_error = torch.abs(target - pred) / scale
absolute_error = torch.abs(target - pred)
point_error = 0.7 * relative_error + 0.3 * absolute_error
weighted_mse = (weight * (point_error**2)).sum(dim=1) / torch.clamp(weight.sum(dim=1), min=1e-12)
return torch.sqrt(weighted_mse)
def build_sample_weight(
true_p: torch.Tensor,
true_d: torch.Tensor,
alpha: float,
w_min: float,
w_max: float,
) -> torch.Tensor:
"""训练时给高幅值曲线样本适度更高的权重。"""
p_level = true_p.abs().mean(dim=1)
d_level = true_d.abs().mean(dim=1)
p_norm = p_level / (p_level.mean().detach() + 1e-6)
d_norm = d_level / (d_level.mean().detach() + 1e-6)
raw = 0.5 * p_norm + 0.5 * d_norm
weight = 1.0 + alpha * (raw - 1.0)
weight = torch.clamp(weight, min=w_min, max=w_max)
return weight
def compute_weighted_loss(
pred: torch.Tensor,
target: torch.Tensor,
slices: dict[str, slice],
curve_mean_raw: torch.Tensor,
curve_scale_raw: torch.Tensor,
huber_beta: float,
w_pressure: float,
w_derivative: float,
w_slope: float,
w_bias_pressure: float,
w_bias_derivative: float,
w_derivative_shape: float,
w_autofit_pressure: float,
w_autofit_derivative: float,
use_sample_reweight: bool,
sample_reweight_alpha: float,
sample_weight_min: float,
sample_weight_max: float,
):
"""计算正演代理模型训练和监控使用的所有损失分量。"""
pred_p = pred[:, slices["log_pressure"]]
pred_d = pred[:, slices["log_derivative"]]
pred_s = pred[:, slices["slope"]]
true_p = target[:, slices["log_pressure"]]
true_d = target[:, slices["log_derivative"]]
true_s = target[:, slices["slope"]]
mean_p = curve_mean_raw[slices["log_pressure"]].unsqueeze(0)
scale_p = curve_scale_raw[slices["log_pressure"]].unsqueeze(0)
mean_d = curve_mean_raw[slices["log_derivative"]].unsqueeze(0)
scale_d = curve_scale_raw[slices["log_derivative"]].unsqueeze(0)
loss_p_vec = smooth_l1_per_sample(pred_p, true_p, beta=huber_beta)
loss_d_vec = smooth_l1_per_sample(pred_d, true_d, beta=huber_beta)
loss_s_vec = mse_per_sample(pred_s, true_s)
# 偏置损失用于抑制整条预测曲线的纵向漂移;
# 点误差负责学习更细的局部形态。
pred_p_mean = pred_p.mean(dim=1, keepdim=True)
true_p_mean = true_p.mean(dim=1, keepdim=True)
loss_bias_p_vec = l1_per_sample(pred_p_mean, true_p_mean)
pred_d_mean = pred_d.mean(dim=1, keepdim=True)
true_d_mean = true_d.mean(dim=1, keepdim=True)
loss_bias_d_vec = l1_per_sample(pred_d_mean, true_d_mean)
pred_d_diff = first_diff(pred_d)
true_d_diff = first_diff(true_d)
loss_d_shape_vec = smooth_l1_per_sample(pred_d_diff, true_d_diff, beta=huber_beta)
# 自动拟合风格损失在原始曲线尺度上计算。默认权重为 0
# 但保留这个接口,方便后续实验启用。
pred_p_raw = affine_restore(pred_p, mean_p, scale_p)
true_p_raw = affine_restore(true_p, mean_p, scale_p)
pred_d_raw = affine_restore(pred_d, mean_d, scale_d)
true_d_raw = affine_restore(true_d, mean_d, scale_d)
loss_autofit_p_vec = autofit_curve_objective_per_sample(pred_p_raw, true_p_raw)
loss_autofit_d_vec = autofit_curve_objective_per_sample(pred_d_raw, true_d_raw)
total_vec = (
w_pressure * loss_p_vec
+ w_derivative * loss_d_vec
+ w_slope * loss_s_vec
+ w_bias_pressure * loss_bias_p_vec
+ w_bias_derivative * loss_bias_d_vec
+ w_derivative_shape * loss_d_shape_vec
+ w_autofit_pressure * loss_autofit_p_vec
+ w_autofit_derivative * loss_autofit_d_vec
)
if use_sample_reweight:
sample_weight = build_sample_weight(
true_p=true_p,
true_d=true_d,
alpha=sample_reweight_alpha,
w_min=sample_weight_min,
w_max=sample_weight_max,
)
else:
sample_weight = torch.ones_like(total_vec)
total = (total_vec * sample_weight).mean()
return {
"loss": total,
"loss_pressure": loss_p_vec.mean(),
"loss_derivative": loss_d_vec.mean(),
"loss_slope": loss_s_vec.mean(),
"loss_bias_pressure": loss_bias_p_vec.mean(),
"loss_bias_derivative": loss_bias_d_vec.mean(),
"loss_derivative_shape": loss_d_shape_vec.mean(),
"loss_autofit_pressure": loss_autofit_p_vec.mean(),
"loss_autofit_derivative": loss_autofit_d_vec.mean(),
"sample_weight_mean": sample_weight.mean(),
"sample_weight_max": sample_weight.max(),
}
def model_forward(model: nn.Module, params_x: torch.Tensor, schedule_x: torch.Tensor, use_schedule: bool) -> torch.Tensor:
if use_schedule:
return model(params_x, schedule_x)
return model(params_x, None)
def evaluate(
model: nn.Module,
loader: DataLoader,
device: str,
slices: dict[str, slice],
cfg: TrainConfig,
) -> dict:
model.eval()
total = {
"loss": 0.0,
"loss_pressure": 0.0,
"loss_derivative": 0.0,
"loss_slope": 0.0,
"loss_bias_pressure": 0.0,
"loss_bias_derivative": 0.0,
"loss_derivative_shape": 0.0,
"loss_autofit_pressure": 0.0,
"loss_autofit_derivative": 0.0,
"sample_weight_mean": 0.0,
"sample_weight_max": 0.0,
}
total_n = 0
with torch.no_grad():
for params_x, schedule_x, curve_y in loader:
params_x = params_x.to(device)
schedule_x = schedule_x.to(device)
curve_y = curve_y.to(device)
pred = model_forward(model, params_x, schedule_x, cfg.use_schedule)
losses = compute_weighted_loss(
pred=pred,
target=curve_y,
slices=slices,
curve_mean_raw=cfg.curve_mean_raw,
curve_scale_raw=cfg.curve_scale_raw,
huber_beta=cfg.huber_beta,
w_pressure=cfg.w_pressure,
w_derivative=cfg.w_derivative,
w_slope=cfg.w_slope,
w_bias_pressure=cfg.w_bias_pressure,
w_bias_derivative=cfg.w_bias_derivative,
w_derivative_shape=cfg.w_derivative_shape,
w_autofit_pressure=cfg.w_autofit_pressure,
w_autofit_derivative=cfg.w_autofit_derivative,
use_sample_reweight=cfg.use_sample_reweight,
sample_reweight_alpha=cfg.sample_reweight_alpha,
sample_weight_min=cfg.sample_weight_min,
sample_weight_max=cfg.sample_weight_max,
)
bs = params_x.size(0)
for k in total:
total[k] += losses[k].item() * bs
total_n += bs
denom = max(total_n, 1)
return {k: v / denom for k, v in total.items()}
def train_forward(cfg: TrainConfig) -> None:
"""训练正演代理模型,保存最优 checkpoint并在测试集上评估。"""
cfg.output_dir.mkdir(parents=True, exist_ok=True)
set_global_seed(int(cfg.seed))
data = load_processed_dataset(cfg.processed_path)
curve_layout = infer_curve_layout(data)
part_slices = get_part_slices(curve_layout)
scaler_curve = data["scaler_curve"]
curve_mean_raw = np.asarray(scaler_curve.mean_, dtype=np.float32).reshape(-1)
curve_scale_raw = np.asarray(scaler_curve.scale_, dtype=np.float32).reshape(-1)
# 将曲线 scaler 张量挂到 cfg 上,方便训练和评估共用同一套损失代码。
cfg.curve_mean_raw = torch.tensor(curve_mean_raw, dtype=torch.float32, device=cfg.device)
cfg.curve_scale_raw = torch.tensor(curve_scale_raw, dtype=torch.float32, device=cfg.device)
train_ds = ForwardDataset(
data["X_params_train"], data["X_schedule_train"], data["Y_curve_train"]
)
val_ds = ForwardDataset(
data["X_params_val"], data["X_schedule_val"], data["Y_curve_val"]
)
test_ds = ForwardDataset(
data["X_params_test"], data["X_schedule_test"], data["Y_curve_test"]
)
loader_generator = torch.Generator()
loader_generator.manual_seed(int(cfg.seed))
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, generator=loader_generator)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False)
param_dim = data["X_params_train"].shape[1]
schedule_dim = data["X_schedule_train"].shape[1]
curve_dim = data["Y_curve_train"].shape[1]
# 模型维度从预处理数据集中自动推断,便于不同数据集版本共用训练入口。
model = ForwardSurrogate(
param_dim=param_dim,
schedule_dim=schedule_dim,
curve_dim=curve_dim,
hidden_dim=cfg.hidden_dim,
dropout=cfg.dropout,
use_schedule=cfg.use_schedule,
).to(cfg.device)
optimizer = torch.optim.Adam(
model.parameters(),
lr=cfg.lr,
weight_decay=cfg.weight_decay,
)
best_val = float("inf")
best_path = cfg.output_dir / "forward_surrogate_best.pt"
history: list[dict] = []
print("训练配置:")
print(f" device={cfg.device}")
print(f" seed={cfg.seed}")
print(f" batch_size={cfg.batch_size}, epochs={cfg.epochs}, lr={cfg.lr}, weight_decay={cfg.weight_decay}")
print(f" hidden_dim={cfg.hidden_dim}, dropout={cfg.dropout}")
print(f" use_schedule={cfg.use_schedule}")
print(
f" weights: pressure={cfg.w_pressure}, derivative={cfg.w_derivative}, "
f"slope={cfg.w_slope}, bias_p={cfg.w_bias_pressure}, "
f"bias_d={cfg.w_bias_derivative}, d_shape={cfg.w_derivative_shape}, "
f"autofit_p={cfg.w_autofit_pressure}, autofit_d={cfg.w_autofit_derivative}"
)
print(
f" sample_reweight={cfg.use_sample_reweight}, alpha={cfg.sample_reweight_alpha}, "
f"clip=[{cfg.sample_weight_min}, {cfg.sample_weight_max}]"
)
print(f" curve_layout={curve_layout}")
print(" note: 当前重点训练 pressure + derivative可显式关闭 schedule 分支做固定制度对照")
for epoch in range(1, cfg.epochs + 1):
model.train()
total = {
"loss": 0.0,
"loss_pressure": 0.0,
"loss_derivative": 0.0,
"loss_slope": 0.0,
"loss_bias_pressure": 0.0,
"loss_bias_derivative": 0.0,
"loss_derivative_shape": 0.0,
"loss_autofit_pressure": 0.0,
"loss_autofit_derivative": 0.0,
"sample_weight_mean": 0.0,
"sample_weight_max": 0.0,
}
total_n = 0
for params_x, schedule_x, curve_y in train_loader:
params_x = params_x.to(cfg.device)
schedule_x = schedule_x.to(cfg.device)
curve_y = curve_y.to(cfg.device)
optimizer.zero_grad()
pred = model_forward(model, params_x, schedule_x, cfg.use_schedule)
losses = compute_weighted_loss(
pred=pred,
target=curve_y,
slices=part_slices,
curve_mean_raw=cfg.curve_mean_raw,
curve_scale_raw=cfg.curve_scale_raw,
huber_beta=cfg.huber_beta,
w_pressure=cfg.w_pressure,
w_derivative=cfg.w_derivative,
w_slope=cfg.w_slope,
w_bias_pressure=cfg.w_bias_pressure,
w_bias_derivative=cfg.w_bias_derivative,
w_derivative_shape=cfg.w_derivative_shape,
w_autofit_pressure=cfg.w_autofit_pressure,
w_autofit_derivative=cfg.w_autofit_derivative,
use_sample_reweight=cfg.use_sample_reweight,
sample_reweight_alpha=cfg.sample_reweight_alpha,
sample_weight_min=cfg.sample_weight_min,
sample_weight_max=cfg.sample_weight_max,
)
losses["loss"].backward()
optimizer.step()
bs = params_x.size(0)
for k in total:
total[k] += losses[k].item() * bs
total_n += bs
denom = max(total_n, 1)
train_metrics = {k: v / denom for k, v in total.items()}
val_metrics = evaluate(
model=model,
loader=val_loader,
device=cfg.device,
slices=part_slices,
cfg=cfg,
)
row = {"epoch": epoch}
for k, v in train_metrics.items():
row[f"train_{k}"] = float(v)
for k, v in val_metrics.items():
row[f"val_{k}"] = float(v)
history.append(row)
print(
f"[Epoch {epoch:03d}] "
f"train={train_metrics['loss']:.6f} "
f"(p={train_metrics['loss_pressure']:.6f}, "
f"d={train_metrics['loss_derivative']:.6f}, "
f"s={train_metrics['loss_slope']:.6f}, "
f"bp={train_metrics['loss_bias_pressure']:.6f}, "
f"bd={train_metrics['loss_bias_derivative']:.6f}, "
f"ds={train_metrics['loss_derivative_shape']:.6f}, "
f"ap={train_metrics['loss_autofit_pressure']:.6f}, "
f"ad={train_metrics['loss_autofit_derivative']:.6f}, "
f"wmean={train_metrics['sample_weight_mean']:.4f}, "
f"wmax={train_metrics['sample_weight_max']:.4f}) "
f"val={val_metrics['loss']:.6f} "
f"(p={val_metrics['loss_pressure']:.6f}, "
f"d={val_metrics['loss_derivative']:.6f}, "
f"s={val_metrics['loss_slope']:.6f}, "
f"bp={val_metrics['loss_bias_pressure']:.6f}, "
f"bd={val_metrics['loss_bias_derivative']:.6f}, "
f"ds={val_metrics['loss_derivative_shape']:.6f}, "
f"ap={val_metrics['loss_autofit_pressure']:.6f}, "
f"ad={val_metrics['loss_autofit_derivative']:.6f}, "
f"wmean={val_metrics['sample_weight_mean']:.4f}, "
f"wmax={val_metrics['sample_weight_max']:.4f})"
)
if val_metrics["loss"] < best_val:
best_val = val_metrics["loss"]
torch.save(
{
"model_state_dict": model.state_dict(),
"param_dim": param_dim,
"schedule_dim": schedule_dim,
"curve_dim": curve_dim,
"hidden_dim": cfg.hidden_dim,
"dropout": cfg.dropout,
"use_schedule": cfg.use_schedule,
"seed": int(cfg.seed),
"curve_layout": curve_layout,
"loss_weights": {
"pressure": cfg.w_pressure,
"derivative": cfg.w_derivative,
"slope": cfg.w_slope,
"bias_pressure": cfg.w_bias_pressure,
"bias_derivative": cfg.w_bias_derivative,
"derivative_shape": cfg.w_derivative_shape,
"autofit_pressure": cfg.w_autofit_pressure,
"autofit_derivative": cfg.w_autofit_derivative,
},
"sample_reweight": {
"enabled": cfg.use_sample_reweight,
"alpha": cfg.sample_reweight_alpha,
"weight_min": cfg.sample_weight_min,
"weight_max": cfg.sample_weight_max,
},
},
best_path,
)
print(f" -> best model saved to: {best_path}")
with open(cfg.output_dir / "history.json", "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
checkpoint = torch.load(best_path, map_location=cfg.device)
best_model = ForwardSurrogate(
param_dim=checkpoint["param_dim"],
schedule_dim=checkpoint["schedule_dim"],
curve_dim=checkpoint["curve_dim"],
hidden_dim=checkpoint["hidden_dim"],
dropout=checkpoint["dropout"],
use_schedule=checkpoint.get("use_schedule", True),
).to(cfg.device)
best_model.load_state_dict(checkpoint["model_state_dict"])
test_metrics = evaluate(
model=best_model,
loader=test_loader,
device=cfg.device,
slices=part_slices,
cfg=cfg,
)
print(
f"[Final] test={test_metrics['loss']:.6f} "
f"(p={test_metrics['loss_pressure']:.6f}, "
f"d={test_metrics['loss_derivative']:.6f}, "
f"s={test_metrics['loss_slope']:.6f}, "
f"bp={test_metrics['loss_bias_pressure']:.6f}, "
f"bd={test_metrics['loss_bias_derivative']:.6f}, "
f"ds={test_metrics['loss_derivative_shape']:.6f}, "
f"ap={test_metrics['loss_autofit_pressure']:.6f}, "
f"ad={test_metrics['loss_autofit_derivative']:.6f}, "
f"wmean={test_metrics['sample_weight_mean']:.4f}, "
f"wmax={test_metrics['sample_weight_max']:.4f})"
)
with open(cfg.output_dir / "metrics.json", "w", encoding="utf-8") as f:
json.dump(
{
"best_val_loss": float(best_val),
"test_metrics": {k: float(v) for k, v in test_metrics.items()},
"use_schedule": cfg.use_schedule,
"seed": int(cfg.seed),
"loss_weights": {
"pressure": cfg.w_pressure,
"derivative": cfg.w_derivative,
"slope": cfg.w_slope,
"bias_pressure": cfg.w_bias_pressure,
"bias_derivative": cfg.w_bias_derivative,
"derivative_shape": cfg.w_derivative_shape,
"autofit_pressure": cfg.w_autofit_pressure,
"autofit_derivative": cfg.w_autofit_derivative,
},
"sample_reweight": {
"enabled": cfg.use_sample_reweight,
"alpha": cfg.sample_reweight_alpha,
"weight_min": cfg.sample_weight_min,
"weight_max": cfg.sample_weight_max,
},
"curve_layout": curve_layout,
},
f,
ensure_ascii=False,
indent=2,
)