|
|
# -*- coding: utf-8 -*-
|
|
|
"""正演代理模型训练流程。
|
|
|
|
|
|
本模块读取 preprocess.py 生成的 joblib 数据,构造 PyTorch Dataset/DataLoader,训练
|
|
|
ForwardSurrogate,并按验证集损失保存最佳 checkpoint。损失函数不是单一 MSE,而是
|
|
|
由压力、导数、均值偏置、导数形状约束和可选自动拟合目标组成,目的是让
|
|
|
模型既能拟合点值,也能保持对自动试井拟合有意义的曲线形态。
|
|
|
|
|
|
训练过程会保存 history.json、metrics.json 和 forward_surrogate_best.pt,后续评估
|
|
|
脚本可以根据 checkpoint 中保存的维度、curve_layout 和损失权重恢复模型。
|
|
|
"""
|
|
|
|
|
|
# pylint: disable=import-error,duplicate-code,too-many-instance-attributes
|
|
|
# pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import json
|
|
|
import random
|
|
|
from dataclasses import asdict, dataclass, field
|
|
|
from pathlib import Path
|
|
|
from typing import Any
|
|
|
|
|
|
import joblib
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
from torch.utils.data import DataLoader, Dataset
|
|
|
|
|
|
from src.models.forward_surrogate import ForwardSurrogate, ForwardSurrogateConfig
|
|
|
|
|
|
|
|
|
METRIC_KEYS = (
|
|
|
"loss",
|
|
|
"loss_pressure",
|
|
|
"loss_derivative",
|
|
|
"loss_bias_pressure",
|
|
|
"loss_bias_derivative",
|
|
|
"loss_derivative_shape",
|
|
|
"loss_autofit_pressure",
|
|
|
"loss_autofit_derivative",
|
|
|
"sample_weight_mean",
|
|
|
"sample_weight_max",
|
|
|
)
|
|
|
|
|
|
|
|
|
class ForwardDataset(Dataset):
|
|
|
"""把预处理后的参数、流量制度和曲线数组封装成 PyTorch Dataset。"""
|
|
|
|
|
|
def __init__(self, params_x: np.ndarray, schedule_x: np.ndarray, curve_y: np.ndarray):
|
|
|
"""把三个 numpy 数组转为 float32 张量,后续 DataLoader 可直接按样本读取。"""
|
|
|
self.params_x = torch.tensor(params_x, dtype=torch.float32)
|
|
|
self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32)
|
|
|
self.curve_y = torch.tensor(curve_y, dtype=torch.float32)
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
"""返回数据集或容器中可迭代样本的数量。"""
|
|
|
return len(self.params_x)
|
|
|
|
|
|
def __getitem__(self, idx: int):
|
|
|
"""按索引取出一个训练样本或数据项。"""
|
|
|
return self.params_x[idx], self.schedule_x[idx], self.curve_y[idx]
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class ModelConfig:
|
|
|
"""模型结构相关配置。"""
|
|
|
|
|
|
hidden_dim: int = 128
|
|
|
dropout: float = 0.0
|
|
|
use_schedule: bool = True
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class OptimConfig:
|
|
|
"""优化器与训练轮次配置。"""
|
|
|
|
|
|
batch_size: int = 128
|
|
|
epochs: int = 100
|
|
|
lr: float = 1e-3
|
|
|
weight_decay: float = 1e-5
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class LossWeights:
|
|
|
"""复合损失的各项权重。"""
|
|
|
|
|
|
pressure: float = 1.0
|
|
|
derivative: float = 2.0
|
|
|
bias_pressure: float = 0.15
|
|
|
bias_derivative: float = 0.05
|
|
|
derivative_shape: float = 0.10
|
|
|
autofit_pressure: float = 0.0
|
|
|
autofit_derivative: float = 0.0
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class LossConfig:
|
|
|
"""损失函数配置。"""
|
|
|
|
|
|
weights: LossWeights = field(default_factory=LossWeights)
|
|
|
use_huber: bool = True
|
|
|
huber_beta: float = 0.05
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class SampleReweightConfig:
|
|
|
"""样本重加权配置。"""
|
|
|
|
|
|
enabled: bool = True
|
|
|
alpha: float = 0.4
|
|
|
weight_min: float = 1.0
|
|
|
weight_max: float = 2.5
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class TrainRuntime:
|
|
|
"""训练运行时配置。"""
|
|
|
|
|
|
seed: int = 42
|
|
|
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class TrainConfig:
|
|
|
"""正演代理模型训练配置。"""
|
|
|
|
|
|
processed_path: Path
|
|
|
output_dir: Path
|
|
|
runtime: TrainRuntime = field(default_factory=TrainRuntime)
|
|
|
optim: OptimConfig = field(default_factory=OptimConfig)
|
|
|
model: ModelConfig = field(default_factory=ModelConfig)
|
|
|
loss: LossConfig = field(default_factory=LossConfig)
|
|
|
sample_reweight: SampleReweightConfig = field(default_factory=SampleReweightConfig)
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class CurveStats:
|
|
|
"""曲线 scaler 的 torch 形式统计量。"""
|
|
|
|
|
|
mean_raw: torch.Tensor
|
|
|
scale_raw: torch.Tensor
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class LossBatchParts:
|
|
|
"""压力和导数的预测值与真实值。"""
|
|
|
|
|
|
pred_p: torch.Tensor
|
|
|
pred_d: torch.Tensor
|
|
|
true_p: torch.Tensor
|
|
|
true_d: torch.Tensor
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class LossContext:
|
|
|
"""计算复合损失所需的上下文。"""
|
|
|
|
|
|
slices: dict[str, slice]
|
|
|
curve_stats: CurveStats
|
|
|
loss_cfg: LossConfig
|
|
|
reweight_cfg: SampleReweightConfig
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
class DatasetBundle:
|
|
|
"""训练、验证、测试 DataLoader 与数据维度。"""
|
|
|
|
|
|
train_loader: DataLoader
|
|
|
val_loader: DataLoader
|
|
|
test_loader: DataLoader
|
|
|
param_dim: int
|
|
|
schedule_dim: int
|
|
|
curve_dim: int
|
|
|
|
|
|
|
|
|
def set_global_seed(seed: int) -> None:
|
|
|
"""设置 Python、NumPy 和 PyTorch 随机种子,并在 CUDA 可用时同步设置 GPU 随机种子。"""
|
|
|
random.seed(seed)
|
|
|
np.random.seed(seed)
|
|
|
torch.manual_seed(seed)
|
|
|
if torch.cuda.is_available():
|
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
|
|
|
def load_processed_dataset(path: Path) -> dict:
|
|
|
"""读取预处理后的 joblib 数据,并检查训练、验证、测试三套数组是否齐全。"""
|
|
|
data = joblib.load(path)
|
|
|
required_keys = [
|
|
|
"X_params_train", "X_schedule_train", "Y_curve_train",
|
|
|
"X_params_val", "X_schedule_val", "Y_curve_val",
|
|
|
"X_params_test", "X_schedule_test", "Y_curve_test",
|
|
|
]
|
|
|
for key in required_keys:
|
|
|
if key not in data:
|
|
|
raise KeyError(f"processed dataset 缺少字段: {key}")
|
|
|
return data
|
|
|
|
|
|
|
|
|
def infer_curve_layout(data: dict) -> dict:
|
|
|
"""读取双通道布局,并拒绝未经重新预处理的旧三通道数据。"""
|
|
|
meta = data.get("meta", {})
|
|
|
curve_dim = int(meta.get("curve_dim", data["Y_curve_train"].shape[1]))
|
|
|
curve_layout = meta.get("curve_layout")
|
|
|
|
|
|
if curve_layout is not None:
|
|
|
names = {str(part["name"]) for part in curve_layout.get("parts", [])}
|
|
|
if "slope" in names:
|
|
|
raise ValueError(
|
|
|
"processed 数据仍包含旧版 slope 通道;请先重新运行 "
|
|
|
"scripts/preprocess_dataset.py,预处理会自动裁掉 slope"
|
|
|
)
|
|
|
return curve_layout
|
|
|
|
|
|
if curve_dim % 2 != 0:
|
|
|
raise ValueError(f"curve_dim={curve_dim} 不能按压力/导数两段均分")
|
|
|
|
|
|
n_time_points = curve_dim // 2
|
|
|
return {
|
|
|
"n_time_points": int(n_time_points),
|
|
|
"parts": [
|
|
|
{"name": "log_pressure", "start": 0, "end": n_time_points},
|
|
|
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
|
|
|
],
|
|
|
}
|
|
|
|
|
|
|
|
|
def get_part_slices(curve_layout: dict) -> dict[str, slice]:
|
|
|
"""把 curve_layout 中的 start/end 信息转换成各曲线分段的 slice。"""
|
|
|
out: dict[str, slice] = {}
|
|
|
for part in curve_layout["parts"]:
|
|
|
name = str(part["name"])
|
|
|
out[name] = slice(int(part["start"]), int(part["end"]))
|
|
|
return out
|
|
|
|
|
|
|
|
|
def smooth_l1_per_sample(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor:
|
|
|
"""按样本计算 Smooth L1 损失,返回每个样本一个损失值。"""
|
|
|
diff = torch.abs(pred - target)
|
|
|
loss = torch.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta)
|
|
|
return loss.mean(dim=1)
|
|
|
|
|
|
|
|
|
def mse_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
|
"""按样本计算均方误差。"""
|
|
|
return ((pred - target) ** 2).mean(dim=1)
|
|
|
|
|
|
|
|
|
def l1_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
|
"""按样本计算平均绝对误差。"""
|
|
|
return torch.abs(pred - target).mean(dim=1)
|
|
|
|
|
|
|
|
|
def regression_per_sample(
|
|
|
pred: torch.Tensor,
|
|
|
target: torch.Tensor,
|
|
|
loss_cfg: LossConfig,
|
|
|
) -> torch.Tensor:
|
|
|
"""按配置在 Smooth L1 和 MSE 之间切换点值损失。"""
|
|
|
if loss_cfg.use_huber:
|
|
|
return smooth_l1_per_sample(pred, target, beta=float(loss_cfg.huber_beta))
|
|
|
return mse_per_sample(pred, target)
|
|
|
|
|
|
|
|
|
def first_diff(x: torch.Tensor) -> torch.Tensor:
|
|
|
"""计算曲线相邻时间点的一阶差分,用于约束导数形态。"""
|
|
|
return x[:, 1:] - x[:, :-1]
|
|
|
|
|
|
|
|
|
def affine_restore(x_scaled: torch.Tensor, mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
|
"""在 torch 张量上执行 scaler 的反标准化公式 x * scale + mean。"""
|
|
|
return x_scaled * scale + mean
|
|
|
|
|
|
|
|
|
def autofit_curve_objective_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
|
"""用 torch 计算自动拟合风格的曲线误差,作为训练附加目标。"""
|
|
|
weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0)
|
|
|
weight = 1.0 / (1.0 + weight_factor)
|
|
|
scale = torch.maximum(
|
|
|
torch.maximum(torch.abs(target), torch.abs(pred)),
|
|
|
torch.full_like(target, 1e-12),
|
|
|
)
|
|
|
relative_error = torch.abs(target - pred) / scale
|
|
|
absolute_error = torch.abs(target - pred)
|
|
|
point_error = 0.7 * relative_error + 0.3 * absolute_error
|
|
|
weighted_mse = (weight * (point_error**2)).sum(dim=1)
|
|
|
weighted_mse = weighted_mse / torch.clamp(weight.sum(dim=1), min=1e-12)
|
|
|
return torch.sqrt(weighted_mse)
|
|
|
|
|
|
|
|
|
def build_sample_weight(
|
|
|
true_p: torch.Tensor,
|
|
|
true_d: torch.Tensor,
|
|
|
reweight_cfg: SampleReweightConfig,
|
|
|
) -> torch.Tensor:
|
|
|
"""根据真实曲线幅值构造样本权重,让高幅值样本训练时权重更高。"""
|
|
|
p_level = true_p.abs().mean(dim=1)
|
|
|
d_level = true_d.abs().mean(dim=1)
|
|
|
|
|
|
p_norm = p_level / (p_level.mean().detach() + 1e-6)
|
|
|
d_norm = d_level / (d_level.mean().detach() + 1e-6)
|
|
|
|
|
|
raw = 0.5 * p_norm + 0.5 * d_norm
|
|
|
weight = 1.0 + reweight_cfg.alpha * (raw - 1.0)
|
|
|
return torch.clamp(weight, min=reweight_cfg.weight_min, max=reweight_cfg.weight_max)
|
|
|
|
|
|
|
|
|
def split_curve_parts(
|
|
|
pred: torch.Tensor,
|
|
|
target: torch.Tensor,
|
|
|
slices: dict[str, slice],
|
|
|
) -> LossBatchParts:
|
|
|
"""把拼接曲线拆成压力和导数两段。"""
|
|
|
return LossBatchParts(
|
|
|
pred_p=pred[:, slices["log_pressure"]],
|
|
|
pred_d=pred[:, slices["log_derivative"]],
|
|
|
true_p=target[:, slices["log_pressure"]],
|
|
|
true_d=target[:, slices["log_derivative"]],
|
|
|
)
|
|
|
|
|
|
|
|
|
def compute_basic_loss_vectors(
|
|
|
parts: LossBatchParts,
|
|
|
loss_cfg: LossConfig,
|
|
|
) -> dict[str, torch.Tensor]:
|
|
|
"""计算标准化空间中的基础点值、偏置和导数形状损失。"""
|
|
|
return {
|
|
|
"loss_pressure": regression_per_sample(parts.pred_p, parts.true_p, loss_cfg),
|
|
|
"loss_derivative": regression_per_sample(parts.pred_d, parts.true_d, loss_cfg),
|
|
|
"loss_bias_pressure": l1_per_sample(
|
|
|
parts.pred_p.mean(dim=1, keepdim=True),
|
|
|
parts.true_p.mean(dim=1, keepdim=True),
|
|
|
),
|
|
|
"loss_bias_derivative": l1_per_sample(
|
|
|
parts.pred_d.mean(dim=1, keepdim=True),
|
|
|
parts.true_d.mean(dim=1, keepdim=True),
|
|
|
),
|
|
|
"loss_derivative_shape": regression_per_sample(
|
|
|
first_diff(parts.pred_d),
|
|
|
first_diff(parts.true_d),
|
|
|
loss_cfg,
|
|
|
),
|
|
|
}
|
|
|
|
|
|
|
|
|
def compute_autofit_loss_vectors(
|
|
|
parts: LossBatchParts,
|
|
|
context: LossContext,
|
|
|
) -> dict[str, torch.Tensor]:
|
|
|
"""在原始尺度上计算自动拟合风格损失。"""
|
|
|
pressure_slice = context.slices["log_pressure"]
|
|
|
derivative_slice = context.slices["log_derivative"]
|
|
|
|
|
|
mean_p = context.curve_stats.mean_raw[pressure_slice].unsqueeze(0)
|
|
|
scale_p = context.curve_stats.scale_raw[pressure_slice].unsqueeze(0)
|
|
|
mean_d = context.curve_stats.mean_raw[derivative_slice].unsqueeze(0)
|
|
|
scale_d = context.curve_stats.scale_raw[derivative_slice].unsqueeze(0)
|
|
|
|
|
|
return {
|
|
|
"loss_autofit_pressure": autofit_curve_objective_per_sample(
|
|
|
affine_restore(parts.pred_p, mean_p, scale_p),
|
|
|
affine_restore(parts.true_p, mean_p, scale_p),
|
|
|
),
|
|
|
"loss_autofit_derivative": autofit_curve_objective_per_sample(
|
|
|
affine_restore(parts.pred_d, mean_d, scale_d),
|
|
|
affine_restore(parts.true_d, mean_d, scale_d),
|
|
|
),
|
|
|
}
|
|
|
|
|
|
|
|
|
def weighted_total_vector(
|
|
|
loss_vectors: dict[str, torch.Tensor],
|
|
|
weights: LossWeights,
|
|
|
) -> torch.Tensor:
|
|
|
"""按配置权重合成每个样本的总损失向量。"""
|
|
|
return (
|
|
|
weights.pressure * loss_vectors["loss_pressure"]
|
|
|
+ weights.derivative * loss_vectors["loss_derivative"]
|
|
|
+ weights.bias_pressure * loss_vectors["loss_bias_pressure"]
|
|
|
+ weights.bias_derivative * loss_vectors["loss_bias_derivative"]
|
|
|
+ weights.derivative_shape * loss_vectors["loss_derivative_shape"]
|
|
|
+ weights.autofit_pressure * loss_vectors["loss_autofit_pressure"]
|
|
|
+ weights.autofit_derivative * loss_vectors["loss_autofit_derivative"]
|
|
|
)
|
|
|
|
|
|
|
|
|
def compute_weighted_loss(
|
|
|
pred: torch.Tensor,
|
|
|
target: torch.Tensor,
|
|
|
context: LossContext,
|
|
|
) -> dict[str, torch.Tensor]:
|
|
|
"""计算正演代理模型的复合训练损失。"""
|
|
|
parts = split_curve_parts(pred, target, context.slices)
|
|
|
loss_vectors = compute_basic_loss_vectors(parts, context.loss_cfg)
|
|
|
loss_vectors.update(compute_autofit_loss_vectors(parts, context))
|
|
|
|
|
|
total_vec = weighted_total_vector(loss_vectors, context.loss_cfg.weights)
|
|
|
if context.reweight_cfg.enabled:
|
|
|
sample_weight = build_sample_weight(parts.true_p, parts.true_d, context.reweight_cfg)
|
|
|
else:
|
|
|
sample_weight = torch.ones_like(total_vec)
|
|
|
|
|
|
metrics = {key: value.mean() for key, value in loss_vectors.items()}
|
|
|
metrics["loss"] = (total_vec * sample_weight).mean()
|
|
|
metrics["sample_weight_mean"] = sample_weight.mean()
|
|
|
metrics["sample_weight_max"] = sample_weight.max()
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
def model_forward(
|
|
|
model: nn.Module,
|
|
|
params_x: torch.Tensor,
|
|
|
schedule_x: torch.Tensor,
|
|
|
use_schedule: bool,
|
|
|
) -> torch.Tensor:
|
|
|
"""按 use_schedule 开关统一调用模型,兼容只用参数输入和参数+流量制度输入。"""
|
|
|
if use_schedule:
|
|
|
return model(params_x, schedule_x)
|
|
|
return model(params_x, None)
|
|
|
|
|
|
|
|
|
def init_metric_accumulator() -> dict[str, float]:
|
|
|
"""创建指标累加器。"""
|
|
|
return {key: 0.0 for key in METRIC_KEYS}
|
|
|
|
|
|
|
|
|
def accumulate_metrics(
|
|
|
total: dict[str, float],
|
|
|
losses: dict[str, torch.Tensor],
|
|
|
batch_size: int,
|
|
|
) -> None:
|
|
|
"""按 batch 样本数加权累加指标。"""
|
|
|
for key in total:
|
|
|
total[key] += losses[key].item() * batch_size
|
|
|
|
|
|
|
|
|
def average_metrics(total: dict[str, float], total_n: int) -> dict[str, float]:
|
|
|
"""将累加指标转换为样本平均指标。"""
|
|
|
denom = max(total_n, 1)
|
|
|
return {key: value / denom for key, value in total.items()}
|
|
|
|
|
|
|
|
|
def run_loader_epoch(
|
|
|
model: nn.Module,
|
|
|
loader: DataLoader,
|
|
|
device: str,
|
|
|
context: LossContext,
|
|
|
use_schedule: bool,
|
|
|
optimizer: torch.optim.Optimizer | None = None,
|
|
|
) -> dict[str, float]:
|
|
|
"""执行一个训练或评估 epoch,并返回平均指标。"""
|
|
|
is_train = optimizer is not None
|
|
|
model.train(mode=is_train)
|
|
|
|
|
|
total = init_metric_accumulator()
|
|
|
total_n = 0
|
|
|
grad_context = torch.enable_grad() if is_train else torch.no_grad()
|
|
|
|
|
|
with grad_context:
|
|
|
for params_x, schedule_x, curve_y in loader:
|
|
|
params_x = params_x.to(device)
|
|
|
schedule_x = schedule_x.to(device)
|
|
|
curve_y = curve_y.to(device)
|
|
|
|
|
|
if is_train:
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
pred = model_forward(model, params_x, schedule_x, use_schedule)
|
|
|
losses = compute_weighted_loss(pred=pred, target=curve_y, context=context)
|
|
|
|
|
|
if is_train:
|
|
|
losses["loss"].backward()
|
|
|
optimizer.step()
|
|
|
|
|
|
batch_size = params_x.size(0)
|
|
|
accumulate_metrics(total, losses, batch_size)
|
|
|
total_n += batch_size
|
|
|
|
|
|
return average_metrics(total, total_n)
|
|
|
|
|
|
|
|
|
def evaluate(
|
|
|
model: nn.Module,
|
|
|
loader: DataLoader,
|
|
|
device: str,
|
|
|
context: LossContext,
|
|
|
use_schedule: bool,
|
|
|
) -> dict[str, float]:
|
|
|
"""在验证或测试 DataLoader 上计算平均损失和各损失分量。"""
|
|
|
return run_loader_epoch(
|
|
|
model=model,
|
|
|
loader=loader,
|
|
|
device=device,
|
|
|
context=context,
|
|
|
use_schedule=use_schedule,
|
|
|
optimizer=None,
|
|
|
)
|
|
|
|
|
|
|
|
|
def build_curve_stats(data: dict, device: str) -> CurveStats:
|
|
|
"""从预处理数据中的曲线 scaler 构建 torch 统计量。"""
|
|
|
scaler_curve = data["scaler_curve"]
|
|
|
curve_mean_raw = np.asarray(scaler_curve.mean_, dtype=np.float32).reshape(-1)
|
|
|
curve_scale_raw = np.asarray(scaler_curve.scale_, dtype=np.float32).reshape(-1)
|
|
|
return CurveStats(
|
|
|
mean_raw=torch.tensor(curve_mean_raw, dtype=torch.float32, device=device),
|
|
|
scale_raw=torch.tensor(curve_scale_raw, dtype=torch.float32, device=device),
|
|
|
)
|
|
|
|
|
|
|
|
|
def build_dataloaders(data: dict, cfg: TrainConfig) -> DatasetBundle:
|
|
|
"""根据预处理数组构造训练、验证、测试 DataLoader。"""
|
|
|
train_ds = ForwardDataset(data["X_params_train"], data["X_schedule_train"], data["Y_curve_train"])
|
|
|
val_ds = ForwardDataset(data["X_params_val"], data["X_schedule_val"], data["Y_curve_val"])
|
|
|
test_ds = ForwardDataset(data["X_params_test"], data["X_schedule_test"], data["Y_curve_test"])
|
|
|
|
|
|
loader_generator = torch.Generator()
|
|
|
loader_generator.manual_seed(int(cfg.runtime.seed))
|
|
|
|
|
|
train_loader = DataLoader(
|
|
|
train_ds,
|
|
|
batch_size=cfg.optim.batch_size,
|
|
|
shuffle=True,
|
|
|
generator=loader_generator,
|
|
|
)
|
|
|
val_loader = DataLoader(val_ds, batch_size=cfg.optim.batch_size, shuffle=False)
|
|
|
test_loader = DataLoader(test_ds, batch_size=cfg.optim.batch_size, shuffle=False)
|
|
|
|
|
|
return DatasetBundle(
|
|
|
train_loader=train_loader,
|
|
|
val_loader=val_loader,
|
|
|
test_loader=test_loader,
|
|
|
param_dim=data["X_params_train"].shape[1],
|
|
|
schedule_dim=data["X_schedule_train"].shape[1],
|
|
|
curve_dim=data["Y_curve_train"].shape[1],
|
|
|
)
|
|
|
|
|
|
|
|
|
def build_forward_model(
|
|
|
model_cfg: ModelConfig,
|
|
|
param_dim: int,
|
|
|
schedule_dim: int,
|
|
|
curve_dim: int,
|
|
|
device: str,
|
|
|
) -> nn.Module:
|
|
|
"""兼容新版配置式 ForwardSurrogate 和旧版关键字参数式 ForwardSurrogate。"""
|
|
|
surrogate_cfg = ForwardSurrogateConfig(
|
|
|
param_dim=param_dim,
|
|
|
schedule_dim=schedule_dim,
|
|
|
curve_dim=curve_dim,
|
|
|
hidden_dim=model_cfg.hidden_dim,
|
|
|
dropout=model_cfg.dropout,
|
|
|
use_schedule=model_cfg.use_schedule,
|
|
|
)
|
|
|
return ForwardSurrogate(surrogate_cfg).to(device)
|
|
|
|
|
|
|
|
|
def build_optimizer(model: nn.Module, optim_cfg: OptimConfig) -> torch.optim.Optimizer:
|
|
|
"""构建 Adam 优化器。"""
|
|
|
return torch.optim.Adam(
|
|
|
model.parameters(),
|
|
|
lr=optim_cfg.lr,
|
|
|
weight_decay=optim_cfg.weight_decay,
|
|
|
)
|
|
|
|
|
|
|
|
|
def print_training_config(cfg: TrainConfig, curve_layout: dict) -> None:
|
|
|
"""打印训练配置摘要。"""
|
|
|
weights = cfg.loss.weights
|
|
|
reweight = cfg.sample_reweight
|
|
|
print("训练配置:")
|
|
|
print(f" device={cfg.runtime.device}")
|
|
|
print(f" seed={cfg.runtime.seed}")
|
|
|
print(
|
|
|
f" batch_size={cfg.optim.batch_size}, epochs={cfg.optim.epochs}, "
|
|
|
f"lr={cfg.optim.lr}, weight_decay={cfg.optim.weight_decay}"
|
|
|
)
|
|
|
print(f" hidden_dim={cfg.model.hidden_dim}, dropout={cfg.model.dropout}")
|
|
|
print(f" use_schedule={cfg.model.use_schedule}")
|
|
|
print(
|
|
|
f" weights: pressure={weights.pressure}, derivative={weights.derivative}, "
|
|
|
f"bias_p={weights.bias_pressure}, "
|
|
|
f"bias_d={weights.bias_derivative}, d_shape={weights.derivative_shape}, "
|
|
|
f"autofit_p={weights.autofit_pressure}, autofit_d={weights.autofit_derivative}"
|
|
|
)
|
|
|
print(
|
|
|
f" sample_reweight={reweight.enabled}, alpha={reweight.alpha}, "
|
|
|
f"clip=[{reweight.weight_min}, {reweight.weight_max}]"
|
|
|
)
|
|
|
print(f" curve_layout={curve_layout}")
|
|
|
print(" note: 当前重点训练 pressure + derivative;可显式关闭 schedule 分支做固定制度对照")
|
|
|
|
|
|
|
|
|
def format_metric_line(epoch: int, train_metrics: dict[str, float], val_metrics: dict[str, float]) -> str:
|
|
|
"""格式化单个 epoch 的训练与验证指标。"""
|
|
|
return (
|
|
|
f"[Epoch {epoch:03d}] "
|
|
|
f"train={train_metrics['loss']:.6f} "
|
|
|
f"(p={train_metrics['loss_pressure']:.6f}, "
|
|
|
f"d={train_metrics['loss_derivative']:.6f}, "
|
|
|
f"bp={train_metrics['loss_bias_pressure']:.6f}, "
|
|
|
f"bd={train_metrics['loss_bias_derivative']:.6f}, "
|
|
|
f"ds={train_metrics['loss_derivative_shape']:.6f}, "
|
|
|
f"ap={train_metrics['loss_autofit_pressure']:.6f}, "
|
|
|
f"ad={train_metrics['loss_autofit_derivative']:.6f}, "
|
|
|
f"wmean={train_metrics['sample_weight_mean']:.4f}, "
|
|
|
f"wmax={train_metrics['sample_weight_max']:.4f}) "
|
|
|
f"val={val_metrics['loss']:.6f} "
|
|
|
f"(p={val_metrics['loss_pressure']:.6f}, "
|
|
|
f"d={val_metrics['loss_derivative']:.6f}, "
|
|
|
f"bp={val_metrics['loss_bias_pressure']:.6f}, "
|
|
|
f"bd={val_metrics['loss_bias_derivative']:.6f}, "
|
|
|
f"ds={val_metrics['loss_derivative_shape']:.6f}, "
|
|
|
f"ap={val_metrics['loss_autofit_pressure']:.6f}, "
|
|
|
f"ad={val_metrics['loss_autofit_derivative']:.6f}, "
|
|
|
f"wmean={val_metrics['sample_weight_mean']:.4f}, "
|
|
|
f"wmax={val_metrics['sample_weight_max']:.4f})"
|
|
|
)
|
|
|
|
|
|
|
|
|
def format_final_line(test_metrics: dict[str, float]) -> str:
|
|
|
"""格式化最终测试集指标。"""
|
|
|
return (
|
|
|
f"[Final] test={test_metrics['loss']:.6f} "
|
|
|
f"(p={test_metrics['loss_pressure']:.6f}, "
|
|
|
f"d={test_metrics['loss_derivative']:.6f}, "
|
|
|
f"bp={test_metrics['loss_bias_pressure']:.6f}, "
|
|
|
f"bd={test_metrics['loss_bias_derivative']:.6f}, "
|
|
|
f"ds={test_metrics['loss_derivative_shape']:.6f}, "
|
|
|
f"ap={test_metrics['loss_autofit_pressure']:.6f}, "
|
|
|
f"ad={test_metrics['loss_autofit_derivative']:.6f}, "
|
|
|
f"wmean={test_metrics['sample_weight_mean']:.4f}, "
|
|
|
f"wmax={test_metrics['sample_weight_max']:.4f})"
|
|
|
)
|
|
|
|
|
|
|
|
|
def build_checkpoint_payload(
|
|
|
model: nn.Module,
|
|
|
bundle: DatasetBundle,
|
|
|
cfg: TrainConfig,
|
|
|
curve_layout: dict,
|
|
|
) -> dict[str, Any]:
|
|
|
"""构建 checkpoint 保存内容。"""
|
|
|
return {
|
|
|
"model_state_dict": model.state_dict(),
|
|
|
"param_dim": bundle.param_dim,
|
|
|
"schedule_dim": bundle.schedule_dim,
|
|
|
"curve_dim": bundle.curve_dim,
|
|
|
"hidden_dim": cfg.model.hidden_dim,
|
|
|
"dropout": cfg.model.dropout,
|
|
|
"use_schedule": cfg.model.use_schedule,
|
|
|
"seed": int(cfg.runtime.seed),
|
|
|
"curve_layout": curve_layout,
|
|
|
"loss_weights": asdict(cfg.loss.weights),
|
|
|
"sample_reweight": asdict(cfg.sample_reweight),
|
|
|
}
|
|
|
|
|
|
|
|
|
def append_history_row(
|
|
|
history: list[dict],
|
|
|
epoch: int,
|
|
|
train_metrics: dict[str, float],
|
|
|
val_metrics: dict[str, float],
|
|
|
) -> None:
|
|
|
"""把当前 epoch 的指标写入 history 列表。"""
|
|
|
row = {"epoch": epoch}
|
|
|
row.update({f"train_{key}": float(value) for key, value in train_metrics.items()})
|
|
|
row.update({f"val_{key}": float(value) for key, value in val_metrics.items()})
|
|
|
history.append(row)
|
|
|
|
|
|
|
|
|
def save_json(path: Path, payload: dict | list) -> None:
|
|
|
"""保存 JSON 文件。"""
|
|
|
with open(path, "w", encoding="utf-8") as file_obj:
|
|
|
json.dump(payload, file_obj, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
def train_epochs(
|
|
|
model: nn.Module,
|
|
|
bundle: DatasetBundle,
|
|
|
cfg: TrainConfig,
|
|
|
context: LossContext,
|
|
|
curve_layout: dict,
|
|
|
) -> tuple[float, Path, list[dict]]:
|
|
|
"""执行训练循环并保存最佳模型。"""
|
|
|
optimizer = build_optimizer(model, cfg.optim)
|
|
|
best_val = float("inf")
|
|
|
best_path = cfg.output_dir / "forward_surrogate_best.pt"
|
|
|
history: list[dict] = []
|
|
|
|
|
|
for epoch in range(1, cfg.optim.epochs + 1):
|
|
|
train_metrics = run_loader_epoch(
|
|
|
model=model,
|
|
|
loader=bundle.train_loader,
|
|
|
device=cfg.runtime.device,
|
|
|
context=context,
|
|
|
use_schedule=cfg.model.use_schedule,
|
|
|
optimizer=optimizer,
|
|
|
)
|
|
|
val_metrics = evaluate(
|
|
|
model=model,
|
|
|
loader=bundle.val_loader,
|
|
|
device=cfg.runtime.device,
|
|
|
context=context,
|
|
|
use_schedule=cfg.model.use_schedule,
|
|
|
)
|
|
|
|
|
|
append_history_row(history, epoch, train_metrics, val_metrics)
|
|
|
print(format_metric_line(epoch, train_metrics, val_metrics))
|
|
|
|
|
|
if val_metrics["loss"] < best_val:
|
|
|
best_val = val_metrics["loss"]
|
|
|
torch.save(build_checkpoint_payload(model, bundle, cfg, curve_layout), best_path)
|
|
|
print(f" -> best model saved to: {best_path}")
|
|
|
|
|
|
return best_val, best_path, history
|
|
|
|
|
|
|
|
|
def load_best_model(best_path: Path, device: str) -> nn.Module:
|
|
|
"""从 checkpoint 恢复最佳模型。"""
|
|
|
checkpoint = torch.load(best_path, map_location=device)
|
|
|
model_cfg = ModelConfig(
|
|
|
hidden_dim=checkpoint["hidden_dim"],
|
|
|
dropout=checkpoint["dropout"],
|
|
|
use_schedule=checkpoint.get("use_schedule", True),
|
|
|
)
|
|
|
best_model = build_forward_model(
|
|
|
model_cfg=model_cfg,
|
|
|
param_dim=checkpoint["param_dim"],
|
|
|
schedule_dim=checkpoint["schedule_dim"],
|
|
|
curve_dim=checkpoint["curve_dim"],
|
|
|
device=device,
|
|
|
)
|
|
|
best_model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
return best_model
|
|
|
|
|
|
|
|
|
def build_metrics_payload(
|
|
|
best_val: float,
|
|
|
test_metrics: dict[str, float],
|
|
|
cfg: TrainConfig,
|
|
|
curve_layout: dict,
|
|
|
) -> dict[str, Any]:
|
|
|
"""构建 metrics.json 内容。"""
|
|
|
return {
|
|
|
"best_val_loss": float(best_val),
|
|
|
"test_metrics": {key: float(value) for key, value in test_metrics.items()},
|
|
|
"use_schedule": cfg.model.use_schedule,
|
|
|
"seed": int(cfg.runtime.seed),
|
|
|
"loss_weights": asdict(cfg.loss.weights),
|
|
|
"sample_reweight": asdict(cfg.sample_reweight),
|
|
|
"curve_layout": curve_layout,
|
|
|
}
|
|
|
|
|
|
|
|
|
def train_forward(cfg: TrainConfig) -> None:
|
|
|
"""训练完整曲线正演代理模型。"""
|
|
|
cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
set_global_seed(int(cfg.runtime.seed))
|
|
|
|
|
|
data = load_processed_dataset(cfg.processed_path)
|
|
|
curve_layout = infer_curve_layout(data)
|
|
|
context = LossContext(
|
|
|
slices=get_part_slices(curve_layout),
|
|
|
curve_stats=build_curve_stats(data, cfg.runtime.device),
|
|
|
loss_cfg=cfg.loss,
|
|
|
reweight_cfg=cfg.sample_reweight,
|
|
|
)
|
|
|
bundle = build_dataloaders(data, cfg)
|
|
|
|
|
|
model = build_forward_model(
|
|
|
model_cfg=cfg.model,
|
|
|
param_dim=bundle.param_dim,
|
|
|
schedule_dim=bundle.schedule_dim,
|
|
|
curve_dim=bundle.curve_dim,
|
|
|
device=cfg.runtime.device,
|
|
|
)
|
|
|
|
|
|
print_training_config(cfg, curve_layout)
|
|
|
best_val, best_path, history = train_epochs(model, bundle, cfg, context, curve_layout)
|
|
|
save_json(cfg.output_dir / "history.json", history)
|
|
|
|
|
|
best_model = load_best_model(best_path, cfg.runtime.device)
|
|
|
test_metrics = evaluate(
|
|
|
model=best_model,
|
|
|
loader=bundle.test_loader,
|
|
|
device=cfg.runtime.device,
|
|
|
context=context,
|
|
|
use_schedule=cfg.model.use_schedule,
|
|
|
)
|
|
|
|
|
|
print(format_final_line(test_metrics))
|
|
|
save_json(
|
|
|
cfg.output_dir / "metrics.json",
|
|
|
build_metrics_payload(best_val, test_metrics, cfg, curve_layout),
|
|
|
)
|