from __future__ import annotations import json import random from dataclasses import dataclass from pathlib import Path import joblib import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from src.models.forward_surrogate import ForwardSurrogate """正演代理模型训练循环。 训练目标是拼接后的曲线向量:log_pressure、log_derivative 和辅助 slope。 默认损失主要约束压力与导数曲线,slope 通常只记录监控,权重为 0。 """ class ForwardDataset(Dataset): """将预处理后的 numpy 数组封装为 PyTorch Dataset。""" def __init__(self, params_x: np.ndarray, schedule_x: np.ndarray, curve_y: np.ndarray): self.params_x = torch.tensor(params_x, dtype=torch.float32) self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32) self.curve_y = torch.tensor(curve_y, dtype=torch.float32) def __len__(self) -> int: return len(self.params_x) def __getitem__(self, idx: int): return self.params_x[idx], self.schedule_x[idx], self.curve_y[idx] @dataclass class TrainConfig: """scripts/train_forward.py 使用的训练运行参数。""" processed_path: Path output_dir: Path seed: int = 42 batch_size: int = 128 epochs: int = 100 lr: float = 1e-3 weight_decay: float = 1e-5 hidden_dim: int = 128 dropout: float = 0.0 w_pressure: float = 1.0 w_derivative: float = 2.0 w_slope: float = 0.0 w_bias_pressure: float = 0.15 w_bias_derivative: float = 0.05 w_derivative_shape: float = 0.10 w_autofit_pressure: float = 0.0 w_autofit_derivative: float = 0.0 use_huber: bool = True huber_beta: float = 0.05 use_sample_reweight: bool = True sample_reweight_alpha: float = 0.4 sample_weight_min: float = 1.0 sample_weight_max: float = 2.5 use_schedule: bool = True device: str = "cuda" if torch.cuda.is_available() else "cpu" def set_global_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def load_processed_dataset(path: Path) -> dict: """加载预处理后的 joblib 数据,并检查必需的数据划分字段。""" data = joblib.load(path) required_keys = [ "X_params_train", "X_schedule_train", "Y_curve_train", "X_params_val", "X_schedule_val", "Y_curve_val", "X_params_test", "X_schedule_test", "Y_curve_test", ] for k in required_keys: if k not in data: raise KeyError(f"processed dataset 缺少字段: {k}") return data def infer_curve_layout(data: dict) -> dict: """从元数据读取曲线布局;没有元数据时回退为旧的三段布局。""" meta = data.get("meta", {}) curve_dim = int(meta.get("curve_dim", data["Y_curve_train"].shape[1])) curve_layout = meta.get("curve_layout") if curve_layout is not None: return curve_layout if curve_dim % 3 != 0: raise ValueError(f"curve_dim={curve_dim} 不能按 3 段均分,且 meta 中没有 curve_layout") n_time_points = curve_dim // 3 return { "n_time_points": int(n_time_points), "parts": [ {"name": "log_pressure", "start": 0, "end": n_time_points}, {"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points}, {"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points}, ], } def get_part_slices(curve_layout: dict) -> dict[str, slice]: """将 curve_layout 元数据转换为 pressure/derivative/slope 的切片。""" out: dict[str, slice] = {} for part in curve_layout["parts"]: name = str(part["name"]) start = int(part["start"]) end = int(part["end"]) out[name] = slice(start, end) return out def smooth_l1_per_sample(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor: diff = torch.abs(pred - target) loss = torch.where( diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta, ) return loss.mean(dim=1) def mse_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return ((pred - target) ** 2).mean(dim=1) def l1_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return torch.abs(pred - target).mean(dim=1) def first_diff(x: torch.Tensor) -> torch.Tensor: return x[:, 1:] - x[:, :-1] def affine_restore(x_scaled: torch.Tensor, mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """在 torch 中还原 StandardScaler 变换,用于计算自动拟合风格损失。""" return x_scaled * scale + mean def autofit_curve_objective_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """混合相对误差/绝对误差曲线目标函数的 torch 版本。""" weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0) weight = 1.0 / (1.0 + weight_factor) scale = torch.maximum(torch.maximum(torch.abs(target), torch.abs(pred)), torch.full_like(target, 1e-12)) relative_error = torch.abs(target - pred) / scale absolute_error = torch.abs(target - pred) point_error = 0.7 * relative_error + 0.3 * absolute_error weighted_mse = (weight * (point_error**2)).sum(dim=1) / torch.clamp(weight.sum(dim=1), min=1e-12) return torch.sqrt(weighted_mse) def build_sample_weight( true_p: torch.Tensor, true_d: torch.Tensor, alpha: float, w_min: float, w_max: float, ) -> torch.Tensor: """训练时给高幅值曲线样本适度更高的权重。""" p_level = true_p.abs().mean(dim=1) d_level = true_d.abs().mean(dim=1) p_norm = p_level / (p_level.mean().detach() + 1e-6) d_norm = d_level / (d_level.mean().detach() + 1e-6) raw = 0.5 * p_norm + 0.5 * d_norm weight = 1.0 + alpha * (raw - 1.0) weight = torch.clamp(weight, min=w_min, max=w_max) return weight def compute_weighted_loss( pred: torch.Tensor, target: torch.Tensor, slices: dict[str, slice], curve_mean_raw: torch.Tensor, curve_scale_raw: torch.Tensor, huber_beta: float, w_pressure: float, w_derivative: float, w_slope: float, w_bias_pressure: float, w_bias_derivative: float, w_derivative_shape: float, w_autofit_pressure: float, w_autofit_derivative: float, use_sample_reweight: bool, sample_reweight_alpha: float, sample_weight_min: float, sample_weight_max: float, ): """计算正演代理模型训练和监控使用的所有损失分量。""" pred_p = pred[:, slices["log_pressure"]] pred_d = pred[:, slices["log_derivative"]] pred_s = pred[:, slices["slope"]] true_p = target[:, slices["log_pressure"]] true_d = target[:, slices["log_derivative"]] true_s = target[:, slices["slope"]] mean_p = curve_mean_raw[slices["log_pressure"]].unsqueeze(0) scale_p = curve_scale_raw[slices["log_pressure"]].unsqueeze(0) mean_d = curve_mean_raw[slices["log_derivative"]].unsqueeze(0) scale_d = curve_scale_raw[slices["log_derivative"]].unsqueeze(0) loss_p_vec = smooth_l1_per_sample(pred_p, true_p, beta=huber_beta) loss_d_vec = smooth_l1_per_sample(pred_d, true_d, beta=huber_beta) loss_s_vec = mse_per_sample(pred_s, true_s) # 偏置损失用于抑制整条预测曲线的纵向漂移; # 点误差负责学习更细的局部形态。 pred_p_mean = pred_p.mean(dim=1, keepdim=True) true_p_mean = true_p.mean(dim=1, keepdim=True) loss_bias_p_vec = l1_per_sample(pred_p_mean, true_p_mean) pred_d_mean = pred_d.mean(dim=1, keepdim=True) true_d_mean = true_d.mean(dim=1, keepdim=True) loss_bias_d_vec = l1_per_sample(pred_d_mean, true_d_mean) pred_d_diff = first_diff(pred_d) true_d_diff = first_diff(true_d) loss_d_shape_vec = smooth_l1_per_sample(pred_d_diff, true_d_diff, beta=huber_beta) # 自动拟合风格损失在原始曲线尺度上计算。默认权重为 0, # 但保留这个接口,方便后续实验启用。 pred_p_raw = affine_restore(pred_p, mean_p, scale_p) true_p_raw = affine_restore(true_p, mean_p, scale_p) pred_d_raw = affine_restore(pred_d, mean_d, scale_d) true_d_raw = affine_restore(true_d, mean_d, scale_d) loss_autofit_p_vec = autofit_curve_objective_per_sample(pred_p_raw, true_p_raw) loss_autofit_d_vec = autofit_curve_objective_per_sample(pred_d_raw, true_d_raw) total_vec = ( w_pressure * loss_p_vec + w_derivative * loss_d_vec + w_slope * loss_s_vec + w_bias_pressure * loss_bias_p_vec + w_bias_derivative * loss_bias_d_vec + w_derivative_shape * loss_d_shape_vec + w_autofit_pressure * loss_autofit_p_vec + w_autofit_derivative * loss_autofit_d_vec ) if use_sample_reweight: sample_weight = build_sample_weight( true_p=true_p, true_d=true_d, alpha=sample_reweight_alpha, w_min=sample_weight_min, w_max=sample_weight_max, ) else: sample_weight = torch.ones_like(total_vec) total = (total_vec * sample_weight).mean() return { "loss": total, "loss_pressure": loss_p_vec.mean(), "loss_derivative": loss_d_vec.mean(), "loss_slope": loss_s_vec.mean(), "loss_bias_pressure": loss_bias_p_vec.mean(), "loss_bias_derivative": loss_bias_d_vec.mean(), "loss_derivative_shape": loss_d_shape_vec.mean(), "loss_autofit_pressure": loss_autofit_p_vec.mean(), "loss_autofit_derivative": loss_autofit_d_vec.mean(), "sample_weight_mean": sample_weight.mean(), "sample_weight_max": sample_weight.max(), } def model_forward(model: nn.Module, params_x: torch.Tensor, schedule_x: torch.Tensor, use_schedule: bool) -> torch.Tensor: if use_schedule: return model(params_x, schedule_x) return model(params_x, None) def evaluate( model: nn.Module, loader: DataLoader, device: str, slices: dict[str, slice], cfg: TrainConfig, ) -> dict: model.eval() total = { "loss": 0.0, "loss_pressure": 0.0, "loss_derivative": 0.0, "loss_slope": 0.0, "loss_bias_pressure": 0.0, "loss_bias_derivative": 0.0, "loss_derivative_shape": 0.0, "loss_autofit_pressure": 0.0, "loss_autofit_derivative": 0.0, "sample_weight_mean": 0.0, "sample_weight_max": 0.0, } total_n = 0 with torch.no_grad(): for params_x, schedule_x, curve_y in loader: params_x = params_x.to(device) schedule_x = schedule_x.to(device) curve_y = curve_y.to(device) pred = model_forward(model, params_x, schedule_x, cfg.use_schedule) losses = compute_weighted_loss( pred=pred, target=curve_y, slices=slices, curve_mean_raw=cfg.curve_mean_raw, curve_scale_raw=cfg.curve_scale_raw, huber_beta=cfg.huber_beta, w_pressure=cfg.w_pressure, w_derivative=cfg.w_derivative, w_slope=cfg.w_slope, w_bias_pressure=cfg.w_bias_pressure, w_bias_derivative=cfg.w_bias_derivative, w_derivative_shape=cfg.w_derivative_shape, w_autofit_pressure=cfg.w_autofit_pressure, w_autofit_derivative=cfg.w_autofit_derivative, use_sample_reweight=cfg.use_sample_reweight, sample_reweight_alpha=cfg.sample_reweight_alpha, sample_weight_min=cfg.sample_weight_min, sample_weight_max=cfg.sample_weight_max, ) bs = params_x.size(0) for k in total: total[k] += losses[k].item() * bs total_n += bs denom = max(total_n, 1) return {k: v / denom for k, v in total.items()} def train_forward(cfg: TrainConfig) -> None: """训练正演代理模型,保存最优 checkpoint,并在测试集上评估。""" cfg.output_dir.mkdir(parents=True, exist_ok=True) set_global_seed(int(cfg.seed)) data = load_processed_dataset(cfg.processed_path) curve_layout = infer_curve_layout(data) part_slices = get_part_slices(curve_layout) scaler_curve = data["scaler_curve"] curve_mean_raw = np.asarray(scaler_curve.mean_, dtype=np.float32).reshape(-1) curve_scale_raw = np.asarray(scaler_curve.scale_, dtype=np.float32).reshape(-1) # 将曲线 scaler 张量挂到 cfg 上,方便训练和评估共用同一套损失代码。 cfg.curve_mean_raw = torch.tensor(curve_mean_raw, dtype=torch.float32, device=cfg.device) cfg.curve_scale_raw = torch.tensor(curve_scale_raw, dtype=torch.float32, device=cfg.device) train_ds = ForwardDataset( data["X_params_train"], data["X_schedule_train"], data["Y_curve_train"] ) val_ds = ForwardDataset( data["X_params_val"], data["X_schedule_val"], data["Y_curve_val"] ) test_ds = ForwardDataset( data["X_params_test"], data["X_schedule_test"], data["Y_curve_test"] ) loader_generator = torch.Generator() loader_generator.manual_seed(int(cfg.seed)) train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, generator=loader_generator) val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False) test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False) param_dim = data["X_params_train"].shape[1] schedule_dim = data["X_schedule_train"].shape[1] curve_dim = data["Y_curve_train"].shape[1] # 模型维度从预处理数据集中自动推断,便于不同数据集版本共用训练入口。 model = ForwardSurrogate( param_dim=param_dim, schedule_dim=schedule_dim, curve_dim=curve_dim, hidden_dim=cfg.hidden_dim, dropout=cfg.dropout, use_schedule=cfg.use_schedule, ).to(cfg.device) optimizer = torch.optim.Adam( model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay, ) best_val = float("inf") best_path = cfg.output_dir / "forward_surrogate_best.pt" history: list[dict] = [] print("训练配置:") print(f" device={cfg.device}") print(f" seed={cfg.seed}") print(f" batch_size={cfg.batch_size}, epochs={cfg.epochs}, lr={cfg.lr}, weight_decay={cfg.weight_decay}") print(f" hidden_dim={cfg.hidden_dim}, dropout={cfg.dropout}") print(f" use_schedule={cfg.use_schedule}") print( f" weights: pressure={cfg.w_pressure}, derivative={cfg.w_derivative}, " f"slope={cfg.w_slope}, bias_p={cfg.w_bias_pressure}, " f"bias_d={cfg.w_bias_derivative}, d_shape={cfg.w_derivative_shape}, " f"autofit_p={cfg.w_autofit_pressure}, autofit_d={cfg.w_autofit_derivative}" ) print( f" sample_reweight={cfg.use_sample_reweight}, alpha={cfg.sample_reweight_alpha}, " f"clip=[{cfg.sample_weight_min}, {cfg.sample_weight_max}]" ) print(f" curve_layout={curve_layout}") print(" note: 当前重点训练 pressure + derivative;可显式关闭 schedule 分支做固定制度对照") for epoch in range(1, cfg.epochs + 1): model.train() total = { "loss": 0.0, "loss_pressure": 0.0, "loss_derivative": 0.0, "loss_slope": 0.0, "loss_bias_pressure": 0.0, "loss_bias_derivative": 0.0, "loss_derivative_shape": 0.0, "loss_autofit_pressure": 0.0, "loss_autofit_derivative": 0.0, "sample_weight_mean": 0.0, "sample_weight_max": 0.0, } total_n = 0 for params_x, schedule_x, curve_y in train_loader: params_x = params_x.to(cfg.device) schedule_x = schedule_x.to(cfg.device) curve_y = curve_y.to(cfg.device) optimizer.zero_grad() pred = model_forward(model, params_x, schedule_x, cfg.use_schedule) losses = compute_weighted_loss( pred=pred, target=curve_y, slices=part_slices, curve_mean_raw=cfg.curve_mean_raw, curve_scale_raw=cfg.curve_scale_raw, huber_beta=cfg.huber_beta, w_pressure=cfg.w_pressure, w_derivative=cfg.w_derivative, w_slope=cfg.w_slope, w_bias_pressure=cfg.w_bias_pressure, w_bias_derivative=cfg.w_bias_derivative, w_derivative_shape=cfg.w_derivative_shape, w_autofit_pressure=cfg.w_autofit_pressure, w_autofit_derivative=cfg.w_autofit_derivative, use_sample_reweight=cfg.use_sample_reweight, sample_reweight_alpha=cfg.sample_reweight_alpha, sample_weight_min=cfg.sample_weight_min, sample_weight_max=cfg.sample_weight_max, ) losses["loss"].backward() optimizer.step() bs = params_x.size(0) for k in total: total[k] += losses[k].item() * bs total_n += bs denom = max(total_n, 1) train_metrics = {k: v / denom for k, v in total.items()} val_metrics = evaluate( model=model, loader=val_loader, device=cfg.device, slices=part_slices, cfg=cfg, ) row = {"epoch": epoch} for k, v in train_metrics.items(): row[f"train_{k}"] = float(v) for k, v in val_metrics.items(): row[f"val_{k}"] = float(v) history.append(row) print( f"[Epoch {epoch:03d}] " f"train={train_metrics['loss']:.6f} " f"(p={train_metrics['loss_pressure']:.6f}, " f"d={train_metrics['loss_derivative']:.6f}, " f"s={train_metrics['loss_slope']:.6f}, " f"bp={train_metrics['loss_bias_pressure']:.6f}, " f"bd={train_metrics['loss_bias_derivative']:.6f}, " f"ds={train_metrics['loss_derivative_shape']:.6f}, " f"ap={train_metrics['loss_autofit_pressure']:.6f}, " f"ad={train_metrics['loss_autofit_derivative']:.6f}, " f"wmean={train_metrics['sample_weight_mean']:.4f}, " f"wmax={train_metrics['sample_weight_max']:.4f}) " f"val={val_metrics['loss']:.6f} " f"(p={val_metrics['loss_pressure']:.6f}, " f"d={val_metrics['loss_derivative']:.6f}, " f"s={val_metrics['loss_slope']:.6f}, " f"bp={val_metrics['loss_bias_pressure']:.6f}, " f"bd={val_metrics['loss_bias_derivative']:.6f}, " f"ds={val_metrics['loss_derivative_shape']:.6f}, " f"ap={val_metrics['loss_autofit_pressure']:.6f}, " f"ad={val_metrics['loss_autofit_derivative']:.6f}, " f"wmean={val_metrics['sample_weight_mean']:.4f}, " f"wmax={val_metrics['sample_weight_max']:.4f})" ) if val_metrics["loss"] < best_val: best_val = val_metrics["loss"] torch.save( { "model_state_dict": model.state_dict(), "param_dim": param_dim, "schedule_dim": schedule_dim, "curve_dim": curve_dim, "hidden_dim": cfg.hidden_dim, "dropout": cfg.dropout, "use_schedule": cfg.use_schedule, "seed": int(cfg.seed), "curve_layout": curve_layout, "loss_weights": { "pressure": cfg.w_pressure, "derivative": cfg.w_derivative, "slope": cfg.w_slope, "bias_pressure": cfg.w_bias_pressure, "bias_derivative": cfg.w_bias_derivative, "derivative_shape": cfg.w_derivative_shape, "autofit_pressure": cfg.w_autofit_pressure, "autofit_derivative": cfg.w_autofit_derivative, }, "sample_reweight": { "enabled": cfg.use_sample_reweight, "alpha": cfg.sample_reweight_alpha, "weight_min": cfg.sample_weight_min, "weight_max": cfg.sample_weight_max, }, }, best_path, ) print(f" -> best model saved to: {best_path}") with open(cfg.output_dir / "history.json", "w", encoding="utf-8") as f: json.dump(history, f, ensure_ascii=False, indent=2) checkpoint = torch.load(best_path, map_location=cfg.device) best_model = ForwardSurrogate( param_dim=checkpoint["param_dim"], schedule_dim=checkpoint["schedule_dim"], curve_dim=checkpoint["curve_dim"], hidden_dim=checkpoint["hidden_dim"], dropout=checkpoint["dropout"], use_schedule=checkpoint.get("use_schedule", True), ).to(cfg.device) best_model.load_state_dict(checkpoint["model_state_dict"]) test_metrics = evaluate( model=best_model, loader=test_loader, device=cfg.device, slices=part_slices, cfg=cfg, ) print( f"[Final] test={test_metrics['loss']:.6f} " f"(p={test_metrics['loss_pressure']:.6f}, " f"d={test_metrics['loss_derivative']:.6f}, " f"s={test_metrics['loss_slope']:.6f}, " f"bp={test_metrics['loss_bias_pressure']:.6f}, " f"bd={test_metrics['loss_bias_derivative']:.6f}, " f"ds={test_metrics['loss_derivative_shape']:.6f}, " f"ap={test_metrics['loss_autofit_pressure']:.6f}, " f"ad={test_metrics['loss_autofit_derivative']:.6f}, " f"wmean={test_metrics['sample_weight_mean']:.4f}, " f"wmax={test_metrics['sample_weight_max']:.4f})" ) with open(cfg.output_dir / "metrics.json", "w", encoding="utf-8") as f: json.dump( { "best_val_loss": float(best_val), "test_metrics": {k: float(v) for k, v in test_metrics.items()}, "use_schedule": cfg.use_schedule, "seed": int(cfg.seed), "loss_weights": { "pressure": cfg.w_pressure, "derivative": cfg.w_derivative, "slope": cfg.w_slope, "bias_pressure": cfg.w_bias_pressure, "bias_derivative": cfg.w_bias_derivative, "derivative_shape": cfg.w_derivative_shape, "autofit_pressure": cfg.w_autofit_pressure, "autofit_derivative": cfg.w_autofit_derivative, }, "sample_reweight": { "enabled": cfg.use_sample_reweight, "alpha": cfg.sample_reweight_alpha, "weight_min": cfg.sample_weight_min, "weight_max": cfg.sample_weight_max, }, "curve_layout": curve_layout, }, f, ensure_ascii=False, indent=2, )