You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/compare_single_case.py

528 lines
19 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import joblib
import matplotlib.pyplot as plt
import numpy as np
import torch
from src.common.config import Config
from src.common.experiment_paths import (
config_for_stage,
model_checkpoint_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features
from src.data.param_features import param_feature_transform_from_meta, transform_param_features
from src.data.params import Params, Schedule
from src.data.runner_client import CppRunner, read_result_bin
from src.data.schedule_features import build_schedule_model_vector
from src.evaluation.autofit_objective import dual_log_objective
from src.models.forward_surrogate import ForwardSurrogate
DEFAULT_SINGLE_CASE = {
"config": "configs/data_gen_family_random.yaml",
"tag": "family_random_50k",
"no_schedule": False,
"well_index": 0,
"params": {
"k": 0.025,
"skin": 0.0,
"wellboreC": 0.01,
"phi": 0.0245,
"h": 9.144,
"Cf": 0.0004315,
},
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Compare solver output and surrogate prediction on a single parameter set"
)
parser.add_argument("--config", type=str, default=DEFAULT_SINGLE_CASE["config"], help="Config yaml path")
parser.add_argument(
"--stage",
choices=["fixed_case", "case_neighborhood", "family_random", "family_random_v2_q"],
default=None,
help="Optional stage to infer config path",
)
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
parser.add_argument("--model", type=str, default=None, help="Surrogate checkpoint path")
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
parser.add_argument(
"--tag",
type=str,
default=DEFAULT_SINGLE_CASE["tag"],
help="Experiment tag for processed/model lookup",
)
parser.add_argument(
"--no-schedule",
action="store_true",
help="When --model is omitted, infer the no-schedule checkpoint path",
)
parser.add_argument("--k", type=float, default=DEFAULT_SINGLE_CASE["params"]["k"])
parser.add_argument("--skin", type=float, default=DEFAULT_SINGLE_CASE["params"]["skin"])
parser.add_argument("--wellboreC", type=float, default=DEFAULT_SINGLE_CASE["params"]["wellboreC"])
parser.add_argument("--phi", type=float, default=DEFAULT_SINGLE_CASE["params"]["phi"])
parser.add_argument("--h", type=float, default=DEFAULT_SINGLE_CASE["params"]["h"])
parser.add_argument("--Cf", type=float, default=DEFAULT_SINGLE_CASE["params"]["Cf"])
parser.add_argument(
"--well-index",
type=int,
default=DEFAULT_SINGLE_CASE["well_index"],
help="Well index for solver output",
)
parser.add_argument(
"--timeQ",
type=str,
default=None,
help="Override schedule durations, e.g. '1000,1000,1000' or '0,1000,1000,1000'.",
)
parser.add_argument(
"--q",
type=str,
default=None,
help="Override schedule rates, e.g. '200,100,0' or '0,200,100,0'.",
)
parser.add_argument(
"--section-index",
type=int,
default=None,
help="Override tested section index. Defaults to the last section of the overridden schedule.",
)
return parser.parse_args()
def calc_metrics(
y_true: np.ndarray,
y_pred: np.ndarray,
eps_range: float = 1e-3,
eps_var: float = 1e-6,
) -> dict:
err = y_pred - y_true
mse = np.mean(err**2)
rmse = float(np.sqrt(mse))
mae = float(np.mean(np.abs(err)))
bias = float(np.mean(err))
value_range = float(np.max(y_true) - np.min(y_true))
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
ss_res = float(np.sum(err**2))
valid_nrmse = value_range > eps_range
valid_r2 = ss_tot > eps_var
nrmse = float(rmse / value_range) if valid_nrmse else np.nan
r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan
return {
"rmse": rmse,
"mae": mae,
"bias": bias,
"abs_bias": float(abs(bias)),
"nrmse": nrmse,
"r2": r2,
}
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
],
}
def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]:
parts: dict[str, np.ndarray] = {}
for part in layout["parts"]:
start = int(part["start"])
end = int(part["end"])
parts[str(part["name"])] = curve[start:end]
return parts
def resolve_cf(args: argparse.Namespace, cfg: Config) -> float:
if args.Cf is not None:
return float(args.Cf)
fixed_cfg = cfg.raw.get("params", {}).get("fixed_params", {}).get("Cf", {}) or {}
if bool(fixed_cfg.get("enabled", False)):
return float(fixed_cfg["value"])
raise ValueError("Cf 未在命令行提供,且配置文件里也没有启用固定 Cf")
def parse_float_list(raw: str, name: str) -> list[float]:
values: list[float] = []
for token in raw.replace(";", ",").replace(" ", ",").split(","):
token = token.strip()
if not token:
continue
values.append(float(token))
if not values:
raise ValueError(f"{name} 不能为空")
return values
def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) -> Schedule:
if args is not None and (args.timeQ is not None or args.q is not None or args.section_index is not None):
if args.timeQ is None or args.q is None:
raise ValueError("覆盖流量制度时必须同时提供 --timeQ 和 --q")
timeQ = parse_float_list(args.timeQ, "--timeQ")
q = parse_float_list(args.q, "--q")
if len(timeQ) != len(q):
raise ValueError(f"--timeQ 和 --q 长度必须一致,当前分别为 {len(timeQ)}{len(q)}")
# Allow the reporting/table convention with an initial "0 0" row.
if len(timeQ) >= 2 and timeQ[0] <= 0.0 and q[0] == 0.0:
timeQ = timeQ[1:]
q = q[1:]
section_index = int(args.section_index) if args.section_index is not None else len(timeQ)
schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q)
if not schedule.validate():
raise ValueError(
"命令行覆盖的流量制度非法。注意 compare 脚本里的 timeQ 表示每段持续时间,"
"如果有初始行请写成 0,0有效段必须 timeQ>0 且 q>=0。"
)
return schedule
schedule_cfg = cfg.raw["schedule"]["case_schedule"]
timeQ = list(map(float, schedule_cfg["timeQ"]))
q = list(map(float, schedule_cfg["q"]))
policy = cfg.raw["schedule"].get("section_policy", {}) or {}
mode = str(policy.get("mode", "fixed_last")).lower()
n_sections = len(timeQ)
if mode == "fixed_last":
section_index = n_sections
elif mode == "fixed_value":
section_index = int(np.clip(int(policy.get("fixed_value", n_sections)), 1, n_sections))
else:
section_index = int(np.clip(int(schedule_cfg.get("default_section_index", n_sections)), 1, n_sections))
schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q)
if not schedule.validate():
raise ValueError("配置文件中的 case_schedule 非法,无法用于单样本对比")
return schedule
def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]:
tag = normalize_tag(args.tag)
config_path = args.config
if config_path is None:
config_path = str(config_for_stage(args.stage) or Path("configs/data_gen_family_random.yaml"))
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = (
Path(args.model)
if args.model is not None
else model_checkpoint_for_tag(tag, use_schedule=not args.no_schedule)
)
if args.output_dir is not None:
output_dir = Path(args.output_dir)
else:
suffix = "" if not args.no_schedule else "_no_schedule"
output_dir = Path("results") / (f"single_compare_{tag}{suffix}" if tag else f"single_compare{suffix}")
return Config(config_path), processed_path, model_path, output_dir
def build_params_from_args(args: argparse.Namespace, cfg: Config, schedule: Schedule) -> Params:
return Params(
k=float(args.k),
skin=float(args.skin),
wellboreC=float(args.wellboreC),
phi=float(args.phi),
h=float(args.h),
Cf=resolve_cf(args, cfg),
schedule=schedule,
)
def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) -> tuple[np.ndarray, dict]:
runner = CppRunner(cfg=cfg)
try:
ok = runner.run_simulation(params, override_schedule=params.schedule, include_schedule=True)
result = read_result_bin(runner.result_bin) if runner.result_bin.exists() else None
if not ok and result is None:
raise RuntimeError("求解器运行失败,未生成有效结果")
if not ok and result is not None:
print("Warning: runner_client 返回失败,但 result.bin 可读取,继续使用求解器结果。")
if result is None:
raise RuntimeError("无法读取 result.bin")
if not result["loglog"]:
raise RuntimeError("求解器未返回 loglog 数据")
if well_index < 0 or well_index >= len(result["loglog"]):
raise IndexError(f"well-index={well_index} 超出范围,可用井数={len(result['loglog'])}")
loglog = result["loglog"][well_index]
t = np.asarray(loglog["t"], dtype=np.float64)
p = np.asarray(loglog["p"], dtype=np.float64)
d = np.asarray(loglog["deriv"], dtype=np.float64)
cleaned = clean_curve_for_dataset(cfg, t, p, d)
if cleaned is None:
raise RuntimeError("求解器返回曲线在清洗后无效")
t_clean, p_clean, d_clean = cleaned
valid, reason = is_valid_curve(cfg, t_clean, p_clean, d_clean)
if not valid:
raise RuntimeError(f"求解器曲线未通过有效性检查: {reason}")
curve_feat = resample_curve_to_features(cfg, t_clean, p_clean, d_clean)
raw = {
"t": t_clean.tolist(),
"p": p_clean.tolist(),
"d": d_clean.tolist(),
"n_steps": int(result["nSteps"]),
"n_wells": int(result["nWells"]),
}
return curve_feat, raw
finally:
runner.close()
def build_schedule_vector(cfg: Config, schedule: Schedule) -> np.ndarray:
return build_schedule_model_vector(cfg, schedule)
def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, bool, torch.device]:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
use_schedule = bool(checkpoint.get("use_schedule", True))
model = ForwardSurrogate(
param_dim=int(checkpoint["param_dim"]),
schedule_dim=int(checkpoint["schedule_dim"]),
curve_dim=int(checkpoint["curve_dim"]),
hidden_dim=int(checkpoint["hidden_dim"]),
dropout=float(checkpoint["dropout"]),
use_schedule=use_schedule,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
return model, use_schedule, device
def predict_surrogate_curve(
processed: dict,
model: ForwardSurrogate,
device: torch.device,
use_schedule: bool,
params: Params,
schedule: Schedule,
cfg: Config,
) -> np.ndarray:
scaler_params = processed["scaler_params"]
scaler_schedule = processed["scaler_schedule"]
scaler_curve = processed["scaler_curve"]
param_transform = param_feature_transform_from_meta(processed.get("meta", {}))
params_vec = np.asarray(
[params.k, params.skin, params.wellboreC, params.phi, params.h, params.Cf],
dtype=np.float32,
).reshape(1, -1)
schedule_vec = build_schedule_vector(cfg, schedule).reshape(1, -1)
params_x = scaler_params.transform(transform_param_features(params_vec, param_transform)).astype(np.float32)
schedule_x = scaler_schedule.transform(schedule_vec).astype(np.float32)
with torch.no_grad():
params_t = torch.tensor(params_x, dtype=torch.float32, device=device)
if use_schedule:
schedule_t = torch.tensor(schedule_x, dtype=torch.float32, device=device)
pred_scaled = model(params_t, schedule_t).cpu().numpy()
else:
pred_scaled = model(params_t, None).cpu().numpy()
return scaler_curve.inverse_transform(pred_scaled)[0].astype(np.float32)
def plot_comparison(
curve_true: np.ndarray,
curve_pred: np.ndarray,
curve_layout: dict,
output_path: Path,
params: Params,
schedule: Schedule,
model_path: Path,
use_schedule: bool,
) -> dict:
true_parts = split_curve_by_layout(curve_true, curve_layout)
pred_parts = split_curve_by_layout(curve_pred, curve_layout)
overall = calc_metrics(curve_true, curve_pred)
autofit = dual_log_objective(curve_true, curve_pred, curve_layout)
part_names = ["log_pressure", "log_derivative", "slope"]
title_map = {
"log_pressure": "Log Pressure",
"log_derivative": "Log |Derivative|",
"slope": "Slope of Log Pressure vs Log Time",
}
fig, axes = plt.subplots(3, 2, figsize=(14, 12))
fig.suptitle(
"Solver vs Surrogate\n"
f"k={params.k:.6g}, skin={params.skin:.6g}, wellboreC={params.wellboreC:.6g}, "
f"phi={params.phi:.6g}, h={params.h:.6g}, Cf={params.Cf:.6g}\n"
f"sectionIndex={schedule.sectionIndex}, timeQ={schedule.timeQ}, q={schedule.q}\n"
f"use_schedule={use_schedule}, model={model_path.name}, "
f"Overall RMSE={overall['rmse']:.4f}, MAE={overall['mae']:.4f}, "
f"AutoFitObj={autofit['dual_log_objective']:.4f}, "
f"R2={'nan' if np.isnan(overall['r2']) else f'{overall['r2']:.4f}'}"
)
summary = {"overall": overall, "autofit": autofit, "parts": {}}
for row, name in enumerate(part_names):
y_true = true_parts[name]
y_pred = pred_parts[name]
err = y_pred - y_true
x = np.arange(len(y_true))
metrics = calc_metrics(y_true, y_pred)
summary["parts"][name] = metrics
nrmse_text = "nan" if np.isnan(metrics["nrmse"]) else f"{metrics['nrmse']:.4f}"
r2_text = "nan" if np.isnan(metrics["r2"]) else f"{metrics['r2']:.4f}"
ax_l = axes[row, 0]
ax_l.plot(x, y_true, label="Solver", linewidth=2, alpha=0.85)
ax_l.plot(x, y_pred, label="Surrogate", linewidth=2, alpha=0.85)
ax_l.set_title(
f"{title_map[name]} | RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, "
f"Bias={metrics['bias']:.4f}, NRMSE={nrmse_text}, R2={r2_text}"
)
ax_l.set_xlabel("Resampled Time Index")
ax_l.set_ylabel("Value")
ax_l.grid(True, alpha=0.3)
ax_l.legend()
ax_r = axes[row, 1]
ax_r.plot(x, err, linewidth=1.5, alpha=0.85)
ax_r.axhline(0.0, linestyle="--", linewidth=1)
ax_r.set_title(f"{title_map[name]} Error")
ax_r.set_xlabel("Resampled Time Index")
ax_r.set_ylabel("Surrogate - Solver")
ax_r.grid(True, alpha=0.3)
plt.tight_layout(rect=[0, 0, 1, 0.94])
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
return summary
def main() -> None:
args = parse_args()
cfg, processed_path, model_path, output_dir = resolve_paths(args)
output_dir.mkdir(parents=True, exist_ok=True)
if not processed_path.exists():
raise FileNotFoundError(f"Processed 数据不存在: {processed_path}")
if not model_path.exists():
raise FileNotFoundError(f"模型 checkpoint 不存在: {model_path}")
processed = joblib.load(processed_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
schedule = resolve_case_schedule(cfg, args)
params = build_params_from_args(args, cfg, schedule)
print("Running solver...")
curve_solver, raw_solver = run_solver_and_extract_curve(cfg, params, args.well_index)
print("Running surrogate...")
model, use_schedule, device = load_model(model_path)
curve_pred = predict_surrogate_curve(
processed=processed,
model=model,
device=device,
use_schedule=use_schedule,
params=params,
schedule=schedule,
cfg=cfg,
)
plot_path = output_dir / "single_case_comparison.png"
summary = plot_comparison(
curve_true=curve_solver,
curve_pred=curve_pred,
curve_layout=curve_layout,
output_path=plot_path,
params=params,
schedule=schedule,
model_path=model_path,
use_schedule=use_schedule,
)
summary_payload = {
"config_path": str(cfg.path),
"processed_path": str(processed_path),
"model_path": str(model_path),
"use_schedule": use_schedule,
"device": str(device),
"params": {
"k": params.k,
"skin": params.skin,
"wellboreC": params.wellboreC,
"phi": params.phi,
"h": params.h,
"Cf": params.Cf,
},
"schedule": {
"sectionIndex": schedule.sectionIndex,
"timeQ": schedule.timeQ,
"q": schedule.q,
},
"metrics": summary,
"solver_raw_loglog": raw_solver,
}
with open(output_dir / "single_case_summary.json", "w", encoding="utf-8") as f:
json.dump(summary_payload, f, ensure_ascii=False, indent=2)
overall = summary["overall"]
autofit = summary["autofit"]
print("\nSingle-case comparison complete.")
print(f"Output dir: {output_dir}")
print(
f"Overall RMSE={overall['rmse']:.6f}, MAE={overall['mae']:.6f}, "
f"Bias={overall['bias']:.6f}, "
f"AutoFitObj={autofit['dual_log_objective']:.6f}, "
f"R2={'nan' if np.isnan(overall['r2']) else f'{overall['r2']:.6f}'}"
)
print(f"Plot saved: {plot_path}")
print(f"Summary saved: {output_dir / 'single_case_summary.json'}")
if __name__ == "__main__":
main()