You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/compare_single_case.py

608 lines
23 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
数值试井代理模型 - 单案例 Solver vs Surrogate 对比脚本
主要功能:
1. 使用 C++ 数值求解器生成真实试井曲线
2. 使用训练好的代理模型进行预测
3. 对比 Solver 与 Surrogate 的输出结果
4. 计算 RMSE / MAE / R2 等指标
5. 绘制压力、导数、斜率三部分对比图
6. 输出 JSON 分析结果
该脚本通常用于:
- 检查代理模型是否学到真实物理行为
- 分析代理模型在哪些阶段误差较大
- 验证不同 schedule 对模型预测的影响
- 调试自动拟合(autofit)效果
"""
# pylint: disable=import-error,wrong-import-position,wrong-import-order,line-too-long,
# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments,invalid-name
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import joblib
import matplotlib.pyplot as plt
import numpy as np
import torch
from src.common.config import Config
from src.common.experiment_paths import (
config_for_stage,
model_checkpoint_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features
from src.data.param_features import param_feature_transform_from_meta, transform_param_features
from src.data.params import Params, Schedule
from src.data.runner_client import CppRunner, read_result_bin
from src.data.schedule_features import build_schedule_model_vector
from src.evaluation.autofit_objective import dual_log_objective
from src.models.forward_surrogate import ForwardSurrogate
# 默认单案例参数
# 用于没有传命令行参数时的快速测试
DEFAULT_SINGLE_CASE = {
"config": "configs/data_gen_family_random.yaml",
"tag": "family_random_50k",
"no_schedule": False,
"well_index": 0,
"params": {
"k": 0.025,
"skin": 0.0,
"wellboreC": 0.01,
"phi": 0.0245,
"h": 9.144,
"Cf": 0.0004315,
},
}
def parse_args() -> argparse.Namespace:
"""解析单个样本对比所需的 processed 数据、模型、样本索引和输出目录。"""
parser = argparse.ArgumentParser(
description="Compare solver output and surrogate prediction on a single parameter set"
)
parser.add_argument("--config", type=str, default=DEFAULT_SINGLE_CASE["config"], help="Config yaml path")
parser.add_argument(
"--stage",
choices=["fixed_case", "case_neighborhood", "family_random", "family_random_v2_q"],
default=None,
help="Optional stage to infer config path",
)
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
parser.add_argument("--model", type=str, default=None, help="Surrogate checkpoint path")
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
parser.add_argument(
"--tag",
type=str,
default=DEFAULT_SINGLE_CASE["tag"],
help="Experiment tag for processed/model lookup",
)
parser.add_argument(
"--no-schedule",
action="store_true",
help="When --model is omitted, infer the no-schedule checkpoint path",
)
parser.add_argument("--k", type=float, default=DEFAULT_SINGLE_CASE["params"]["k"])
parser.add_argument("--skin", type=float, default=DEFAULT_SINGLE_CASE["params"]["skin"])
parser.add_argument("--wellboreC", type=float, default=DEFAULT_SINGLE_CASE["params"]["wellboreC"])
parser.add_argument("--phi", type=float, default=DEFAULT_SINGLE_CASE["params"]["phi"])
parser.add_argument("--h", type=float, default=DEFAULT_SINGLE_CASE["params"]["h"])
parser.add_argument("--Cf", type=float, default=DEFAULT_SINGLE_CASE["params"]["Cf"])
parser.add_argument(
"--well-index",
type=int,
default=DEFAULT_SINGLE_CASE["well_index"],
help="Well index for solver output",
)
parser.add_argument(
"--timeQ",
type=str,
default=None,
help="Override schedule durations, e.g. '1000,1000,1000' or '0,1000,1000,1000'.",
)
parser.add_argument(
"--q",
type=str,
default=None,
help="Override schedule rates, e.g. '200,100,0' or '0,200,100,0'.",
)
parser.add_argument(
"--section-index",
type=int,
default=None,
help="Override tested section index. Defaults to the last section of the overridden schedule.",
)
return parser.parse_args()
def calc_metrics(
y_true: np.ndarray,
y_pred: np.ndarray,
eps_range: float = 1e-3,
eps_var: float = 1e-6,
) -> dict:
"""计算 RMSE、MAE、Bias、NRMSE、R2 等回归指标。"""
# 残差 = 代理模型预测值 - 数值求解器真实值
err = y_pred - y_true
mse = np.mean(err**2)
rmse = float(np.sqrt(mse))
mae = float(np.mean(np.abs(err)))
bias = float(np.mean(err))
value_range = float(np.max(y_true) - np.min(y_true))
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
ss_res = float(np.sum(err**2))
# 曲线几乎为常数时NRMSE 和 R2 的分母会过小,此时返回 NaN 更真实。
valid_nrmse = value_range > eps_range
valid_r2 = ss_tot > eps_var
nrmse = float(rmse / value_range) if valid_nrmse else np.nan
r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan
return {
"rmse": rmse,
"mae": mae,
"bias": bias,
"abs_bias": float(abs(bias)),
"nrmse": nrmse,
"r2": r2,
}
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
"""从元数据读取曲线分段布局;旧数据没有布局时按压力/导数/斜率三等分回退。"""
# curve_layout 描述了拼接曲线的结构布局
# 包括每一段 pressure/derivative/slope 的起止位置
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
# 兼容早期 processed 文件:没有显式 layout 时仍按 pressure/derivative/slope 三段切分。
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
],
}
def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]:
"""按照 curve_layout 将拼接曲线拆成 log_pressure、log_derivative 和 slope 三段。"""
# parts 用于保存拆分后的不同曲线段
# 例如 log_pressure / derivative / slope
parts: dict[str, np.ndarray] = {}
for part in layout["parts"]:
start = int(part["start"])
end = int(part["end"])
parts[str(part["name"])] = curve[start:end]
return parts
def resolve_cf(args: argparse.Namespace, cfg: Config) -> float:
"""解析压缩系数 Cf优先使用命令行参数否则从配置默认参数中读取。"""
# 命令行优先级最高
# 如果用户显式提供 Cf则直接使用
if args.Cf is not None:
return float(args.Cf)
fixed_cfg = cfg.raw.get("params", {}).get("fixed_params", {}).get("Cf", {}) or {}
if bool(fixed_cfg.get("enabled", False)):
return float(fixed_cfg["value"])
raise ValueError("Cf 未在命令行提供,且配置文件里也没有启用固定 Cf")
def parse_float_list(raw: str, name: str) -> list[float]:
"""解析输入文本、命令行或配置值,转换为后续流程可直接使用的结构。"""
# 用于解析类似 "1000,1000,1000" 的字符串输入
values: list[float] = []
for token in raw.replace(";", ",").replace(" ", ",").split(","):
token = token.strip()
if not token:
continue
values.append(float(token))
if not values:
raise ValueError(f"{name} 不能为空")
return values
def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) -> Schedule:
"""解析单案例使用的流量制度;命令行提供 timeQ/q 时覆盖配置中的默认制度。"""
if args is not None and (args.timeQ is not None or args.q is not None or args.section_index is not None):
if args.timeQ is None or args.q is None:
raise ValueError("覆盖流量制度时必须同时提供 --timeQ 和 --q")
timeQ = parse_float_list(args.timeQ, "--timeQ")
q = parse_float_list(args.q, "--q")
if len(timeQ) != len(q):
raise ValueError(f"--timeQ 和 --q 长度必须一致,当前分别为 {len(timeQ)}{len(q)}")
# 兼容报表写法:首行 0,0 只表示初始状态,不作为真实流量段传给求解器。
# 某些报表第一行仅用于表示初始状态
# 不是真实生产段,因此自动丢弃
if len(timeQ) >= 2 and timeQ[0] <= 0.0 and q[0] == 0.0:
timeQ = timeQ[1:]
q = q[1:]
section_index = int(args.section_index) if args.section_index is not None else len(timeQ)
schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q)
if not schedule.validate():
raise ValueError(
"命令行覆盖的流量制度非法。注意 compare 脚本里的 timeQ 表示每段持续时间,"
"如果有初始行请写成 0,0有效段必须 timeQ>0 且 q>=0。"
)
return schedule
schedule_cfg = cfg.raw["schedule"]["case_schedule"]
timeQ = list(map(float, schedule_cfg["timeQ"]))
q = list(map(float, schedule_cfg["q"]))
policy = cfg.raw["schedule"].get("section_policy", {}) or {}
mode = str(policy.get("mode", "fixed_last")).lower()
n_sections = len(timeQ)
if mode == "fixed_last":
section_index = n_sections
elif mode == "fixed_value":
section_index = int(np.clip(int(policy.get("fixed_value", n_sections)), 1, n_sections))
else:
section_index = int(np.clip(int(schedule_cfg.get("default_section_index", n_sections)), 1, n_sections))
schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q)
if not schedule.validate():
raise ValueError("配置文件中的 case_schedule 非法,无法用于单样本对比")
return schedule
def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]:
"""根据命令行参数、实验标签和 use_schedule 开关解析配置、预处理数据、模型和输出路径。"""
tag = normalize_tag(args.tag)
config_path = args.config
if config_path is None:
config_path = str(config_for_stage(args.stage) or Path("configs/data_gen_family_random.yaml"))
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = (
Path(args.model)
if args.model is not None
else model_checkpoint_for_tag(tag, use_schedule=not args.no_schedule)
)
if args.output_dir is not None:
output_dir = Path(args.output_dir)
else:
suffix = "" if not args.no_schedule else "_no_schedule"
output_dir = Path("results") / (f"single_compare_{tag}{suffix}" if tag else f"single_compare{suffix}")
return Config(config_path), processed_path, model_path, output_dir
def build_params_from_args(args: argparse.Namespace, cfg: Config, schedule: Schedule) -> Params:
"""根据命令行参数和默认值构造目标 Params 对象。"""
# 构建完整物理参数对象
# 后续会直接送入数值求解器和代理模型
return Params(
k=float(args.k),
skin=float(args.skin),
wellboreC=float(args.wellboreC),
phi=float(args.phi),
h=float(args.h),
Cf=resolve_cf(args, cfg),
schedule=schedule,
)
def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) -> tuple[np.ndarray, dict]:
"""调用 C++ 求解器运行一次正演,并把双对数输出重采样为模型曲线向量。"""
# 创建 C++ 求解器客户端
# 实际底层会调用数值试井求解程序
runner = CppRunner(cfg=cfg)
try:
ok = runner.run_simulation(params, override_schedule=params.schedule, include_schedule=True)
result = read_result_bin(runner.result_bin) if runner.result_bin.exists() else None
# 某些求解器返回码可能失败但 result.bin 已写完;只要结果可读就继续对比。
if not ok and result is None:
raise RuntimeError("求解器运行失败,未生成有效结果")
if not ok and result is not None:
print("Warning: runner_client 返回失败,但 result.bin 可读取,继续使用求解器结果。")
if result is None:
raise RuntimeError("无法读取 result.bin")
if not result["loglog"]:
raise RuntimeError("求解器未返回 loglog 数据")
if well_index < 0 or well_index >= len(result["loglog"]):
raise IndexError(f"well-index={well_index} 超出范围,可用井数={len(result['loglog'])}")
loglog = result["loglog"][well_index]
t = np.asarray(loglog["t"], dtype=np.float64)
p = np.asarray(loglog["p"], dtype=np.float64)
d = np.asarray(loglog["deriv"], dtype=np.float64)
# 求解器原始点数不一定等于模型输出维度,先清洗再重采样到训练时的固定网格。
# 对原始求解器曲线进行清洗
# 去除非法点、异常点、重复点等
cleaned = clean_curve_for_dataset(cfg, t, p, d)
if cleaned is None:
raise RuntimeError("求解器返回曲线在清洗后无效")
t_clean, p_clean, d_clean = cleaned
valid, reason = is_valid_curve(cfg, t_clean, p_clean, d_clean)
if not valid:
raise RuntimeError(f"求解器曲线未通过有效性检查: {reason}")
# 将原始不规则时间曲线重采样到固定特征网格
# 这是代理模型训练时使用的统一输入格式
curve_feat = resample_curve_to_features(cfg, t_clean, p_clean, d_clean)
raw = {
"t": t_clean.tolist(),
"p": p_clean.tolist(),
"d": d_clean.tolist(),
"n_steps": int(result["nSteps"]),
"n_wells": int(result["nWells"]),
}
return curve_feat, raw
finally:
runner.close()
def build_schedule_vector(cfg: Config, schedule: Schedule) -> np.ndarray:
"""把 Schedule 编码成正演代理模型可直接接收的流量制度特征向量。"""
return build_schedule_model_vector(cfg, schedule)
def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, bool, torch.device]:
"""加载模型检查点,按保存的维度和超参数重建网络并切换到评估模式。"""
# 加载训练好的 PyTorch checkpoint
checkpoint = torch.load(checkpoint_path, map_location="cpu")
use_schedule = bool(checkpoint.get("use_schedule", True))
model = ForwardSurrogate(
param_dim=int(checkpoint["param_dim"]),
schedule_dim=int(checkpoint["schedule_dim"]),
curve_dim=int(checkpoint["curve_dim"]),
hidden_dim=int(checkpoint["hidden_dim"]),
dropout=float(checkpoint["dropout"]),
use_schedule=use_schedule,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return model, use_schedule, device
def predict_surrogate_curve(
processed: dict,
model: ForwardSurrogate,
device: torch.device,
use_schedule: bool,
params: Params,
schedule: Schedule,
cfg: Config,
) -> np.ndarray:
"""使用正演代理模型预测单个参数和流量制度对应的完整曲线。"""
scaler_params = processed["scaler_params"]
scaler_schedule = processed["scaler_schedule"]
scaler_curve = processed["scaler_curve"]
param_transform = param_feature_transform_from_meta(processed.get("meta", {}))
# 构建参数特征向量
# 顺序必须与训练阶段保持一致
params_vec = np.asarray(
[params.k, params.skin, params.wellboreC, params.phi, params.h, params.Cf],
dtype=np.float32,
).reshape(1, -1)
# 将流量制度 schedule 编码成模型输入向量
schedule_vec = build_schedule_vector(cfg, schedule).reshape(1, -1)
# 输入特征必须使用训练时保存的 scaler 和参数变换,否则单案例预测会发生分布偏移。
params_x = scaler_params.transform(transform_param_features(params_vec, param_transform)).astype(np.float32)
schedule_x = scaler_schedule.transform(schedule_vec).astype(np.float32)
with torch.no_grad():
params_t = torch.tensor(params_x, dtype=torch.float32, device=device)
if use_schedule:
schedule_t = torch.tensor(schedule_x, dtype=torch.float32, device=device)
pred_scaled = model(params_t, schedule_t).cpu().numpy()
else:
pred_scaled = model(params_t, None).cpu().numpy()
return scaler_curve.inverse_transform(pred_scaled)[0].astype(np.float32)
def plot_comparison(
curve_true: np.ndarray,
curve_pred: np.ndarray,
curve_layout: dict,
output_path: Path,
params: Params,
schedule: Schedule,
model_path: Path,
use_schedule: bool,
) -> dict:
"""绘制数值求解器与代理模型的单案例曲线对比图。"""
true_parts = split_curve_by_layout(curve_true, curve_layout)
pred_parts = split_curve_by_layout(curve_pred, curve_layout)
# 整体曲线误差指标
overall = calc_metrics(curve_true, curve_pred)
# 自动拟合目标函数
# 用于衡量双对数曲线形态是否一致
autofit = dual_log_objective(curve_true, curve_pred, curve_layout)
part_names = ["log_pressure", "log_derivative", "slope"]
title_map = {
"log_pressure": "Log Pressure",
"log_derivative": "Log |Derivative|",
"slope": "Slope of Log Pressure vs Log Time",
}
r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.4f}"
fig, axes = plt.subplots(3, 2, figsize=(14, 12))
fig.suptitle(
"Solver vs Surrogate\n"
f"k={params.k:.6g}, skin={params.skin:.6g}, wellboreC={params.wellboreC:.6g}, "
f"phi={params.phi:.6g}, h={params.h:.6g}, Cf={params.Cf:.6g}\n"
f"sectionIndex={schedule.sectionIndex}, timeQ={schedule.timeQ}, q={schedule.q}\n"
f"use_schedule={use_schedule}, model={model_path.name}, "
f"Overall RMSE={overall['rmse']:.4f}, MAE={overall['mae']:.4f}, "
f"AutoFitObj={autofit['dual_log_objective']:.4f}, "
f"R2={r2_text}"
)
summary = {"overall": overall, "autofit": autofit, "parts": {}}
# 分别对 pressure / derivative / slope 三部分绘图
for row, name in enumerate(part_names):
# 左列画曲线重合程度,右列画逐时间点残差,便于定位误差集中在哪个时期。
y_true = true_parts[name]
y_pred = pred_parts[name]
# 残差 = 代理模型预测值 - 数值求解器真实值
err = y_pred - y_true
x = np.arange(len(y_true))
metrics = calc_metrics(y_true, y_pred)
summary["parts"][name] = metrics
nrmse_text = "nan" if np.isnan(metrics["nrmse"]) else f"{metrics['nrmse']:.4f}"
r2_text = "nan" if np.isnan(metrics["r2"]) else f"{metrics['r2']:.4f}"
ax_l = axes[row, 0]
ax_l.plot(x, y_true, label="Solver", linewidth=2, alpha=0.85)
ax_l.plot(x, y_pred, label="Surrogate", linewidth=2, alpha=0.85)
ax_l.set_title(
f"{title_map[name]} | RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, "
f"Bias={metrics['bias']:.4f}, NRMSE={nrmse_text}, R2={r2_text}"
)
ax_l.set_xlabel("Resampled Time Index")
ax_l.set_ylabel("Value")
ax_l.grid(True, alpha=0.3)
ax_l.legend()
ax_r = axes[row, 1]
ax_r.plot(x, err, linewidth=1.5, alpha=0.85)
ax_r.axhline(0.0, linestyle="--", linewidth=1)
ax_r.set_title(f"{title_map[name]} Error")
ax_r.set_xlabel("Resampled Time Index")
ax_r.set_ylabel("Surrogate - Solver")
ax_r.grid(True, alpha=0.3)
plt.tight_layout(rect=[0, 0, 1, 0.94])
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
return summary
def main() -> None:
"""抽取一个测试样本,绘制代理模型曲线与真实数值求解曲线的对比图。"""
args = parse_args()
cfg, processed_path, model_path, output_dir = resolve_paths(args)
output_dir.mkdir(parents=True, exist_ok=True)
if not processed_path.exists():
raise FileNotFoundError(f"Processed 数据不存在: {processed_path}")
if not model_path.exists():
raise FileNotFoundError(f"模型 checkpoint 不存在: {model_path}")
processed = joblib.load(processed_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
# 先用同一组参数和制度跑真实求解器,再用代理模型预测同一输入,保证对比公平。
schedule = resolve_case_schedule(cfg, args)
params = build_params_from_args(args, cfg, schedule)
print("Running solver...")
# 运行真实数值求解器
# 得到 ground truth 曲线
curve_solver, raw_solver = run_solver_and_extract_curve(cfg, params, args.well_index)
print("Running surrogate...")
model, use_schedule, device = load_model(model_path)
# 使用代理模型预测同一组参数对应的曲线
curve_pred = predict_surrogate_curve(
processed=processed,
model=model,
device=device,
use_schedule=use_schedule,
params=params,
schedule=schedule,
cfg=cfg,
)
plot_path = output_dir / "single_case_comparison.png"
summary = plot_comparison(
curve_true=curve_solver,
curve_pred=curve_pred,
curve_layout=curve_layout,
output_path=plot_path,
params=params,
schedule=schedule,
model_path=model_path,
use_schedule=use_schedule,
)
# 汇总所有实验结果
# 最终会保存为 JSON 便于后续分析
summary_payload = {
"config_path": str(cfg.path),
"processed_path": str(processed_path),
"model_path": str(model_path),
"use_schedule": use_schedule,
"device": str(device),
"params": {
"k": params.k,
"skin": params.skin,
"wellboreC": params.wellboreC,
"phi": params.phi,
"h": params.h,
"Cf": params.Cf,
},
"schedule": {
"sectionIndex": schedule.sectionIndex,
"timeQ": schedule.timeQ,
"q": schedule.q,
},
"metrics": summary,
"solver_raw_loglog": raw_solver,
}
with open(output_dir / "single_case_summary.json", "w", encoding="utf-8") as f:
json.dump(summary_payload, f, ensure_ascii=False, indent=2)
overall = summary["overall"]
autofit = summary["autofit"]
r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.6f}"
print("\nSingle-case comparison complete.")
print(f"Output dir: {output_dir}")
print(
f"Overall RMSE={overall['rmse']:.6f}, MAE={overall['mae']:.6f}, "
f"Bias={overall['bias']:.6f}, "
f"AutoFitObj={autofit['dual_log_objective']:.6f}, "
f"R2={r2_text}"
)
print(f"Plot saved: {plot_path}")
print(f"Summary saved: {output_dir / 'single_case_summary.json'}")
if __name__ == "__main__":
main()