diff --git a/ML/nmWTAI-ML/src/models/time_conditioned_surrogate.py b/ML/nmWTAI-ML/src/models/time_conditioned_surrogate.py index cde54e4..7e93160 100644 --- a/ML/nmWTAI-ML/src/models/time_conditioned_surrogate.py +++ b/ML/nmWTAI-ML/src/models/time_conditioned_surrogate.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- -"""时间条件代理模型网络结构。 +"""Time-conditioned surrogate network. -TimeConditionedSurrogate 不一次性输出完整曲线,而是把“物理参数 + 流量制度 + 某个 -时间点特征”作为输入,预测该时间点的 log_pressure 和 log_derivative。它适合用于 -更灵活的时间采样、局部曲线重建以及需要按时间点加权的训练策略。 +The model predicts one point on the log-pressure and log-derivative curves: -模型主体由参数编码器、制度编码器、时间编码器、融合投影和若干残差块组成。残差 -块让较深的全连接网络更容易优化,同时保留原始融合特征。 + f(params, schedule, time_point) -> [log_pressure(t), log_derivative(t)] + +It is intended for flexible time grids, point-wise training, and later PSO +candidate screening where only selected curves need full solver evaluation. """ from __future__ import annotations @@ -21,7 +21,7 @@ from torch import nn @dataclass(frozen=True) class TimeConditionedSurrogateConfig: - """时间条件代理模型配置。""" + """Network dimensions and architecture options.""" param_dim: int schedule_dim: int @@ -33,10 +33,9 @@ class TimeConditionedSurrogateConfig: class ResidualBlock(nn.Module): - """时间条件模型使用的全连接残差块,用于在较深网络中稳定传播特征。""" + """Fully connected residual block used in the fusion trunk.""" def __init__(self, dim: int, dropout: float = 0.0): - """构造两层全连接残差块,并根据 dropout 参数决定是否启用随机失活。""" super().__init__() self.net = nn.Sequential( nn.LayerNorm(dim), @@ -48,18 +47,12 @@ class ResidualBlock(nn.Module): self.act = nn.GELU() def forward(self, x: torch.Tensor) -> torch.Tensor: - """在保留原始隐藏表示的基础上叠加两层非线性修正。""" - # 残差连接让块学习修正量,而不是每层都重新表示完整特征。 + """Apply a residual correction and activation.""" return self.act(x + self.net(x)) class TimeConditionedSurrogate(nn.Module): - """逐时间点预测的时间条件代理模型。 - - 与 ForwardSurrogate 一次输出整条曲线不同,该模型每次只预测一个时间点的 - log_pressure 和 log_derivative。训练数据通常由 PointCurveDataset 将 [N, T] 曲线 - 展开为 N*T 个样本,因此 batch 中每一行都带有自己的 time_x。 - """ + """Point-wise time-conditioned pressure/derivative predictor.""" def __init__( self, @@ -73,7 +66,6 @@ class TimeConditionedSurrogate(nn.Module): dropout: float = 0.05, use_schedule: bool = True, ): - """按配置组装时间条件代理模型的编码、融合和输出层。""" super().__init__() config = self._coerce_config( config=config, @@ -129,7 +121,7 @@ class TimeConditionedSurrogate(nn.Module): dropout: float, use_schedule: bool, ) -> TimeConditionedSurrogateConfig: - """兼容配置对象式构造和旧版关键字参数式构造。""" + """Support both config-object and legacy keyword constructors.""" if config is not None: return config if param_dim is None or schedule_dim is None or time_dim is None: @@ -149,7 +141,6 @@ class TimeConditionedSurrogate(nn.Module): @staticmethod def _build_encoder(in_dim: int, out_dim: int) -> nn.Sequential: - """构建 Linear-LayerNorm-GELU 编码器。""" return nn.Sequential( nn.Linear(in_dim, out_dim), nn.LayerNorm(out_dim), @@ -162,25 +153,18 @@ class TimeConditionedSurrogate(nn.Module): time_x: torch.Tensor, schedule_x: torch.Tensor | None = None, ) -> torch.Tensor: - """融合样本级特征和点级时间特征,输出双通道响应。 - - params_x 与 schedule_x 在同一条曲线的所有时间点上相同,time_x 则随时间点变化。 - 三类特征编码后拼接进入残差主干,最后输出 [B, 2],分别对应标准化空间中的 - log_pressure 和 log_derivative。 - """ - # params_x 和 schedule_x 是样本级特征;time_x 是展开后的点级特征。 + """Fuse sample-level and point-level features into a two-channel output.""" param_feat = self.param_encoder(params_x) time_feat = self.time_encoder(time_x) if self.use_schedule: if schedule_x is None: - raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x") + raise ValueError("use_schedule=True, but schedule_x was not provided") schedule_feat = self.schedule_encoder(schedule_x) fused = torch.cat([param_feat, schedule_feat, time_feat], dim=-1) else: fused = torch.cat([param_feat, time_feat], dim=-1) - # 融合后的特征经过残差主干,输出 log_pressure 和 log_derivative 两个通道。 hidden = self.input_proj(fused) hidden = self.blocks(hidden) return self.head(hidden) @@ -189,5 +173,5 @@ class TimeConditionedSurrogate(nn.Module): def build_time_conditioned_surrogate( config: TimeConditionedSurrogateConfig, ) -> TimeConditionedSurrogate: - """根据配置创建时间条件代理模型。""" + """Create a time-conditioned surrogate from a config object.""" return TimeConditionedSurrogate(config)