1、c++自动拟合范围适配

feature/ModelOpt20260526
1294271022 2 weeks ago
parent 4f698b5d0c
commit 47d00106b1

@ -1,12 +1,12 @@
# -*- coding: utf-8 -*-
"""Time-conditioned surrogate network.
"""时间条件代理模型网络结构。
The model predicts one point on the log-pressure and log-derivative curves:
TimeConditionedSurrogate 不一次性输出完整曲线而是把物理参数 + 流量制度 + 某个
时间点特征作为输入预测该时间点的 log_pressure log_derivative它适合用于
更灵活的时间采样局部曲线重建以及需要按时间点加权的训练策略
f(params, schedule, time_point) -> [log_pressure(t), log_derivative(t)]
It is intended for flexible time grids, point-wise training, and later PSO
candidate screening where only selected curves need full solver evaluation.
模型主体由参数编码器制度编码器时间编码器融合投影和若干残差块组成残差
块让较深的全连接网络更容易优化同时保留原始融合特征
"""
from __future__ import annotations
@ -21,7 +21,7 @@ from torch import nn
@dataclass(frozen=True)
class TimeConditionedSurrogateConfig:
"""Network dimensions and architecture options."""
"""时间条件代理模型配置。"""
param_dim: int
schedule_dim: int
@ -33,9 +33,10 @@ class TimeConditionedSurrogateConfig:
class ResidualBlock(nn.Module):
"""Fully connected residual block used in the fusion trunk."""
"""时间条件模型使用的全连接残差块,用于在较深网络中稳定传播特征。"""
def __init__(self, dim: int, dropout: float = 0.0):
"""构造两层全连接残差块,并根据 dropout 参数决定是否启用随机失活。"""
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
@ -47,12 +48,18 @@ class ResidualBlock(nn.Module):
self.act = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Apply a residual correction and activation."""
"""在保留原始隐藏表示的基础上叠加两层非线性修正。"""
# 残差连接让块学习修正量,而不是每层都重新表示完整特征。
return self.act(x + self.net(x))
class TimeConditionedSurrogate(nn.Module):
"""Point-wise time-conditioned pressure/derivative predictor."""
"""逐时间点预测的时间条件代理模型。
ForwardSurrogate 一次输出整条曲线不同该模型每次只预测一个时间点的
log_pressure log_derivative训练数据通常由 PointCurveDataset [N, T] 曲线
展开为 N*T 个样本因此 batch 中每一行都带有自己的 time_x
"""
def __init__(
self,
@ -66,6 +73,7 @@ class TimeConditionedSurrogate(nn.Module):
dropout: float = 0.05,
use_schedule: bool = True,
):
"""按配置组装时间条件代理模型的编码、融合和输出层。"""
super().__init__()
config = self._coerce_config(
config=config,
@ -121,7 +129,7 @@ class TimeConditionedSurrogate(nn.Module):
dropout: float,
use_schedule: bool,
) -> TimeConditionedSurrogateConfig:
"""Support both config-object and legacy keyword constructors."""
"""兼容配置对象式构造和旧版关键字参数式构造。"""
if config is not None:
return config
if param_dim is None or schedule_dim is None or time_dim is None:
@ -141,6 +149,7 @@ class TimeConditionedSurrogate(nn.Module):
@staticmethod
def _build_encoder(in_dim: int, out_dim: int) -> nn.Sequential:
"""构建 Linear-LayerNorm-GELU 编码器。"""
return nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim),
@ -153,18 +162,25 @@ class TimeConditionedSurrogate(nn.Module):
time_x: torch.Tensor,
schedule_x: torch.Tensor | None = None,
) -> torch.Tensor:
"""Fuse sample-level and point-level features into a two-channel output."""
"""融合样本级特征和点级时间特征,输出双通道响应。
params_x schedule_x 在同一条曲线的所有时间点上相同time_x 则随时间点变化
三类特征编码后拼接进入残差主干最后输出 [B, 2]分别对应标准化空间中的
log_pressure log_derivative
"""
# params_x 和 schedule_x 是样本级特征time_x 是展开后的点级特征。
param_feat = self.param_encoder(params_x)
time_feat = self.time_encoder(time_x)
if self.use_schedule:
if schedule_x is None:
raise ValueError("use_schedule=True, but schedule_x was not provided")
raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x")
schedule_feat = self.schedule_encoder(schedule_x)
fused = torch.cat([param_feat, schedule_feat, time_feat], dim=-1)
else:
fused = torch.cat([param_feat, time_feat], dim=-1)
# 融合后的特征经过残差主干,输出 log_pressure 和 log_derivative 两个通道。
hidden = self.input_proj(fused)
hidden = self.blocks(hidden)
return self.head(hidden)
@ -173,5 +189,5 @@ class TimeConditionedSurrogate(nn.Module):
def build_time_conditioned_surrogate(
config: TimeConditionedSurrogateConfig,
) -> TimeConditionedSurrogate:
"""Create a time-conditioned surrogate from a config object."""
"""根据配置创建时间条件代理模型。"""
return TimeConditionedSurrogate(config)

@ -47,8 +47,8 @@ static const int kSurrogateScoringMaxAttempts = 2;
static const bool kUseFixedPsoSeed = true;
static const unsigned int kFixedPsoSeed = 1792008679u;
static const double kSurrogateKMin = 1.0e-4;
static const double kSurrogateKMax = 100.0;
static const double kSurrogateKMin = 1.0e-3;
static const double kSurrogateKMax = 10.0;
static const double kSurrogateSkinMin = -10.0;
static const double kSurrogateSkinMax = 10.0;
static const double kSurrogateWellboreCMin = 1.0e-4;
@ -56,7 +56,7 @@ static const double kSurrogateWellboreCMax = 2.0;
static const double kSurrogatePhiMin = 1.0e-2;
static const double kSurrogatePhiMax = 0.50;
static const double kSurrogateHMin = 2.0;
static const double kSurrogateHMax = 100.0;
static const double kSurrogateHMax = 50.0;
static const double kSurrogateCfFixed = 4.315e-4;
static const double kSurrogateCfRelTol = 0.25;
static const int kSurrogateMinProdSections = 3;

Loading…
Cancel
Save