from __future__ import annotations import argparse import json import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) import joblib import matplotlib.pyplot as plt import numpy as np import torch from src.common.config import Config from src.common.experiment_paths import ( config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag, ) from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features from src.data.param_features import param_feature_transform_from_meta, transform_param_features from src.data.params import Params, Schedule from src.data.runner_client import CppRunner, read_result_bin from src.data.schedule_features import build_schedule_model_vector from src.evaluation.autofit_objective import dual_log_objective from src.models.forward_surrogate import ForwardSurrogate DEFAULT_SINGLE_CASE = { "config": "configs/data_gen_family_random.yaml", "tag": "family_random_50k", "no_schedule": False, "well_index": 0, "params": { "k": 0.025, "skin": 0.0, "wellboreC": 0.01, "phi": 0.0245, "h": 9.144, "Cf": 0.0004315, }, } def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Compare solver output and surrogate prediction on a single parameter set" ) parser.add_argument("--config", type=str, default=DEFAULT_SINGLE_CASE["config"], help="Config yaml path") parser.add_argument( "--stage", choices=["fixed_case", "case_neighborhood", "family_random", "family_random_v2_q"], default=None, help="Optional stage to infer config path", ) parser.add_argument("--processed", type=str, default=None, help="Processed dataset path") parser.add_argument("--model", type=str, default=None, help="Surrogate checkpoint path") parser.add_argument("--output-dir", type=str, default=None, help="Output directory") parser.add_argument( "--tag", type=str, default=DEFAULT_SINGLE_CASE["tag"], help="Experiment tag for processed/model lookup", ) parser.add_argument( "--no-schedule", action="store_true", help="When --model is omitted, infer the no-schedule checkpoint path", ) parser.add_argument("--k", type=float, default=DEFAULT_SINGLE_CASE["params"]["k"]) parser.add_argument("--skin", type=float, default=DEFAULT_SINGLE_CASE["params"]["skin"]) parser.add_argument("--wellboreC", type=float, default=DEFAULT_SINGLE_CASE["params"]["wellboreC"]) parser.add_argument("--phi", type=float, default=DEFAULT_SINGLE_CASE["params"]["phi"]) parser.add_argument("--h", type=float, default=DEFAULT_SINGLE_CASE["params"]["h"]) parser.add_argument("--Cf", type=float, default=DEFAULT_SINGLE_CASE["params"]["Cf"]) parser.add_argument( "--well-index", type=int, default=DEFAULT_SINGLE_CASE["well_index"], help="Well index for solver output", ) parser.add_argument( "--timeQ", type=str, default=None, help="Override schedule durations, e.g. '1000,1000,1000' or '0,1000,1000,1000'.", ) parser.add_argument( "--q", type=str, default=None, help="Override schedule rates, e.g. '200,100,0' or '0,200,100,0'.", ) parser.add_argument( "--section-index", type=int, default=None, help="Override tested section index. Defaults to the last section of the overridden schedule.", ) return parser.parse_args() def calc_metrics( y_true: np.ndarray, y_pred: np.ndarray, eps_range: float = 1e-3, eps_var: float = 1e-6, ) -> dict: err = y_pred - y_true mse = np.mean(err**2) rmse = float(np.sqrt(mse)) mae = float(np.mean(np.abs(err))) bias = float(np.mean(err)) value_range = float(np.max(y_true) - np.min(y_true)) ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2)) ss_res = float(np.sum(err**2)) valid_nrmse = value_range > eps_range valid_r2 = ss_tot > eps_var nrmse = float(rmse / value_range) if valid_nrmse else np.nan r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan return { "rmse": rmse, "mae": mae, "bias": bias, "abs_bias": float(abs(bias)), "nrmse": nrmse, "r2": r2, } def infer_curve_layout(meta: dict, curve_dim: int) -> dict: curve_layout = meta.get("curve_layout") if curve_layout is not None: return curve_layout n_time_points = curve_dim // 3 return { "n_time_points": int(n_time_points), "parts": [ {"name": "log_pressure", "start": 0, "end": n_time_points}, {"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points}, {"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points}, ], } def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]: parts: dict[str, np.ndarray] = {} for part in layout["parts"]: start = int(part["start"]) end = int(part["end"]) parts[str(part["name"])] = curve[start:end] return parts def resolve_cf(args: argparse.Namespace, cfg: Config) -> float: if args.Cf is not None: return float(args.Cf) fixed_cfg = cfg.raw.get("params", {}).get("fixed_params", {}).get("Cf", {}) or {} if bool(fixed_cfg.get("enabled", False)): return float(fixed_cfg["value"]) raise ValueError("Cf 未在命令行提供,且配置文件里也没有启用固定 Cf") def parse_float_list(raw: str, name: str) -> list[float]: values: list[float] = [] for token in raw.replace(";", ",").replace(" ", ",").split(","): token = token.strip() if not token: continue values.append(float(token)) if not values: raise ValueError(f"{name} 不能为空") return values def resolve_case_schedule(cfg: Config, args: argparse.Namespace | None = None) -> Schedule: if args is not None and (args.timeQ is not None or args.q is not None or args.section_index is not None): if args.timeQ is None or args.q is None: raise ValueError("覆盖流量制度时必须同时提供 --timeQ 和 --q") timeQ = parse_float_list(args.timeQ, "--timeQ") q = parse_float_list(args.q, "--q") if len(timeQ) != len(q): raise ValueError(f"--timeQ 和 --q 长度必须一致,当前分别为 {len(timeQ)} 和 {len(q)}") # Allow the reporting/table convention with an initial "0 0" row. if len(timeQ) >= 2 and timeQ[0] <= 0.0 and q[0] == 0.0: timeQ = timeQ[1:] q = q[1:] section_index = int(args.section_index) if args.section_index is not None else len(timeQ) schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q) if not schedule.validate(): raise ValueError( "命令行覆盖的流量制度非法。注意 compare 脚本里的 timeQ 表示每段持续时间," "如果有初始行请写成 0,0;有效段必须 timeQ>0 且 q>=0。" ) return schedule schedule_cfg = cfg.raw["schedule"]["case_schedule"] timeQ = list(map(float, schedule_cfg["timeQ"])) q = list(map(float, schedule_cfg["q"])) policy = cfg.raw["schedule"].get("section_policy", {}) or {} mode = str(policy.get("mode", "fixed_last")).lower() n_sections = len(timeQ) if mode == "fixed_last": section_index = n_sections elif mode == "fixed_value": section_index = int(np.clip(int(policy.get("fixed_value", n_sections)), 1, n_sections)) else: section_index = int(np.clip(int(schedule_cfg.get("default_section_index", n_sections)), 1, n_sections)) schedule = Schedule(sectionIndex=section_index, timeQ=timeQ, q=q) if not schedule.validate(): raise ValueError("配置文件中的 case_schedule 非法,无法用于单样本对比") return schedule def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]: tag = normalize_tag(args.tag) config_path = args.config if config_path is None: config_path = str(config_for_stage(args.stage) or Path("configs/data_gen_family_random.yaml")) processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) model_path = ( Path(args.model) if args.model is not None else model_checkpoint_for_tag(tag, use_schedule=not args.no_schedule) ) if args.output_dir is not None: output_dir = Path(args.output_dir) else: suffix = "" if not args.no_schedule else "_no_schedule" output_dir = Path("results") / (f"single_compare_{tag}{suffix}" if tag else f"single_compare{suffix}") return Config(config_path), processed_path, model_path, output_dir def build_params_from_args(args: argparse.Namespace, cfg: Config, schedule: Schedule) -> Params: return Params( k=float(args.k), skin=float(args.skin), wellboreC=float(args.wellboreC), phi=float(args.phi), h=float(args.h), Cf=resolve_cf(args, cfg), schedule=schedule, ) def run_solver_and_extract_curve(cfg: Config, params: Params, well_index: int) -> tuple[np.ndarray, dict]: runner = CppRunner(cfg=cfg) try: ok = runner.run_simulation(params, override_schedule=params.schedule, include_schedule=True) result = read_result_bin(runner.result_bin) if runner.result_bin.exists() else None if not ok and result is None: raise RuntimeError("求解器运行失败,未生成有效结果") if not ok and result is not None: print("Warning: runner_client 返回失败,但 result.bin 可读取,继续使用求解器结果。") if result is None: raise RuntimeError("无法读取 result.bin") if not result["loglog"]: raise RuntimeError("求解器未返回 loglog 数据") if well_index < 0 or well_index >= len(result["loglog"]): raise IndexError(f"well-index={well_index} 超出范围,可用井数={len(result['loglog'])}") loglog = result["loglog"][well_index] t = np.asarray(loglog["t"], dtype=np.float64) p = np.asarray(loglog["p"], dtype=np.float64) d = np.asarray(loglog["deriv"], dtype=np.float64) cleaned = clean_curve_for_dataset(cfg, t, p, d) if cleaned is None: raise RuntimeError("求解器返回曲线在清洗后无效") t_clean, p_clean, d_clean = cleaned valid, reason = is_valid_curve(cfg, t_clean, p_clean, d_clean) if not valid: raise RuntimeError(f"求解器曲线未通过有效性检查: {reason}") curve_feat = resample_curve_to_features(cfg, t_clean, p_clean, d_clean) raw = { "t": t_clean.tolist(), "p": p_clean.tolist(), "d": d_clean.tolist(), "n_steps": int(result["nSteps"]), "n_wells": int(result["nWells"]), } return curve_feat, raw finally: runner.close() def build_schedule_vector(cfg: Config, schedule: Schedule) -> np.ndarray: return build_schedule_model_vector(cfg, schedule) def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, bool, torch.device]: checkpoint = torch.load(checkpoint_path, map_location="cpu") use_schedule = bool(checkpoint.get("use_schedule", True)) model = ForwardSurrogate( param_dim=int(checkpoint["param_dim"]), schedule_dim=int(checkpoint["schedule_dim"]), curve_dim=int(checkpoint["curve_dim"]), hidden_dim=int(checkpoint["hidden_dim"]), dropout=float(checkpoint["dropout"]), use_schedule=use_schedule, ) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) return model, use_schedule, device def predict_surrogate_curve( processed: dict, model: ForwardSurrogate, device: torch.device, use_schedule: bool, params: Params, schedule: Schedule, cfg: Config, ) -> np.ndarray: scaler_params = processed["scaler_params"] scaler_schedule = processed["scaler_schedule"] scaler_curve = processed["scaler_curve"] param_transform = param_feature_transform_from_meta(processed.get("meta", {})) params_vec = np.asarray( [params.k, params.skin, params.wellboreC, params.phi, params.h, params.Cf], dtype=np.float32, ).reshape(1, -1) schedule_vec = build_schedule_vector(cfg, schedule).reshape(1, -1) params_x = scaler_params.transform(transform_param_features(params_vec, param_transform)).astype(np.float32) schedule_x = scaler_schedule.transform(schedule_vec).astype(np.float32) with torch.no_grad(): params_t = torch.tensor(params_x, dtype=torch.float32, device=device) if use_schedule: schedule_t = torch.tensor(schedule_x, dtype=torch.float32, device=device) pred_scaled = model(params_t, schedule_t).cpu().numpy() else: pred_scaled = model(params_t, None).cpu().numpy() return scaler_curve.inverse_transform(pred_scaled)[0].astype(np.float32) def plot_comparison( curve_true: np.ndarray, curve_pred: np.ndarray, curve_layout: dict, output_path: Path, params: Params, schedule: Schedule, model_path: Path, use_schedule: bool, ) -> dict: true_parts = split_curve_by_layout(curve_true, curve_layout) pred_parts = split_curve_by_layout(curve_pred, curve_layout) overall = calc_metrics(curve_true, curve_pred) autofit = dual_log_objective(curve_true, curve_pred, curve_layout) part_names = ["log_pressure", "log_derivative", "slope"] title_map = { "log_pressure": "Log Pressure", "log_derivative": "Log |Derivative|", "slope": "Slope of Log Pressure vs Log Time", } fig, axes = plt.subplots(3, 2, figsize=(14, 12)) fig.suptitle( "Solver vs Surrogate\n" f"k={params.k:.6g}, skin={params.skin:.6g}, wellboreC={params.wellboreC:.6g}, " f"phi={params.phi:.6g}, h={params.h:.6g}, Cf={params.Cf:.6g}\n" f"sectionIndex={schedule.sectionIndex}, timeQ={schedule.timeQ}, q={schedule.q}\n" f"use_schedule={use_schedule}, model={model_path.name}, " f"Overall RMSE={overall['rmse']:.4f}, MAE={overall['mae']:.4f}, " f"AutoFitObj={autofit['dual_log_objective']:.4f}, " f"R2={'nan' if np.isnan(overall['r2']) else f'{overall['r2']:.4f}'}" ) summary = {"overall": overall, "autofit": autofit, "parts": {}} for row, name in enumerate(part_names): y_true = true_parts[name] y_pred = pred_parts[name] err = y_pred - y_true x = np.arange(len(y_true)) metrics = calc_metrics(y_true, y_pred) summary["parts"][name] = metrics nrmse_text = "nan" if np.isnan(metrics["nrmse"]) else f"{metrics['nrmse']:.4f}" r2_text = "nan" if np.isnan(metrics["r2"]) else f"{metrics['r2']:.4f}" ax_l = axes[row, 0] ax_l.plot(x, y_true, label="Solver", linewidth=2, alpha=0.85) ax_l.plot(x, y_pred, label="Surrogate", linewidth=2, alpha=0.85) ax_l.set_title( f"{title_map[name]} | RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, " f"Bias={metrics['bias']:.4f}, NRMSE={nrmse_text}, R2={r2_text}" ) ax_l.set_xlabel("Resampled Time Index") ax_l.set_ylabel("Value") ax_l.grid(True, alpha=0.3) ax_l.legend() ax_r = axes[row, 1] ax_r.plot(x, err, linewidth=1.5, alpha=0.85) ax_r.axhline(0.0, linestyle="--", linewidth=1) ax_r.set_title(f"{title_map[name]} Error") ax_r.set_xlabel("Resampled Time Index") ax_r.set_ylabel("Surrogate - Solver") ax_r.grid(True, alpha=0.3) plt.tight_layout(rect=[0, 0, 1, 0.94]) plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() return summary def main() -> None: args = parse_args() cfg, processed_path, model_path, output_dir = resolve_paths(args) output_dir.mkdir(parents=True, exist_ok=True) if not processed_path.exists(): raise FileNotFoundError(f"Processed 数据不存在: {processed_path}") if not model_path.exists(): raise FileNotFoundError(f"模型 checkpoint 不存在: {model_path}") processed = joblib.load(processed_path) curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"])) schedule = resolve_case_schedule(cfg, args) params = build_params_from_args(args, cfg, schedule) print("Running solver...") curve_solver, raw_solver = run_solver_and_extract_curve(cfg, params, args.well_index) print("Running surrogate...") model, use_schedule, device = load_model(model_path) curve_pred = predict_surrogate_curve( processed=processed, model=model, device=device, use_schedule=use_schedule, params=params, schedule=schedule, cfg=cfg, ) plot_path = output_dir / "single_case_comparison.png" summary = plot_comparison( curve_true=curve_solver, curve_pred=curve_pred, curve_layout=curve_layout, output_path=plot_path, params=params, schedule=schedule, model_path=model_path, use_schedule=use_schedule, ) summary_payload = { "config_path": str(cfg.path), "processed_path": str(processed_path), "model_path": str(model_path), "use_schedule": use_schedule, "device": str(device), "params": { "k": params.k, "skin": params.skin, "wellboreC": params.wellboreC, "phi": params.phi, "h": params.h, "Cf": params.Cf, }, "schedule": { "sectionIndex": schedule.sectionIndex, "timeQ": schedule.timeQ, "q": schedule.q, }, "metrics": summary, "solver_raw_loglog": raw_solver, } with open(output_dir / "single_case_summary.json", "w", encoding="utf-8") as f: json.dump(summary_payload, f, ensure_ascii=False, indent=2) overall = summary["overall"] autofit = summary["autofit"] print("\nSingle-case comparison complete.") print(f"Output dir: {output_dir}") print( f"Overall RMSE={overall['rmse']:.6f}, MAE={overall['mae']:.6f}, " f"Bias={overall['bias']:.6f}, " f"AutoFitObj={autofit['dual_log_objective']:.6f}, " f"R2={'nan' if np.isnan(overall['r2']) else f'{overall['r2']:.6f}'}" ) print(f"Plot saved: {plot_path}") print(f"Summary saved: {output_dir / 'single_case_summary.json'}") if __name__ == "__main__": main()