1、忽略审查代码配置文件 2、python代码静态审查 3、python代码逻辑修正

feature/ModelOpt20260526
1294271022 3 weeks ago
parent 547020b992
commit 137ee0bc1b

1
.gitignore vendored

@ -119,5 +119,6 @@ ML/nmWTAI-ML/data
ML/nmWTAI-ML/models
ML/nmWTAI-ML/results
__pycache__
.pylintrc
ML/Training/Debug
ML/Training/Release

@ -1,25 +1,19 @@
"""分析集成代理模型的不确定性评估结果。
本脚本读取 `evaluate_forward_ensemble.py` 导出的逐样本误差与不确定性 CSV
统计剔除高不确定性样本后剩余样本误差是否下降高误差样本是否被不确定性捕获
并输出汇总 CSV风险样本清单以及保留率曲线图
"""
分析集成代理模型不确定性结果
评估 uncertainty 是否能够识别高误差样本
并生成 retention curve风险样本统计和 summary 输出
"""
from __future__ import annotations
# 标准库导入:
# argparse 用于解析命令行参数;
# csv/json 用于读写分析结果;
# sys/Path 用于处理项目路径和文件路径。
import argparse
import csv
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
# 第三方库导入:
# matplotlib 用于绘制不确定性筛选曲线;
# numpy 用于数组计算、排序、百分位数等统计操作。
import matplotlib.pyplot as plt
import numpy as np
@ -32,47 +26,91 @@ ROOT = Path(__file__).resolve().parents[1]
# 当前脚本虽然没有直接导入内部模块,但保留该设置可以兼容项目结构。
sys.path.append(str(ROOT))
Row = dict[str, str]
OutputRow = dict[str, Any]
def parse_args() -> argparse.Namespace:
"""解析不确定性评估结果 CSV、输出目录以及分位数分箱参数。"""
# 创建命令行参数解析器。
# 本脚本主要用于分析集成代理模型输出的不确定性是否能够辅助识别高误差样本。
parser = argparse.ArgumentParser(description="Analyze ensemble UQ outputs for fallback usefulness")
# 输入 CSV通常来自集成不确定性评估阶段生成的 sample_uncertainty_metrics.csv。
# 如果不指定,则会根据 tag 自动拼接默认路径。
parser.add_argument("--input-csv", type=str, default=None, help="Path to sample_uncertainty_metrics.csv")
@dataclass(frozen=True)
class Metrics:
"""保存样本级误差和不确定性指标,减少 main 函数中的局部变量数量。"""
rmse: np.ndarray
mae: np.ndarray
r2: np.ndarray
unc: np.ndarray
@dataclass(frozen=True)
class UncertaintyThresholds:
"""保存低/高不确定性阈值。"""
low: float
high: float
# 输出目录:用于保存 summary、筛选曲线 CSV、风险样本 CSV 和图像。
# 如果不指定,则会根据 tag 自动生成默认输出目录。
parser.add_argument("--output-dir", type=str, default=None, help="Optional output directory")
@dataclass(frozen=True)
class RiskRows:
"""保存几类需要导出的风险样本。"""
# 实验标签:用于区分不同实验、不同数据集或不同代理模型配置。
parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag")
top_uncertain: list[Row]
high_error_high_unc: list[Row]
high_error_low_unc: list[Row]
# 不确定性剔除比例。
# 例如 0,5,10 表示分别剔除不确定性最高的 0%、5%、10% 样本后统计保留样本误差。
def parse_args() -> argparse.Namespace:
"""解析不确定性评估结果 CSV、输出目录以及分位数分箱参数。"""
parser = argparse.ArgumentParser(
description="Analyze ensemble UQ outputs for fallback usefulness"
)
parser.add_argument(
"--input-csv",
type=str,
default=None,
help="Path to sample_uncertainty_metrics.csv",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Optional output directory",
)
parser.add_argument(
"--tag",
type=str,
default="family_random_50k",
help="Experiment tag",
)
parser.add_argument(
"--quantiles",
type=str,
default="0,5,10,20,30,40,50",
help="Comma-separated uncertainty removal percentages",
)
# 导出不确定性最高的 Top-K 样本,方便后续查看这些样本是否确实更容易预测失败。
parser.add_argument("--top-k", type=int, default=100, help="Top-K risky samples to export")
# 高误差样本阈值。
# overall_rmse 大于等于该值的样本会被视为高误差样本。
parser.add_argument("--high-error-rmse", type=float, default=1.0, help="High-error RMSE threshold")
# 低不确定性分位点。
# 例如 50 表示不确定性低于中位数的样本被视为低不确定性样本。
parser.add_argument("--low-unc-quantile", type=float, default=50.0, help="Low uncertainty quantile")
# 高不确定性分位点。
# 例如 90 表示不确定性高于第 90 百分位的样本被视为高不确定性样本。
parser.add_argument("--high-unc-quantile", type=float, default=90.0, help="High uncertainty quantile")
parser.add_argument(
"--top-k",
type=int,
default=100,
help="Top-K risky samples to export",
)
parser.add_argument(
"--high-error-rmse",
type=float,
default=1.0,
help="High-error RMSE threshold",
)
parser.add_argument(
"--low-unc-quantile",
type=float,
default=50.0,
help="Low uncertainty quantile",
)
parser.add_argument(
"--high-unc-quantile",
type=float,
default=90.0,
help="High uncertainty quantile",
)
return parser.parse_args()
@ -81,69 +119,69 @@ def parse_quantiles(text: str) -> list[float]:
"""解析逗号分隔的剔除比例,用于不确定性保留曲线分析。"""
values = []
# 将形如 "0,5,10,20" 的字符串逐项拆分并转成 float。
for item in str(text).split(","):
item = item.strip()
# 允许用户在字符串中多写逗号或空格,空项直接跳过。
if not item:
continue
q = float(item)
# 剔除比例必须位于 [0, 100)。
# 不能等于 100因为全部剔除后没有样本可用于统计误差。
if q < 0 or q >= 100:
raise ValueError(f"非法 quantile 百分比: {q}")
values.append(q)
# 强制包含 0%,用于记录不剔除任何样本时的全局基线结果。
if 0.0 not in values:
values = [0.0] + values
# 去重并升序排列,保证后续曲线横轴顺序稳定。
return sorted(set(values))
def default_input_csv(tag: str) -> Path:
"""根据实验标签定位集成评估生成的 sample_uncertainty_metrics.csv。"""
# 默认输入路径约定:
# results/evaluation_{tag}_ensemble_uq/sample_uncertainty_metrics.csv
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
def default_output_dir(tag: str) -> Path:
"""根据实验标签生成当前分析脚本默认的输出目录。"""
# 默认输出路径约定:
# results/evaluation_{tag}_ensemble_uq_analysis
return Path("results") / f"evaluation_{tag}_ensemble_uq_analysis"
def load_rows(csv_path: Path) -> list[dict]:
def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path]:
"""根据命令行参数解析输入 CSV 和输出目录。"""
tag = str(args.tag).strip()
input_csv = Path(args.input_csv) if args.input_csv else default_input_csv(tag)
output_dir = Path(args.output_dir) if args.output_dir else default_output_dir(tag)
return input_csv, output_dir
def load_rows(csv_path: Path) -> list[Row]:
"""读取 CSV 为字典行列表,保留每个样本的不确定性和误差字段。"""
# 使用 utf-8-sig 是为了兼容 Excel 或部分 Windows 工具生成的带 BOM 的 CSV。
with open(csv_path, "r", encoding="utf-8-sig", newline="") as f:
rows = list(csv.DictReader(f))
with open(csv_path, "r", encoding="utf-8-sig", newline="") as file:
rows = list(csv.DictReader(file))
# 如果 CSV 只有表头或为空,则无法进行统计分析,直接抛出错误。
if not rows:
raise ValueError(f"CSV 没有数据: {csv_path}")
return rows
def as_float(rows: list[dict], key: str) -> np.ndarray:
def as_float(rows: list[Row], key: str) -> np.ndarray:
"""将 CSV 字段转为 float 数组,用于后续统计分析。"""
# 这里要求 CSV 中必须存在对应字段,并且字段值可以转为 float。
# 当前脚本依赖的关键字段包括:
# overall_rmse、overall_mae、overall_r2、unc_mean_std。
return np.array([float(row[key]) for row in rows], dtype=np.float64)
def extract_metrics(rows: list[Row]) -> Metrics:
"""从 CSV 行中提取样本级误差和不确定性指标。"""
return Metrics(
rmse=as_float(rows, "overall_rmse"),
mae=as_float(rows, "overall_mae"),
r2=as_float(rows, "overall_r2"),
unc=as_float(rows, "unc_mean_std"),
)
def safe_mean(x: np.ndarray) -> float:
"""计算均值;没有有效数据时返回 NaN。"""
# 当某个筛选条件过严时,可能出现空数组,此时返回 NaN 避免程序崩溃。
return float(np.mean(x)) if x.size else np.nan
@ -157,126 +195,41 @@ def safe_percentile(x: np.ndarray, q: float) -> float:
return float(np.percentile(x, q)) if x.size else np.nan
def save_csv(path: Path, rows: list[dict]) -> None:
def save_csv(path: Path, rows: list[OutputRow] | list[Row]) -> None:
"""把字典行写入 CSV。"""
# 若 rows 为空,则不写文件。
# 例如某次实验中不存在“高误差且低不确定性”的样本,就会走到这里。
if not rows:
return
# 输出使用 utf-8-sig方便在 Excel 中直接打开时中文不乱码。
with open(path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
with open(path, "w", encoding="utf-8-sig", newline="") as file:
writer = csv.DictWriter(file, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
def plot_retention_curve(curve_rows: list[dict], output_path: Path) -> None:
"""绘制按不确定性由高到低剔除样本后的误差变化,用来判断不确定性能否筛掉坏样本。"""
# 从曲线统计结果中取出横轴和纵轴数据。
removed = np.array([float(row["removed_pct"]) for row in curve_rows], dtype=np.float64)
retained = np.array([float(row["retained_pct"]) for row in curve_rows], dtype=np.float64)
rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64)
mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64)
r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64)
# 创建 1 行 3 列子图,分别观察 RMSE、MAE、R2 随保留样本比例的变化。
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
# 子图 1保留样本比例 vs RMSE。
# 如果不确定性有效剔除高不确定性样本后RMSE 理论上应下降。
axes[0].plot(retained, rmse, marker="o")
axes[0].set_title("Retained vs RMSE")
axes[0].set_xlabel("Retained Samples (%)")
axes[0].set_ylabel("RMSE mean")
axes[0].grid(True, alpha=0.3)
# 子图 2保留样本比例 vs MAE。
# MAE 对极端值不如 RMSE 敏感,可作为误差变化的补充观察指标。
axes[1].plot(retained, mae, marker="o")
axes[1].set_title("Retained vs MAE")
axes[1].set_xlabel("Retained Samples (%)")
axes[1].set_ylabel("MAE mean")
axes[1].grid(True, alpha=0.3)
# 子图 3保留样本比例 vs R2。
# 如果保留样本更容易预测R2 通常应升高。
axes[2].plot(retained, r2, marker="o")
axes[2].set_title("Retained vs R2")
axes[2].set_xlabel("Retained Samples (%)")
axes[2].set_ylabel("R2 mean")
axes[2].grid(True, alpha=0.3)
# 总标题中记录剔除比例网格,便于之后回看图像时知道对应的实验设置。
fig.suptitle(
f"Uncertainty Filtering Curves | removed grid={','.join(str(int(x)) for x in removed)}%"
def build_thresholds(
unc: np.ndarray,
low_unc_quantile: float,
high_unc_quantile: float,
) -> UncertaintyThresholds:
"""根据样本不确定性分布计算低/高不确定性阈值。"""
return UncertaintyThresholds(
low=float(np.percentile(unc, low_unc_quantile)),
high=float(np.percentile(unc, high_unc_quantile)),
)
# tight_layout 用于减少子图之间的重叠rect 为总标题预留空间。
plt.tight_layout(rect=[0, 0, 1, 0.94])
# 保存图像dpi=150 对实验报告和日常查看基本够用。
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def build_retention_curve(metrics: Metrics, quantiles: list[float]) -> list[OutputRow]:
"""构造按不确定性由高到低剔除样本后的误差统计曲线。"""
curve_rows: list[OutputRow] = []
for removed_pct in quantiles:
threshold = float(np.percentile(metrics.unc, 100.0 - removed_pct))
keep_mask = metrics.unc <= threshold
def main() -> None:
"""分析集成不确定性与真实误差的相关性,并生成分箱统计和散点图。"""
# 1. 解析命令行参数。
args = parse_args()
# 去除 tag 两端空格,避免路径拼接时出现隐藏错误。
tag = str(args.tag).strip()
# 2. 确定输入 CSV 和输出目录。
# 若用户通过命令行指定路径,则优先使用用户路径;否则使用默认路径。
input_csv = Path(args.input_csv) if args.input_csv is not None else default_input_csv(tag)
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag)
# 确保输出目录存在parents=True 允许自动创建多级目录。
output_dir.mkdir(parents=True, exist_ok=True)
# 3. 读取样本级不确定性评估结果。
rows = load_rows(input_csv)
# 解析用户指定的剔除比例列表。
quantiles = parse_quantiles(args.quantiles)
# 4. 提取关键指标。
# overall_rmse / overall_mae / overall_r2 是每个样本的预测误差表现;
# unc_mean_std 是集成模型输出的平均标准差,可理解为该样本的预测不确定性。
rmse = as_float(rows, "overall_rmse")
mae = as_float(rows, "overall_mae")
r2 = as_float(rows, "overall_r2")
unc = as_float(rows, "unc_mean_std")
total_n = len(rows)
# 低/高不确定性阈值来自样本分布本身,避免手动指定绝对阈值依赖数据尺度。
# 例如不同压力归一化方式、不同数据集或不同代理模型会改变 unc_mean_std 的绝对大小。
low_unc_thr = float(np.percentile(unc, float(args.low_unc_quantile)))
high_unc_thr = float(np.percentile(unc, float(args.high_unc_quantile)))
# 5. 构造“不确定性过滤曲线”。
# 核心思路:
# 按不确定性从高到低剔除一部分样本;
# 对剩余样本重新计算 RMSE/MAE/R2
# 如果误差明显改善,说明不确定性指标可以作为 fallback 或人工复核触发条件。
curve_rows: list[dict] = []
kept_rmse = metrics.rmse[keep_mask]
kept_mae = metrics.mae[keep_mask]
kept_r2 = metrics.r2[keep_mask]
for removed_pct in quantiles:
# 计算当前剔除比例对应的不确定性阈值。
# 例如 removed_pct=10 时,阈值为第 90 百分位;
# keep_mask 会保留不确定性小于等于该阈值的样本,即剔除最高 10% 不确定性样本。
thr = float(np.percentile(unc, 100.0 - removed_pct))
keep_mask = unc <= thr
# 根据掩码取出保留下来的样本误差。
kept_rmse = rmse[keep_mask]
kept_mae = mae[keep_mask]
kept_r2 = r2[keep_mask]
# 保存当前剔除比例下的统计结果。
curve_rows.append(
{
"removed_pct": float(removed_pct),
@ -289,20 +242,25 @@ def main() -> None:
"mae_median": safe_median(kept_mae),
"r2_mean": safe_mean(kept_r2),
"r2_median": safe_median(kept_r2),
"unc_threshold": thr,
"unc_threshold": threshold,
}
)
# 6. 导出不确定性最高的 Top-K 样本。
# np.argsort(-unc) 表示按 unc 从大到小排序。
top_unc_idx = np.argsort(-unc)[: min(args.top_k, total_n)]
top_uncertain_rows = [rows[int(i)] for i in top_unc_idx]
return curve_rows
def select_top_uncertain_rows(rows: list[Row], unc: np.ndarray, top_k: int) -> list[Row]:
"""按不确定性从大到小导出 Top-K 样本。"""
top_uncertain_indices = np.argsort(-unc)[: min(top_k, len(rows))]
return [rows[int(index)] for index in top_uncertain_indices]
# 7. 识别两类高误差样本:
# high_error_high_unc模型误差大并且不确定性也高。
# 这类样本是“可被不确定性发现的风险样本”。
# high_error_low_unc模型误差大但不确定性低。
# 这类样本最危险,说明模型错得很自信,是后续改进代理模型时最值得重点排查的对象。
def classify_risky_rows(
rows: list[Row],
thresholds: UncertaintyThresholds,
high_error_rmse: float,
) -> tuple[list[Row], list[Row]]:
"""识别高误差高不确定性样本,以及高误差低不确定性样本。"""
high_error_high_unc_rows = []
high_error_low_unc_rows = []
@ -310,61 +268,143 @@ def main() -> None:
row_rmse = float(row["overall_rmse"])
row_unc = float(row["unc_mean_std"])
# 高误差 + 高不确定性fallback 机制比较容易捕捉。
if row_rmse >= float(args.high_error_rmse) and row_unc >= high_unc_thr:
if row_rmse >= high_error_rmse and row_unc >= thresholds.high:
high_error_high_unc_rows.append(row)
# 高误差 + 低不确定性:说明不确定性没有给出预警,是更严重的问题。
if row_rmse >= float(args.high_error_rmse) and row_unc <= low_unc_thr:
if row_rmse >= high_error_rmse and row_unc <= thresholds.low:
high_error_low_unc_rows.append(row)
# 按 RMSE 从高到低排序,方便优先查看最严重的失败样本。
high_error_high_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
high_error_low_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
high_error_high_unc_rows.sort(
key=lambda row: float(row["overall_rmse"]),
reverse=True,
)
high_error_low_unc_rows.sort(
key=lambda row: float(row["overall_rmse"]),
reverse=True,
)
# 8. 汇总全局统计信息和关键阈值,写入 JSON。
# JSON 更适合被后续自动化流程、实验记录系统或报告脚本读取。
summary = {
return high_error_high_unc_rows, high_error_low_unc_rows
def build_risk_rows(
rows: list[Row],
metrics: Metrics,
thresholds: UncertaintyThresholds,
high_error_rmse: float,
top_k: int,
) -> RiskRows:
"""构造所有需要导出的风险样本集合。"""
high_error_high_unc_rows, high_error_low_unc_rows = classify_risky_rows(
rows=rows,
thresholds=thresholds,
high_error_rmse=high_error_rmse,
)
return RiskRows(
top_uncertain=select_top_uncertain_rows(rows, metrics.unc, top_k),
high_error_high_unc=high_error_high_unc_rows,
high_error_low_unc=high_error_low_unc_rows,
)
def build_summary(
input_csv: Path,
metrics: Metrics,
thresholds: UncertaintyThresholds,
risk_rows: RiskRows,
args: argparse.Namespace,
) -> dict[str, Any]:
"""汇总全局统计信息和关键阈值,方便后续写入 JSON。"""
return {
"input_csv": str(input_csv),
"n_samples": total_n,
"n_samples": int(metrics.rmse.size),
"global": {
"rmse_mean": safe_mean(rmse),
"rmse_median": safe_median(rmse),
"rmse_p90": safe_percentile(rmse, 90),
"mae_mean": safe_mean(mae),
"r2_mean": safe_mean(r2),
"unc_mean": safe_mean(unc),
"unc_median": safe_median(unc),
"unc_p90": safe_percentile(unc, 90),
"rmse_mean": safe_mean(metrics.rmse),
"rmse_median": safe_median(metrics.rmse),
"rmse_p90": safe_percentile(metrics.rmse, 90),
"mae_mean": safe_mean(metrics.mae),
"r2_mean": safe_mean(metrics.r2),
"unc_mean": safe_mean(metrics.unc),
"unc_median": safe_median(metrics.unc),
"unc_p90": safe_percentile(metrics.unc, 90),
},
"thresholds": {
"high_error_rmse": float(args.high_error_rmse),
"low_unc_quantile": float(args.low_unc_quantile),
"low_unc_threshold": low_unc_thr,
"low_unc_threshold": thresholds.low,
"high_unc_quantile": float(args.high_unc_quantile),
"high_unc_threshold": high_unc_thr,
"high_unc_threshold": thresholds.high,
},
"counts": {
"top_uncertain_exported": len(top_uncertain_rows),
"high_error_high_unc": len(high_error_high_unc_rows),
"high_error_low_unc": len(high_error_low_unc_rows),
"top_uncertain_exported": len(risk_rows.top_uncertain),
"high_error_high_unc": len(risk_rows.high_error_high_unc),
"high_error_low_unc": len(risk_rows.high_error_low_unc),
},
}
# ensure_ascii=False 可以保证 JSON 文件中中文正常显示。
with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
# 9. 保存所有分析产物。
def plot_retention_curve(curve_rows: list[OutputRow], output_path: Path) -> None:
"""绘制按不确定性由高到低剔除样本后的误差变化。"""
removed = np.array(
[float(row["removed_pct"]) for row in curve_rows],
dtype=np.float64,
)
retained = np.array(
[float(row["retained_pct"]) for row in curve_rows],
dtype=np.float64,
)
rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64)
mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64)
r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64)
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
axes[0].plot(retained, rmse, marker="o")
axes[0].set_title("Retained vs RMSE")
axes[0].set_xlabel("Retained Samples (%)")
axes[0].set_ylabel("RMSE mean")
axes[0].grid(True, alpha=0.3)
axes[1].plot(retained, mae, marker="o")
axes[1].set_title("Retained vs MAE")
axes[1].set_xlabel("Retained Samples (%)")
axes[1].set_ylabel("MAE mean")
axes[1].grid(True, alpha=0.3)
axes[2].plot(retained, r2, marker="o")
axes[2].set_title("Retained vs R2")
axes[2].set_xlabel("Retained Samples (%)")
axes[2].set_ylabel("R2 mean")
axes[2].grid(True, alpha=0.3)
removed_grid = ",".join(str(int(value)) for value in removed)
fig.suptitle(f"Uncertainty Filtering Curves | removed grid={removed_grid}%")
plt.tight_layout(rect=[0, 0, 1, 0.94])
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def save_outputs(
output_dir: Path,
curve_rows: list[OutputRow],
risk_rows: RiskRows,
summary: dict[str, Any],
) -> None:
"""保存 JSON、CSV 和不确定性过滤曲线图。"""
with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as file:
json.dump(summary, file, ensure_ascii=False, indent=2)
save_csv(output_dir / "uncertainty_filter_curve.csv", curve_rows)
save_csv(output_dir / "top_uncertain_samples.csv", top_uncertain_rows)
save_csv(output_dir / "high_error_high_unc_samples.csv", high_error_high_unc_rows)
save_csv(output_dir / "high_error_low_unc_samples.csv", high_error_low_unc_rows)
save_csv(output_dir / "top_uncertain_samples.csv", risk_rows.top_uncertain)
save_csv(output_dir / "high_error_high_unc_samples.csv", risk_rows.high_error_high_unc)
save_csv(output_dir / "high_error_low_unc_samples.csv", risk_rows.high_error_low_unc)
# 绘制并保存不确定性过滤曲线。
plot_retention_curve(curve_rows, output_dir / "uncertainty_filter_curve.png")
# 10. 在控制台打印简要结果,便于命令行运行后快速判断分析是否完成。
def print_summary(output_dir: Path, summary: dict[str, Any]) -> None:
"""在控制台打印简要结果,便于命令行运行后快速判断分析是否完成。"""
print("UQ fallback analysis complete.")
print(f"Output dir: {output_dir}")
print(
@ -377,8 +417,39 @@ def main() -> None:
)
# Python 脚本入口。
# 当该文件被直接运行时执行 main()
# 当它被其他脚本 import 时,不会自动执行分析流程。
def main() -> None:
"""分析集成不确定性与真实误差的相关性,并生成分析产物。"""
args = parse_args()
input_csv, output_dir = resolve_paths(args)
output_dir.mkdir(parents=True, exist_ok=True)
rows = load_rows(input_csv)
metrics = extract_metrics(rows)
quantiles = parse_quantiles(args.quantiles)
thresholds = build_thresholds(
unc=metrics.unc,
low_unc_quantile=float(args.low_unc_quantile),
high_unc_quantile=float(args.high_unc_quantile),
)
curve_rows = build_retention_curve(metrics, quantiles)
risk_rows = build_risk_rows(
rows=rows,
metrics=metrics,
thresholds=thresholds,
high_error_rmse=float(args.high_error_rmse),
top_k=int(args.top_k),
)
summary = build_summary(
input_csv=input_csv,
metrics=metrics,
thresholds=thresholds,
risk_rows=risk_rows,
args=args,
)
save_outputs(output_dir, curve_rows, risk_rows, summary)
print_summary(output_dir, summary)
if __name__ == "__main__":
main()

@ -1,8 +1,7 @@
"""结合原始样本元数据分析不确定性评估结果。
"""
将集成学习不确定性量化样本指标与处理后的数据集元数据进行关联
`analyze_uq_results.py` 的全局统计之外本脚本会把 UQ CSV 与预处理数据中的
参数流量制度等元数据对齐按制度族参数区间等维度分组帮助定位代理模型
在哪类数值试井样本上更容易给出高误差或高不确定性
该脚本为每个 UQ 指标行补充调度元数据类别标签以及可选的源标签然后输出分组汇总结果和困难样本
"""
from __future__ import annotations
@ -11,31 +10,91 @@ import argparse
import csv
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import joblib
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
def load_experiment_path_helpers():
"""Load project path helpers after adding project root to sys.path."""
# pylint: disable=import-error,import-outside-toplevel
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
return normalize_tag, processed_path_for_tag
normalize_tag_func, processed_path_func = load_experiment_path_helpers()
@dataclass(frozen=True)
class AnalysisPaths:
"""Resolved input and output paths used by the metadata analysis."""
tag: str
processed_path: Path
uq_csv: Path
output_dir: Path
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
@dataclass(frozen=True)
class ProcessedMetadata:
"""Metadata arrays loaded from the processed dataset."""
schedule_meta_test: np.ndarray
family_name_test: np.ndarray
source_name_test: np.ndarray | None
source_id_test: np.ndarray | None
meta_names: list[str]
name_to_col: dict[str, int]
def parse_args() -> argparse.Namespace:
"""解析 UQ 指标 CSV、processed 数据和输出路径,用于合并样本元数据分析。"""
parser = argparse.ArgumentParser(description="Join UQ sample metrics with saved metadata")
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
parser.add_argument("--uq-csv", type=str, default=None, help="sample_uncertainty_metrics.csv path")
parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag")
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
parser = argparse.ArgumentParser(
description="Join UQ sample metrics with saved metadata"
)
parser.add_argument(
"--processed",
type=str,
default=None,
help="Processed dataset path",
)
parser.add_argument(
"--uq-csv",
type=str,
default=None,
help="sample_uncertainty_metrics.csv path",
)
parser.add_argument(
"--tag",
type=str,
default="family_random_50k",
help="Experiment tag",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory",
)
parser.add_argument("--high-error-rmse", type=float, default=1.0)
return parser.parse_args()
def default_uq_csv(tag: str) -> Path:
"""根据实验标签定位默认的不确定性指标 CSV。"""
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
return (
Path("results")
/ f"evaluation_{tag}_ensemble_uq"
/ "sample_uncertainty_metrics.csv"
)
def default_output_dir(tag: str) -> Path:
@ -43,123 +102,269 @@ def default_output_dir(tag: str) -> Path:
return Path("results") / f"evaluation_{tag}_ensemble_uq_metadata_analysis"
def save_csv(path: Path, rows: list[dict]) -> None:
"""把字典行写入 CSV当没有行时仍写出表头方便后续脚本读取。"""
def resolve_paths(args: argparse.Namespace) -> AnalysisPaths:
"""根据命令行参数和实验标签解析输入输出路径。"""
tag = normalize_tag_func(args.tag)
fallback_tag = tag or "family_random_50k"
processed_path = (
Path(args.processed)
if args.processed is not None
else processed_path_func(tag)
)
uq_csv = (
Path(args.uq_csv)
if args.uq_csv is not None
else default_uq_csv(fallback_tag)
)
output_dir = (
Path(args.output_dir)
if args.output_dir is not None
else default_output_dir(fallback_tag)
)
return AnalysisPaths(
tag=tag,
processed_path=processed_path,
uq_csv=uq_csv,
output_dir=output_dir,
)
def save_csv(path: Path, rows: list[dict[str, Any]]) -> None:
"""把字典行写入 CSV当没有行时不写文件。"""
if not rows:
return
with open(path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
def summarize_group(rows: list[dict], group_key: str) -> list[dict]:
"""对一个样本分组计算误差、不确定性和样本数量等汇总指标。"""
groups: dict[str, list[dict]] = {}
for row in rows:
groups.setdefault(str(row[group_key]), []).append(row)
def load_uq_rows(uq_csv: Path) -> list[dict[str, str]]:
"""读取 sample_uncertainty_metrics.csv。"""
with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f:
rows = list(csv.DictReader(f))
out = []
for key, group_rows in sorted(groups.items(), key=lambda item: len(item[1]), reverse=True):
rmse = np.array([float(r["overall_rmse"]) for r in group_rows], dtype=np.float64)
unc = np.array([float(r["unc_mean_std"]) for r in group_rows], dtype=np.float64)
out.append(
{
group_key: key,
"n_samples": len(group_rows),
"rmse_mean": float(np.mean(rmse)),
"rmse_median": float(np.median(rmse)),
"unc_mean": float(np.mean(unc)),
"unc_median": float(np.median(unc)),
}
)
return out
if not rows:
raise ValueError(f"UQ CSV 没有数据: {uq_csv}")
return rows
def main() -> None:
"""把 UQ 结果与参数/制度元数据拼接,按工况区域统计误差和不确定性。"""
args = parse_args()
tag = normalize_tag(args.tag)
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
uq_csv = Path(args.uq_csv) if args.uq_csv is not None else default_uq_csv(tag or "family_random_50k")
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag or "family_random_50k")
output_dir.mkdir(parents=True, exist_ok=True)
def load_processed_metadata(processed_path: Path) -> ProcessedMetadata:
"""读取 processed 数据中的测试集元数据。"""
data = joblib.load(processed_path)
if "schedule_meta_test" not in data or "family_name_test" not in data:
required_keys = {"schedule_meta_test", "family_name_test"}
if not required_keys.issubset(data):
raise RuntimeError(
"Processed dataset does not contain schedule metadata. Re-run preprocess on a metadata-rich raw HDF5 first."
"Processed dataset does not contain schedule metadata. "
"Re-run preprocess on a metadata-rich raw HDF5 first."
)
schedule_meta_test = np.asarray(data["schedule_meta_test"], dtype=np.float32)
family_name_test = np.asarray(data["family_name_test"]).astype(str)
# source 字段只在混合数据集里存在;普通数据集没有时保持可选。
source_name_test = np.asarray(data["source_name_test"]).astype(str) if "source_name_test" in data else None
source_id_test = np.asarray(data["source_id_test"]) if "source_id_test" in data else None
meta_names = data["meta"].get("schedule_meta_names") or []
source_name_test = None
source_id_test = None
if "source_name_test" in data:
source_name_test = np.asarray(data["source_name_test"]).astype(str)
if "source_id_test" in data:
source_id_test = np.asarray(data["source_id_test"])
meta_names = data["meta"].get("schedule_meta_names") or []
name_to_col = {name: i for i, name in enumerate(meta_names)}
with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f:
uq_rows = list(csv.DictReader(f))
return ProcessedMetadata(
schedule_meta_test=schedule_meta_test,
family_name_test=family_name_test,
source_name_test=source_name_test,
source_id_test=source_id_test,
meta_names=meta_names,
name_to_col=name_to_col,
)
def enrich_uq_rows(
uq_rows: list[dict[str, str]],
metadata: ProcessedMetadata,
) -> list[dict[str, Any]]:
"""把 UQ 指标与 family/source/schedule metadata 拼接到同一行。"""
enriched_rows: list[dict[str, Any]] = []
enriched_rows = []
for row in uq_rows:
idx = int(row["idx"])
meta_row = schedule_meta_test[idx]
enriched = dict(row)
# UQ CSV 的 idx 对应 processed 测试集下标,因此可直接取测试集元数据补列。
enriched["family_name"] = family_name_test[idx]
if source_name_test is not None:
enriched["source_name"] = source_name_test[idx]
if source_id_test is not None:
enriched["source_id"] = int(source_id_test[idx])
for name, col in name_to_col.items():
meta_row = metadata.schedule_meta_test[idx]
enriched: dict[str, Any] = dict(row)
enriched["family_name"] = metadata.family_name_test[idx]
if metadata.source_name_test is not None:
enriched["source_name"] = metadata.source_name_test[idx]
if metadata.source_id_test is not None:
enriched["source_id"] = int(metadata.source_id_test[idx])
for name, col in metadata.name_to_col.items():
enriched[name] = float(meta_row[col])
enriched_rows.append(enriched)
save_csv(output_dir / "uq_samples_with_metadata.csv", enriched_rows)
return enriched_rows
family_summary = summarize_group(enriched_rows, "family_name")
save_csv(output_dir / "summary_by_family.csv", family_summary)
if source_name_test is not None:
source_summary = summarize_group(enriched_rows, "source_name")
save_csv(output_dir / "summary_by_source.csv", source_summary)
def summarize_group(
rows: list[dict[str, Any]],
group_key: str,
) -> list[dict[str, Any]]:
"""对一个样本分组计算误差、不确定性和样本数量等汇总指标。"""
groups: dict[str, list[dict[str, Any]]] = {}
for row in rows:
groups.setdefault(str(row[group_key]), []).append(row)
summaries = []
sorted_groups = sorted(
groups.items(),
key=lambda item: len(item[1]),
reverse=True,
)
for key, group_rows in sorted_groups:
rmse = np.array(
[float(row["overall_rmse"]) for row in group_rows],
dtype=np.float64,
)
unc = np.array(
[float(row["unc_mean_std"]) for row in group_rows],
dtype=np.float64,
)
summaries.append(
{
group_key: key,
"n_samples": len(group_rows),
"rmse_mean": float(np.mean(rmse)),
"rmse_median": float(np.median(rmse)),
"unc_mean": float(np.mean(unc)),
"unc_median": float(np.median(unc)),
}
)
return summaries
def save_group_summaries(
output_dir: Path,
enriched_rows: list[dict[str, Any]],
metadata: ProcessedMetadata,
) -> None:
"""按 family/source/n_prod 等维度保存分组统计。"""
save_csv(
output_dir / "summary_by_family.csv",
summarize_group(enriched_rows, "family_name"),
)
if "n_prod" in name_to_col:
if metadata.source_name_test is not None:
save_csv(
output_dir / "summary_by_source.csv",
summarize_group(enriched_rows, "source_name"),
)
if "n_prod" in metadata.name_to_col:
for row in enriched_rows:
row["n_prod_group"] = int(round(float(row["n_prod"])))
# 生产段数常对应不同制度复杂度,单独分组有助于定位 UQ 失效区域。
n_prod_summary = summarize_group(enriched_rows, "n_prod_group")
save_csv(output_dir / "summary_by_n_prod.csv", n_prod_summary)
unc_values = np.array([float(row["unc_mean_std"]) for row in enriched_rows], dtype=np.float64)
save_csv(
output_dir / "summary_by_n_prod.csv",
summarize_group(enriched_rows, "n_prod_group"),
)
def find_high_error_low_unc(
enriched_rows: list[dict[str, Any]],
high_error_rmse: float,
) -> list[dict[str, Any]]:
"""筛出高误差但低不确定性的危险样本。"""
unc_values = np.array(
[float(row["unc_mean_std"]) for row in enriched_rows],
dtype=np.float64,
)
low_unc_threshold = float(np.median(unc_values))
# 筛出“高误差但低不确定性”的样本,后续可作为困难样本补采样或诊断入口。
high_error_low_unc = [
risky_rows = [
row
for row in enriched_rows
if float(row["overall_rmse"]) >= float(args.high_error_rmse)
if float(row["overall_rmse"]) >= high_error_rmse
and float(row["unc_mean_std"]) <= low_unc_threshold
]
high_error_low_unc.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
save_csv(output_dir / "high_error_low_unc_with_metadata.csv", high_error_low_unc)
risky_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
return risky_rows
def write_json_summary(
output_dir: Path,
paths: AnalysisPaths,
metadata: ProcessedMetadata,
enriched_rows: list[dict[str, Any]],
high_error_low_unc: list[dict[str, Any]],
) -> None:
"""保存 metadata join 的整体摘要信息。"""
summary = {
"processed_path": str(processed_path),
"uq_csv": str(uq_csv),
"tag": paths.tag,
"processed_path": str(paths.processed_path),
"uq_csv": str(paths.uq_csv),
"n_samples": len(enriched_rows),
"meta_names": meta_names,
"has_source_name": bool(source_name_test is not None),
"meta_names": metadata.meta_names,
"has_source_name": bool(metadata.source_name_test is not None),
"high_error_low_unc_count": len(high_error_low_unc),
}
with open(output_dir / "metadata_join_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
def run_analysis(args: argparse.Namespace) -> tuple[Path, int]:
"""执行完整 UQ metadata join 分析流程。"""
paths = resolve_paths(args)
paths.output_dir.mkdir(parents=True, exist_ok=True)
metadata = load_processed_metadata(paths.processed_path)
uq_rows = load_uq_rows(paths.uq_csv)
enriched_rows = enrich_uq_rows(uq_rows, metadata)
save_csv(paths.output_dir / "uq_samples_with_metadata.csv", enriched_rows)
save_group_summaries(paths.output_dir, enriched_rows, metadata)
high_error_low_unc = find_high_error_low_unc(
enriched_rows=enriched_rows,
high_error_rmse=float(args.high_error_rmse),
)
save_csv(
paths.output_dir / "high_error_low_unc_with_metadata.csv",
high_error_low_unc,
)
write_json_summary(
output_dir=paths.output_dir,
paths=paths,
metadata=metadata,
enriched_rows=enriched_rows,
high_error_low_unc=high_error_low_unc,
)
return paths.output_dir, len(high_error_low_unc)
def main() -> None:
"""命令行入口:拼接 UQ 结果和元数据,并输出汇总文件。"""
output_dir, risky_count = run_analysis(parse_args())
print("UQ metadata join analysis complete.")
print(f"Output dir: {output_dir}")
print(f"High-error low-unc count={len(high_error_low_unc)}")
print(f"High-error low-unc count={risky_count}")
if __name__ == "__main__":

@ -1,32 +1,61 @@
"""构建“普通样本 + 困难样本”的混合训练数据集。
脚本先调用合并逻辑把常规数据集与局部自动拟合邻域数据集合成为一个 HDF5
再复用统一预处理流程生成模型训练所需的标准化数据文件适合在正演代理模型
需要兼顾全局覆盖和 PSO/自动拟合困难区域时使用
再复用统一预处理流程生成模型训练所需的标准化数据文件
该脚本适合在正演代理模型需要同时兼顾全局覆盖样本和 PSO/自动拟合困难区域
样本时使用
"""
from __future__ import annotations
import argparse
import importlib
import sys
from pathlib import Path
from typing import Any, Callable
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from scripts.merge_datasets import merge_datasets
from src.common.experiment_paths import normalize_tag, processed_path_for_tag, sample_path_for_tag
from src.data.preprocess import preprocess_dataset
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
def parse_args() -> argparse.Namespace:
"""解析 normal/hard 两类 HDF5 的混合比例、输出路径和预处理切分参数。"""
parser = argparse.ArgumentParser(description="Build a mixed raw+processed dataset from normal and hard HDF5 pools")
parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset")
parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset")
parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Experiment tag")
parser.add_argument("--output-h5", type=str, default=None, help="Optional merged raw .h5 path")
parser.add_argument("--output-processed", type=str, default=None, help="Optional processed .pkl path")
parser = argparse.ArgumentParser(
description="Build a mixed raw+processed dataset from normal and hard HDF5 pools"
)
parser.add_argument(
"--normal-input",
type=str,
required=True,
help="Path to the normal/main .h5 dataset",
)
parser.add_argument(
"--hard-input",
type=str,
required=True,
help="Path to the hard-targeted .h5 dataset",
)
parser.add_argument(
"--tag",
type=str,
default="family_random_mixed_50k",
help="Experiment tag",
)
parser.add_argument(
"--output-h5",
type=str,
default=None,
help="Optional merged raw .h5 path",
)
parser.add_argument(
"--output-processed",
type=str,
default=None,
help="Optional processed .pkl path",
)
parser.add_argument("--total-samples", type=int, default=50000)
parser.add_argument("--hard-ratio", type=float, default=0.30)
parser.add_argument("--normal-count", type=int, default=None)
@ -40,17 +69,55 @@ def parse_args() -> argparse.Namespace:
return parser.parse_args()
def main() -> None:
"""合并普通样本与困难样本,并立即生成对应的 processed 训练数据。"""
args = parse_args()
tag = normalize_tag(args.tag)
output_h5 = Path(args.output_h5) if args.output_h5 is not None else sample_path_for_tag(tag)
def load_project_functions() -> tuple[
Callable[..., dict[str, Any]],
Callable[[str], str],
Callable[[str], Path],
Callable[[str], Path],
Callable[..., None],
]:
"""延迟导入项目内部函数,避免 Pylint 对动态项目路径产生误报。"""
merge_module = importlib.import_module("scripts.merge_datasets")
paths_module = importlib.import_module("src.common.experiment_paths")
preprocess_module = importlib.import_module("src.data.preprocess")
return (
merge_module.merge_datasets,
paths_module.normalize_tag,
paths_module.processed_path_for_tag,
paths_module.sample_path_for_tag,
preprocess_module.preprocess_dataset,
)
def resolve_output_paths(
args: argparse.Namespace,
tag: str,
processed_path_for_tag: Callable[[str], Path],
sample_path_for_tag: Callable[[str], Path],
) -> tuple[Path, Path]:
"""根据命令行参数和实验标签确定 merged HDF5 与 processed 输出路径。"""
output_h5 = (
Path(args.output_h5)
if args.output_h5 is not None
else sample_path_for_tag(tag)
)
output_processed = (
Path(args.output_processed) if args.output_processed is not None else processed_path_for_tag(tag)
Path(args.output_processed)
if args.output_processed is not None
else processed_path_for_tag(tag)
)
return output_h5, output_processed
# 先在原始 HDF5 层面按比例抽样合并,保留 source_label 便于之后追踪样本来源。
merge_meta = merge_datasets(
def merge_raw_datasets(
args: argparse.Namespace,
tag: str,
output_h5: Path,
merge_datasets: Callable[..., dict[str, Any]],
) -> dict[str, Any]:
"""按设定比例合并普通样本与困难样本,并返回合并过程元数据。"""
return merge_datasets(
normal_input=args.normal_input,
hard_input=args.hard_input,
output=output_h5,
@ -65,7 +132,14 @@ def main() -> None:
batch_size=args.batch_size,
)
# 合并后的原始数据立即进入同一套预处理流程,保证训练集格式与普通数据一致。
def build_processed_dataset(
args: argparse.Namespace,
output_h5: Path,
output_processed: Path,
preprocess_dataset: Callable[..., None],
) -> None:
"""对合并后的 HDF5 数据执行统一预处理,生成模型训练用 pkl 文件。"""
preprocess_dataset(
input_path=output_h5,
output_path=output_processed,
@ -74,10 +148,52 @@ def main() -> None:
random_seed=args.seed,
)
def print_outputs(merge_meta: dict[str, Any], output_processed: Path) -> None:
"""打印混合原始数据和 processed 数据的输出位置。"""
print(f"Merged raw dataset: {merge_meta['output_path']}")
print(f"Merge summary: {merge_meta['summary_path']}")
print(f"Processed dataset: {output_processed}")
def main() -> None:
"""合并普通样本与困难样本,并立即生成对应的 processed 训练数据。"""
args = parse_args()
(
merge_datasets,
normalize_tag,
processed_path_for_tag,
sample_path_for_tag,
preprocess_dataset,
) = load_project_functions()
tag = normalize_tag(args.tag)
output_h5, output_processed = resolve_output_paths(
args=args,
tag=tag,
processed_path_for_tag=processed_path_for_tag,
sample_path_for_tag=sample_path_for_tag,
)
# 先在原始 HDF5 层面按比例抽样合并,保留 source_label 便于之后追踪样本来源。
merge_meta = merge_raw_datasets(
args=args,
tag=tag,
output_h5=output_h5,
merge_datasets=merge_datasets,
)
# 合并后的原始数据立即进入同一套预处理流程,保证训练集格式与普通数据一致。
build_processed_dataset(
args=args,
output_h5=output_h5,
output_processed=output_processed,
preprocess_dataset=preprocess_dataset,
)
print_outputs(merge_meta=merge_meta, output_processed=output_processed)
if __name__ == "__main__":
main()

@ -1,9 +1,23 @@
"""单个试井样本的数值求解器与正演代理模型对比。
该脚本用一组指定地层/井筒参数和流量制度分别运行 C++ 数值求解器与 Python
正演代理模型随后在压力压力导数和斜率三段曲线上计算误差指标绘制逐点
对比图并导出 JSON 汇总便于排查单个案例的代理误差来源
"""
数值试井代理模型 - 单案例 Solver vs Surrogate 对比脚本
主要功能
1. 使用 C++ 数值求解器生成真实试井曲线
2. 使用训练好的代理模型进行预测
3. 对比 Solver Surrogate 的输出结果
4. 计算 RMSE / MAE / R2 等指标
5. 绘制压力导数斜率三部分对比图
6. 输出 JSON 分析结果
该脚本通常用于
- 检查代理模型是否学到真实物理行为
- 分析代理模型在哪些阶段误差较大
- 验证不同 schedule 对模型预测的影响
- 调试自动拟合(autofit)效果
"""
# pylint: disable=import-error,wrong-import-position,wrong-import-order,line-too-long,
# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments,invalid-name
from __future__ import annotations
@ -36,6 +50,8 @@ from src.evaluation.autofit_objective import dual_log_objective
from src.models.forward_surrogate import ForwardSurrogate
# 默认单案例参数
# 用于没有传命令行参数时的快速测试
DEFAULT_SINGLE_CASE = {
"config": "configs/data_gen_family_random.yaml",
"tag": "family_random_50k",
@ -53,11 +69,7 @@ DEFAULT_SINGLE_CASE = {
def parse_args() -> argparse.Namespace:
"""解析单案例对比所需的路径、参数、井号和流量制度覆盖项。
默认值对应一个可复现实验案例命令行可覆盖模型 checkpointprocessed 数据
地层/井筒参数以及 `timeQ/q/sectionIndex`方便快速定位某个具体样本的误差表现
"""
"""解析单个样本对比所需的 processed 数据、模型、样本索引和输出目录。"""
parser = argparse.ArgumentParser(
description="Compare solver output and surrogate prediction on a single parameter set"
)
@ -122,11 +134,8 @@ def calc_metrics(
eps_range: float = 1e-3,
eps_var: float = 1e-6,
) -> dict:
"""计算一维曲线预测误差指标。
`eps_range` 用于避免真实曲线几乎为常数时 NRMSE 分母过小
`eps_var` 用于避免 R2 在真实曲线方差接近 0 时失真无效场景返回 NaN
"""
"""计算 RMSE、MAE、Bias、NRMSE、R2 等回归指标。"""
# 残差 = 代理模型预测值 - 数值求解器真实值
err = y_pred - y_true
mse = np.mean(err**2)
rmse = float(np.sqrt(mse))
@ -137,6 +146,7 @@ def calc_metrics(
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
ss_res = float(np.sum(err**2))
# 曲线几乎为常数时NRMSE 和 R2 的分母会过小,此时返回 NaN 更真实。
valid_nrmse = value_range > eps_range
valid_r2 = ss_tot > eps_var
@ -154,15 +164,14 @@ def calc_metrics(
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
"""从预处理元数据推断曲线拼接布局。
新版 processed 文件会保存 `curve_layout`若旧数据缺失该字段则按压力
压力导数斜率三段等长拼接的历史约定回退保证旧实验仍可比较
"""
"""从元数据读取曲线分段布局;旧数据没有布局时按压力/导数/斜率三等分回退。"""
# curve_layout 描述了拼接曲线的结构布局
# 包括每一段 pressure/derivative/slope 的起止位置
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
# 兼容早期 processed 文件:没有显式 layout 时仍按 pressure/derivative/slope 三段切分。
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
@ -175,11 +184,9 @@ def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]:
"""按 `curve_layout` 把一维拼接曲线拆成命名片段。
返回值通常包含 `log_pressure``log_derivative` `slope`
后续绘图分段指标和自动拟合目标函数都依赖这些片段边界一致
"""
"""按照 curve_layout 将拼接曲线拆成 log_pressure、log_derivative 和 slope 三段。"""
# parts 用于保存拆分后的不同曲线段
# 例如 log_pressure / derivative / slope
parts: dict[str, np.ndarray] = {}
for part in layout["parts"]:
start = int(part["start"])
@ -189,11 +196,9 @@ def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarr
def resolve_cf(args: argparse.Namespace, cfg: Config) -> float:
"""解析综合压缩系数 Cf。
Cf 可能由命令行显式给出也可能在配置文件的 fixed_params 中固定
集中处理该兜底逻辑避免构造 `Params` 时把缺失 Cf 静默写成错误默认值
"""
"""解析压缩系数 Cf优先使用命令行参数否则从配置默认参数中读取。"""
# 命令行优先级最高
# 如果用户显式提供 Cf则直接使用
if args.Cf is not None:
return float(args.Cf)
@ -205,11 +210,8 @@ def resolve_cf(args: argparse.Namespace, cfg: Config) -> float:
def parse_float_list(raw: str, name: str) -> list[float]:
"""把命令行传入的逗号/分号/空格分隔数值串转换为浮点列表。
该函数用于解析 `--timeQ` `--q`允许用户用不同分隔符快速输入制度序列
若没有解析出任何有效数字则抛出带参数名的错误
"""
"""解析输入文本、命令行或配置值,转换为后续流程可直接使用的结构。"""
# 用于解析类似 "1000,1000,1000" 的字符串输入
values: list[float] = []
for token in raw.replace(";", ",").replace(" ", ",").split(","):
token = token.strip()
@ -222,11 +224,7 @@ def parse_float_list(raw: str, name: str) -> list[float]:
def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) -> Schedule:
"""构造单案例对比使用的流量制度。
默认从配置文件读取 `case_schedule`当命令行提供 `--timeQ/--q` 时优先使用覆盖值
并按显式 `--section-index` 或最后一段规则确定测试段
"""
"""解析单案例使用的流量制度;命令行提供 timeQ/q 时覆盖配置中的默认制度。"""
if args is not None and (args.timeQ is not None or args.q is not None or args.section_index is not None):
if args.timeQ is None or args.q is None:
raise ValueError("覆盖流量制度时必须同时提供 --timeQ 和 --q")
@ -236,7 +234,9 @@ def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) -
if len(timeQ) != len(q):
raise ValueError(f"--timeQ 和 --q 长度必须一致,当前分别为 {len(timeQ)}{len(q)}")
# Allow the reporting/table convention with an initial "0 0" row.
# 兼容报表写法:首行 0,0 只表示初始状态,不作为真实流量段传给求解器。
# 某些报表第一行仅用于表示初始状态
# 不是真实生产段,因此自动丢弃
if len(timeQ) >= 2 and timeQ[0] <= 0.0 and q[0] == 0.0:
timeQ = timeQ[1:]
q = q[1:]
@ -272,11 +272,7 @@ def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) -
def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]:
"""解析配置、processed 数据、模型 checkpoint 和输出目录路径。
路径优先级为命令行显式参数其次为实验 tag 对应的标准目录
`--no-schedule` 只在自动推断模型路径时用于区分是否带制度分支的模型
"""
"""根据命令行参数、实验标签和 use_schedule 开关解析配置、预处理数据、模型和输出路径。"""
tag = normalize_tag(args.tag)
config_path = args.config
@ -300,10 +296,9 @@ def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]:
def build_params_from_args(args: argparse.Namespace, cfg: Config, schedule: Schedule) -> Params:
"""把命令行中的地层/井筒参数组装为求解器和代理模型共用的 `Params`。
该对象同时持有物性参数与流量制度是数值求解参数特征变换和图标题展示的共同数据源
"""
"""根据命令行参数和默认值构造目标 Params 对象。"""
# 构建完整物理参数对象
# 后续会直接送入数值求解器和代理模型
return Params(
k=float(args.k),
skin=float(args.skin),
@ -316,15 +311,14 @@ def build_params_from_args(args: argparse.Namespace, cfg: Config, schedule: Sche
def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) -> tuple[np.ndarray, dict]:
"""运行 C++ 数值求解器并抽取可与代理输出对齐的曲线。
原始求解结果会先做有效性检查和清洗再按训练数据相同的重采样规则转换为
`log_pressure/log_derivative/slope` 拼接向量同时返回原始 log-log 曲线供 JSON 留痕
"""
"""调用 C++ 求解器运行一次正演,并把双对数输出重采样为模型曲线向量。"""
# 创建 C++ 求解器客户端
# 实际底层会调用数值试井求解程序
runner = CppRunner(cfg=cfg)
try:
ok = runner.run_simulation(params, override_schedule=params.schedule, include_schedule=True)
result = read_result_bin(runner.result_bin) if runner.result_bin.exists() else None
# 某些求解器返回码可能失败但 result.bin 已写完;只要结果可读就继续对比。
if not ok and result is None:
raise RuntimeError("求解器运行失败,未生成有效结果")
if not ok and result is not None:
@ -343,6 +337,9 @@ def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) -
p = np.asarray(loglog["p"], dtype=np.float64)
d = np.asarray(loglog["deriv"], dtype=np.float64)
# 求解器原始点数不一定等于模型输出维度,先清洗再重采样到训练时的固定网格。
# 对原始求解器曲线进行清洗
# 去除非法点、异常点、重复点等
cleaned = clean_curve_for_dataset(cfg, t, p, d)
if cleaned is None:
raise RuntimeError("求解器返回曲线在清洗后无效")
@ -352,6 +349,8 @@ def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) -
if not valid:
raise RuntimeError(f"求解器曲线未通过有效性检查: {reason}")
# 将原始不规则时间曲线重采样到固定特征网格
# 这是代理模型训练时使用的统一输入格式
curve_feat = resample_curve_to_features(cfg, t_clean, p_clean, d_clean)
raw = {
"t": t_clean.tolist(),
@ -366,19 +365,13 @@ def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) -
def build_schedule_vector(cfg: Config, schedule: Schedule) -> np.ndarray:
"""把 `Schedule` 编码为正演代理模型的制度特征向量。
这里复用训练阶段同一套编码函数确保单案例推理时的制度特征与 processed 数据一致
"""
"""把 Schedule 编码成正演代理模型可直接接收的流量制度特征向量。"""
return build_schedule_model_vector(cfg, schedule)
def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, bool, torch.device]:
"""加载正演代理模型 checkpoint 并恢复网络结构。
checkpoint 中保存了输入维度隐藏层宽度dropout 和是否使用制度分支等信息
按这些元数据重建模型后再加载权重才能保证推理结构与训练结构一致
"""
"""加载模型检查点,按保存的维度和超参数重建网络并切换到评估模式。"""
# 加载训练好的 PyTorch checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
use_schedule = bool(checkpoint.get("use_schedule", True))
@ -407,22 +400,22 @@ def predict_surrogate_curve(
schedule: Schedule,
cfg: Config,
) -> np.ndarray:
"""使用代理模型预测单个参数点的反标准化曲线。
参数和制度先按 processed 文件中的 transform/scaler 进入训练时的标准化空间
模型输出再通过 `scaler_curve.inverse_transform` 还原成可与数值求解器直接比较的曲线值
"""
"""使用正演代理模型预测单个参数和流量制度对应的完整曲线。"""
scaler_params = processed["scaler_params"]
scaler_schedule = processed["scaler_schedule"]
scaler_curve = processed["scaler_curve"]
param_transform = param_feature_transform_from_meta(processed.get("meta", {}))
# 构建参数特征向量
# 顺序必须与训练阶段保持一致
params_vec = np.asarray(
[params.k, params.skin, params.wellboreC, params.phi, params.h, params.Cf],
dtype=np.float32,
).reshape(1, -1)
# 将流量制度 schedule 编码成模型输入向量
schedule_vec = build_schedule_vector(cfg, schedule).reshape(1, -1)
# 输入特征必须使用训练时保存的 scaler 和参数变换,否则单案例预测会发生分布偏移。
params_x = scaler_params.transform(transform_param_features(params_vec, param_transform)).astype(np.float32)
schedule_x = scaler_schedule.transform(schedule_vec).astype(np.float32)
@ -447,16 +440,14 @@ def plot_comparison(
model_path: Path,
use_schedule: bool,
) -> dict:
"""绘制数值求解器曲线与代理预测曲线的分段对比图。
左列展示每个曲线片段的真实值与预测值右列展示逐点误差
图标题汇总参数制度模型名称和整体指标返回的 summary 会同步写入 JSON
"""
"""绘制数值求解器与代理模型的单案例曲线对比图。"""
true_parts = split_curve_by_layout(curve_true, curve_layout)
pred_parts = split_curve_by_layout(curve_pred, curve_layout)
# 整体曲线误差指标
overall = calc_metrics(curve_true, curve_pred)
# 自动拟合目标函数
# 用于衡量双对数曲线形态是否一致
autofit = dual_log_objective(curve_true, curve_pred, curve_layout)
overall_r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.4f}"
part_names = ["log_pressure", "log_derivative", "slope"]
title_map = {
@ -465,6 +456,8 @@ def plot_comparison(
"slope": "Slope of Log Pressure vs Log Time",
}
r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.4f}"
fig, axes = plt.subplots(3, 2, figsize=(14, 12))
fig.suptitle(
"Solver vs Surrogate\n"
@ -474,14 +467,17 @@ def plot_comparison(
f"use_schedule={use_schedule}, model={model_path.name}, "
f"Overall RMSE={overall['rmse']:.4f}, MAE={overall['mae']:.4f}, "
f"AutoFitObj={autofit['dual_log_objective']:.4f}, "
f"R2={overall_r2_text}"
f"R2={r2_text}"
)
summary = {"overall": overall, "autofit": autofit, "parts": {}}
# 分别对 pressure / derivative / slope 三部分绘图
for row, name in enumerate(part_names):
# 左列画曲线重合程度,右列画逐时间点残差,便于定位误差集中在哪个时期。
y_true = true_parts[name]
y_pred = pred_parts[name]
# 残差 = 代理模型预测值 - 数值求解器真实值
err = y_pred - y_true
x = np.arange(len(y_true))
metrics = calc_metrics(y_true, y_pred)
@ -517,11 +513,7 @@ def plot_comparison(
def main() -> None:
"""执行单案例完整对比流程。
流程包括解析路径和制度加载 processed checkpoint运行数值求解器
调用代理模型预测绘制对比图写出 JSON 汇总并打印核心指标
"""
"""抽取一个测试样本,绘制代理模型曲线与真实数值求解曲线的对比图。"""
args = parse_args()
cfg, processed_path, model_path, output_dir = resolve_paths(args)
output_dir.mkdir(parents=True, exist_ok=True)
@ -534,14 +526,18 @@ def main() -> None:
processed = joblib.load(processed_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
# 先用同一组参数和制度跑真实求解器,再用代理模型预测同一输入,保证对比公平。
schedule = resolve_case_schedule(cfg, args)
params = build_params_from_args(args, cfg, schedule)
print("Running solver...")
# 运行真实数值求解器
# 得到 ground truth 曲线
curve_solver, raw_solver = run_solver_and_extract_curve(cfg, params, args.well_index)
print("Running surrogate...")
model, use_schedule, device = load_model(model_path)
# 使用代理模型预测同一组参数对应的曲线
curve_pred = predict_surrogate_curve(
processed=processed,
model=model,
@ -564,6 +560,8 @@ def main() -> None:
use_schedule=use_schedule,
)
# 汇总所有实验结果
# 最终会保存为 JSON 便于后续分析
summary_payload = {
"config_path": str(cfg.path),
"processed_path": str(processed_path),
@ -592,14 +590,14 @@ def main() -> None:
overall = summary["overall"]
autofit = summary["autofit"]
overall_r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.6f}"
r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.6f}"
print("\nSingle-case comparison complete.")
print(f"Output dir: {output_dir}")
print(
f"Overall RMSE={overall['rmse']:.6f}, MAE={overall['mae']:.6f}, "
f"Bias={overall['bias']:.6f}, "
f"AutoFitObj={autofit['dual_log_objective']:.6f}, "
f"R2={overall_r2_text}"
f"R2={r2_text}"
)
print(f"Plot saved: {plot_path}")
print(f"Summary saved: {output_dir / 'single_case_summary.json'}")

@ -1,10 +1,14 @@
"""评估固定长度曲线正演代理模型。
脚本加载预处理数据和 `ForwardSurrogate` checkpoint批量预测验证/测试样本曲线
脚本加载预处理数据和 `ForwardSurrogate` checkpoint,批量预测验证/测试样本曲线
按整体曲线与压力导数斜率分段统计 RMSEMAENRMSER2 等指标并保存
随机最佳最差样本图作为正演代理模型离线验收入口
"""
# pylint: disable=import-error,wrong-import-position
# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments,too-many-statements
# pylint: disable=line-too-long
from __future__ import annotations
import argparse
@ -53,7 +57,10 @@ def parse_args() -> argparse.Namespace:
"--fit-processed",
type=str,
default=None,
help="Processed dataset used to fit scalers for the evaluated model; required for cross-dataset evaluation",
help=(
"Processed dataset used to fit scalers for the evaluated model; "
"required for cross-dataset evaluation"
),
)
parser.add_argument(
"--output-dir",
@ -470,10 +477,7 @@ def main() -> None:
}
sample_scores: list[dict] = []
for idx in range(len(all_true)):
curve_true = all_true[idx]
curve_pred = all_pred[idx]
for idx, (curve_true, curve_pred) in enumerate(zip(all_true, all_pred)):
# 同时记录整体指标和三段指标,便于判断误差来自压力、导数还是辅助 slope。
overall_m = calc_metrics(curve_true, curve_pred)
overall_metric_list.append(overall_m)

@ -4,6 +4,10 @@
导出不确定性-误差相关性统计散点图和高不确定性样本用于判断集成方差是否可作为
自动拟合候选筛选或风险提示信号
"""
# pylint: disable=import-error,wrong-import-position
# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments
# pylint: disable=too-many-statements
from __future__ import annotations
@ -59,7 +63,12 @@ def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Evaluate deep-ensemble UQ for forward surrogate")
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
parser.add_argument("--tag", type=str, default=None, help="Experiment tag")
parser.add_argument("--model-root", type=str, default=None, help="Root dir that contains seed_* members")
parser.add_argument(
"--model-root",
type=str,
default=None,
help="Root dir that contains seed_* members",
)
parser.add_argument("--output-dir", type=str, default=None, help="Evaluation output dir")
parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list")
parser.add_argument("--no-schedule", action="store_true")
@ -212,7 +221,13 @@ def plot_uncertain_sample(
x = np.arange(len(y_true))
ax.plot(x, y_true, label="True", linewidth=2)
ax.plot(x, y_mean, label="Ensemble mean", linewidth=2)
ax.fill_between(x, y_mean - 2.0 * y_std, y_mean + 2.0 * y_std, alpha=0.2, label="mean ± 2 std")
ax.fill_between(
x,
y_mean - 2.0 * y_std,
y_mean + 2.0 * y_std,
alpha=0.2,
label="mean ± 2 std",
)
ax.set_title(title_map[name])
ax.grid(True, alpha=0.3)
ax.legend()
@ -229,9 +244,21 @@ def main() -> None:
use_schedule = not args.no_schedule
seeds = parse_seed_list(args.seeds)
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_root = Path(args.model_root) if args.model_root is not None else default_model_root(tag, use_schedule)
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag, use_schedule)
processed_path = (
Path(args.processed)
if args.processed is not None
else processed_path_for_tag(tag)
)
model_root = (
Path(args.model_root)
if args.model_root is not None
else default_model_root(tag, use_schedule)
)
output_dir = (
Path(args.output_dir)
if args.output_dir is not None
else default_output_dir(tag, use_schedule)
)
output_dir.mkdir(parents=True, exist_ok=True)
# 集成评估必须使用同一份 processed 数据,保证各成员输入标准化口径一致。
@ -266,8 +293,16 @@ def main() -> None:
with torch.no_grad():
for idx in range(len(x_params_test)):
params_t = torch.tensor(x_params_test[idx : idx + 1], dtype=torch.float32, device=device)
schedule_t = torch.tensor(x_schedule_test[idx : idx + 1], dtype=torch.float32, device=device)
params_t = torch.tensor(
x_params_test[idx : idx + 1],
dtype=torch.float32,
device=device,
)
schedule_t = torch.tensor(
x_schedule_test[idx : idx + 1],
dtype=torch.float32,
device=device,
)
member_preds = []
for _, model in members:
# 每个成员独立预测后先反标准化;集成均值和标准差都在原始曲线尺度上计算。
@ -279,7 +314,9 @@ def main() -> None:
member_preds.append(pred)
member_preds = np.stack(member_preds, axis=0)
curve_true = scaler_curve.inverse_transform(y_curve_test[idx : idx + 1])[0].astype(np.float32)
curve_true = scaler_curve.inverse_transform(
y_curve_test[idx : idx + 1]
)[0].astype(np.float32)
curve_mean = member_preds.mean(axis=0).astype(np.float32)
curve_std = member_preds.std(axis=0, ddof=0).astype(np.float32)
@ -344,7 +381,12 @@ def main() -> None:
with open(output_dir / "ensemble_uq_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
with open(output_dir / "sample_uncertainty_metrics.csv", "w", newline="", encoding="utf-8-sig") as f:
with open(
output_dir / "sample_uncertainty_metrics.csv",
"w",
newline="",
encoding="utf-8-sig",
) as f:
writer = csv.DictWriter(f, fieldnames=list(sample_rows[0].keys()))
writer.writeheader()
writer.writerows(sample_rows)

@ -7,6 +7,12 @@
from __future__ import annotations
# pylint: disable=import-error,wrong-import-position,line-too-long
# pylint: disable=too-many-arguments,too-many-positional-arguments
# pylint: disable=too-many-locals,too-many-statements
# pylint: disable=invalid-name,unused-argument,too-many-function-args
import argparse
import csv
import json

@ -13,17 +13,25 @@ import random
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import h5py
import joblib
import numpy as np
import torch
import torch.nn.functional as F
from src.common.experiment_paths import model_checkpoint_for_tag, model_dir_for_tag, normalize_tag, processed_path_for_tag
from src.data.param_features import param_feature_transform_from_meta, transform_param_features
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import (
model_checkpoint_for_tag,
model_dir_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.data.param_features import (
param_feature_transform_from_meta,
transform_param_features,
)
from src.models.forward_surrogate import ForwardSurrogate
@ -32,7 +40,12 @@ def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Fine-tune a forward surrogate with local pairwise autofit ranking constraints"
)
parser.add_argument("--neighborhood", type=str, required=True, help="Anchor-neighborhood HDF5 path")
parser.add_argument(
"--neighborhood",
type=str,
required=True,
help="Anchor-neighborhood HDF5 path",
)
parser.add_argument("--base-tag", type=str, default="family_random_mixed_50k_biasfix")
parser.add_argument("--base-processed", type=str, default=None)
parser.add_argument("--base-model", type=str, default=None)
@ -144,8 +157,12 @@ def load_neighborhood_groups(
anchor_params_scaled = scaler_params.transform(
transform_param_features(anchor_params[anchor_id : anchor_id + 1], param_transform)
).astype(np.float32)
anchor_schedule_scaled = scaler_schedule.transform(anchor_schedule[anchor_id : anchor_id + 1]).astype(np.float32)
anchor_curve_scaled = scaler_curve.transform(anchor_curve[anchor_id : anchor_id + 1]).astype(np.float32)
anchor_schedule_scaled = scaler_schedule.transform(
anchor_schedule[anchor_id : anchor_id + 1]
).astype(np.float32)
anchor_curve_scaled = scaler_curve.transform(
anchor_curve[anchor_id : anchor_id + 1]
).astype(np.float32)
groups.append(
{
@ -157,7 +174,9 @@ def load_neighborhood_groups(
"neighbor_params_x": scaler_params.transform(
transform_param_features(neighbor_params[idx], param_transform)
).astype(np.float32),
"neighbor_schedule_x": scaler_schedule.transform(anchor_schedule[neighbor_anchor_id[idx]]).astype(np.float32),
"neighbor_schedule_x": scaler_schedule.transform(
anchor_schedule[neighbor_anchor_id[idx]]
).astype(np.float32),
"neighbor_curve_x": scaler_curve.transform(neighbor_curve[idx]).astype(np.float32),
"neighbor_objective": neighbor_objective[idx].astype(np.float32),
}

@ -4,6 +4,8 @@
扁平化为常规 `params/schedule/curve` 结构方便复用已有预处理评估和合并流程
"""
# pylint: disable=too-many-locals,too-many-statements
from __future__ import annotations
import argparse
@ -11,17 +13,20 @@ import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import h5py
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
def parse_args() -> argparse.Namespace:
"""解析自动拟合邻域 HDF5 的输入输出路径以及是否只导出邻域样本。"""
parser = argparse.ArgumentParser(
description="Flatten anchor-neighborhood autofit HDF5 into the standard sample-level HDF5 format"
description=(
"Flatten anchor-neighborhood autofit HDF5 into the standard "
"sample-level HDF5 format"
)
)
parser.add_argument("--input", type=str, required=True, help="Input autofit neighborhood .h5")
parser.add_argument("--output", type=str, required=True, help="Output flat sample-level .h5")
@ -69,7 +74,7 @@ def main() -> None:
anchor_schedule_meta = np.asarray(src["anchor_schedule_meta"][:], dtype=np.float32)
anchor_family_name = np.asarray(src["anchor_family_name"][:]).astype(str)
anchor_section_index = np.asarray(src["anchor_section_index"][:], dtype=np.int32)
anchor_timeQ_json = np.asarray(src["anchor_timeQ_json"][:]).astype(str)
anchor_time_q_json = np.asarray(src["anchor_time_q_json"][:]).astype(str)
anchor_q_json = np.asarray(src["anchor_q_json"][:]).astype(str)
neighbor_anchor_id = np.asarray(src["neighbor_anchor_id"][:], dtype=np.int32)
@ -85,7 +90,14 @@ def main() -> None:
)
schedule_meta_names = _decode_attr_list(src.attrs.get("schedule_meta_names"))
span_fracs = np.asarray(src.attrs.get("span_fracs", []), dtype=np.float32).reshape(-1).tolist()
span_fracs = (
np.asarray(
src.attrs.get("span_fracs", []),
dtype=np.float32,
)
.reshape(-1)
.tolist()
)
n_anchors = int(anchor_params.shape[0])
n_neighbors = int(neighbor_params.shape[0])
@ -117,7 +129,11 @@ def main() -> None:
dst.create_dataset("curve", shape=(total_rows, curve_dim), dtype=np.float32)
dst.create_dataset("group_id", shape=(total_rows,), dtype=np.int64)
dst.create_dataset("schedule_meta", shape=(total_rows, schedule_meta_dim), dtype=np.float32)
dst.create_dataset("family_name", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
dst.create_dataset(
"family_name",
shape=(total_rows,),
dtype=h5py.string_dtype(encoding="utf-8"),
)
dst.create_dataset("is_anchor", shape=(total_rows,), dtype=np.int8)
dst.create_dataset("neighbor_objective", shape=(total_rows,), dtype=np.float32)
@ -125,8 +141,16 @@ def main() -> None:
dst.create_dataset("neighbor_objective_d", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("neighbor_span_frac", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("section_index", shape=(total_rows,), dtype=np.int32)
dst.create_dataset("timeQ_json", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
dst.create_dataset("q_json", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
dst.create_dataset(
"timeQ_json",
shape=(total_rows,),
dtype=h5py.string_dtype(encoding="utf-8"),
)
dst.create_dataset(
"q_json",
shape=(total_rows,),
dtype=h5py.string_dtype(encoding="utf-8"),
)
write_pos = 0
@ -145,7 +169,7 @@ def main() -> None:
dst["neighbor_objective_d"][write_pos:anchor_end] = 0.0
dst["neighbor_span_frac"][write_pos:anchor_end] = 0.0
dst["section_index"][write_pos:anchor_end] = anchor_section_index
dst["timeQ_json"][write_pos:anchor_end] = anchor_timeQ_json.tolist()
dst["timeQ_json"][write_pos:anchor_end] = anchor_time_q_json.tolist()
dst["q_json"][write_pos:anchor_end] = anchor_q_json.tolist()
write_pos = anchor_end
@ -156,15 +180,23 @@ def main() -> None:
dst["curve"][write_pos:neighbor_end] = neighbor_curve
dst["group_id"][write_pos:neighbor_end] = neighbor_anchor_id
dst["schedule_meta"][write_pos:neighbor_end] = anchor_schedule_meta[neighbor_anchor_id]
dst["family_name"][write_pos:neighbor_end] = anchor_family_name[neighbor_anchor_id].tolist()
dst["family_name"][write_pos:neighbor_end] = anchor_family_name[
neighbor_anchor_id
].tolist()
dst["is_anchor"][write_pos:neighbor_end] = 0
dst["neighbor_objective"][write_pos:neighbor_end] = neighbor_objective
dst["neighbor_objective_p"][write_pos:neighbor_end] = neighbor_objective_p
dst["neighbor_objective_d"][write_pos:neighbor_end] = neighbor_objective_d
dst["neighbor_span_frac"][write_pos:neighbor_end] = neighbor_span_frac
dst["section_index"][write_pos:neighbor_end] = anchor_section_index[neighbor_anchor_id]
dst["timeQ_json"][write_pos:neighbor_end] = anchor_timeQ_json[neighbor_anchor_id].tolist()
dst["q_json"][write_pos:neighbor_end] = anchor_q_json[neighbor_anchor_id].tolist()
dst["section_index"][write_pos:neighbor_end] = anchor_section_index[
neighbor_anchor_id
]
dst["timeQ_json"][write_pos:neighbor_end] = anchor_time_q_json[
neighbor_anchor_id
].tolist()
dst["q_json"][write_pos:neighbor_end] = anchor_q_json[
neighbor_anchor_id
].tolist()
dst.attrs["n_samples"] = int(total_rows)
dst.attrs["source_neighborhood_h5"] = str(input_path)
@ -172,7 +204,10 @@ def main() -> None:
print("Autofit neighborhood flatten complete.")
print(f"Input: {input_path}")
print(f"Output: {output_path}")
print(f"anchors={n_anchors}, neighbors={n_neighbors}, total_exported={total_rows}")
print(
f"anchors={n_anchors}, neighbors={n_neighbors}, "
f"total_exported={total_rows}"
)
print(f"neighbors_only={bool(args.neighbors_only)}")

@ -5,6 +5,17 @@
输出的 HDF5 同时包含曲线参数制度编码和排序训练所需的邻域元数据
"""
# pylint: disable=
# import-error,
# wrong-import-position,
# too-many-locals,
# too-many-arguments,
# too-many-positional-arguments,
# too-many-statements,
# invalid-name,
# broad-exception-caught,
# no-member
from __future__ import annotations
import argparse

@ -3,6 +3,7 @@
脚本读取数据生成配置调用并行数据集生成器批量采样地层/井筒参数和流量制度
运行底层数值求解器并把有效曲线写入 HDF5它是训练前最上游的数据生产入口
"""
# pylint: disable=import-error,wrong-import-position,broad-exception-caught
from __future__ import annotations
import argparse
@ -18,20 +19,31 @@ from src.common.experiment_paths import config_for_stage
from src.data.dataset_generation import ParallelDatasetGenerator
def main():
def main() -> None:
"""按配置阶段启动并行数值试井样本生成,输出原始 HDF5 数据集路径。"""
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=None)
parser.add_argument(
"--stage",
choices=["fixed_case", "case_neighborhood", "family_random", "family_random_hard", "family_random_v2_q"],
choices=[
"fixed_case",
"case_neighborhood",
"family_random",
"family_random_hard",
"family_random_v2_q",
],
default=None,
)
parser.add_argument("--n-samples", type=int, default=None)
parser.add_argument("--n-workers", type=int, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--method", type=str, default=None)
parser.add_argument("--dataset-tag", type=str, default=None, help="Optional tag injected into output dataset filename")
parser.add_argument(
"--dataset-tag",
type=str,
default=None,
help="Optional tag injected into output dataset filename",
)
args = parser.parse_args()
config_path = args.config
@ -41,8 +53,14 @@ def main():
# stage 用来选择预设配置;命令行参数继续覆盖样本数、并行数和随机种子。
cfg = Config(config_path)
cfg.ensure_dirs()
path = ParallelDatasetGenerator(cfg=cfg, n_workers=args.n_workers).generate(
n_samples=args.n_samples, method=args.method, random_seed=args.seed, dataset_tag=args.dataset_tag
path = ParallelDatasetGenerator(
cfg=cfg,
n_workers=args.n_workers,
).generate(
n_samples=args.n_samples,
method=args.method,
random_seed=args.seed,
dataset_tag=args.dataset_tag,
)
print(path)
@ -51,6 +69,6 @@ if __name__ == "__main__":
mp.freeze_support()
try:
mp.set_start_method("spawn", force=True)
except Exception:
except RuntimeError:
pass
main()

@ -5,6 +5,8 @@
提升代理模型在实际反演困难区域的覆盖度
"""
# pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,no-member
from __future__ import annotations
import argparse
@ -12,6 +14,7 @@ import json
import math
import sys
from pathlib import Path
from types import SimpleNamespace
from typing import Any
import h5py
@ -390,15 +393,12 @@ def merge_datasets(
merge_meta: dict[str, Any]
class _Args:
"""内部轻量参数对象,用于把递归合并复用到单文件合并场景。"""
pass
count_args = _Args()
count_args.normal_count = normal_count
count_args.hard_count = hard_count
count_args.total_samples = total_samples
count_args.hard_ratio = hard_ratio
count_args = SimpleNamespace(
normal_count=normal_count,
hard_count=hard_count,
total_samples=total_samples,
hard_ratio=hard_ratio,
)
with h5py.File(normal_input, "r") as normal_file, h5py.File(hard_input, "r") as hard_file:
# 合并前先锁定 schema避免后面边写边发现字段不兼容。

@ -5,6 +5,8 @@
processed 数据文件
"""
# pylint: disable=import-error,wrong-import-position
from __future__ import annotations
import argparse
@ -20,7 +22,9 @@ from src.data.preprocess import preprocess_dataset
def main() -> None:
"""把原始 HDF5 曲线数据切分、标准化并保存为训练用 pkl。"""
parser = argparse.ArgumentParser(description="Preprocess HDF5 dataset for forward surrogate")
parser = argparse.ArgumentParser(
description="Preprocess HDF5 dataset for forward surrogate"
)
parser.add_argument(
"--input",
type=str,
@ -33,7 +37,12 @@ def main() -> None:
default=None,
help="Optional output .pkl path",
)
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
parser.add_argument(
"--tag",
type=str,
default=None,
help="Experiment tag for auto naming",
)
parser.add_argument("--test-size", type=float, default=0.15)
parser.add_argument("--val-size", type=float, default=0.15)
parser.add_argument("--seed", type=int, default=42)
@ -45,8 +54,12 @@ def main() -> None:
args = parser.parse_args()
tag = normalize_tag(args.tag)
# 未显式指定输出路径时,使用 tag 生成与训练脚本约定一致的 processed 文件名。
output_path = Path(args.output) if args.output is not None else processed_path_for_tag(tag)
output_path = (
Path(args.output)
if args.output is not None
else processed_path_for_tag(tag)
)
preprocess_dataset(
input_path=Path(args.input),

@ -5,6 +5,8 @@
保持稳定的自动拟合筛选能力
"""
# pylint: disable=import-error,wrong-import-position,wrong-import-order,too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements,broad-exception-caught
from __future__ import annotations
import argparse
@ -31,7 +33,12 @@ from scripts.validate_autofit_local_ranking import (
run_solver_and_extract_curve,
)
from src.common.config import Config
from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag
from src.common.experiment_paths import (
config_for_stage,
model_checkpoint_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.data.params import Params, Schedule
from src.data.runner_client import CppRunner
from src.evaluation.autofit_objective import dual_log_objective
@ -55,7 +62,17 @@ def parse_args() -> argparse.Namespace:
description="Generate Q schedules around theta* and run first-layer local ranking validation."
)
parser.add_argument("--config", type=str, default=None)
parser.add_argument("--stage", choices=["fixed_case", "case_neighborhood", "family_random", "family_random_hard", "family_random_v2_q"], default="family_random")
parser.add_argument(
"--stage",
choices=[
"fixed_case",
"case_neighborhood",
"family_random",
"family_random_hard",
"family_random_v2_q",
],
default="family_random",
)
parser.add_argument("--processed", type=str, default=None)
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--tag", type=str, default="family_random_mixed_50k_logparam")
@ -149,7 +166,15 @@ def schedule_features(time_q: list[float], q: list[float]) -> dict:
}
def add_case(cases: list[dict], seen: set[tuple], case_id: str, family: str, time_q: list[float], q: list[float], axis: str) -> None:
def add_case(
cases: list[dict],
seen: set[tuple],
case_id: str,
family: str,
time_q: list[float],
q: list[float],
axis: str,
) -> None:
"""向案例列表追加一个流量制度候选及其特征。"""
key = tuple(round(float(x), 8) for x in (time_q + q))
if key in seen:
@ -180,7 +205,15 @@ def generate_q_cases() -> list[dict]:
base_q = [170.0, 170.0, 210.0, 250.0]
base_shutin = 72.0
add_case(cases, seen, "Q000_baseline_b3_alt_like", "mild_step", base_prod_dt + [base_shutin], base_q + [0.0], "baseline")
add_case(
cases,
seen,
"Q000_baseline_b3_alt_like",
"mild_step",
base_prod_dt + [base_shutin],
base_q + [0.0],
"baseline",
)
# 分别扫描生产总时长、关井时长、流量倍率、阶跃强度和生产段数量。
for total in [24.0, 48.0, 72.0, 96.0, 160.0, 220.0, 260.0]:
@ -192,12 +225,28 @@ def generate_q_cases() -> list[dict]:
for scale in [0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0]:
q_prod = [float(x * scale) for x in base_q]
add_case(cases, seen, f"q_scale_{str(scale).replace('.', 'p')}", "mild_step", base_prod_dt + [base_shutin], q_prod + [0.0], "q_scale")
add_case(
cases,
seen,
f"prod_total_{int(total)}",
"mild_step",
time_q,
base_q + [0.0],
"prod_total_time",
)
for ratio in [1.2, 1.6, 2.0, 2.2, 2.8, 3.5]:
q0 = 180.0
q_prod = [q0, q0 * ratio, q0, q0 * ratio]
add_case(cases, seen, f"step_ratio_{str(ratio).replace('.', 'p')}", "sharp_step", base_prod_dt + [base_shutin], q_prod + [0.0], "step_ratio")
add_case(
cases,
seen,
f"step_ratio_{str(ratio).replace('.', 'p')}",
"sharp_step",
base_prod_dt + [base_shutin],
q_prod + [0.0],
"step_ratio",
)
for n_prod in [2, 3, 4, 5, 6]:
total = 84.0
@ -338,7 +387,12 @@ def main() -> None:
for case_idx, q_case in enumerate(q_cases):
schedule: Schedule = q_case["schedule"]
if not schedule.validate():
failed_case_rows.append({**{k: v for k, v in q_case.items() if k != "schedule"}, "reason": "invalid schedule"})
failed_case_rows.append(
{
**{k: v for k, v in q_case.items() if k != "schedule"},
"reason": "invalid schedule",
}
)
continue
target_params = make_theta_params(args, schedule)
@ -359,7 +413,12 @@ def main() -> None:
timeout=int(args.solver_timeout),
)
except Exception as exc:
failed_case_rows.append({**{k: v for k, v in q_case.items() if k != "schedule"}, "reason": str(exc)})
failed_case_rows.append(
{
**{k: v for k, v in q_case.items() if k != "schedule"},
"reason": str(exc),
}
)
print(f" [fail] target solver failed: {exc}")
continue
finally:
@ -436,7 +495,12 @@ def main() -> None:
runner.close()
if len(rows) < 2:
failed_case_rows.append({**{k: v for k, v in q_case.items() if k != "schedule"}, "reason": "fewer than two valid candidates"})
failed_case_rows.append(
{
**{k: v for k, v in q_case.items() if k != "schedule"},
"reason": "fewer than two valid candidates",
}
)
print(f" [fail] valid candidates={len(rows)}")
continue

@ -5,6 +5,9 @@
调用量而不明显损失最优候选
"""
# pylint: disable=import-error,wrong-import-position,wrong-import-order,too-many-locals,too-many-statements,broad-exception-caught
from __future__ import annotations
import argparse
@ -22,10 +25,15 @@ import joblib
import matplotlib.pyplot as plt
import numpy as np
from scripts.compare_single_case import build_schedule_vector, load_model
from scripts.compare_single_case import load_model
from scripts.validate_autofit_local_ranking import infer_curve_layout, predict_surrogate_curve
from src.common.config import Config
from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag
from src.common.experiment_paths import (
config_for_stage,
model_checkpoint_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features
from src.data.params import Params, Schedule
from src.evaluation.autofit_objective import dual_log_objective
@ -189,7 +197,13 @@ def corr_spearman(a: np.ndarray, b: np.ndarray) -> float:
def summarize_generation(rows: list[dict], keep_fracs: list[float]) -> tuple[dict, list[dict]]:
"""按 PSO 代数汇总代理筛选保留真实优质候选的效果。"""
valid = [r for r in rows if r["solver_success"] == "1" and math.isfinite(float(r["solver_objective"])) and float(r["solver_objective"]) < 1e9]
valid = [
r
for r in rows
if r["solver_success"] == "1"
and math.isfinite(float(r["solver_objective"]))
and float(r["solver_objective"]) < 1e9
]
if len(valid) < 2:
return {}, []

@ -4,6 +4,8 @@
再计算与目标曲线的自动拟合目标函数输出可被 PSO 或筛选流程继续使用的打分 CSV
"""
# pylint: disable=import-error,wrong-import-position,wrong-import-order,too-many-locals,broad-exception-caught
from __future__ import annotations
import argparse
@ -21,7 +23,12 @@ import numpy as np
from scripts.compare_single_case import load_model
from scripts.validate_autofit_local_ranking import infer_curve_layout, predict_surrogate_curve
from src.common.config import Config
from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag
from src.common.experiment_paths import (
config_for_stage,
model_checkpoint_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features
from src.data.params import Params, Schedule
from src.evaluation.autofit_objective import dual_log_objective
@ -30,13 +37,21 @@ from src.evaluation.autofit_objective import dual_log_objective
PARAM_COLUMNS = ["k", "skin", "wellboreC", "phi", "h", "Cf"]
def parse_args() -> argparse.Namespace:
"""解析 PSO 候选粒子 CSV、trace 元数据、代理模型和输出评分表路径。"""
parser = argparse.ArgumentParser(description="使用正演代理模型为一批 PSO 候选粒子打分")
parser.add_argument("--candidates", required=True, type=str, help="候选粒子 CSV包含 particle_id,k,skin,wellboreC,phi,h,Cf")
parser.add_argument("--trace-meta", required=True, type=str, help="匹配的 pso_baseline_trace_*.meta.json")
parser.add_argument(
"--candidates",
required=True,
type=str,
help="候选粒子 CSV包含 particle_id,k,skin,wellboreC,phi,h,Cf",
)
parser.add_argument(
"--trace-meta",
required=True,
type=str,
help="匹配的 pso_baseline_trace_*.meta.json",
)
parser.add_argument("--output", required=True, type=str, help="输出代理目标函数 CSV")
parser.add_argument("--tag", type=str, default="family_random_mixed_50k_logparam")
parser.add_argument("--stage", type=str, default="family_random")

@ -4,6 +4,8 @@
随后通过标准输入接收候选 CSV 路径并输出 JSON 状态降低 PSO 与代理模型耦合时
频繁加载 checkpoint 的开销
"""
# pylint: disable=import-error,wrong-import-position,wrong-import-order,too-many-instance-attributes,too-few-public-methods,broad-exception-caught
from __future__ import annotations

@ -5,6 +5,8 @@
代理模型的主训练入口
"""
# pylint: disable=import-error,wrong-import-position
from __future__ import annotations
import argparse
@ -15,7 +17,16 @@ ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import model_dir_for_tag, normalize_tag, processed_path_for_tag
from src.training.train_forward import TrainConfig, train_forward
from src.training.train_forward import (
LossConfig,
LossWeights,
ModelConfig,
OptimConfig,
SampleReweightConfig,
TrainConfig,
TrainRuntime,
train_forward,
)
def main() -> None:
@ -53,6 +64,7 @@ def main() -> None:
parser.add_argument("--huber-beta", type=float, default=0.05)
parser.add_argument("--use-sample-reweight", action="store_true", default=True)
parser.add_argument("--no-sample-reweight", action="store_false", dest="use_sample_reweight")
parser.add_argument("--sample-reweight-alpha", type=float, default=0.4)
parser.add_argument("--sample-weight-min", type=float, default=1.0)
parser.add_argument("--sample-weight-max", type=float, default=2.5)
@ -74,27 +86,37 @@ def main() -> None:
cfg = TrainConfig(
processed_path=processed_path,
output_dir=output_dir,
seed=args.seed,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
hidden_dim=args.hidden_dim,
dropout=args.dropout,
w_pressure=args.w_pressure,
w_derivative=args.w_derivative,
w_slope=args.w_slope,
w_bias_pressure=args.w_bias_pressure,
w_bias_derivative=args.w_bias_derivative,
w_derivative_shape=args.w_derivative_shape,
w_autofit_pressure=args.w_autofit_pressure,
w_autofit_derivative=args.w_autofit_derivative,
huber_beta=args.huber_beta,
use_sample_reweight=args.use_sample_reweight,
sample_reweight_alpha=args.sample_reweight_alpha,
sample_weight_min=args.sample_weight_min,
sample_weight_max=args.sample_weight_max,
use_schedule=use_schedule,
runtime=TrainRuntime(seed=args.seed),
optim=OptimConfig(
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
),
model=ModelConfig(
hidden_dim=args.hidden_dim,
dropout=args.dropout,
use_schedule=use_schedule,
),
loss=LossConfig(
weights=LossWeights(
pressure=args.w_pressure,
derivative=args.w_derivative,
slope=args.w_slope,
bias_pressure=args.w_bias_pressure,
bias_derivative=args.w_bias_derivative,
derivative_shape=args.w_derivative_shape,
autofit_pressure=args.w_autofit_pressure,
autofit_derivative=args.w_autofit_derivative,
),
huber_beta=args.huber_beta,
),
sample_reweight=SampleReweightConfig(
enabled=args.use_sample_reweight,
alpha=args.sample_reweight_alpha,
weight_min=args.sample_weight_min,
weight_max=args.sample_weight_max,
),
)
train_forward(cfg)

@ -4,6 +4,8 @@
独立 seed 和输出目录所得模型集合用于后续不确定性估计误差风险分析和 fallback 筛选
"""
# pylint: disable=import-error,wrong-import-position,duplicate-code
from __future__ import annotations
import argparse
@ -15,7 +17,16 @@ ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
from src.training.train_forward import TrainConfig, train_forward
from src.training.train_forward import (
LossConfig,
LossWeights,
ModelConfig,
OptimConfig,
SampleReweightConfig,
TrainConfig,
TrainRuntime,
train_forward,
)
def parse_seed_list(seed_text: str) -> list[int]:
@ -62,6 +73,7 @@ def main() -> None:
parser.add_argument("--w-autofit-derivative", type=float, default=0.0)
parser.add_argument("--huber-beta", type=float, default=0.05)
parser.add_argument("--use-sample-reweight", action="store_true", default=True)
parser.add_argument("--no-sample-reweight", action="store_false", dest="use_sample_reweight")
parser.add_argument("--sample-reweight-alpha", type=float, default=0.4)
parser.add_argument("--sample-weight-min", type=float, default=1.0)
parser.add_argument("--sample-weight-max", type=float, default=2.5)
@ -91,27 +103,37 @@ def main() -> None:
cfg = TrainConfig(
processed_path=processed_path,
output_dir=member_dir,
seed=seed,
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
hidden_dim=args.hidden_dim,
dropout=args.dropout,
w_pressure=args.w_pressure,
w_derivative=args.w_derivative,
w_slope=args.w_slope,
w_bias_pressure=args.w_bias_pressure,
w_bias_derivative=args.w_bias_derivative,
w_derivative_shape=args.w_derivative_shape,
w_autofit_pressure=args.w_autofit_pressure,
w_autofit_derivative=args.w_autofit_derivative,
huber_beta=args.huber_beta,
use_sample_reweight=args.use_sample_reweight,
sample_reweight_alpha=args.sample_reweight_alpha,
sample_weight_min=args.sample_weight_min,
sample_weight_max=args.sample_weight_max,
use_schedule=use_schedule,
runtime=TrainRuntime(seed=seed),
optim=OptimConfig(
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
),
model=ModelConfig(
hidden_dim=args.hidden_dim,
dropout=args.dropout,
use_schedule=use_schedule,
),
loss=LossConfig(
weights=LossWeights(
pressure=args.w_pressure,
derivative=args.w_derivative,
slope=args.w_slope,
bias_pressure=args.w_bias_pressure,
bias_derivative=args.w_bias_derivative,
derivative_shape=args.w_derivative_shape,
autofit_pressure=args.w_autofit_pressure,
autofit_derivative=args.w_autofit_derivative,
),
huber_beta=args.huber_beta,
),
sample_reweight=SampleReweightConfig(
enabled=args.use_sample_reweight,
alpha=args.sample_reweight_alpha,
weight_min=args.sample_weight_min,
weight_max=args.sample_weight_max,
),
)
train_forward(cfg)
manifest["members"].append(

@ -5,6 +5,8 @@
适合处理可变时间采样或逐点推理场景
"""
# pylint: disable=import-error,wrong-import-position
from __future__ import annotations
import argparse
@ -15,7 +17,15 @@ ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
from src.training.train_time_conditioned import TimeConditionedTrainConfig, train_time_conditioned
from src.training.train_time_conditioned import (
RiskWeightConfig,
TimeConditionedTrainConfig,
TimeLossConfig,
TimeModelConfig,
TimeOptimConfig,
TimeRuntimeConfig,
train_time_conditioned,
)
def main() -> None:
@ -61,23 +71,31 @@ def main() -> None:
cfg = TimeConditionedTrainConfig(
processed_path=processed_path,
output_dir=output_dir,
seed=int(args.seed),
batch_size=int(args.batch_size),
epochs=int(args.epochs),
lr=float(args.lr),
weight_decay=float(args.weight_decay),
hidden_dim=int(args.hidden_dim),
n_blocks=int(args.n_blocks),
dropout=float(args.dropout),
w_pressure=float(args.w_pressure),
w_derivative=float(args.w_derivative),
huber_beta=float(args.huber_beta),
use_schedule=not bool(args.no_schedule),
sample_weight_mode=str(args.sample_weight_mode),
risk_weight=float(args.risk_weight),
skin_lt_minus8_weight=float(args.skin_lt_minus8_weight),
sample_weight_min=float(args.sample_weight_min),
sample_weight_max=float(args.sample_weight_max),
runtime=TimeRuntimeConfig(seed=int(args.seed)),
optim=TimeOptimConfig(
batch_size=int(args.batch_size),
epochs=int(args.epochs),
lr=float(args.lr),
weight_decay=float(args.weight_decay),
),
model=TimeModelConfig(
hidden_dim=int(args.hidden_dim),
n_blocks=int(args.n_blocks),
dropout=float(args.dropout),
use_schedule=not bool(args.no_schedule),
),
loss=TimeLossConfig(
w_pressure=float(args.w_pressure),
w_derivative=float(args.w_derivative),
huber_beta=float(args.huber_beta),
),
risk_weight=RiskWeightConfig(
mode=str(args.sample_weight_mode),
risk_weight=float(args.risk_weight),
skin_lt_minus8_weight=float(args.skin_lt_minus8_weight),
weight_min=float(args.sample_weight_min),
weight_max=float(args.sample_weight_max),
),
)
train_time_conditioned(cfg)

@ -5,6 +5,8 @@
它直接回答代理模型能否用于 PSO 候选预筛选这个问题
"""
# pylint: disable=import-error,wrong-import-position,invalid-name,too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements,broad-exception-caught,no-member
from __future__ import annotations
import argparse
@ -103,12 +105,12 @@ def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
def resolve_case_schedule(cfg: Config) -> Schedule:
"""解析局部排序验证目标案例的流量制度,支持命令行覆盖配置默认值。"""
schedule_cfg = cfg.raw["schedule"]["case_schedule"]
timeQ = list(map(float, schedule_cfg["timeQ"]))
time_q = list(map(float, schedule_cfg["timeQ"]))
q = list(map(float, schedule_cfg["q"]))
policy = cfg.raw["schedule"].get("section_policy", {}) or {}
mode = str(policy.get("mode", "fixed_last")).lower()
n_sections = len(timeQ)
n_sections = len(time_q)
if mode == "fixed_last":
section_index = n_sections
@ -117,7 +119,7 @@ def resolve_case_schedule(cfg: Config) -> Schedule:
else:
section_index = int(np.clip(int(schedule_cfg.get("default_section_index", n_sections)), 1, n_sections))
schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q)
schedule = Schedule(sectionIndex=section_index, timeQ=time_q, q=q)
if not schedule.validate():
raise ValueError("Invalid case_schedule in config")
return schedule

@ -4,6 +4,7 @@
重复采样局部候选汇总代理目标与真实目标之间的排序相关性保留比例和失败案例
用于给出更稳健的 PSO 预筛选可用性判断
"""
# pylint: disable=import-error,wrong-import-position,wrong-import-order,invalid-name,too-many-locals,too-many-branches,too-many-statements,broad-exception-caught,no-member
from __future__ import annotations
@ -95,10 +96,10 @@ def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]:
def sample_target_case(cfg: Config, rng: np.random.RandomState, seed: int) -> Params:
"""从配置或数据集中抽取一个目标案例,用于局部排序验证。"""
params = generate_params_dataset(cfg, n_samples=1, method=cfg.raw["params"].get("sampling_method", "sobol"), random_seed=seed)[0]
timeQ, q, _sched_info = sample_schedule_by_mode(cfg, rng)
section_indices = _resolve_section_indices(cfg, timeQ, q, rng)
time_q, q, _sched_info = sample_schedule_by_mode(cfg, rng)
section_indices = _resolve_section_indices(cfg, time_q, q, rng)
section_index = int(section_indices[int(rng.randint(0, len(section_indices)))])
params.schedule = Schedule(sectionIndex=section_index, timeQ=list(map(float, timeQ)), q=list(map(float, q)))
params.schedule = Schedule(sectionIndex=section_index, timeQ=list(map(float, time_q)), q=list(map(float, q)))
return params

@ -30,6 +30,7 @@ class ProjectPaths:
class Config:
# pylint: disable=too-many-instance-attributes
"""读取 YAML 配置,并把项目根目录、数据目录、模型目录等派生为可复用属性。"""
def __init__(self, config_path: str | Path) -> None:
"""读取 YAML 配置,解析项目根目录以及训练、求解器和数据文件路径。"""
@ -85,10 +86,9 @@ class Config:
@property
def curve_dim(self) -> int:
"""返回配置中的曲线输出维度。"""
T = int(self.get("curve_processing", "n_time_points", default=160))
n_time_points = int(self.get("curve_processing", "n_time_points", default=160))
use_slope = bool(self.get("curve_processing", "use_slope_feature", default=True))
# 曲线至少包含 pressure 和 derivative启用 slope 时再追加一段辅助特征。
return (2 + (1 if use_slope else 0)) * T
return (2 + (1 if use_slope else 0)) * n_time_points
@property
def sec_feat_dim(self) -> int:
@ -98,13 +98,14 @@ class Config:
@property
def schedule_grid_shape(self) -> tuple[int, int]:
"""返回流量制度固定时间网格的形状配置。"""
Nu = int(self.get("timegrid_encoding", "n_u_points", default=256))
Cu = 1
# 每开启一个附加通道,固定网格的通道数就增加一维。
n_u_points = int(self.get("timegrid_encoding", "n_u_points", default=256))
n_channels = 1
if bool(self.get("timegrid_encoding", "include_cum", default=True)):
Cu += 1
n_channels += 1
if bool(self.get("timegrid_encoding", "include_dq", default=True)):
Cu += 1
n_channels += 1
if bool(self.get("timegrid_encoding", "include_shutin", default=False)):
Cu += 1
return Nu, Cu
n_channels += 1
return n_u_points, n_channels

@ -9,6 +9,8 @@ C++ 数值试井求解器输出的双对数曲线通常是不等长时间序列
的维度和顺序必须与 Config.curve_dimmeta.curve_layout 保持一致
"""
# pylint: disable=import-error,too-many-arguments,too-many-positional-arguments,too-many-return-statements,too-many-locals
from __future__ import annotations
from typing import Optional, Tuple
@ -91,6 +93,8 @@ def clean_curve_for_dataset(
操作可能掩盖真实的求解质量问题更严格的有效性判断交给 is_valid_curve 完成
"""
c = cfg.raw["curve_processing"]
# 该参数保留用于兼容旧调用接口;当前版本只做保守清洗,不做异常值平滑。
_ = outlier_factor
min_len = int(min_len if min_len is not None else c.get("min_valid_points", 30))
eps = float(eps if eps is not None else c.get("feature_epsilon", 1e-12))
@ -140,7 +144,7 @@ def _make_time_grid(cfg: Config, t0: float, t1: float, n: int) -> np.ndarray:
def curve_time_grid_for_sample(cfg: Config, t: np.ndarray, n: Optional[int] = None) -> np.ndarray:
"""为单个样本生成与曲线特征对齐的时间坐标。"""
nT = int(n if n is not None else cfg.get("curve_processing", "n_time_points", default=160))
n_time_points = int(n if n is not None else cfg.get("curve_processing", "n_time_points", default=160))
t = np.asarray(t, dtype=np.float64).reshape(-1)
if t.size < 2:
raise RuntimeError("curve_time_grid_for_sample: time array is too short")
@ -151,7 +155,7 @@ def curve_time_grid_for_sample(cfg: Config, t: np.ndarray, n: Optional[int] = No
else:
t0 = float(np.min(t))
t1 = float(np.max(t))
return _make_time_grid(cfg, t0, t1, nT).astype(np.float32)
return _make_time_grid(cfg, t0, t1, n_time_points).astype(np.float32)
def resample_curve_to_features_with_time(
@ -170,7 +174,7 @@ def resample_curve_to_features_with_time(
[cfg.curve_dim]第二个数组是该样本使用的时间网格形状为 [n_time_points]
"""
eps = float(cfg.get("curve_processing", "feature_epsilon", default=1e-12))
nT = int(cfg.get("curve_processing", "n_time_points", default=160))
n_time_points = int(cfg.get("curve_processing", "n_time_points", default=160))
use_slope = bool(cfg.get("curve_processing", "use_slope_feature", default=True))
t = np.asarray(t, dtype=np.float64).reshape(-1)
@ -180,7 +184,7 @@ def resample_curve_to_features_with_time(
if t.size < 2:
raise RuntimeError("resample_curve_to_features: 时间点过少")
grid = curve_time_grid_for_sample(cfg, t, n=nT).astype(np.float64)
grid = curve_time_grid_for_sample(cfg, t, n=n_time_points).astype(np.float64)
logp = np.log(np.maximum(p, eps))
logd = np.log(np.maximum(np.abs(d), eps))
@ -194,8 +198,8 @@ def resample_curve_to_features_with_time(
if use_slope:
# slope 作为辅助特征保留;即使训练权重设为 0也能保持数据布局兼容。
logt = np.log(np.maximum(grid, eps)).astype(np.float64)
s = np.zeros((nT,), dtype=np.float32)
if nT >= 3:
s = np.zeros((n_time_points,), dtype=np.float32)
if n_time_points >= 3:
denom = logt[2:] - logt[:-2]
denom = np.maximum(denom, 1e-12)
s[1:-1] = ((lp[2:] - lp[:-2]) / denom).astype(np.float32)

@ -11,6 +11,8 @@ HDF5 流式写入,是构建正演代理模型训练集的主流程。为了支
区域分析和自动拟合候选排序验证
"""
# pylint: disable=import-error,invalid-name,too-many-locals,too-many-branches,too-many-statements,too-many-return-statements,too-many-instance-attributes,too-many-arguments,too-many-positional-arguments,broad-exception-caught,no-member,global-statement,too-few-public-methods
from __future__ import annotations
import json
@ -41,11 +43,11 @@ except ImportError:
def update(self, _=1):
"""空进度条的 update 方法,保持与 tqdm.update 接口一致。"""
pass
return None
def close(self):
"""空进度条的 close 方法,保持与 tqdm.close 接口一致。"""
pass
return None
return _NoopTqdm(**kwargs)
return iterable
@ -250,6 +252,7 @@ def build_schedule_metadata(
def _sample_schedule_fixed_case(cfg: Config, rng) -> Tuple[List[float], List[float], Dict[str, Any]]:
"""返回配置中固定的流量制度,用于基准案例或可复现实验。"""
_ = rng
sc = cfg.raw["schedule"]["case_schedule"]
return (
list(map(float, sc["timeQ"])),
@ -360,6 +363,7 @@ def sample_schedule_by_mode(cfg: Config, rng) -> Tuple[List[float], List[float],
def _resolve_section_indices(cfg: Config, timeQ, q, rng):
"""解析允许作为 sectionIndex 的分段范围,并裁剪到当前流量制度长度内。"""
_ = q
policy = cfg.raw["schedule"]["section_policy"]
mode = str(policy["mode"]).lower()
n = int(len(timeQ))
@ -439,6 +443,7 @@ def _worker_simulate_parallel(args):
encoding="utf-8",
errors="ignore",
env=_SUBPROC_ENV,
check=False,
)
if result.returncode != 0 or not _RESULT_BIN.exists():
@ -580,8 +585,8 @@ class HDF5Appender:
if not samples:
return
B = len(samples)
start, end = self._n, self._n + B
batch_size = len(samples)
start, end = self._n, self._n + batch_size
self.d_group_id.resize((end,))
self.d_params.resize((end, self.param_dim))
@ -591,13 +596,13 @@ class HDF5Appender:
self.d_schedule_meta.resize((end, len(SCHEDULE_META_NAMES)))
self.d_family_name.resize((end,))
gid = np.full((B,), -1, dtype=np.int64)
params = np.full((B, self.param_dim), np.nan, dtype=np.float32)
schedule = np.full((B, self.schedule_dim), np.nan, dtype=np.float32)
curve = np.full((B, self.curve_dim), np.nan, dtype=np.float32)
curve_time = np.full((B, self.time_dim), np.nan, dtype=np.float32)
schedule_meta = np.full((B, len(SCHEDULE_META_NAMES)), np.nan, dtype=np.float32)
family_name: list[str] = ["" for _ in range(B)]
gid = np.full((batch_size,), -1, dtype=np.int64)
params = np.full((batch_size, self.param_dim), np.nan, dtype=np.float32)
schedule = np.full((batch_size, self.schedule_dim), np.nan, dtype=np.float32)
curve = np.full((batch_size, self.curve_dim), np.nan, dtype=np.float32)
curve_time = np.full((batch_size, self.time_dim), np.nan, dtype=np.float32)
schedule_meta = np.full((batch_size, len(SCHEDULE_META_NAMES)), np.nan, dtype=np.float32)
family_name: list[str] = ["" for _ in range(batch_size)]
for i, s in enumerate(samples):
gid[i] = int(s.get("group_id", -1))
@ -653,6 +658,7 @@ class ParallelDatasetGenerator:
stderr=subprocess.PIPE,
encoding="utf-8",
errors="ignore",
check=False,
)
if result.returncode == 0 and self.cfg.dataset_bin.exists():
return True
@ -697,8 +703,8 @@ class ParallelDatasetGenerator:
filepath = self.output_dir / f"dataset_{mode}_{sec_mode}_target{n_samples}_{timestamp}.h5"
curve_dim = cfg.curve_dim
Nu, Cu = cfg.schedule_grid_shape
sched_dim = Nu * Cu + cfg.sec_feat_dim
n_u_points, n_channels = cfg.schedule_grid_shape
sched_dim = n_u_points * n_channels + cfg.sec_feat_dim
if bool(cfg.raw.get("schedule", {}).get("use_metadata_features_for_model", False)):
sched_dim += len(cfg.raw.get("schedule", {}).get("metadata_features_for_model", []) or [])
param_dim = 6
@ -853,7 +859,7 @@ class ParallelDatasetGenerator:
while in_flight:
fut = in_flight.popleft()
try:
task_idx, sample, fail_reason, fail_ctx = fut.result()
task_idx, sample, fail_reason, _fail_ctx = fut.result()
except Exception:
fail_reasons["future_exception"] = fail_reasons.get("future_exception", 0) + 1
else:

@ -10,6 +10,8 @@
恢复原始参数含义
"""
# pylint: disable=too-many-locals,invalid-name
from __future__ import annotations
from typing import Any
@ -187,4 +189,3 @@ def inverse_transform_param_features(
else:
raise ValueError(f"Unknown transform mode for {name}: {mode}")
return out.astype(np.float32)

@ -11,6 +11,8 @@ params.bin因此这里的二进制布局必须与求解器端严格一致。
from __future__ import annotations
# pylint: disable=import-error,invalid-name,too-many-return-statements,too-many-locals,no-member,broad-exception-caught,import-outside-toplevel
import os
import struct
from dataclasses import dataclass, asdict
@ -104,10 +106,10 @@ class Params:
sch = self.schedule.clipped(int(cfg.get("schedule", "max_points", default=512)))
if not sch.validate():
raise ValueError("Invalid schedule extension")
nQ = len(sch.timeQ)
b += struct.pack("<II", int(sch.sectionIndex) & 0xFFFFFFFF, nQ & 0xFFFFFFFF)
b += struct.pack("<" + "d" * nQ, *map(float, sch.timeQ))
b += struct.pack("<" + "d" * nQ, *map(float, sch.q))
n_q = len(sch.timeQ)
b += struct.pack("<II", int(sch.sectionIndex) & 0xFFFFFFFF, n_q & 0xFFFFFFFF)
b += struct.pack("<" + "d" * n_q, *map(float, sch.timeQ))
b += struct.pack("<" + "d" * n_q, *map(float, sch.q))
return b
@ -232,8 +234,8 @@ def _qmc_unit(n: int, d: int, method: str, seed: int | None) -> np.ndarray:
if method == "sobol":
sampler = qmc.Sobol(d=d, scramble=True, seed=seed)
m = int(np.ceil(np.log2(max(n, 1))))
U = sampler.random_base2(m=m)
return U[:n]
unit_samples = sampler.random_base2(m=m)
return unit_samples[:n]
if method == "lhs":
sampler = qmc.LatinHypercube(d=d, seed=seed)
return sampler.random(n=n)
@ -267,11 +269,11 @@ def generate_params_dataset(cfg: Config, n_samples: int, method: str | None = No
method = (method or cfg.raw["params"].get("sampling_method", "sobol")).lower()
active_names = list(cfg.raw["params"]["active_param_names"])
log_params = set(cfg.raw["params"]["log_params"])
U = _qmc_unit(n_samples, len(active_names), method, random_seed)
unit_samples = _qmc_unit(n_samples, len(active_names), method, random_seed)
out: list[Params] = []
seen = set()
for row in U:
for row in unit_samples:
sampled_vals: Dict[str, float] = {}
for i, name in enumerate(active_names):
u = float(row[i])

@ -9,6 +9,8 @@
供数据集生成候选参数评分和调试工具复用
"""
# pylint: disable=import-error,too-many-instance-attributes,broad-exception-caught,consider-using-with,invalid-name
from __future__ import annotations
import os
@ -112,6 +114,7 @@ class CppRunner:
encoding="utf-8",
errors="ignore",
env=self._subproc_env(),
check=False,
)
if result.returncode != 0:
@ -234,6 +237,7 @@ class CppRunner:
encoding="utf-8",
errors="ignore",
env=self._subproc_env(),
check=False,
)
return bool(result.returncode == 0 and self.result_bin.exists())
@ -252,22 +256,28 @@ def read_result_bin(result_bin_path: Path) -> Optional[Dict[str, Any]]:
if magic != expected_magic or version != 1:
return None
nWells, nSteps = struct.unpack("<II", f.read(8))
t = list(struct.unpack("<" + "d" * nSteps, f.read(8 * nSteps)))
n_wells, n_steps = struct.unpack("<II", f.read(8))
t = list(struct.unpack("<" + "d" * n_steps, f.read(8 * n_steps)))
pw = []
for _ in range(nWells):
arr = list(struct.unpack("<" + "d" * nSteps, f.read(8 * nSteps)))
for _ in range(n_wells):
arr = list(struct.unpack("<" + "d" * n_steps, f.read(8 * n_steps)))
pw.append(arr)
loglog = []
for _ in range(nWells):
(nLogLog,) = struct.unpack("<I", f.read(4))
loglog_t = list(struct.unpack("<" + "d" * nLogLog, f.read(8 * nLogLog)))
loglog_p = list(struct.unpack("<" + "d" * nLogLog, f.read(8 * nLogLog)))
loglog_deriv = list(struct.unpack("<" + "d" * nLogLog, f.read(8 * nLogLog)))
for _ in range(n_wells):
(n_loglog,) = struct.unpack("<I", f.read(4))
loglog_t = list(struct.unpack("<" + "d" * n_loglog, f.read(8 * n_loglog)))
loglog_p = list(struct.unpack("<" + "d" * n_loglog, f.read(8 * n_loglog)))
loglog_deriv = list(struct.unpack("<" + "d" * n_loglog, f.read(8 * n_loglog)))
loglog.append({"t": loglog_t, "p": loglog_p, "deriv": loglog_deriv})
return {"nWells": nWells, "nSteps": nSteps, "t": t, "pw": pw, "loglog": loglog}
return {
"nWells": n_wells,
"nSteps": n_steps,
"t": t,
"pw": pw,
"loglog": loglog,
}
except Exception:
return None

@ -6,14 +6,18 @@
流量跳变和关井标记等辅助通道
编码结果由 x_sched x_sec 两部分组成前者描述整条制度在时间网格上的形态
后者描述当前 sectionIndex 及总时长等段级上下文
后者描述当前 section_index 及总时长等段级上下文
"""
# pylint: disable=import-error
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Optional, Tuple
import numpy as np
from src.common.config import Config
@ -22,80 +26,126 @@ class EncodedSchedule:
"""固定长度流量制度编码结果。
x_sched 是按时间网格展开后的多通道制度特征通常包含 q(t)累计产量流量跳变
和关井标记x_sec 是段级上下文特征描述当前 sectionIndex 在整条制度中的位置
和关井标记x_sec 是段级上下文特征描述当前 section_index 在整条制度中的位置
制度总段数和总时长等信息
"""
x_sched: np.ndarray
x_sec: np.ndarray
def canonicalize_schedule(cfg: Config, timeQ: List[float], q: List[float]) -> Tuple[np.ndarray, np.ndarray, bool]:
# pylint: disable=too-many-locals
def canonicalize_schedule(
cfg: Config,
time_q: List[float],
q: List[float],
) -> Tuple[np.ndarray, np.ndarray, bool]:
"""规范化变长流量制度。
该步骤把输入裁剪到求解器允许的最大段数保证时长为正流量非负并按配置合并
相邻且流量近似相同的分段规范化后再进入固定网格编码可减少同一制度因为
无意义细碎分段而得到不同特征的情况
"""
max_points = int(cfg.get("schedule", "max_points", default=512))
dt = np.asarray(list(map(float, timeQ)), dtype=np.float64).reshape(-1)[:max_points]
qq = np.asarray(list(map(float, q)), dtype=np.float64).reshape(-1)[:max_points]
dt = np.asarray(
list(map(float, time_q)),
dtype=np.float64,
).reshape(-1)[:max_points]
qq = np.asarray(
list(map(float, q)),
dtype=np.float64,
).reshape(-1)[:max_points]
dt = np.maximum(dt, 1e-12)
qq = np.maximum(qq, 0.0)
cf = cfg.raw["schedule"].get("canonicalize_for_model", {}) or {}
q_thr = float(cf.get("q_thr", 1e-6))
merge_same_q = bool(cf.get("merge_same_q", True))
merge_rel_tol = float(cf.get("merge_rel_tol", 1e-4))
remove_shutin = bool(cf.get("remove_shutin", False))
canonical_cfg = (
cfg.raw["schedule"].get("canonicalize_for_model", {}) or {}
)
q_threshold = float(canonical_cfg.get("q_thr", 1e-6))
merge_same_q = bool(canonical_cfg.get("merge_same_q", True))
merge_rel_tol = float(canonical_cfg.get("merge_rel_tol", 1e-4))
remove_shutin = bool(canonical_cfg.get("remove_shutin", False))
has_shutin = bool(np.any(qq <= q_thr))
has_shutin = bool(np.any(qq <= q_threshold))
if merge_same_q and dt.size >= 2:
# 合并相邻且流量几乎相同的流动段,让编码器关注真正有意义的流量变化。
new_dt, new_q = [], []
cur_dt, cur_q = float(dt[0]), float(qq[0])
new_dt = []
new_q = []
current_dt = float(dt[0])
current_q = float(qq[0])
def close(a: float, b: float) -> bool:
def close(value_a: float, value_b: float) -> bool:
"""判断两个相邻流量是否可视为同一平台段。"""
denom = max(abs(a), abs(b), 1.0)
return abs(a - b) / denom <= merge_rel_tol
for i in range(1, dt.size):
if close(float(qq[i]), cur_q):
cur_dt += float(dt[i])
denominator = max(abs(value_a), abs(value_b), 1.0)
return abs(value_a - value_b) / denominator <= merge_rel_tol
for idx in range(1, dt.size):
if close(float(qq[idx]), current_q):
current_dt += float(dt[idx])
else:
new_dt.append(cur_dt)
new_q.append(cur_q)
cur_dt = float(dt[i])
cur_q = float(qq[i])
new_dt.append(cur_dt)
new_q.append(cur_q)
new_dt.append(current_dt)
new_q.append(current_q)
current_dt = float(dt[idx])
current_q = float(qq[idx])
new_dt.append(current_dt)
new_q.append(current_q)
dt = np.asarray(new_dt, dtype=np.float64)
qq = np.asarray(new_q, dtype=np.float64)
if remove_shutin and dt.size >= 2:
m = qq > q_thr
if int(np.sum(m)) >= 2:
dt = dt[m]
qq = qq[m]
valid_mask = qq > q_threshold
if int(np.sum(valid_mask)) >= 2:
dt = dt[valid_mask]
qq = qq[valid_mask]
return dt, qq, has_shutin
def _make_u_grid(cfg: Config, T_total: float, Nu: int, mode: str) -> np.ndarray:
def _make_u_grid(
cfg: Config,
total_time: float,
n_u_points: int,
mode: str,
) -> np.ndarray:
"""生成 [0, 1] 上的固定网格,用于把变长流量制度编码成定长向量。"""
t_min = float((cfg.raw["schedule"].get("obs_window", {}) or {}).get("t_min", 1e-6))
T_total = max(float(T_total), t_min * 2.0)
obs_window_cfg = cfg.raw["schedule"].get("obs_window", {}) or {}
t_min = float(obs_window_cfg.get("t_min", 1e-6))
total_time = max(float(total_time), t_min * 2.0)
if mode == "linear":
return np.linspace(t_min, T_total, Nu, dtype=np.float64)
return np.geomspace(t_min, T_total, Nu, dtype=np.float64)
return np.linspace(
t_min,
total_time,
n_u_points,
dtype=np.float64,
)
return np.geomspace(
t_min,
total_time,
n_u_points,
dtype=np.float64,
)
# pylint: disable=too-many-locals,too-many-statements
def encode_schedule_to_timegrid(
cfg: Config,
sectionIndex: int,
timeQ: List[float],
section_index: int,
time_q: List[float],
q: List[float],
n_sections: Optional[int] = None,
) -> EncodedSchedule:
@ -105,84 +155,246 @@ def encode_schedule_to_timegrid(
流量积分相邻网格流量差分关井标记等通道最后再拼接 5 section 级特征
让模型知道当前样本对应哪一个生产/关井分段而不只看到整条制度形状
"""
dt, qq, has_shutin = canonicalize_schedule(cfg, timeQ, q)
dt, qq, has_shutin = canonicalize_schedule(
cfg,
time_q,
q,
)
if dt.size < 2 or qq.size != dt.size:
raise ValueError("invalid schedule after canonicalize")
sec = int(max(1, sectionIndex))
N = int(n_sections) if n_sections is not None else int(dt.size)
N = max(N, int(dt.size))
T_total = float(np.sum(dt))
enc = cfg.raw["timegrid_encoding"]
Nu = int(enc.get("n_u_points", 256))
grid_mode = str(enc.get("grid", "log")).lower()
include_cum = bool(enc.get("include_cum", True))
include_dq = bool(enc.get("include_dq", True))
include_shutin = bool(enc.get("include_shutin", False))
q_eps = float(enc.get("q_eps", 1e-12))
t_edges = np.concatenate([[0.0], np.cumsum(dt)], axis=0)
u = _make_u_grid(cfg, T_total=T_total, Nu=Nu, mode=("linear" if grid_mode == "linear" else "log"))
idx = np.searchsorted(t_edges[1:], u, side="right")
idx = np.clip(idx, 0, dt.size - 1)
q_u = qq[idx].astype(np.float64)
# q(t) 是主通道cum(t)、dq(t) 和关井标志是可选辅助通道,
# 用于帮助代理模型区分不同流量制度形态。
section_index = int(max(1, section_index))
n_sections_total = (
int(n_sections)
if n_sections is not None
else int(dt.size)
)
n_sections_total = max(n_sections_total, int(dt.size))
total_time = float(np.sum(dt))
encoding_cfg = cfg.raw["timegrid_encoding"]
n_u_points = int(encoding_cfg.get("n_u_points", 256))
grid_mode = str(
encoding_cfg.get("grid", "log"),
).lower()
include_cum = bool(
encoding_cfg.get("include_cum", True),
)
include_dq = bool(
encoding_cfg.get("include_dq", True),
)
include_shutin = bool(
encoding_cfg.get("include_shutin", False),
)
q_eps = float(
encoding_cfg.get("q_eps", 1e-12),
)
time_edges = np.concatenate(
[[0.0], np.cumsum(dt)],
axis=0,
)
u_grid = _make_u_grid(
cfg,
total_time=total_time,
n_u_points=n_u_points,
mode=(
"linear"
if grid_mode == "linear"
else "log"
),
)
interval_idx = np.searchsorted(
time_edges[1:],
u_grid,
side="right",
)
interval_idx = np.clip(
interval_idx,
0,
dt.size - 1,
)
q_u = qq[interval_idx].astype(np.float64)
cum_u = None
if include_cum:
prefix = np.concatenate([[0.0], np.cumsum(dt * qq)], axis=0)
cum_u = prefix[idx] + (u - t_edges[:-1][idx]) * qq[idx]
prefix = np.concatenate(
[[0.0], np.cumsum(dt * qq)],
axis=0,
)
cum_u = (
prefix[interval_idx]
+ (
u_grid
- time_edges[:-1][interval_idx]
)
* qq[interval_idx]
)
dq_u = None
if include_dq:
dq_u = np.zeros_like(u, dtype=np.float64)
dq_u = np.zeros_like(
u_grid,
dtype=np.float64,
)
dq_u[1:] = np.diff(q_u)
dq_u[0] = dq_u[1] if dq_u.size > 1 else 0.0
dq_u[0] = (
dq_u[1]
if dq_u.size > 1
else 0.0
)
shut_u = None
if include_shutin:
thr = float(enc.get("shutin_thr", 1e-6))
shut_u = (q_u <= thr).astype(np.float64)
norm_mode = str(enc.get("normalize_mode", "global")).lower()
if norm_mode == "per_sample":
q_scale = max(float(np.max(q_u)), q_eps)
cum_scale = max(float(cum_u.max()) if cum_u is not None else 1.0, q_eps)
shutin_thr = float(
encoding_cfg.get("shutin_thr", 1e-6),
)
shut_u = (
q_u <= shutin_thr
).astype(np.float64)
normalize_mode = str(
encoding_cfg.get(
"normalize_mode",
"global",
)
).lower()
if normalize_mode == "per_sample":
q_scale = max(
float(np.max(q_u)),
q_eps,
)
cum_scale = max(
(
float(cum_u.max())
if cum_u is not None
else 1.0
),
q_eps,
)
else:
q_scale = float(enc.get("q_global_max", 1.0))
cum_scale = float(enc.get("cum_global_max", 1.0))
q_scale = float(
encoding_cfg.get(
"q_global_max",
1.0,
)
)
cum_scale = float(
encoding_cfg.get(
"cum_global_max",
1.0,
)
)
q_u = q_u / max(q_scale, q_eps)
if cum_u is not None:
cum_u = cum_u / max(cum_scale, q_eps)
cum_u = cum_u / max(
cum_scale,
q_eps,
)
if dq_u is not None:
dq_u = dq_u / max(q_scale, q_eps)
dq_u = dq_u / max(
q_scale,
q_eps,
)
channels = [q_u]
chans = [q_u]
if include_cum:
chans.append(cum_u if cum_u is not None else np.zeros_like(q_u))
channels.append(
cum_u
if cum_u is not None
else np.zeros_like(q_u)
)
if include_dq:
chans.append(dq_u if dq_u is not None else np.zeros_like(q_u))
channels.append(
dq_u
if dq_u is not None
else np.zeros_like(q_u)
)
if include_shutin:
chans.append(shut_u if shut_u is not None else np.zeros_like(q_u))
channels.append(
shut_u
if shut_u is not None
else np.zeros_like(q_u)
)
x_grid = np.stack(
channels,
axis=1,
)
X = np.stack(chans, axis=1)
x_sched = X.reshape(-1).astype(np.float32)
x_sched = x_grid.reshape(-1).astype(np.float32)
# 时间网格之外额外拼接 5 个流动段级别特征,
# 用于告诉模型当前预测的是哪个流动段对应的双对数曲线。
n_sec = int(max(1, N))
sec_clamped = int(np.clip(sec, 1, n_sec))
section_pos = float((sec_clamped - 1) / max(n_sec - 1, 1))
sectionIndex_norm = float(sec_clamped / max(n_sec, 1))
n_sections_norm = float(n_sec / max(12, 1))
log_T_total = float(np.log(max(T_total, 1e-12)))
n_sections_total = int(
max(1, n_sections_total),
)
section_clamped = int(
np.clip(
section_index,
1,
n_sections_total,
)
)
section_pos = float(
(section_clamped - 1)
/ max(n_sections_total - 1, 1)
)
section_index_norm = float(
section_clamped
/ max(n_sections_total, 1)
)
n_sections_norm = float(
n_sections_total
/ max(12, 1)
)
log_total_time = float(
np.log(max(total_time, 1e-12))
)
x_sec = np.array(
[sectionIndex_norm, section_pos, n_sections_norm, float(has_shutin), log_T_total],
[
section_index_norm,
section_pos,
n_sections_norm,
float(has_shutin),
log_total_time,
],
dtype=np.float32,
)
return EncodedSchedule(x_sched=x_sched, x_sec=x_sec)
return EncodedSchedule(
x_sched=x_sched,
x_sec=x_sec,
)

@ -9,6 +9,8 @@
sectionIndex 分组分析模型误差
"""
# pylint: disable=too-many-locals,import-error
from __future__ import annotations
from typing import List

@ -12,10 +12,28 @@ ForwardSurrogate 输入标准化后的物理参数特征和可选的流量制度
from __future__ import annotations
# pylint: disable=import-error,duplicate-code,too-many-arguments,too-many-positional-arguments
from dataclasses import dataclass, field
import torch
import torch.nn as nn
from torch import nn
@dataclass(slots=True)
class ForwardSurrogateConfig:
"""ForwardSurrogate 的结构配置。
使用配置对象可以避免模型构造函数参数过多同时让训练脚本中的超参数更集中
"""
param_dim: int
schedule_dim: int
curve_dim: int
hidden_dim: int = 128
fusion_hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
dropout: float = 0.0
use_schedule: bool = True
def build_mlp(
@ -26,14 +44,15 @@ def build_mlp(
) -> nn.Sequential:
"""按隐藏层列表搭建 Linear-ReLU-Dropout 组成的多层感知机。"""
layers: list[nn.Module] = []
prev = in_dim
for h in hidden_dims:
layers.append(nn.Linear(prev, h))
prev_dim = in_dim
for hidden_dim in hidden_dims:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
if dropout > 0:
layers.append(nn.Dropout(dropout))
prev = h
layers.append(nn.Linear(prev, out_dim))
prev_dim = hidden_dim
layers.append(nn.Linear(prev_dim, out_dim))
return nn.Sequential(*layers)
@ -53,7 +72,6 @@ class ScheduleEncoder(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""把流量制度统计特征映射到与参数分支同宽度的隐藏表示。"""
# 该分支只处理制度向量,便于后续与地层参数特征拼接融合。
return self.net(x)
@ -73,7 +91,6 @@ class ParamEncoder(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""把变换后的地层和井筒参数映射为隐藏表示。"""
# 参数特征通常来自 log/asinh 等尺度变换,先编码再与制度分支融合。
return self.net(x)
@ -82,19 +99,21 @@ class ForwardSurrogate(nn.Module):
输入:
params_x: 标准化后的物理参数特征形状 [B, param_dim]
schedule_x: 标准化后的流量制度向量形状 [B, schedule_dim] use_schedule=False
时该输入可为空
schedule_x: 标准化后的流量制度向量形状 [B, schedule_dim]
use_schedule=False 时该输入可为空
输出:
curve_pred: 形状 [B, curve_dim] log_pressurelog_derivativeslope 三段
顺序拼接curve_dim 必须能被 3 整除以便每段拥有相同时间点数
curve_pred: 形状 [B, curve_dim] log_pressurelog_derivativeslope
三段顺序拼接curve_dim 必须能被 3 整除以便每段拥有相同时间点数
"""
def __init__(
self,
param_dim: int,
schedule_dim: int,
curve_dim: int,
config: ForwardSurrogateConfig | None = None,
*,
param_dim: int | None = None,
schedule_dim: int | None = None,
curve_dim: int | None = None,
hidden_dim: int = 128,
fusion_hidden_dims: list[int] | None = None,
dropout: float = 0.0,
@ -102,71 +121,168 @@ class ForwardSurrogate(nn.Module):
):
"""构建参数分支、可选流量制度分支、融合主干和三组曲线输出头。"""
super().__init__()
config = self._coerce_config(
config=config,
param_dim=param_dim,
schedule_dim=schedule_dim,
curve_dim=curve_dim,
hidden_dim=hidden_dim,
fusion_hidden_dims=fusion_hidden_dims,
dropout=dropout,
use_schedule=use_schedule,
)
self._validate_config(config)
self.config = config
self.encoders = self._build_encoders()
self.trunk = self._build_trunk()
self.heads = self._build_heads()
@staticmethod
def _coerce_config(
config: ForwardSurrogateConfig | None,
*,
param_dim: int | None,
schedule_dim: int | None,
curve_dim: int | None,
hidden_dim: int,
fusion_hidden_dims: list[int] | None,
dropout: float,
use_schedule: bool,
) -> ForwardSurrogateConfig:
"""兼容配置对象式构造和旧版关键字参数式构造。"""
if config is not None:
return config
if param_dim is None or schedule_dim is None or curve_dim is None:
raise TypeError(
"ForwardSurrogate requires either a ForwardSurrogateConfig or "
"param_dim, schedule_dim and curve_dim keyword arguments"
)
return ForwardSurrogateConfig(
param_dim=int(param_dim),
schedule_dim=int(schedule_dim),
curve_dim=int(curve_dim),
hidden_dim=int(hidden_dim),
fusion_hidden_dims=fusion_hidden_dims or [256, 256],
dropout=float(dropout),
use_schedule=bool(use_schedule),
)
if curve_dim % 3 != 0:
raise ValueError(f"curve_dim={curve_dim} 不能被 3 整除;期望为 pressure/derivative/slope 三段")
@property
def curve_dim(self) -> int:
"""曲线拼接后的总维度。"""
return self.config.curve_dim
if fusion_hidden_dims is None:
fusion_hidden_dims = [256, 256]
@property
def part_dim(self) -> int:
"""压力、导数和 slope 每一段的时间点数量。"""
return self.config.curve_dim // 3
self.curve_dim = curve_dim
self.part_dim = curve_dim // 3
self.use_schedule = bool(use_schedule)
@property
def use_schedule(self) -> bool:
"""是否启用流量制度分支。"""
return bool(self.config.use_schedule)
# 参数和流量制度的物理含义与尺度差异较大,因此采用两个分支分别编码。
self.param_encoder = ParamEncoder(param_dim, hidden_dim, dropout=dropout)
@staticmethod
def _validate_config(config: ForwardSurrogateConfig) -> None:
"""检查模型配置是否满足网络结构约束。"""
if config.curve_dim % 3 != 0:
msg = (
f"curve_dim={config.curve_dim} 不能被 3 整除;"
"期望为 pressure/derivative/slope 三段"
)
raise ValueError(msg)
if not config.fusion_hidden_dims:
raise ValueError("fusion_hidden_dims 不能为空")
def _build_encoders(self) -> nn.ModuleDict:
"""构建参数分支和可选流量制度分支。"""
encoders = nn.ModuleDict(
{
"param": ParamEncoder(
self.config.param_dim,
self.config.hidden_dim,
dropout=self.config.dropout,
)
}
)
if self.use_schedule:
self.schedule_encoder = ScheduleEncoder(schedule_dim, hidden_dim, dropout=dropout)
trunk_in_dim = hidden_dim * 2
else:
self.schedule_encoder = None
trunk_in_dim = hidden_dim
trunk_out_dim = fusion_hidden_dims[-1]
self.trunk = build_mlp(
encoders["schedule"] = ScheduleEncoder(
self.config.schedule_dim,
self.config.hidden_dim,
dropout=self.config.dropout,
)
return encoders
def _build_trunk(self) -> nn.Sequential:
"""构建融合主干网络。"""
trunk_in_dim = self.config.hidden_dim * 2 if self.use_schedule else self.config.hidden_dim
trunk_out_dim = self.config.fusion_hidden_dims[-1]
return build_mlp(
in_dim=trunk_in_dim,
hidden_dims=fusion_hidden_dims,
hidden_dims=self.config.fusion_hidden_dims,
out_dim=trunk_out_dim,
dropout=dropout,
dropout=self.config.dropout,
)
# 压力曲线拆成 level + centered shape
# level 学习整体纵向偏移shape 学习局部曲线形态。
self.pressure_level_head = build_mlp(
in_dim=trunk_out_dim,
hidden_dims=[128],
out_dim=1,
dropout=dropout,
)
self.pressure_shape_head = build_mlp(
in_dim=trunk_out_dim,
hidden_dims=[128],
out_dim=self.part_dim,
dropout=dropout,
def _build_heads(self) -> nn.ModuleDict:
"""构建压力、导数和 slope 输出头。"""
trunk_out_dim = self.config.fusion_hidden_dims[-1]
return nn.ModuleDict(
{
"pressure_level": self._build_single_head(trunk_out_dim, 1),
"pressure_shape": self._build_single_head(trunk_out_dim, self.part_dim),
"derivative_level": self._build_single_head(trunk_out_dim, 1),
"derivative_shape": self._build_single_head(trunk_out_dim, self.part_dim),
"slope": self._build_single_head(trunk_out_dim, self.part_dim),
}
)
# 导数曲线同样拆分为 level + shape因为平台、谷值和过渡段
# 对自动拟合筛选非常重要。
self.derivative_level_head = build_mlp(
in_dim=trunk_out_dim,
def _build_single_head(self, in_dim: int, out_dim: int) -> nn.Sequential:
"""构建一个曲线输出头。"""
return build_mlp(
in_dim=in_dim,
hidden_dims=[128],
out_dim=1,
dropout=dropout,
)
self.derivative_shape_head = build_mlp(
in_dim=trunk_out_dim,
hidden_dims=[128],
out_dim=self.part_dim,
dropout=dropout,
out_dim=out_dim,
dropout=self.config.dropout,
)
# slope 是辅助输出,主要用于保持数据布局兼容。
self.slope_head = build_mlp(
in_dim=trunk_out_dim,
hidden_dims=[128],
out_dim=self.part_dim,
dropout=dropout,
@staticmethod
def _upgrade_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""把旧版 checkpoint 键名转换为当前 ModuleDict 结构键名。"""
prefix_map = {
"param_encoder.": "encoders.param.",
"schedule_encoder.": "encoders.schedule.",
"pressure_level_head.": "heads.pressure_level.",
"pressure_shape_head.": "heads.pressure_shape.",
"derivative_level_head.": "heads.derivative_level.",
"derivative_shape_head.": "heads.derivative_shape.",
"slope_head.": "heads.slope.",
}
upgraded: dict[str, torch.Tensor] = {}
for key, value in state_dict.items():
new_key = key
for old_prefix, new_prefix in prefix_map.items():
if key.startswith(old_prefix):
new_key = new_prefix + key[len(old_prefix) :]
break
upgraded[new_key] = value
return upgraded
def load_state_dict(
self,
state_dict: dict[str, torch.Tensor],
strict: bool = True,
assign: bool = False,
):
"""加载当前或旧版 ForwardSurrogate checkpoint。"""
return super().load_state_dict(
self._upgrade_state_dict_keys(state_dict),
strict=strict,
assign=assign,
)
@staticmethod
@ -174,40 +290,58 @@ class ForwardSurrogate(nn.Module):
"""去除每个样本 shape 分支的均值,让 level 分支专门学习整体偏移。"""
return x - x.mean(dim=1, keepdim=True)
def forward(self, params_x: torch.Tensor, schedule_x: torch.Tensor | None = None) -> torch.Tensor:
def _encode_features(
self,
params_x: torch.Tensor,
schedule_x: torch.Tensor | None,
) -> torch.Tensor:
"""分别编码物理参数和流量制度,再在隐空间融合。"""
param_feat = self.encoders["param"](params_x)
if not self.use_schedule:
return param_feat
if schedule_x is None:
raise ValueError("use_schedule=True但 forward 没有传入 schedule_x")
schedule_feat = self.encoders["schedule"](schedule_x)
return torch.cat([param_feat, schedule_feat], dim=-1)
def _predict_level_shape(
self,
trunk_feat: torch.Tensor,
level_head: str,
shape_head: str,
) -> torch.Tensor:
"""用 level + centered shape 生成一段曲线。"""
level = self.heads[level_head](trunk_feat)
shape = self.center_shape(self.heads[shape_head](trunk_feat))
return level + shape
def forward(
self,
params_x: torch.Tensor,
schedule_x: torch.Tensor | None = None,
) -> torch.Tensor:
"""执行一次前向预测。
参数分支和流量制度分支先分别编码再在隐空间拼接融合主干提取共同特征后
压力和导数各自通过 level + centered shape 两个输出头生成slope 作为辅助通道
直接由单独输出头预测返回值仍保持预处理阶段约定的曲线拼接布局
"""
p = self.param_encoder(params_x)
fused_feat = self._encode_features(params_x, schedule_x)
trunk_feat = self.trunk(fused_feat)
if self.use_schedule:
if schedule_x is None:
raise ValueError("use_schedule=True但 forward 没有传入 schedule_x")
s = self.schedule_encoder(schedule_x)
# 两个分支在隐藏空间拼接,避免直接混合量纲差异很大的原始特征。
fused = torch.cat([p, s], dim=-1)
else:
fused = p
# trunk 负责学习参数-制度共同决定的曲线整体形态。
trunk_feat = self.trunk(fused)
pressure_level = self.pressure_level_head(trunk_feat) # [B, 1]
pressure_shape = self.pressure_shape_head(trunk_feat) # [B, T]
# shape 去均值后只表达相对形态,纵向偏移交给 level 分支学习。
pressure_shape = self.center_shape(pressure_shape)
pressure_pred = pressure_level + pressure_shape
derivative_level = self.derivative_level_head(trunk_feat) # [B, 1]
derivative_shape = self.derivative_shape_head(trunk_feat) # [B, T]
# 导数也采用 level + shape减少平台值和局部过渡段之间的相互牵制。
derivative_shape = self.center_shape(derivative_shape)
derivative_pred = derivative_level + derivative_shape
slope_pred = self.slope_head(trunk_feat) # [B, T]
curve_pred = torch.cat([pressure_pred, derivative_pred, slope_pred], dim=1)
return curve_pred
pressure_pred = self._predict_level_shape(
trunk_feat,
"pressure_level",
"pressure_shape",
)
derivative_pred = self._predict_level_shape(
trunk_feat,
"derivative_level",
"derivative_shape",
)
slope_pred = self.heads["slope"](trunk_feat)
return torch.cat([pressure_pred, derivative_pred, slope_pred], dim=1)

@ -11,12 +11,30 @@ TimeConditionedSurrogate 不一次性输出完整曲线,而是把“物理参
from __future__ import annotations
# pylint: disable=import-error,duplicate-code,too-many-arguments,too-many-positional-arguments
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch import nn
@dataclass(frozen=True)
class TimeConditionedSurrogateConfig:
"""时间条件代理模型配置。"""
param_dim: int
schedule_dim: int
time_dim: int
hidden_dim: int = 256
n_blocks: int = 4
dropout: float = 0.05
use_schedule: bool = True
class ResidualBlock(nn.Module):
"""时间条件模型使用的全连接残差块,用于在较深网络中稳定传播特征。"""
def __init__(self, dim: int, dropout: float = 0.0):
"""构造两层全连接残差块,并根据 dropout 参数决定是否启用随机失活。"""
super().__init__()
@ -45,49 +63,97 @@ class TimeConditionedSurrogate(nn.Module):
def __init__(
self,
param_dim: int,
schedule_dim: int,
time_dim: int,
config: TimeConditionedSurrogateConfig | None = None,
*,
param_dim: int | None = None,
schedule_dim: int | None = None,
time_dim: int | None = None,
hidden_dim: int = 256,
n_blocks: int = 4,
dropout: float = 0.05,
use_schedule: bool = True,
):
"""输入维度和隐藏层宽度组装时间条件代理模型的编码、融合和输出层。"""
"""配置组装时间条件代理模型的编码、融合和输出层。"""
super().__init__()
self.use_schedule = bool(use_schedule)
self.param_encoder = nn.Sequential(
nn.Linear(param_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
self.time_encoder = nn.Sequential(
nn.Linear(time_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.GELU(),
config = self._coerce_config(
config=config,
param_dim=param_dim,
schedule_dim=schedule_dim,
time_dim=time_dim,
hidden_dim=hidden_dim,
n_blocks=n_blocks,
dropout=dropout,
use_schedule=use_schedule,
)
hidden_dim = int(config.hidden_dim)
time_hidden_dim = hidden_dim // 2
self.use_schedule = bool(config.use_schedule)
self.param_encoder = self._build_encoder(config.param_dim, hidden_dim)
self.time_encoder = self._build_encoder(config.time_dim, time_hidden_dim)
if self.use_schedule:
self.schedule_encoder = nn.Sequential(
nn.Linear(schedule_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.GELU(),
)
fusion_dim = hidden_dim * 2 + hidden_dim // 2
self.schedule_encoder = self._build_encoder(config.schedule_dim, hidden_dim)
fusion_dim = hidden_dim * 2 + time_hidden_dim
else:
self.schedule_encoder = None
fusion_dim = hidden_dim + hidden_dim // 2
fusion_dim = hidden_dim + time_hidden_dim
self.input_proj = nn.Sequential(
nn.Linear(fusion_dim, hidden_dim),
nn.GELU(),
)
self.blocks = nn.Sequential(*[ResidualBlock(hidden_dim, dropout=dropout) for _ in range(int(n_blocks))])
self.blocks = nn.Sequential(
*[
ResidualBlock(hidden_dim, dropout=config.dropout)
for _ in range(int(config.n_blocks))
]
)
self.head = nn.Sequential(
nn.LayerNorm(hidden_dim),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.Linear(hidden_dim, time_hidden_dim),
nn.GELU(),
nn.Linear(time_hidden_dim, 2),
)
@staticmethod
def _coerce_config(
config: TimeConditionedSurrogateConfig | None,
*,
param_dim: int | None,
schedule_dim: int | None,
time_dim: int | None,
hidden_dim: int,
n_blocks: int,
dropout: float,
use_schedule: bool,
) -> TimeConditionedSurrogateConfig:
"""兼容配置对象式构造和旧版关键字参数式构造。"""
if config is not None:
return config
if param_dim is None or schedule_dim is None or time_dim is None:
raise TypeError(
"TimeConditionedSurrogate requires either a TimeConditionedSurrogateConfig "
"or param_dim, schedule_dim and time_dim keyword arguments"
)
return TimeConditionedSurrogateConfig(
param_dim=int(param_dim),
schedule_dim=int(schedule_dim),
time_dim=int(time_dim),
hidden_dim=int(hidden_dim),
n_blocks=int(n_blocks),
dropout=float(dropout),
use_schedule=bool(use_schedule),
)
@staticmethod
def _build_encoder(in_dim: int, out_dim: int) -> nn.Sequential:
"""构建 Linear-LayerNorm-GELU 编码器。"""
return nn.Sequential(
nn.Linear(in_dim, out_dim),
nn.LayerNorm(out_dim),
nn.GELU(),
nn.Linear(hidden_dim // 2, 2),
)
def forward(
@ -103,16 +169,25 @@ class TimeConditionedSurrogate(nn.Module):
log_pressure log_derivative
"""
# params_x 和 schedule_x 是样本级特征time_x 是展开后的点级特征。
p = self.param_encoder(params_x)
t = self.time_encoder(time_x)
param_feat = self.param_encoder(params_x)
time_feat = self.time_encoder(time_x)
if self.use_schedule:
if schedule_x is None:
raise ValueError("use_schedule=True but schedule_x is None")
s = self.schedule_encoder(schedule_x)
x = torch.cat([p, s, t], dim=-1)
raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x")
schedule_feat = self.schedule_encoder(schedule_x)
fused = torch.cat([param_feat, schedule_feat, time_feat], dim=-1)
else:
x = torch.cat([p, t], dim=-1)
fused = torch.cat([param_feat, time_feat], dim=-1)
# 融合后的特征经过残差主干,输出 log_pressure 和 log_derivative 两个通道。
x = self.input_proj(x)
x = self.blocks(x)
return self.head(x)
hidden = self.input_proj(fused)
hidden = self.blocks(hidden)
return self.head(hidden)
def build_time_conditioned_surrogate(
config: TimeConditionedSurrogateConfig,
) -> TimeConditionedSurrogate:
"""根据配置创建时间条件代理模型。"""
return TimeConditionedSurrogate(config)

File diff suppressed because it is too large Load Diff

@ -11,10 +11,13 @@ log_pressure 和 log_derivative。
from __future__ import annotations
# pylint: disable=import-error,duplicate-code
import json
import random
from dataclasses import dataclass
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
import joblib
import numpy as np
@ -23,40 +26,45 @@ import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from src.data.param_features import inverse_transform_param_features
from src.models.time_conditioned_surrogate import TimeConditionedSurrogate
from src.models.time_conditioned_surrogate import (
TimeConditionedSurrogate,
TimeConditionedSurrogateConfig,
)
from src.training.train_forward import get_part_slices, infer_curve_layout
@dataclass
class PointCurveArrays:
"""逐时间点数据集所需数组。"""
params_x: np.ndarray
schedule_x: np.ndarray
time_x: np.ndarray
curve_y: np.ndarray
layout: dict
sample_weight: np.ndarray | None = None
class PointCurveDataset(Dataset):
"""把完整曲线展开为逐时间点训练样本。
原始数据形状是 N 条曲线每条曲线 T 个时间点 Dataset 的长度为 N*T
__getitem__ 会根据一维索引反推出 sample_idx time_idx返回该时间点对应的
参数特征制度特征时间特征双通道目标值以及样本级权重
"""
def __init__(
self,
params_x: np.ndarray,
schedule_x: np.ndarray,
time_x: np.ndarray,
curve_y: np.ndarray,
layout: dict,
sample_weight: np.ndarray | None = None,
):
"""把完整曲线展开为逐时间点训练样本。"""
def __init__(self, arrays: PointCurveArrays):
"""保存逐点训练所需的参数、流量制度、时间特征、目标曲线和样本权重。"""
self.params_x = torch.tensor(params_x, dtype=torch.float32)
self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32)
self.time_x = torch.tensor(time_x, dtype=torch.float32)
self.params_x = torch.tensor(arrays.params_x, dtype=torch.float32)
self.schedule_x = torch.tensor(arrays.schedule_x, dtype=torch.float32)
self.time_x = torch.tensor(arrays.time_x, dtype=torch.float32)
slices = get_part_slices(layout)
p = curve_y[:, slices["log_pressure"]]
d = curve_y[:, slices["log_derivative"]]
self.y = torch.tensor(np.stack([p, d], axis=-1), dtype=torch.float32)
slices = get_part_slices(arrays.layout)
pressure_y = arrays.curve_y[:, slices["log_pressure"]]
derivative_y = arrays.curve_y[:, slices["log_derivative"]]
self.y = torch.tensor(np.stack([pressure_y, derivative_y], axis=-1), dtype=torch.float32)
self.n_samples = int(self.params_x.shape[0])
self.n_time = int(self.time_x.shape[1])
sample_weight = arrays.sample_weight
if sample_weight is None:
sample_weight = np.ones((self.n_samples,), dtype=np.float32)
sample_weight = np.asarray(sample_weight, dtype=np.float32).reshape(-1)
if sample_weight.shape[0] != self.n_samples:
raise ValueError(f"sample_weight length mismatch: {sample_weight.shape[0]} != {self.n_samples}")
@ -80,30 +88,88 @@ class PointCurveDataset(Dataset):
@dataclass
class TimeConditionedTrainConfig:
"""时间条件代理模型训练配置,包括批量大小、学习率、模型宽度和设备。"""
processed_path: Path
output_dir: Path
seed: int = 42
class TimeModelConfig:
"""时间条件代理模型结构配置。"""
hidden_dim: int = 256
n_blocks: int = 4
dropout: float = 0.05
use_schedule: bool = True
@dataclass
class TimeOptimConfig:
"""训练轮次和优化器配置。"""
batch_size: int = 4096
epochs: int = 120
lr: float = 1.0e-3
weight_decay: float = 1.0e-4
hidden_dim: int = 256
n_blocks: int = 4
dropout: float = 0.05
@dataclass
class TimeLossConfig:
"""时间条件模型点级损失配置。"""
w_pressure: float = 1.0
w_derivative: float = 2.0
huber_beta: float = 0.05
use_schedule: bool = True
sample_weight_mode: str = "none"
@dataclass
class RiskWeightConfig:
"""风险区域样本加权配置。"""
mode: str = "none"
risk_weight: float = 2.5
skin_lt_minus8_weight: float = 3.5
sample_weight_min: float = 1.0
sample_weight_max: float = 4.0
weight_min: float = 1.0
weight_max: float = 4.0
@dataclass
class TimeRuntimeConfig:
"""训练运行时配置。"""
seed: int = 42
device: str = "cuda" if torch.cuda.is_available() else "cpu"
@dataclass
class TimeConditionedTrainConfig:
"""时间条件代理模型训练配置。"""
processed_path: Path
output_dir: Path
runtime: TimeRuntimeConfig = field(default_factory=TimeRuntimeConfig)
optim: TimeOptimConfig = field(default_factory=TimeOptimConfig)
model: TimeModelConfig = field(default_factory=TimeModelConfig)
loss: TimeLossConfig = field(default_factory=TimeLossConfig)
risk_weight: RiskWeightConfig = field(default_factory=RiskWeightConfig)
@dataclass
class DataBundle:
"""训练、验证、测试数据加载器与输入维度。"""
train_loader: DataLoader
val_loader: DataLoader
test_loader: DataLoader
param_dim: int
schedule_dim: int
time_dim: int
@dataclass
class TrainArtifacts:
"""训练过程中需要跨函数传递的数据。"""
data: dict
curve_layout: dict
train_weight_summary: dict[str, Any]
bundle: DataBundle
def set_global_seed(seed: int) -> None:
"""设置 Python、NumPy 和 PyTorch 随机种子,并在 CUDA 可用时同步设置 GPU 随机种子。"""
random.seed(seed)
@ -121,215 +187,366 @@ def _smooth_l1_vector(pred: torch.Tensor, target: torch.Tensor, beta: float) ->
def _loss(
pred: torch.Tensor,
target: torch.Tensor,
cfg: TimeConditionedTrainConfig,
loss_cfg: TimeLossConfig,
sample_weight: torch.Tensor | None = None,
) -> torch.Tensor:
"""计算时间条件模型的点级损失,并按样本权重求平均。"""
loss_p = _smooth_l1_vector(pred[:, 0], target[:, 0], beta=float(cfg.huber_beta))
loss_d = _smooth_l1_vector(pred[:, 1], target[:, 1], beta=float(cfg.huber_beta))
loss_vec = float(cfg.w_pressure) * loss_p + float(cfg.w_derivative) * loss_d
loss_p = _smooth_l1_vector(pred[:, 0], target[:, 0], beta=float(loss_cfg.huber_beta))
loss_d = _smooth_l1_vector(pred[:, 1], target[:, 1], beta=float(loss_cfg.huber_beta))
loss_vec = float(loss_cfg.w_pressure) * loss_p + float(loss_cfg.w_derivative) * loss_d
if sample_weight is None:
return loss_vec.mean()
w = sample_weight.to(loss_vec.device).reshape(-1).clamp_min(0.0)
return (loss_vec * w).sum() / torch.clamp(w.sum(), min=1.0e-12)
weight = sample_weight.to(loss_vec.device).reshape(-1).clamp_min(0.0)
return (loss_vec * weight).sum() / torch.clamp(weight.sum(), min=1.0e-12)
def _move_batch_to_device(
batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
device: str,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""把一个 batch 的所有张量移动到目标设备。"""
params_x, schedule_x, time_x, target_y, sample_weight = batch
return (
params_x.to(device),
schedule_x.to(device),
time_x.to(device),
target_y.to(device),
sample_weight.to(device),
)
def _evaluate(model: TimeConditionedSurrogate, loader: DataLoader, cfg: TimeConditionedTrainConfig) -> float:
"""在验证集上评估时间条件模型的平均损失。"""
model.eval()
def _forward_batch(
model: TimeConditionedSurrogate,
batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
cfg: TimeConditionedTrainConfig,
use_weight: bool,
) -> tuple[torch.Tensor, int]:
"""完成一个 batch 的前向计算并返回损失和样本数。"""
params_x, schedule_x, time_x, target_y, sample_weight = _move_batch_to_device(
batch,
cfg.runtime.device,
)
schedule_input = schedule_x if cfg.model.use_schedule else None
pred = model(params_x, time_x, schedule_input)
weight = sample_weight if use_weight else None
loss = _loss(pred, target_y, cfg.loss, sample_weight=weight)
return loss, int(target_y.shape[0])
def _run_loader(
model: TimeConditionedSurrogate,
loader: DataLoader,
cfg: TimeConditionedTrainConfig,
optimizer: torch.optim.Optimizer | None = None,
) -> float:
"""执行一个训练或评估 epoch。"""
is_train = optimizer is not None
model.train(mode=is_train)
total = 0.0
total_n = 0
with torch.no_grad():
for params_x, schedule_x, time_x, y, _sample_weight in loader:
params_x = params_x.to(cfg.device)
schedule_x = schedule_x.to(cfg.device)
time_x = time_x.to(cfg.device)
y = y.to(cfg.device)
pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None)
loss = _loss(pred, y, cfg)
bs = int(y.shape[0])
total += float(loss.detach().cpu()) * bs
total_n += bs
grad_context = torch.enable_grad() if is_train else torch.no_grad()
with grad_context:
for batch in loader:
if is_train:
optimizer.zero_grad()
loss, batch_size = _forward_batch(model, batch, cfg, use_weight=is_train)
if is_train:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
total += float(loss.detach().cpu()) * batch_size
total_n += batch_size
return total / max(total_n, 1)
def _evaluate(
model: TimeConditionedSurrogate,
loader: DataLoader,
cfg: TimeConditionedTrainConfig,
) -> float:
"""在验证集或测试集上评估时间条件模型的平均损失。"""
return _run_loader(model, loader, cfg, optimizer=None)
def _raw_params_from_processed_split(data: dict, split: str) -> dict[str, np.ndarray]:
"""从预处理数据中读取某个划分的原始参数,用于构造样本权重。"""
key = f"X_params_{split}"
features = data["scaler_params"].inverse_transform(data[key])
raw = inverse_transform_param_features(features, data.get("meta", {}).get("param_feature_transform"))
names = list(data.get("meta", {}).get("param_names") or ["k", "skin", "wellboreC", "phi", "h", "Cf"])
return {name: raw[:, idx].astype(np.float64) for idx, name in enumerate(names[: raw.shape[1]])}
transform = data.get("meta", {}).get("param_feature_transform")
raw = inverse_transform_param_features(features, transform)
default_names = ["k", "skin", "wellboreC", "phi", "h", "Cf"]
names = list(data.get("meta", {}).get("param_names") or default_names)
return {
name: raw[:, idx].astype(np.float64)
for idx, name in enumerate(names[: raw.shape[1]])
}
def _build_sample_weight(data: dict, cfg: TimeConditionedTrainConfig, split: str = "train") -> np.ndarray:
def _build_sample_weight(
data: dict,
cfg: TimeConditionedTrainConfig,
split: str = "train",
) -> np.ndarray:
"""根据原始物理参数生成样本权重,使关键参数区域得到更多关注。"""
mode = str(cfg.sample_weight_mode or "none").lower()
n = int(data[f"X_params_{split}"].shape[0])
mode = str(cfg.risk_weight.mode or "none").lower()
n_samples = int(data[f"X_params_{split}"].shape[0])
if mode in {"none", "off", "false"}:
return np.ones((n,), dtype=np.float32)
return np.ones((n_samples,), dtype=np.float32)
if mode != "risk_region":
raise ValueError(f"Unknown sample_weight_mode={cfg.sample_weight_mode!r}")
raise ValueError(f"Unknown sample_weight_mode={cfg.risk_weight.mode!r}")
params = _raw_params_from_processed_split(data, split)
weight = np.ones((n,), dtype=np.float32)
weight = np.ones((n_samples,), dtype=np.float32)
risk = (params["skin"] < -5.0) & (params["wellboreC"] > 0.1)
skin_extreme = params["skin"] < -8.0
weight[risk] = np.maximum(weight[risk], float(cfg.risk_weight))
weight[skin_extreme] = np.maximum(weight[skin_extreme], float(cfg.skin_lt_minus8_weight))
weight[risk] = np.maximum(weight[risk], float(cfg.risk_weight.risk_weight))
weight[skin_extreme] = np.maximum(weight[skin_extreme], float(cfg.risk_weight.skin_lt_minus8_weight))
return np.clip(
weight,
float(cfg.risk_weight.weight_min),
float(cfg.risk_weight.weight_max),
).astype(np.float32)
weight = np.clip(weight, float(cfg.sample_weight_min), float(cfg.sample_weight_max))
return weight.astype(np.float32)
def _summarize_sample_weight(sample_weight: np.ndarray) -> dict:
def _summarize_sample_weight(sample_weight: np.ndarray) -> dict[str, Any]:
"""统计样本权重的最小值、最大值和分位数,便于检查加权强度。"""
w = np.asarray(sample_weight, dtype=np.float32).reshape(-1)
weight = np.asarray(sample_weight, dtype=np.float32).reshape(-1)
return {
"min": float(np.min(w)),
"mean": float(np.mean(w)),
"median": float(np.median(w)),
"max": float(np.max(w)),
"n_weight_gt_1": int(np.sum(w > 1.0)),
"n_weight_lt_1": int(np.sum(w < 1.0)),
"min": float(np.min(weight)),
"mean": float(np.mean(weight)),
"median": float(np.median(weight)),
"max": float(np.max(weight)),
"n_weight_gt_1": int(np.sum(weight > 1.0)),
"n_weight_lt_1": int(np.sum(weight < 1.0)),
}
def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None:
"""训练时间条件代理模型并保存训练产物。
输入数据必须由新版预处理流程生成包含 X_time_train/val/test训练时只有训练集
打乱顺序验证和测试保持固定顺序以便复现指标最佳模型按验证损失保存最终
写出 history.json metrics.json用于查看训练趋势和测试集性能
"""
cfg.output_dir.mkdir(parents=True, exist_ok=True)
set_global_seed(int(cfg.seed))
data = joblib.load(cfg.processed_path)
def _load_processed_data(path: Path) -> dict:
"""读取预处理数据并检查时间条件训练所需字段。"""
data = joblib.load(path)
required = ["X_time_train", "X_time_val", "X_time_test"]
missing = [key for key in required if key not in data]
if missing:
# 时间条件训练需要每个曲线点的时间特征;缺失时说明预处理版本不匹配。
raise KeyError(f"processed dataset is missing time-conditioned fields: {missing}")
curve_layout = infer_curve_layout(data)
train_weight = _build_sample_weight(data, cfg, split="train")
train_weight_summary = _summarize_sample_weight(train_weight)
# PointCurveDataset 会把 [N, T] 曲线展开成 N*T 个点级样本。
train_ds = PointCurveDataset(
data["X_params_train"],
data["X_schedule_train"],
data["X_time_train"],
data["Y_curve_train"],
curve_layout,
sample_weight=train_weight,
return data
def _make_point_dataset(
data: dict,
split: str,
curve_layout: dict,
sample_weight: np.ndarray | None = None,
) -> PointCurveDataset:
"""构造某个数据划分对应的逐点数据集。"""
return PointCurveDataset(
PointCurveArrays(
params_x=data[f"X_params_{split}"],
schedule_x=data[f"X_schedule_{split}"],
time_x=data[f"X_time_{split}"],
curve_y=data[f"Y_curve_{split}"],
layout=curve_layout,
sample_weight=sample_weight,
)
)
val_ds = PointCurveDataset(data["X_params_val"], data["X_schedule_val"], data["X_time_val"], data["Y_curve_val"], curve_layout)
test_ds = PointCurveDataset(data["X_params_test"], data["X_schedule_test"], data["X_time_test"], data["Y_curve_test"], curve_layout)
def _build_dataloaders(
data: dict,
curve_layout: dict,
train_weight: np.ndarray,
cfg: TimeConditionedTrainConfig,
) -> DataBundle:
"""构建训练、验证和测试 DataLoader。"""
train_ds = _make_point_dataset(data, "train", curve_layout, train_weight)
val_ds = _make_point_dataset(data, "val", curve_layout)
test_ds = _make_point_dataset(data, "test", curve_layout)
generator = torch.Generator()
generator.manual_seed(int(cfg.seed))
# 只打乱训练集;验证/测试保持固定顺序方便复现指标。
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, generator=generator)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False)
generator.manual_seed(int(cfg.runtime.seed))
train_loader = DataLoader(
train_ds,
batch_size=cfg.optim.batch_size,
shuffle=True,
generator=generator,
)
val_loader = DataLoader(val_ds, batch_size=cfg.optim.batch_size, shuffle=False)
test_loader = DataLoader(test_ds, batch_size=cfg.optim.batch_size, shuffle=False)
model = TimeConditionedSurrogate(
return DataBundle(
train_loader=train_loader,
val_loader=val_loader,
test_loader=test_loader,
param_dim=int(data["X_params_train"].shape[1]),
schedule_dim=int(data["X_schedule_train"].shape[1]),
time_dim=int(data["X_time_train"].shape[-1]),
hidden_dim=int(cfg.hidden_dim),
n_blocks=int(cfg.n_blocks),
dropout=float(cfg.dropout),
use_schedule=bool(cfg.use_schedule),
).to(cfg.device)
# AdamW 的 weight_decay 与梯度更新解耦,适合这个全连接回归模型。
optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay))
best_val = float("inf")
best_path = cfg.output_dir / "time_conditioned_surrogate_best.pt"
history: list[dict] = []
)
def _prepare_training_artifacts(cfg: TimeConditionedTrainConfig) -> TrainArtifacts:
"""加载数据、推断布局、构造样本权重和 DataLoader。"""
data = _load_processed_data(cfg.processed_path)
curve_layout = infer_curve_layout(data)
train_weight = _build_sample_weight(data, cfg, split="train")
train_weight_summary = _summarize_sample_weight(train_weight)
bundle = _build_dataloaders(data, curve_layout, train_weight, cfg)
return TrainArtifacts(
data=data,
curve_layout=curve_layout,
train_weight_summary=train_weight_summary,
bundle=bundle,
)
def _build_model(bundle: DataBundle, cfg: TimeConditionedTrainConfig) -> TimeConditionedSurrogate:
"""兼容新版配置式模型和旧版关键字参数式模型。"""
model_cfg = TimeConditionedSurrogateConfig(
param_dim=bundle.param_dim,
schedule_dim=bundle.schedule_dim,
time_dim=bundle.time_dim,
hidden_dim=int(cfg.model.hidden_dim),
n_blocks=int(cfg.model.n_blocks),
dropout=float(cfg.model.dropout),
use_schedule=bool(cfg.model.use_schedule),
)
return TimeConditionedSurrogate(model_cfg).to(cfg.runtime.device)
def _build_optimizer(
model: TimeConditionedSurrogate,
cfg: TimeConditionedTrainConfig,
) -> torch.optim.Optimizer:
"""构建 AdamW 优化器。"""
return torch.optim.AdamW(
model.parameters(),
lr=float(cfg.optim.lr),
weight_decay=float(cfg.optim.weight_decay),
)
def _print_training_config(
cfg: TimeConditionedTrainConfig,
artifacts: TrainArtifacts,
) -> None:
"""打印时间条件训练配置摘要。"""
meta = artifacts.data.get("meta", {})
print("Time-conditioned training config:")
print(f" processed={cfg.processed_path}")
print(f" output_dir={cfg.output_dir}")
print(f" device={cfg.device}, batch_size={cfg.batch_size}, epochs={cfg.epochs}")
print(
f" dims: param={data['X_params_train'].shape[1]}, "
f"schedule={data['X_schedule_train'].shape[1]}, time={data['X_time_train'].shape[-1]}"
f" device={cfg.runtime.device}, batch_size={cfg.optim.batch_size}, "
f"epochs={cfg.optim.epochs}"
)
print(
f" dims: param={artifacts.bundle.param_dim}, "
f"schedule={artifacts.bundle.schedule_dim}, time={artifacts.bundle.time_dim}"
)
print(f" curve_time_source={data.get('meta', {}).get('curve_time_source', 'unknown')}")
print(f" sample_weight_mode={cfg.sample_weight_mode}, sample_weight={train_weight_summary}")
for epoch in range(1, int(cfg.epochs) + 1):
model.train()
total = 0.0
total_n = 0
for params_x, schedule_x, time_x, y, sample_weight in train_loader:
params_x = params_x.to(cfg.device)
schedule_x = schedule_x.to(cfg.device)
time_x = time_x.to(cfg.device)
y = y.to(cfg.device)
sample_weight = sample_weight.to(cfg.device)
optimizer.zero_grad()
# 每个点的输入由“样本级参数/制度 + 点级时间特征”组成。
pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None)
loss = _loss(pred, y, cfg, sample_weight=sample_weight)
loss.backward()
# 点级样本数量大,偶发高误差 batch 可能产生尖峰梯度,训练时做裁剪。
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
bs = int(y.shape[0])
total += float(loss.detach().cpu()) * bs
total_n += bs
train_loss = total / max(total_n, 1)
val_loss = _evaluate(model, val_loader, cfg)
print(f" curve_time_source={meta.get('curve_time_source', 'unknown')}")
print(
f" sample_weight_mode={cfg.risk_weight.mode}, "
f"sample_weight={artifacts.train_weight_summary}"
)
def _checkpoint_payload(
model: TimeConditionedSurrogate,
cfg: TimeConditionedTrainConfig,
artifacts: TrainArtifacts,
) -> dict[str, Any]:
"""构造 checkpoint 保存内容。"""
return {
"model_state_dict": model.state_dict(),
"param_dim": artifacts.bundle.param_dim,
"schedule_dim": artifacts.bundle.schedule_dim,
"time_dim": artifacts.bundle.time_dim,
"hidden_dim": int(cfg.model.hidden_dim),
"n_blocks": int(cfg.model.n_blocks),
"dropout": float(cfg.model.dropout),
"use_schedule": bool(cfg.model.use_schedule),
"curve_layout": artifacts.curve_layout,
"processed_path": str(cfg.processed_path),
"seed": int(cfg.runtime.seed),
"sample_weight_mode": str(cfg.risk_weight.mode),
"sample_weight_summary": artifacts.train_weight_summary,
"model_config": asdict(cfg.model),
"loss_config": asdict(cfg.loss),
"risk_weight_config": asdict(cfg.risk_weight),
}
def _save_json(path: Path, payload: dict | list) -> None:
"""写出 JSON 文件。"""
path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8")
def _train_epochs(
model: TimeConditionedSurrogate,
cfg: TimeConditionedTrainConfig,
artifacts: TrainArtifacts,
) -> tuple[float, Path, list[dict]]:
"""执行训练循环并保存最佳模型。"""
optimizer = _build_optimizer(model, cfg)
best_val = float("inf")
best_path = cfg.output_dir / "time_conditioned_surrogate_best.pt"
history: list[dict] = []
for epoch in range(1, int(cfg.optim.epochs) + 1):
train_loss = _run_loader(model, artifacts.bundle.train_loader, cfg, optimizer=optimizer)
val_loss = _evaluate(model, artifacts.bundle.val_loader, cfg)
history.append({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss})
print(f"[Epoch {epoch:03d}] train={train_loss:.6f} val={val_loss:.6f}")
if val_loss < best_val:
best_val = val_loss
# checkpoint 保存曲线布局和输入维度,评估脚本可据此重建同构模型。
torch.save(
{
"model_state_dict": model.state_dict(),
"param_dim": int(data["X_params_train"].shape[1]),
"schedule_dim": int(data["X_schedule_train"].shape[1]),
"time_dim": int(data["X_time_train"].shape[-1]),
"hidden_dim": int(cfg.hidden_dim),
"n_blocks": int(cfg.n_blocks),
"dropout": float(cfg.dropout),
"use_schedule": bool(cfg.use_schedule),
"curve_layout": curve_layout,
"processed_path": str(cfg.processed_path),
"seed": int(cfg.seed),
"sample_weight_mode": str(cfg.sample_weight_mode),
"sample_weight_summary": train_weight_summary,
},
best_path,
)
torch.save(_checkpoint_payload(model, cfg, artifacts), best_path)
print(f" -> best model saved to: {best_path}")
checkpoint = torch.load(best_path, map_location=cfg.device)
model.load_state_dict(checkpoint["model_state_dict"])
test_loss = _evaluate(model, test_loader, cfg)
# history 记录逐 epoch 走势metrics 记录最终选择的最佳验证和测试损失。
(cfg.output_dir / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False), encoding="utf-8")
(cfg.output_dir / "metrics.json").write_text(
json.dumps(
{
"best_val_loss": best_val,
"test_loss": test_loss,
"sample_weight_mode": str(cfg.sample_weight_mode),
"sample_weight_summary": train_weight_summary,
},
indent=2,
ensure_ascii=False,
),
encoding="utf-8",
return best_val, best_path, history
def _write_final_outputs(
cfg: TimeConditionedTrainConfig,
best_val: float,
test_loss: float,
history: list[dict],
artifacts: TrainArtifacts,
) -> None:
"""保存 history.json 和 metrics.json。"""
_save_json(cfg.output_dir / "history.json", history)
_save_json(
cfg.output_dir / "metrics.json",
{
"best_val_loss": best_val,
"test_loss": test_loss,
"sample_weight_mode": str(cfg.risk_weight.mode),
"sample_weight_summary": artifacts.train_weight_summary,
"model_config": asdict(cfg.model),
"loss_config": asdict(cfg.loss),
"risk_weight_config": asdict(cfg.risk_weight),
},
)
def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None:
"""训练时间条件代理模型并保存训练产物。"""
cfg.output_dir.mkdir(parents=True, exist_ok=True)
set_global_seed(int(cfg.runtime.seed))
artifacts = _prepare_training_artifacts(cfg)
model = _build_model(artifacts.bundle, cfg)
_print_training_config(cfg, artifacts)
best_val, best_path, history = _train_epochs(model, cfg, artifacts)
checkpoint = torch.load(best_path, map_location=cfg.runtime.device)
model.load_state_dict(checkpoint["model_state_dict"])
test_loss = _evaluate(model, artifacts.bundle.test_loader, cfg)
_write_final_outputs(cfg, best_val, test_loss, history, artifacts)
print(f"[Final] test={test_loss:.6f}")

Loading…
Cancel
Save