# -*- coding: utf-8 -*- """正演代理模型网络结构。 ForwardSurrogate 输入标准化后的物理参数特征和可选的流量制度编码,输出固定长度 拼接曲线:log_pressure 和 log_derivative。模型采用“参数分支 + 流量制度 分支 + 融合主干 + 多输出头”的结构,便于分别学习静态地层信息、动态制度信息以及 二者共同决定的曲线形态。 压力和导数输出被拆成 level 与 shape 两部分:level 学习整条曲线的纵向偏移,shape 学习去均值后的局部形态,从结构上减少整体幅值与局部形状之间的相互干扰。 """ from __future__ import annotations # pylint: disable=import-error,duplicate-code,too-many-arguments,too-many-positional-arguments from dataclasses import dataclass, field import torch from torch import nn @dataclass(slots=True) class ForwardSurrogateConfig: """ForwardSurrogate 的结构配置。 使用配置对象可以避免模型构造函数参数过多,同时让训练脚本中的超参数更集中。 """ param_dim: int schedule_dim: int curve_dim: int hidden_dim: int = 128 fusion_hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) dropout: float = 0.0 use_schedule: bool = True def build_mlp( in_dim: int, hidden_dims: list[int], out_dim: int, dropout: float = 0.0, ) -> nn.Sequential: """按隐藏层列表搭建 Linear-ReLU-Dropout 组成的多层感知机。""" layers: list[nn.Module] = [] prev_dim = in_dim for hidden_dim in hidden_dims: layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.ReLU()) if dropout > 0: layers.append(nn.Dropout(dropout)) prev_dim = hidden_dim layers.append(nn.Linear(prev_dim, out_dim)) return nn.Sequential(*layers) class ScheduleEncoder(nn.Module): """神经网络中的流量制度分支,把固定长度制度向量编码为隐层特征。""" def __init__(self, schedule_dim: int, hidden_dim: int, dropout: float = 0.0): """按流量制度向量维度构建两层编码网络。""" super().__init__() self.net = nn.Sequential( nn.Linear(schedule_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout) if dropout > 0 else nn.Identity(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """把流量制度统计特征映射到与参数分支同宽度的隐藏表示。""" return self.net(x) class ParamEncoder(nn.Module): """神经网络中的参数分支,把变换后的物理参数编码为隐层特征。""" def __init__(self, param_dim: int, hidden_dim: int, dropout: float = 0.0): """按物理参数特征维度构建两层编码网络。""" super().__init__() self.net = nn.Sequential( nn.Linear(param_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout) if dropout > 0 else nn.Identity(), nn.Linear(hidden_dim, hidden_dim), nn.ReLU(), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """把变换后的地层和井筒参数映射为隐藏表示。""" return self.net(x) class ForwardSurrogate(nn.Module): """完整曲线正演代理模型。 输入: params_x: 标准化后的物理参数特征,形状 [B, param_dim]。 schedule_x: 标准化后的流量制度向量,形状 [B, schedule_dim]; 当 use_schedule=False 时该输入可为空。 输出: curve_pred: 形状 [B, curve_dim],按 log_pressure、log_derivative 顺序拼接。 新模型使用双通道布局;旧版三通道 checkpoint 仍可加载用于兼容部署。 """ def __init__( self, config: ForwardSurrogateConfig | None = None, *, param_dim: int | None = None, schedule_dim: int | None = None, curve_dim: int | None = None, hidden_dim: int = 128, fusion_hidden_dims: list[int] | None = None, dropout: float = 0.0, use_schedule: bool = True, ): """构建参数分支、可选流量制度分支、融合主干和曲线输出头。""" super().__init__() config = self._coerce_config( config=config, param_dim=param_dim, schedule_dim=schedule_dim, curve_dim=curve_dim, hidden_dim=hidden_dim, fusion_hidden_dims=fusion_hidden_dims, dropout=dropout, use_schedule=use_schedule, ) self._validate_config(config) self.config = config self.encoders = self._build_encoders() self.trunk = self._build_trunk() self.heads = self._build_heads() @staticmethod def _coerce_config( config: ForwardSurrogateConfig | None, *, param_dim: int | None, schedule_dim: int | None, curve_dim: int | None, hidden_dim: int, fusion_hidden_dims: list[int] | None, dropout: float, use_schedule: bool, ) -> ForwardSurrogateConfig: """兼容配置对象式构造和旧版关键字参数式构造。""" if config is not None: return config if param_dim is None or schedule_dim is None or curve_dim is None: raise TypeError( "ForwardSurrogate requires either a ForwardSurrogateConfig or " "param_dim, schedule_dim and curve_dim keyword arguments" ) return ForwardSurrogateConfig( param_dim=int(param_dim), schedule_dim=int(schedule_dim), curve_dim=int(curve_dim), hidden_dim=int(hidden_dim), fusion_hidden_dims=fusion_hidden_dims or [256, 256], dropout=float(dropout), use_schedule=bool(use_schedule), ) @property def curve_dim(self) -> int: """曲线拼接后的总维度。""" return self.config.curve_dim @property def part_dim(self) -> int: """每一段曲线的时间点数量。""" return self.config.curve_dim // self.n_curve_parts @property def n_curve_parts(self) -> int: """返回输出通道数;480 等旧 checkpoint 保持三通道兼容。""" return 3 if self.config.curve_dim % 3 == 0 else 2 @property def use_schedule(self) -> bool: """是否启用流量制度分支。""" return bool(self.config.use_schedule) @staticmethod def _validate_config(config: ForwardSurrogateConfig) -> None: """检查模型配置是否满足网络结构约束。""" if config.curve_dim % 2 != 0 and config.curve_dim % 3 != 0: msg = ( f"curve_dim={config.curve_dim} 无法识别;" "期望为 pressure/derivative 双通道或旧版三通道布局" ) raise ValueError(msg) if not config.fusion_hidden_dims: raise ValueError("fusion_hidden_dims 不能为空") def _build_encoders(self) -> nn.ModuleDict: """构建参数分支和可选流量制度分支。""" encoders = nn.ModuleDict( { "param": ParamEncoder( self.config.param_dim, self.config.hidden_dim, dropout=self.config.dropout, ) } ) if self.use_schedule: encoders["schedule"] = ScheduleEncoder( self.config.schedule_dim, self.config.hidden_dim, dropout=self.config.dropout, ) return encoders def _build_trunk(self) -> nn.Sequential: """构建融合主干网络。""" trunk_in_dim = self.config.hidden_dim * 2 if self.use_schedule else self.config.hidden_dim trunk_out_dim = self.config.fusion_hidden_dims[-1] return build_mlp( in_dim=trunk_in_dim, hidden_dims=self.config.fusion_hidden_dims, out_dim=trunk_out_dim, dropout=self.config.dropout, ) def _build_heads(self) -> nn.ModuleDict: """构建压力和导数输出头;仅旧版三通道模型增加 slope 头。""" trunk_out_dim = self.config.fusion_hidden_dims[-1] heads = { "pressure_level": self._build_single_head(trunk_out_dim, 1), "pressure_shape": self._build_single_head(trunk_out_dim, self.part_dim), "derivative_level": self._build_single_head(trunk_out_dim, 1), "derivative_shape": self._build_single_head(trunk_out_dim, self.part_dim), } if self.n_curve_parts == 3: heads["slope"] = self._build_single_head(trunk_out_dim, self.part_dim) return nn.ModuleDict(heads) def _build_single_head(self, in_dim: int, out_dim: int) -> nn.Sequential: """构建一个曲线输出头。""" return build_mlp( in_dim=in_dim, hidden_dims=[128], out_dim=out_dim, dropout=self.config.dropout, ) @staticmethod def _upgrade_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """把旧版 checkpoint 键名转换为当前 ModuleDict 结构键名。""" prefix_map = { "param_encoder.": "encoders.param.", "schedule_encoder.": "encoders.schedule.", "pressure_level_head.": "heads.pressure_level.", "pressure_shape_head.": "heads.pressure_shape.", "derivative_level_head.": "heads.derivative_level.", "derivative_shape_head.": "heads.derivative_shape.", "slope_head.": "heads.slope.", } upgraded: dict[str, torch.Tensor] = {} for key, value in state_dict.items(): new_key = key for old_prefix, new_prefix in prefix_map.items(): if key.startswith(old_prefix): new_key = new_prefix + key[len(old_prefix) :] break upgraded[new_key] = value return upgraded def load_state_dict( self, state_dict: dict[str, torch.Tensor], strict: bool = True, assign: bool = False, ): """加载当前或旧版 ForwardSurrogate checkpoint。""" return super().load_state_dict( self._upgrade_state_dict_keys(state_dict), strict=strict, assign=assign, ) @staticmethod def center_shape(x: torch.Tensor) -> torch.Tensor: """去除每个样本 shape 分支的均值,让 level 分支专门学习整体偏移。""" return x - x.mean(dim=1, keepdim=True) def _encode_features( self, params_x: torch.Tensor, schedule_x: torch.Tensor | None, ) -> torch.Tensor: """分别编码物理参数和流量制度,再在隐空间融合。""" param_feat = self.encoders["param"](params_x) if not self.use_schedule: return param_feat if schedule_x is None: raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x") schedule_feat = self.encoders["schedule"](schedule_x) return torch.cat([param_feat, schedule_feat], dim=-1) def _predict_level_shape( self, trunk_feat: torch.Tensor, level_head: str, shape_head: str, ) -> torch.Tensor: """用 level + centered shape 生成一段曲线。""" level = self.heads[level_head](trunk_feat) shape = self.center_shape(self.heads[shape_head](trunk_feat)) return level + shape def forward( self, params_x: torch.Tensor, schedule_x: torch.Tensor | None = None, ) -> torch.Tensor: """执行一次前向预测。 参数分支和流量制度分支先分别编码,再在隐空间拼接。融合主干提取共同特征后, 压力和导数各自通过 level + centered shape 两个输出头生成。旧版三通道模型 会继续生成 slope,以便已有 checkpoint 保持可加载。 """ fused_feat = self._encode_features(params_x, schedule_x) trunk_feat = self.trunk(fused_feat) pressure_pred = self._predict_level_shape( trunk_feat, "pressure_level", "pressure_shape", ) derivative_pred = self._predict_level_shape( trunk_feat, "derivative_level", "derivative_shape", ) outputs = [pressure_pred, derivative_pred] if "slope" in self.heads: outputs.append(self.heads["slope"](trunk_feat)) return torch.cat(outputs, dim=1)