You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/src/models/forward_surrogate.py

354 lines
12 KiB
Python

# -*- coding: utf-8 -*-
"""正演代理模型网络结构。
ForwardSurrogate 输入标准化后的物理参数特征和可选的流量制度编码输出固定长度
拼接曲线log_pressure log_derivative模型采用参数分支 + 流量制度
分支 + 融合主干 + 多输出头的结构便于分别学习静态地层信息动态制度信息以及
二者共同决定的曲线形态
压力和导数输出被拆成 level shape 两部分level 学习整条曲线的纵向偏移shape
学习去均值后的局部形态从结构上减少整体幅值与局部形状之间的相互干扰
"""
from __future__ import annotations
# pylint: disable=import-error,duplicate-code,too-many-arguments,too-many-positional-arguments
from dataclasses import dataclass, field
import torch
from torch import nn
@dataclass(slots=True)
class ForwardSurrogateConfig:
"""ForwardSurrogate 的结构配置。
使用配置对象可以避免模型构造函数参数过多同时让训练脚本中的超参数更集中
"""
param_dim: int
schedule_dim: int
curve_dim: int
hidden_dim: int = 128
fusion_hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
dropout: float = 0.0
use_schedule: bool = True
def build_mlp(
in_dim: int,
hidden_dims: list[int],
out_dim: int,
dropout: float = 0.0,
) -> nn.Sequential:
"""按隐藏层列表搭建 Linear-ReLU-Dropout 组成的多层感知机。"""
layers: list[nn.Module] = []
prev_dim = in_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
if dropout > 0:
layers.append(nn.Dropout(dropout))
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, out_dim))
return nn.Sequential(*layers)
class ScheduleEncoder(nn.Module):
"""神经网络中的流量制度分支,把固定长度制度向量编码为隐层特征。"""
def __init__(self, schedule_dim: int, hidden_dim: int, dropout: float = 0.0):
"""按流量制度向量维度构建两层编码网络。"""
super().__init__()
self.net = nn.Sequential(
nn.Linear(schedule_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""把流量制度统计特征映射到与参数分支同宽度的隐藏表示。"""
return self.net(x)
class ParamEncoder(nn.Module):
"""神经网络中的参数分支,把变换后的物理参数编码为隐层特征。"""
def __init__(self, param_dim: int, hidden_dim: int, dropout: float = 0.0):
"""按物理参数特征维度构建两层编码网络。"""
super().__init__()
self.net = nn.Sequential(
nn.Linear(param_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""把变换后的地层和井筒参数映射为隐藏表示。"""
return self.net(x)
class ForwardSurrogate(nn.Module):
"""完整曲线正演代理模型。
输入:
params_x: 标准化后的物理参数特征形状 [B, param_dim]
schedule_x: 标准化后的流量制度向量形状 [B, schedule_dim]
use_schedule=False 时该输入可为空
输出:
curve_pred: 形状 [B, curve_dim] log_pressurelog_derivative 顺序拼接
新模型使用双通道布局旧版三通道 checkpoint 仍可加载用于兼容部署
"""
def __init__(
self,
config: ForwardSurrogateConfig | None = None,
*,
param_dim: int | None = None,
schedule_dim: int | None = None,
curve_dim: int | None = None,
hidden_dim: int = 128,
fusion_hidden_dims: list[int] | None = None,
dropout: float = 0.0,
use_schedule: bool = True,
):
"""构建参数分支、可选流量制度分支、融合主干和曲线输出头。"""
super().__init__()
config = self._coerce_config(
config=config,
param_dim=param_dim,
schedule_dim=schedule_dim,
curve_dim=curve_dim,
hidden_dim=hidden_dim,
fusion_hidden_dims=fusion_hidden_dims,
dropout=dropout,
use_schedule=use_schedule,
)
self._validate_config(config)
self.config = config
self.encoders = self._build_encoders()
self.trunk = self._build_trunk()
self.heads = self._build_heads()
@staticmethod
def _coerce_config(
config: ForwardSurrogateConfig | None,
*,
param_dim: int | None,
schedule_dim: int | None,
curve_dim: int | None,
hidden_dim: int,
fusion_hidden_dims: list[int] | None,
dropout: float,
use_schedule: bool,
) -> ForwardSurrogateConfig:
"""兼容配置对象式构造和旧版关键字参数式构造。"""
if config is not None:
return config
if param_dim is None or schedule_dim is None or curve_dim is None:
raise TypeError(
"ForwardSurrogate requires either a ForwardSurrogateConfig or "
"param_dim, schedule_dim and curve_dim keyword arguments"
)
return ForwardSurrogateConfig(
param_dim=int(param_dim),
schedule_dim=int(schedule_dim),
curve_dim=int(curve_dim),
hidden_dim=int(hidden_dim),
fusion_hidden_dims=fusion_hidden_dims or [256, 256],
dropout=float(dropout),
use_schedule=bool(use_schedule),
)
@property
def curve_dim(self) -> int:
"""曲线拼接后的总维度。"""
return self.config.curve_dim
@property
def part_dim(self) -> int:
"""每一段曲线的时间点数量。"""
return self.config.curve_dim // self.n_curve_parts
@property
def n_curve_parts(self) -> int:
"""返回输出通道数480 等旧 checkpoint 保持三通道兼容。"""
return 3 if self.config.curve_dim % 3 == 0 else 2
@property
def use_schedule(self) -> bool:
"""是否启用流量制度分支。"""
return bool(self.config.use_schedule)
@staticmethod
def _validate_config(config: ForwardSurrogateConfig) -> None:
"""检查模型配置是否满足网络结构约束。"""
if config.curve_dim % 2 != 0 and config.curve_dim % 3 != 0:
msg = (
f"curve_dim={config.curve_dim} 无法识别;"
"期望为 pressure/derivative 双通道或旧版三通道布局"
)
raise ValueError(msg)
if not config.fusion_hidden_dims:
raise ValueError("fusion_hidden_dims 不能为空")
def _build_encoders(self) -> nn.ModuleDict:
"""构建参数分支和可选流量制度分支。"""
encoders = nn.ModuleDict(
{
"param": ParamEncoder(
self.config.param_dim,
self.config.hidden_dim,
dropout=self.config.dropout,
)
}
)
if self.use_schedule:
encoders["schedule"] = ScheduleEncoder(
self.config.schedule_dim,
self.config.hidden_dim,
dropout=self.config.dropout,
)
return encoders
def _build_trunk(self) -> nn.Sequential:
"""构建融合主干网络。"""
trunk_in_dim = self.config.hidden_dim * 2 if self.use_schedule else self.config.hidden_dim
trunk_out_dim = self.config.fusion_hidden_dims[-1]
return build_mlp(
in_dim=trunk_in_dim,
hidden_dims=self.config.fusion_hidden_dims,
out_dim=trunk_out_dim,
dropout=self.config.dropout,
)
def _build_heads(self) -> nn.ModuleDict:
"""构建压力和导数输出头;仅旧版三通道模型增加 slope 头。"""
trunk_out_dim = self.config.fusion_hidden_dims[-1]
heads = {
"pressure_level": self._build_single_head(trunk_out_dim, 1),
"pressure_shape": self._build_single_head(trunk_out_dim, self.part_dim),
"derivative_level": self._build_single_head(trunk_out_dim, 1),
"derivative_shape": self._build_single_head(trunk_out_dim, self.part_dim),
}
if self.n_curve_parts == 3:
heads["slope"] = self._build_single_head(trunk_out_dim, self.part_dim)
return nn.ModuleDict(heads)
def _build_single_head(self, in_dim: int, out_dim: int) -> nn.Sequential:
"""构建一个曲线输出头。"""
return build_mlp(
in_dim=in_dim,
hidden_dims=[128],
out_dim=out_dim,
dropout=self.config.dropout,
)
@staticmethod
def _upgrade_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""把旧版 checkpoint 键名转换为当前 ModuleDict 结构键名。"""
prefix_map = {
"param_encoder.": "encoders.param.",
"schedule_encoder.": "encoders.schedule.",
"pressure_level_head.": "heads.pressure_level.",
"pressure_shape_head.": "heads.pressure_shape.",
"derivative_level_head.": "heads.derivative_level.",
"derivative_shape_head.": "heads.derivative_shape.",
"slope_head.": "heads.slope.",
}
upgraded: dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
new_key = key
for old_prefix, new_prefix in prefix_map.items():
if key.startswith(old_prefix):
new_key = new_prefix + key[len(old_prefix) :]
break
upgraded[new_key] = value
return upgraded
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
strict: bool = True,
assign: bool = False,
):
"""加载当前或旧版 ForwardSurrogate checkpoint。"""
return super().load_state_dict(
self._upgrade_state_dict_keys(state_dict),
strict=strict,
assign=assign,
)
@staticmethod
def center_shape(x: torch.Tensor) -> torch.Tensor:
"""去除每个样本 shape 分支的均值,让 level 分支专门学习整体偏移。"""
return x - x.mean(dim=1, keepdim=True)
def _encode_features(
self,
params_x: torch.Tensor,
schedule_x: torch.Tensor | None,
) -> torch.Tensor:
"""分别编码物理参数和流量制度,再在隐空间融合。"""
param_feat = self.encoders["param"](params_x)
if not self.use_schedule:
return param_feat
if schedule_x is None:
raise ValueError("use_schedule=True但 forward 没有传入 schedule_x")
schedule_feat = self.encoders["schedule"](schedule_x)
return torch.cat([param_feat, schedule_feat], dim=-1)
def _predict_level_shape(
self,
trunk_feat: torch.Tensor,
level_head: str,
shape_head: str,
) -> torch.Tensor:
"""用 level + centered shape 生成一段曲线。"""
level = self.heads[level_head](trunk_feat)
shape = self.center_shape(self.heads[shape_head](trunk_feat))
return level + shape
def forward(
self,
params_x: torch.Tensor,
schedule_x: torch.Tensor | None = None,
) -> torch.Tensor:
"""执行一次前向预测。
参数分支和流量制度分支先分别编码再在隐空间拼接融合主干提取共同特征后
压力和导数各自通过 level + centered shape 两个输出头生成旧版三通道模型
会继续生成 slope以便已有 checkpoint 保持可加载
"""
fused_feat = self._encode_features(params_x, schedule_x)
trunk_feat = self.trunk(fused_feat)
pressure_pred = self._predict_level_shape(
trunk_feat,
"pressure_level",
"pressure_shape",
)
derivative_pred = self._predict_level_shape(
trunk_feat,
"derivative_level",
"derivative_shape",
)
outputs = [pressure_pred, derivative_pred]
if "slope" in self.heads:
outputs.append(self.heads["slope"](trunk_feat))
return torch.cat(outputs, dim=1)