You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/compare_single_case.py

610 lines
23 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""单个试井样本的数值求解器与正演代理模型对比。
该脚本用一组指定地层/井筒参数和流量制度分别运行 C++ 数值求解器与 Python
正演代理模型,随后在压力、压力导数和斜率三段曲线上计算误差指标,绘制逐点
对比图,并导出 JSON 汇总,便于排查单个案例的代理误差来源。
"""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import joblib
import matplotlib.pyplot as plt
import numpy as np
import torch
from src.common.config import Config
from src.common.experiment_paths import (
config_for_stage,
model_checkpoint_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features
from src.data.param_features import param_feature_transform_from_meta, transform_param_features
from src.data.params import Params, Schedule
from src.data.runner_client import CppRunner, read_result_bin
from src.data.schedule_features import build_schedule_model_vector
from src.evaluation.autofit_objective import dual_log_objective
from src.models.forward_surrogate import ForwardSurrogate
DEFAULT_SINGLE_CASE = {
"config": "configs/data_gen_family_random.yaml",
"tag": "family_random_50k",
"no_schedule": False,
"well_index": 0,
"params": {
"k": 0.025,
"skin": 0.0,
"wellboreC": 0.01,
"phi": 0.0245,
"h": 9.144,
"Cf": 0.0004315,
},
}
def parse_args() -> argparse.Namespace:
"""解析单案例对比所需的路径、参数、井号和流量制度覆盖项。
默认值对应一个可复现实验案例;命令行可覆盖模型 checkpoint、processed 数据、
地层/井筒参数以及 `timeQ/q/sectionIndex`,方便快速定位某个具体样本的误差表现。
"""
parser = argparse.ArgumentParser(
description="Compare solver output and surrogate prediction on a single parameter set"
)
parser.add_argument("--config", type=str, default=DEFAULT_SINGLE_CASE["config"], help="Config yaml path")
parser.add_argument(
"--stage",
choices=["fixed_case", "case_neighborhood", "family_random", "family_random_v2_q"],
default=None,
help="Optional stage to infer config path",
)
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
parser.add_argument("--model", type=str, default=None, help="Surrogate checkpoint path")
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
parser.add_argument(
"--tag",
type=str,
default=DEFAULT_SINGLE_CASE["tag"],
help="Experiment tag for processed/model lookup",
)
parser.add_argument(
"--no-schedule",
action="store_true",
help="When --model is omitted, infer the no-schedule checkpoint path",
)
parser.add_argument("--k", type=float, default=DEFAULT_SINGLE_CASE["params"]["k"])
parser.add_argument("--skin", type=float, default=DEFAULT_SINGLE_CASE["params"]["skin"])
parser.add_argument("--wellboreC", type=float, default=DEFAULT_SINGLE_CASE["params"]["wellboreC"])
parser.add_argument("--phi", type=float, default=DEFAULT_SINGLE_CASE["params"]["phi"])
parser.add_argument("--h", type=float, default=DEFAULT_SINGLE_CASE["params"]["h"])
parser.add_argument("--Cf", type=float, default=DEFAULT_SINGLE_CASE["params"]["Cf"])
parser.add_argument(
"--well-index",
type=int,
default=DEFAULT_SINGLE_CASE["well_index"],
help="Well index for solver output",
)
parser.add_argument(
"--timeQ",
type=str,
default=None,
help="Override schedule durations, e.g. '1000,1000,1000' or '0,1000,1000,1000'.",
)
parser.add_argument(
"--q",
type=str,
default=None,
help="Override schedule rates, e.g. '200,100,0' or '0,200,100,0'.",
)
parser.add_argument(
"--section-index",
type=int,
default=None,
help="Override tested section index. Defaults to the last section of the overridden schedule.",
)
return parser.parse_args()
def calc_metrics(
y_true: np.ndarray,
y_pred: np.ndarray,
eps_range: float = 1e-3,
eps_var: float = 1e-6,
) -> dict:
"""计算一维曲线预测误差指标。
`eps_range` 用于避免真实曲线几乎为常数时 NRMSE 分母过小,
`eps_var` 用于避免 R2 在真实曲线方差接近 0 时失真;无效场景返回 NaN。
"""
err = y_pred - y_true
mse = np.mean(err**2)
rmse = float(np.sqrt(mse))
mae = float(np.mean(np.abs(err)))
bias = float(np.mean(err))
value_range = float(np.max(y_true) - np.min(y_true))
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
ss_res = float(np.sum(err**2))
valid_nrmse = value_range > eps_range
valid_r2 = ss_tot > eps_var
nrmse = float(rmse / value_range) if valid_nrmse else np.nan
r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan
return {
"rmse": rmse,
"mae": mae,
"bias": bias,
"abs_bias": float(abs(bias)),
"nrmse": nrmse,
"r2": r2,
}
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
"""从预处理元数据推断曲线拼接布局。
新版 processed 文件会保存 `curve_layout`;若旧数据缺失该字段,则按压力、
压力导数、斜率三段等长拼接的历史约定回退,保证旧实验仍可比较。
"""
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
],
}
def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]:
"""按 `curve_layout` 把一维拼接曲线拆成命名片段。
返回值通常包含 `log_pressure`、`log_derivative` 和 `slope`
后续绘图、分段指标和自动拟合目标函数都依赖这些片段边界一致。
"""
parts: dict[str, np.ndarray] = {}
for part in layout["parts"]:
start = int(part["start"])
end = int(part["end"])
parts[str(part["name"])] = curve[start:end]
return parts
def resolve_cf(args: argparse.Namespace, cfg: Config) -> float:
"""解析综合压缩系数 Cf。
Cf 可能由命令行显式给出,也可能在配置文件的 fixed_params 中固定;
集中处理该兜底逻辑,避免构造 `Params` 时把缺失 Cf 静默写成错误默认值。
"""
if args.Cf is not None:
return float(args.Cf)
fixed_cfg = cfg.raw.get("params", {}).get("fixed_params", {}).get("Cf", {}) or {}
if bool(fixed_cfg.get("enabled", False)):
return float(fixed_cfg["value"])
raise ValueError("Cf 未在命令行提供,且配置文件里也没有启用固定 Cf")
def parse_float_list(raw: str, name: str) -> list[float]:
"""把命令行传入的逗号/分号/空格分隔数值串转换为浮点列表。
该函数用于解析 `--timeQ` 和 `--q`,允许用户用不同分隔符快速输入制度序列;
若没有解析出任何有效数字,则抛出带参数名的错误。
"""
values: list[float] = []
for token in raw.replace(";", ",").replace(" ", ",").split(","):
token = token.strip()
if not token:
continue
values.append(float(token))
if not values:
raise ValueError(f"{name} 不能为空")
return values
def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) -> Schedule:
"""构造单案例对比使用的流量制度。
默认从配置文件读取 `case_schedule`;当命令行提供 `--timeQ/--q` 时优先使用覆盖值,
并按显式 `--section-index` 或最后一段规则确定测试段。
"""
if args is not None and (args.timeQ is not None or args.q is not None or args.section_index is not None):
if args.timeQ is None or args.q is None:
raise ValueError("覆盖流量制度时必须同时提供 --timeQ 和 --q")
timeQ = parse_float_list(args.timeQ, "--timeQ")
q = parse_float_list(args.q, "--q")
if len(timeQ) != len(q):
raise ValueError(f"--timeQ 和 --q 长度必须一致,当前分别为 {len(timeQ)}{len(q)}")
# Allow the reporting/table convention with an initial "0 0" row.
if len(timeQ) >= 2 and timeQ[0] <= 0.0 and q[0] == 0.0:
timeQ = timeQ[1:]
q = q[1:]
section_index = int(args.section_index) if args.section_index is not None else len(timeQ)
schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q)
if not schedule.validate():
raise ValueError(
"命令行覆盖的流量制度非法。注意 compare 脚本里的 timeQ 表示每段持续时间,"
"如果有初始行请写成 0,0有效段必须 timeQ>0 且 q>=0。"
)
return schedule
schedule_cfg = cfg.raw["schedule"]["case_schedule"]
timeQ = list(map(float, schedule_cfg["timeQ"]))
q = list(map(float, schedule_cfg["q"]))
policy = cfg.raw["schedule"].get("section_policy", {}) or {}
mode = str(policy.get("mode", "fixed_last")).lower()
n_sections = len(timeQ)
if mode == "fixed_last":
section_index = n_sections
elif mode == "fixed_value":
section_index = int(np.clip(int(policy.get("fixed_value", n_sections)), 1, n_sections))
else:
section_index = int(np.clip(int(schedule_cfg.get("default_section_index", n_sections)), 1, n_sections))
schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q)
if not schedule.validate():
raise ValueError("配置文件中的 case_schedule 非法,无法用于单样本对比")
return schedule
def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]:
"""解析配置、processed 数据、模型 checkpoint 和输出目录路径。
路径优先级为命令行显式参数,其次为实验 tag 对应的标准目录;
`--no-schedule` 只在自动推断模型路径时用于区分是否带制度分支的模型。
"""
tag = normalize_tag(args.tag)
config_path = args.config
if config_path is None:
config_path = str(config_for_stage(args.stage) or Path("configs/data_gen_family_random.yaml"))
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = (
Path(args.model)
if args.model is not None
else model_checkpoint_for_tag(tag, use_schedule=not args.no_schedule)
)
if args.output_dir is not None:
output_dir = Path(args.output_dir)
else:
suffix = "" if not args.no_schedule else "_no_schedule"
output_dir = Path("results") / (f"single_compare_{tag}{suffix}" if tag else f"single_compare{suffix}")
return Config(config_path), processed_path, model_path, output_dir
def build_params_from_args(args: argparse.Namespace, cfg: Config, schedule: Schedule) -> Params:
"""把命令行中的地层/井筒参数组装为求解器和代理模型共用的 `Params`。
该对象同时持有物性参数与流量制度,是数值求解、参数特征变换和图标题展示的共同数据源。
"""
return Params(
k=float(args.k),
skin=float(args.skin),
wellboreC=float(args.wellboreC),
phi=float(args.phi),
h=float(args.h),
Cf=resolve_cf(args, cfg),
schedule=schedule,
)
def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) -> tuple[np.ndarray, dict]:
"""运行 C++ 数值求解器并抽取可与代理输出对齐的曲线。
原始求解结果会先做有效性检查和清洗,再按训练数据相同的重采样规则转换为
`log_pressure/log_derivative/slope` 拼接向量;同时返回原始 log-log 曲线供 JSON 留痕。
"""
runner = CppRunner(cfg=cfg)
try:
ok = runner.run_simulation(params, override_schedule=params.schedule, include_schedule=True)
result = read_result_bin(runner.result_bin) if runner.result_bin.exists() else None
if not ok and result is None:
raise RuntimeError("求解器运行失败,未生成有效结果")
if not ok and result is not None:
print("Warning: runner_client 返回失败,但 result.bin 可读取,继续使用求解器结果。")
if result is None:
raise RuntimeError("无法读取 result.bin")
if not result["loglog"]:
raise RuntimeError("求解器未返回 loglog 数据")
if well_index < 0 or well_index >= len(result["loglog"]):
raise IndexError(f"well-index={well_index} 超出范围,可用井数={len(result['loglog'])}")
loglog = result["loglog"][well_index]
t = np.asarray(loglog["t"], dtype=np.float64)
p = np.asarray(loglog["p"], dtype=np.float64)
d = np.asarray(loglog["deriv"], dtype=np.float64)
cleaned = clean_curve_for_dataset(cfg, t, p, d)
if cleaned is None:
raise RuntimeError("求解器返回曲线在清洗后无效")
t_clean, p_clean, d_clean = cleaned
valid, reason = is_valid_curve(cfg, t_clean, p_clean, d_clean)
if not valid:
raise RuntimeError(f"求解器曲线未通过有效性检查: {reason}")
curve_feat = resample_curve_to_features(cfg, t_clean, p_clean, d_clean)
raw = {
"t": t_clean.tolist(),
"p": p_clean.tolist(),
"d": d_clean.tolist(),
"n_steps": int(result["nSteps"]),
"n_wells": int(result["nWells"]),
}
return curve_feat, raw
finally:
runner.close()
def build_schedule_vector(cfg: Config, schedule: Schedule) -> np.ndarray:
"""把 `Schedule` 编码为正演代理模型的制度特征向量。
这里复用训练阶段同一套编码函数,确保单案例推理时的制度特征与 processed 数据一致。
"""
return build_schedule_model_vector(cfg, schedule)
def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, bool, torch.device]:
"""加载正演代理模型 checkpoint 并恢复网络结构。
checkpoint 中保存了输入维度、隐藏层宽度、dropout 和是否使用制度分支等信息;
按这些元数据重建模型后再加载权重,才能保证推理结构与训练结构一致。
"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
use_schedule = bool(checkpoint.get("use_schedule", True))
model = ForwardSurrogate(
param_dim=int(checkpoint["param_dim"]),
schedule_dim=int(checkpoint["schedule_dim"]),
curve_dim=int(checkpoint["curve_dim"]),
hidden_dim=int(checkpoint["hidden_dim"]),
dropout=float(checkpoint["dropout"]),
use_schedule=use_schedule,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return model, use_schedule, device
def predict_surrogate_curve(
processed: dict,
model: ForwardSurrogate,
device: torch.device,
use_schedule: bool,
params: Params,
schedule: Schedule,
cfg: Config,
) -> np.ndarray:
"""使用代理模型预测单个参数点的反标准化曲线。
参数和制度先按 processed 文件中的 transform/scaler 进入训练时的标准化空间;
模型输出再通过 `scaler_curve.inverse_transform` 还原成可与数值求解器直接比较的曲线值。
"""
scaler_params = processed["scaler_params"]
scaler_schedule = processed["scaler_schedule"]
scaler_curve = processed["scaler_curve"]
param_transform = param_feature_transform_from_meta(processed.get("meta", {}))
params_vec = np.asarray(
[params.k, params.skin, params.wellboreC, params.phi, params.h, params.Cf],
dtype=np.float32,
).reshape(1, -1)
schedule_vec = build_schedule_vector(cfg, schedule).reshape(1, -1)
params_x = scaler_params.transform(transform_param_features(params_vec, param_transform)).astype(np.float32)
schedule_x = scaler_schedule.transform(schedule_vec).astype(np.float32)
with torch.no_grad():
params_t = torch.tensor(params_x, dtype=torch.float32, device=device)
if use_schedule:
schedule_t = torch.tensor(schedule_x, dtype=torch.float32, device=device)
pred_scaled = model(params_t, schedule_t).cpu().numpy()
else:
pred_scaled = model(params_t, None).cpu().numpy()
return scaler_curve.inverse_transform(pred_scaled)[0].astype(np.float32)
def plot_comparison(
curve_true: np.ndarray,
curve_pred: np.ndarray,
curve_layout: dict,
output_path: Path,
params: Params,
schedule: Schedule,
model_path: Path,
use_schedule: bool,
) -> dict:
"""绘制数值求解器曲线与代理预测曲线的分段对比图。
左列展示每个曲线片段的真实值与预测值,右列展示逐点误差;
图标题汇总参数、制度、模型名称和整体指标,返回的 summary 会同步写入 JSON。
"""
true_parts = split_curve_by_layout(curve_true, curve_layout)
pred_parts = split_curve_by_layout(curve_pred, curve_layout)
overall = calc_metrics(curve_true, curve_pred)
autofit = dual_log_objective(curve_true, curve_pred, curve_layout)
overall_r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.4f}"
part_names = ["log_pressure", "log_derivative", "slope"]
title_map = {
"log_pressure": "Log Pressure",
"log_derivative": "Log |Derivative|",
"slope": "Slope of Log Pressure vs Log Time",
}
fig, axes = plt.subplots(3, 2, figsize=(14, 12))
fig.suptitle(
"Solver vs Surrogate\n"
f"k={params.k:.6g}, skin={params.skin:.6g}, wellboreC={params.wellboreC:.6g}, "
f"phi={params.phi:.6g}, h={params.h:.6g}, Cf={params.Cf:.6g}\n"
f"sectionIndex={schedule.sectionIndex}, timeQ={schedule.timeQ}, q={schedule.q}\n"
f"use_schedule={use_schedule}, model={model_path.name}, "
f"Overall RMSE={overall['rmse']:.4f}, MAE={overall['mae']:.4f}, "
f"AutoFitObj={autofit['dual_log_objective']:.4f}, "
f"R2={overall_r2_text}"
)
summary = {"overall": overall, "autofit": autofit, "parts": {}}
for row, name in enumerate(part_names):
y_true = true_parts[name]
y_pred = pred_parts[name]
err = y_pred - y_true
x = np.arange(len(y_true))
metrics = calc_metrics(y_true, y_pred)
summary["parts"][name] = metrics
nrmse_text = "nan" if np.isnan(metrics["nrmse"]) else f"{metrics['nrmse']:.4f}"
r2_text = "nan" if np.isnan(metrics["r2"]) else f"{metrics['r2']:.4f}"
ax_l = axes[row, 0]
ax_l.plot(x, y_true, label="Solver", linewidth=2, alpha=0.85)
ax_l.plot(x, y_pred, label="Surrogate", linewidth=2, alpha=0.85)
ax_l.set_title(
f"{title_map[name]} | RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, "
f"Bias={metrics['bias']:.4f}, NRMSE={nrmse_text}, R2={r2_text}"
)
ax_l.set_xlabel("Resampled Time Index")
ax_l.set_ylabel("Value")
ax_l.grid(True, alpha=0.3)
ax_l.legend()
ax_r = axes[row, 1]
ax_r.plot(x, err, linewidth=1.5, alpha=0.85)
ax_r.axhline(0.0, linestyle="--", linewidth=1)
ax_r.set_title(f"{title_map[name]} Error")
ax_r.set_xlabel("Resampled Time Index")
ax_r.set_ylabel("Surrogate - Solver")
ax_r.grid(True, alpha=0.3)
plt.tight_layout(rect=[0, 0, 1, 0.94])
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
return summary
def main() -> None:
"""执行单案例完整对比流程。
流程包括解析路径和制度、加载 processed 与 checkpoint、运行数值求解器、
调用代理模型预测、绘制对比图、写出 JSON 汇总并打印核心指标。
"""
args = parse_args()
cfg, processed_path, model_path, output_dir = resolve_paths(args)
output_dir.mkdir(parents=True, exist_ok=True)
if not processed_path.exists():
raise FileNotFoundError(f"Processed 数据不存在: {processed_path}")
if not model_path.exists():
raise FileNotFoundError(f"模型 checkpoint 不存在: {model_path}")
processed = joblib.load(processed_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
schedule = resolve_case_schedule(cfg, args)
params = build_params_from_args(args, cfg, schedule)
print("Running solver...")
curve_solver, raw_solver = run_solver_and_extract_curve(cfg, params, args.well_index)
print("Running surrogate...")
model, use_schedule, device = load_model(model_path)
curve_pred = predict_surrogate_curve(
processed=processed,
model=model,
device=device,
use_schedule=use_schedule,
params=params,
schedule=schedule,
cfg=cfg,
)
plot_path = output_dir / "single_case_comparison.png"
summary = plot_comparison(
curve_true=curve_solver,
curve_pred=curve_pred,
curve_layout=curve_layout,
output_path=plot_path,
params=params,
schedule=schedule,
model_path=model_path,
use_schedule=use_schedule,
)
summary_payload = {
"config_path": str(cfg.path),
"processed_path": str(processed_path),
"model_path": str(model_path),
"use_schedule": use_schedule,
"device": str(device),
"params": {
"k": params.k,
"skin": params.skin,
"wellboreC": params.wellboreC,
"phi": params.phi,
"h": params.h,
"Cf": params.Cf,
},
"schedule": {
"sectionIndex": schedule.sectionIndex,
"timeQ": schedule.timeQ,
"q": schedule.q,
},
"metrics": summary,
"solver_raw_loglog": raw_solver,
}
with open(output_dir / "single_case_summary.json", "w", encoding="utf-8") as f:
json.dump(summary_payload, f, ensure_ascii=False, indent=2)
overall = summary["overall"]
autofit = summary["autofit"]
overall_r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.6f}"
print("\nSingle-case comparison complete.")
print(f"Output dir: {output_dir}")
print(
f"Overall RMSE={overall['rmse']:.6f}, MAE={overall['mae']:.6f}, "
f"Bias={overall['bias']:.6f}, "
f"AutoFitObj={autofit['dual_log_objective']:.6f}, "
f"R2={overall_r2_text}"
)
print(f"Plot saved: {plot_path}")
print(f"Summary saved: {output_dir / 'single_case_summary.json'}")
if __name__ == "__main__":
main()