1、模型新增注释

feature/ModelOpt20260526
1294271022 2 weeks ago
parent 137ee0bc1b
commit 4f698b5d0c

@ -1,12 +1,12 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""时间条件代理模型网络结构。 """Time-conditioned surrogate network.
TimeConditionedSurrogate 不一次性输出完整曲线而是把物理参数 + 流量制度 + 某个 The model predicts one point on the log-pressure and log-derivative curves:
时间点特征作为输入预测该时间点的 log_pressure log_derivative它适合用于
更灵活的时间采样局部曲线重建以及需要按时间点加权的训练策略
模型主体由参数编码器制度编码器时间编码器融合投影和若干残差块组成残差 f(params, schedule, time_point) -> [log_pressure(t), log_derivative(t)]
块让较深的全连接网络更容易优化同时保留原始融合特征
It is intended for flexible time grids, point-wise training, and later PSO
candidate screening where only selected curves need full solver evaluation.
""" """
from __future__ import annotations from __future__ import annotations
@ -21,7 +21,7 @@ from torch import nn
@dataclass(frozen=True) @dataclass(frozen=True)
class TimeConditionedSurrogateConfig: class TimeConditionedSurrogateConfig:
"""时间条件代理模型配置。""" """Network dimensions and architecture options."""
param_dim: int param_dim: int
schedule_dim: int schedule_dim: int
@ -33,10 +33,9 @@ class TimeConditionedSurrogateConfig:
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
"""时间条件模型使用的全连接残差块,用于在较深网络中稳定传播特征。""" """Fully connected residual block used in the fusion trunk."""
def __init__(self, dim: int, dropout: float = 0.0): def __init__(self, dim: int, dropout: float = 0.0):
"""构造两层全连接残差块,并根据 dropout 参数决定是否启用随机失活。"""
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.LayerNorm(dim), nn.LayerNorm(dim),
@ -48,18 +47,12 @@ class ResidualBlock(nn.Module):
self.act = nn.GELU() self.act = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""在保留原始隐藏表示的基础上叠加两层非线性修正。""" """Apply a residual correction and activation."""
# 残差连接让块学习修正量,而不是每层都重新表示完整特征。
return self.act(x + self.net(x)) return self.act(x + self.net(x))
class TimeConditionedSurrogate(nn.Module): class TimeConditionedSurrogate(nn.Module):
"""逐时间点预测的时间条件代理模型。 """Point-wise time-conditioned pressure/derivative predictor."""
ForwardSurrogate 一次输出整条曲线不同该模型每次只预测一个时间点的
log_pressure log_derivative训练数据通常由 PointCurveDataset [N, T] 曲线
展开为 N*T 个样本因此 batch 中每一行都带有自己的 time_x
"""
def __init__( def __init__(
self, self,
@ -73,7 +66,6 @@ class TimeConditionedSurrogate(nn.Module):
dropout: float = 0.05, dropout: float = 0.05,
use_schedule: bool = True, use_schedule: bool = True,
): ):
"""按配置组装时间条件代理模型的编码、融合和输出层。"""
super().__init__() super().__init__()
config = self._coerce_config( config = self._coerce_config(
config=config, config=config,
@ -129,7 +121,7 @@ class TimeConditionedSurrogate(nn.Module):
dropout: float, dropout: float,
use_schedule: bool, use_schedule: bool,
) -> TimeConditionedSurrogateConfig: ) -> TimeConditionedSurrogateConfig:
"""兼容配置对象式构造和旧版关键字参数式构造。""" """Support both config-object and legacy keyword constructors."""
if config is not None: if config is not None:
return config return config
if param_dim is None or schedule_dim is None or time_dim is None: if param_dim is None or schedule_dim is None or time_dim is None:
@ -149,7 +141,6 @@ class TimeConditionedSurrogate(nn.Module):
@staticmethod @staticmethod
def _build_encoder(in_dim: int, out_dim: int) -> nn.Sequential: def _build_encoder(in_dim: int, out_dim: int) -> nn.Sequential:
"""构建 Linear-LayerNorm-GELU 编码器。"""
return nn.Sequential( return nn.Sequential(
nn.Linear(in_dim, out_dim), nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim), nn.LayerNorm(out_dim),
@ -162,25 +153,18 @@ class TimeConditionedSurrogate(nn.Module):
time_x: torch.Tensor, time_x: torch.Tensor,
schedule_x: torch.Tensor | None = None, schedule_x: torch.Tensor | None = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""融合样本级特征和点级时间特征,输出双通道响应。 """Fuse sample-level and point-level features into a two-channel output."""
params_x schedule_x 在同一条曲线的所有时间点上相同time_x 则随时间点变化
三类特征编码后拼接进入残差主干最后输出 [B, 2]分别对应标准化空间中的
log_pressure log_derivative
"""
# params_x 和 schedule_x 是样本级特征time_x 是展开后的点级特征。
param_feat = self.param_encoder(params_x) param_feat = self.param_encoder(params_x)
time_feat = self.time_encoder(time_x) time_feat = self.time_encoder(time_x)
if self.use_schedule: if self.use_schedule:
if schedule_x is None: if schedule_x is None:
raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x") raise ValueError("use_schedule=True, but schedule_x was not provided")
schedule_feat = self.schedule_encoder(schedule_x) schedule_feat = self.schedule_encoder(schedule_x)
fused = torch.cat([param_feat, schedule_feat, time_feat], dim=-1) fused = torch.cat([param_feat, schedule_feat, time_feat], dim=-1)
else: else:
fused = torch.cat([param_feat, time_feat], dim=-1) fused = torch.cat([param_feat, time_feat], dim=-1)
# 融合后的特征经过残差主干,输出 log_pressure 和 log_derivative 两个通道。
hidden = self.input_proj(fused) hidden = self.input_proj(fused)
hidden = self.blocks(hidden) hidden = self.blocks(hidden)
return self.head(hidden) return self.head(hidden)
@ -189,5 +173,5 @@ class TimeConditionedSurrogate(nn.Module):
def build_time_conditioned_surrogate( def build_time_conditioned_surrogate(
config: TimeConditionedSurrogateConfig, config: TimeConditionedSurrogateConfig,
) -> TimeConditionedSurrogate: ) -> TimeConditionedSurrogate:
"""根据配置创建时间条件代理模型。""" """Create a time-conditioned surrogate from a config object."""
return TimeConditionedSurrogate(config) return TimeConditionedSurrogate(config)

Loading…
Cancel
Save