|
|
# -*- coding: utf-8 -*-
|
|
|
"""正演代理模型网络结构。
|
|
|
|
|
|
ForwardSurrogate 输入标准化后的物理参数特征和可选的流量制度编码,输出固定长度
|
|
|
拼接曲线:log_pressure 和 log_derivative。模型采用“参数分支 + 流量制度
|
|
|
分支 + 融合主干 + 多输出头”的结构,便于分别学习静态地层信息、动态制度信息以及
|
|
|
二者共同决定的曲线形态。
|
|
|
|
|
|
压力和导数输出被拆成 level 与 shape 两部分:level 学习整条曲线的纵向偏移,shape
|
|
|
学习去均值后的局部形态,从结构上减少整体幅值与局部形状之间的相互干扰。
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
# pylint: disable=import-error,duplicate-code,too-many-arguments,too-many-positional-arguments
|
|
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
|
|
import torch
|
|
|
from torch import nn
|
|
|
|
|
|
|
|
|
@dataclass(slots=True)
|
|
|
class ForwardSurrogateConfig:
|
|
|
"""ForwardSurrogate 的结构配置。
|
|
|
|
|
|
使用配置对象可以避免模型构造函数参数过多,同时让训练脚本中的超参数更集中。
|
|
|
"""
|
|
|
|
|
|
param_dim: int
|
|
|
schedule_dim: int
|
|
|
curve_dim: int
|
|
|
hidden_dim: int = 128
|
|
|
fusion_hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
|
|
dropout: float = 0.0
|
|
|
use_schedule: bool = True
|
|
|
|
|
|
|
|
|
def build_mlp(
|
|
|
in_dim: int,
|
|
|
hidden_dims: list[int],
|
|
|
out_dim: int,
|
|
|
dropout: float = 0.0,
|
|
|
) -> nn.Sequential:
|
|
|
"""按隐藏层列表搭建 Linear-ReLU-Dropout 组成的多层感知机。"""
|
|
|
layers: list[nn.Module] = []
|
|
|
prev_dim = in_dim
|
|
|
for hidden_dim in hidden_dims:
|
|
|
layers.append(nn.Linear(prev_dim, hidden_dim))
|
|
|
layers.append(nn.ReLU())
|
|
|
if dropout > 0:
|
|
|
layers.append(nn.Dropout(dropout))
|
|
|
prev_dim = hidden_dim
|
|
|
|
|
|
layers.append(nn.Linear(prev_dim, out_dim))
|
|
|
return nn.Sequential(*layers)
|
|
|
|
|
|
|
|
|
class ScheduleEncoder(nn.Module):
|
|
|
"""神经网络中的流量制度分支,把固定长度制度向量编码为隐层特征。"""
|
|
|
|
|
|
def __init__(self, schedule_dim: int, hidden_dim: int, dropout: float = 0.0):
|
|
|
"""按流量制度向量维度构建两层编码网络。"""
|
|
|
super().__init__()
|
|
|
self.net = nn.Sequential(
|
|
|
nn.Linear(schedule_dim, hidden_dim),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
|
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
|
nn.ReLU(),
|
|
|
)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""把流量制度统计特征映射到与参数分支同宽度的隐藏表示。"""
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
class ParamEncoder(nn.Module):
|
|
|
"""神经网络中的参数分支,把变换后的物理参数编码为隐层特征。"""
|
|
|
|
|
|
def __init__(self, param_dim: int, hidden_dim: int, dropout: float = 0.0):
|
|
|
"""按物理参数特征维度构建两层编码网络。"""
|
|
|
super().__init__()
|
|
|
self.net = nn.Sequential(
|
|
|
nn.Linear(param_dim, hidden_dim),
|
|
|
nn.ReLU(),
|
|
|
nn.Dropout(dropout) if dropout > 0 else nn.Identity(),
|
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
|
nn.ReLU(),
|
|
|
)
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
"""把变换后的地层和井筒参数映射为隐藏表示。"""
|
|
|
return self.net(x)
|
|
|
|
|
|
|
|
|
class ForwardSurrogate(nn.Module):
|
|
|
"""完整曲线正演代理模型。
|
|
|
|
|
|
输入:
|
|
|
params_x: 标准化后的物理参数特征,形状 [B, param_dim]。
|
|
|
schedule_x: 标准化后的流量制度向量,形状 [B, schedule_dim];
|
|
|
当 use_schedule=False 时该输入可为空。
|
|
|
|
|
|
输出:
|
|
|
curve_pred: 形状 [B, curve_dim],按 log_pressure、log_derivative 顺序拼接。
|
|
|
新模型使用双通道布局;旧版三通道 checkpoint 仍可加载用于兼容部署。
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: ForwardSurrogateConfig | None = None,
|
|
|
*,
|
|
|
param_dim: int | None = None,
|
|
|
schedule_dim: int | None = None,
|
|
|
curve_dim: int | None = None,
|
|
|
hidden_dim: int = 128,
|
|
|
fusion_hidden_dims: list[int] | None = None,
|
|
|
dropout: float = 0.0,
|
|
|
use_schedule: bool = True,
|
|
|
):
|
|
|
"""构建参数分支、可选流量制度分支、融合主干和曲线输出头。"""
|
|
|
super().__init__()
|
|
|
config = self._coerce_config(
|
|
|
config=config,
|
|
|
param_dim=param_dim,
|
|
|
schedule_dim=schedule_dim,
|
|
|
curve_dim=curve_dim,
|
|
|
hidden_dim=hidden_dim,
|
|
|
fusion_hidden_dims=fusion_hidden_dims,
|
|
|
dropout=dropout,
|
|
|
use_schedule=use_schedule,
|
|
|
)
|
|
|
self._validate_config(config)
|
|
|
|
|
|
self.config = config
|
|
|
self.encoders = self._build_encoders()
|
|
|
self.trunk = self._build_trunk()
|
|
|
self.heads = self._build_heads()
|
|
|
|
|
|
@staticmethod
|
|
|
def _coerce_config(
|
|
|
config: ForwardSurrogateConfig | None,
|
|
|
*,
|
|
|
param_dim: int | None,
|
|
|
schedule_dim: int | None,
|
|
|
curve_dim: int | None,
|
|
|
hidden_dim: int,
|
|
|
fusion_hidden_dims: list[int] | None,
|
|
|
dropout: float,
|
|
|
use_schedule: bool,
|
|
|
) -> ForwardSurrogateConfig:
|
|
|
"""兼容配置对象式构造和旧版关键字参数式构造。"""
|
|
|
if config is not None:
|
|
|
return config
|
|
|
if param_dim is None or schedule_dim is None or curve_dim is None:
|
|
|
raise TypeError(
|
|
|
"ForwardSurrogate requires either a ForwardSurrogateConfig or "
|
|
|
"param_dim, schedule_dim and curve_dim keyword arguments"
|
|
|
)
|
|
|
return ForwardSurrogateConfig(
|
|
|
param_dim=int(param_dim),
|
|
|
schedule_dim=int(schedule_dim),
|
|
|
curve_dim=int(curve_dim),
|
|
|
hidden_dim=int(hidden_dim),
|
|
|
fusion_hidden_dims=fusion_hidden_dims or [256, 256],
|
|
|
dropout=float(dropout),
|
|
|
use_schedule=bool(use_schedule),
|
|
|
)
|
|
|
|
|
|
@property
|
|
|
def curve_dim(self) -> int:
|
|
|
"""曲线拼接后的总维度。"""
|
|
|
return self.config.curve_dim
|
|
|
|
|
|
@property
|
|
|
def part_dim(self) -> int:
|
|
|
"""每一段曲线的时间点数量。"""
|
|
|
return self.config.curve_dim // self.n_curve_parts
|
|
|
|
|
|
@property
|
|
|
def n_curve_parts(self) -> int:
|
|
|
"""返回输出通道数;480 等旧 checkpoint 保持三通道兼容。"""
|
|
|
return 3 if self.config.curve_dim % 3 == 0 else 2
|
|
|
|
|
|
@property
|
|
|
def use_schedule(self) -> bool:
|
|
|
"""是否启用流量制度分支。"""
|
|
|
return bool(self.config.use_schedule)
|
|
|
|
|
|
@staticmethod
|
|
|
def _validate_config(config: ForwardSurrogateConfig) -> None:
|
|
|
"""检查模型配置是否满足网络结构约束。"""
|
|
|
if config.curve_dim % 2 != 0 and config.curve_dim % 3 != 0:
|
|
|
msg = (
|
|
|
f"curve_dim={config.curve_dim} 无法识别;"
|
|
|
"期望为 pressure/derivative 双通道或旧版三通道布局"
|
|
|
)
|
|
|
raise ValueError(msg)
|
|
|
|
|
|
if not config.fusion_hidden_dims:
|
|
|
raise ValueError("fusion_hidden_dims 不能为空")
|
|
|
|
|
|
def _build_encoders(self) -> nn.ModuleDict:
|
|
|
"""构建参数分支和可选流量制度分支。"""
|
|
|
encoders = nn.ModuleDict(
|
|
|
{
|
|
|
"param": ParamEncoder(
|
|
|
self.config.param_dim,
|
|
|
self.config.hidden_dim,
|
|
|
dropout=self.config.dropout,
|
|
|
)
|
|
|
}
|
|
|
)
|
|
|
|
|
|
if self.use_schedule:
|
|
|
encoders["schedule"] = ScheduleEncoder(
|
|
|
self.config.schedule_dim,
|
|
|
self.config.hidden_dim,
|
|
|
dropout=self.config.dropout,
|
|
|
)
|
|
|
|
|
|
return encoders
|
|
|
|
|
|
def _build_trunk(self) -> nn.Sequential:
|
|
|
"""构建融合主干网络。"""
|
|
|
trunk_in_dim = self.config.hidden_dim * 2 if self.use_schedule else self.config.hidden_dim
|
|
|
trunk_out_dim = self.config.fusion_hidden_dims[-1]
|
|
|
return build_mlp(
|
|
|
in_dim=trunk_in_dim,
|
|
|
hidden_dims=self.config.fusion_hidden_dims,
|
|
|
out_dim=trunk_out_dim,
|
|
|
dropout=self.config.dropout,
|
|
|
)
|
|
|
|
|
|
def _build_heads(self) -> nn.ModuleDict:
|
|
|
"""构建压力和导数输出头;仅旧版三通道模型增加 slope 头。"""
|
|
|
trunk_out_dim = self.config.fusion_hidden_dims[-1]
|
|
|
heads = {
|
|
|
"pressure_level": self._build_single_head(trunk_out_dim, 1),
|
|
|
"pressure_shape": self._build_single_head(trunk_out_dim, self.part_dim),
|
|
|
"derivative_level": self._build_single_head(trunk_out_dim, 1),
|
|
|
"derivative_shape": self._build_single_head(trunk_out_dim, self.part_dim),
|
|
|
}
|
|
|
if self.n_curve_parts == 3:
|
|
|
heads["slope"] = self._build_single_head(trunk_out_dim, self.part_dim)
|
|
|
return nn.ModuleDict(heads)
|
|
|
|
|
|
def _build_single_head(self, in_dim: int, out_dim: int) -> nn.Sequential:
|
|
|
"""构建一个曲线输出头。"""
|
|
|
return build_mlp(
|
|
|
in_dim=in_dim,
|
|
|
hidden_dims=[128],
|
|
|
out_dim=out_dim,
|
|
|
dropout=self.config.dropout,
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
|
def _upgrade_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
|
|
"""把旧版 checkpoint 键名转换为当前 ModuleDict 结构键名。"""
|
|
|
prefix_map = {
|
|
|
"param_encoder.": "encoders.param.",
|
|
|
"schedule_encoder.": "encoders.schedule.",
|
|
|
"pressure_level_head.": "heads.pressure_level.",
|
|
|
"pressure_shape_head.": "heads.pressure_shape.",
|
|
|
"derivative_level_head.": "heads.derivative_level.",
|
|
|
"derivative_shape_head.": "heads.derivative_shape.",
|
|
|
"slope_head.": "heads.slope.",
|
|
|
}
|
|
|
upgraded: dict[str, torch.Tensor] = {}
|
|
|
for key, value in state_dict.items():
|
|
|
new_key = key
|
|
|
for old_prefix, new_prefix in prefix_map.items():
|
|
|
if key.startswith(old_prefix):
|
|
|
new_key = new_prefix + key[len(old_prefix) :]
|
|
|
break
|
|
|
upgraded[new_key] = value
|
|
|
return upgraded
|
|
|
|
|
|
def load_state_dict(
|
|
|
self,
|
|
|
state_dict: dict[str, torch.Tensor],
|
|
|
strict: bool = True,
|
|
|
assign: bool = False,
|
|
|
):
|
|
|
"""加载当前或旧版 ForwardSurrogate checkpoint。"""
|
|
|
return super().load_state_dict(
|
|
|
self._upgrade_state_dict_keys(state_dict),
|
|
|
strict=strict,
|
|
|
assign=assign,
|
|
|
)
|
|
|
|
|
|
@staticmethod
|
|
|
def center_shape(x: torch.Tensor) -> torch.Tensor:
|
|
|
"""去除每个样本 shape 分支的均值,让 level 分支专门学习整体偏移。"""
|
|
|
return x - x.mean(dim=1, keepdim=True)
|
|
|
|
|
|
def _encode_features(
|
|
|
self,
|
|
|
params_x: torch.Tensor,
|
|
|
schedule_x: torch.Tensor | None,
|
|
|
) -> torch.Tensor:
|
|
|
"""分别编码物理参数和流量制度,再在隐空间融合。"""
|
|
|
param_feat = self.encoders["param"](params_x)
|
|
|
|
|
|
if not self.use_schedule:
|
|
|
return param_feat
|
|
|
|
|
|
if schedule_x is None:
|
|
|
raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x")
|
|
|
|
|
|
schedule_feat = self.encoders["schedule"](schedule_x)
|
|
|
return torch.cat([param_feat, schedule_feat], dim=-1)
|
|
|
|
|
|
def _predict_level_shape(
|
|
|
self,
|
|
|
trunk_feat: torch.Tensor,
|
|
|
level_head: str,
|
|
|
shape_head: str,
|
|
|
) -> torch.Tensor:
|
|
|
"""用 level + centered shape 生成一段曲线。"""
|
|
|
level = self.heads[level_head](trunk_feat)
|
|
|
shape = self.center_shape(self.heads[shape_head](trunk_feat))
|
|
|
return level + shape
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
params_x: torch.Tensor,
|
|
|
schedule_x: torch.Tensor | None = None,
|
|
|
) -> torch.Tensor:
|
|
|
"""执行一次前向预测。
|
|
|
|
|
|
参数分支和流量制度分支先分别编码,再在隐空间拼接。融合主干提取共同特征后,
|
|
|
压力和导数各自通过 level + centered shape 两个输出头生成。旧版三通道模型
|
|
|
会继续生成 slope,以便已有 checkpoint 保持可加载。
|
|
|
"""
|
|
|
fused_feat = self._encode_features(params_x, schedule_x)
|
|
|
trunk_feat = self.trunk(fused_feat)
|
|
|
|
|
|
pressure_pred = self._predict_level_shape(
|
|
|
trunk_feat,
|
|
|
"pressure_level",
|
|
|
"pressure_shape",
|
|
|
)
|
|
|
derivative_pred = self._predict_level_shape(
|
|
|
trunk_feat,
|
|
|
"derivative_level",
|
|
|
"derivative_shape",
|
|
|
)
|
|
|
outputs = [pressure_pred, derivative_pred]
|
|
|
if "slope" in self.heads:
|
|
|
outputs.append(self.heads["slope"](trunk_feat))
|
|
|
return torch.cat(outputs, dim=1)
|