From b7721f7cc1d6f59eb372ee9a78c7cceb9d2cb48a Mon Sep 17 00:00:00 2001 From: 1294271022 <1294271022@qq.com> Date: Tue, 26 May 2026 15:07:23 +0800 Subject: [PATCH] =?UTF-8?q?1=E3=80=81=E6=96=B0=E5=A2=9E=E6=97=B6=E9=97=B4?= =?UTF-8?q?=E8=8C=83=E5=9B=B4=E8=AF=84=E4=BC=B0=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../scripts/evaluate_time_conditioned.py | 643 ++++++++++++++++++ 1 file changed, 643 insertions(+) create mode 100644 ML/nmWTAI-ML/scripts/evaluate_time_conditioned.py diff --git a/ML/nmWTAI-ML/scripts/evaluate_time_conditioned.py b/ML/nmWTAI-ML/scripts/evaluate_time_conditioned.py new file mode 100644 index 0000000..021a407 --- /dev/null +++ b/ML/nmWTAI-ML/scripts/evaluate_time_conditioned.py @@ -0,0 +1,643 @@ +from __future__ import annotations + +import argparse +import csv +import json +import random +import sys +from pathlib import Path +from typing import Iterable + +import joblib +import matplotlib.pyplot as plt +import numpy as np +import torch + +ROOT = Path(__file__).resolve().parents[1] +sys.path.append(str(ROOT)) + +from src.common.experiment_paths import normalize_tag, processed_path_for_tag +from src.data.param_features import inverse_transform_param_features +from src.models.time_conditioned_surrogate import TimeConditionedSurrogate +from src.training.train_forward import get_part_slices, infer_curve_layout + + +DEFAULT_RANDOM_SEED = 42 +DEFAULT_PSO_DOMAIN = { + "k_min": 0.001, + "k_max": 10.0, + "skin_min": -10.0, + "skin_max": 10.0, + "wellboreC_min": 1.0e-4, + "wellboreC_max": 2.0, + "phi_min": 0.01, + "phi_max": 0.5, + "h_min": 2.0, + "h_max": 50.0, +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Evaluate a time-conditioned point-wise surrogate") + parser.add_argument("--processed", type=str, default=None, help="Processed dataset path") + parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming") + parser.add_argument("--model", type=str, default=None, help="Model checkpoint path") + parser.add_argument("--output-dir", type=str, default=None, help="Optional evaluation output directory") + parser.add_argument("--batch-size", type=int, default=65536, help="Point batch size for inference") + parser.add_argument("--device", type=str, default=None, help="Override device, e.g. cpu or cuda") + parser.add_argument("--seed", type=int, default=DEFAULT_RANDOM_SEED) + parser.add_argument("--n-random-plots", type=int, default=5) + parser.add_argument("--n-best-plots", type=int, default=5) + parser.add_argument("--n-worst-plots", type=int, default=10) + parser.add_argument("--top-k-analysis", type=int, default=300) + parser.add_argument("--pso-k-min", type=float, default=DEFAULT_PSO_DOMAIN["k_min"]) + parser.add_argument("--pso-k-max", type=float, default=DEFAULT_PSO_DOMAIN["k_max"]) + parser.add_argument("--pso-h-min", type=float, default=DEFAULT_PSO_DOMAIN["h_min"]) + parser.add_argument("--pso-h-max", type=float, default=DEFAULT_PSO_DOMAIN["h_max"]) + parser.add_argument("--pso-skin-min", type=float, default=DEFAULT_PSO_DOMAIN["skin_min"]) + parser.add_argument("--pso-skin-max", type=float, default=DEFAULT_PSO_DOMAIN["skin_max"]) + parser.add_argument("--pso-wellboreC-min", type=float, default=DEFAULT_PSO_DOMAIN["wellboreC_min"]) + parser.add_argument("--pso-wellboreC-max", type=float, default=DEFAULT_PSO_DOMAIN["wellboreC_max"]) + parser.add_argument("--pso-phi-min", type=float, default=DEFAULT_PSO_DOMAIN["phi_min"]) + parser.add_argument("--pso-phi-max", type=float, default=DEFAULT_PSO_DOMAIN["phi_max"]) + return parser.parse_args() + + +def default_model_path(tag: str | None) -> Path: + if tag: + return Path("models") / f"time_conditioned_surrogate_{tag}" / "time_conditioned_surrogate_best.pt" + return Path("models/time_conditioned_surrogate/time_conditioned_surrogate_best.pt") + + +def default_output_dir(tag: str | None) -> Path: + if tag: + return Path("results") / f"evaluation_time_conditioned_{tag}" + return Path("results/evaluation_time_conditioned") + + +def percentile_summary(values: np.ndarray) -> dict: + x = np.asarray(values, dtype=np.float64).reshape(-1) + if x.size == 0: + return { + "min": None, + "p05": None, + "p25": None, + "median": None, + "p75": None, + "p90": None, + "p95": None, + "max": None, + } + return { + "min": float(np.min(x)), + "p05": float(np.percentile(x, 5)), + "p25": float(np.percentile(x, 25)), + "median": float(np.percentile(x, 50)), + "p75": float(np.percentile(x, 75)), + "p90": float(np.percentile(x, 90)), + "p95": float(np.percentile(x, 95)), + "max": float(np.max(x)), + } + + +def point_metrics(true: np.ndarray, pred: np.ndarray) -> dict: + err = np.asarray(pred, dtype=np.float64) - np.asarray(true, dtype=np.float64) + abs_err = np.abs(err) + return { + "rmse": float(np.sqrt(np.mean(err**2))), + "mae": float(np.mean(abs_err)), + "bias": float(np.mean(err)), + "p90_abs": float(np.percentile(abs_err, 90)), + "p95_abs": float(np.percentile(abs_err, 95)), + } + + +def sample_metrics(true_p: np.ndarray, pred_p: np.ndarray, true_d: np.ndarray, pred_d: np.ndarray) -> list[dict]: + rows: list[dict] = [] + for idx in range(true_p.shape[0]): + p_err = pred_p[idx] - true_p[idx] + d_err = pred_d[idx] - true_d[idx] + rmse_p = float(np.sqrt(np.mean(p_err**2))) + rmse_d = float(np.sqrt(np.mean(d_err**2))) + mae_p = float(np.mean(np.abs(p_err))) + mae_d = float(np.mean(np.abs(d_err))) + rows.append( + { + "idx": idx, + "rmse_p": rmse_p, + "rmse_d": rmse_d, + "mae_p": mae_p, + "mae_d": mae_d, + "score": float(rmse_p + 2.0 * rmse_d), + } + ) + return rows + + +def write_csv(path: Path, rows: list[dict], fieldnames: list[str] | None = None) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + if not rows: + path.write_text("", encoding="utf-8-sig") + return + names = fieldnames or list(rows[0].keys()) + with path.open("w", newline="", encoding="utf-8-sig") as f: + writer = csv.DictWriter(f, fieldnames=names, extrasaction="ignore") + writer.writeheader() + writer.writerows(rows) + + +def iter_batches(total: int, batch_size: int) -> Iterable[tuple[int, int]]: + batch = max(1, int(batch_size)) + for start in range(0, int(total), batch): + yield start, min(start + batch, int(total)) + + +def load_model(model_path: Path, device: torch.device) -> tuple[TimeConditionedSurrogate, dict]: + checkpoint = torch.load(model_path, map_location="cpu") + model = TimeConditionedSurrogate( + param_dim=int(checkpoint["param_dim"]), + schedule_dim=int(checkpoint["schedule_dim"]), + time_dim=int(checkpoint["time_dim"]), + hidden_dim=int(checkpoint["hidden_dim"]), + n_blocks=int(checkpoint["n_blocks"]), + dropout=float(checkpoint["dropout"]), + use_schedule=bool(checkpoint.get("use_schedule", True)), + ) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + model.to(device) + return model, checkpoint + + +def predict_scaled_points( + model: TimeConditionedSurrogate, + params_x: np.ndarray, + schedule_x: np.ndarray, + time_x: np.ndarray, + device: torch.device, + batch_size: int, +) -> np.ndarray: + n_samples, n_time, time_dim = time_x.shape + params_flat = np.repeat(params_x, n_time, axis=0) + schedule_flat = np.repeat(schedule_x, n_time, axis=0) + time_flat = time_x.reshape(n_samples * n_time, time_dim) + + pred_flat = np.empty((n_samples * n_time, 2), dtype=np.float32) + use_schedule = bool(model.use_schedule) + with torch.no_grad(): + for start, end in iter_batches(len(time_flat), batch_size): + params_t = torch.tensor(params_flat[start:end], dtype=torch.float32, device=device) + time_t = torch.tensor(time_flat[start:end], dtype=torch.float32, device=device) + if use_schedule: + schedule_t = torch.tensor(schedule_flat[start:end], dtype=torch.float32, device=device) + else: + schedule_t = None + pred_flat[start:end] = model(params_t, time_t, schedule_t).detach().cpu().numpy() + return pred_flat.reshape(n_samples, n_time, 2) + + +def inverse_curve_part(values_scaled: np.ndarray, scaler_curve: object, part_slice: slice) -> np.ndarray: + mean = np.asarray(scaler_curve.mean_[part_slice], dtype=np.float32) + scale = np.asarray(scaler_curve.scale_[part_slice], dtype=np.float32) + return values_scaled.astype(np.float32) * scale.reshape(1, -1) + mean.reshape(1, -1) + + +def recover_raw_params(data: dict) -> dict[str, np.ndarray]: + meta = data.get("meta", {}) or {} + features = data["scaler_params"].inverse_transform(data["X_params_test"]) + raw = inverse_transform_param_features(features, meta.get("param_feature_transform")) + names = list(meta.get("param_names") or ["k", "skin", "wellboreC", "phi", "h", "Cf"]) + return {name: raw[:, idx].astype(np.float64) for idx, name in enumerate(names[: raw.shape[1]])} + + +def build_pso_mask(params: dict[str, np.ndarray], args: argparse.Namespace) -> np.ndarray: + return ( + (params["k"] >= float(args.pso_k_min)) + & (params["k"] <= float(args.pso_k_max)) + & (params["skin"] >= float(args.pso_skin_min)) + & (params["skin"] <= float(args.pso_skin_max)) + & (params["wellboreC"] >= float(args.pso_wellboreC_min)) + & (params["wellboreC"] <= float(args.pso_wellboreC_max)) + & (params["phi"] >= float(args.pso_phi_min)) + & (params["phi"] <= float(args.pso_phi_max)) + & (params["h"] >= float(args.pso_h_min)) + & (params["h"] <= float(args.pso_h_max)) + ) + + +def summarize_group(score: np.ndarray, rmse_p: np.ndarray, rmse_d: np.ndarray, mask: np.ndarray) -> dict: + m = np.asarray(mask, dtype=bool) + return { + "n": int(np.sum(m)), + "score": percentile_summary(score[m]), + "rmse_p": percentile_summary(rmse_p[m]), + "rmse_d": percentile_summary(rmse_d[m]), + "score_gt_1_ratio": float(np.mean(score[m] > 1.0)) if np.any(m) else None, + "score_gt_2_ratio": float(np.mean(score[m] > 2.0)) if np.any(m) else None, + "score_gt_5_ratio": float(np.mean(score[m] > 5.0)) if np.any(m) else None, + } + + +def build_domain_summary(sample_rows: list[dict], params: dict[str, np.ndarray], pso_mask: np.ndarray) -> dict: + score = np.asarray([r["score"] for r in sample_rows], dtype=np.float64) + rmse_p = np.asarray([r["rmse_p"] for r in sample_rows], dtype=np.float64) + rmse_d = np.asarray([r["rmse_d"] for r in sample_rows], dtype=np.float64) + skin = params["skin"] + wellboreC = params["wellboreC"] + + order = np.argsort(-score) + top100 = order[: min(100, order.size)] + return { + "all": summarize_group(score, rmse_p, rmse_d, np.ones_like(pso_mask, dtype=bool)), + "pso_domain": summarize_group(score, rmse_p, rmse_d, pso_mask), + "outside_pso_domain": summarize_group(score, rmse_p, rmse_d, ~pso_mask), + "pso_skin_lt_minus_5": summarize_group(score, rmse_p, rmse_d, pso_mask & (skin < -5.0)), + "pso_skin_lt_minus_8": summarize_group(score, rmse_p, rmse_d, pso_mask & (skin < -8.0)), + "pso_skin_lt_minus_5_wellboreC_gt_0_1": summarize_group( + score, + rmse_p, + rmse_d, + pso_mask & (skin < -5.0) & (wellboreC > 0.1), + ), + "top100": { + "outside_pso_domain": int(np.sum(~pso_mask[top100])), + "k_lt_0_001": int(np.sum(params["k"][top100] < 0.001)), + "k_gt_10": int(np.sum(params["k"][top100] > 10.0)), + "h_gt_50": int(np.sum(params["h"][top100] > 50.0)), + "pso_skin_lt_minus_5_wellboreC_gt_0_1": int( + np.sum((pso_mask & (skin < -5.0) & (wellboreC > 0.1))[top100]) + ), + }, + } + + +def summarize_params_for_indices(params: dict[str, np.ndarray], indices: np.ndarray) -> dict: + return { + name: percentile_summary(values[np.asarray(indices, dtype=int)]) + for name, values in params.items() + if name in {"k", "skin", "wellboreC", "phi", "h", "Cf"} + } + + +def build_worst_case_summary( + sample_rows: list[dict], + params: dict[str, np.ndarray], + pso_mask: np.ndarray, + top_k: int, +) -> dict: + score = np.asarray([r["score"] for r in sample_rows], dtype=np.float64) + order_worst = np.argsort(-score) + order_best = np.argsort(score) + top = order_worst[: min(int(top_k), order_worst.size)] + worst100 = order_worst[: min(100, order_worst.size)] + best100 = order_best[: min(100, order_best.size)] + + return { + "top_k": int(top.size), + "metrics": { + "score": percentile_summary(score), + "n_score_gt_1": int(np.sum(score > 1.0)), + "n_score_gt_2": int(np.sum(score > 2.0)), + "n_score_gt_5": int(np.sum(score > 5.0)), + }, + "pso_domain": { + "n_inside": int(np.sum(pso_mask)), + "n_outside": int(np.sum(~pso_mask)), + "top100_outside": int(np.sum(~pso_mask[worst100])), + "top100_k_lt_0_001": int(np.sum(params["k"][worst100] < 0.001)), + "top100_k_gt_10": int(np.sum(params["k"][worst100] > 10.0)), + "top100_h_gt_50": int(np.sum(params["h"][worst100] > 50.0)), + }, + "params": { + "all": summarize_params_for_indices(params, np.arange(score.size)), + "worst_top_k": summarize_params_for_indices(params, top), + "worst100": summarize_params_for_indices(params, worst100), + "best100": summarize_params_for_indices(params, best100), + }, + } + + +def build_worst_case_rows( + sample_rows: list[dict], + params: dict[str, np.ndarray], + data: dict, + true_p: np.ndarray, + true_d: np.ndarray, + pred_p: np.ndarray, + pred_d: np.ndarray, + pso_mask: np.ndarray, + top_k: int, +) -> tuple[list[dict], list[dict]]: + score = np.asarray([r["score"] for r in sample_rows], dtype=np.float64) + order = np.argsort(-score)[: min(int(top_k), len(sample_rows))] + family = data.get("family_name_test") + schedule_meta = data.get("schedule_meta_test") + schedule_meta_names = list((data.get("meta", {}) or {}).get("schedule_meta_names") or []) + + case_rows: list[dict] = [] + residual_rows: list[dict] = [] + for rank, idx in enumerate(order, 1): + p_res = pred_p[idx] - true_p[idx] + d_res = pred_d[idx] - true_d[idx] + p_rmse = float(np.sqrt(np.mean(p_res**2))) + d_rmse = float(np.sqrt(np.mean(d_res**2))) + p_mean = float(np.mean(p_res)) + d_mean = float(np.mean(d_res)) + p_std = float(np.std(p_res)) + d_std = float(np.std(d_res)) + + row = { + "rank": rank, + "idx": int(idx), + "score": float(score[idx]), + "rmse_p": float(sample_rows[idx]["rmse_p"]), + "rmse_d": float(sample_rows[idx]["rmse_d"]), + "mae_p": float(sample_rows[idx]["mae_p"]), + "mae_d": float(sample_rows[idx]["mae_d"]), + "in_pso_domain": int(bool(pso_mask[idx])), + "family": str(family[idx]) if family is not None else "", + } + for name in ["k", "skin", "wellboreC", "phi", "h", "Cf"]: + if name in params: + row[name] = float(params[name][idx]) + if schedule_meta is not None: + for midx, name in enumerate(schedule_meta_names): + if midx < schedule_meta.shape[1]: + row[f"sched_{name}"] = float(schedule_meta[idx, midx]) + case_rows.append(row) + + residual_rows.append( + { + "rank": rank, + "idx": int(idx), + "score": float(score[idx]), + "p_res_mean": p_mean, + "p_res_std": p_std, + "p_res_rmse": p_rmse, + "p_shift_ratio": float(abs(p_mean) / max(p_rmse, 1.0e-12)), + "d_res_mean": d_mean, + "d_res_std": d_std, + "d_res_rmse": d_rmse, + "d_shift_ratio": float(abs(d_mean) / max(d_rmse, 1.0e-12)), + "p_res_first": float(p_res[0]), + "p_res_last": float(p_res[-1]), + "d_res_first": float(d_res[0]), + "d_res_last": float(d_res[-1]), + } + ) + return case_rows, residual_rows + + +def plot_sample( + output_path: Path, + idx: int, + t: np.ndarray, + true_p: np.ndarray, + pred_p: np.ndarray, + true_d: np.ndarray, + pred_d: np.ndarray, + title: str, +) -> None: + x = np.asarray(t, dtype=np.float64) + fig, axes = plt.subplots(2, 2, figsize=(13, 8)) + fig.suptitle(title) + + axes[0, 0].plot(x, true_p, label="True", linewidth=2) + axes[0, 0].plot(x, pred_p, label="Pred", linewidth=2) + axes[0, 0].set_title("Log Pressure") + axes[0, 0].set_xscale("log") + axes[0, 0].grid(True, alpha=0.3) + axes[0, 0].legend() + + axes[0, 1].plot(x, pred_p - true_p, linewidth=1.5) + axes[0, 1].axhline(0.0, linestyle="--", linewidth=1) + axes[0, 1].set_title("Pressure Residual") + axes[0, 1].set_xscale("log") + axes[0, 1].grid(True, alpha=0.3) + + axes[1, 0].plot(x, true_d, label="True", linewidth=2) + axes[1, 0].plot(x, pred_d, label="Pred", linewidth=2) + axes[1, 0].set_title("Log Derivative") + axes[1, 0].set_xscale("log") + axes[1, 0].grid(True, alpha=0.3) + axes[1, 0].legend() + + axes[1, 1].plot(x, pred_d - true_d, linewidth=1.5) + axes[1, 1].axhline(0.0, linestyle="--", linewidth=1) + axes[1, 1].set_title("Derivative Residual") + axes[1, 1].set_xscale("log") + axes[1, 1].grid(True, alpha=0.3) + + for ax in axes.ravel(): + ax.set_xlabel("Time") + + output_path.parent.mkdir(parents=True, exist_ok=True) + plt.tight_layout(rect=[0, 0, 1, 0.95]) + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close(fig) + + +def write_plots( + output_dir: Path, + sample_rows: list[dict], + t_curve: np.ndarray, + true_p: np.ndarray, + true_d: np.ndarray, + pred_p: np.ndarray, + pred_d: np.ndarray, + args: argparse.Namespace, +) -> None: + plot_dir = output_dir / "plots" + score = np.asarray([r["score"] for r in sample_rows], dtype=np.float64) + random.seed(int(args.seed)) + n_random = min(int(args.n_random_plots), len(sample_rows)) + n_best = min(int(args.n_best_plots), len(sample_rows)) + n_worst = min(int(args.n_worst_plots), len(sample_rows)) + best = np.argsort(score)[:n_best].tolist() + worst = np.argsort(-score)[:n_worst].tolist() + random_idx = random.sample(range(len(sample_rows)), n_random) + + for idx in random_idx: + plot_sample( + plot_dir / f"sample_{idx:04d}.png", + idx, + t_curve[idx], + true_p[idx], + pred_p[idx], + true_d[idx], + pred_d[idx], + f"Random sample {idx} | score={score[idx]:.4f}", + ) + for idx in best: + plot_sample( + plot_dir / f"best_sample_{idx:04d}.png", + idx, + t_curve[idx], + true_p[idx], + pred_p[idx], + true_d[idx], + pred_d[idx], + f"Best sample {idx} | score={score[idx]:.4f}", + ) + for idx in worst: + plot_sample( + plot_dir / f"worst_sample_{idx:04d}.png", + idx, + t_curve[idx], + true_p[idx], + pred_p[idx], + true_d[idx], + pred_d[idx], + f"Worst sample {idx} | score={score[idx]:.4f}", + ) + + +def main() -> None: + args = parse_args() + tag = normalize_tag(args.tag) + processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) + model_path = Path(args.model) if args.model is not None else default_model_path(tag) + output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag) + output_dir.mkdir(parents=True, exist_ok=True) + + device_name = args.device or ("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device(device_name) + + print("Loading processed dataset...") + data = joblib.load(processed_path) + required = ["X_params_test", "X_schedule_test", "X_time_test", "Y_curve_test"] + missing = [key for key in required if key not in data] + if missing: + raise KeyError(f"processed dataset is missing time-conditioned fields: {missing}") + + print("Loading model...") + model, checkpoint = load_model(model_path, device) + curve_layout = checkpoint.get("curve_layout") or infer_curve_layout(data) + slices = get_part_slices(curve_layout) + + x_params = np.asarray(data["X_params_test"], dtype=np.float32) + x_schedule = np.asarray(data["X_schedule_test"], dtype=np.float32) + x_time = np.asarray(data["X_time_test"], dtype=np.float32) + y_curve = np.asarray(data["Y_curve_test"], dtype=np.float32) + scaler_curve = data["scaler_curve"] + + print( + f"test={x_params.shape[0]}, n_time={x_time.shape[1]}, " + f"param_dim={x_params.shape[1]}, schedule_dim={x_schedule.shape[1]}, time_dim={x_time.shape[-1]}" + ) + print(f"device={device}, batch_size={args.batch_size}") + + pred_scaled = predict_scaled_points( + model=model, + params_x=x_params, + schedule_x=x_schedule, + time_x=x_time, + device=device, + batch_size=int(args.batch_size), + ) + + p_slice = slices["log_pressure"] + d_slice = slices["log_derivative"] + true_p_scaled = y_curve[:, p_slice] + true_d_scaled = y_curve[:, d_slice] + pred_p_scaled = pred_scaled[:, :, 0] + pred_d_scaled = pred_scaled[:, :, 1] + + true_p = inverse_curve_part(true_p_scaled, scaler_curve, p_slice) + true_d = inverse_curve_part(true_d_scaled, scaler_curve, d_slice) + pred_p = inverse_curve_part(pred_p_scaled, scaler_curve, p_slice) + pred_d = inverse_curve_part(pred_d_scaled, scaler_curve, d_slice) + + summary = { + "processed_path": str(processed_path), + "model_path": str(model_path), + "device": str(device), + "checkpoint": { + "hidden_dim": int(checkpoint["hidden_dim"]), + "n_blocks": int(checkpoint["n_blocks"]), + "dropout": float(checkpoint["dropout"]), + "use_schedule": bool(checkpoint.get("use_schedule", True)), + }, + "scaled_log_pressure": point_metrics(true_p_scaled, pred_p_scaled), + "scaled_log_derivative": point_metrics(true_d_scaled, pred_d_scaled), + "raw_log_pressure": point_metrics(true_p, pred_p), + "raw_log_derivative": point_metrics(true_d, pred_d), + } + + rows = sample_metrics(true_p=true_p, pred_p=pred_p, true_d=true_d, pred_d=pred_d) + params = recover_raw_params(data) + pso_mask = build_pso_mask(params, args) + domain_summary = build_domain_summary(rows, params, pso_mask) + summary["pso_domain"] = { + "bounds": { + "k": [float(args.pso_k_min), float(args.pso_k_max)], + "skin": [float(args.pso_skin_min), float(args.pso_skin_max)], + "wellboreC": [float(args.pso_wellboreC_min), float(args.pso_wellboreC_max)], + "phi": [float(args.pso_phi_min), float(args.pso_phi_max)], + "h": [float(args.pso_h_min), float(args.pso_h_max)], + }, + "metrics": domain_summary, + } + + case_rows, residual_rows = build_worst_case_rows( + sample_rows=rows, + params=params, + data=data, + true_p=true_p, + true_d=true_d, + pred_p=pred_p, + pred_d=pred_d, + pso_mask=pso_mask, + top_k=int(args.top_k_analysis), + ) + worst_case_summary = build_worst_case_summary( + sample_rows=rows, + params=params, + pso_mask=pso_mask, + top_k=int(args.top_k_analysis), + ) + + (output_dir / "summary_metrics.json").write_text( + json.dumps(summary, indent=2, ensure_ascii=False), + encoding="utf-8", + ) + write_csv(output_dir / "sample_metrics.csv", rows) + write_csv(output_dir / "worst_case_analysis.csv", case_rows) + write_csv(output_dir / "worst_residual_analysis.csv", residual_rows) + (output_dir / "worst_case_summary.json").write_text( + json.dumps(worst_case_summary, indent=2, ensure_ascii=False), + encoding="utf-8", + ) + + residual_summary = { + "top_k": int(len(residual_rows)), + "top300_p_shift_ratio_median": float(np.median([r["p_shift_ratio"] for r in residual_rows])), + "top300_d_shift_ratio_median": float(np.median([r["d_shift_ratio"] for r in residual_rows])), + "top100_p_shift_ratio_median": float(np.median([r["p_shift_ratio"] for r in residual_rows[:100]])), + "top100_d_shift_ratio_median": float(np.median([r["d_shift_ratio"] for r in residual_rows[:100]])), + "top20": residual_rows[:20], + } + (output_dir / "worst_residual_summary.json").write_text( + json.dumps(residual_summary, indent=2, ensure_ascii=False), + encoding="utf-8", + ) + + t_curve = np.asarray(data.get("T_curve_test"), dtype=np.float32) + if t_curve.ndim == 2 and t_curve.shape == true_p.shape: + write_plots(output_dir, rows, t_curve, true_p, true_d, pred_p, pred_d, args) + + print("\nEvaluation complete.") + print(f"raw_log_pressure RMSE={summary['raw_log_pressure']['rmse']:.6f}, MAE={summary['raw_log_pressure']['mae']:.6f}") + print(f"raw_log_derivative RMSE={summary['raw_log_derivative']['rmse']:.6f}, MAE={summary['raw_log_derivative']['mae']:.6f}") + print( + "PSO-domain: " + f"n={domain_summary['pso_domain']['n']}, " + f"median={domain_summary['pso_domain']['score']['median']:.6f}, " + f"p95={domain_summary['pso_domain']['score']['p95']:.6f}, " + f"score>1={100.0 * domain_summary['pso_domain']['score_gt_1_ratio']:.3f}%" + ) + print(f"Artifacts written to: {output_dir}") + + +if __name__ == "__main__": + main()