You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/src/training/train_forward.py

797 lines
27 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# -*- coding: utf-8 -*-
"""正演代理模型训练流程。
本模块读取 preprocess.py 生成的 joblib 数据,构造 PyTorch Dataset/DataLoader训练
ForwardSurrogate并按验证集损失保存最佳 checkpoint。损失函数不是单一 MSE而是
由压力、导数、均值偏置、导数形状约束和可选自动拟合目标组成,目的是让
模型既能拟合点值,也能保持对自动试井拟合有意义的曲线形态。
训练过程会保存 history.json、metrics.json 和 forward_surrogate_best.pt后续评估
脚本可以根据 checkpoint 中保存的维度、curve_layout 和损失权重恢复模型。
"""
# pylint: disable=import-error,duplicate-code,too-many-instance-attributes
# pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
from __future__ import annotations
import json
import random
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
import joblib
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from src.models.forward_surrogate import ForwardSurrogate, ForwardSurrogateConfig
METRIC_KEYS = (
"loss",
"loss_pressure",
"loss_derivative",
"loss_bias_pressure",
"loss_bias_derivative",
"loss_derivative_shape",
"loss_autofit_pressure",
"loss_autofit_derivative",
"sample_weight_mean",
"sample_weight_max",
)
class ForwardDataset(Dataset):
"""把预处理后的参数、流量制度和曲线数组封装成 PyTorch Dataset。"""
def __init__(self, params_x: np.ndarray, schedule_x: np.ndarray, curve_y: np.ndarray):
"""把三个 numpy 数组转为 float32 张量,后续 DataLoader 可直接按样本读取。"""
self.params_x = torch.tensor(params_x, dtype=torch.float32)
self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32)
self.curve_y = torch.tensor(curve_y, dtype=torch.float32)
def __len__(self) -> int:
"""返回数据集或容器中可迭代样本的数量。"""
return len(self.params_x)
def __getitem__(self, idx: int):
"""按索引取出一个训练样本或数据项。"""
return self.params_x[idx], self.schedule_x[idx], self.curve_y[idx]
@dataclass
class ModelConfig:
"""模型结构相关配置。"""
hidden_dim: int = 128
dropout: float = 0.0
use_schedule: bool = True
@dataclass
class OptimConfig:
"""优化器与训练轮次配置。"""
batch_size: int = 128
epochs: int = 100
lr: float = 1e-3
weight_decay: float = 1e-5
@dataclass
class LossWeights:
"""复合损失的各项权重。"""
pressure: float = 1.0
derivative: float = 2.0
bias_pressure: float = 0.15
bias_derivative: float = 0.05
derivative_shape: float = 0.10
autofit_pressure: float = 0.0
autofit_derivative: float = 0.0
@dataclass
class LossConfig:
"""损失函数配置。"""
weights: LossWeights = field(default_factory=LossWeights)
use_huber: bool = True
huber_beta: float = 0.05
@dataclass
class SampleReweightConfig:
"""样本重加权配置。"""
enabled: bool = True
alpha: float = 0.4
weight_min: float = 1.0
weight_max: float = 2.5
@dataclass
class TrainRuntime:
"""训练运行时配置。"""
seed: int = 42
device: str = "cuda" if torch.cuda.is_available() else "cpu"
@dataclass
class TrainConfig:
"""正演代理模型训练配置。"""
processed_path: Path
output_dir: Path
runtime: TrainRuntime = field(default_factory=TrainRuntime)
optim: OptimConfig = field(default_factory=OptimConfig)
model: ModelConfig = field(default_factory=ModelConfig)
loss: LossConfig = field(default_factory=LossConfig)
sample_reweight: SampleReweightConfig = field(default_factory=SampleReweightConfig)
@dataclass
class CurveStats:
"""曲线 scaler 的 torch 形式统计量。"""
mean_raw: torch.Tensor
scale_raw: torch.Tensor
@dataclass
class LossBatchParts:
"""压力和导数的预测值与真实值。"""
pred_p: torch.Tensor
pred_d: torch.Tensor
true_p: torch.Tensor
true_d: torch.Tensor
@dataclass
class LossContext:
"""计算复合损失所需的上下文。"""
slices: dict[str, slice]
curve_stats: CurveStats
loss_cfg: LossConfig
reweight_cfg: SampleReweightConfig
@dataclass
class DatasetBundle:
"""训练、验证、测试 DataLoader 与数据维度。"""
train_loader: DataLoader
val_loader: DataLoader
test_loader: DataLoader
param_dim: int
schedule_dim: int
curve_dim: int
def set_global_seed(seed: int) -> None:
"""设置 Python、NumPy 和 PyTorch 随机种子,并在 CUDA 可用时同步设置 GPU 随机种子。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def load_processed_dataset(path: Path) -> dict:
"""读取预处理后的 joblib 数据,并检查训练、验证、测试三套数组是否齐全。"""
data = joblib.load(path)
required_keys = [
"X_params_train", "X_schedule_train", "Y_curve_train",
"X_params_val", "X_schedule_val", "Y_curve_val",
"X_params_test", "X_schedule_test", "Y_curve_test",
]
for key in required_keys:
if key not in data:
raise KeyError(f"processed dataset 缺少字段: {key}")
return data
def infer_curve_layout(data: dict) -> dict:
"""读取双通道布局,并拒绝未经重新预处理的旧三通道数据。"""
meta = data.get("meta", {})
curve_dim = int(meta.get("curve_dim", data["Y_curve_train"].shape[1]))
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
names = {str(part["name"]) for part in curve_layout.get("parts", [])}
if "slope" in names:
raise ValueError(
"processed 数据仍包含旧版 slope 通道;请先重新运行 "
"scripts/preprocess_dataset.py预处理会自动裁掉 slope"
)
return curve_layout
if curve_dim % 2 != 0:
raise ValueError(f"curve_dim={curve_dim} 不能按压力/导数两段均分")
n_time_points = curve_dim // 2
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
],
}
def get_part_slices(curve_layout: dict) -> dict[str, slice]:
"""把 curve_layout 中的 start/end 信息转换成各曲线分段的 slice。"""
out: dict[str, slice] = {}
for part in curve_layout["parts"]:
name = str(part["name"])
out[name] = slice(int(part["start"]), int(part["end"]))
return out
def smooth_l1_per_sample(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor:
"""按样本计算 Smooth L1 损失,返回每个样本一个损失值。"""
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta)
return loss.mean(dim=1)
def mse_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""按样本计算均方误差。"""
return ((pred - target) ** 2).mean(dim=1)
def l1_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""按样本计算平均绝对误差。"""
return torch.abs(pred - target).mean(dim=1)
def regression_per_sample(
pred: torch.Tensor,
target: torch.Tensor,
loss_cfg: LossConfig,
) -> torch.Tensor:
"""按配置在 Smooth L1 和 MSE 之间切换点值损失。"""
if loss_cfg.use_huber:
return smooth_l1_per_sample(pred, target, beta=float(loss_cfg.huber_beta))
return mse_per_sample(pred, target)
def first_diff(x: torch.Tensor) -> torch.Tensor:
"""计算曲线相邻时间点的一阶差分,用于约束导数形态。"""
return x[:, 1:] - x[:, :-1]
def affine_restore(x_scaled: torch.Tensor, mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""在 torch 张量上执行 scaler 的反标准化公式 x * scale + mean。"""
return x_scaled * scale + mean
def autofit_curve_objective_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""用 torch 计算自动拟合风格的曲线误差,作为训练附加目标。"""
weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0)
weight = 1.0 / (1.0 + weight_factor)
scale = torch.maximum(
torch.maximum(torch.abs(target), torch.abs(pred)),
torch.full_like(target, 1e-12),
)
relative_error = torch.abs(target - pred) / scale
absolute_error = torch.abs(target - pred)
point_error = 0.7 * relative_error + 0.3 * absolute_error
weighted_mse = (weight * (point_error**2)).sum(dim=1)
weighted_mse = weighted_mse / torch.clamp(weight.sum(dim=1), min=1e-12)
return torch.sqrt(weighted_mse)
def build_sample_weight(
true_p: torch.Tensor,
true_d: torch.Tensor,
reweight_cfg: SampleReweightConfig,
) -> torch.Tensor:
"""根据真实曲线幅值构造样本权重,让高幅值样本训练时权重更高。"""
p_level = true_p.abs().mean(dim=1)
d_level = true_d.abs().mean(dim=1)
p_norm = p_level / (p_level.mean().detach() + 1e-6)
d_norm = d_level / (d_level.mean().detach() + 1e-6)
raw = 0.5 * p_norm + 0.5 * d_norm
weight = 1.0 + reweight_cfg.alpha * (raw - 1.0)
return torch.clamp(weight, min=reweight_cfg.weight_min, max=reweight_cfg.weight_max)
def split_curve_parts(
pred: torch.Tensor,
target: torch.Tensor,
slices: dict[str, slice],
) -> LossBatchParts:
"""把拼接曲线拆成压力和导数两段。"""
return LossBatchParts(
pred_p=pred[:, slices["log_pressure"]],
pred_d=pred[:, slices["log_derivative"]],
true_p=target[:, slices["log_pressure"]],
true_d=target[:, slices["log_derivative"]],
)
def compute_basic_loss_vectors(
parts: LossBatchParts,
loss_cfg: LossConfig,
) -> dict[str, torch.Tensor]:
"""计算标准化空间中的基础点值、偏置和导数形状损失。"""
return {
"loss_pressure": regression_per_sample(parts.pred_p, parts.true_p, loss_cfg),
"loss_derivative": regression_per_sample(parts.pred_d, parts.true_d, loss_cfg),
"loss_bias_pressure": l1_per_sample(
parts.pred_p.mean(dim=1, keepdim=True),
parts.true_p.mean(dim=1, keepdim=True),
),
"loss_bias_derivative": l1_per_sample(
parts.pred_d.mean(dim=1, keepdim=True),
parts.true_d.mean(dim=1, keepdim=True),
),
"loss_derivative_shape": regression_per_sample(
first_diff(parts.pred_d),
first_diff(parts.true_d),
loss_cfg,
),
}
def compute_autofit_loss_vectors(
parts: LossBatchParts,
context: LossContext,
) -> dict[str, torch.Tensor]:
"""在原始尺度上计算自动拟合风格损失。"""
pressure_slice = context.slices["log_pressure"]
derivative_slice = context.slices["log_derivative"]
mean_p = context.curve_stats.mean_raw[pressure_slice].unsqueeze(0)
scale_p = context.curve_stats.scale_raw[pressure_slice].unsqueeze(0)
mean_d = context.curve_stats.mean_raw[derivative_slice].unsqueeze(0)
scale_d = context.curve_stats.scale_raw[derivative_slice].unsqueeze(0)
return {
"loss_autofit_pressure": autofit_curve_objective_per_sample(
affine_restore(parts.pred_p, mean_p, scale_p),
affine_restore(parts.true_p, mean_p, scale_p),
),
"loss_autofit_derivative": autofit_curve_objective_per_sample(
affine_restore(parts.pred_d, mean_d, scale_d),
affine_restore(parts.true_d, mean_d, scale_d),
),
}
def weighted_total_vector(
loss_vectors: dict[str, torch.Tensor],
weights: LossWeights,
) -> torch.Tensor:
"""按配置权重合成每个样本的总损失向量。"""
return (
weights.pressure * loss_vectors["loss_pressure"]
+ weights.derivative * loss_vectors["loss_derivative"]
+ weights.bias_pressure * loss_vectors["loss_bias_pressure"]
+ weights.bias_derivative * loss_vectors["loss_bias_derivative"]
+ weights.derivative_shape * loss_vectors["loss_derivative_shape"]
+ weights.autofit_pressure * loss_vectors["loss_autofit_pressure"]
+ weights.autofit_derivative * loss_vectors["loss_autofit_derivative"]
)
def compute_weighted_loss(
pred: torch.Tensor,
target: torch.Tensor,
context: LossContext,
) -> dict[str, torch.Tensor]:
"""计算正演代理模型的复合训练损失。"""
parts = split_curve_parts(pred, target, context.slices)
loss_vectors = compute_basic_loss_vectors(parts, context.loss_cfg)
loss_vectors.update(compute_autofit_loss_vectors(parts, context))
total_vec = weighted_total_vector(loss_vectors, context.loss_cfg.weights)
if context.reweight_cfg.enabled:
sample_weight = build_sample_weight(parts.true_p, parts.true_d, context.reweight_cfg)
else:
sample_weight = torch.ones_like(total_vec)
metrics = {key: value.mean() for key, value in loss_vectors.items()}
metrics["loss"] = (total_vec * sample_weight).mean()
metrics["sample_weight_mean"] = sample_weight.mean()
metrics["sample_weight_max"] = sample_weight.max()
return metrics
def model_forward(
model: nn.Module,
params_x: torch.Tensor,
schedule_x: torch.Tensor,
use_schedule: bool,
) -> torch.Tensor:
"""按 use_schedule 开关统一调用模型,兼容只用参数输入和参数+流量制度输入。"""
if use_schedule:
return model(params_x, schedule_x)
return model(params_x, None)
def init_metric_accumulator() -> dict[str, float]:
"""创建指标累加器。"""
return {key: 0.0 for key in METRIC_KEYS}
def accumulate_metrics(
total: dict[str, float],
losses: dict[str, torch.Tensor],
batch_size: int,
) -> None:
"""按 batch 样本数加权累加指标。"""
for key in total:
total[key] += losses[key].item() * batch_size
def average_metrics(total: dict[str, float], total_n: int) -> dict[str, float]:
"""将累加指标转换为样本平均指标。"""
denom = max(total_n, 1)
return {key: value / denom for key, value in total.items()}
def run_loader_epoch(
model: nn.Module,
loader: DataLoader,
device: str,
context: LossContext,
use_schedule: bool,
optimizer: torch.optim.Optimizer | None = None,
) -> dict[str, float]:
"""执行一个训练或评估 epoch并返回平均指标。"""
is_train = optimizer is not None
model.train(mode=is_train)
total = init_metric_accumulator()
total_n = 0
grad_context = torch.enable_grad() if is_train else torch.no_grad()
with grad_context:
for params_x, schedule_x, curve_y in loader:
params_x = params_x.to(device)
schedule_x = schedule_x.to(device)
curve_y = curve_y.to(device)
if is_train:
optimizer.zero_grad()
pred = model_forward(model, params_x, schedule_x, use_schedule)
losses = compute_weighted_loss(pred=pred, target=curve_y, context=context)
if is_train:
losses["loss"].backward()
optimizer.step()
batch_size = params_x.size(0)
accumulate_metrics(total, losses, batch_size)
total_n += batch_size
return average_metrics(total, total_n)
def evaluate(
model: nn.Module,
loader: DataLoader,
device: str,
context: LossContext,
use_schedule: bool,
) -> dict[str, float]:
"""在验证或测试 DataLoader 上计算平均损失和各损失分量。"""
return run_loader_epoch(
model=model,
loader=loader,
device=device,
context=context,
use_schedule=use_schedule,
optimizer=None,
)
def build_curve_stats(data: dict, device: str) -> CurveStats:
"""从预处理数据中的曲线 scaler 构建 torch 统计量。"""
scaler_curve = data["scaler_curve"]
curve_mean_raw = np.asarray(scaler_curve.mean_, dtype=np.float32).reshape(-1)
curve_scale_raw = np.asarray(scaler_curve.scale_, dtype=np.float32).reshape(-1)
return CurveStats(
mean_raw=torch.tensor(curve_mean_raw, dtype=torch.float32, device=device),
scale_raw=torch.tensor(curve_scale_raw, dtype=torch.float32, device=device),
)
def build_dataloaders(data: dict, cfg: TrainConfig) -> DatasetBundle:
"""根据预处理数组构造训练、验证、测试 DataLoader。"""
train_ds = ForwardDataset(data["X_params_train"], data["X_schedule_train"], data["Y_curve_train"])
val_ds = ForwardDataset(data["X_params_val"], data["X_schedule_val"], data["Y_curve_val"])
test_ds = ForwardDataset(data["X_params_test"], data["X_schedule_test"], data["Y_curve_test"])
loader_generator = torch.Generator()
loader_generator.manual_seed(int(cfg.runtime.seed))
train_loader = DataLoader(
train_ds,
batch_size=cfg.optim.batch_size,
shuffle=True,
generator=loader_generator,
)
val_loader = DataLoader(val_ds, batch_size=cfg.optim.batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=cfg.optim.batch_size, shuffle=False)
return DatasetBundle(
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
param_dim=data["X_params_train"].shape[1],
schedule_dim=data["X_schedule_train"].shape[1],
curve_dim=data["Y_curve_train"].shape[1],
)
def build_forward_model(
model_cfg: ModelConfig,
param_dim: int,
schedule_dim: int,
curve_dim: int,
device: str,
) -> nn.Module:
"""兼容新版配置式 ForwardSurrogate 和旧版关键字参数式 ForwardSurrogate。"""
surrogate_cfg = ForwardSurrogateConfig(
param_dim=param_dim,
schedule_dim=schedule_dim,
curve_dim=curve_dim,
hidden_dim=model_cfg.hidden_dim,
dropout=model_cfg.dropout,
use_schedule=model_cfg.use_schedule,
)
return ForwardSurrogate(surrogate_cfg).to(device)
def build_optimizer(model: nn.Module, optim_cfg: OptimConfig) -> torch.optim.Optimizer:
"""构建 Adam 优化器。"""
return torch.optim.Adam(
model.parameters(),
lr=optim_cfg.lr,
weight_decay=optim_cfg.weight_decay,
)
def print_training_config(cfg: TrainConfig, curve_layout: dict) -> None:
"""打印训练配置摘要。"""
weights = cfg.loss.weights
reweight = cfg.sample_reweight
print("训练配置:")
print(f" device={cfg.runtime.device}")
print(f" seed={cfg.runtime.seed}")
print(
f" batch_size={cfg.optim.batch_size}, epochs={cfg.optim.epochs}, "
f"lr={cfg.optim.lr}, weight_decay={cfg.optim.weight_decay}"
)
print(f" hidden_dim={cfg.model.hidden_dim}, dropout={cfg.model.dropout}")
print(f" use_schedule={cfg.model.use_schedule}")
print(
f" weights: pressure={weights.pressure}, derivative={weights.derivative}, "
f"bias_p={weights.bias_pressure}, "
f"bias_d={weights.bias_derivative}, d_shape={weights.derivative_shape}, "
f"autofit_p={weights.autofit_pressure}, autofit_d={weights.autofit_derivative}"
)
print(
f" sample_reweight={reweight.enabled}, alpha={reweight.alpha}, "
f"clip=[{reweight.weight_min}, {reweight.weight_max}]"
)
print(f" curve_layout={curve_layout}")
print(" note: 当前重点训练 pressure + derivative可显式关闭 schedule 分支做固定制度对照")
def format_metric_line(epoch: int, train_metrics: dict[str, float], val_metrics: dict[str, float]) -> str:
"""格式化单个 epoch 的训练与验证指标。"""
return (
f"[Epoch {epoch:03d}] "
f"train={train_metrics['loss']:.6f} "
f"(p={train_metrics['loss_pressure']:.6f}, "
f"d={train_metrics['loss_derivative']:.6f}, "
f"bp={train_metrics['loss_bias_pressure']:.6f}, "
f"bd={train_metrics['loss_bias_derivative']:.6f}, "
f"ds={train_metrics['loss_derivative_shape']:.6f}, "
f"ap={train_metrics['loss_autofit_pressure']:.6f}, "
f"ad={train_metrics['loss_autofit_derivative']:.6f}, "
f"wmean={train_metrics['sample_weight_mean']:.4f}, "
f"wmax={train_metrics['sample_weight_max']:.4f}) "
f"val={val_metrics['loss']:.6f} "
f"(p={val_metrics['loss_pressure']:.6f}, "
f"d={val_metrics['loss_derivative']:.6f}, "
f"bp={val_metrics['loss_bias_pressure']:.6f}, "
f"bd={val_metrics['loss_bias_derivative']:.6f}, "
f"ds={val_metrics['loss_derivative_shape']:.6f}, "
f"ap={val_metrics['loss_autofit_pressure']:.6f}, "
f"ad={val_metrics['loss_autofit_derivative']:.6f}, "
f"wmean={val_metrics['sample_weight_mean']:.4f}, "
f"wmax={val_metrics['sample_weight_max']:.4f})"
)
def format_final_line(test_metrics: dict[str, float]) -> str:
"""格式化最终测试集指标。"""
return (
f"[Final] test={test_metrics['loss']:.6f} "
f"(p={test_metrics['loss_pressure']:.6f}, "
f"d={test_metrics['loss_derivative']:.6f}, "
f"bp={test_metrics['loss_bias_pressure']:.6f}, "
f"bd={test_metrics['loss_bias_derivative']:.6f}, "
f"ds={test_metrics['loss_derivative_shape']:.6f}, "
f"ap={test_metrics['loss_autofit_pressure']:.6f}, "
f"ad={test_metrics['loss_autofit_derivative']:.6f}, "
f"wmean={test_metrics['sample_weight_mean']:.4f}, "
f"wmax={test_metrics['sample_weight_max']:.4f})"
)
def build_checkpoint_payload(
model: nn.Module,
bundle: DatasetBundle,
cfg: TrainConfig,
curve_layout: dict,
) -> dict[str, Any]:
"""构建 checkpoint 保存内容。"""
return {
"model_state_dict": model.state_dict(),
"param_dim": bundle.param_dim,
"schedule_dim": bundle.schedule_dim,
"curve_dim": bundle.curve_dim,
"hidden_dim": cfg.model.hidden_dim,
"dropout": cfg.model.dropout,
"use_schedule": cfg.model.use_schedule,
"seed": int(cfg.runtime.seed),
"curve_layout": curve_layout,
"loss_weights": asdict(cfg.loss.weights),
"sample_reweight": asdict(cfg.sample_reweight),
}
def append_history_row(
history: list[dict],
epoch: int,
train_metrics: dict[str, float],
val_metrics: dict[str, float],
) -> None:
"""把当前 epoch 的指标写入 history 列表。"""
row = {"epoch": epoch}
row.update({f"train_{key}": float(value) for key, value in train_metrics.items()})
row.update({f"val_{key}": float(value) for key, value in val_metrics.items()})
history.append(row)
def save_json(path: Path, payload: dict | list) -> None:
"""保存 JSON 文件。"""
with open(path, "w", encoding="utf-8") as file_obj:
json.dump(payload, file_obj, ensure_ascii=False, indent=2)
def train_epochs(
model: nn.Module,
bundle: DatasetBundle,
cfg: TrainConfig,
context: LossContext,
curve_layout: dict,
) -> tuple[float, Path, list[dict]]:
"""执行训练循环并保存最佳模型。"""
optimizer = build_optimizer(model, cfg.optim)
best_val = float("inf")
best_path = cfg.output_dir / "forward_surrogate_best.pt"
history: list[dict] = []
for epoch in range(1, cfg.optim.epochs + 1):
train_metrics = run_loader_epoch(
model=model,
loader=bundle.train_loader,
device=cfg.runtime.device,
context=context,
use_schedule=cfg.model.use_schedule,
optimizer=optimizer,
)
val_metrics = evaluate(
model=model,
loader=bundle.val_loader,
device=cfg.runtime.device,
context=context,
use_schedule=cfg.model.use_schedule,
)
append_history_row(history, epoch, train_metrics, val_metrics)
print(format_metric_line(epoch, train_metrics, val_metrics))
if val_metrics["loss"] < best_val:
best_val = val_metrics["loss"]
torch.save(build_checkpoint_payload(model, bundle, cfg, curve_layout), best_path)
print(f" -> best model saved to: {best_path}")
return best_val, best_path, history
def load_best_model(best_path: Path, device: str) -> nn.Module:
"""从 checkpoint 恢复最佳模型。"""
checkpoint = torch.load(best_path, map_location=device)
model_cfg = ModelConfig(
hidden_dim=checkpoint["hidden_dim"],
dropout=checkpoint["dropout"],
use_schedule=checkpoint.get("use_schedule", True),
)
best_model = build_forward_model(
model_cfg=model_cfg,
param_dim=checkpoint["param_dim"],
schedule_dim=checkpoint["schedule_dim"],
curve_dim=checkpoint["curve_dim"],
device=device,
)
best_model.load_state_dict(checkpoint["model_state_dict"])
return best_model
def build_metrics_payload(
best_val: float,
test_metrics: dict[str, float],
cfg: TrainConfig,
curve_layout: dict,
) -> dict[str, Any]:
"""构建 metrics.json 内容。"""
return {
"best_val_loss": float(best_val),
"test_metrics": {key: float(value) for key, value in test_metrics.items()},
"use_schedule": cfg.model.use_schedule,
"seed": int(cfg.runtime.seed),
"loss_weights": asdict(cfg.loss.weights),
"sample_reweight": asdict(cfg.sample_reweight),
"curve_layout": curve_layout,
}
def train_forward(cfg: TrainConfig) -> None:
"""训练完整曲线正演代理模型。"""
cfg.output_dir.mkdir(parents=True, exist_ok=True)
set_global_seed(int(cfg.runtime.seed))
data = load_processed_dataset(cfg.processed_path)
curve_layout = infer_curve_layout(data)
context = LossContext(
slices=get_part_slices(curve_layout),
curve_stats=build_curve_stats(data, cfg.runtime.device),
loss_cfg=cfg.loss,
reweight_cfg=cfg.sample_reweight,
)
bundle = build_dataloaders(data, cfg)
model = build_forward_model(
model_cfg=cfg.model,
param_dim=bundle.param_dim,
schedule_dim=bundle.schedule_dim,
curve_dim=bundle.curve_dim,
device=cfg.runtime.device,
)
print_training_config(cfg, curve_layout)
best_val, best_path, history = train_epochs(model, bundle, cfg, context, curve_layout)
save_json(cfg.output_dir / "history.json", history)
best_model = load_best_model(best_path, cfg.runtime.device)
test_metrics = evaluate(
model=best_model,
loader=bundle.test_loader,
device=cfg.runtime.device,
context=context,
use_schedule=cfg.model.use_schedule,
)
print(format_final_line(test_metrics))
save_json(
cfg.output_dir / "metrics.json",
build_metrics_payload(best_val, test_metrics, cfg, curve_layout),
)