# -*- coding: utf-8 -*- """正演代理模型训练流程。 本模块读取 preprocess.py 生成的 joblib 数据,构造 PyTorch Dataset/DataLoader,训练 ForwardSurrogate,并按验证集损失保存最佳 checkpoint。损失函数不是单一 MSE,而是 由压力、导数、均值偏置、导数形状约束和可选自动拟合目标组成,目的是让 模型既能拟合点值,也能保持对自动试井拟合有意义的曲线形态。 训练过程会保存 history.json、metrics.json 和 forward_surrogate_best.pt,后续评估 脚本可以根据 checkpoint 中保存的维度、curve_layout 和损失权重恢复模型。 """ # pylint: disable=import-error,duplicate-code,too-many-instance-attributes # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals from __future__ import annotations import json import random from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Any import joblib import numpy as np import torch from torch import nn from torch.utils.data import DataLoader, Dataset from src.models.forward_surrogate import ForwardSurrogate, ForwardSurrogateConfig METRIC_KEYS = ( "loss", "loss_pressure", "loss_derivative", "loss_bias_pressure", "loss_bias_derivative", "loss_derivative_shape", "loss_autofit_pressure", "loss_autofit_derivative", "sample_weight_mean", "sample_weight_max", ) class ForwardDataset(Dataset): """把预处理后的参数、流量制度和曲线数组封装成 PyTorch Dataset。""" def __init__(self, params_x: np.ndarray, schedule_x: np.ndarray, curve_y: np.ndarray): """把三个 numpy 数组转为 float32 张量,后续 DataLoader 可直接按样本读取。""" self.params_x = torch.tensor(params_x, dtype=torch.float32) self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32) self.curve_y = torch.tensor(curve_y, dtype=torch.float32) def __len__(self) -> int: """返回数据集或容器中可迭代样本的数量。""" return len(self.params_x) def __getitem__(self, idx: int): """按索引取出一个训练样本或数据项。""" return self.params_x[idx], self.schedule_x[idx], self.curve_y[idx] @dataclass class ModelConfig: """模型结构相关配置。""" hidden_dim: int = 128 dropout: float = 0.0 use_schedule: bool = True @dataclass class OptimConfig: """优化器与训练轮次配置。""" batch_size: int = 128 epochs: int = 100 lr: float = 1e-3 weight_decay: float = 1e-5 @dataclass class LossWeights: """复合损失的各项权重。""" pressure: float = 1.0 derivative: float = 2.0 bias_pressure: float = 0.15 bias_derivative: float = 0.05 derivative_shape: float = 0.10 autofit_pressure: float = 0.0 autofit_derivative: float = 0.0 @dataclass class LossConfig: """损失函数配置。""" weights: LossWeights = field(default_factory=LossWeights) use_huber: bool = True huber_beta: float = 0.05 @dataclass class SampleReweightConfig: """样本重加权配置。""" enabled: bool = True alpha: float = 0.4 weight_min: float = 1.0 weight_max: float = 2.5 @dataclass class TrainRuntime: """训练运行时配置。""" seed: int = 42 device: str = "cuda" if torch.cuda.is_available() else "cpu" @dataclass class TrainConfig: """正演代理模型训练配置。""" processed_path: Path output_dir: Path runtime: TrainRuntime = field(default_factory=TrainRuntime) optim: OptimConfig = field(default_factory=OptimConfig) model: ModelConfig = field(default_factory=ModelConfig) loss: LossConfig = field(default_factory=LossConfig) sample_reweight: SampleReweightConfig = field(default_factory=SampleReweightConfig) @dataclass class CurveStats: """曲线 scaler 的 torch 形式统计量。""" mean_raw: torch.Tensor scale_raw: torch.Tensor @dataclass class LossBatchParts: """压力和导数的预测值与真实值。""" pred_p: torch.Tensor pred_d: torch.Tensor true_p: torch.Tensor true_d: torch.Tensor @dataclass class LossContext: """计算复合损失所需的上下文。""" slices: dict[str, slice] curve_stats: CurveStats loss_cfg: LossConfig reweight_cfg: SampleReweightConfig @dataclass class DatasetBundle: """训练、验证、测试 DataLoader 与数据维度。""" train_loader: DataLoader val_loader: DataLoader test_loader: DataLoader param_dim: int schedule_dim: int curve_dim: int def set_global_seed(seed: int) -> None: """设置 Python、NumPy 和 PyTorch 随机种子,并在 CUDA 可用时同步设置 GPU 随机种子。""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def load_processed_dataset(path: Path) -> dict: """读取预处理后的 joblib 数据,并检查训练、验证、测试三套数组是否齐全。""" data = joblib.load(path) required_keys = [ "X_params_train", "X_schedule_train", "Y_curve_train", "X_params_val", "X_schedule_val", "Y_curve_val", "X_params_test", "X_schedule_test", "Y_curve_test", ] for key in required_keys: if key not in data: raise KeyError(f"processed dataset 缺少字段: {key}") return data def infer_curve_layout(data: dict) -> dict: """读取双通道布局,并拒绝未经重新预处理的旧三通道数据。""" meta = data.get("meta", {}) curve_dim = int(meta.get("curve_dim", data["Y_curve_train"].shape[1])) curve_layout = meta.get("curve_layout") if curve_layout is not None: names = {str(part["name"]) for part in curve_layout.get("parts", [])} if "slope" in names: raise ValueError( "processed 数据仍包含旧版 slope 通道;请先重新运行 " "scripts/preprocess_dataset.py,预处理会自动裁掉 slope" ) return curve_layout if curve_dim % 2 != 0: raise ValueError(f"curve_dim={curve_dim} 不能按压力/导数两段均分") n_time_points = curve_dim // 2 return { "n_time_points": int(n_time_points), "parts": [ {"name": "log_pressure", "start": 0, "end": n_time_points}, {"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points}, ], } def get_part_slices(curve_layout: dict) -> dict[str, slice]: """把 curve_layout 中的 start/end 信息转换成各曲线分段的 slice。""" out: dict[str, slice] = {} for part in curve_layout["parts"]: name = str(part["name"]) out[name] = slice(int(part["start"]), int(part["end"])) return out def smooth_l1_per_sample(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor: """按样本计算 Smooth L1 损失,返回每个样本一个损失值。""" diff = torch.abs(pred - target) loss = torch.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta) return loss.mean(dim=1) def mse_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """按样本计算均方误差。""" return ((pred - target) ** 2).mean(dim=1) def l1_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """按样本计算平均绝对误差。""" return torch.abs(pred - target).mean(dim=1) def regression_per_sample( pred: torch.Tensor, target: torch.Tensor, loss_cfg: LossConfig, ) -> torch.Tensor: """按配置在 Smooth L1 和 MSE 之间切换点值损失。""" if loss_cfg.use_huber: return smooth_l1_per_sample(pred, target, beta=float(loss_cfg.huber_beta)) return mse_per_sample(pred, target) def first_diff(x: torch.Tensor) -> torch.Tensor: """计算曲线相邻时间点的一阶差分,用于约束导数形态。""" return x[:, 1:] - x[:, :-1] def affine_restore(x_scaled: torch.Tensor, mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: """在 torch 张量上执行 scaler 的反标准化公式 x * scale + mean。""" return x_scaled * scale + mean def autofit_curve_objective_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """用 torch 计算自动拟合风格的曲线误差,作为训练附加目标。""" weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0) weight = 1.0 / (1.0 + weight_factor) scale = torch.maximum( torch.maximum(torch.abs(target), torch.abs(pred)), torch.full_like(target, 1e-12), ) relative_error = torch.abs(target - pred) / scale absolute_error = torch.abs(target - pred) point_error = 0.7 * relative_error + 0.3 * absolute_error weighted_mse = (weight * (point_error**2)).sum(dim=1) weighted_mse = weighted_mse / torch.clamp(weight.sum(dim=1), min=1e-12) return torch.sqrt(weighted_mse) def build_sample_weight( true_p: torch.Tensor, true_d: torch.Tensor, reweight_cfg: SampleReweightConfig, ) -> torch.Tensor: """根据真实曲线幅值构造样本权重,让高幅值样本训练时权重更高。""" p_level = true_p.abs().mean(dim=1) d_level = true_d.abs().mean(dim=1) p_norm = p_level / (p_level.mean().detach() + 1e-6) d_norm = d_level / (d_level.mean().detach() + 1e-6) raw = 0.5 * p_norm + 0.5 * d_norm weight = 1.0 + reweight_cfg.alpha * (raw - 1.0) return torch.clamp(weight, min=reweight_cfg.weight_min, max=reweight_cfg.weight_max) def split_curve_parts( pred: torch.Tensor, target: torch.Tensor, slices: dict[str, slice], ) -> LossBatchParts: """把拼接曲线拆成压力和导数两段。""" return LossBatchParts( pred_p=pred[:, slices["log_pressure"]], pred_d=pred[:, slices["log_derivative"]], true_p=target[:, slices["log_pressure"]], true_d=target[:, slices["log_derivative"]], ) def compute_basic_loss_vectors( parts: LossBatchParts, loss_cfg: LossConfig, ) -> dict[str, torch.Tensor]: """计算标准化空间中的基础点值、偏置和导数形状损失。""" return { "loss_pressure": regression_per_sample(parts.pred_p, parts.true_p, loss_cfg), "loss_derivative": regression_per_sample(parts.pred_d, parts.true_d, loss_cfg), "loss_bias_pressure": l1_per_sample( parts.pred_p.mean(dim=1, keepdim=True), parts.true_p.mean(dim=1, keepdim=True), ), "loss_bias_derivative": l1_per_sample( parts.pred_d.mean(dim=1, keepdim=True), parts.true_d.mean(dim=1, keepdim=True), ), "loss_derivative_shape": regression_per_sample( first_diff(parts.pred_d), first_diff(parts.true_d), loss_cfg, ), } def compute_autofit_loss_vectors( parts: LossBatchParts, context: LossContext, ) -> dict[str, torch.Tensor]: """在原始尺度上计算自动拟合风格损失。""" pressure_slice = context.slices["log_pressure"] derivative_slice = context.slices["log_derivative"] mean_p = context.curve_stats.mean_raw[pressure_slice].unsqueeze(0) scale_p = context.curve_stats.scale_raw[pressure_slice].unsqueeze(0) mean_d = context.curve_stats.mean_raw[derivative_slice].unsqueeze(0) scale_d = context.curve_stats.scale_raw[derivative_slice].unsqueeze(0) return { "loss_autofit_pressure": autofit_curve_objective_per_sample( affine_restore(parts.pred_p, mean_p, scale_p), affine_restore(parts.true_p, mean_p, scale_p), ), "loss_autofit_derivative": autofit_curve_objective_per_sample( affine_restore(parts.pred_d, mean_d, scale_d), affine_restore(parts.true_d, mean_d, scale_d), ), } def weighted_total_vector( loss_vectors: dict[str, torch.Tensor], weights: LossWeights, ) -> torch.Tensor: """按配置权重合成每个样本的总损失向量。""" return ( weights.pressure * loss_vectors["loss_pressure"] + weights.derivative * loss_vectors["loss_derivative"] + weights.bias_pressure * loss_vectors["loss_bias_pressure"] + weights.bias_derivative * loss_vectors["loss_bias_derivative"] + weights.derivative_shape * loss_vectors["loss_derivative_shape"] + weights.autofit_pressure * loss_vectors["loss_autofit_pressure"] + weights.autofit_derivative * loss_vectors["loss_autofit_derivative"] ) def compute_weighted_loss( pred: torch.Tensor, target: torch.Tensor, context: LossContext, ) -> dict[str, torch.Tensor]: """计算正演代理模型的复合训练损失。""" parts = split_curve_parts(pred, target, context.slices) loss_vectors = compute_basic_loss_vectors(parts, context.loss_cfg) loss_vectors.update(compute_autofit_loss_vectors(parts, context)) total_vec = weighted_total_vector(loss_vectors, context.loss_cfg.weights) if context.reweight_cfg.enabled: sample_weight = build_sample_weight(parts.true_p, parts.true_d, context.reweight_cfg) else: sample_weight = torch.ones_like(total_vec) metrics = {key: value.mean() for key, value in loss_vectors.items()} metrics["loss"] = (total_vec * sample_weight).mean() metrics["sample_weight_mean"] = sample_weight.mean() metrics["sample_weight_max"] = sample_weight.max() return metrics def model_forward( model: nn.Module, params_x: torch.Tensor, schedule_x: torch.Tensor, use_schedule: bool, ) -> torch.Tensor: """按 use_schedule 开关统一调用模型,兼容只用参数输入和参数+流量制度输入。""" if use_schedule: return model(params_x, schedule_x) return model(params_x, None) def init_metric_accumulator() -> dict[str, float]: """创建指标累加器。""" return {key: 0.0 for key in METRIC_KEYS} def accumulate_metrics( total: dict[str, float], losses: dict[str, torch.Tensor], batch_size: int, ) -> None: """按 batch 样本数加权累加指标。""" for key in total: total[key] += losses[key].item() * batch_size def average_metrics(total: dict[str, float], total_n: int) -> dict[str, float]: """将累加指标转换为样本平均指标。""" denom = max(total_n, 1) return {key: value / denom for key, value in total.items()} def run_loader_epoch( model: nn.Module, loader: DataLoader, device: str, context: LossContext, use_schedule: bool, optimizer: torch.optim.Optimizer | None = None, ) -> dict[str, float]: """执行一个训练或评估 epoch,并返回平均指标。""" is_train = optimizer is not None model.train(mode=is_train) total = init_metric_accumulator() total_n = 0 grad_context = torch.enable_grad() if is_train else torch.no_grad() with grad_context: for params_x, schedule_x, curve_y in loader: params_x = params_x.to(device) schedule_x = schedule_x.to(device) curve_y = curve_y.to(device) if is_train: optimizer.zero_grad() pred = model_forward(model, params_x, schedule_x, use_schedule) losses = compute_weighted_loss(pred=pred, target=curve_y, context=context) if is_train: losses["loss"].backward() optimizer.step() batch_size = params_x.size(0) accumulate_metrics(total, losses, batch_size) total_n += batch_size return average_metrics(total, total_n) def evaluate( model: nn.Module, loader: DataLoader, device: str, context: LossContext, use_schedule: bool, ) -> dict[str, float]: """在验证或测试 DataLoader 上计算平均损失和各损失分量。""" return run_loader_epoch( model=model, loader=loader, device=device, context=context, use_schedule=use_schedule, optimizer=None, ) def build_curve_stats(data: dict, device: str) -> CurveStats: """从预处理数据中的曲线 scaler 构建 torch 统计量。""" scaler_curve = data["scaler_curve"] curve_mean_raw = np.asarray(scaler_curve.mean_, dtype=np.float32).reshape(-1) curve_scale_raw = np.asarray(scaler_curve.scale_, dtype=np.float32).reshape(-1) return CurveStats( mean_raw=torch.tensor(curve_mean_raw, dtype=torch.float32, device=device), scale_raw=torch.tensor(curve_scale_raw, dtype=torch.float32, device=device), ) def build_dataloaders(data: dict, cfg: TrainConfig) -> DatasetBundle: """根据预处理数组构造训练、验证、测试 DataLoader。""" train_ds = ForwardDataset(data["X_params_train"], data["X_schedule_train"], data["Y_curve_train"]) val_ds = ForwardDataset(data["X_params_val"], data["X_schedule_val"], data["Y_curve_val"]) test_ds = ForwardDataset(data["X_params_test"], data["X_schedule_test"], data["Y_curve_test"]) loader_generator = torch.Generator() loader_generator.manual_seed(int(cfg.runtime.seed)) train_loader = DataLoader( train_ds, batch_size=cfg.optim.batch_size, shuffle=True, generator=loader_generator, ) val_loader = DataLoader(val_ds, batch_size=cfg.optim.batch_size, shuffle=False) test_loader = DataLoader(test_ds, batch_size=cfg.optim.batch_size, shuffle=False) return DatasetBundle( train_loader=train_loader, val_loader=val_loader, test_loader=test_loader, param_dim=data["X_params_train"].shape[1], schedule_dim=data["X_schedule_train"].shape[1], curve_dim=data["Y_curve_train"].shape[1], ) def build_forward_model( model_cfg: ModelConfig, param_dim: int, schedule_dim: int, curve_dim: int, device: str, ) -> nn.Module: """兼容新版配置式 ForwardSurrogate 和旧版关键字参数式 ForwardSurrogate。""" surrogate_cfg = ForwardSurrogateConfig( param_dim=param_dim, schedule_dim=schedule_dim, curve_dim=curve_dim, hidden_dim=model_cfg.hidden_dim, dropout=model_cfg.dropout, use_schedule=model_cfg.use_schedule, ) return ForwardSurrogate(surrogate_cfg).to(device) def build_optimizer(model: nn.Module, optim_cfg: OptimConfig) -> torch.optim.Optimizer: """构建 Adam 优化器。""" return torch.optim.Adam( model.parameters(), lr=optim_cfg.lr, weight_decay=optim_cfg.weight_decay, ) def print_training_config(cfg: TrainConfig, curve_layout: dict) -> None: """打印训练配置摘要。""" weights = cfg.loss.weights reweight = cfg.sample_reweight print("训练配置:") print(f" device={cfg.runtime.device}") print(f" seed={cfg.runtime.seed}") print( f" batch_size={cfg.optim.batch_size}, epochs={cfg.optim.epochs}, " f"lr={cfg.optim.lr}, weight_decay={cfg.optim.weight_decay}" ) print(f" hidden_dim={cfg.model.hidden_dim}, dropout={cfg.model.dropout}") print(f" use_schedule={cfg.model.use_schedule}") print( f" weights: pressure={weights.pressure}, derivative={weights.derivative}, " f"bias_p={weights.bias_pressure}, " f"bias_d={weights.bias_derivative}, d_shape={weights.derivative_shape}, " f"autofit_p={weights.autofit_pressure}, autofit_d={weights.autofit_derivative}" ) print( f" sample_reweight={reweight.enabled}, alpha={reweight.alpha}, " f"clip=[{reweight.weight_min}, {reweight.weight_max}]" ) print(f" curve_layout={curve_layout}") print(" note: 当前重点训练 pressure + derivative;可显式关闭 schedule 分支做固定制度对照") def format_metric_line(epoch: int, train_metrics: dict[str, float], val_metrics: dict[str, float]) -> str: """格式化单个 epoch 的训练与验证指标。""" return ( f"[Epoch {epoch:03d}] " f"train={train_metrics['loss']:.6f} " f"(p={train_metrics['loss_pressure']:.6f}, " f"d={train_metrics['loss_derivative']:.6f}, " f"bp={train_metrics['loss_bias_pressure']:.6f}, " f"bd={train_metrics['loss_bias_derivative']:.6f}, " f"ds={train_metrics['loss_derivative_shape']:.6f}, " f"ap={train_metrics['loss_autofit_pressure']:.6f}, " f"ad={train_metrics['loss_autofit_derivative']:.6f}, " f"wmean={train_metrics['sample_weight_mean']:.4f}, " f"wmax={train_metrics['sample_weight_max']:.4f}) " f"val={val_metrics['loss']:.6f} " f"(p={val_metrics['loss_pressure']:.6f}, " f"d={val_metrics['loss_derivative']:.6f}, " f"bp={val_metrics['loss_bias_pressure']:.6f}, " f"bd={val_metrics['loss_bias_derivative']:.6f}, " f"ds={val_metrics['loss_derivative_shape']:.6f}, " f"ap={val_metrics['loss_autofit_pressure']:.6f}, " f"ad={val_metrics['loss_autofit_derivative']:.6f}, " f"wmean={val_metrics['sample_weight_mean']:.4f}, " f"wmax={val_metrics['sample_weight_max']:.4f})" ) def format_final_line(test_metrics: dict[str, float]) -> str: """格式化最终测试集指标。""" return ( f"[Final] test={test_metrics['loss']:.6f} " f"(p={test_metrics['loss_pressure']:.6f}, " f"d={test_metrics['loss_derivative']:.6f}, " f"bp={test_metrics['loss_bias_pressure']:.6f}, " f"bd={test_metrics['loss_bias_derivative']:.6f}, " f"ds={test_metrics['loss_derivative_shape']:.6f}, " f"ap={test_metrics['loss_autofit_pressure']:.6f}, " f"ad={test_metrics['loss_autofit_derivative']:.6f}, " f"wmean={test_metrics['sample_weight_mean']:.4f}, " f"wmax={test_metrics['sample_weight_max']:.4f})" ) def build_checkpoint_payload( model: nn.Module, bundle: DatasetBundle, cfg: TrainConfig, curve_layout: dict, ) -> dict[str, Any]: """构建 checkpoint 保存内容。""" return { "model_state_dict": model.state_dict(), "param_dim": bundle.param_dim, "schedule_dim": bundle.schedule_dim, "curve_dim": bundle.curve_dim, "hidden_dim": cfg.model.hidden_dim, "dropout": cfg.model.dropout, "use_schedule": cfg.model.use_schedule, "seed": int(cfg.runtime.seed), "curve_layout": curve_layout, "loss_weights": asdict(cfg.loss.weights), "sample_reweight": asdict(cfg.sample_reweight), } def append_history_row( history: list[dict], epoch: int, train_metrics: dict[str, float], val_metrics: dict[str, float], ) -> None: """把当前 epoch 的指标写入 history 列表。""" row = {"epoch": epoch} row.update({f"train_{key}": float(value) for key, value in train_metrics.items()}) row.update({f"val_{key}": float(value) for key, value in val_metrics.items()}) history.append(row) def save_json(path: Path, payload: dict | list) -> None: """保存 JSON 文件。""" with open(path, "w", encoding="utf-8") as file_obj: json.dump(payload, file_obj, ensure_ascii=False, indent=2) def train_epochs( model: nn.Module, bundle: DatasetBundle, cfg: TrainConfig, context: LossContext, curve_layout: dict, ) -> tuple[float, Path, list[dict]]: """执行训练循环并保存最佳模型。""" optimizer = build_optimizer(model, cfg.optim) best_val = float("inf") best_path = cfg.output_dir / "forward_surrogate_best.pt" history: list[dict] = [] for epoch in range(1, cfg.optim.epochs + 1): train_metrics = run_loader_epoch( model=model, loader=bundle.train_loader, device=cfg.runtime.device, context=context, use_schedule=cfg.model.use_schedule, optimizer=optimizer, ) val_metrics = evaluate( model=model, loader=bundle.val_loader, device=cfg.runtime.device, context=context, use_schedule=cfg.model.use_schedule, ) append_history_row(history, epoch, train_metrics, val_metrics) print(format_metric_line(epoch, train_metrics, val_metrics)) if val_metrics["loss"] < best_val: best_val = val_metrics["loss"] torch.save(build_checkpoint_payload(model, bundle, cfg, curve_layout), best_path) print(f" -> best model saved to: {best_path}") return best_val, best_path, history def load_best_model(best_path: Path, device: str) -> nn.Module: """从 checkpoint 恢复最佳模型。""" checkpoint = torch.load(best_path, map_location=device) model_cfg = ModelConfig( hidden_dim=checkpoint["hidden_dim"], dropout=checkpoint["dropout"], use_schedule=checkpoint.get("use_schedule", True), ) best_model = build_forward_model( model_cfg=model_cfg, param_dim=checkpoint["param_dim"], schedule_dim=checkpoint["schedule_dim"], curve_dim=checkpoint["curve_dim"], device=device, ) best_model.load_state_dict(checkpoint["model_state_dict"]) return best_model def build_metrics_payload( best_val: float, test_metrics: dict[str, float], cfg: TrainConfig, curve_layout: dict, ) -> dict[str, Any]: """构建 metrics.json 内容。""" return { "best_val_loss": float(best_val), "test_metrics": {key: float(value) for key, value in test_metrics.items()}, "use_schedule": cfg.model.use_schedule, "seed": int(cfg.runtime.seed), "loss_weights": asdict(cfg.loss.weights), "sample_reweight": asdict(cfg.sample_reweight), "curve_layout": curve_layout, } def train_forward(cfg: TrainConfig) -> None: """训练完整曲线正演代理模型。""" cfg.output_dir.mkdir(parents=True, exist_ok=True) set_global_seed(int(cfg.runtime.seed)) data = load_processed_dataset(cfg.processed_path) curve_layout = infer_curve_layout(data) context = LossContext( slices=get_part_slices(curve_layout), curve_stats=build_curve_stats(data, cfg.runtime.device), loss_cfg=cfg.loss, reweight_cfg=cfg.sample_reweight, ) bundle = build_dataloaders(data, cfg) model = build_forward_model( model_cfg=cfg.model, param_dim=bundle.param_dim, schedule_dim=bundle.schedule_dim, curve_dim=bundle.curve_dim, device=cfg.runtime.device, ) print_training_config(cfg, curve_layout) best_val, best_path, history = train_epochs(model, bundle, cfg, context, curve_layout) save_json(cfg.output_dir / "history.json", history) best_model = load_best_model(best_path, cfg.runtime.device) test_metrics = evaluate( model=best_model, loader=bundle.test_loader, device=cfg.runtime.device, context=context, use_schedule=cfg.model.use_schedule, ) print(format_final_line(test_metrics)) save_json( cfg.output_dir / "metrics.json", build_metrics_payload(best_val, test_metrics, cfg, curve_layout), )