|
|
|
|
@ -1,12 +1,12 @@
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
"""Time-conditioned surrogate network.
|
|
|
|
|
"""时间条件代理模型网络结构。
|
|
|
|
|
|
|
|
|
|
The model predicts one point on the log-pressure and log-derivative curves:
|
|
|
|
|
TimeConditionedSurrogate 不一次性输出完整曲线,而是把“物理参数 + 流量制度 + 某个
|
|
|
|
|
时间点特征”作为输入,预测该时间点的 log_pressure 和 log_derivative。它适合用于
|
|
|
|
|
更灵活的时间采样、局部曲线重建以及需要按时间点加权的训练策略。
|
|
|
|
|
|
|
|
|
|
f(params, schedule, time_point) -> [log_pressure(t), log_derivative(t)]
|
|
|
|
|
|
|
|
|
|
It is intended for flexible time grids, point-wise training, and later PSO
|
|
|
|
|
candidate screening where only selected curves need full solver evaluation.
|
|
|
|
|
模型主体由参数编码器、制度编码器、时间编码器、融合投影和若干残差块组成。残差
|
|
|
|
|
块让较深的全连接网络更容易优化,同时保留原始融合特征。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
@ -21,7 +21,7 @@ from torch import nn
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
class TimeConditionedSurrogateConfig:
|
|
|
|
|
"""Network dimensions and architecture options."""
|
|
|
|
|
"""时间条件代理模型配置。"""
|
|
|
|
|
|
|
|
|
|
param_dim: int
|
|
|
|
|
schedule_dim: int
|
|
|
|
|
@ -33,9 +33,10 @@ class TimeConditionedSurrogateConfig:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResidualBlock(nn.Module):
|
|
|
|
|
"""Fully connected residual block used in the fusion trunk."""
|
|
|
|
|
"""时间条件模型使用的全连接残差块,用于在较深网络中稳定传播特征。"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim: int, dropout: float = 0.0):
|
|
|
|
|
"""构造两层全连接残差块,并根据 dropout 参数决定是否启用随机失活。"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
|
nn.LayerNorm(dim),
|
|
|
|
|
@ -47,12 +48,18 @@ class ResidualBlock(nn.Module):
|
|
|
|
|
self.act = nn.GELU()
|
|
|
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
|
|
"""Apply a residual correction and activation."""
|
|
|
|
|
"""在保留原始隐藏表示的基础上叠加两层非线性修正。"""
|
|
|
|
|
# 残差连接让块学习修正量,而不是每层都重新表示完整特征。
|
|
|
|
|
return self.act(x + self.net(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TimeConditionedSurrogate(nn.Module):
|
|
|
|
|
"""Point-wise time-conditioned pressure/derivative predictor."""
|
|
|
|
|
"""逐时间点预测的时间条件代理模型。
|
|
|
|
|
|
|
|
|
|
与 ForwardSurrogate 一次输出整条曲线不同,该模型每次只预测一个时间点的
|
|
|
|
|
log_pressure 和 log_derivative。训练数据通常由 PointCurveDataset 将 [N, T] 曲线
|
|
|
|
|
展开为 N*T 个样本,因此 batch 中每一行都带有自己的 time_x。
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
@ -66,6 +73,7 @@ class TimeConditionedSurrogate(nn.Module):
|
|
|
|
|
dropout: float = 0.05,
|
|
|
|
|
use_schedule: bool = True,
|
|
|
|
|
):
|
|
|
|
|
"""按配置组装时间条件代理模型的编码、融合和输出层。"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
config = self._coerce_config(
|
|
|
|
|
config=config,
|
|
|
|
|
@ -121,7 +129,7 @@ class TimeConditionedSurrogate(nn.Module):
|
|
|
|
|
dropout: float,
|
|
|
|
|
use_schedule: bool,
|
|
|
|
|
) -> TimeConditionedSurrogateConfig:
|
|
|
|
|
"""Support both config-object and legacy keyword constructors."""
|
|
|
|
|
"""兼容配置对象式构造和旧版关键字参数式构造。"""
|
|
|
|
|
if config is not None:
|
|
|
|
|
return config
|
|
|
|
|
if param_dim is None or schedule_dim is None or time_dim is None:
|
|
|
|
|
@ -141,6 +149,7 @@ class TimeConditionedSurrogate(nn.Module):
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _build_encoder(in_dim: int, out_dim: int) -> nn.Sequential:
|
|
|
|
|
"""构建 Linear-LayerNorm-GELU 编码器。"""
|
|
|
|
|
return nn.Sequential(
|
|
|
|
|
nn.Linear(in_dim, out_dim),
|
|
|
|
|
nn.LayerNorm(out_dim),
|
|
|
|
|
@ -153,18 +162,25 @@ class TimeConditionedSurrogate(nn.Module):
|
|
|
|
|
time_x: torch.Tensor,
|
|
|
|
|
schedule_x: torch.Tensor | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
"""Fuse sample-level and point-level features into a two-channel output."""
|
|
|
|
|
"""融合样本级特征和点级时间特征,输出双通道响应。
|
|
|
|
|
|
|
|
|
|
params_x 与 schedule_x 在同一条曲线的所有时间点上相同,time_x 则随时间点变化。
|
|
|
|
|
三类特征编码后拼接进入残差主干,最后输出 [B, 2],分别对应标准化空间中的
|
|
|
|
|
log_pressure 和 log_derivative。
|
|
|
|
|
"""
|
|
|
|
|
# params_x 和 schedule_x 是样本级特征;time_x 是展开后的点级特征。
|
|
|
|
|
param_feat = self.param_encoder(params_x)
|
|
|
|
|
time_feat = self.time_encoder(time_x)
|
|
|
|
|
|
|
|
|
|
if self.use_schedule:
|
|
|
|
|
if schedule_x is None:
|
|
|
|
|
raise ValueError("use_schedule=True, but schedule_x was not provided")
|
|
|
|
|
raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x")
|
|
|
|
|
schedule_feat = self.schedule_encoder(schedule_x)
|
|
|
|
|
fused = torch.cat([param_feat, schedule_feat, time_feat], dim=-1)
|
|
|
|
|
else:
|
|
|
|
|
fused = torch.cat([param_feat, time_feat], dim=-1)
|
|
|
|
|
|
|
|
|
|
# 融合后的特征经过残差主干,输出 log_pressure 和 log_derivative 两个通道。
|
|
|
|
|
hidden = self.input_proj(fused)
|
|
|
|
|
hidden = self.blocks(hidden)
|
|
|
|
|
return self.head(hidden)
|
|
|
|
|
@ -173,5 +189,5 @@ class TimeConditionedSurrogate(nn.Module):
|
|
|
|
|
def build_time_conditioned_surrogate(
|
|
|
|
|
config: TimeConditionedSurrogateConfig,
|
|
|
|
|
) -> TimeConditionedSurrogate:
|
|
|
|
|
"""Create a time-conditioned surrogate from a config object."""
|
|
|
|
|
"""根据配置创建时间条件代理模型。"""
|
|
|
|
|
return TimeConditionedSurrogate(config)
|
|
|
|
|
|