diff --git a/.gitignore b/.gitignore index a93ac31..bd58533 100644 --- a/.gitignore +++ b/.gitignore @@ -119,5 +119,6 @@ ML/nmWTAI-ML/data ML/nmWTAI-ML/models ML/nmWTAI-ML/results __pycache__ +.pylintrc ML/Training/Debug ML/Training/Release \ No newline at end of file diff --git a/ML/nmWTAI-ML/scripts/analyze_uq_results.py b/ML/nmWTAI-ML/scripts/analyze_uq_results.py index 3bacb2f..d616282 100644 --- a/ML/nmWTAI-ML/scripts/analyze_uq_results.py +++ b/ML/nmWTAI-ML/scripts/analyze_uq_results.py @@ -1,25 +1,19 @@ -"""分析集成代理模型的不确定性评估结果。 - -本脚本读取 `evaluate_forward_ensemble.py` 导出的逐样本误差与不确定性 CSV, -统计“剔除高不确定性样本后剩余样本误差是否下降”、高误差样本是否被不确定性捕获, -并输出汇总 CSV、风险样本清单以及保留率曲线图。 +""" +分析集成代理模型不确定性结果, +评估 uncertainty 是否能够识别高误差样本, +并生成 retention curve、风险样本统计和 summary 输出。 """ from __future__ import annotations -# 标准库导入: -# argparse 用于解析命令行参数; -# csv/json 用于读写分析结果; -# sys/Path 用于处理项目路径和文件路径。 import argparse import csv import json import sys +from dataclasses import dataclass from pathlib import Path +from typing import Any -# 第三方库导入: -# matplotlib 用于绘制不确定性筛选曲线; -# numpy 用于数组计算、排序、百分位数等统计操作。 import matplotlib.pyplot as plt import numpy as np @@ -32,47 +26,91 @@ ROOT = Path(__file__).resolve().parents[1] # 当前脚本虽然没有直接导入内部模块,但保留该设置可以兼容项目结构。 sys.path.append(str(ROOT)) +Row = dict[str, str] +OutputRow = dict[str, Any] -def parse_args() -> argparse.Namespace: - """解析不确定性评估结果 CSV、输出目录以及分位数分箱参数。""" - # 创建命令行参数解析器。 - # 本脚本主要用于分析集成代理模型输出的不确定性是否能够辅助识别高误差样本。 - parser = argparse.ArgumentParser(description="Analyze ensemble UQ outputs for fallback usefulness") - # 输入 CSV:通常来自集成不确定性评估阶段生成的 sample_uncertainty_metrics.csv。 - # 如果不指定,则会根据 tag 自动拼接默认路径。 - parser.add_argument("--input-csv", type=str, default=None, help="Path to sample_uncertainty_metrics.csv") +@dataclass(frozen=True) +class Metrics: + """保存样本级误差和不确定性指标,减少 main 函数中的局部变量数量。""" + + rmse: np.ndarray + mae: np.ndarray + r2: np.ndarray + unc: np.ndarray + + +@dataclass(frozen=True) +class UncertaintyThresholds: + """保存低/高不确定性阈值。""" + + low: float + high: float + - # 输出目录:用于保存 summary、筛选曲线 CSV、风险样本 CSV 和图像。 - # 如果不指定,则会根据 tag 自动生成默认输出目录。 - parser.add_argument("--output-dir", type=str, default=None, help="Optional output directory") +@dataclass(frozen=True) +class RiskRows: + """保存几类需要导出的风险样本。""" - # 实验标签:用于区分不同实验、不同数据集或不同代理模型配置。 - parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag") + top_uncertain: list[Row] + high_error_high_unc: list[Row] + high_error_low_unc: list[Row] - # 不确定性剔除比例。 - # 例如 0,5,10 表示分别剔除不确定性最高的 0%、5%、10% 样本后统计保留样本误差。 + +def parse_args() -> argparse.Namespace: + """解析不确定性评估结果 CSV、输出目录以及分位数分箱参数。""" + parser = argparse.ArgumentParser( + description="Analyze ensemble UQ outputs for fallback usefulness" + ) + + parser.add_argument( + "--input-csv", + type=str, + default=None, + help="Path to sample_uncertainty_metrics.csv", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Optional output directory", + ) + parser.add_argument( + "--tag", + type=str, + default="family_random_50k", + help="Experiment tag", + ) parser.add_argument( "--quantiles", type=str, default="0,5,10,20,30,40,50", help="Comma-separated uncertainty removal percentages", ) - - # 导出不确定性最高的 Top-K 样本,方便后续查看这些样本是否确实更容易预测失败。 - parser.add_argument("--top-k", type=int, default=100, help="Top-K risky samples to export") - - # 高误差样本阈值。 - # overall_rmse 大于等于该值的样本会被视为高误差样本。 - parser.add_argument("--high-error-rmse", type=float, default=1.0, help="High-error RMSE threshold") - - # 低不确定性分位点。 - # 例如 50 表示不确定性低于中位数的样本被视为低不确定性样本。 - parser.add_argument("--low-unc-quantile", type=float, default=50.0, help="Low uncertainty quantile") - - # 高不确定性分位点。 - # 例如 90 表示不确定性高于第 90 百分位的样本被视为高不确定性样本。 - parser.add_argument("--high-unc-quantile", type=float, default=90.0, help="High uncertainty quantile") + parser.add_argument( + "--top-k", + type=int, + default=100, + help="Top-K risky samples to export", + ) + parser.add_argument( + "--high-error-rmse", + type=float, + default=1.0, + help="High-error RMSE threshold", + ) + parser.add_argument( + "--low-unc-quantile", + type=float, + default=50.0, + help="Low uncertainty quantile", + ) + parser.add_argument( + "--high-unc-quantile", + type=float, + default=90.0, + help="High uncertainty quantile", + ) return parser.parse_args() @@ -81,69 +119,69 @@ def parse_quantiles(text: str) -> list[float]: """解析逗号分隔的剔除比例,用于不确定性保留曲线分析。""" values = [] - # 将形如 "0,5,10,20" 的字符串逐项拆分并转成 float。 for item in str(text).split(","): item = item.strip() - - # 允许用户在字符串中多写逗号或空格,空项直接跳过。 if not item: continue q = float(item) - - # 剔除比例必须位于 [0, 100)。 - # 不能等于 100,因为全部剔除后没有样本可用于统计误差。 if q < 0 or q >= 100: raise ValueError(f"非法 quantile 百分比: {q}") values.append(q) - # 强制包含 0%,用于记录不剔除任何样本时的全局基线结果。 if 0.0 not in values: values = [0.0] + values - # 去重并升序排列,保证后续曲线横轴顺序稳定。 return sorted(set(values)) def default_input_csv(tag: str) -> Path: """根据实验标签定位集成评估生成的 sample_uncertainty_metrics.csv。""" - # 默认输入路径约定: - # results/evaluation_{tag}_ensemble_uq/sample_uncertainty_metrics.csv return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv" def default_output_dir(tag: str) -> Path: """根据实验标签生成当前分析脚本默认的输出目录。""" - # 默认输出路径约定: - # results/evaluation_{tag}_ensemble_uq_analysis return Path("results") / f"evaluation_{tag}_ensemble_uq_analysis" -def load_rows(csv_path: Path) -> list[dict]: +def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path]: + """根据命令行参数解析输入 CSV 和输出目录。""" + tag = str(args.tag).strip() + input_csv = Path(args.input_csv) if args.input_csv else default_input_csv(tag) + output_dir = Path(args.output_dir) if args.output_dir else default_output_dir(tag) + return input_csv, output_dir + + +def load_rows(csv_path: Path) -> list[Row]: """读取 CSV 为字典行列表,保留每个样本的不确定性和误差字段。""" - # 使用 utf-8-sig 是为了兼容 Excel 或部分 Windows 工具生成的带 BOM 的 CSV。 - with open(csv_path, "r", encoding="utf-8-sig", newline="") as f: - rows = list(csv.DictReader(f)) + with open(csv_path, "r", encoding="utf-8-sig", newline="") as file: + rows = list(csv.DictReader(file)) - # 如果 CSV 只有表头或为空,则无法进行统计分析,直接抛出错误。 if not rows: raise ValueError(f"CSV 没有数据: {csv_path}") return rows -def as_float(rows: list[dict], key: str) -> np.ndarray: +def as_float(rows: list[Row], key: str) -> np.ndarray: """将 CSV 字段转为 float 数组,用于后续统计分析。""" - # 这里要求 CSV 中必须存在对应字段,并且字段值可以转为 float。 - # 当前脚本依赖的关键字段包括: - # overall_rmse、overall_mae、overall_r2、unc_mean_std。 return np.array([float(row[key]) for row in rows], dtype=np.float64) +def extract_metrics(rows: list[Row]) -> Metrics: + """从 CSV 行中提取样本级误差和不确定性指标。""" + return Metrics( + rmse=as_float(rows, "overall_rmse"), + mae=as_float(rows, "overall_mae"), + r2=as_float(rows, "overall_r2"), + unc=as_float(rows, "unc_mean_std"), + ) + + def safe_mean(x: np.ndarray) -> float: """计算均值;没有有效数据时返回 NaN。""" - # 当某个筛选条件过严时,可能出现空数组,此时返回 NaN 避免程序崩溃。 return float(np.mean(x)) if x.size else np.nan @@ -157,126 +195,41 @@ def safe_percentile(x: np.ndarray, q: float) -> float: return float(np.percentile(x, q)) if x.size else np.nan -def save_csv(path: Path, rows: list[dict]) -> None: +def save_csv(path: Path, rows: list[OutputRow] | list[Row]) -> None: """把字典行写入 CSV。""" - # 若 rows 为空,则不写文件。 - # 例如某次实验中不存在“高误差且低不确定性”的样本,就会走到这里。 if not rows: return - # 输出使用 utf-8-sig,方便在 Excel 中直接打开时中文不乱码。 - with open(path, "w", encoding="utf-8-sig", newline="") as f: - writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + with open(path, "w", encoding="utf-8-sig", newline="") as file: + writer = csv.DictWriter(file, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) -def plot_retention_curve(curve_rows: list[dict], output_path: Path) -> None: - """绘制按不确定性由高到低剔除样本后的误差变化,用来判断不确定性能否筛掉坏样本。""" - # 从曲线统计结果中取出横轴和纵轴数据。 - removed = np.array([float(row["removed_pct"]) for row in curve_rows], dtype=np.float64) - retained = np.array([float(row["retained_pct"]) for row in curve_rows], dtype=np.float64) - rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64) - mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64) - r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64) - - # 创建 1 行 3 列子图,分别观察 RMSE、MAE、R2 随保留样本比例的变化。 - fig, axes = plt.subplots(1, 3, figsize=(15, 4.5)) - - # 子图 1:保留样本比例 vs RMSE。 - # 如果不确定性有效,剔除高不确定性样本后,RMSE 理论上应下降。 - axes[0].plot(retained, rmse, marker="o") - axes[0].set_title("Retained vs RMSE") - axes[0].set_xlabel("Retained Samples (%)") - axes[0].set_ylabel("RMSE mean") - axes[0].grid(True, alpha=0.3) - - # 子图 2:保留样本比例 vs MAE。 - # MAE 对极端值不如 RMSE 敏感,可作为误差变化的补充观察指标。 - axes[1].plot(retained, mae, marker="o") - axes[1].set_title("Retained vs MAE") - axes[1].set_xlabel("Retained Samples (%)") - axes[1].set_ylabel("MAE mean") - axes[1].grid(True, alpha=0.3) - - # 子图 3:保留样本比例 vs R2。 - # 如果保留样本更容易预测,R2 通常应升高。 - axes[2].plot(retained, r2, marker="o") - axes[2].set_title("Retained vs R2") - axes[2].set_xlabel("Retained Samples (%)") - axes[2].set_ylabel("R2 mean") - axes[2].grid(True, alpha=0.3) - - # 总标题中记录剔除比例网格,便于之后回看图像时知道对应的实验设置。 - fig.suptitle( - f"Uncertainty Filtering Curves | removed grid={','.join(str(int(x)) for x in removed)}%" +def build_thresholds( + unc: np.ndarray, + low_unc_quantile: float, + high_unc_quantile: float, +) -> UncertaintyThresholds: + """根据样本不确定性分布计算低/高不确定性阈值。""" + return UncertaintyThresholds( + low=float(np.percentile(unc, low_unc_quantile)), + high=float(np.percentile(unc, high_unc_quantile)), ) - # tight_layout 用于减少子图之间的重叠;rect 为总标题预留空间。 - plt.tight_layout(rect=[0, 0, 1, 0.94]) - # 保存图像,dpi=150 对实验报告和日常查看基本够用。 - plt.savefig(output_path, dpi=150, bbox_inches="tight") - plt.close() +def build_retention_curve(metrics: Metrics, quantiles: list[float]) -> list[OutputRow]: + """构造按不确定性由高到低剔除样本后的误差统计曲线。""" + curve_rows: list[OutputRow] = [] + for removed_pct in quantiles: + threshold = float(np.percentile(metrics.unc, 100.0 - removed_pct)) + keep_mask = metrics.unc <= threshold -def main() -> None: - """分析集成不确定性与真实误差的相关性,并生成分箱统计和散点图。""" - # 1. 解析命令行参数。 - args = parse_args() - - # 去除 tag 两端空格,避免路径拼接时出现隐藏错误。 - tag = str(args.tag).strip() - - # 2. 确定输入 CSV 和输出目录。 - # 若用户通过命令行指定路径,则优先使用用户路径;否则使用默认路径。 - input_csv = Path(args.input_csv) if args.input_csv is not None else default_input_csv(tag) - output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag) - - # 确保输出目录存在;parents=True 允许自动创建多级目录。 - output_dir.mkdir(parents=True, exist_ok=True) - - # 3. 读取样本级不确定性评估结果。 - rows = load_rows(input_csv) - - # 解析用户指定的剔除比例列表。 - quantiles = parse_quantiles(args.quantiles) - - # 4. 提取关键指标。 - # overall_rmse / overall_mae / overall_r2 是每个样本的预测误差表现; - # unc_mean_std 是集成模型输出的平均标准差,可理解为该样本的预测不确定性。 - rmse = as_float(rows, "overall_rmse") - mae = as_float(rows, "overall_mae") - r2 = as_float(rows, "overall_r2") - unc = as_float(rows, "unc_mean_std") - - total_n = len(rows) - - # 低/高不确定性阈值来自样本分布本身,避免手动指定绝对阈值依赖数据尺度。 - # 例如不同压力归一化方式、不同数据集或不同代理模型会改变 unc_mean_std 的绝对大小。 - low_unc_thr = float(np.percentile(unc, float(args.low_unc_quantile))) - high_unc_thr = float(np.percentile(unc, float(args.high_unc_quantile))) - - # 5. 构造“不确定性过滤曲线”。 - # 核心思路: - # 按不确定性从高到低剔除一部分样本; - # 对剩余样本重新计算 RMSE/MAE/R2; - # 如果误差明显改善,说明不确定性指标可以作为 fallback 或人工复核触发条件。 - curve_rows: list[dict] = [] + kept_rmse = metrics.rmse[keep_mask] + kept_mae = metrics.mae[keep_mask] + kept_r2 = metrics.r2[keep_mask] - for removed_pct in quantiles: - # 计算当前剔除比例对应的不确定性阈值。 - # 例如 removed_pct=10 时,阈值为第 90 百分位; - # keep_mask 会保留不确定性小于等于该阈值的样本,即剔除最高 10% 不确定性样本。 - thr = float(np.percentile(unc, 100.0 - removed_pct)) - keep_mask = unc <= thr - - # 根据掩码取出保留下来的样本误差。 - kept_rmse = rmse[keep_mask] - kept_mae = mae[keep_mask] - kept_r2 = r2[keep_mask] - - # 保存当前剔除比例下的统计结果。 curve_rows.append( { "removed_pct": float(removed_pct), @@ -289,20 +242,25 @@ def main() -> None: "mae_median": safe_median(kept_mae), "r2_mean": safe_mean(kept_r2), "r2_median": safe_median(kept_r2), - "unc_threshold": thr, + "unc_threshold": threshold, } ) - # 6. 导出不确定性最高的 Top-K 样本。 - # np.argsort(-unc) 表示按 unc 从大到小排序。 - top_unc_idx = np.argsort(-unc)[: min(args.top_k, total_n)] - top_uncertain_rows = [rows[int(i)] for i in top_unc_idx] + return curve_rows + + +def select_top_uncertain_rows(rows: list[Row], unc: np.ndarray, top_k: int) -> list[Row]: + """按不确定性从大到小导出 Top-K 样本。""" + top_uncertain_indices = np.argsort(-unc)[: min(top_k, len(rows))] + return [rows[int(index)] for index in top_uncertain_indices] + - # 7. 识别两类高误差样本: - # high_error_high_unc:模型误差大,并且不确定性也高。 - # 这类样本是“可被不确定性发现的风险样本”。 - # high_error_low_unc:模型误差大,但不确定性低。 - # 这类样本最危险,说明模型错得很自信,是后续改进代理模型时最值得重点排查的对象。 +def classify_risky_rows( + rows: list[Row], + thresholds: UncertaintyThresholds, + high_error_rmse: float, +) -> tuple[list[Row], list[Row]]: + """识别高误差高不确定性样本,以及高误差低不确定性样本。""" high_error_high_unc_rows = [] high_error_low_unc_rows = [] @@ -310,61 +268,143 @@ def main() -> None: row_rmse = float(row["overall_rmse"]) row_unc = float(row["unc_mean_std"]) - # 高误差 + 高不确定性:fallback 机制比较容易捕捉。 - if row_rmse >= float(args.high_error_rmse) and row_unc >= high_unc_thr: + if row_rmse >= high_error_rmse and row_unc >= thresholds.high: high_error_high_unc_rows.append(row) - # 高误差 + 低不确定性:说明不确定性没有给出预警,是更严重的问题。 - if row_rmse >= float(args.high_error_rmse) and row_unc <= low_unc_thr: + if row_rmse >= high_error_rmse and row_unc <= thresholds.low: high_error_low_unc_rows.append(row) - # 按 RMSE 从高到低排序,方便优先查看最严重的失败样本。 - high_error_high_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True) - high_error_low_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True) + high_error_high_unc_rows.sort( + key=lambda row: float(row["overall_rmse"]), + reverse=True, + ) + high_error_low_unc_rows.sort( + key=lambda row: float(row["overall_rmse"]), + reverse=True, + ) - # 8. 汇总全局统计信息和关键阈值,写入 JSON。 - # JSON 更适合被后续自动化流程、实验记录系统或报告脚本读取。 - summary = { + return high_error_high_unc_rows, high_error_low_unc_rows + + +def build_risk_rows( + rows: list[Row], + metrics: Metrics, + thresholds: UncertaintyThresholds, + high_error_rmse: float, + top_k: int, +) -> RiskRows: + """构造所有需要导出的风险样本集合。""" + high_error_high_unc_rows, high_error_low_unc_rows = classify_risky_rows( + rows=rows, + thresholds=thresholds, + high_error_rmse=high_error_rmse, + ) + + return RiskRows( + top_uncertain=select_top_uncertain_rows(rows, metrics.unc, top_k), + high_error_high_unc=high_error_high_unc_rows, + high_error_low_unc=high_error_low_unc_rows, + ) + + +def build_summary( + input_csv: Path, + metrics: Metrics, + thresholds: UncertaintyThresholds, + risk_rows: RiskRows, + args: argparse.Namespace, +) -> dict[str, Any]: + """汇总全局统计信息和关键阈值,方便后续写入 JSON。""" + return { "input_csv": str(input_csv), - "n_samples": total_n, + "n_samples": int(metrics.rmse.size), "global": { - "rmse_mean": safe_mean(rmse), - "rmse_median": safe_median(rmse), - "rmse_p90": safe_percentile(rmse, 90), - "mae_mean": safe_mean(mae), - "r2_mean": safe_mean(r2), - "unc_mean": safe_mean(unc), - "unc_median": safe_median(unc), - "unc_p90": safe_percentile(unc, 90), + "rmse_mean": safe_mean(metrics.rmse), + "rmse_median": safe_median(metrics.rmse), + "rmse_p90": safe_percentile(metrics.rmse, 90), + "mae_mean": safe_mean(metrics.mae), + "r2_mean": safe_mean(metrics.r2), + "unc_mean": safe_mean(metrics.unc), + "unc_median": safe_median(metrics.unc), + "unc_p90": safe_percentile(metrics.unc, 90), }, "thresholds": { "high_error_rmse": float(args.high_error_rmse), "low_unc_quantile": float(args.low_unc_quantile), - "low_unc_threshold": low_unc_thr, + "low_unc_threshold": thresholds.low, "high_unc_quantile": float(args.high_unc_quantile), - "high_unc_threshold": high_unc_thr, + "high_unc_threshold": thresholds.high, }, "counts": { - "top_uncertain_exported": len(top_uncertain_rows), - "high_error_high_unc": len(high_error_high_unc_rows), - "high_error_low_unc": len(high_error_low_unc_rows), + "top_uncertain_exported": len(risk_rows.top_uncertain), + "high_error_high_unc": len(risk_rows.high_error_high_unc), + "high_error_low_unc": len(risk_rows.high_error_low_unc), }, } - # ensure_ascii=False 可以保证 JSON 文件中中文正常显示。 - with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as f: - json.dump(summary, f, ensure_ascii=False, indent=2) - # 9. 保存所有分析产物。 +def plot_retention_curve(curve_rows: list[OutputRow], output_path: Path) -> None: + """绘制按不确定性由高到低剔除样本后的误差变化。""" + removed = np.array( + [float(row["removed_pct"]) for row in curve_rows], + dtype=np.float64, + ) + retained = np.array( + [float(row["retained_pct"]) for row in curve_rows], + dtype=np.float64, + ) + rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64) + mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64) + r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64) + + fig, axes = plt.subplots(1, 3, figsize=(15, 4.5)) + + axes[0].plot(retained, rmse, marker="o") + axes[0].set_title("Retained vs RMSE") + axes[0].set_xlabel("Retained Samples (%)") + axes[0].set_ylabel("RMSE mean") + axes[0].grid(True, alpha=0.3) + + axes[1].plot(retained, mae, marker="o") + axes[1].set_title("Retained vs MAE") + axes[1].set_xlabel("Retained Samples (%)") + axes[1].set_ylabel("MAE mean") + axes[1].grid(True, alpha=0.3) + + axes[2].plot(retained, r2, marker="o") + axes[2].set_title("Retained vs R2") + axes[2].set_xlabel("Retained Samples (%)") + axes[2].set_ylabel("R2 mean") + axes[2].grid(True, alpha=0.3) + + removed_grid = ",".join(str(int(value)) for value in removed) + fig.suptitle(f"Uncertainty Filtering Curves | removed grid={removed_grid}%") + + plt.tight_layout(rect=[0, 0, 1, 0.94]) + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + + +def save_outputs( + output_dir: Path, + curve_rows: list[OutputRow], + risk_rows: RiskRows, + summary: dict[str, Any], +) -> None: + """保存 JSON、CSV 和不确定性过滤曲线图。""" + with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as file: + json.dump(summary, file, ensure_ascii=False, indent=2) + save_csv(output_dir / "uncertainty_filter_curve.csv", curve_rows) - save_csv(output_dir / "top_uncertain_samples.csv", top_uncertain_rows) - save_csv(output_dir / "high_error_high_unc_samples.csv", high_error_high_unc_rows) - save_csv(output_dir / "high_error_low_unc_samples.csv", high_error_low_unc_rows) + save_csv(output_dir / "top_uncertain_samples.csv", risk_rows.top_uncertain) + save_csv(output_dir / "high_error_high_unc_samples.csv", risk_rows.high_error_high_unc) + save_csv(output_dir / "high_error_low_unc_samples.csv", risk_rows.high_error_low_unc) - # 绘制并保存不确定性过滤曲线。 plot_retention_curve(curve_rows, output_dir / "uncertainty_filter_curve.png") - # 10. 在控制台打印简要结果,便于命令行运行后快速判断分析是否完成。 + +def print_summary(output_dir: Path, summary: dict[str, Any]) -> None: + """在控制台打印简要结果,便于命令行运行后快速判断分析是否完成。""" print("UQ fallback analysis complete.") print(f"Output dir: {output_dir}") print( @@ -377,8 +417,39 @@ def main() -> None: ) -# Python 脚本入口。 -# 当该文件被直接运行时执行 main(); -# 当它被其他脚本 import 时,不会自动执行分析流程。 +def main() -> None: + """分析集成不确定性与真实误差的相关性,并生成分析产物。""" + args = parse_args() + input_csv, output_dir = resolve_paths(args) + output_dir.mkdir(parents=True, exist_ok=True) + + rows = load_rows(input_csv) + metrics = extract_metrics(rows) + quantiles = parse_quantiles(args.quantiles) + thresholds = build_thresholds( + unc=metrics.unc, + low_unc_quantile=float(args.low_unc_quantile), + high_unc_quantile=float(args.high_unc_quantile), + ) + curve_rows = build_retention_curve(metrics, quantiles) + risk_rows = build_risk_rows( + rows=rows, + metrics=metrics, + thresholds=thresholds, + high_error_rmse=float(args.high_error_rmse), + top_k=int(args.top_k), + ) + summary = build_summary( + input_csv=input_csv, + metrics=metrics, + thresholds=thresholds, + risk_rows=risk_rows, + args=args, + ) + + save_outputs(output_dir, curve_rows, risk_rows, summary) + print_summary(output_dir, summary) + + if __name__ == "__main__": main() diff --git a/ML/nmWTAI-ML/scripts/analyze_uq_with_metadata.py b/ML/nmWTAI-ML/scripts/analyze_uq_with_metadata.py index ef4699d..6643840 100644 --- a/ML/nmWTAI-ML/scripts/analyze_uq_with_metadata.py +++ b/ML/nmWTAI-ML/scripts/analyze_uq_with_metadata.py @@ -1,8 +1,7 @@ -"""结合原始样本元数据分析不确定性评估结果。 +""" +将集成学习不确定性量化样本指标与处理后的数据集元数据进行关联。 -在 `analyze_uq_results.py` 的全局统计之外,本脚本会把 UQ CSV 与预处理数据中的 -参数、流量制度等元数据对齐,按制度族、参数区间等维度分组,帮助定位代理模型 -在哪类数值试井样本上更容易给出高误差或高不确定性。 +该脚本为每个 UQ 指标行补充调度元数据、类别标签以及可选的源标签,然后输出分组汇总结果和困难样本。 """ from __future__ import annotations @@ -11,31 +10,91 @@ import argparse import csv import json import sys +from dataclasses import dataclass from pathlib import Path +from typing import Any import joblib import numpy as np ROOT = Path(__file__).resolve().parents[1] -sys.path.append(str(ROOT)) +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +def load_experiment_path_helpers(): + """Load project path helpers after adding project root to sys.path.""" + # pylint: disable=import-error,import-outside-toplevel + from src.common.experiment_paths import normalize_tag, processed_path_for_tag + + return normalize_tag, processed_path_for_tag + + +normalize_tag_func, processed_path_func = load_experiment_path_helpers() + + +@dataclass(frozen=True) +class AnalysisPaths: + """Resolved input and output paths used by the metadata analysis.""" + + tag: str + processed_path: Path + uq_csv: Path + output_dir: Path + -from src.common.experiment_paths import normalize_tag, processed_path_for_tag +@dataclass(frozen=True) +class ProcessedMetadata: + """Metadata arrays loaded from the processed dataset.""" + + schedule_meta_test: np.ndarray + family_name_test: np.ndarray + source_name_test: np.ndarray | None + source_id_test: np.ndarray | None + meta_names: list[str] + name_to_col: dict[str, int] def parse_args() -> argparse.Namespace: """解析 UQ 指标 CSV、processed 数据和输出路径,用于合并样本元数据分析。""" - parser = argparse.ArgumentParser(description="Join UQ sample metrics with saved metadata") - parser.add_argument("--processed", type=str, default=None, help="Processed dataset path") - parser.add_argument("--uq-csv", type=str, default=None, help="sample_uncertainty_metrics.csv path") - parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag") - parser.add_argument("--output-dir", type=str, default=None, help="Output directory") + parser = argparse.ArgumentParser( + description="Join UQ sample metrics with saved metadata" + ) + parser.add_argument( + "--processed", + type=str, + default=None, + help="Processed dataset path", + ) + parser.add_argument( + "--uq-csv", + type=str, + default=None, + help="sample_uncertainty_metrics.csv path", + ) + parser.add_argument( + "--tag", + type=str, + default="family_random_50k", + help="Experiment tag", + ) + parser.add_argument( + "--output-dir", + type=str, + default=None, + help="Output directory", + ) parser.add_argument("--high-error-rmse", type=float, default=1.0) return parser.parse_args() def default_uq_csv(tag: str) -> Path: """根据实验标签定位默认的不确定性指标 CSV。""" - return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv" + return ( + Path("results") + / f"evaluation_{tag}_ensemble_uq" + / "sample_uncertainty_metrics.csv" + ) def default_output_dir(tag: str) -> Path: @@ -43,123 +102,269 @@ def default_output_dir(tag: str) -> Path: return Path("results") / f"evaluation_{tag}_ensemble_uq_metadata_analysis" -def save_csv(path: Path, rows: list[dict]) -> None: - """把字典行写入 CSV;当没有行时仍写出表头,方便后续脚本读取。""" +def resolve_paths(args: argparse.Namespace) -> AnalysisPaths: + """根据命令行参数和实验标签解析输入输出路径。""" + tag = normalize_tag_func(args.tag) + fallback_tag = tag or "family_random_50k" + + processed_path = ( + Path(args.processed) + if args.processed is not None + else processed_path_func(tag) + ) + uq_csv = ( + Path(args.uq_csv) + if args.uq_csv is not None + else default_uq_csv(fallback_tag) + ) + output_dir = ( + Path(args.output_dir) + if args.output_dir is not None + else default_output_dir(fallback_tag) + ) + + return AnalysisPaths( + tag=tag, + processed_path=processed_path, + uq_csv=uq_csv, + output_dir=output_dir, + ) + + +def save_csv(path: Path, rows: list[dict[str, Any]]) -> None: + """把字典行写入 CSV;当没有行时不写文件。""" if not rows: return + with open(path, "w", encoding="utf-8-sig", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) -def summarize_group(rows: list[dict], group_key: str) -> list[dict]: - """对一个样本分组计算误差、不确定性和样本数量等汇总指标。""" - groups: dict[str, list[dict]] = {} - for row in rows: - groups.setdefault(str(row[group_key]), []).append(row) +def load_uq_rows(uq_csv: Path) -> list[dict[str, str]]: + """读取 sample_uncertainty_metrics.csv。""" + with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f: + rows = list(csv.DictReader(f)) - out = [] - for key, group_rows in sorted(groups.items(), key=lambda item: len(item[1]), reverse=True): - rmse = np.array([float(r["overall_rmse"]) for r in group_rows], dtype=np.float64) - unc = np.array([float(r["unc_mean_std"]) for r in group_rows], dtype=np.float64) - out.append( - { - group_key: key, - "n_samples": len(group_rows), - "rmse_mean": float(np.mean(rmse)), - "rmse_median": float(np.median(rmse)), - "unc_mean": float(np.mean(unc)), - "unc_median": float(np.median(unc)), - } - ) - return out + if not rows: + raise ValueError(f"UQ CSV 没有数据: {uq_csv}") + return rows -def main() -> None: - """把 UQ 结果与参数/制度元数据拼接,按工况区域统计误差和不确定性。""" - args = parse_args() - tag = normalize_tag(args.tag) - processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) - uq_csv = Path(args.uq_csv) if args.uq_csv is not None else default_uq_csv(tag or "family_random_50k") - output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag or "family_random_50k") - output_dir.mkdir(parents=True, exist_ok=True) +def load_processed_metadata(processed_path: Path) -> ProcessedMetadata: + """读取 processed 数据中的测试集元数据。""" data = joblib.load(processed_path) - if "schedule_meta_test" not in data or "family_name_test" not in data: + required_keys = {"schedule_meta_test", "family_name_test"} + + if not required_keys.issubset(data): raise RuntimeError( - "Processed dataset does not contain schedule metadata. Re-run preprocess on a metadata-rich raw HDF5 first." + "Processed dataset does not contain schedule metadata. " + "Re-run preprocess on a metadata-rich raw HDF5 first." ) schedule_meta_test = np.asarray(data["schedule_meta_test"], dtype=np.float32) family_name_test = np.asarray(data["family_name_test"]).astype(str) - # source 字段只在混合数据集里存在;普通数据集没有时保持可选。 - source_name_test = np.asarray(data["source_name_test"]).astype(str) if "source_name_test" in data else None - source_id_test = np.asarray(data["source_id_test"]) if "source_id_test" in data else None - meta_names = data["meta"].get("schedule_meta_names") or [] + source_name_test = None + source_id_test = None + + if "source_name_test" in data: + source_name_test = np.asarray(data["source_name_test"]).astype(str) + if "source_id_test" in data: + source_id_test = np.asarray(data["source_id_test"]) + meta_names = data["meta"].get("schedule_meta_names") or [] name_to_col = {name: i for i, name in enumerate(meta_names)} - with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f: - uq_rows = list(csv.DictReader(f)) + return ProcessedMetadata( + schedule_meta_test=schedule_meta_test, + family_name_test=family_name_test, + source_name_test=source_name_test, + source_id_test=source_id_test, + meta_names=meta_names, + name_to_col=name_to_col, + ) + + +def enrich_uq_rows( + uq_rows: list[dict[str, str]], + metadata: ProcessedMetadata, +) -> list[dict[str, Any]]: + """把 UQ 指标与 family/source/schedule metadata 拼接到同一行。""" + enriched_rows: list[dict[str, Any]] = [] - enriched_rows = [] for row in uq_rows: idx = int(row["idx"]) - meta_row = schedule_meta_test[idx] - enriched = dict(row) - # UQ CSV 的 idx 对应 processed 测试集下标,因此可直接取测试集元数据补列。 - enriched["family_name"] = family_name_test[idx] - if source_name_test is not None: - enriched["source_name"] = source_name_test[idx] - if source_id_test is not None: - enriched["source_id"] = int(source_id_test[idx]) - for name, col in name_to_col.items(): + meta_row = metadata.schedule_meta_test[idx] + enriched: dict[str, Any] = dict(row) + + enriched["family_name"] = metadata.family_name_test[idx] + + if metadata.source_name_test is not None: + enriched["source_name"] = metadata.source_name_test[idx] + + if metadata.source_id_test is not None: + enriched["source_id"] = int(metadata.source_id_test[idx]) + + for name, col in metadata.name_to_col.items(): enriched[name] = float(meta_row[col]) + enriched_rows.append(enriched) - save_csv(output_dir / "uq_samples_with_metadata.csv", enriched_rows) + return enriched_rows - family_summary = summarize_group(enriched_rows, "family_name") - save_csv(output_dir / "summary_by_family.csv", family_summary) - if source_name_test is not None: - source_summary = summarize_group(enriched_rows, "source_name") - save_csv(output_dir / "summary_by_source.csv", source_summary) +def summarize_group( + rows: list[dict[str, Any]], + group_key: str, +) -> list[dict[str, Any]]: + """对一个样本分组计算误差、不确定性和样本数量等汇总指标。""" + groups: dict[str, list[dict[str, Any]]] = {} + + for row in rows: + groups.setdefault(str(row[group_key]), []).append(row) + + summaries = [] + sorted_groups = sorted( + groups.items(), + key=lambda item: len(item[1]), + reverse=True, + ) + + for key, group_rows in sorted_groups: + rmse = np.array( + [float(row["overall_rmse"]) for row in group_rows], + dtype=np.float64, + ) + unc = np.array( + [float(row["unc_mean_std"]) for row in group_rows], + dtype=np.float64, + ) + summaries.append( + { + group_key: key, + "n_samples": len(group_rows), + "rmse_mean": float(np.mean(rmse)), + "rmse_median": float(np.median(rmse)), + "unc_mean": float(np.mean(unc)), + "unc_median": float(np.median(unc)), + } + ) + + return summaries + + +def save_group_summaries( + output_dir: Path, + enriched_rows: list[dict[str, Any]], + metadata: ProcessedMetadata, +) -> None: + """按 family/source/n_prod 等维度保存分组统计。""" + save_csv( + output_dir / "summary_by_family.csv", + summarize_group(enriched_rows, "family_name"), + ) - if "n_prod" in name_to_col: + if metadata.source_name_test is not None: + save_csv( + output_dir / "summary_by_source.csv", + summarize_group(enriched_rows, "source_name"), + ) + + if "n_prod" in metadata.name_to_col: for row in enriched_rows: row["n_prod_group"] = int(round(float(row["n_prod"]))) - # 生产段数常对应不同制度复杂度,单独分组有助于定位 UQ 失效区域。 - n_prod_summary = summarize_group(enriched_rows, "n_prod_group") - save_csv(output_dir / "summary_by_n_prod.csv", n_prod_summary) - unc_values = np.array([float(row["unc_mean_std"]) for row in enriched_rows], dtype=np.float64) + save_csv( + output_dir / "summary_by_n_prod.csv", + summarize_group(enriched_rows, "n_prod_group"), + ) + + +def find_high_error_low_unc( + enriched_rows: list[dict[str, Any]], + high_error_rmse: float, +) -> list[dict[str, Any]]: + """筛出高误差但低不确定性的危险样本。""" + unc_values = np.array( + [float(row["unc_mean_std"]) for row in enriched_rows], + dtype=np.float64, + ) low_unc_threshold = float(np.median(unc_values)) - # 筛出“高误差但低不确定性”的样本,后续可作为困难样本补采样或诊断入口。 - high_error_low_unc = [ + + risky_rows = [ row for row in enriched_rows - if float(row["overall_rmse"]) >= float(args.high_error_rmse) + if float(row["overall_rmse"]) >= high_error_rmse and float(row["unc_mean_std"]) <= low_unc_threshold ] - high_error_low_unc.sort(key=lambda row: float(row["overall_rmse"]), reverse=True) - save_csv(output_dir / "high_error_low_unc_with_metadata.csv", high_error_low_unc) + risky_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True) + + return risky_rows + +def write_json_summary( + output_dir: Path, + paths: AnalysisPaths, + metadata: ProcessedMetadata, + enriched_rows: list[dict[str, Any]], + high_error_low_unc: list[dict[str, Any]], +) -> None: + """保存 metadata join 的整体摘要信息。""" summary = { - "processed_path": str(processed_path), - "uq_csv": str(uq_csv), + "tag": paths.tag, + "processed_path": str(paths.processed_path), + "uq_csv": str(paths.uq_csv), "n_samples": len(enriched_rows), - "meta_names": meta_names, - "has_source_name": bool(source_name_test is not None), + "meta_names": metadata.meta_names, + "has_source_name": bool(metadata.source_name_test is not None), "high_error_low_unc_count": len(high_error_low_unc), } + with open(output_dir / "metadata_join_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) + +def run_analysis(args: argparse.Namespace) -> tuple[Path, int]: + """执行完整 UQ metadata join 分析流程。""" + paths = resolve_paths(args) + paths.output_dir.mkdir(parents=True, exist_ok=True) + + metadata = load_processed_metadata(paths.processed_path) + uq_rows = load_uq_rows(paths.uq_csv) + enriched_rows = enrich_uq_rows(uq_rows, metadata) + + save_csv(paths.output_dir / "uq_samples_with_metadata.csv", enriched_rows) + save_group_summaries(paths.output_dir, enriched_rows, metadata) + + high_error_low_unc = find_high_error_low_unc( + enriched_rows=enriched_rows, + high_error_rmse=float(args.high_error_rmse), + ) + save_csv( + paths.output_dir / "high_error_low_unc_with_metadata.csv", + high_error_low_unc, + ) + + write_json_summary( + output_dir=paths.output_dir, + paths=paths, + metadata=metadata, + enriched_rows=enriched_rows, + high_error_low_unc=high_error_low_unc, + ) + + return paths.output_dir, len(high_error_low_unc) + + +def main() -> None: + """命令行入口:拼接 UQ 结果和元数据,并输出汇总文件。""" + output_dir, risky_count = run_analysis(parse_args()) + print("UQ metadata join analysis complete.") print(f"Output dir: {output_dir}") - print(f"High-error low-unc count={len(high_error_low_unc)}") + print(f"High-error low-unc count={risky_count}") if __name__ == "__main__": diff --git a/ML/nmWTAI-ML/scripts/build_mixed_dataset.py b/ML/nmWTAI-ML/scripts/build_mixed_dataset.py index cef11ed..0882de1 100644 --- a/ML/nmWTAI-ML/scripts/build_mixed_dataset.py +++ b/ML/nmWTAI-ML/scripts/build_mixed_dataset.py @@ -1,32 +1,61 @@ """构建“普通样本 + 困难样本”的混合训练数据集。 脚本先调用合并逻辑把常规数据集与局部自动拟合邻域数据集合成为一个 HDF5, -再复用统一预处理流程生成模型训练所需的标准化数据文件。适合在正演代理模型 -需要兼顾全局覆盖和 PSO/自动拟合困难区域时使用。 +再复用统一预处理流程生成模型训练所需的标准化数据文件。 + +该脚本适合在正演代理模型需要同时兼顾全局覆盖样本和 PSO/自动拟合困难区域 +样本时使用。 """ from __future__ import annotations import argparse +import importlib import sys from pathlib import Path +from typing import Any, Callable -ROOT = Path(__file__).resolve().parents[1] -sys.path.append(str(ROOT)) -from scripts.merge_datasets import merge_datasets -from src.common.experiment_paths import normalize_tag, processed_path_for_tag, sample_path_for_tag -from src.data.preprocess import preprocess_dataset +ROOT = Path(__file__).resolve().parents[1] +if str(ROOT) not in sys.path: + sys.path.append(str(ROOT)) def parse_args() -> argparse.Namespace: """解析 normal/hard 两类 HDF5 的混合比例、输出路径和预处理切分参数。""" - parser = argparse.ArgumentParser(description="Build a mixed raw+processed dataset from normal and hard HDF5 pools") - parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset") - parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset") - parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Experiment tag") - parser.add_argument("--output-h5", type=str, default=None, help="Optional merged raw .h5 path") - parser.add_argument("--output-processed", type=str, default=None, help="Optional processed .pkl path") + parser = argparse.ArgumentParser( + description="Build a mixed raw+processed dataset from normal and hard HDF5 pools" + ) + parser.add_argument( + "--normal-input", + type=str, + required=True, + help="Path to the normal/main .h5 dataset", + ) + parser.add_argument( + "--hard-input", + type=str, + required=True, + help="Path to the hard-targeted .h5 dataset", + ) + parser.add_argument( + "--tag", + type=str, + default="family_random_mixed_50k", + help="Experiment tag", + ) + parser.add_argument( + "--output-h5", + type=str, + default=None, + help="Optional merged raw .h5 path", + ) + parser.add_argument( + "--output-processed", + type=str, + default=None, + help="Optional processed .pkl path", + ) parser.add_argument("--total-samples", type=int, default=50000) parser.add_argument("--hard-ratio", type=float, default=0.30) parser.add_argument("--normal-count", type=int, default=None) @@ -40,17 +69,55 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def main() -> None: - """合并普通样本与困难样本,并立即生成对应的 processed 训练数据。""" - args = parse_args() - tag = normalize_tag(args.tag) - output_h5 = Path(args.output_h5) if args.output_h5 is not None else sample_path_for_tag(tag) +def load_project_functions() -> tuple[ + Callable[..., dict[str, Any]], + Callable[[str], str], + Callable[[str], Path], + Callable[[str], Path], + Callable[..., None], +]: + """延迟导入项目内部函数,避免 Pylint 对动态项目路径产生误报。""" + merge_module = importlib.import_module("scripts.merge_datasets") + paths_module = importlib.import_module("src.common.experiment_paths") + preprocess_module = importlib.import_module("src.data.preprocess") + + return ( + merge_module.merge_datasets, + paths_module.normalize_tag, + paths_module.processed_path_for_tag, + paths_module.sample_path_for_tag, + preprocess_module.preprocess_dataset, + ) + + +def resolve_output_paths( + args: argparse.Namespace, + tag: str, + processed_path_for_tag: Callable[[str], Path], + sample_path_for_tag: Callable[[str], Path], +) -> tuple[Path, Path]: + """根据命令行参数和实验标签确定 merged HDF5 与 processed 输出路径。""" + output_h5 = ( + Path(args.output_h5) + if args.output_h5 is not None + else sample_path_for_tag(tag) + ) output_processed = ( - Path(args.output_processed) if args.output_processed is not None else processed_path_for_tag(tag) + Path(args.output_processed) + if args.output_processed is not None + else processed_path_for_tag(tag) ) + return output_h5, output_processed - # 先在原始 HDF5 层面按比例抽样合并,保留 source_label 便于之后追踪样本来源。 - merge_meta = merge_datasets( + +def merge_raw_datasets( + args: argparse.Namespace, + tag: str, + output_h5: Path, + merge_datasets: Callable[..., dict[str, Any]], +) -> dict[str, Any]: + """按设定比例合并普通样本与困难样本,并返回合并过程元数据。""" + return merge_datasets( normal_input=args.normal_input, hard_input=args.hard_input, output=output_h5, @@ -65,7 +132,14 @@ def main() -> None: batch_size=args.batch_size, ) - # 合并后的原始数据立即进入同一套预处理流程,保证训练集格式与普通数据一致。 + +def build_processed_dataset( + args: argparse.Namespace, + output_h5: Path, + output_processed: Path, + preprocess_dataset: Callable[..., None], +) -> None: + """对合并后的 HDF5 数据执行统一预处理,生成模型训练用 pkl 文件。""" preprocess_dataset( input_path=output_h5, output_path=output_processed, @@ -74,10 +148,52 @@ def main() -> None: random_seed=args.seed, ) + +def print_outputs(merge_meta: dict[str, Any], output_processed: Path) -> None: + """打印混合原始数据和 processed 数据的输出位置。""" print(f"Merged raw dataset: {merge_meta['output_path']}") print(f"Merge summary: {merge_meta['summary_path']}") print(f"Processed dataset: {output_processed}") +def main() -> None: + """合并普通样本与困难样本,并立即生成对应的 processed 训练数据。""" + args = parse_args() + + ( + merge_datasets, + normalize_tag, + processed_path_for_tag, + sample_path_for_tag, + preprocess_dataset, + ) = load_project_functions() + + tag = normalize_tag(args.tag) + output_h5, output_processed = resolve_output_paths( + args=args, + tag=tag, + processed_path_for_tag=processed_path_for_tag, + sample_path_for_tag=sample_path_for_tag, + ) + + # 先在原始 HDF5 层面按比例抽样合并,保留 source_label 便于之后追踪样本来源。 + merge_meta = merge_raw_datasets( + args=args, + tag=tag, + output_h5=output_h5, + merge_datasets=merge_datasets, + ) + + # 合并后的原始数据立即进入同一套预处理流程,保证训练集格式与普通数据一致。 + build_processed_dataset( + args=args, + output_h5=output_h5, + output_processed=output_processed, + preprocess_dataset=preprocess_dataset, + ) + + print_outputs(merge_meta=merge_meta, output_processed=output_processed) + + if __name__ == "__main__": main() diff --git a/ML/nmWTAI-ML/scripts/compare_single_case.py b/ML/nmWTAI-ML/scripts/compare_single_case.py index c030648..8b75ee5 100644 --- a/ML/nmWTAI-ML/scripts/compare_single_case.py +++ b/ML/nmWTAI-ML/scripts/compare_single_case.py @@ -1,9 +1,23 @@ -"""单个试井样本的数值求解器与正演代理模型对比。 - -该脚本用一组指定地层/井筒参数和流量制度分别运行 C++ 数值求解器与 Python -正演代理模型,随后在压力、压力导数和斜率三段曲线上计算误差指标,绘制逐点 -对比图,并导出 JSON 汇总,便于排查单个案例的代理误差来源。 """ +数值试井代理模型 - 单案例 Solver vs Surrogate 对比脚本 + +主要功能: +1. 使用 C++ 数值求解器生成真实试井曲线 +2. 使用训练好的代理模型进行预测 +3. 对比 Solver 与 Surrogate 的输出结果 +4. 计算 RMSE / MAE / R2 等指标 +5. 绘制压力、导数、斜率三部分对比图 +6. 输出 JSON 分析结果 + +该脚本通常用于: +- 检查代理模型是否学到真实物理行为 +- 分析代理模型在哪些阶段误差较大 +- 验证不同 schedule 对模型预测的影响 +- 调试自动拟合(autofit)效果 +""" + +# pylint: disable=import-error,wrong-import-position,wrong-import-order,line-too-long, +# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments,invalid-name from __future__ import annotations @@ -36,6 +50,8 @@ from src.evaluation.autofit_objective import dual_log_objective from src.models.forward_surrogate import ForwardSurrogate +# 默认单案例参数 +# 用于没有传命令行参数时的快速测试 DEFAULT_SINGLE_CASE = { "config": "configs/data_gen_family_random.yaml", "tag": "family_random_50k", @@ -53,11 +69,7 @@ DEFAULT_SINGLE_CASE = { def parse_args() -> argparse.Namespace: - """解析单案例对比所需的路径、参数、井号和流量制度覆盖项。 - - 默认值对应一个可复现实验案例;命令行可覆盖模型 checkpoint、processed 数据、 - 地层/井筒参数以及 `timeQ/q/sectionIndex`,方便快速定位某个具体样本的误差表现。 - """ + """解析单个样本对比所需的 processed 数据、模型、样本索引和输出目录。""" parser = argparse.ArgumentParser( description="Compare solver output and surrogate prediction on a single parameter set" ) @@ -122,11 +134,8 @@ def calc_metrics( eps_range: float = 1e-3, eps_var: float = 1e-6, ) -> dict: - """计算一维曲线预测误差指标。 - - `eps_range` 用于避免真实曲线几乎为常数时 NRMSE 分母过小, - `eps_var` 用于避免 R2 在真实曲线方差接近 0 时失真;无效场景返回 NaN。 - """ + """计算 RMSE、MAE、Bias、NRMSE、R2 等回归指标。""" + # 残差 = 代理模型预测值 - 数值求解器真实值 err = y_pred - y_true mse = np.mean(err**2) rmse = float(np.sqrt(mse)) @@ -137,6 +146,7 @@ def calc_metrics( ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2)) ss_res = float(np.sum(err**2)) + # 曲线几乎为常数时,NRMSE 和 R2 的分母会过小,此时返回 NaN 更真实。 valid_nrmse = value_range > eps_range valid_r2 = ss_tot > eps_var @@ -154,15 +164,14 @@ def calc_metrics( def infer_curve_layout(meta: dict, curve_dim: int) -> dict: - """从预处理元数据推断曲线拼接布局。 - - 新版 processed 文件会保存 `curve_layout`;若旧数据缺失该字段,则按压力、 - 压力导数、斜率三段等长拼接的历史约定回退,保证旧实验仍可比较。 - """ + """从元数据读取曲线分段布局;旧数据没有布局时按压力/导数/斜率三等分回退。""" + # curve_layout 描述了拼接曲线的结构布局 + # 包括每一段 pressure/derivative/slope 的起止位置 curve_layout = meta.get("curve_layout") if curve_layout is not None: return curve_layout + # 兼容早期 processed 文件:没有显式 layout 时仍按 pressure/derivative/slope 三段切分。 n_time_points = curve_dim // 3 return { "n_time_points": int(n_time_points), @@ -175,11 +184,9 @@ def infer_curve_layout(meta: dict, curve_dim: int) -> dict: def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]: - """按 `curve_layout` 把一维拼接曲线拆成命名片段。 - - 返回值通常包含 `log_pressure`、`log_derivative` 和 `slope`, - 后续绘图、分段指标和自动拟合目标函数都依赖这些片段边界一致。 - """ + """按照 curve_layout 将拼接曲线拆成 log_pressure、log_derivative 和 slope 三段。""" + # parts 用于保存拆分后的不同曲线段 + # 例如 log_pressure / derivative / slope parts: dict[str, np.ndarray] = {} for part in layout["parts"]: start = int(part["start"]) @@ -189,11 +196,9 @@ def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarr def resolve_cf(args: argparse.Namespace, cfg: Config) -> float: - """解析综合压缩系数 Cf。 - - Cf 可能由命令行显式给出,也可能在配置文件的 fixed_params 中固定; - 集中处理该兜底逻辑,避免构造 `Params` 时把缺失 Cf 静默写成错误默认值。 - """ + """解析压缩系数 Cf:优先使用命令行参数,否则从配置默认参数中读取。""" + # 命令行优先级最高 + # 如果用户显式提供 Cf,则直接使用 if args.Cf is not None: return float(args.Cf) @@ -205,11 +210,8 @@ def resolve_cf(args: argparse.Namespace, cfg: Config) -> float: def parse_float_list(raw: str, name: str) -> list[float]: - """把命令行传入的逗号/分号/空格分隔数值串转换为浮点列表。 - - 该函数用于解析 `--timeQ` 和 `--q`,允许用户用不同分隔符快速输入制度序列; - 若没有解析出任何有效数字,则抛出带参数名的错误。 - """ + """解析输入文本、命令行或配置值,转换为后续流程可直接使用的结构。""" + # 用于解析类似 "1000,1000,1000" 的字符串输入 values: list[float] = [] for token in raw.replace(";", ",").replace(" ", ",").split(","): token = token.strip() @@ -222,11 +224,7 @@ def parse_float_list(raw: str, name: str) -> list[float]: def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) -> Schedule: - """构造单案例对比使用的流量制度。 - - 默认从配置文件读取 `case_schedule`;当命令行提供 `--timeQ/--q` 时优先使用覆盖值, - 并按显式 `--section-index` 或最后一段规则确定测试段。 - """ + """解析单案例使用的流量制度;命令行提供 timeQ/q 时覆盖配置中的默认制度。""" if args is not None and (args.timeQ is not None or args.q is not None or args.section_index is not None): if args.timeQ is None or args.q is None: raise ValueError("覆盖流量制度时必须同时提供 --timeQ 和 --q") @@ -236,7 +234,9 @@ def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) - if len(timeQ) != len(q): raise ValueError(f"--timeQ 和 --q 长度必须一致,当前分别为 {len(timeQ)} 和 {len(q)}") - # Allow the reporting/table convention with an initial "0 0" row. + # 兼容报表写法:首行 0,0 只表示初始状态,不作为真实流量段传给求解器。 + # 某些报表第一行仅用于表示初始状态 + # 不是真实生产段,因此自动丢弃 if len(timeQ) >= 2 and timeQ[0] <= 0.0 and q[0] == 0.0: timeQ = timeQ[1:] q = q[1:] @@ -272,11 +272,7 @@ def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) - def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]: - """解析配置、processed 数据、模型 checkpoint 和输出目录路径。 - - 路径优先级为命令行显式参数,其次为实验 tag 对应的标准目录; - `--no-schedule` 只在自动推断模型路径时用于区分是否带制度分支的模型。 - """ + """根据命令行参数、实验标签和 use_schedule 开关解析配置、预处理数据、模型和输出路径。""" tag = normalize_tag(args.tag) config_path = args.config @@ -300,10 +296,9 @@ def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]: def build_params_from_args(args: argparse.Namespace, cfg: Config, schedule: Schedule) -> Params: - """把命令行中的地层/井筒参数组装为求解器和代理模型共用的 `Params`。 - - 该对象同时持有物性参数与流量制度,是数值求解、参数特征变换和图标题展示的共同数据源。 - """ + """根据命令行参数和默认值构造目标 Params 对象。""" + # 构建完整物理参数对象 + # 后续会直接送入数值求解器和代理模型 return Params( k=float(args.k), skin=float(args.skin), @@ -316,15 +311,14 @@ def build_params_from_args(args: argparse.Namespace, cfg: Config, schedule: Sche def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) -> tuple[np.ndarray, dict]: - """运行 C++ 数值求解器并抽取可与代理输出对齐的曲线。 - - 原始求解结果会先做有效性检查和清洗,再按训练数据相同的重采样规则转换为 - `log_pressure/log_derivative/slope` 拼接向量;同时返回原始 log-log 曲线供 JSON 留痕。 - """ + """调用 C++ 求解器运行一次正演,并把双对数输出重采样为模型曲线向量。""" + # 创建 C++ 求解器客户端 + # 实际底层会调用数值试井求解程序 runner = CppRunner(cfg=cfg) try: ok = runner.run_simulation(params, override_schedule=params.schedule, include_schedule=True) result = read_result_bin(runner.result_bin) if runner.result_bin.exists() else None + # 某些求解器返回码可能失败但 result.bin 已写完;只要结果可读就继续对比。 if not ok and result is None: raise RuntimeError("求解器运行失败,未生成有效结果") if not ok and result is not None: @@ -343,6 +337,9 @@ def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) - p = np.asarray(loglog["p"], dtype=np.float64) d = np.asarray(loglog["deriv"], dtype=np.float64) + # 求解器原始点数不一定等于模型输出维度,先清洗再重采样到训练时的固定网格。 + # 对原始求解器曲线进行清洗 + # 去除非法点、异常点、重复点等 cleaned = clean_curve_for_dataset(cfg, t, p, d) if cleaned is None: raise RuntimeError("求解器返回曲线在清洗后无效") @@ -352,6 +349,8 @@ def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) - if not valid: raise RuntimeError(f"求解器曲线未通过有效性检查: {reason}") + # 将原始不规则时间曲线重采样到固定特征网格 + # 这是代理模型训练时使用的统一输入格式 curve_feat = resample_curve_to_features(cfg, t_clean, p_clean, d_clean) raw = { "t": t_clean.tolist(), @@ -366,19 +365,13 @@ def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) - def build_schedule_vector(cfg: Config, schedule: Schedule) -> np.ndarray: - """把 `Schedule` 编码为正演代理模型的制度特征向量。 - - 这里复用训练阶段同一套编码函数,确保单案例推理时的制度特征与 processed 数据一致。 - """ + """把 Schedule 编码成正演代理模型可直接接收的流量制度特征向量。""" return build_schedule_model_vector(cfg, schedule) def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, bool, torch.device]: - """加载正演代理模型 checkpoint 并恢复网络结构。 - - checkpoint 中保存了输入维度、隐藏层宽度、dropout 和是否使用制度分支等信息; - 按这些元数据重建模型后再加载权重,才能保证推理结构与训练结构一致。 - """ + """加载模型检查点,按保存的维度和超参数重建网络并切换到评估模式。""" + # 加载训练好的 PyTorch checkpoint checkpoint = torch.load(checkpoint_path, map_location="cpu") use_schedule = bool(checkpoint.get("use_schedule", True)) @@ -407,22 +400,22 @@ def predict_surrogate_curve( schedule: Schedule, cfg: Config, ) -> np.ndarray: - """使用代理模型预测单个参数点的反标准化曲线。 - - 参数和制度先按 processed 文件中的 transform/scaler 进入训练时的标准化空间; - 模型输出再通过 `scaler_curve.inverse_transform` 还原成可与数值求解器直接比较的曲线值。 - """ + """使用正演代理模型预测单个参数和流量制度对应的完整曲线。""" scaler_params = processed["scaler_params"] scaler_schedule = processed["scaler_schedule"] scaler_curve = processed["scaler_curve"] param_transform = param_feature_transform_from_meta(processed.get("meta", {})) + # 构建参数特征向量 + # 顺序必须与训练阶段保持一致 params_vec = np.asarray( [params.k, params.skin, params.wellboreC, params.phi, params.h, params.Cf], dtype=np.float32, ).reshape(1, -1) + # 将流量制度 schedule 编码成模型输入向量 schedule_vec = build_schedule_vector(cfg, schedule).reshape(1, -1) + # 输入特征必须使用训练时保存的 scaler 和参数变换,否则单案例预测会发生分布偏移。 params_x = scaler_params.transform(transform_param_features(params_vec, param_transform)).astype(np.float32) schedule_x = scaler_schedule.transform(schedule_vec).astype(np.float32) @@ -447,16 +440,14 @@ def plot_comparison( model_path: Path, use_schedule: bool, ) -> dict: - """绘制数值求解器曲线与代理预测曲线的分段对比图。 - - 左列展示每个曲线片段的真实值与预测值,右列展示逐点误差; - 图标题汇总参数、制度、模型名称和整体指标,返回的 summary 会同步写入 JSON。 - """ + """绘制数值求解器与代理模型的单案例曲线对比图。""" true_parts = split_curve_by_layout(curve_true, curve_layout) pred_parts = split_curve_by_layout(curve_pred, curve_layout) + # 整体曲线误差指标 overall = calc_metrics(curve_true, curve_pred) + # 自动拟合目标函数 + # 用于衡量双对数曲线形态是否一致 autofit = dual_log_objective(curve_true, curve_pred, curve_layout) - overall_r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.4f}" part_names = ["log_pressure", "log_derivative", "slope"] title_map = { @@ -465,6 +456,8 @@ def plot_comparison( "slope": "Slope of Log Pressure vs Log Time", } + r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.4f}" + fig, axes = plt.subplots(3, 2, figsize=(14, 12)) fig.suptitle( "Solver vs Surrogate\n" @@ -474,14 +467,17 @@ def plot_comparison( f"use_schedule={use_schedule}, model={model_path.name}, " f"Overall RMSE={overall['rmse']:.4f}, MAE={overall['mae']:.4f}, " f"AutoFitObj={autofit['dual_log_objective']:.4f}, " - f"R2={overall_r2_text}" + f"R2={r2_text}" ) summary = {"overall": overall, "autofit": autofit, "parts": {}} + # 分别对 pressure / derivative / slope 三部分绘图 for row, name in enumerate(part_names): + # 左列画曲线重合程度,右列画逐时间点残差,便于定位误差集中在哪个时期。 y_true = true_parts[name] y_pred = pred_parts[name] + # 残差 = 代理模型预测值 - 数值求解器真实值 err = y_pred - y_true x = np.arange(len(y_true)) metrics = calc_metrics(y_true, y_pred) @@ -517,11 +513,7 @@ def plot_comparison( def main() -> None: - """执行单案例完整对比流程。 - - 流程包括解析路径和制度、加载 processed 与 checkpoint、运行数值求解器、 - 调用代理模型预测、绘制对比图、写出 JSON 汇总并打印核心指标。 - """ + """抽取一个测试样本,绘制代理模型曲线与真实数值求解曲线的对比图。""" args = parse_args() cfg, processed_path, model_path, output_dir = resolve_paths(args) output_dir.mkdir(parents=True, exist_ok=True) @@ -534,14 +526,18 @@ def main() -> None: processed = joblib.load(processed_path) curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"])) + # 先用同一组参数和制度跑真实求解器,再用代理模型预测同一输入,保证对比公平。 schedule = resolve_case_schedule(cfg, args) params = build_params_from_args(args, cfg, schedule) print("Running solver...") + # 运行真实数值求解器 + # 得到 ground truth 曲线 curve_solver, raw_solver = run_solver_and_extract_curve(cfg, params, args.well_index) print("Running surrogate...") model, use_schedule, device = load_model(model_path) + # 使用代理模型预测同一组参数对应的曲线 curve_pred = predict_surrogate_curve( processed=processed, model=model, @@ -564,6 +560,8 @@ def main() -> None: use_schedule=use_schedule, ) + # 汇总所有实验结果 + # 最终会保存为 JSON 便于后续分析 summary_payload = { "config_path": str(cfg.path), "processed_path": str(processed_path), @@ -592,14 +590,14 @@ def main() -> None: overall = summary["overall"] autofit = summary["autofit"] - overall_r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.6f}" + r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.6f}" print("\nSingle-case comparison complete.") print(f"Output dir: {output_dir}") print( f"Overall RMSE={overall['rmse']:.6f}, MAE={overall['mae']:.6f}, " f"Bias={overall['bias']:.6f}, " f"AutoFitObj={autofit['dual_log_objective']:.6f}, " - f"R2={overall_r2_text}" + f"R2={r2_text}" ) print(f"Plot saved: {plot_path}") print(f"Summary saved: {output_dir / 'single_case_summary.json'}") diff --git a/ML/nmWTAI-ML/scripts/evaluate_forward.py b/ML/nmWTAI-ML/scripts/evaluate_forward.py index ca3ed17..eb11e8b 100644 --- a/ML/nmWTAI-ML/scripts/evaluate_forward.py +++ b/ML/nmWTAI-ML/scripts/evaluate_forward.py @@ -1,10 +1,14 @@ """评估固定长度曲线正演代理模型。 -脚本加载预处理数据和 `ForwardSurrogate` checkpoint,批量预测验证/测试样本曲线, +脚本加载预处理数据和 `ForwardSurrogate` checkpoint,批量预测验证/测试样本曲线, 按整体曲线与压力、导数、斜率分段统计 RMSE、MAE、NRMSE、R2 等指标,并保存 随机、最佳、最差样本图,作为正演代理模型离线验收入口。 """ +# pylint: disable=import-error,wrong-import-position +# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments,too-many-statements +# pylint: disable=line-too-long + from __future__ import annotations import argparse @@ -53,7 +57,10 @@ def parse_args() -> argparse.Namespace: "--fit-processed", type=str, default=None, - help="Processed dataset used to fit scalers for the evaluated model; required for cross-dataset evaluation", + help=( + "Processed dataset used to fit scalers for the evaluated model; " + "required for cross-dataset evaluation" + ), ) parser.add_argument( "--output-dir", @@ -470,10 +477,7 @@ def main() -> None: } sample_scores: list[dict] = [] - for idx in range(len(all_true)): - curve_true = all_true[idx] - curve_pred = all_pred[idx] - + for idx, (curve_true, curve_pred) in enumerate(zip(all_true, all_pred)): # 同时记录整体指标和三段指标,便于判断误差来自压力、导数还是辅助 slope。 overall_m = calc_metrics(curve_true, curve_pred) overall_metric_list.append(overall_m) diff --git a/ML/nmWTAI-ML/scripts/evaluate_forward_ensemble.py b/ML/nmWTAI-ML/scripts/evaluate_forward_ensemble.py index 0652f4e..1636986 100644 --- a/ML/nmWTAI-ML/scripts/evaluate_forward_ensemble.py +++ b/ML/nmWTAI-ML/scripts/evaluate_forward_ensemble.py @@ -4,6 +4,10 @@ 导出不确定性-误差相关性统计、散点图和高不确定性样本,用于判断集成方差是否可作为 自动拟合候选筛选或风险提示信号。 """ +# pylint: disable=import-error,wrong-import-position +# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments +# pylint: disable=too-many-statements + from __future__ import annotations @@ -59,7 +63,12 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Evaluate deep-ensemble UQ for forward surrogate") parser.add_argument("--processed", type=str, default=None, help="Processed dataset path") parser.add_argument("--tag", type=str, default=None, help="Experiment tag") - parser.add_argument("--model-root", type=str, default=None, help="Root dir that contains seed_* members") + parser.add_argument( + "--model-root", + type=str, + default=None, + help="Root dir that contains seed_* members", + ) parser.add_argument("--output-dir", type=str, default=None, help="Evaluation output dir") parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list") parser.add_argument("--no-schedule", action="store_true") @@ -212,7 +221,13 @@ def plot_uncertain_sample( x = np.arange(len(y_true)) ax.plot(x, y_true, label="True", linewidth=2) ax.plot(x, y_mean, label="Ensemble mean", linewidth=2) - ax.fill_between(x, y_mean - 2.0 * y_std, y_mean + 2.0 * y_std, alpha=0.2, label="mean ± 2 std") + ax.fill_between( + x, + y_mean - 2.0 * y_std, + y_mean + 2.0 * y_std, + alpha=0.2, + label="mean ± 2 std", + ) ax.set_title(title_map[name]) ax.grid(True, alpha=0.3) ax.legend() @@ -229,9 +244,21 @@ def main() -> None: use_schedule = not args.no_schedule seeds = parse_seed_list(args.seeds) - processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) - model_root = Path(args.model_root) if args.model_root is not None else default_model_root(tag, use_schedule) - output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag, use_schedule) + processed_path = ( + Path(args.processed) + if args.processed is not None + else processed_path_for_tag(tag) + ) + model_root = ( + Path(args.model_root) + if args.model_root is not None + else default_model_root(tag, use_schedule) + ) + output_dir = ( + Path(args.output_dir) + if args.output_dir is not None + else default_output_dir(tag, use_schedule) + ) output_dir.mkdir(parents=True, exist_ok=True) # 集成评估必须使用同一份 processed 数据,保证各成员输入标准化口径一致。 @@ -266,8 +293,16 @@ def main() -> None: with torch.no_grad(): for idx in range(len(x_params_test)): - params_t = torch.tensor(x_params_test[idx : idx + 1], dtype=torch.float32, device=device) - schedule_t = torch.tensor(x_schedule_test[idx : idx + 1], dtype=torch.float32, device=device) + params_t = torch.tensor( + x_params_test[idx : idx + 1], + dtype=torch.float32, + device=device, + ) + schedule_t = torch.tensor( + x_schedule_test[idx : idx + 1], + dtype=torch.float32, + device=device, + ) member_preds = [] for _, model in members: # 每个成员独立预测后先反标准化;集成均值和标准差都在原始曲线尺度上计算。 @@ -279,7 +314,9 @@ def main() -> None: member_preds.append(pred) member_preds = np.stack(member_preds, axis=0) - curve_true = scaler_curve.inverse_transform(y_curve_test[idx : idx + 1])[0].astype(np.float32) + curve_true = scaler_curve.inverse_transform( + y_curve_test[idx : idx + 1] + )[0].astype(np.float32) curve_mean = member_preds.mean(axis=0).astype(np.float32) curve_std = member_preds.std(axis=0, ddof=0).astype(np.float32) @@ -344,7 +381,12 @@ def main() -> None: with open(output_dir / "ensemble_uq_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) - with open(output_dir / "sample_uncertainty_metrics.csv", "w", newline="", encoding="utf-8-sig") as f: + with open( + output_dir / "sample_uncertainty_metrics.csv", + "w", + newline="", + encoding="utf-8-sig", + ) as f: writer = csv.DictWriter(f, fieldnames=list(sample_rows[0].keys())) writer.writeheader() writer.writerows(sample_rows) diff --git a/ML/nmWTAI-ML/scripts/evaluate_time_conditioned.py b/ML/nmWTAI-ML/scripts/evaluate_time_conditioned.py index f893c62..1ee6b5f 100644 --- a/ML/nmWTAI-ML/scripts/evaluate_time_conditioned.py +++ b/ML/nmWTAI-ML/scripts/evaluate_time_conditioned.py @@ -7,6 +7,12 @@ from __future__ import annotations +# pylint: disable=import-error,wrong-import-position,line-too-long +# pylint: disable=too-many-arguments,too-many-positional-arguments +# pylint: disable=too-many-locals,too-many-statements +# pylint: disable=invalid-name,unused-argument,too-many-function-args + + import argparse import csv import json diff --git a/ML/nmWTAI-ML/scripts/finetune_forward_local_ranking.py b/ML/nmWTAI-ML/scripts/finetune_forward_local_ranking.py index 38a0827..aad86e5 100644 --- a/ML/nmWTAI-ML/scripts/finetune_forward_local_ranking.py +++ b/ML/nmWTAI-ML/scripts/finetune_forward_local_ranking.py @@ -13,17 +13,25 @@ import random import sys from pathlib import Path -ROOT = Path(__file__).resolve().parents[1] -sys.path.append(str(ROOT)) - import h5py import joblib import numpy as np import torch import torch.nn.functional as F -from src.common.experiment_paths import model_checkpoint_for_tag, model_dir_for_tag, normalize_tag, processed_path_for_tag -from src.data.param_features import param_feature_transform_from_meta, transform_param_features +ROOT = Path(__file__).resolve().parents[1] +sys.path.append(str(ROOT)) + +from src.common.experiment_paths import ( + model_checkpoint_for_tag, + model_dir_for_tag, + normalize_tag, + processed_path_for_tag, +) +from src.data.param_features import ( + param_feature_transform_from_meta, + transform_param_features, +) from src.models.forward_surrogate import ForwardSurrogate @@ -32,7 +40,12 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Fine-tune a forward surrogate with local pairwise autofit ranking constraints" ) - parser.add_argument("--neighborhood", type=str, required=True, help="Anchor-neighborhood HDF5 path") + parser.add_argument( + "--neighborhood", + type=str, + required=True, + help="Anchor-neighborhood HDF5 path", + ) parser.add_argument("--base-tag", type=str, default="family_random_mixed_50k_biasfix") parser.add_argument("--base-processed", type=str, default=None) parser.add_argument("--base-model", type=str, default=None) @@ -144,8 +157,12 @@ def load_neighborhood_groups( anchor_params_scaled = scaler_params.transform( transform_param_features(anchor_params[anchor_id : anchor_id + 1], param_transform) ).astype(np.float32) - anchor_schedule_scaled = scaler_schedule.transform(anchor_schedule[anchor_id : anchor_id + 1]).astype(np.float32) - anchor_curve_scaled = scaler_curve.transform(anchor_curve[anchor_id : anchor_id + 1]).astype(np.float32) + anchor_schedule_scaled = scaler_schedule.transform( + anchor_schedule[anchor_id : anchor_id + 1] + ).astype(np.float32) + anchor_curve_scaled = scaler_curve.transform( + anchor_curve[anchor_id : anchor_id + 1] + ).astype(np.float32) groups.append( { @@ -157,7 +174,9 @@ def load_neighborhood_groups( "neighbor_params_x": scaler_params.transform( transform_param_features(neighbor_params[idx], param_transform) ).astype(np.float32), - "neighbor_schedule_x": scaler_schedule.transform(anchor_schedule[neighbor_anchor_id[idx]]).astype(np.float32), + "neighbor_schedule_x": scaler_schedule.transform( + anchor_schedule[neighbor_anchor_id[idx]] + ).astype(np.float32), "neighbor_curve_x": scaler_curve.transform(neighbor_curve[idx]).astype(np.float32), "neighbor_objective": neighbor_objective[idx].astype(np.float32), } diff --git a/ML/nmWTAI-ML/scripts/flatten_autofit_neighborhood_dataset.py b/ML/nmWTAI-ML/scripts/flatten_autofit_neighborhood_dataset.py index e45283c..40a4cc3 100644 --- a/ML/nmWTAI-ML/scripts/flatten_autofit_neighborhood_dataset.py +++ b/ML/nmWTAI-ML/scripts/flatten_autofit_neighborhood_dataset.py @@ -4,6 +4,8 @@ 扁平化为常规 `params/schedule/curve` 结构,方便复用已有预处理、评估和合并流程。 """ +# pylint: disable=too-many-locals,too-many-statements + from __future__ import annotations import argparse @@ -11,17 +13,20 @@ import json import sys from pathlib import Path -ROOT = Path(__file__).resolve().parents[1] -sys.path.append(str(ROOT)) - import h5py import numpy as np +ROOT = Path(__file__).resolve().parents[1] +sys.path.append(str(ROOT)) + def parse_args() -> argparse.Namespace: """解析自动拟合邻域 HDF5 的输入输出路径以及是否只导出邻域样本。""" parser = argparse.ArgumentParser( - description="Flatten anchor-neighborhood autofit HDF5 into the standard sample-level HDF5 format" + description=( + "Flatten anchor-neighborhood autofit HDF5 into the standard " + "sample-level HDF5 format" + ) ) parser.add_argument("--input", type=str, required=True, help="Input autofit neighborhood .h5") parser.add_argument("--output", type=str, required=True, help="Output flat sample-level .h5") @@ -69,7 +74,7 @@ def main() -> None: anchor_schedule_meta = np.asarray(src["anchor_schedule_meta"][:], dtype=np.float32) anchor_family_name = np.asarray(src["anchor_family_name"][:]).astype(str) anchor_section_index = np.asarray(src["anchor_section_index"][:], dtype=np.int32) - anchor_timeQ_json = np.asarray(src["anchor_timeQ_json"][:]).astype(str) + anchor_time_q_json = np.asarray(src["anchor_time_q_json"][:]).astype(str) anchor_q_json = np.asarray(src["anchor_q_json"][:]).astype(str) neighbor_anchor_id = np.asarray(src["neighbor_anchor_id"][:], dtype=np.int32) @@ -85,7 +90,14 @@ def main() -> None: ) schedule_meta_names = _decode_attr_list(src.attrs.get("schedule_meta_names")) - span_fracs = np.asarray(src.attrs.get("span_fracs", []), dtype=np.float32).reshape(-1).tolist() + span_fracs = ( + np.asarray( + src.attrs.get("span_fracs", []), + dtype=np.float32, + ) + .reshape(-1) + .tolist() + ) n_anchors = int(anchor_params.shape[0]) n_neighbors = int(neighbor_params.shape[0]) @@ -117,7 +129,11 @@ def main() -> None: dst.create_dataset("curve", shape=(total_rows, curve_dim), dtype=np.float32) dst.create_dataset("group_id", shape=(total_rows,), dtype=np.int64) dst.create_dataset("schedule_meta", shape=(total_rows, schedule_meta_dim), dtype=np.float32) - dst.create_dataset("family_name", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) + dst.create_dataset( + "family_name", + shape=(total_rows,), + dtype=h5py.string_dtype(encoding="utf-8"), + ) dst.create_dataset("is_anchor", shape=(total_rows,), dtype=np.int8) dst.create_dataset("neighbor_objective", shape=(total_rows,), dtype=np.float32) @@ -125,8 +141,16 @@ def main() -> None: dst.create_dataset("neighbor_objective_d", shape=(total_rows,), dtype=np.float32) dst.create_dataset("neighbor_span_frac", shape=(total_rows,), dtype=np.float32) dst.create_dataset("section_index", shape=(total_rows,), dtype=np.int32) - dst.create_dataset("timeQ_json", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) - dst.create_dataset("q_json", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) + dst.create_dataset( + "timeQ_json", + shape=(total_rows,), + dtype=h5py.string_dtype(encoding="utf-8"), + ) + dst.create_dataset( + "q_json", + shape=(total_rows,), + dtype=h5py.string_dtype(encoding="utf-8"), + ) write_pos = 0 @@ -145,7 +169,7 @@ def main() -> None: dst["neighbor_objective_d"][write_pos:anchor_end] = 0.0 dst["neighbor_span_frac"][write_pos:anchor_end] = 0.0 dst["section_index"][write_pos:anchor_end] = anchor_section_index - dst["timeQ_json"][write_pos:anchor_end] = anchor_timeQ_json.tolist() + dst["timeQ_json"][write_pos:anchor_end] = anchor_time_q_json.tolist() dst["q_json"][write_pos:anchor_end] = anchor_q_json.tolist() write_pos = anchor_end @@ -156,15 +180,23 @@ def main() -> None: dst["curve"][write_pos:neighbor_end] = neighbor_curve dst["group_id"][write_pos:neighbor_end] = neighbor_anchor_id dst["schedule_meta"][write_pos:neighbor_end] = anchor_schedule_meta[neighbor_anchor_id] - dst["family_name"][write_pos:neighbor_end] = anchor_family_name[neighbor_anchor_id].tolist() + dst["family_name"][write_pos:neighbor_end] = anchor_family_name[ + neighbor_anchor_id + ].tolist() dst["is_anchor"][write_pos:neighbor_end] = 0 dst["neighbor_objective"][write_pos:neighbor_end] = neighbor_objective dst["neighbor_objective_p"][write_pos:neighbor_end] = neighbor_objective_p dst["neighbor_objective_d"][write_pos:neighbor_end] = neighbor_objective_d dst["neighbor_span_frac"][write_pos:neighbor_end] = neighbor_span_frac - dst["section_index"][write_pos:neighbor_end] = anchor_section_index[neighbor_anchor_id] - dst["timeQ_json"][write_pos:neighbor_end] = anchor_timeQ_json[neighbor_anchor_id].tolist() - dst["q_json"][write_pos:neighbor_end] = anchor_q_json[neighbor_anchor_id].tolist() + dst["section_index"][write_pos:neighbor_end] = anchor_section_index[ + neighbor_anchor_id + ] + dst["timeQ_json"][write_pos:neighbor_end] = anchor_time_q_json[ + neighbor_anchor_id + ].tolist() + dst["q_json"][write_pos:neighbor_end] = anchor_q_json[ + neighbor_anchor_id + ].tolist() dst.attrs["n_samples"] = int(total_rows) dst.attrs["source_neighborhood_h5"] = str(input_path) @@ -172,7 +204,10 @@ def main() -> None: print("Autofit neighborhood flatten complete.") print(f"Input: {input_path}") print(f"Output: {output_path}") - print(f"anchors={n_anchors}, neighbors={n_neighbors}, total_exported={total_rows}") + print( + f"anchors={n_anchors}, neighbors={n_neighbors}, " + f"total_exported={total_rows}" + ) print(f"neighbors_only={bool(args.neighbors_only)}") diff --git a/ML/nmWTAI-ML/scripts/generate_autofit_neighborhood_dataset.py b/ML/nmWTAI-ML/scripts/generate_autofit_neighborhood_dataset.py index f5b8444..0ef5b73 100644 --- a/ML/nmWTAI-ML/scripts/generate_autofit_neighborhood_dataset.py +++ b/ML/nmWTAI-ML/scripts/generate_autofit_neighborhood_dataset.py @@ -5,6 +5,17 @@ 输出的 HDF5 同时包含曲线、参数、制度编码和排序训练所需的邻域元数据。 """ +# pylint: disable= +# import-error, +# wrong-import-position, +# too-many-locals, +# too-many-arguments, +# too-many-positional-arguments, +# too-many-statements, +# invalid-name, +# broad-exception-caught, +# no-member + from __future__ import annotations import argparse diff --git a/ML/nmWTAI-ML/scripts/generate_dataset.py b/ML/nmWTAI-ML/scripts/generate_dataset.py index 72420f4..f000ae1 100644 --- a/ML/nmWTAI-ML/scripts/generate_dataset.py +++ b/ML/nmWTAI-ML/scripts/generate_dataset.py @@ -3,6 +3,7 @@ 脚本读取数据生成配置,调用并行数据集生成器批量采样地层/井筒参数和流量制度, 运行底层数值求解器并把有效曲线写入 HDF5。它是训练前最上游的数据生产入口。 """ +# pylint: disable=import-error,wrong-import-position,broad-exception-caught from __future__ import annotations import argparse @@ -18,20 +19,31 @@ from src.common.experiment_paths import config_for_stage from src.data.dataset_generation import ParallelDatasetGenerator -def main(): +def main() -> None: """按配置阶段启动并行数值试井样本生成,输出原始 HDF5 数据集路径。""" parser = argparse.ArgumentParser() parser.add_argument("--config", default=None) parser.add_argument( "--stage", - choices=["fixed_case", "case_neighborhood", "family_random", "family_random_hard", "family_random_v2_q"], + choices=[ + "fixed_case", + "case_neighborhood", + "family_random", + "family_random_hard", + "family_random_v2_q", + ], default=None, ) parser.add_argument("--n-samples", type=int, default=None) parser.add_argument("--n-workers", type=int, default=None) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--method", type=str, default=None) - parser.add_argument("--dataset-tag", type=str, default=None, help="Optional tag injected into output dataset filename") + parser.add_argument( + "--dataset-tag", + type=str, + default=None, + help="Optional tag injected into output dataset filename", + ) args = parser.parse_args() config_path = args.config @@ -41,8 +53,14 @@ def main(): # stage 用来选择预设配置;命令行参数继续覆盖样本数、并行数和随机种子。 cfg = Config(config_path) cfg.ensure_dirs() - path = ParallelDatasetGenerator(cfg=cfg, n_workers=args.n_workers).generate( - n_samples=args.n_samples, method=args.method, random_seed=args.seed, dataset_tag=args.dataset_tag + path = ParallelDatasetGenerator( + cfg=cfg, + n_workers=args.n_workers, + ).generate( + n_samples=args.n_samples, + method=args.method, + random_seed=args.seed, + dataset_tag=args.dataset_tag, ) print(path) @@ -51,6 +69,6 @@ if __name__ == "__main__": mp.freeze_support() try: mp.set_start_method("spawn", force=True) - except Exception: + except RuntimeError: pass main() diff --git a/ML/nmWTAI-ML/scripts/merge_datasets.py b/ML/nmWTAI-ML/scripts/merge_datasets.py index 9ffd8dd..67512b2 100644 --- a/ML/nmWTAI-ML/scripts/merge_datasets.py +++ b/ML/nmWTAI-ML/scripts/merge_datasets.py @@ -5,6 +5,8 @@ 提升代理模型在实际反演困难区域的覆盖度。 """ +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,no-member + from __future__ import annotations import argparse @@ -12,6 +14,7 @@ import json import math import sys from pathlib import Path +from types import SimpleNamespace from typing import Any import h5py @@ -390,15 +393,12 @@ def merge_datasets( merge_meta: dict[str, Any] - class _Args: - """内部轻量参数对象,用于把递归合并复用到单文件合并场景。""" - pass - - count_args = _Args() - count_args.normal_count = normal_count - count_args.hard_count = hard_count - count_args.total_samples = total_samples - count_args.hard_ratio = hard_ratio + count_args = SimpleNamespace( + normal_count=normal_count, + hard_count=hard_count, + total_samples=total_samples, + hard_ratio=hard_ratio, + ) with h5py.File(normal_input, "r") as normal_file, h5py.File(hard_input, "r") as hard_file: # 合并前先锁定 schema,避免后面边写边发现字段不兼容。 diff --git a/ML/nmWTAI-ML/scripts/preprocess_dataset.py b/ML/nmWTAI-ML/scripts/preprocess_dataset.py index 3feaea4..871fa08 100644 --- a/ML/nmWTAI-ML/scripts/preprocess_dataset.py +++ b/ML/nmWTAI-ML/scripts/preprocess_dataset.py @@ -5,6 +5,8 @@ processed 数据文件。 """ +# pylint: disable=import-error,wrong-import-position + from __future__ import annotations import argparse @@ -20,7 +22,9 @@ from src.data.preprocess import preprocess_dataset def main() -> None: """把原始 HDF5 曲线数据切分、标准化并保存为训练用 pkl。""" - parser = argparse.ArgumentParser(description="Preprocess HDF5 dataset for forward surrogate") + parser = argparse.ArgumentParser( + description="Preprocess HDF5 dataset for forward surrogate" + ) parser.add_argument( "--input", type=str, @@ -33,7 +37,12 @@ def main() -> None: default=None, help="Optional output .pkl path", ) - parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming") + parser.add_argument( + "--tag", + type=str, + default=None, + help="Experiment tag for auto naming", + ) parser.add_argument("--test-size", type=float, default=0.15) parser.add_argument("--val-size", type=float, default=0.15) parser.add_argument("--seed", type=int, default=42) @@ -45,8 +54,12 @@ def main() -> None: args = parser.parse_args() tag = normalize_tag(args.tag) - # 未显式指定输出路径时,使用 tag 生成与训练脚本约定一致的 processed 文件名。 - output_path = Path(args.output) if args.output is not None else processed_path_for_tag(tag) + + output_path = ( + Path(args.output) + if args.output is not None + else processed_path_for_tag(tag) + ) preprocess_dataset( input_path=Path(args.input), diff --git a/ML/nmWTAI-ML/scripts/q_sweep_local_ranking.py b/ML/nmWTAI-ML/scripts/q_sweep_local_ranking.py index db91bc4..d0bc700 100644 --- a/ML/nmWTAI-ML/scripts/q_sweep_local_ranking.py +++ b/ML/nmWTAI-ML/scripts/q_sweep_local_ranking.py @@ -5,6 +5,8 @@ 保持稳定的自动拟合筛选能力。 """ +# pylint: disable=import-error,wrong-import-position,wrong-import-order,too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements,broad-exception-caught + from __future__ import annotations import argparse @@ -31,7 +33,12 @@ from scripts.validate_autofit_local_ranking import ( run_solver_and_extract_curve, ) from src.common.config import Config -from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag +from src.common.experiment_paths import ( + config_for_stage, + model_checkpoint_for_tag, + normalize_tag, + processed_path_for_tag, +) from src.data.params import Params, Schedule from src.data.runner_client import CppRunner from src.evaluation.autofit_objective import dual_log_objective @@ -55,7 +62,17 @@ def parse_args() -> argparse.Namespace: description="Generate Q schedules around theta* and run first-layer local ranking validation." ) parser.add_argument("--config", type=str, default=None) - parser.add_argument("--stage", choices=["fixed_case", "case_neighborhood", "family_random", "family_random_hard", "family_random_v2_q"], default="family_random") + parser.add_argument( + "--stage", + choices=[ + "fixed_case", + "case_neighborhood", + "family_random", + "family_random_hard", + "family_random_v2_q", + ], + default="family_random", + ) parser.add_argument("--processed", type=str, default=None) parser.add_argument("--model", type=str, default=None) parser.add_argument("--tag", type=str, default="family_random_mixed_50k_logparam") @@ -149,7 +166,15 @@ def schedule_features(time_q: list[float], q: list[float]) -> dict: } -def add_case(cases: list[dict], seen: set[tuple], case_id: str, family: str, time_q: list[float], q: list[float], axis: str) -> None: +def add_case( + cases: list[dict], + seen: set[tuple], + case_id: str, + family: str, + time_q: list[float], + q: list[float], + axis: str, +) -> None: """向案例列表追加一个流量制度候选及其特征。""" key = tuple(round(float(x), 8) for x in (time_q + q)) if key in seen: @@ -180,7 +205,15 @@ def generate_q_cases() -> list[dict]: base_q = [170.0, 170.0, 210.0, 250.0] base_shutin = 72.0 - add_case(cases, seen, "Q000_baseline_b3_alt_like", "mild_step", base_prod_dt + [base_shutin], base_q + [0.0], "baseline") + add_case( + cases, + seen, + "Q000_baseline_b3_alt_like", + "mild_step", + base_prod_dt + [base_shutin], + base_q + [0.0], + "baseline", + ) # 分别扫描生产总时长、关井时长、流量倍率、阶跃强度和生产段数量。 for total in [24.0, 48.0, 72.0, 96.0, 160.0, 220.0, 260.0]: @@ -192,12 +225,28 @@ def generate_q_cases() -> list[dict]: for scale in [0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0]: q_prod = [float(x * scale) for x in base_q] - add_case(cases, seen, f"q_scale_{str(scale).replace('.', 'p')}", "mild_step", base_prod_dt + [base_shutin], q_prod + [0.0], "q_scale") + add_case( + cases, + seen, + f"prod_total_{int(total)}", + "mild_step", + time_q, + base_q + [0.0], + "prod_total_time", + ) for ratio in [1.2, 1.6, 2.0, 2.2, 2.8, 3.5]: q0 = 180.0 q_prod = [q0, q0 * ratio, q0, q0 * ratio] - add_case(cases, seen, f"step_ratio_{str(ratio).replace('.', 'p')}", "sharp_step", base_prod_dt + [base_shutin], q_prod + [0.0], "step_ratio") + add_case( + cases, + seen, + f"step_ratio_{str(ratio).replace('.', 'p')}", + "sharp_step", + base_prod_dt + [base_shutin], + q_prod + [0.0], + "step_ratio", + ) for n_prod in [2, 3, 4, 5, 6]: total = 84.0 @@ -338,7 +387,12 @@ def main() -> None: for case_idx, q_case in enumerate(q_cases): schedule: Schedule = q_case["schedule"] if not schedule.validate(): - failed_case_rows.append({**{k: v for k, v in q_case.items() if k != "schedule"}, "reason": "invalid schedule"}) + failed_case_rows.append( + { + **{k: v for k, v in q_case.items() if k != "schedule"}, + "reason": "invalid schedule", + } + ) continue target_params = make_theta_params(args, schedule) @@ -359,7 +413,12 @@ def main() -> None: timeout=int(args.solver_timeout), ) except Exception as exc: - failed_case_rows.append({**{k: v for k, v in q_case.items() if k != "schedule"}, "reason": str(exc)}) + failed_case_rows.append( + { + **{k: v for k, v in q_case.items() if k != "schedule"}, + "reason": str(exc), + } + ) print(f" [fail] target solver failed: {exc}") continue finally: @@ -436,7 +495,12 @@ def main() -> None: runner.close() if len(rows) < 2: - failed_case_rows.append({**{k: v for k, v in q_case.items() if k != "schedule"}, "reason": "fewer than two valid candidates"}) + failed_case_rows.append( + { + **{k: v for k, v in q_case.items() if k != "schedule"}, + "reason": "fewer than two valid candidates", + } + ) print(f" [fail] valid candidates={len(rows)}") continue diff --git a/ML/nmWTAI-ML/scripts/replay_pso_trace_screening.py b/ML/nmWTAI-ML/scripts/replay_pso_trace_screening.py index a4faf44..44f2e92 100644 --- a/ML/nmWTAI-ML/scripts/replay_pso_trace_screening.py +++ b/ML/nmWTAI-ML/scripts/replay_pso_trace_screening.py @@ -5,6 +5,9 @@ 调用量而不明显损失最优候选。 """ +# pylint: disable=import-error,wrong-import-position,wrong-import-order,too-many-locals,too-many-statements,broad-exception-caught + + from __future__ import annotations import argparse @@ -22,10 +25,15 @@ import joblib import matplotlib.pyplot as plt import numpy as np -from scripts.compare_single_case import build_schedule_vector, load_model +from scripts.compare_single_case import load_model from scripts.validate_autofit_local_ranking import infer_curve_layout, predict_surrogate_curve from src.common.config import Config -from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag +from src.common.experiment_paths import ( + config_for_stage, + model_checkpoint_for_tag, + normalize_tag, + processed_path_for_tag, +) from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features from src.data.params import Params, Schedule from src.evaluation.autofit_objective import dual_log_objective @@ -189,7 +197,13 @@ def corr_spearman(a: np.ndarray, b: np.ndarray) -> float: def summarize_generation(rows: list[dict], keep_fracs: list[float]) -> tuple[dict, list[dict]]: """按 PSO 代数汇总代理筛选保留真实优质候选的效果。""" - valid = [r for r in rows if r["solver_success"] == "1" and math.isfinite(float(r["solver_objective"])) and float(r["solver_objective"]) < 1e9] + valid = [ + r + for r in rows + if r["solver_success"] == "1" + and math.isfinite(float(r["solver_objective"])) + and float(r["solver_objective"]) < 1e9 + ] if len(valid) < 2: return {}, [] diff --git a/ML/nmWTAI-ML/scripts/score_pso_candidates.py b/ML/nmWTAI-ML/scripts/score_pso_candidates.py index cd448b1..aef1106 100644 --- a/ML/nmWTAI-ML/scripts/score_pso_candidates.py +++ b/ML/nmWTAI-ML/scripts/score_pso_candidates.py @@ -4,6 +4,8 @@ 再计算与目标曲线的自动拟合目标函数,输出可被 PSO 或筛选流程继续使用的打分 CSV。 """ +# pylint: disable=import-error,wrong-import-position,wrong-import-order,too-many-locals,broad-exception-caught + from __future__ import annotations import argparse @@ -21,7 +23,12 @@ import numpy as np from scripts.compare_single_case import load_model from scripts.validate_autofit_local_ranking import infer_curve_layout, predict_surrogate_curve from src.common.config import Config -from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag +from src.common.experiment_paths import ( + config_for_stage, + model_checkpoint_for_tag, + normalize_tag, + processed_path_for_tag, +) from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features from src.data.params import Params, Schedule from src.evaluation.autofit_objective import dual_log_objective @@ -30,13 +37,21 @@ from src.evaluation.autofit_objective import dual_log_objective PARAM_COLUMNS = ["k", "skin", "wellboreC", "phi", "h", "Cf"] - - def parse_args() -> argparse.Namespace: """解析 PSO 候选粒子 CSV、trace 元数据、代理模型和输出评分表路径。""" parser = argparse.ArgumentParser(description="使用正演代理模型为一批 PSO 候选粒子打分") - parser.add_argument("--candidates", required=True, type=str, help="候选粒子 CSV,包含 particle_id,k,skin,wellboreC,phi,h,Cf") - parser.add_argument("--trace-meta", required=True, type=str, help="匹配的 pso_baseline_trace_*.meta.json") + parser.add_argument( + "--candidates", + required=True, + type=str, + help="候选粒子 CSV,包含 particle_id,k,skin,wellboreC,phi,h,Cf", + ) + parser.add_argument( + "--trace-meta", + required=True, + type=str, + help="匹配的 pso_baseline_trace_*.meta.json", + ) parser.add_argument("--output", required=True, type=str, help="输出代理目标函数 CSV") parser.add_argument("--tag", type=str, default="family_random_mixed_50k_logparam") parser.add_argument("--stage", type=str, default="family_random") diff --git a/ML/nmWTAI-ML/scripts/score_pso_candidates_server.py b/ML/nmWTAI-ML/scripts/score_pso_candidates_server.py index 304ac4f..b70a4b5 100644 --- a/ML/nmWTAI-ML/scripts/score_pso_candidates_server.py +++ b/ML/nmWTAI-ML/scripts/score_pso_candidates_server.py @@ -4,6 +4,8 @@ 随后通过标准输入接收候选 CSV 路径并输出 JSON 状态,降低 PSO 与代理模型耦合时 频繁加载 checkpoint 的开销。 """ +# pylint: disable=import-error,wrong-import-position,wrong-import-order,too-many-instance-attributes,too-few-public-methods,broad-exception-caught + from __future__ import annotations diff --git a/ML/nmWTAI-ML/scripts/train_forward.py b/ML/nmWTAI-ML/scripts/train_forward.py index 2ce5dc3..8206f8d 100644 --- a/ML/nmWTAI-ML/scripts/train_forward.py +++ b/ML/nmWTAI-ML/scripts/train_forward.py @@ -5,6 +5,8 @@ 代理模型的主训练入口。 """ +# pylint: disable=import-error,wrong-import-position + from __future__ import annotations import argparse @@ -15,7 +17,16 @@ ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import model_dir_for_tag, normalize_tag, processed_path_for_tag -from src.training.train_forward import TrainConfig, train_forward +from src.training.train_forward import ( + LossConfig, + LossWeights, + ModelConfig, + OptimConfig, + SampleReweightConfig, + TrainConfig, + TrainRuntime, + train_forward, +) def main() -> None: @@ -53,6 +64,7 @@ def main() -> None: parser.add_argument("--huber-beta", type=float, default=0.05) parser.add_argument("--use-sample-reweight", action="store_true", default=True) + parser.add_argument("--no-sample-reweight", action="store_false", dest="use_sample_reweight") parser.add_argument("--sample-reweight-alpha", type=float, default=0.4) parser.add_argument("--sample-weight-min", type=float, default=1.0) parser.add_argument("--sample-weight-max", type=float, default=2.5) @@ -74,27 +86,37 @@ def main() -> None: cfg = TrainConfig( processed_path=processed_path, output_dir=output_dir, - seed=args.seed, - batch_size=args.batch_size, - epochs=args.epochs, - lr=args.lr, - weight_decay=args.weight_decay, - hidden_dim=args.hidden_dim, - dropout=args.dropout, - w_pressure=args.w_pressure, - w_derivative=args.w_derivative, - w_slope=args.w_slope, - w_bias_pressure=args.w_bias_pressure, - w_bias_derivative=args.w_bias_derivative, - w_derivative_shape=args.w_derivative_shape, - w_autofit_pressure=args.w_autofit_pressure, - w_autofit_derivative=args.w_autofit_derivative, - huber_beta=args.huber_beta, - use_sample_reweight=args.use_sample_reweight, - sample_reweight_alpha=args.sample_reweight_alpha, - sample_weight_min=args.sample_weight_min, - sample_weight_max=args.sample_weight_max, - use_schedule=use_schedule, + runtime=TrainRuntime(seed=args.seed), + optim=OptimConfig( + batch_size=args.batch_size, + epochs=args.epochs, + lr=args.lr, + weight_decay=args.weight_decay, + ), + model=ModelConfig( + hidden_dim=args.hidden_dim, + dropout=args.dropout, + use_schedule=use_schedule, + ), + loss=LossConfig( + weights=LossWeights( + pressure=args.w_pressure, + derivative=args.w_derivative, + slope=args.w_slope, + bias_pressure=args.w_bias_pressure, + bias_derivative=args.w_bias_derivative, + derivative_shape=args.w_derivative_shape, + autofit_pressure=args.w_autofit_pressure, + autofit_derivative=args.w_autofit_derivative, + ), + huber_beta=args.huber_beta, + ), + sample_reweight=SampleReweightConfig( + enabled=args.use_sample_reweight, + alpha=args.sample_reweight_alpha, + weight_min=args.sample_weight_min, + weight_max=args.sample_weight_max, + ), ) train_forward(cfg) diff --git a/ML/nmWTAI-ML/scripts/train_forward_ensemble.py b/ML/nmWTAI-ML/scripts/train_forward_ensemble.py index 178e7e1..a2b580a 100644 --- a/ML/nmWTAI-ML/scripts/train_forward_ensemble.py +++ b/ML/nmWTAI-ML/scripts/train_forward_ensemble.py @@ -4,6 +4,8 @@ 独立 seed 和输出目录。所得模型集合用于后续不确定性估计、误差风险分析和 fallback 筛选。 """ +# pylint: disable=import-error,wrong-import-position,duplicate-code + from __future__ import annotations import argparse @@ -15,7 +17,16 @@ ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import normalize_tag, processed_path_for_tag -from src.training.train_forward import TrainConfig, train_forward +from src.training.train_forward import ( + LossConfig, + LossWeights, + ModelConfig, + OptimConfig, + SampleReweightConfig, + TrainConfig, + TrainRuntime, + train_forward, +) def parse_seed_list(seed_text: str) -> list[int]: @@ -62,6 +73,7 @@ def main() -> None: parser.add_argument("--w-autofit-derivative", type=float, default=0.0) parser.add_argument("--huber-beta", type=float, default=0.05) parser.add_argument("--use-sample-reweight", action="store_true", default=True) + parser.add_argument("--no-sample-reweight", action="store_false", dest="use_sample_reweight") parser.add_argument("--sample-reweight-alpha", type=float, default=0.4) parser.add_argument("--sample-weight-min", type=float, default=1.0) parser.add_argument("--sample-weight-max", type=float, default=2.5) @@ -91,27 +103,37 @@ def main() -> None: cfg = TrainConfig( processed_path=processed_path, output_dir=member_dir, - seed=seed, - batch_size=args.batch_size, - epochs=args.epochs, - lr=args.lr, - weight_decay=args.weight_decay, - hidden_dim=args.hidden_dim, - dropout=args.dropout, - w_pressure=args.w_pressure, - w_derivative=args.w_derivative, - w_slope=args.w_slope, - w_bias_pressure=args.w_bias_pressure, - w_bias_derivative=args.w_bias_derivative, - w_derivative_shape=args.w_derivative_shape, - w_autofit_pressure=args.w_autofit_pressure, - w_autofit_derivative=args.w_autofit_derivative, - huber_beta=args.huber_beta, - use_sample_reweight=args.use_sample_reweight, - sample_reweight_alpha=args.sample_reweight_alpha, - sample_weight_min=args.sample_weight_min, - sample_weight_max=args.sample_weight_max, - use_schedule=use_schedule, + runtime=TrainRuntime(seed=seed), + optim=OptimConfig( + batch_size=args.batch_size, + epochs=args.epochs, + lr=args.lr, + weight_decay=args.weight_decay, + ), + model=ModelConfig( + hidden_dim=args.hidden_dim, + dropout=args.dropout, + use_schedule=use_schedule, + ), + loss=LossConfig( + weights=LossWeights( + pressure=args.w_pressure, + derivative=args.w_derivative, + slope=args.w_slope, + bias_pressure=args.w_bias_pressure, + bias_derivative=args.w_bias_derivative, + derivative_shape=args.w_derivative_shape, + autofit_pressure=args.w_autofit_pressure, + autofit_derivative=args.w_autofit_derivative, + ), + huber_beta=args.huber_beta, + ), + sample_reweight=SampleReweightConfig( + enabled=args.use_sample_reweight, + alpha=args.sample_reweight_alpha, + weight_min=args.sample_weight_min, + weight_max=args.sample_weight_max, + ), ) train_forward(cfg) manifest["members"].append( diff --git a/ML/nmWTAI-ML/scripts/train_time_conditioned.py b/ML/nmWTAI-ML/scripts/train_time_conditioned.py index 3d2f63c..162e6a3 100644 --- a/ML/nmWTAI-ML/scripts/train_time_conditioned.py +++ b/ML/nmWTAI-ML/scripts/train_time_conditioned.py @@ -5,6 +5,8 @@ 适合处理可变时间采样或逐点推理场景。 """ +# pylint: disable=import-error,wrong-import-position + from __future__ import annotations import argparse @@ -15,7 +17,15 @@ ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import normalize_tag, processed_path_for_tag -from src.training.train_time_conditioned import TimeConditionedTrainConfig, train_time_conditioned +from src.training.train_time_conditioned import ( + RiskWeightConfig, + TimeConditionedTrainConfig, + TimeLossConfig, + TimeModelConfig, + TimeOptimConfig, + TimeRuntimeConfig, + train_time_conditioned, +) def main() -> None: @@ -61,23 +71,31 @@ def main() -> None: cfg = TimeConditionedTrainConfig( processed_path=processed_path, output_dir=output_dir, - seed=int(args.seed), - batch_size=int(args.batch_size), - epochs=int(args.epochs), - lr=float(args.lr), - weight_decay=float(args.weight_decay), - hidden_dim=int(args.hidden_dim), - n_blocks=int(args.n_blocks), - dropout=float(args.dropout), - w_pressure=float(args.w_pressure), - w_derivative=float(args.w_derivative), - huber_beta=float(args.huber_beta), - use_schedule=not bool(args.no_schedule), - sample_weight_mode=str(args.sample_weight_mode), - risk_weight=float(args.risk_weight), - skin_lt_minus8_weight=float(args.skin_lt_minus8_weight), - sample_weight_min=float(args.sample_weight_min), - sample_weight_max=float(args.sample_weight_max), + runtime=TimeRuntimeConfig(seed=int(args.seed)), + optim=TimeOptimConfig( + batch_size=int(args.batch_size), + epochs=int(args.epochs), + lr=float(args.lr), + weight_decay=float(args.weight_decay), + ), + model=TimeModelConfig( + hidden_dim=int(args.hidden_dim), + n_blocks=int(args.n_blocks), + dropout=float(args.dropout), + use_schedule=not bool(args.no_schedule), + ), + loss=TimeLossConfig( + w_pressure=float(args.w_pressure), + w_derivative=float(args.w_derivative), + huber_beta=float(args.huber_beta), + ), + risk_weight=RiskWeightConfig( + mode=str(args.sample_weight_mode), + risk_weight=float(args.risk_weight), + skin_lt_minus8_weight=float(args.skin_lt_minus8_weight), + weight_min=float(args.sample_weight_min), + weight_max=float(args.sample_weight_max), + ), ) train_time_conditioned(cfg) diff --git a/ML/nmWTAI-ML/scripts/validate_autofit_local_ranking.py b/ML/nmWTAI-ML/scripts/validate_autofit_local_ranking.py index eab320b..b1cd3f4 100644 --- a/ML/nmWTAI-ML/scripts/validate_autofit_local_ranking.py +++ b/ML/nmWTAI-ML/scripts/validate_autofit_local_ranking.py @@ -5,6 +5,8 @@ 它直接回答“代理模型能否用于 PSO 候选预筛选”这个问题。 """ +# pylint: disable=import-error,wrong-import-position,invalid-name,too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements,broad-exception-caught,no-member + from __future__ import annotations import argparse @@ -103,12 +105,12 @@ def infer_curve_layout(meta: dict, curve_dim: int) -> dict: def resolve_case_schedule(cfg: Config) -> Schedule: """解析局部排序验证目标案例的流量制度,支持命令行覆盖配置默认值。""" schedule_cfg = cfg.raw["schedule"]["case_schedule"] - timeQ = list(map(float, schedule_cfg["timeQ"])) + time_q = list(map(float, schedule_cfg["timeQ"])) q = list(map(float, schedule_cfg["q"])) policy = cfg.raw["schedule"].get("section_policy", {}) or {} mode = str(policy.get("mode", "fixed_last")).lower() - n_sections = len(timeQ) + n_sections = len(time_q) if mode == "fixed_last": section_index = n_sections @@ -117,7 +119,7 @@ def resolve_case_schedule(cfg: Config) -> Schedule: else: section_index = int(np.clip(int(schedule_cfg.get("default_section_index", n_sections)), 1, n_sections)) - schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q) + schedule = Schedule(sectionIndex=section_index, timeQ=time_q, q=q) if not schedule.validate(): raise ValueError("Invalid case_schedule in config") return schedule diff --git a/ML/nmWTAI-ML/scripts/validate_autofit_local_ranking_batch.py b/ML/nmWTAI-ML/scripts/validate_autofit_local_ranking_batch.py index a899ffa..d2560c0 100644 --- a/ML/nmWTAI-ML/scripts/validate_autofit_local_ranking_batch.py +++ b/ML/nmWTAI-ML/scripts/validate_autofit_local_ranking_batch.py @@ -4,6 +4,7 @@ 重复采样局部候选,汇总代理目标与真实目标之间的排序相关性、保留比例和失败案例, 用于给出更稳健的 PSO 预筛选可用性判断。 """ +# pylint: disable=import-error,wrong-import-position,wrong-import-order,invalid-name,too-many-locals,too-many-branches,too-many-statements,broad-exception-caught,no-member from __future__ import annotations @@ -95,10 +96,10 @@ def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]: def sample_target_case(cfg: Config, rng: np.random.RandomState, seed: int) -> Params: """从配置或数据集中抽取一个目标案例,用于局部排序验证。""" params = generate_params_dataset(cfg, n_samples=1, method=cfg.raw["params"].get("sampling_method", "sobol"), random_seed=seed)[0] - timeQ, q, _sched_info = sample_schedule_by_mode(cfg, rng) - section_indices = _resolve_section_indices(cfg, timeQ, q, rng) + time_q, q, _sched_info = sample_schedule_by_mode(cfg, rng) + section_indices = _resolve_section_indices(cfg, time_q, q, rng) section_index = int(section_indices[int(rng.randint(0, len(section_indices)))]) - params.schedule = Schedule(sectionIndex=section_index, timeQ=list(map(float, timeQ)), q=list(map(float, q))) + params.schedule = Schedule(sectionIndex=section_index, timeQ=list(map(float, time_q)), q=list(map(float, q))) return params diff --git a/ML/nmWTAI-ML/src/common/config.py b/ML/nmWTAI-ML/src/common/config.py index 695ff9e..292aba6 100644 --- a/ML/nmWTAI-ML/src/common/config.py +++ b/ML/nmWTAI-ML/src/common/config.py @@ -30,6 +30,7 @@ class ProjectPaths: class Config: + # pylint: disable=too-many-instance-attributes """读取 YAML 配置,并把项目根目录、数据目录、模型目录等派生为可复用属性。""" def __init__(self, config_path: str | Path) -> None: """读取 YAML 配置,解析项目根目录以及训练、求解器和数据文件路径。""" @@ -85,10 +86,9 @@ class Config: @property def curve_dim(self) -> int: """返回配置中的曲线输出维度。""" - T = int(self.get("curve_processing", "n_time_points", default=160)) + n_time_points = int(self.get("curve_processing", "n_time_points", default=160)) use_slope = bool(self.get("curve_processing", "use_slope_feature", default=True)) - # 曲线至少包含 pressure 和 derivative;启用 slope 时再追加一段辅助特征。 - return (2 + (1 if use_slope else 0)) * T + return (2 + (1 if use_slope else 0)) * n_time_points @property def sec_feat_dim(self) -> int: @@ -98,13 +98,14 @@ class Config: @property def schedule_grid_shape(self) -> tuple[int, int]: """返回流量制度固定时间网格的形状配置。""" - Nu = int(self.get("timegrid_encoding", "n_u_points", default=256)) - Cu = 1 - # 每开启一个附加通道,固定网格的通道数就增加一维。 + n_u_points = int(self.get("timegrid_encoding", "n_u_points", default=256)) + n_channels = 1 + if bool(self.get("timegrid_encoding", "include_cum", default=True)): - Cu += 1 + n_channels += 1 if bool(self.get("timegrid_encoding", "include_dq", default=True)): - Cu += 1 + n_channels += 1 if bool(self.get("timegrid_encoding", "include_shutin", default=False)): - Cu += 1 - return Nu, Cu + n_channels += 1 + + return n_u_points, n_channels diff --git a/ML/nmWTAI-ML/src/data/curve_processing.py b/ML/nmWTAI-ML/src/data/curve_processing.py index 79cbd0f..0459dc2 100644 --- a/ML/nmWTAI-ML/src/data/curve_processing.py +++ b/ML/nmWTAI-ML/src/data/curve_processing.py @@ -9,6 +9,8 @@ C++ 数值试井求解器输出的双对数曲线通常是不等长时间序列 的维度和顺序必须与 Config.curve_dim、meta.curve_layout 保持一致。 """ +# pylint: disable=import-error,too-many-arguments,too-many-positional-arguments,too-many-return-statements,too-many-locals + from __future__ import annotations from typing import Optional, Tuple @@ -91,6 +93,8 @@ def clean_curve_for_dataset( 操作可能掩盖真实的求解质量问题;更严格的有效性判断交给 is_valid_curve 完成。 """ c = cfg.raw["curve_processing"] + # 该参数保留用于兼容旧调用接口;当前版本只做保守清洗,不做异常值平滑。 + _ = outlier_factor min_len = int(min_len if min_len is not None else c.get("min_valid_points", 30)) eps = float(eps if eps is not None else c.get("feature_epsilon", 1e-12)) @@ -140,7 +144,7 @@ def _make_time_grid(cfg: Config, t0: float, t1: float, n: int) -> np.ndarray: def curve_time_grid_for_sample(cfg: Config, t: np.ndarray, n: Optional[int] = None) -> np.ndarray: """为单个样本生成与曲线特征对齐的时间坐标。""" - nT = int(n if n is not None else cfg.get("curve_processing", "n_time_points", default=160)) + n_time_points = int(n if n is not None else cfg.get("curve_processing", "n_time_points", default=160)) t = np.asarray(t, dtype=np.float64).reshape(-1) if t.size < 2: raise RuntimeError("curve_time_grid_for_sample: time array is too short") @@ -151,7 +155,7 @@ def curve_time_grid_for_sample(cfg: Config, t: np.ndarray, n: Optional[int] = No else: t0 = float(np.min(t)) t1 = float(np.max(t)) - return _make_time_grid(cfg, t0, t1, nT).astype(np.float32) + return _make_time_grid(cfg, t0, t1, n_time_points).astype(np.float32) def resample_curve_to_features_with_time( @@ -170,7 +174,7 @@ def resample_curve_to_features_with_time( [cfg.curve_dim];第二个数组是该样本使用的时间网格,形状为 [n_time_points]。 """ eps = float(cfg.get("curve_processing", "feature_epsilon", default=1e-12)) - nT = int(cfg.get("curve_processing", "n_time_points", default=160)) + n_time_points = int(cfg.get("curve_processing", "n_time_points", default=160)) use_slope = bool(cfg.get("curve_processing", "use_slope_feature", default=True)) t = np.asarray(t, dtype=np.float64).reshape(-1) @@ -180,7 +184,7 @@ def resample_curve_to_features_with_time( if t.size < 2: raise RuntimeError("resample_curve_to_features: 时间点过少") - grid = curve_time_grid_for_sample(cfg, t, n=nT).astype(np.float64) + grid = curve_time_grid_for_sample(cfg, t, n=n_time_points).astype(np.float64) logp = np.log(np.maximum(p, eps)) logd = np.log(np.maximum(np.abs(d), eps)) @@ -194,8 +198,8 @@ def resample_curve_to_features_with_time( if use_slope: # slope 作为辅助特征保留;即使训练权重设为 0,也能保持数据布局兼容。 logt = np.log(np.maximum(grid, eps)).astype(np.float64) - s = np.zeros((nT,), dtype=np.float32) - if nT >= 3: + s = np.zeros((n_time_points,), dtype=np.float32) + if n_time_points >= 3: denom = logt[2:] - logt[:-2] denom = np.maximum(denom, 1e-12) s[1:-1] = ((lp[2:] - lp[:-2]) / denom).astype(np.float32) diff --git a/ML/nmWTAI-ML/src/data/dataset_generation.py b/ML/nmWTAI-ML/src/data/dataset_generation.py index db668dc..779afd6 100644 --- a/ML/nmWTAI-ML/src/data/dataset_generation.py +++ b/ML/nmWTAI-ML/src/data/dataset_generation.py @@ -11,6 +11,8 @@ HDF5 流式写入,是构建正演代理模型训练集的主流程。为了支 区域分析和自动拟合候选排序验证。 """ +# pylint: disable=import-error,invalid-name,too-many-locals,too-many-branches,too-many-statements,too-many-return-statements,too-many-instance-attributes,too-many-arguments,too-many-positional-arguments,broad-exception-caught,no-member,global-statement,too-few-public-methods + from __future__ import annotations import json @@ -41,11 +43,11 @@ except ImportError: def update(self, _=1): """空进度条的 update 方法,保持与 tqdm.update 接口一致。""" - pass + return None def close(self): """空进度条的 close 方法,保持与 tqdm.close 接口一致。""" - pass + return None return _NoopTqdm(**kwargs) return iterable @@ -250,6 +252,7 @@ def build_schedule_metadata( def _sample_schedule_fixed_case(cfg: Config, rng) -> Tuple[List[float], List[float], Dict[str, Any]]: """返回配置中固定的流量制度,用于基准案例或可复现实验。""" + _ = rng sc = cfg.raw["schedule"]["case_schedule"] return ( list(map(float, sc["timeQ"])), @@ -360,6 +363,7 @@ def sample_schedule_by_mode(cfg: Config, rng) -> Tuple[List[float], List[float], def _resolve_section_indices(cfg: Config, timeQ, q, rng): """解析允许作为 sectionIndex 的分段范围,并裁剪到当前流量制度长度内。""" + _ = q policy = cfg.raw["schedule"]["section_policy"] mode = str(policy["mode"]).lower() n = int(len(timeQ)) @@ -439,6 +443,7 @@ def _worker_simulate_parallel(args): encoding="utf-8", errors="ignore", env=_SUBPROC_ENV, + check=False, ) if result.returncode != 0 or not _RESULT_BIN.exists(): @@ -580,8 +585,8 @@ class HDF5Appender: if not samples: return - B = len(samples) - start, end = self._n, self._n + B + batch_size = len(samples) + start, end = self._n, self._n + batch_size self.d_group_id.resize((end,)) self.d_params.resize((end, self.param_dim)) @@ -591,13 +596,13 @@ class HDF5Appender: self.d_schedule_meta.resize((end, len(SCHEDULE_META_NAMES))) self.d_family_name.resize((end,)) - gid = np.full((B,), -1, dtype=np.int64) - params = np.full((B, self.param_dim), np.nan, dtype=np.float32) - schedule = np.full((B, self.schedule_dim), np.nan, dtype=np.float32) - curve = np.full((B, self.curve_dim), np.nan, dtype=np.float32) - curve_time = np.full((B, self.time_dim), np.nan, dtype=np.float32) - schedule_meta = np.full((B, len(SCHEDULE_META_NAMES)), np.nan, dtype=np.float32) - family_name: list[str] = ["" for _ in range(B)] + gid = np.full((batch_size,), -1, dtype=np.int64) + params = np.full((batch_size, self.param_dim), np.nan, dtype=np.float32) + schedule = np.full((batch_size, self.schedule_dim), np.nan, dtype=np.float32) + curve = np.full((batch_size, self.curve_dim), np.nan, dtype=np.float32) + curve_time = np.full((batch_size, self.time_dim), np.nan, dtype=np.float32) + schedule_meta = np.full((batch_size, len(SCHEDULE_META_NAMES)), np.nan, dtype=np.float32) + family_name: list[str] = ["" for _ in range(batch_size)] for i, s in enumerate(samples): gid[i] = int(s.get("group_id", -1)) @@ -653,6 +658,7 @@ class ParallelDatasetGenerator: stderr=subprocess.PIPE, encoding="utf-8", errors="ignore", + check=False, ) if result.returncode == 0 and self.cfg.dataset_bin.exists(): return True @@ -697,8 +703,8 @@ class ParallelDatasetGenerator: filepath = self.output_dir / f"dataset_{mode}_{sec_mode}_target{n_samples}_{timestamp}.h5" curve_dim = cfg.curve_dim - Nu, Cu = cfg.schedule_grid_shape - sched_dim = Nu * Cu + cfg.sec_feat_dim + n_u_points, n_channels = cfg.schedule_grid_shape + sched_dim = n_u_points * n_channels + cfg.sec_feat_dim if bool(cfg.raw.get("schedule", {}).get("use_metadata_features_for_model", False)): sched_dim += len(cfg.raw.get("schedule", {}).get("metadata_features_for_model", []) or []) param_dim = 6 @@ -853,7 +859,7 @@ class ParallelDatasetGenerator: while in_flight: fut = in_flight.popleft() try: - task_idx, sample, fail_reason, fail_ctx = fut.result() + task_idx, sample, fail_reason, _fail_ctx = fut.result() except Exception: fail_reasons["future_exception"] = fail_reasons.get("future_exception", 0) + 1 else: diff --git a/ML/nmWTAI-ML/src/data/param_features.py b/ML/nmWTAI-ML/src/data/param_features.py index a1ebf10..b448918 100644 --- a/ML/nmWTAI-ML/src/data/param_features.py +++ b/ML/nmWTAI-ML/src/data/param_features.py @@ -10,6 +10,8 @@ 恢复原始参数含义。 """ +# pylint: disable=too-many-locals,invalid-name + from __future__ import annotations from typing import Any @@ -187,4 +189,3 @@ def inverse_transform_param_features( else: raise ValueError(f"Unknown transform mode for {name}: {mode}") return out.astype(np.float32) - diff --git a/ML/nmWTAI-ML/src/data/params.py b/ML/nmWTAI-ML/src/data/params.py index 6c3afd9..8507c9d 100644 --- a/ML/nmWTAI-ML/src/data/params.py +++ b/ML/nmWTAI-ML/src/data/params.py @@ -11,6 +11,8 @@ params.bin,因此这里的二进制布局必须与求解器端严格一致。 from __future__ import annotations +# pylint: disable=import-error,invalid-name,too-many-return-statements,too-many-locals,no-member,broad-exception-caught,import-outside-toplevel + import os import struct from dataclasses import dataclass, asdict @@ -104,10 +106,10 @@ class Params: sch = self.schedule.clipped(int(cfg.get("schedule", "max_points", default=512))) if not sch.validate(): raise ValueError("Invalid schedule extension") - nQ = len(sch.timeQ) - b += struct.pack(" np.ndarray: if method == "sobol": sampler = qmc.Sobol(d=d, scramble=True, seed=seed) m = int(np.ceil(np.log2(max(n, 1)))) - U = sampler.random_base2(m=m) - return U[:n] + unit_samples = sampler.random_base2(m=m) + return unit_samples[:n] if method == "lhs": sampler = qmc.LatinHypercube(d=d, seed=seed) return sampler.random(n=n) @@ -267,11 +269,11 @@ def generate_params_dataset(cfg: Config, n_samples: int, method: str | None = No method = (method or cfg.raw["params"].get("sampling_method", "sobol")).lower() active_names = list(cfg.raw["params"]["active_param_names"]) log_params = set(cfg.raw["params"]["log_params"]) - U = _qmc_unit(n_samples, len(active_names), method, random_seed) + unit_samples = _qmc_unit(n_samples, len(active_names), method, random_seed) out: list[Params] = [] seen = set() - for row in U: + for row in unit_samples: sampled_vals: Dict[str, float] = {} for i, name in enumerate(active_names): u = float(row[i]) diff --git a/ML/nmWTAI-ML/src/data/runner_client.py b/ML/nmWTAI-ML/src/data/runner_client.py index c684a77..8fe5537 100644 --- a/ML/nmWTAI-ML/src/data/runner_client.py +++ b/ML/nmWTAI-ML/src/data/runner_client.py @@ -9,6 +9,8 @@ 供数据集生成、候选参数评分和调试工具复用。 """ + +# pylint: disable=import-error,too-many-instance-attributes,broad-exception-caught,consider-using-with,invalid-name from __future__ import annotations import os @@ -112,6 +114,7 @@ class CppRunner: encoding="utf-8", errors="ignore", env=self._subproc_env(), + check=False, ) if result.returncode != 0: @@ -234,6 +237,7 @@ class CppRunner: encoding="utf-8", errors="ignore", env=self._subproc_env(), + check=False, ) return bool(result.returncode == 0 and self.result_bin.exists()) @@ -252,22 +256,28 @@ def read_result_bin(result_bin_path: Path) -> Optional[Dict[str, Any]]: if magic != expected_magic or version != 1: return None - nWells, nSteps = struct.unpack(" Tuple[np.ndarray, np.ndarray, bool]: +# pylint: disable=too-many-locals +def canonicalize_schedule( + cfg: Config, + time_q: List[float], + q: List[float], +) -> Tuple[np.ndarray, np.ndarray, bool]: """规范化变长流量制度。 该步骤把输入裁剪到求解器允许的最大段数,保证时长为正、流量非负,并按配置合并 相邻且流量近似相同的分段。规范化后再进入固定网格编码,可减少“同一制度因为 无意义细碎分段而得到不同特征”的情况。 """ + max_points = int(cfg.get("schedule", "max_points", default=512)) - dt = np.asarray(list(map(float, timeQ)), dtype=np.float64).reshape(-1)[:max_points] - qq = np.asarray(list(map(float, q)), dtype=np.float64).reshape(-1)[:max_points] + + dt = np.asarray( + list(map(float, time_q)), + dtype=np.float64, + ).reshape(-1)[:max_points] + + qq = np.asarray( + list(map(float, q)), + dtype=np.float64, + ).reshape(-1)[:max_points] dt = np.maximum(dt, 1e-12) qq = np.maximum(qq, 0.0) - cf = cfg.raw["schedule"].get("canonicalize_for_model", {}) or {} - q_thr = float(cf.get("q_thr", 1e-6)) - merge_same_q = bool(cf.get("merge_same_q", True)) - merge_rel_tol = float(cf.get("merge_rel_tol", 1e-4)) - remove_shutin = bool(cf.get("remove_shutin", False)) + canonical_cfg = ( + cfg.raw["schedule"].get("canonicalize_for_model", {}) or {} + ) + + q_threshold = float(canonical_cfg.get("q_thr", 1e-6)) + merge_same_q = bool(canonical_cfg.get("merge_same_q", True)) + merge_rel_tol = float(canonical_cfg.get("merge_rel_tol", 1e-4)) + remove_shutin = bool(canonical_cfg.get("remove_shutin", False)) - has_shutin = bool(np.any(qq <= q_thr)) + has_shutin = bool(np.any(qq <= q_threshold)) if merge_same_q and dt.size >= 2: - # 合并相邻且流量几乎相同的流动段,让编码器关注真正有意义的流量变化。 - new_dt, new_q = [], [] - cur_dt, cur_q = float(dt[0]), float(qq[0]) + new_dt = [] + new_q = [] + + current_dt = float(dt[0]) + current_q = float(qq[0]) - def close(a: float, b: float) -> bool: + def close(value_a: float, value_b: float) -> bool: """判断两个相邻流量是否可视为同一平台段。""" - denom = max(abs(a), abs(b), 1.0) - return abs(a - b) / denom <= merge_rel_tol - for i in range(1, dt.size): - if close(float(qq[i]), cur_q): - cur_dt += float(dt[i]) + denominator = max(abs(value_a), abs(value_b), 1.0) + return abs(value_a - value_b) / denominator <= merge_rel_tol + + for idx in range(1, dt.size): + if close(float(qq[idx]), current_q): + current_dt += float(dt[idx]) else: - new_dt.append(cur_dt) - new_q.append(cur_q) - cur_dt = float(dt[i]) - cur_q = float(qq[i]) - new_dt.append(cur_dt) - new_q.append(cur_q) + new_dt.append(current_dt) + new_q.append(current_q) + + current_dt = float(dt[idx]) + current_q = float(qq[idx]) + + new_dt.append(current_dt) + new_q.append(current_q) + dt = np.asarray(new_dt, dtype=np.float64) qq = np.asarray(new_q, dtype=np.float64) if remove_shutin and dt.size >= 2: - m = qq > q_thr - if int(np.sum(m)) >= 2: - dt = dt[m] - qq = qq[m] + valid_mask = qq > q_threshold + if int(np.sum(valid_mask)) >= 2: + dt = dt[valid_mask] + qq = qq[valid_mask] return dt, qq, has_shutin -def _make_u_grid(cfg: Config, T_total: float, Nu: int, mode: str) -> np.ndarray: +def _make_u_grid( + cfg: Config, + total_time: float, + n_u_points: int, + mode: str, +) -> np.ndarray: """生成 [0, 1] 上的固定网格,用于把变长流量制度编码成定长向量。""" - t_min = float((cfg.raw["schedule"].get("obs_window", {}) or {}).get("t_min", 1e-6)) - T_total = max(float(T_total), t_min * 2.0) + + obs_window_cfg = cfg.raw["schedule"].get("obs_window", {}) or {} + + t_min = float(obs_window_cfg.get("t_min", 1e-6)) + + total_time = max(float(total_time), t_min * 2.0) + if mode == "linear": - return np.linspace(t_min, T_total, Nu, dtype=np.float64) - return np.geomspace(t_min, T_total, Nu, dtype=np.float64) + return np.linspace( + t_min, + total_time, + n_u_points, + dtype=np.float64, + ) + + return np.geomspace( + t_min, + total_time, + n_u_points, + dtype=np.float64, + ) +# pylint: disable=too-many-locals,too-many-statements def encode_schedule_to_timegrid( cfg: Config, - sectionIndex: int, - timeQ: List[float], + section_index: int, + time_q: List[float], q: List[float], n_sections: Optional[int] = None, ) -> EncodedSchedule: @@ -105,84 +155,246 @@ def encode_schedule_to_timegrid( 流量积分、相邻网格流量差分、关井标记等通道。最后再拼接 5 个 section 级特征, 让模型知道当前样本对应哪一个生产/关井分段,而不只看到整条制度形状。 """ - dt, qq, has_shutin = canonicalize_schedule(cfg, timeQ, q) + + dt, qq, has_shutin = canonicalize_schedule( + cfg, + time_q, + q, + ) + if dt.size < 2 or qq.size != dt.size: raise ValueError("invalid schedule after canonicalize") - sec = int(max(1, sectionIndex)) - N = int(n_sections) if n_sections is not None else int(dt.size) - N = max(N, int(dt.size)) - - T_total = float(np.sum(dt)) - enc = cfg.raw["timegrid_encoding"] - Nu = int(enc.get("n_u_points", 256)) - grid_mode = str(enc.get("grid", "log")).lower() - include_cum = bool(enc.get("include_cum", True)) - include_dq = bool(enc.get("include_dq", True)) - include_shutin = bool(enc.get("include_shutin", False)) - q_eps = float(enc.get("q_eps", 1e-12)) - - t_edges = np.concatenate([[0.0], np.cumsum(dt)], axis=0) - u = _make_u_grid(cfg, T_total=T_total, Nu=Nu, mode=("linear" if grid_mode == "linear" else "log")) - idx = np.searchsorted(t_edges[1:], u, side="right") - idx = np.clip(idx, 0, dt.size - 1) - q_u = qq[idx].astype(np.float64) - - # q(t) 是主通道;cum(t)、dq(t) 和关井标志是可选辅助通道, - # 用于帮助代理模型区分不同流量制度形态。 + section_index = int(max(1, section_index)) + + n_sections_total = ( + int(n_sections) + if n_sections is not None + else int(dt.size) + ) + + n_sections_total = max(n_sections_total, int(dt.size)) + + total_time = float(np.sum(dt)) + + encoding_cfg = cfg.raw["timegrid_encoding"] + + n_u_points = int(encoding_cfg.get("n_u_points", 256)) + + grid_mode = str( + encoding_cfg.get("grid", "log"), + ).lower() + + include_cum = bool( + encoding_cfg.get("include_cum", True), + ) + + include_dq = bool( + encoding_cfg.get("include_dq", True), + ) + + include_shutin = bool( + encoding_cfg.get("include_shutin", False), + ) + + q_eps = float( + encoding_cfg.get("q_eps", 1e-12), + ) + + time_edges = np.concatenate( + [[0.0], np.cumsum(dt)], + axis=0, + ) + + u_grid = _make_u_grid( + cfg, + total_time=total_time, + n_u_points=n_u_points, + mode=( + "linear" + if grid_mode == "linear" + else "log" + ), + ) + + interval_idx = np.searchsorted( + time_edges[1:], + u_grid, + side="right", + ) + + interval_idx = np.clip( + interval_idx, + 0, + dt.size - 1, + ) + + q_u = qq[interval_idx].astype(np.float64) + cum_u = None if include_cum: - prefix = np.concatenate([[0.0], np.cumsum(dt * qq)], axis=0) - cum_u = prefix[idx] + (u - t_edges[:-1][idx]) * qq[idx] + prefix = np.concatenate( + [[0.0], np.cumsum(dt * qq)], + axis=0, + ) + + cum_u = ( + prefix[interval_idx] + + ( + u_grid + - time_edges[:-1][interval_idx] + ) + * qq[interval_idx] + ) dq_u = None if include_dq: - dq_u = np.zeros_like(u, dtype=np.float64) + dq_u = np.zeros_like( + u_grid, + dtype=np.float64, + ) + dq_u[1:] = np.diff(q_u) - dq_u[0] = dq_u[1] if dq_u.size > 1 else 0.0 + + dq_u[0] = ( + dq_u[1] + if dq_u.size > 1 + else 0.0 + ) shut_u = None if include_shutin: - thr = float(enc.get("shutin_thr", 1e-6)) - shut_u = (q_u <= thr).astype(np.float64) - - norm_mode = str(enc.get("normalize_mode", "global")).lower() - if norm_mode == "per_sample": - q_scale = max(float(np.max(q_u)), q_eps) - cum_scale = max(float(cum_u.max()) if cum_u is not None else 1.0, q_eps) + shutin_thr = float( + encoding_cfg.get("shutin_thr", 1e-6), + ) + + shut_u = ( + q_u <= shutin_thr + ).astype(np.float64) + + normalize_mode = str( + encoding_cfg.get( + "normalize_mode", + "global", + ) + ).lower() + + if normalize_mode == "per_sample": + q_scale = max( + float(np.max(q_u)), + q_eps, + ) + + cum_scale = max( + ( + float(cum_u.max()) + if cum_u is not None + else 1.0 + ), + q_eps, + ) else: - q_scale = float(enc.get("q_global_max", 1.0)) - cum_scale = float(enc.get("cum_global_max", 1.0)) + q_scale = float( + encoding_cfg.get( + "q_global_max", + 1.0, + ) + ) + + cum_scale = float( + encoding_cfg.get( + "cum_global_max", + 1.0, + ) + ) q_u = q_u / max(q_scale, q_eps) + if cum_u is not None: - cum_u = cum_u / max(cum_scale, q_eps) + cum_u = cum_u / max( + cum_scale, + q_eps, + ) + if dq_u is not None: - dq_u = dq_u / max(q_scale, q_eps) + dq_u = dq_u / max( + q_scale, + q_eps, + ) + + channels = [q_u] - chans = [q_u] if include_cum: - chans.append(cum_u if cum_u is not None else np.zeros_like(q_u)) + channels.append( + cum_u + if cum_u is not None + else np.zeros_like(q_u) + ) + if include_dq: - chans.append(dq_u if dq_u is not None else np.zeros_like(q_u)) + channels.append( + dq_u + if dq_u is not None + else np.zeros_like(q_u) + ) + if include_shutin: - chans.append(shut_u if shut_u is not None else np.zeros_like(q_u)) + channels.append( + shut_u + if shut_u is not None + else np.zeros_like(q_u) + ) + + x_grid = np.stack( + channels, + axis=1, + ) - X = np.stack(chans, axis=1) - x_sched = X.reshape(-1).astype(np.float32) + x_sched = x_grid.reshape(-1).astype(np.float32) - # 时间网格之外额外拼接 5 个流动段级别特征, - # 用于告诉模型当前预测的是哪个流动段对应的双对数曲线。 - n_sec = int(max(1, N)) - sec_clamped = int(np.clip(sec, 1, n_sec)) - section_pos = float((sec_clamped - 1) / max(n_sec - 1, 1)) - sectionIndex_norm = float(sec_clamped / max(n_sec, 1)) - n_sections_norm = float(n_sec / max(12, 1)) - log_T_total = float(np.log(max(T_total, 1e-12))) + n_sections_total = int( + max(1, n_sections_total), + ) + + section_clamped = int( + np.clip( + section_index, + 1, + n_sections_total, + ) + ) + + section_pos = float( + (section_clamped - 1) + / max(n_sections_total - 1, 1) + ) + + section_index_norm = float( + section_clamped + / max(n_sections_total, 1) + ) + + n_sections_norm = float( + n_sections_total + / max(12, 1) + ) + + log_total_time = float( + np.log(max(total_time, 1e-12)) + ) x_sec = np.array( - [sectionIndex_norm, section_pos, n_sections_norm, float(has_shutin), log_T_total], + [ + section_index_norm, + section_pos, + n_sections_norm, + float(has_shutin), + log_total_time, + ], dtype=np.float32, ) - return EncodedSchedule(x_sched=x_sched, x_sec=x_sec) + return EncodedSchedule( + x_sched=x_sched, + x_sec=x_sec, + ) diff --git a/ML/nmWTAI-ML/src/data/schedule_features.py b/ML/nmWTAI-ML/src/data/schedule_features.py index 40bc4b6..87e9055 100644 --- a/ML/nmWTAI-ML/src/data/schedule_features.py +++ b/ML/nmWTAI-ML/src/data/schedule_features.py @@ -9,6 +9,8 @@ 或 sectionIndex 分组分析模型误差。 """ +# pylint: disable=too-many-locals,import-error + from __future__ import annotations from typing import List diff --git a/ML/nmWTAI-ML/src/models/forward_surrogate.py b/ML/nmWTAI-ML/src/models/forward_surrogate.py index a5117cc..c007681 100644 --- a/ML/nmWTAI-ML/src/models/forward_surrogate.py +++ b/ML/nmWTAI-ML/src/models/forward_surrogate.py @@ -12,10 +12,28 @@ ForwardSurrogate 输入标准化后的物理参数特征和可选的流量制度 from __future__ import annotations +# pylint: disable=import-error,duplicate-code,too-many-arguments,too-many-positional-arguments + +from dataclasses import dataclass, field + import torch -import torch.nn as nn +from torch import nn + +@dataclass(slots=True) +class ForwardSurrogateConfig: + """ForwardSurrogate 的结构配置。 + + 使用配置对象可以避免模型构造函数参数过多,同时让训练脚本中的超参数更集中。 + """ + param_dim: int + schedule_dim: int + curve_dim: int + hidden_dim: int = 128 + fusion_hidden_dims: list[int] = field(default_factory=lambda: [256, 256]) + dropout: float = 0.0 + use_schedule: bool = True def build_mlp( @@ -26,14 +44,15 @@ def build_mlp( ) -> nn.Sequential: """按隐藏层列表搭建 Linear-ReLU-Dropout 组成的多层感知机。""" layers: list[nn.Module] = [] - prev = in_dim - for h in hidden_dims: - layers.append(nn.Linear(prev, h)) + prev_dim = in_dim + for hidden_dim in hidden_dims: + layers.append(nn.Linear(prev_dim, hidden_dim)) layers.append(nn.ReLU()) if dropout > 0: layers.append(nn.Dropout(dropout)) - prev = h - layers.append(nn.Linear(prev, out_dim)) + prev_dim = hidden_dim + + layers.append(nn.Linear(prev_dim, out_dim)) return nn.Sequential(*layers) @@ -53,7 +72,6 @@ class ScheduleEncoder(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """把流量制度统计特征映射到与参数分支同宽度的隐藏表示。""" - # 该分支只处理制度向量,便于后续与地层参数特征拼接融合。 return self.net(x) @@ -73,7 +91,6 @@ class ParamEncoder(nn.Module): def forward(self, x: torch.Tensor) -> torch.Tensor: """把变换后的地层和井筒参数映射为隐藏表示。""" - # 参数特征通常来自 log/asinh 等尺度变换,先编码再与制度分支融合。 return self.net(x) @@ -82,19 +99,21 @@ class ForwardSurrogate(nn.Module): 输入: params_x: 标准化后的物理参数特征,形状 [B, param_dim]。 - schedule_x: 标准化后的流量制度向量,形状 [B, schedule_dim];当 use_schedule=False - 时该输入可为空。 + schedule_x: 标准化后的流量制度向量,形状 [B, schedule_dim]; + 当 use_schedule=False 时该输入可为空。 输出: - curve_pred: 形状 [B, curve_dim],按 log_pressure、log_derivative、slope 三段 - 顺序拼接。curve_dim 必须能被 3 整除,以便每段拥有相同时间点数。 + curve_pred: 形状 [B, curve_dim],按 log_pressure、log_derivative、slope + 三段顺序拼接。curve_dim 必须能被 3 整除,以便每段拥有相同时间点数。 """ def __init__( self, - param_dim: int, - schedule_dim: int, - curve_dim: int, + config: ForwardSurrogateConfig | None = None, + *, + param_dim: int | None = None, + schedule_dim: int | None = None, + curve_dim: int | None = None, hidden_dim: int = 128, fusion_hidden_dims: list[int] | None = None, dropout: float = 0.0, @@ -102,71 +121,168 @@ class ForwardSurrogate(nn.Module): ): """构建参数分支、可选流量制度分支、融合主干和三组曲线输出头。""" super().__init__() + config = self._coerce_config( + config=config, + param_dim=param_dim, + schedule_dim=schedule_dim, + curve_dim=curve_dim, + hidden_dim=hidden_dim, + fusion_hidden_dims=fusion_hidden_dims, + dropout=dropout, + use_schedule=use_schedule, + ) + self._validate_config(config) + + self.config = config + self.encoders = self._build_encoders() + self.trunk = self._build_trunk() + self.heads = self._build_heads() + + @staticmethod + def _coerce_config( + config: ForwardSurrogateConfig | None, + *, + param_dim: int | None, + schedule_dim: int | None, + curve_dim: int | None, + hidden_dim: int, + fusion_hidden_dims: list[int] | None, + dropout: float, + use_schedule: bool, + ) -> ForwardSurrogateConfig: + """兼容配置对象式构造和旧版关键字参数式构造。""" + if config is not None: + return config + if param_dim is None or schedule_dim is None or curve_dim is None: + raise TypeError( + "ForwardSurrogate requires either a ForwardSurrogateConfig or " + "param_dim, schedule_dim and curve_dim keyword arguments" + ) + return ForwardSurrogateConfig( + param_dim=int(param_dim), + schedule_dim=int(schedule_dim), + curve_dim=int(curve_dim), + hidden_dim=int(hidden_dim), + fusion_hidden_dims=fusion_hidden_dims or [256, 256], + dropout=float(dropout), + use_schedule=bool(use_schedule), + ) - if curve_dim % 3 != 0: - raise ValueError(f"curve_dim={curve_dim} 不能被 3 整除;期望为 pressure/derivative/slope 三段") + @property + def curve_dim(self) -> int: + """曲线拼接后的总维度。""" + return self.config.curve_dim - if fusion_hidden_dims is None: - fusion_hidden_dims = [256, 256] + @property + def part_dim(self) -> int: + """压力、导数和 slope 每一段的时间点数量。""" + return self.config.curve_dim // 3 - self.curve_dim = curve_dim - self.part_dim = curve_dim // 3 - self.use_schedule = bool(use_schedule) + @property + def use_schedule(self) -> bool: + """是否启用流量制度分支。""" + return bool(self.config.use_schedule) - # 参数和流量制度的物理含义与尺度差异较大,因此采用两个分支分别编码。 - self.param_encoder = ParamEncoder(param_dim, hidden_dim, dropout=dropout) + @staticmethod + def _validate_config(config: ForwardSurrogateConfig) -> None: + """检查模型配置是否满足网络结构约束。""" + if config.curve_dim % 3 != 0: + msg = ( + f"curve_dim={config.curve_dim} 不能被 3 整除;" + "期望为 pressure/derivative/slope 三段" + ) + raise ValueError(msg) + + if not config.fusion_hidden_dims: + raise ValueError("fusion_hidden_dims 不能为空") + + def _build_encoders(self) -> nn.ModuleDict: + """构建参数分支和可选流量制度分支。""" + encoders = nn.ModuleDict( + { + "param": ParamEncoder( + self.config.param_dim, + self.config.hidden_dim, + dropout=self.config.dropout, + ) + } + ) if self.use_schedule: - self.schedule_encoder = ScheduleEncoder(schedule_dim, hidden_dim, dropout=dropout) - trunk_in_dim = hidden_dim * 2 - else: - self.schedule_encoder = None - trunk_in_dim = hidden_dim - - trunk_out_dim = fusion_hidden_dims[-1] - self.trunk = build_mlp( + encoders["schedule"] = ScheduleEncoder( + self.config.schedule_dim, + self.config.hidden_dim, + dropout=self.config.dropout, + ) + + return encoders + + def _build_trunk(self) -> nn.Sequential: + """构建融合主干网络。""" + trunk_in_dim = self.config.hidden_dim * 2 if self.use_schedule else self.config.hidden_dim + trunk_out_dim = self.config.fusion_hidden_dims[-1] + return build_mlp( in_dim=trunk_in_dim, - hidden_dims=fusion_hidden_dims, + hidden_dims=self.config.fusion_hidden_dims, out_dim=trunk_out_dim, - dropout=dropout, + dropout=self.config.dropout, ) - # 压力曲线拆成 level + centered shape: - # level 学习整体纵向偏移,shape 学习局部曲线形态。 - self.pressure_level_head = build_mlp( - in_dim=trunk_out_dim, - hidden_dims=[128], - out_dim=1, - dropout=dropout, - ) - self.pressure_shape_head = build_mlp( - in_dim=trunk_out_dim, - hidden_dims=[128], - out_dim=self.part_dim, - dropout=dropout, + def _build_heads(self) -> nn.ModuleDict: + """构建压力、导数和 slope 输出头。""" + trunk_out_dim = self.config.fusion_hidden_dims[-1] + return nn.ModuleDict( + { + "pressure_level": self._build_single_head(trunk_out_dim, 1), + "pressure_shape": self._build_single_head(trunk_out_dim, self.part_dim), + "derivative_level": self._build_single_head(trunk_out_dim, 1), + "derivative_shape": self._build_single_head(trunk_out_dim, self.part_dim), + "slope": self._build_single_head(trunk_out_dim, self.part_dim), + } ) - # 导数曲线同样拆分为 level + shape,因为平台、谷值和过渡段 - # 对自动拟合筛选非常重要。 - self.derivative_level_head = build_mlp( - in_dim=trunk_out_dim, + def _build_single_head(self, in_dim: int, out_dim: int) -> nn.Sequential: + """构建一个曲线输出头。""" + return build_mlp( + in_dim=in_dim, hidden_dims=[128], - out_dim=1, - dropout=dropout, - ) - self.derivative_shape_head = build_mlp( - in_dim=trunk_out_dim, - hidden_dims=[128], - out_dim=self.part_dim, - dropout=dropout, + out_dim=out_dim, + dropout=self.config.dropout, ) - # slope 是辅助输出,主要用于保持数据布局兼容。 - self.slope_head = build_mlp( - in_dim=trunk_out_dim, - hidden_dims=[128], - out_dim=self.part_dim, - dropout=dropout, + @staticmethod + def _upgrade_state_dict_keys(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """把旧版 checkpoint 键名转换为当前 ModuleDict 结构键名。""" + prefix_map = { + "param_encoder.": "encoders.param.", + "schedule_encoder.": "encoders.schedule.", + "pressure_level_head.": "heads.pressure_level.", + "pressure_shape_head.": "heads.pressure_shape.", + "derivative_level_head.": "heads.derivative_level.", + "derivative_shape_head.": "heads.derivative_shape.", + "slope_head.": "heads.slope.", + } + upgraded: dict[str, torch.Tensor] = {} + for key, value in state_dict.items(): + new_key = key + for old_prefix, new_prefix in prefix_map.items(): + if key.startswith(old_prefix): + new_key = new_prefix + key[len(old_prefix) :] + break + upgraded[new_key] = value + return upgraded + + def load_state_dict( + self, + state_dict: dict[str, torch.Tensor], + strict: bool = True, + assign: bool = False, + ): + """加载当前或旧版 ForwardSurrogate checkpoint。""" + return super().load_state_dict( + self._upgrade_state_dict_keys(state_dict), + strict=strict, + assign=assign, ) @staticmethod @@ -174,40 +290,58 @@ class ForwardSurrogate(nn.Module): """去除每个样本 shape 分支的均值,让 level 分支专门学习整体偏移。""" return x - x.mean(dim=1, keepdim=True) - def forward(self, params_x: torch.Tensor, schedule_x: torch.Tensor | None = None) -> torch.Tensor: + def _encode_features( + self, + params_x: torch.Tensor, + schedule_x: torch.Tensor | None, + ) -> torch.Tensor: + """分别编码物理参数和流量制度,再在隐空间融合。""" + param_feat = self.encoders["param"](params_x) + + if not self.use_schedule: + return param_feat + + if schedule_x is None: + raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x") + + schedule_feat = self.encoders["schedule"](schedule_x) + return torch.cat([param_feat, schedule_feat], dim=-1) + + def _predict_level_shape( + self, + trunk_feat: torch.Tensor, + level_head: str, + shape_head: str, + ) -> torch.Tensor: + """用 level + centered shape 生成一段曲线。""" + level = self.heads[level_head](trunk_feat) + shape = self.center_shape(self.heads[shape_head](trunk_feat)) + return level + shape + + def forward( + self, + params_x: torch.Tensor, + schedule_x: torch.Tensor | None = None, + ) -> torch.Tensor: """执行一次前向预测。 参数分支和流量制度分支先分别编码,再在隐空间拼接。融合主干提取共同特征后, 压力和导数各自通过 level + centered shape 两个输出头生成;slope 作为辅助通道 直接由单独输出头预测。返回值仍保持预处理阶段约定的曲线拼接布局。 """ - p = self.param_encoder(params_x) + fused_feat = self._encode_features(params_x, schedule_x) + trunk_feat = self.trunk(fused_feat) - if self.use_schedule: - if schedule_x is None: - raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x") - s = self.schedule_encoder(schedule_x) - # 两个分支在隐藏空间拼接,避免直接混合量纲差异很大的原始特征。 - fused = torch.cat([p, s], dim=-1) - else: - fused = p - - # trunk 负责学习参数-制度共同决定的曲线整体形态。 - trunk_feat = self.trunk(fused) - - pressure_level = self.pressure_level_head(trunk_feat) # [B, 1] - pressure_shape = self.pressure_shape_head(trunk_feat) # [B, T] - # shape 去均值后只表达相对形态,纵向偏移交给 level 分支学习。 - pressure_shape = self.center_shape(pressure_shape) - pressure_pred = pressure_level + pressure_shape - - derivative_level = self.derivative_level_head(trunk_feat) # [B, 1] - derivative_shape = self.derivative_shape_head(trunk_feat) # [B, T] - # 导数也采用 level + shape,减少平台值和局部过渡段之间的相互牵制。 - derivative_shape = self.center_shape(derivative_shape) - derivative_pred = derivative_level + derivative_shape - - slope_pred = self.slope_head(trunk_feat) # [B, T] - - curve_pred = torch.cat([pressure_pred, derivative_pred, slope_pred], dim=1) - return curve_pred + pressure_pred = self._predict_level_shape( + trunk_feat, + "pressure_level", + "pressure_shape", + ) + derivative_pred = self._predict_level_shape( + trunk_feat, + "derivative_level", + "derivative_shape", + ) + slope_pred = self.heads["slope"](trunk_feat) + + return torch.cat([pressure_pred, derivative_pred, slope_pred], dim=1) diff --git a/ML/nmWTAI-ML/src/models/time_conditioned_surrogate.py b/ML/nmWTAI-ML/src/models/time_conditioned_surrogate.py index 8b84249..cde54e4 100644 --- a/ML/nmWTAI-ML/src/models/time_conditioned_surrogate.py +++ b/ML/nmWTAI-ML/src/models/time_conditioned_surrogate.py @@ -11,12 +11,30 @@ TimeConditionedSurrogate 不一次性输出完整曲线,而是把“物理参 from __future__ import annotations +# pylint: disable=import-error,duplicate-code,too-many-arguments,too-many-positional-arguments + +from dataclasses import dataclass + import torch -import torch.nn as nn +from torch import nn + + +@dataclass(frozen=True) +class TimeConditionedSurrogateConfig: + """时间条件代理模型配置。""" + + param_dim: int + schedule_dim: int + time_dim: int + hidden_dim: int = 256 + n_blocks: int = 4 + dropout: float = 0.05 + use_schedule: bool = True class ResidualBlock(nn.Module): """时间条件模型使用的全连接残差块,用于在较深网络中稳定传播特征。""" + def __init__(self, dim: int, dropout: float = 0.0): """构造两层全连接残差块,并根据 dropout 参数决定是否启用随机失活。""" super().__init__() @@ -45,49 +63,97 @@ class TimeConditionedSurrogate(nn.Module): def __init__( self, - param_dim: int, - schedule_dim: int, - time_dim: int, + config: TimeConditionedSurrogateConfig | None = None, + *, + param_dim: int | None = None, + schedule_dim: int | None = None, + time_dim: int | None = None, hidden_dim: int = 256, n_blocks: int = 4, dropout: float = 0.05, use_schedule: bool = True, ): - """按输入维度和隐藏层宽度组装时间条件代理模型的编码、融合和输出层。""" + """按配置组装时间条件代理模型的编码、融合和输出层。""" super().__init__() - self.use_schedule = bool(use_schedule) - - self.param_encoder = nn.Sequential( - nn.Linear(param_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU(), - ) - self.time_encoder = nn.Sequential( - nn.Linear(time_dim, hidden_dim // 2), - nn.LayerNorm(hidden_dim // 2), - nn.GELU(), + config = self._coerce_config( + config=config, + param_dim=param_dim, + schedule_dim=schedule_dim, + time_dim=time_dim, + hidden_dim=hidden_dim, + n_blocks=n_blocks, + dropout=dropout, + use_schedule=use_schedule, ) + + hidden_dim = int(config.hidden_dim) + time_hidden_dim = hidden_dim // 2 + self.use_schedule = bool(config.use_schedule) + + self.param_encoder = self._build_encoder(config.param_dim, hidden_dim) + self.time_encoder = self._build_encoder(config.time_dim, time_hidden_dim) + if self.use_schedule: - self.schedule_encoder = nn.Sequential( - nn.Linear(schedule_dim, hidden_dim), - nn.LayerNorm(hidden_dim), - nn.GELU(), - ) - fusion_dim = hidden_dim * 2 + hidden_dim // 2 + self.schedule_encoder = self._build_encoder(config.schedule_dim, hidden_dim) + fusion_dim = hidden_dim * 2 + time_hidden_dim else: self.schedule_encoder = None - fusion_dim = hidden_dim + hidden_dim // 2 + fusion_dim = hidden_dim + time_hidden_dim self.input_proj = nn.Sequential( nn.Linear(fusion_dim, hidden_dim), nn.GELU(), ) - self.blocks = nn.Sequential(*[ResidualBlock(hidden_dim, dropout=dropout) for _ in range(int(n_blocks))]) + self.blocks = nn.Sequential( + *[ + ResidualBlock(hidden_dim, dropout=config.dropout) + for _ in range(int(config.n_blocks)) + ] + ) self.head = nn.Sequential( nn.LayerNorm(hidden_dim), - nn.Linear(hidden_dim, hidden_dim // 2), + nn.Linear(hidden_dim, time_hidden_dim), + nn.GELU(), + nn.Linear(time_hidden_dim, 2), + ) + + @staticmethod + def _coerce_config( + config: TimeConditionedSurrogateConfig | None, + *, + param_dim: int | None, + schedule_dim: int | None, + time_dim: int | None, + hidden_dim: int, + n_blocks: int, + dropout: float, + use_schedule: bool, + ) -> TimeConditionedSurrogateConfig: + """兼容配置对象式构造和旧版关键字参数式构造。""" + if config is not None: + return config + if param_dim is None or schedule_dim is None or time_dim is None: + raise TypeError( + "TimeConditionedSurrogate requires either a TimeConditionedSurrogateConfig " + "or param_dim, schedule_dim and time_dim keyword arguments" + ) + return TimeConditionedSurrogateConfig( + param_dim=int(param_dim), + schedule_dim=int(schedule_dim), + time_dim=int(time_dim), + hidden_dim=int(hidden_dim), + n_blocks=int(n_blocks), + dropout=float(dropout), + use_schedule=bool(use_schedule), + ) + + @staticmethod + def _build_encoder(in_dim: int, out_dim: int) -> nn.Sequential: + """构建 Linear-LayerNorm-GELU 编码器。""" + return nn.Sequential( + nn.Linear(in_dim, out_dim), + nn.LayerNorm(out_dim), nn.GELU(), - nn.Linear(hidden_dim // 2, 2), ) def forward( @@ -103,16 +169,25 @@ class TimeConditionedSurrogate(nn.Module): log_pressure 和 log_derivative。 """ # params_x 和 schedule_x 是样本级特征;time_x 是展开后的点级特征。 - p = self.param_encoder(params_x) - t = self.time_encoder(time_x) + param_feat = self.param_encoder(params_x) + time_feat = self.time_encoder(time_x) + if self.use_schedule: if schedule_x is None: - raise ValueError("use_schedule=True but schedule_x is None") - s = self.schedule_encoder(schedule_x) - x = torch.cat([p, s, t], dim=-1) + raise ValueError("use_schedule=True,但 forward 没有传入 schedule_x") + schedule_feat = self.schedule_encoder(schedule_x) + fused = torch.cat([param_feat, schedule_feat, time_feat], dim=-1) else: - x = torch.cat([p, t], dim=-1) + fused = torch.cat([param_feat, time_feat], dim=-1) + # 融合后的特征经过残差主干,输出 log_pressure 和 log_derivative 两个通道。 - x = self.input_proj(x) - x = self.blocks(x) - return self.head(x) + hidden = self.input_proj(fused) + hidden = self.blocks(hidden) + return self.head(hidden) + + +def build_time_conditioned_surrogate( + config: TimeConditionedSurrogateConfig, +) -> TimeConditionedSurrogate: + """根据配置创建时间条件代理模型。""" + return TimeConditionedSurrogate(config) diff --git a/ML/nmWTAI-ML/src/training/train_forward.py b/ML/nmWTAI-ML/src/training/train_forward.py index c6a1e1d..d30b3ac 100644 --- a/ML/nmWTAI-ML/src/training/train_forward.py +++ b/ML/nmWTAI-ML/src/training/train_forward.py @@ -10,22 +10,39 @@ ForwardSurrogate,并按验证集损失保存最佳 checkpoint。损失函数 脚本可以根据 checkpoint 中保存的维度、curve_layout 和损失权重恢复模型。 """ +# pylint: disable=import-error,duplicate-code,too-many-instance-attributes +# pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals + from __future__ import annotations import json import random -from dataclasses import dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path +from typing import Any import joblib import numpy as np import torch -import torch.nn as nn +from torch import nn from torch.utils.data import DataLoader, Dataset -from src.models.forward_surrogate import ForwardSurrogate +from src.models.forward_surrogate import ForwardSurrogate, ForwardSurrogateConfig +METRIC_KEYS = ( + "loss", + "loss_pressure", + "loss_derivative", + "loss_slope", + "loss_bias_pressure", + "loss_bias_derivative", + "loss_derivative_shape", + "loss_autofit_pressure", + "loss_autofit_derivative", + "sample_weight_mean", + "sample_weight_max", +) class ForwardDataset(Dataset): @@ -47,42 +64,120 @@ class ForwardDataset(Dataset): @dataclass -class TrainConfig: - """正演代理模型训练配置,包括优化器参数、损失权重、重加权策略和设备。""" +class ModelConfig: + """模型结构相关配置。""" + + hidden_dim: int = 128 + dropout: float = 0.0 + use_schedule: bool = True + + +@dataclass +class OptimConfig: + """优化器与训练轮次配置。""" - processed_path: Path - output_dir: Path - seed: int = 42 batch_size: int = 128 epochs: int = 100 lr: float = 1e-3 weight_decay: float = 1e-5 - hidden_dim: int = 128 - dropout: float = 0.0 - w_pressure: float = 1.0 - w_derivative: float = 2.0 - w_slope: float = 0.0 - w_bias_pressure: float = 0.15 - w_bias_derivative: float = 0.05 - w_derivative_shape: float = 0.10 - w_autofit_pressure: float = 0.0 - w_autofit_derivative: float = 0.0 +@dataclass +class LossWeights: + """复合损失的各项权重。""" + + pressure: float = 1.0 + derivative: float = 2.0 + slope: float = 0.0 + bias_pressure: float = 0.15 + bias_derivative: float = 0.05 + derivative_shape: float = 0.10 + autofit_pressure: float = 0.0 + autofit_derivative: float = 0.0 + + +@dataclass +class LossConfig: + """损失函数配置。""" + weights: LossWeights = field(default_factory=LossWeights) use_huber: bool = True huber_beta: float = 0.05 - use_sample_reweight: bool = True - sample_reweight_alpha: float = 0.4 - sample_weight_min: float = 1.0 - sample_weight_max: float = 2.5 - use_schedule: bool = True +@dataclass +class SampleReweightConfig: + """样本重加权配置。""" + + enabled: bool = True + alpha: float = 0.4 + weight_min: float = 1.0 + weight_max: float = 2.5 + + +@dataclass +class TrainRuntime: + """训练运行时配置。""" + seed: int = 42 device: str = "cuda" if torch.cuda.is_available() else "cpu" +@dataclass +class TrainConfig: + """正演代理模型训练配置。""" + + processed_path: Path + output_dir: Path + runtime: TrainRuntime = field(default_factory=TrainRuntime) + optim: OptimConfig = field(default_factory=OptimConfig) + model: ModelConfig = field(default_factory=ModelConfig) + loss: LossConfig = field(default_factory=LossConfig) + sample_reweight: SampleReweightConfig = field(default_factory=SampleReweightConfig) + + +@dataclass +class CurveStats: + """曲线 scaler 的 torch 形式统计量。""" + + mean_raw: torch.Tensor + scale_raw: torch.Tensor + + +@dataclass +class LossBatchParts: + """曲线三段的预测值和真实值。""" + + pred_p: torch.Tensor + pred_d: torch.Tensor + pred_s: torch.Tensor + true_p: torch.Tensor + true_d: torch.Tensor + true_s: torch.Tensor + + +@dataclass +class LossContext: + """计算复合损失所需的上下文。""" + + slices: dict[str, slice] + curve_stats: CurveStats + loss_cfg: LossConfig + reweight_cfg: SampleReweightConfig + + +@dataclass +class DatasetBundle: + """训练、验证、测试 DataLoader 与数据维度。""" + + train_loader: DataLoader + val_loader: DataLoader + test_loader: DataLoader + param_dim: int + schedule_dim: int + curve_dim: int + + def set_global_seed(seed: int) -> None: """设置 Python、NumPy 和 PyTorch 随机种子,并在 CUDA 可用时同步设置 GPU 随机种子。""" random.seed(seed) @@ -100,9 +195,9 @@ def load_processed_dataset(path: Path) -> dict: "X_params_val", "X_schedule_val", "Y_curve_val", "X_params_test", "X_schedule_test", "Y_curve_test", ] - for k in required_keys: - if k not in data: - raise KeyError(f"processed dataset 缺少字段: {k}") + for key in required_keys: + if key not in data: + raise KeyError(f"processed dataset 缺少字段: {key}") return data @@ -134,20 +229,14 @@ def get_part_slices(curve_layout: dict) -> dict[str, slice]: out: dict[str, slice] = {} for part in curve_layout["parts"]: name = str(part["name"]) - start = int(part["start"]) - end = int(part["end"]) - out[name] = slice(start, end) + out[name] = slice(int(part["start"]), int(part["end"])) return out def smooth_l1_per_sample(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor: """按样本计算 Smooth L1 损失,返回每个样本一个损失值。""" diff = torch.abs(pred - target) - loss = torch.where( - diff < beta, - 0.5 * diff * diff / beta, - diff - 0.5 * beta, - ) + loss = torch.where(diff < beta, 0.5 * diff * diff / beta, diff - 0.5 * beta) return loss.mean(dim=1) @@ -161,6 +250,17 @@ def l1_per_sample(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: return torch.abs(pred - target).mean(dim=1) +def regression_per_sample( + pred: torch.Tensor, + target: torch.Tensor, + loss_cfg: LossConfig, +) -> torch.Tensor: + """按配置在 Smooth L1 和 MSE 之间切换点值损失。""" + if loss_cfg.use_huber: + return smooth_l1_per_sample(pred, target, beta=float(loss_cfg.huber_beta)) + return mse_per_sample(pred, target) + + def first_diff(x: torch.Tensor) -> torch.Tensor: """计算曲线相邻时间点的一阶差分,用于约束导数形态。""" return x[:, 1:] - x[:, :-1] @@ -175,491 +275,528 @@ def autofit_curve_objective_per_sample(pred: torch.Tensor, target: torch.Tensor) """用 torch 计算自动拟合风格的曲线误差,作为训练附加目标。""" weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0) weight = 1.0 / (1.0 + weight_factor) - - scale = torch.maximum(torch.maximum(torch.abs(target), torch.abs(pred)), torch.full_like(target, 1e-12)) + scale = torch.maximum( + torch.maximum(torch.abs(target), torch.abs(pred)), + torch.full_like(target, 1e-12), + ) relative_error = torch.abs(target - pred) / scale absolute_error = torch.abs(target - pred) point_error = 0.7 * relative_error + 0.3 * absolute_error - - weighted_mse = (weight * (point_error**2)).sum(dim=1) / torch.clamp(weight.sum(dim=1), min=1e-12) + weighted_mse = (weight * (point_error**2)).sum(dim=1) + weighted_mse = weighted_mse / torch.clamp(weight.sum(dim=1), min=1e-12) return torch.sqrt(weighted_mse) def build_sample_weight( true_p: torch.Tensor, true_d: torch.Tensor, - alpha: float, - w_min: float, - w_max: float, + reweight_cfg: SampleReweightConfig, ) -> torch.Tensor: """根据真实曲线幅值构造样本权重,让高幅值样本训练时权重更高。""" p_level = true_p.abs().mean(dim=1) d_level = true_d.abs().mean(dim=1) - # 用 batch 内均值做相对归一化,使权重只表达“本批样本相对更难/幅值更高”。 p_norm = p_level / (p_level.mean().detach() + 1e-6) d_norm = d_level / (d_level.mean().detach() + 1e-6) raw = 0.5 * p_norm + 0.5 * d_norm - weight = 1.0 + alpha * (raw - 1.0) - weight = torch.clamp(weight, min=w_min, max=w_max) - return weight + weight = 1.0 + reweight_cfg.alpha * (raw - 1.0) + return torch.clamp(weight, min=reweight_cfg.weight_min, max=reweight_cfg.weight_max) -def compute_weighted_loss( +def split_curve_parts( pred: torch.Tensor, target: torch.Tensor, slices: dict[str, slice], - curve_mean_raw: torch.Tensor, - curve_scale_raw: torch.Tensor, - huber_beta: float, - w_pressure: float, - w_derivative: float, - w_slope: float, - w_bias_pressure: float, - w_bias_derivative: float, - w_derivative_shape: float, - w_autofit_pressure: float, - w_autofit_derivative: float, - use_sample_reweight: bool, - sample_reweight_alpha: float, - sample_weight_min: float, - sample_weight_max: float, -): - """计算正演代理模型的复合训练损失。 - - 曲线被按 curve_layout 拆成压力、导数和 slope 三段。基础点值损失在标准化空间中 - 计算;均值偏置损失用于约束整条曲线的纵向漂移;导数一阶差分损失用于加强形状 - 连续性;自动拟合风格损失会先把曲线反标准化到原始尺度,再计算更接近候选排序 - 场景的误差。函数返回每个分量,方便训练日志判断到底是哪一类误差在主导。 - """ - pred_p = pred[:, slices["log_pressure"]] - pred_d = pred[:, slices["log_derivative"]] - pred_s = pred[:, slices["slope"]] - - true_p = target[:, slices["log_pressure"]] - true_d = target[:, slices["log_derivative"]] - true_s = target[:, slices["slope"]] - - mean_p = curve_mean_raw[slices["log_pressure"]].unsqueeze(0) - scale_p = curve_scale_raw[slices["log_pressure"]].unsqueeze(0) - mean_d = curve_mean_raw[slices["log_derivative"]].unsqueeze(0) - scale_d = curve_scale_raw[slices["log_derivative"]].unsqueeze(0) - - loss_p_vec = smooth_l1_per_sample(pred_p, true_p, beta=huber_beta) - loss_d_vec = smooth_l1_per_sample(pred_d, true_d, beta=huber_beta) - loss_s_vec = mse_per_sample(pred_s, true_s) - - # 偏置损失用于抑制整条预测曲线的纵向漂移; - # 点误差负责学习更细的局部形态。 - pred_p_mean = pred_p.mean(dim=1, keepdim=True) - true_p_mean = true_p.mean(dim=1, keepdim=True) - loss_bias_p_vec = l1_per_sample(pred_p_mean, true_p_mean) - - pred_d_mean = pred_d.mean(dim=1, keepdim=True) - true_d_mean = true_d.mean(dim=1, keepdim=True) - loss_bias_d_vec = l1_per_sample(pred_d_mean, true_d_mean) - - pred_d_diff = first_diff(pred_d) - true_d_diff = first_diff(true_d) - loss_d_shape_vec = smooth_l1_per_sample(pred_d_diff, true_d_diff, beta=huber_beta) - - # 自动拟合风格损失在原始曲线尺度上计算。默认权重为 0, - # 但保留这个接口,方便后续实验启用。 - pred_p_raw = affine_restore(pred_p, mean_p, scale_p) - true_p_raw = affine_restore(true_p, mean_p, scale_p) - pred_d_raw = affine_restore(pred_d, mean_d, scale_d) - true_d_raw = affine_restore(true_d, mean_d, scale_d) - - loss_autofit_p_vec = autofit_curve_objective_per_sample(pred_p_raw, true_p_raw) - loss_autofit_d_vec = autofit_curve_objective_per_sample(pred_d_raw, true_d_raw) - - total_vec = ( - w_pressure * loss_p_vec - + w_derivative * loss_d_vec - + w_slope * loss_s_vec - + w_bias_pressure * loss_bias_p_vec - + w_bias_derivative * loss_bias_d_vec - + w_derivative_shape * loss_d_shape_vec - + w_autofit_pressure * loss_autofit_p_vec - + w_autofit_derivative * loss_autofit_d_vec +) -> LossBatchParts: + """把拼接曲线拆成压力、导数和 slope 三段。""" + return LossBatchParts( + pred_p=pred[:, slices["log_pressure"]], + pred_d=pred[:, slices["log_derivative"]], + pred_s=pred[:, slices["slope"]], + true_p=target[:, slices["log_pressure"]], + true_d=target[:, slices["log_derivative"]], + true_s=target[:, slices["slope"]], ) - if use_sample_reweight: - # 重加权只改变样本在 batch 均值中的贡献,不改变各损失分量本身的记录口径。 - sample_weight = build_sample_weight( - true_p=true_p, - true_d=true_d, - alpha=sample_reweight_alpha, - w_min=sample_weight_min, - w_max=sample_weight_max, - ) - else: - sample_weight = torch.ones_like(total_vec) - total = (total_vec * sample_weight).mean() +def compute_basic_loss_vectors( + parts: LossBatchParts, + loss_cfg: LossConfig, +) -> dict[str, torch.Tensor]: + """计算标准化空间中的基础点值、偏置和导数形状损失。""" + return { + "loss_pressure": regression_per_sample(parts.pred_p, parts.true_p, loss_cfg), + "loss_derivative": regression_per_sample(parts.pred_d, parts.true_d, loss_cfg), + "loss_slope": mse_per_sample(parts.pred_s, parts.true_s), + "loss_bias_pressure": l1_per_sample( + parts.pred_p.mean(dim=1, keepdim=True), + parts.true_p.mean(dim=1, keepdim=True), + ), + "loss_bias_derivative": l1_per_sample( + parts.pred_d.mean(dim=1, keepdim=True), + parts.true_d.mean(dim=1, keepdim=True), + ), + "loss_derivative_shape": regression_per_sample( + first_diff(parts.pred_d), + first_diff(parts.true_d), + loss_cfg, + ), + } + + +def compute_autofit_loss_vectors( + parts: LossBatchParts, + context: LossContext, +) -> dict[str, torch.Tensor]: + """在原始尺度上计算自动拟合风格损失。""" + pressure_slice = context.slices["log_pressure"] + derivative_slice = context.slices["log_derivative"] + + mean_p = context.curve_stats.mean_raw[pressure_slice].unsqueeze(0) + scale_p = context.curve_stats.scale_raw[pressure_slice].unsqueeze(0) + mean_d = context.curve_stats.mean_raw[derivative_slice].unsqueeze(0) + scale_d = context.curve_stats.scale_raw[derivative_slice].unsqueeze(0) return { - "loss": total, - "loss_pressure": loss_p_vec.mean(), - "loss_derivative": loss_d_vec.mean(), - "loss_slope": loss_s_vec.mean(), - "loss_bias_pressure": loss_bias_p_vec.mean(), - "loss_bias_derivative": loss_bias_d_vec.mean(), - "loss_derivative_shape": loss_d_shape_vec.mean(), - "loss_autofit_pressure": loss_autofit_p_vec.mean(), - "loss_autofit_derivative": loss_autofit_d_vec.mean(), - "sample_weight_mean": sample_weight.mean(), - "sample_weight_max": sample_weight.max(), + "loss_autofit_pressure": autofit_curve_objective_per_sample( + affine_restore(parts.pred_p, mean_p, scale_p), + affine_restore(parts.true_p, mean_p, scale_p), + ), + "loss_autofit_derivative": autofit_curve_objective_per_sample( + affine_restore(parts.pred_d, mean_d, scale_d), + affine_restore(parts.true_d, mean_d, scale_d), + ), } -def model_forward(model: nn.Module, params_x: torch.Tensor, schedule_x: torch.Tensor, use_schedule: bool) -> torch.Tensor: +def weighted_total_vector( + loss_vectors: dict[str, torch.Tensor], + weights: LossWeights, +) -> torch.Tensor: + """按配置权重合成每个样本的总损失向量。""" + return ( + weights.pressure * loss_vectors["loss_pressure"] + + weights.derivative * loss_vectors["loss_derivative"] + + weights.slope * loss_vectors["loss_slope"] + + weights.bias_pressure * loss_vectors["loss_bias_pressure"] + + weights.bias_derivative * loss_vectors["loss_bias_derivative"] + + weights.derivative_shape * loss_vectors["loss_derivative_shape"] + + weights.autofit_pressure * loss_vectors["loss_autofit_pressure"] + + weights.autofit_derivative * loss_vectors["loss_autofit_derivative"] + ) + + +def compute_weighted_loss( + pred: torch.Tensor, + target: torch.Tensor, + context: LossContext, +) -> dict[str, torch.Tensor]: + """计算正演代理模型的复合训练损失。""" + parts = split_curve_parts(pred, target, context.slices) + loss_vectors = compute_basic_loss_vectors(parts, context.loss_cfg) + loss_vectors.update(compute_autofit_loss_vectors(parts, context)) + + total_vec = weighted_total_vector(loss_vectors, context.loss_cfg.weights) + if context.reweight_cfg.enabled: + sample_weight = build_sample_weight(parts.true_p, parts.true_d, context.reweight_cfg) + else: + sample_weight = torch.ones_like(total_vec) + + metrics = {key: value.mean() for key, value in loss_vectors.items()} + metrics["loss"] = (total_vec * sample_weight).mean() + metrics["sample_weight_mean"] = sample_weight.mean() + metrics["sample_weight_max"] = sample_weight.max() + return metrics + + +def model_forward( + model: nn.Module, + params_x: torch.Tensor, + schedule_x: torch.Tensor, + use_schedule: bool, +) -> torch.Tensor: """按 use_schedule 开关统一调用模型,兼容只用参数输入和参数+流量制度输入。""" if use_schedule: return model(params_x, schedule_x) return model(params_x, None) -def evaluate( +def init_metric_accumulator() -> dict[str, float]: + """创建指标累加器。""" + return {key: 0.0 for key in METRIC_KEYS} + + +def accumulate_metrics( + total: dict[str, float], + losses: dict[str, torch.Tensor], + batch_size: int, +) -> None: + """按 batch 样本数加权累加指标。""" + for key in total: + total[key] += losses[key].item() * batch_size + + +def average_metrics(total: dict[str, float], total_n: int) -> dict[str, float]: + """将累加指标转换为样本平均指标。""" + denom = max(total_n, 1) + return {key: value / denom for key, value in total.items()} + + +def run_loader_epoch( model: nn.Module, loader: DataLoader, device: str, - slices: dict[str, slice], - cfg: TrainConfig, -) -> dict: - """在验证或测试 DataLoader 上计算平均损失和各损失分量。""" - model.eval() - - total = { - "loss": 0.0, - "loss_pressure": 0.0, - "loss_derivative": 0.0, - "loss_slope": 0.0, - "loss_bias_pressure": 0.0, - "loss_bias_derivative": 0.0, - "loss_derivative_shape": 0.0, - "loss_autofit_pressure": 0.0, - "loss_autofit_derivative": 0.0, - "sample_weight_mean": 0.0, - "sample_weight_max": 0.0, - } + context: LossContext, + use_schedule: bool, + optimizer: torch.optim.Optimizer | None = None, +) -> dict[str, float]: + """执行一个训练或评估 epoch,并返回平均指标。""" + is_train = optimizer is not None + model.train(mode=is_train) + + total = init_metric_accumulator() total_n = 0 + grad_context = torch.enable_grad() if is_train else torch.no_grad() - with torch.no_grad(): + with grad_context: for params_x, schedule_x, curve_y in loader: params_x = params_x.to(device) schedule_x = schedule_x.to(device) curve_y = curve_y.to(device) - pred = model_forward(model, params_x, schedule_x, cfg.use_schedule) - losses = compute_weighted_loss( - pred=pred, - target=curve_y, - slices=slices, - curve_mean_raw=cfg.curve_mean_raw, - curve_scale_raw=cfg.curve_scale_raw, - huber_beta=cfg.huber_beta, - w_pressure=cfg.w_pressure, - w_derivative=cfg.w_derivative, - w_slope=cfg.w_slope, - w_bias_pressure=cfg.w_bias_pressure, - w_bias_derivative=cfg.w_bias_derivative, - w_derivative_shape=cfg.w_derivative_shape, - w_autofit_pressure=cfg.w_autofit_pressure, - w_autofit_derivative=cfg.w_autofit_derivative, - use_sample_reweight=cfg.use_sample_reweight, - sample_reweight_alpha=cfg.sample_reweight_alpha, - sample_weight_min=cfg.sample_weight_min, - sample_weight_max=cfg.sample_weight_max, - ) - - bs = params_x.size(0) - for k in total: - total[k] += losses[k].item() * bs - total_n += bs + if is_train: + optimizer.zero_grad() - denom = max(total_n, 1) - return {k: v / denom for k, v in total.items()} + pred = model_forward(model, params_x, schedule_x, use_schedule) + losses = compute_weighted_loss(pred=pred, target=curve_y, context=context) + if is_train: + losses["loss"].backward() + optimizer.step() -def train_forward(cfg: TrainConfig) -> None: - """训练完整曲线正演代理模型。 + batch_size = params_x.size(0) + accumulate_metrics(total, losses, batch_size) + total_n += batch_size - 该流程负责加载预处理数据、推断输入输出维度、构建模型与优化器、循环训练、验证 - 保存最佳 checkpoint,并在训练结束后用最佳模型评估测试集。checkpoint 中会保存 - 模型结构所需维度、curve_layout、损失权重和重加权配置,便于独立评估脚本复现。 - """ - cfg.output_dir.mkdir(parents=True, exist_ok=True) - set_global_seed(int(cfg.seed)) + return average_metrics(total, total_n) - data = load_processed_dataset(cfg.processed_path) - curve_layout = infer_curve_layout(data) - part_slices = get_part_slices(curve_layout) + +def evaluate( + model: nn.Module, + loader: DataLoader, + device: str, + context: LossContext, + use_schedule: bool, +) -> dict[str, float]: + """在验证或测试 DataLoader 上计算平均损失和各损失分量。""" + return run_loader_epoch( + model=model, + loader=loader, + device=device, + context=context, + use_schedule=use_schedule, + optimizer=None, + ) + + +def build_curve_stats(data: dict, device: str) -> CurveStats: + """从预处理数据中的曲线 scaler 构建 torch 统计量。""" scaler_curve = data["scaler_curve"] curve_mean_raw = np.asarray(scaler_curve.mean_, dtype=np.float32).reshape(-1) curve_scale_raw = np.asarray(scaler_curve.scale_, dtype=np.float32).reshape(-1) - # 将曲线 scaler 张量挂到 cfg 上,方便训练和评估共用同一套损失代码。 - cfg.curve_mean_raw = torch.tensor(curve_mean_raw, dtype=torch.float32, device=cfg.device) - cfg.curve_scale_raw = torch.tensor(curve_scale_raw, dtype=torch.float32, device=cfg.device) - - train_ds = ForwardDataset( - data["X_params_train"], data["X_schedule_train"], data["Y_curve_train"] - ) - val_ds = ForwardDataset( - data["X_params_val"], data["X_schedule_val"], data["Y_curve_val"] - ) - test_ds = ForwardDataset( - data["X_params_test"], data["X_schedule_test"], data["Y_curve_test"] + return CurveStats( + mean_raw=torch.tensor(curve_mean_raw, dtype=torch.float32, device=device), + scale_raw=torch.tensor(curve_scale_raw, dtype=torch.float32, device=device), ) + +def build_dataloaders(data: dict, cfg: TrainConfig) -> DatasetBundle: + """根据预处理数组构造训练、验证、测试 DataLoader。""" + train_ds = ForwardDataset(data["X_params_train"], data["X_schedule_train"], data["Y_curve_train"]) + val_ds = ForwardDataset(data["X_params_val"], data["X_schedule_val"], data["Y_curve_val"]) + test_ds = ForwardDataset(data["X_params_test"], data["X_schedule_test"], data["Y_curve_test"]) + loader_generator = torch.Generator() - loader_generator.manual_seed(int(cfg.seed)) + loader_generator.manual_seed(int(cfg.runtime.seed)) - train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, generator=loader_generator) - val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False) - test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False) + train_loader = DataLoader( + train_ds, + batch_size=cfg.optim.batch_size, + shuffle=True, + generator=loader_generator, + ) + val_loader = DataLoader(val_ds, batch_size=cfg.optim.batch_size, shuffle=False) + test_loader = DataLoader(test_ds, batch_size=cfg.optim.batch_size, shuffle=False) + + return DatasetBundle( + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, + param_dim=data["X_params_train"].shape[1], + schedule_dim=data["X_schedule_train"].shape[1], + curve_dim=data["Y_curve_train"].shape[1], + ) - param_dim = data["X_params_train"].shape[1] - schedule_dim = data["X_schedule_train"].shape[1] - curve_dim = data["Y_curve_train"].shape[1] - # 模型维度从预处理数据集中自动推断,便于不同数据集版本共用训练入口。 - model = ForwardSurrogate( +def build_forward_model( + model_cfg: ModelConfig, + param_dim: int, + schedule_dim: int, + curve_dim: int, + device: str, +) -> nn.Module: + """兼容新版配置式 ForwardSurrogate 和旧版关键字参数式 ForwardSurrogate。""" + surrogate_cfg = ForwardSurrogateConfig( param_dim=param_dim, schedule_dim=schedule_dim, curve_dim=curve_dim, - hidden_dim=cfg.hidden_dim, - dropout=cfg.dropout, - use_schedule=cfg.use_schedule, - ).to(cfg.device) + hidden_dim=model_cfg.hidden_dim, + dropout=model_cfg.dropout, + use_schedule=model_cfg.use_schedule, + ) + return ForwardSurrogate(surrogate_cfg).to(device) + - optimizer = torch.optim.Adam( +def build_optimizer(model: nn.Module, optim_cfg: OptimConfig) -> torch.optim.Optimizer: + """构建 Adam 优化器。""" + return torch.optim.Adam( model.parameters(), - lr=cfg.lr, - weight_decay=cfg.weight_decay, + lr=optim_cfg.lr, + weight_decay=optim_cfg.weight_decay, ) - best_val = float("inf") - best_path = cfg.output_dir / "forward_surrogate_best.pt" - history: list[dict] = [] +def print_training_config(cfg: TrainConfig, curve_layout: dict) -> None: + """打印训练配置摘要。""" + weights = cfg.loss.weights + reweight = cfg.sample_reweight print("训练配置:") - print(f" device={cfg.device}") - print(f" seed={cfg.seed}") - print(f" batch_size={cfg.batch_size}, epochs={cfg.epochs}, lr={cfg.lr}, weight_decay={cfg.weight_decay}") - print(f" hidden_dim={cfg.hidden_dim}, dropout={cfg.dropout}") - print(f" use_schedule={cfg.use_schedule}") + print(f" device={cfg.runtime.device}") + print(f" seed={cfg.runtime.seed}") print( - f" weights: pressure={cfg.w_pressure}, derivative={cfg.w_derivative}, " - f"slope={cfg.w_slope}, bias_p={cfg.w_bias_pressure}, " - f"bias_d={cfg.w_bias_derivative}, d_shape={cfg.w_derivative_shape}, " - f"autofit_p={cfg.w_autofit_pressure}, autofit_d={cfg.w_autofit_derivative}" + f" batch_size={cfg.optim.batch_size}, epochs={cfg.optim.epochs}, " + f"lr={cfg.optim.lr}, weight_decay={cfg.optim.weight_decay}" ) + print(f" hidden_dim={cfg.model.hidden_dim}, dropout={cfg.model.dropout}") + print(f" use_schedule={cfg.model.use_schedule}") print( - f" sample_reweight={cfg.use_sample_reweight}, alpha={cfg.sample_reweight_alpha}, " - f"clip=[{cfg.sample_weight_min}, {cfg.sample_weight_max}]" + f" weights: pressure={weights.pressure}, derivative={weights.derivative}, " + f"slope={weights.slope}, bias_p={weights.bias_pressure}, " + f"bias_d={weights.bias_derivative}, d_shape={weights.derivative_shape}, " + f"autofit_p={weights.autofit_pressure}, autofit_d={weights.autofit_derivative}" + ) + print( + f" sample_reweight={reweight.enabled}, alpha={reweight.alpha}, " + f"clip=[{reweight.weight_min}, {reweight.weight_max}]" ) print(f" curve_layout={curve_layout}") print(" note: 当前重点训练 pressure + derivative;可显式关闭 schedule 分支做固定制度对照") - for epoch in range(1, cfg.epochs + 1): - model.train() - - total = { - "loss": 0.0, - "loss_pressure": 0.0, - "loss_derivative": 0.0, - "loss_slope": 0.0, - "loss_bias_pressure": 0.0, - "loss_bias_derivative": 0.0, - "loss_derivative_shape": 0.0, - "loss_autofit_pressure": 0.0, - "loss_autofit_derivative": 0.0, - "sample_weight_mean": 0.0, - "sample_weight_max": 0.0, - } - total_n = 0 - - for params_x, schedule_x, curve_y in train_loader: - params_x = params_x.to(cfg.device) - schedule_x = schedule_x.to(cfg.device) - curve_y = curve_y.to(cfg.device) - - optimizer.zero_grad() - - pred = model_forward(model, params_x, schedule_x, cfg.use_schedule) - losses = compute_weighted_loss( - pred=pred, - target=curve_y, - slices=part_slices, - curve_mean_raw=cfg.curve_mean_raw, - curve_scale_raw=cfg.curve_scale_raw, - huber_beta=cfg.huber_beta, - w_pressure=cfg.w_pressure, - w_derivative=cfg.w_derivative, - w_slope=cfg.w_slope, - w_bias_pressure=cfg.w_bias_pressure, - w_bias_derivative=cfg.w_bias_derivative, - w_derivative_shape=cfg.w_derivative_shape, - w_autofit_pressure=cfg.w_autofit_pressure, - w_autofit_derivative=cfg.w_autofit_derivative, - use_sample_reweight=cfg.use_sample_reweight, - sample_reweight_alpha=cfg.sample_reweight_alpha, - sample_weight_min=cfg.sample_weight_min, - sample_weight_max=cfg.sample_weight_max, - ) - - losses["loss"].backward() - optimizer.step() - - bs = params_x.size(0) - # 所有 batch 指标按样本数加权累加,最后再除以总样本数。 - for k in total: - total[k] += losses[k].item() * bs - total_n += bs - - denom = max(total_n, 1) - train_metrics = {k: v / denom for k, v in total.items()} +def format_metric_line(epoch: int, train_metrics: dict[str, float], val_metrics: dict[str, float]) -> str: + """格式化单个 epoch 的训练与验证指标。""" + return ( + f"[Epoch {epoch:03d}] " + f"train={train_metrics['loss']:.6f} " + f"(p={train_metrics['loss_pressure']:.6f}, " + f"d={train_metrics['loss_derivative']:.6f}, " + f"s={train_metrics['loss_slope']:.6f}, " + f"bp={train_metrics['loss_bias_pressure']:.6f}, " + f"bd={train_metrics['loss_bias_derivative']:.6f}, " + f"ds={train_metrics['loss_derivative_shape']:.6f}, " + f"ap={train_metrics['loss_autofit_pressure']:.6f}, " + f"ad={train_metrics['loss_autofit_derivative']:.6f}, " + f"wmean={train_metrics['sample_weight_mean']:.4f}, " + f"wmax={train_metrics['sample_weight_max']:.4f}) " + f"val={val_metrics['loss']:.6f} " + f"(p={val_metrics['loss_pressure']:.6f}, " + f"d={val_metrics['loss_derivative']:.6f}, " + f"s={val_metrics['loss_slope']:.6f}, " + f"bp={val_metrics['loss_bias_pressure']:.6f}, " + f"bd={val_metrics['loss_bias_derivative']:.6f}, " + f"ds={val_metrics['loss_derivative_shape']:.6f}, " + f"ap={val_metrics['loss_autofit_pressure']:.6f}, " + f"ad={val_metrics['loss_autofit_derivative']:.6f}, " + f"wmean={val_metrics['sample_weight_mean']:.4f}, " + f"wmax={val_metrics['sample_weight_max']:.4f})" + ) + + +def format_final_line(test_metrics: dict[str, float]) -> str: + """格式化最终测试集指标。""" + return ( + f"[Final] test={test_metrics['loss']:.6f} " + f"(p={test_metrics['loss_pressure']:.6f}, " + f"d={test_metrics['loss_derivative']:.6f}, " + f"s={test_metrics['loss_slope']:.6f}, " + f"bp={test_metrics['loss_bias_pressure']:.6f}, " + f"bd={test_metrics['loss_bias_derivative']:.6f}, " + f"ds={test_metrics['loss_derivative_shape']:.6f}, " + f"ap={test_metrics['loss_autofit_pressure']:.6f}, " + f"ad={test_metrics['loss_autofit_derivative']:.6f}, " + f"wmean={test_metrics['sample_weight_mean']:.4f}, " + f"wmax={test_metrics['sample_weight_max']:.4f})" + ) + + +def build_checkpoint_payload( + model: nn.Module, + bundle: DatasetBundle, + cfg: TrainConfig, + curve_layout: dict, +) -> dict[str, Any]: + """构建 checkpoint 保存内容。""" + return { + "model_state_dict": model.state_dict(), + "param_dim": bundle.param_dim, + "schedule_dim": bundle.schedule_dim, + "curve_dim": bundle.curve_dim, + "hidden_dim": cfg.model.hidden_dim, + "dropout": cfg.model.dropout, + "use_schedule": cfg.model.use_schedule, + "seed": int(cfg.runtime.seed), + "curve_layout": curve_layout, + "loss_weights": asdict(cfg.loss.weights), + "sample_reweight": asdict(cfg.sample_reweight), + } + + +def append_history_row( + history: list[dict], + epoch: int, + train_metrics: dict[str, float], + val_metrics: dict[str, float], +) -> None: + """把当前 epoch 的指标写入 history 列表。""" + row = {"epoch": epoch} + row.update({f"train_{key}": float(value) for key, value in train_metrics.items()}) + row.update({f"val_{key}": float(value) for key, value in val_metrics.items()}) + history.append(row) + + +def save_json(path: Path, payload: dict | list) -> None: + """保存 JSON 文件。""" + with open(path, "w", encoding="utf-8") as file_obj: + json.dump(payload, file_obj, ensure_ascii=False, indent=2) + + +def train_epochs( + model: nn.Module, + bundle: DatasetBundle, + cfg: TrainConfig, + context: LossContext, + curve_layout: dict, +) -> tuple[float, Path, list[dict]]: + """执行训练循环并保存最佳模型。""" + optimizer = build_optimizer(model, cfg.optim) + best_val = float("inf") + best_path = cfg.output_dir / "forward_surrogate_best.pt" + history: list[dict] = [] + + for epoch in range(1, cfg.optim.epochs + 1): + train_metrics = run_loader_epoch( + model=model, + loader=bundle.train_loader, + device=cfg.runtime.device, + context=context, + use_schedule=cfg.model.use_schedule, + optimizer=optimizer, + ) val_metrics = evaluate( model=model, - loader=val_loader, - device=cfg.device, - slices=part_slices, - cfg=cfg, + loader=bundle.val_loader, + device=cfg.runtime.device, + context=context, + use_schedule=cfg.model.use_schedule, ) - row = {"epoch": epoch} - for k, v in train_metrics.items(): - row[f"train_{k}"] = float(v) - for k, v in val_metrics.items(): - row[f"val_{k}"] = float(v) - history.append(row) - - print( - f"[Epoch {epoch:03d}] " - f"train={train_metrics['loss']:.6f} " - f"(p={train_metrics['loss_pressure']:.6f}, " - f"d={train_metrics['loss_derivative']:.6f}, " - f"s={train_metrics['loss_slope']:.6f}, " - f"bp={train_metrics['loss_bias_pressure']:.6f}, " - f"bd={train_metrics['loss_bias_derivative']:.6f}, " - f"ds={train_metrics['loss_derivative_shape']:.6f}, " - f"ap={train_metrics['loss_autofit_pressure']:.6f}, " - f"ad={train_metrics['loss_autofit_derivative']:.6f}, " - f"wmean={train_metrics['sample_weight_mean']:.4f}, " - f"wmax={train_metrics['sample_weight_max']:.4f}) " - f"val={val_metrics['loss']:.6f} " - f"(p={val_metrics['loss_pressure']:.6f}, " - f"d={val_metrics['loss_derivative']:.6f}, " - f"s={val_metrics['loss_slope']:.6f}, " - f"bp={val_metrics['loss_bias_pressure']:.6f}, " - f"bd={val_metrics['loss_bias_derivative']:.6f}, " - f"ds={val_metrics['loss_derivative_shape']:.6f}, " - f"ap={val_metrics['loss_autofit_pressure']:.6f}, " - f"ad={val_metrics['loss_autofit_derivative']:.6f}, " - f"wmean={val_metrics['sample_weight_mean']:.4f}, " - f"wmax={val_metrics['sample_weight_max']:.4f})" - ) + append_history_row(history, epoch, train_metrics, val_metrics) + print(format_metric_line(epoch, train_metrics, val_metrics)) if val_metrics["loss"] < best_val: best_val = val_metrics["loss"] - # 检查点保存模型结构所需维度和训练配置,评估脚本可不依赖外部配置直接恢复模型。 - torch.save( - { - "model_state_dict": model.state_dict(), - "param_dim": param_dim, - "schedule_dim": schedule_dim, - "curve_dim": curve_dim, - "hidden_dim": cfg.hidden_dim, - "dropout": cfg.dropout, - "use_schedule": cfg.use_schedule, - "seed": int(cfg.seed), - "curve_layout": curve_layout, - "loss_weights": { - "pressure": cfg.w_pressure, - "derivative": cfg.w_derivative, - "slope": cfg.w_slope, - "bias_pressure": cfg.w_bias_pressure, - "bias_derivative": cfg.w_bias_derivative, - "derivative_shape": cfg.w_derivative_shape, - "autofit_pressure": cfg.w_autofit_pressure, - "autofit_derivative": cfg.w_autofit_derivative, - }, - "sample_reweight": { - "enabled": cfg.use_sample_reweight, - "alpha": cfg.sample_reweight_alpha, - "weight_min": cfg.sample_weight_min, - "weight_max": cfg.sample_weight_max, - }, - }, - best_path, - ) + torch.save(build_checkpoint_payload(model, bundle, cfg, curve_layout), best_path) print(f" -> best model saved to: {best_path}") - with open(cfg.output_dir / "history.json", "w", encoding="utf-8") as f: - json.dump(history, f, ensure_ascii=False, indent=2) + return best_val, best_path, history - checkpoint = torch.load(best_path, map_location=cfg.device) - best_model = ForwardSurrogate( - param_dim=checkpoint["param_dim"], - schedule_dim=checkpoint["schedule_dim"], - curve_dim=checkpoint["curve_dim"], + +def load_best_model(best_path: Path, device: str) -> nn.Module: + """从 checkpoint 恢复最佳模型。""" + checkpoint = torch.load(best_path, map_location=device) + model_cfg = ModelConfig( hidden_dim=checkpoint["hidden_dim"], dropout=checkpoint["dropout"], use_schedule=checkpoint.get("use_schedule", True), - ).to(cfg.device) + ) + best_model = build_forward_model( + model_cfg=model_cfg, + param_dim=checkpoint["param_dim"], + schedule_dim=checkpoint["schedule_dim"], + curve_dim=checkpoint["curve_dim"], + device=device, + ) best_model.load_state_dict(checkpoint["model_state_dict"]) + return best_model + + +def build_metrics_payload( + best_val: float, + test_metrics: dict[str, float], + cfg: TrainConfig, + curve_layout: dict, +) -> dict[str, Any]: + """构建 metrics.json 内容。""" + return { + "best_val_loss": float(best_val), + "test_metrics": {key: float(value) for key, value in test_metrics.items()}, + "use_schedule": cfg.model.use_schedule, + "seed": int(cfg.runtime.seed), + "loss_weights": asdict(cfg.loss.weights), + "sample_reweight": asdict(cfg.sample_reweight), + "curve_layout": curve_layout, + } + +def train_forward(cfg: TrainConfig) -> None: + """训练完整曲线正演代理模型。""" + cfg.output_dir.mkdir(parents=True, exist_ok=True) + set_global_seed(int(cfg.runtime.seed)) + + data = load_processed_dataset(cfg.processed_path) + curve_layout = infer_curve_layout(data) + context = LossContext( + slices=get_part_slices(curve_layout), + curve_stats=build_curve_stats(data, cfg.runtime.device), + loss_cfg=cfg.loss, + reweight_cfg=cfg.sample_reweight, + ) + bundle = build_dataloaders(data, cfg) + + model = build_forward_model( + model_cfg=cfg.model, + param_dim=bundle.param_dim, + schedule_dim=bundle.schedule_dim, + curve_dim=bundle.curve_dim, + device=cfg.runtime.device, + ) + + print_training_config(cfg, curve_layout) + best_val, best_path, history = train_epochs(model, bundle, cfg, context, curve_layout) + save_json(cfg.output_dir / "history.json", history) + + best_model = load_best_model(best_path, cfg.runtime.device) test_metrics = evaluate( model=best_model, - loader=test_loader, - device=cfg.device, - slices=part_slices, - cfg=cfg, + loader=bundle.test_loader, + device=cfg.runtime.device, + context=context, + use_schedule=cfg.model.use_schedule, ) - print( - f"[Final] test={test_metrics['loss']:.6f} " - f"(p={test_metrics['loss_pressure']:.6f}, " - f"d={test_metrics['loss_derivative']:.6f}, " - f"s={test_metrics['loss_slope']:.6f}, " - f"bp={test_metrics['loss_bias_pressure']:.6f}, " - f"bd={test_metrics['loss_bias_derivative']:.6f}, " - f"ds={test_metrics['loss_derivative_shape']:.6f}, " - f"ap={test_metrics['loss_autofit_pressure']:.6f}, " - f"ad={test_metrics['loss_autofit_derivative']:.6f}, " - f"wmean={test_metrics['sample_weight_mean']:.4f}, " - f"wmax={test_metrics['sample_weight_max']:.4f})" + print(format_final_line(test_metrics)) + save_json( + cfg.output_dir / "metrics.json", + build_metrics_payload(best_val, test_metrics, cfg, curve_layout), ) - - with open(cfg.output_dir / "metrics.json", "w", encoding="utf-8") as f: - json.dump( - { - "best_val_loss": float(best_val), - "test_metrics": {k: float(v) for k, v in test_metrics.items()}, - "use_schedule": cfg.use_schedule, - "seed": int(cfg.seed), - "loss_weights": { - "pressure": cfg.w_pressure, - "derivative": cfg.w_derivative, - "slope": cfg.w_slope, - "bias_pressure": cfg.w_bias_pressure, - "bias_derivative": cfg.w_bias_derivative, - "derivative_shape": cfg.w_derivative_shape, - "autofit_pressure": cfg.w_autofit_pressure, - "autofit_derivative": cfg.w_autofit_derivative, - }, - "sample_reweight": { - "enabled": cfg.use_sample_reweight, - "alpha": cfg.sample_reweight_alpha, - "weight_min": cfg.sample_weight_min, - "weight_max": cfg.sample_weight_max, - }, - "curve_layout": curve_layout, - }, - f, - ensure_ascii=False, - indent=2, - ) diff --git a/ML/nmWTAI-ML/src/training/train_time_conditioned.py b/ML/nmWTAI-ML/src/training/train_time_conditioned.py index 8861fee..70e4d72 100644 --- a/ML/nmWTAI-ML/src/training/train_time_conditioned.py +++ b/ML/nmWTAI-ML/src/training/train_time_conditioned.py @@ -11,10 +11,13 @@ log_pressure 和 log_derivative。 from __future__ import annotations +# pylint: disable=import-error,duplicate-code + import json import random -from dataclasses import dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path +from typing import Any import joblib import numpy as np @@ -23,40 +26,45 @@ import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from src.data.param_features import inverse_transform_param_features -from src.models.time_conditioned_surrogate import TimeConditionedSurrogate +from src.models.time_conditioned_surrogate import ( + TimeConditionedSurrogate, + TimeConditionedSurrogateConfig, +) from src.training.train_forward import get_part_slices, infer_curve_layout +@dataclass +class PointCurveArrays: + """逐时间点数据集所需数组。""" + + params_x: np.ndarray + schedule_x: np.ndarray + time_x: np.ndarray + curve_y: np.ndarray + layout: dict + sample_weight: np.ndarray | None = None + + class PointCurveDataset(Dataset): - """把完整曲线展开为逐时间点训练样本。 - - 原始数据形状是 N 条曲线,每条曲线 T 个时间点。本 Dataset 的长度为 N*T, - __getitem__ 会根据一维索引反推出 sample_idx 和 time_idx,返回该时间点对应的 - 参数特征、制度特征、时间特征、双通道目标值以及样本级权重。 - """ - def __init__( - self, - params_x: np.ndarray, - schedule_x: np.ndarray, - time_x: np.ndarray, - curve_y: np.ndarray, - layout: dict, - sample_weight: np.ndarray | None = None, - ): + """把完整曲线展开为逐时间点训练样本。""" + + def __init__(self, arrays: PointCurveArrays): """保存逐点训练所需的参数、流量制度、时间特征、目标曲线和样本权重。""" - self.params_x = torch.tensor(params_x, dtype=torch.float32) - self.schedule_x = torch.tensor(schedule_x, dtype=torch.float32) - self.time_x = torch.tensor(time_x, dtype=torch.float32) + self.params_x = torch.tensor(arrays.params_x, dtype=torch.float32) + self.schedule_x = torch.tensor(arrays.schedule_x, dtype=torch.float32) + self.time_x = torch.tensor(arrays.time_x, dtype=torch.float32) - slices = get_part_slices(layout) - p = curve_y[:, slices["log_pressure"]] - d = curve_y[:, slices["log_derivative"]] - self.y = torch.tensor(np.stack([p, d], axis=-1), dtype=torch.float32) + slices = get_part_slices(arrays.layout) + pressure_y = arrays.curve_y[:, slices["log_pressure"]] + derivative_y = arrays.curve_y[:, slices["log_derivative"]] + self.y = torch.tensor(np.stack([pressure_y, derivative_y], axis=-1), dtype=torch.float32) self.n_samples = int(self.params_x.shape[0]) self.n_time = int(self.time_x.shape[1]) + sample_weight = arrays.sample_weight if sample_weight is None: sample_weight = np.ones((self.n_samples,), dtype=np.float32) + sample_weight = np.asarray(sample_weight, dtype=np.float32).reshape(-1) if sample_weight.shape[0] != self.n_samples: raise ValueError(f"sample_weight length mismatch: {sample_weight.shape[0]} != {self.n_samples}") @@ -80,30 +88,88 @@ class PointCurveDataset(Dataset): @dataclass -class TimeConditionedTrainConfig: - """时间条件代理模型训练配置,包括批量大小、学习率、模型宽度和设备。""" - processed_path: Path - output_dir: Path - seed: int = 42 +class TimeModelConfig: + """时间条件代理模型结构配置。""" + + hidden_dim: int = 256 + n_blocks: int = 4 + dropout: float = 0.05 + use_schedule: bool = True + + +@dataclass +class TimeOptimConfig: + """训练轮次和优化器配置。""" + batch_size: int = 4096 epochs: int = 120 lr: float = 1.0e-3 weight_decay: float = 1.0e-4 - hidden_dim: int = 256 - n_blocks: int = 4 - dropout: float = 0.05 + + +@dataclass +class TimeLossConfig: + """时间条件模型点级损失配置。""" + w_pressure: float = 1.0 w_derivative: float = 2.0 huber_beta: float = 0.05 - use_schedule: bool = True - sample_weight_mode: str = "none" + + +@dataclass +class RiskWeightConfig: + """风险区域样本加权配置。""" + + mode: str = "none" risk_weight: float = 2.5 skin_lt_minus8_weight: float = 3.5 - sample_weight_min: float = 1.0 - sample_weight_max: float = 4.0 + weight_min: float = 1.0 + weight_max: float = 4.0 + + +@dataclass +class TimeRuntimeConfig: + """训练运行时配置。""" + + seed: int = 42 device: str = "cuda" if torch.cuda.is_available() else "cpu" +@dataclass +class TimeConditionedTrainConfig: + """时间条件代理模型训练配置。""" + + processed_path: Path + output_dir: Path + runtime: TimeRuntimeConfig = field(default_factory=TimeRuntimeConfig) + optim: TimeOptimConfig = field(default_factory=TimeOptimConfig) + model: TimeModelConfig = field(default_factory=TimeModelConfig) + loss: TimeLossConfig = field(default_factory=TimeLossConfig) + risk_weight: RiskWeightConfig = field(default_factory=RiskWeightConfig) + + +@dataclass +class DataBundle: + """训练、验证、测试数据加载器与输入维度。""" + + train_loader: DataLoader + val_loader: DataLoader + test_loader: DataLoader + param_dim: int + schedule_dim: int + time_dim: int + + +@dataclass +class TrainArtifacts: + """训练过程中需要跨函数传递的数据。""" + + data: dict + curve_layout: dict + train_weight_summary: dict[str, Any] + bundle: DataBundle + + def set_global_seed(seed: int) -> None: """设置 Python、NumPy 和 PyTorch 随机种子,并在 CUDA 可用时同步设置 GPU 随机种子。""" random.seed(seed) @@ -121,215 +187,366 @@ def _smooth_l1_vector(pred: torch.Tensor, target: torch.Tensor, beta: float) -> def _loss( pred: torch.Tensor, target: torch.Tensor, - cfg: TimeConditionedTrainConfig, + loss_cfg: TimeLossConfig, sample_weight: torch.Tensor | None = None, ) -> torch.Tensor: """计算时间条件模型的点级损失,并按样本权重求平均。""" - loss_p = _smooth_l1_vector(pred[:, 0], target[:, 0], beta=float(cfg.huber_beta)) - loss_d = _smooth_l1_vector(pred[:, 1], target[:, 1], beta=float(cfg.huber_beta)) - loss_vec = float(cfg.w_pressure) * loss_p + float(cfg.w_derivative) * loss_d + loss_p = _smooth_l1_vector(pred[:, 0], target[:, 0], beta=float(loss_cfg.huber_beta)) + loss_d = _smooth_l1_vector(pred[:, 1], target[:, 1], beta=float(loss_cfg.huber_beta)) + loss_vec = float(loss_cfg.w_pressure) * loss_p + float(loss_cfg.w_derivative) * loss_d if sample_weight is None: return loss_vec.mean() - w = sample_weight.to(loss_vec.device).reshape(-1).clamp_min(0.0) - return (loss_vec * w).sum() / torch.clamp(w.sum(), min=1.0e-12) + weight = sample_weight.to(loss_vec.device).reshape(-1).clamp_min(0.0) + return (loss_vec * weight).sum() / torch.clamp(weight.sum(), min=1.0e-12) + + +def _move_batch_to_device( + batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + device: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """把一个 batch 的所有张量移动到目标设备。""" + params_x, schedule_x, time_x, target_y, sample_weight = batch + return ( + params_x.to(device), + schedule_x.to(device), + time_x.to(device), + target_y.to(device), + sample_weight.to(device), + ) -def _evaluate(model: TimeConditionedSurrogate, loader: DataLoader, cfg: TimeConditionedTrainConfig) -> float: - """在验证集上评估时间条件模型的平均损失。""" - model.eval() + +def _forward_batch( + model: TimeConditionedSurrogate, + batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + cfg: TimeConditionedTrainConfig, + use_weight: bool, +) -> tuple[torch.Tensor, int]: + """完成一个 batch 的前向计算并返回损失和样本数。""" + params_x, schedule_x, time_x, target_y, sample_weight = _move_batch_to_device( + batch, + cfg.runtime.device, + ) + schedule_input = schedule_x if cfg.model.use_schedule else None + pred = model(params_x, time_x, schedule_input) + weight = sample_weight if use_weight else None + loss = _loss(pred, target_y, cfg.loss, sample_weight=weight) + return loss, int(target_y.shape[0]) + + +def _run_loader( + model: TimeConditionedSurrogate, + loader: DataLoader, + cfg: TimeConditionedTrainConfig, + optimizer: torch.optim.Optimizer | None = None, +) -> float: + """执行一个训练或评估 epoch。""" + is_train = optimizer is not None + model.train(mode=is_train) total = 0.0 total_n = 0 - with torch.no_grad(): - for params_x, schedule_x, time_x, y, _sample_weight in loader: - params_x = params_x.to(cfg.device) - schedule_x = schedule_x.to(cfg.device) - time_x = time_x.to(cfg.device) - y = y.to(cfg.device) - pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None) - loss = _loss(pred, y, cfg) - bs = int(y.shape[0]) - total += float(loss.detach().cpu()) * bs - total_n += bs + grad_context = torch.enable_grad() if is_train else torch.no_grad() + + with grad_context: + for batch in loader: + if is_train: + optimizer.zero_grad() + + loss, batch_size = _forward_batch(model, batch, cfg, use_weight=is_train) + + if is_train: + loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) + optimizer.step() + + total += float(loss.detach().cpu()) * batch_size + total_n += batch_size + return total / max(total_n, 1) +def _evaluate( + model: TimeConditionedSurrogate, + loader: DataLoader, + cfg: TimeConditionedTrainConfig, +) -> float: + """在验证集或测试集上评估时间条件模型的平均损失。""" + return _run_loader(model, loader, cfg, optimizer=None) + + def _raw_params_from_processed_split(data: dict, split: str) -> dict[str, np.ndarray]: """从预处理数据中读取某个划分的原始参数,用于构造样本权重。""" key = f"X_params_{split}" features = data["scaler_params"].inverse_transform(data[key]) - raw = inverse_transform_param_features(features, data.get("meta", {}).get("param_feature_transform")) - names = list(data.get("meta", {}).get("param_names") or ["k", "skin", "wellboreC", "phi", "h", "Cf"]) - return {name: raw[:, idx].astype(np.float64) for idx, name in enumerate(names[: raw.shape[1]])} + transform = data.get("meta", {}).get("param_feature_transform") + raw = inverse_transform_param_features(features, transform) + default_names = ["k", "skin", "wellboreC", "phi", "h", "Cf"] + names = list(data.get("meta", {}).get("param_names") or default_names) + return { + name: raw[:, idx].astype(np.float64) + for idx, name in enumerate(names[: raw.shape[1]]) + } -def _build_sample_weight(data: dict, cfg: TimeConditionedTrainConfig, split: str = "train") -> np.ndarray: +def _build_sample_weight( + data: dict, + cfg: TimeConditionedTrainConfig, + split: str = "train", +) -> np.ndarray: """根据原始物理参数生成样本权重,使关键参数区域得到更多关注。""" - mode = str(cfg.sample_weight_mode or "none").lower() - n = int(data[f"X_params_{split}"].shape[0]) + mode = str(cfg.risk_weight.mode or "none").lower() + n_samples = int(data[f"X_params_{split}"].shape[0]) if mode in {"none", "off", "false"}: - return np.ones((n,), dtype=np.float32) + return np.ones((n_samples,), dtype=np.float32) if mode != "risk_region": - raise ValueError(f"Unknown sample_weight_mode={cfg.sample_weight_mode!r}") + raise ValueError(f"Unknown sample_weight_mode={cfg.risk_weight.mode!r}") params = _raw_params_from_processed_split(data, split) - weight = np.ones((n,), dtype=np.float32) + weight = np.ones((n_samples,), dtype=np.float32) risk = (params["skin"] < -5.0) & (params["wellboreC"] > 0.1) skin_extreme = params["skin"] < -8.0 - weight[risk] = np.maximum(weight[risk], float(cfg.risk_weight)) - weight[skin_extreme] = np.maximum(weight[skin_extreme], float(cfg.skin_lt_minus8_weight)) + weight[risk] = np.maximum(weight[risk], float(cfg.risk_weight.risk_weight)) + weight[skin_extreme] = np.maximum(weight[skin_extreme], float(cfg.risk_weight.skin_lt_minus8_weight)) + return np.clip( + weight, + float(cfg.risk_weight.weight_min), + float(cfg.risk_weight.weight_max), + ).astype(np.float32) - weight = np.clip(weight, float(cfg.sample_weight_min), float(cfg.sample_weight_max)) - return weight.astype(np.float32) - -def _summarize_sample_weight(sample_weight: np.ndarray) -> dict: +def _summarize_sample_weight(sample_weight: np.ndarray) -> dict[str, Any]: """统计样本权重的最小值、最大值和分位数,便于检查加权强度。""" - w = np.asarray(sample_weight, dtype=np.float32).reshape(-1) + weight = np.asarray(sample_weight, dtype=np.float32).reshape(-1) return { - "min": float(np.min(w)), - "mean": float(np.mean(w)), - "median": float(np.median(w)), - "max": float(np.max(w)), - "n_weight_gt_1": int(np.sum(w > 1.0)), - "n_weight_lt_1": int(np.sum(w < 1.0)), + "min": float(np.min(weight)), + "mean": float(np.mean(weight)), + "median": float(np.median(weight)), + "max": float(np.max(weight)), + "n_weight_gt_1": int(np.sum(weight > 1.0)), + "n_weight_lt_1": int(np.sum(weight < 1.0)), } -def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None: - """训练时间条件代理模型并保存训练产物。 - - 输入数据必须由新版预处理流程生成,包含 X_time_train/val/test。训练时只有训练集 - 打乱顺序,验证和测试保持固定顺序以便复现指标。最佳模型按验证损失保存,最终 - 写出 history.json 和 metrics.json,用于查看训练趋势和测试集性能。 - """ - cfg.output_dir.mkdir(parents=True, exist_ok=True) - set_global_seed(int(cfg.seed)) - - data = joblib.load(cfg.processed_path) +def _load_processed_data(path: Path) -> dict: + """读取预处理数据并检查时间条件训练所需字段。""" + data = joblib.load(path) required = ["X_time_train", "X_time_val", "X_time_test"] missing = [key for key in required if key not in data] if missing: - # 时间条件训练需要每个曲线点的时间特征;缺失时说明预处理版本不匹配。 raise KeyError(f"processed dataset is missing time-conditioned fields: {missing}") - - curve_layout = infer_curve_layout(data) - train_weight = _build_sample_weight(data, cfg, split="train") - train_weight_summary = _summarize_sample_weight(train_weight) - # PointCurveDataset 会把 [N, T] 曲线展开成 N*T 个点级样本。 - train_ds = PointCurveDataset( - data["X_params_train"], - data["X_schedule_train"], - data["X_time_train"], - data["Y_curve_train"], - curve_layout, - sample_weight=train_weight, + return data + + +def _make_point_dataset( + data: dict, + split: str, + curve_layout: dict, + sample_weight: np.ndarray | None = None, +) -> PointCurveDataset: + """构造某个数据划分对应的逐点数据集。""" + return PointCurveDataset( + PointCurveArrays( + params_x=data[f"X_params_{split}"], + schedule_x=data[f"X_schedule_{split}"], + time_x=data[f"X_time_{split}"], + curve_y=data[f"Y_curve_{split}"], + layout=curve_layout, + sample_weight=sample_weight, + ) ) - val_ds = PointCurveDataset(data["X_params_val"], data["X_schedule_val"], data["X_time_val"], data["Y_curve_val"], curve_layout) - test_ds = PointCurveDataset(data["X_params_test"], data["X_schedule_test"], data["X_time_test"], data["Y_curve_test"], curve_layout) + + +def _build_dataloaders( + data: dict, + curve_layout: dict, + train_weight: np.ndarray, + cfg: TimeConditionedTrainConfig, +) -> DataBundle: + """构建训练、验证和测试 DataLoader。""" + train_ds = _make_point_dataset(data, "train", curve_layout, train_weight) + val_ds = _make_point_dataset(data, "val", curve_layout) + test_ds = _make_point_dataset(data, "test", curve_layout) generator = torch.Generator() - generator.manual_seed(int(cfg.seed)) - # 只打乱训练集;验证/测试保持固定顺序方便复现指标。 - train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, generator=generator) - val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False) - test_loader = DataLoader(test_ds, batch_size=cfg.batch_size, shuffle=False) + generator.manual_seed(int(cfg.runtime.seed)) + + train_loader = DataLoader( + train_ds, + batch_size=cfg.optim.batch_size, + shuffle=True, + generator=generator, + ) + val_loader = DataLoader(val_ds, batch_size=cfg.optim.batch_size, shuffle=False) + test_loader = DataLoader(test_ds, batch_size=cfg.optim.batch_size, shuffle=False) - model = TimeConditionedSurrogate( + return DataBundle( + train_loader=train_loader, + val_loader=val_loader, + test_loader=test_loader, param_dim=int(data["X_params_train"].shape[1]), schedule_dim=int(data["X_schedule_train"].shape[1]), time_dim=int(data["X_time_train"].shape[-1]), - hidden_dim=int(cfg.hidden_dim), - n_blocks=int(cfg.n_blocks), - dropout=float(cfg.dropout), - use_schedule=bool(cfg.use_schedule), - ).to(cfg.device) - - # AdamW 的 weight_decay 与梯度更新解耦,适合这个全连接回归模型。 - optimizer = torch.optim.AdamW(model.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay)) - best_val = float("inf") - best_path = cfg.output_dir / "time_conditioned_surrogate_best.pt" - history: list[dict] = [] + ) + + +def _prepare_training_artifacts(cfg: TimeConditionedTrainConfig) -> TrainArtifacts: + """加载数据、推断布局、构造样本权重和 DataLoader。""" + data = _load_processed_data(cfg.processed_path) + curve_layout = infer_curve_layout(data) + train_weight = _build_sample_weight(data, cfg, split="train") + train_weight_summary = _summarize_sample_weight(train_weight) + bundle = _build_dataloaders(data, curve_layout, train_weight, cfg) + return TrainArtifacts( + data=data, + curve_layout=curve_layout, + train_weight_summary=train_weight_summary, + bundle=bundle, + ) + + +def _build_model(bundle: DataBundle, cfg: TimeConditionedTrainConfig) -> TimeConditionedSurrogate: + """兼容新版配置式模型和旧版关键字参数式模型。""" + model_cfg = TimeConditionedSurrogateConfig( + param_dim=bundle.param_dim, + schedule_dim=bundle.schedule_dim, + time_dim=bundle.time_dim, + hidden_dim=int(cfg.model.hidden_dim), + n_blocks=int(cfg.model.n_blocks), + dropout=float(cfg.model.dropout), + use_schedule=bool(cfg.model.use_schedule), + ) + return TimeConditionedSurrogate(model_cfg).to(cfg.runtime.device) + + +def _build_optimizer( + model: TimeConditionedSurrogate, + cfg: TimeConditionedTrainConfig, +) -> torch.optim.Optimizer: + """构建 AdamW 优化器。""" + return torch.optim.AdamW( + model.parameters(), + lr=float(cfg.optim.lr), + weight_decay=float(cfg.optim.weight_decay), + ) + +def _print_training_config( + cfg: TimeConditionedTrainConfig, + artifacts: TrainArtifacts, +) -> None: + """打印时间条件训练配置摘要。""" + meta = artifacts.data.get("meta", {}) print("Time-conditioned training config:") print(f" processed={cfg.processed_path}") print(f" output_dir={cfg.output_dir}") - print(f" device={cfg.device}, batch_size={cfg.batch_size}, epochs={cfg.epochs}") print( - f" dims: param={data['X_params_train'].shape[1]}, " - f"schedule={data['X_schedule_train'].shape[1]}, time={data['X_time_train'].shape[-1]}" + f" device={cfg.runtime.device}, batch_size={cfg.optim.batch_size}, " + f"epochs={cfg.optim.epochs}" + ) + print( + f" dims: param={artifacts.bundle.param_dim}, " + f"schedule={artifacts.bundle.schedule_dim}, time={artifacts.bundle.time_dim}" ) - print(f" curve_time_source={data.get('meta', {}).get('curve_time_source', 'unknown')}") - print(f" sample_weight_mode={cfg.sample_weight_mode}, sample_weight={train_weight_summary}") - - for epoch in range(1, int(cfg.epochs) + 1): - model.train() - total = 0.0 - total_n = 0 - for params_x, schedule_x, time_x, y, sample_weight in train_loader: - params_x = params_x.to(cfg.device) - schedule_x = schedule_x.to(cfg.device) - time_x = time_x.to(cfg.device) - y = y.to(cfg.device) - sample_weight = sample_weight.to(cfg.device) - - optimizer.zero_grad() - # 每个点的输入由“样本级参数/制度 + 点级时间特征”组成。 - pred = model(params_x, time_x, schedule_x if cfg.use_schedule else None) - loss = _loss(pred, y, cfg, sample_weight=sample_weight) - loss.backward() - # 点级样本数量大,偶发高误差 batch 可能产生尖峰梯度,训练时做裁剪。 - torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) - optimizer.step() - - bs = int(y.shape[0]) - total += float(loss.detach().cpu()) * bs - total_n += bs - - train_loss = total / max(total_n, 1) - val_loss = _evaluate(model, val_loader, cfg) + print(f" curve_time_source={meta.get('curve_time_source', 'unknown')}") + print( + f" sample_weight_mode={cfg.risk_weight.mode}, " + f"sample_weight={artifacts.train_weight_summary}" + ) + + +def _checkpoint_payload( + model: TimeConditionedSurrogate, + cfg: TimeConditionedTrainConfig, + artifacts: TrainArtifacts, +) -> dict[str, Any]: + """构造 checkpoint 保存内容。""" + return { + "model_state_dict": model.state_dict(), + "param_dim": artifacts.bundle.param_dim, + "schedule_dim": artifacts.bundle.schedule_dim, + "time_dim": artifacts.bundle.time_dim, + "hidden_dim": int(cfg.model.hidden_dim), + "n_blocks": int(cfg.model.n_blocks), + "dropout": float(cfg.model.dropout), + "use_schedule": bool(cfg.model.use_schedule), + "curve_layout": artifacts.curve_layout, + "processed_path": str(cfg.processed_path), + "seed": int(cfg.runtime.seed), + "sample_weight_mode": str(cfg.risk_weight.mode), + "sample_weight_summary": artifacts.train_weight_summary, + "model_config": asdict(cfg.model), + "loss_config": asdict(cfg.loss), + "risk_weight_config": asdict(cfg.risk_weight), + } + + +def _save_json(path: Path, payload: dict | list) -> None: + """写出 JSON 文件。""" + path.write_text(json.dumps(payload, indent=2, ensure_ascii=False), encoding="utf-8") + + +def _train_epochs( + model: TimeConditionedSurrogate, + cfg: TimeConditionedTrainConfig, + artifacts: TrainArtifacts, +) -> tuple[float, Path, list[dict]]: + """执行训练循环并保存最佳模型。""" + optimizer = _build_optimizer(model, cfg) + best_val = float("inf") + best_path = cfg.output_dir / "time_conditioned_surrogate_best.pt" + history: list[dict] = [] + + for epoch in range(1, int(cfg.optim.epochs) + 1): + train_loss = _run_loader(model, artifacts.bundle.train_loader, cfg, optimizer=optimizer) + val_loss = _evaluate(model, artifacts.bundle.val_loader, cfg) history.append({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss}) print(f"[Epoch {epoch:03d}] train={train_loss:.6f} val={val_loss:.6f}") if val_loss < best_val: best_val = val_loss - # checkpoint 保存曲线布局和输入维度,评估脚本可据此重建同构模型。 - torch.save( - { - "model_state_dict": model.state_dict(), - "param_dim": int(data["X_params_train"].shape[1]), - "schedule_dim": int(data["X_schedule_train"].shape[1]), - "time_dim": int(data["X_time_train"].shape[-1]), - "hidden_dim": int(cfg.hidden_dim), - "n_blocks": int(cfg.n_blocks), - "dropout": float(cfg.dropout), - "use_schedule": bool(cfg.use_schedule), - "curve_layout": curve_layout, - "processed_path": str(cfg.processed_path), - "seed": int(cfg.seed), - "sample_weight_mode": str(cfg.sample_weight_mode), - "sample_weight_summary": train_weight_summary, - }, - best_path, - ) + torch.save(_checkpoint_payload(model, cfg, artifacts), best_path) print(f" -> best model saved to: {best_path}") - checkpoint = torch.load(best_path, map_location=cfg.device) - model.load_state_dict(checkpoint["model_state_dict"]) - test_loss = _evaluate(model, test_loader, cfg) - - # history 记录逐 epoch 走势;metrics 记录最终选择的最佳验证和测试损失。 - (cfg.output_dir / "history.json").write_text(json.dumps(history, indent=2, ensure_ascii=False), encoding="utf-8") - (cfg.output_dir / "metrics.json").write_text( - json.dumps( - { - "best_val_loss": best_val, - "test_loss": test_loss, - "sample_weight_mode": str(cfg.sample_weight_mode), - "sample_weight_summary": train_weight_summary, - }, - indent=2, - ensure_ascii=False, - ), - encoding="utf-8", + return best_val, best_path, history + + +def _write_final_outputs( + cfg: TimeConditionedTrainConfig, + best_val: float, + test_loss: float, + history: list[dict], + artifacts: TrainArtifacts, +) -> None: + """保存 history.json 和 metrics.json。""" + _save_json(cfg.output_dir / "history.json", history) + _save_json( + cfg.output_dir / "metrics.json", + { + "best_val_loss": best_val, + "test_loss": test_loss, + "sample_weight_mode": str(cfg.risk_weight.mode), + "sample_weight_summary": artifacts.train_weight_summary, + "model_config": asdict(cfg.model), + "loss_config": asdict(cfg.loss), + "risk_weight_config": asdict(cfg.risk_weight), + }, ) + + +def train_time_conditioned(cfg: TimeConditionedTrainConfig) -> None: + """训练时间条件代理模型并保存训练产物。""" + cfg.output_dir.mkdir(parents=True, exist_ok=True) + set_global_seed(int(cfg.runtime.seed)) + + artifacts = _prepare_training_artifacts(cfg) + model = _build_model(artifacts.bundle, cfg) + + _print_training_config(cfg, artifacts) + best_val, best_path, history = _train_epochs(model, cfg, artifacts) + + checkpoint = torch.load(best_path, map_location=cfg.runtime.device) + model.load_state_dict(checkpoint["model_state_dict"]) + test_loss = _evaluate(model, artifacts.bundle.test_loader, cfg) + + _write_final_outputs(cfg, best_val, test_loss, history, artifacts) print(f"[Final] test={test_loss:.6f}")