1、新增时间范围评估脚本
parent
99e74efa8c
commit
b7721f7cc1
@ -0,0 +1,643 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Iterable
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[1]
|
||||||
|
sys.path.append(str(ROOT))
|
||||||
|
|
||||||
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
|
||||||
|
from src.data.param_features import inverse_transform_param_features
|
||||||
|
from src.models.time_conditioned_surrogate import TimeConditionedSurrogate
|
||||||
|
from src.training.train_forward import get_part_slices, infer_curve_layout
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_RANDOM_SEED = 42
|
||||||
|
DEFAULT_PSO_DOMAIN = {
|
||||||
|
"k_min": 0.001,
|
||||||
|
"k_max": 10.0,
|
||||||
|
"skin_min": -10.0,
|
||||||
|
"skin_max": 10.0,
|
||||||
|
"wellboreC_min": 1.0e-4,
|
||||||
|
"wellboreC_max": 2.0,
|
||||||
|
"phi_min": 0.01,
|
||||||
|
"phi_max": 0.5,
|
||||||
|
"h_min": 2.0,
|
||||||
|
"h_max": 50.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(description="Evaluate a time-conditioned point-wise surrogate")
|
||||||
|
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
|
||||||
|
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
|
||||||
|
parser.add_argument("--model", type=str, default=None, help="Model checkpoint path")
|
||||||
|
parser.add_argument("--output-dir", type=str, default=None, help="Optional evaluation output directory")
|
||||||
|
parser.add_argument("--batch-size", type=int, default=65536, help="Point batch size for inference")
|
||||||
|
parser.add_argument("--device", type=str, default=None, help="Override device, e.g. cpu or cuda")
|
||||||
|
parser.add_argument("--seed", type=int, default=DEFAULT_RANDOM_SEED)
|
||||||
|
parser.add_argument("--n-random-plots", type=int, default=5)
|
||||||
|
parser.add_argument("--n-best-plots", type=int, default=5)
|
||||||
|
parser.add_argument("--n-worst-plots", type=int, default=10)
|
||||||
|
parser.add_argument("--top-k-analysis", type=int, default=300)
|
||||||
|
parser.add_argument("--pso-k-min", type=float, default=DEFAULT_PSO_DOMAIN["k_min"])
|
||||||
|
parser.add_argument("--pso-k-max", type=float, default=DEFAULT_PSO_DOMAIN["k_max"])
|
||||||
|
parser.add_argument("--pso-h-min", type=float, default=DEFAULT_PSO_DOMAIN["h_min"])
|
||||||
|
parser.add_argument("--pso-h-max", type=float, default=DEFAULT_PSO_DOMAIN["h_max"])
|
||||||
|
parser.add_argument("--pso-skin-min", type=float, default=DEFAULT_PSO_DOMAIN["skin_min"])
|
||||||
|
parser.add_argument("--pso-skin-max", type=float, default=DEFAULT_PSO_DOMAIN["skin_max"])
|
||||||
|
parser.add_argument("--pso-wellboreC-min", type=float, default=DEFAULT_PSO_DOMAIN["wellboreC_min"])
|
||||||
|
parser.add_argument("--pso-wellboreC-max", type=float, default=DEFAULT_PSO_DOMAIN["wellboreC_max"])
|
||||||
|
parser.add_argument("--pso-phi-min", type=float, default=DEFAULT_PSO_DOMAIN["phi_min"])
|
||||||
|
parser.add_argument("--pso-phi-max", type=float, default=DEFAULT_PSO_DOMAIN["phi_max"])
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def default_model_path(tag: str | None) -> Path:
|
||||||
|
if tag:
|
||||||
|
return Path("models") / f"time_conditioned_surrogate_{tag}" / "time_conditioned_surrogate_best.pt"
|
||||||
|
return Path("models/time_conditioned_surrogate/time_conditioned_surrogate_best.pt")
|
||||||
|
|
||||||
|
|
||||||
|
def default_output_dir(tag: str | None) -> Path:
|
||||||
|
if tag:
|
||||||
|
return Path("results") / f"evaluation_time_conditioned_{tag}"
|
||||||
|
return Path("results/evaluation_time_conditioned")
|
||||||
|
|
||||||
|
|
||||||
|
def percentile_summary(values: np.ndarray) -> dict:
|
||||||
|
x = np.asarray(values, dtype=np.float64).reshape(-1)
|
||||||
|
if x.size == 0:
|
||||||
|
return {
|
||||||
|
"min": None,
|
||||||
|
"p05": None,
|
||||||
|
"p25": None,
|
||||||
|
"median": None,
|
||||||
|
"p75": None,
|
||||||
|
"p90": None,
|
||||||
|
"p95": None,
|
||||||
|
"max": None,
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
"min": float(np.min(x)),
|
||||||
|
"p05": float(np.percentile(x, 5)),
|
||||||
|
"p25": float(np.percentile(x, 25)),
|
||||||
|
"median": float(np.percentile(x, 50)),
|
||||||
|
"p75": float(np.percentile(x, 75)),
|
||||||
|
"p90": float(np.percentile(x, 90)),
|
||||||
|
"p95": float(np.percentile(x, 95)),
|
||||||
|
"max": float(np.max(x)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def point_metrics(true: np.ndarray, pred: np.ndarray) -> dict:
|
||||||
|
err = np.asarray(pred, dtype=np.float64) - np.asarray(true, dtype=np.float64)
|
||||||
|
abs_err = np.abs(err)
|
||||||
|
return {
|
||||||
|
"rmse": float(np.sqrt(np.mean(err**2))),
|
||||||
|
"mae": float(np.mean(abs_err)),
|
||||||
|
"bias": float(np.mean(err)),
|
||||||
|
"p90_abs": float(np.percentile(abs_err, 90)),
|
||||||
|
"p95_abs": float(np.percentile(abs_err, 95)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def sample_metrics(true_p: np.ndarray, pred_p: np.ndarray, true_d: np.ndarray, pred_d: np.ndarray) -> list[dict]:
|
||||||
|
rows: list[dict] = []
|
||||||
|
for idx in range(true_p.shape[0]):
|
||||||
|
p_err = pred_p[idx] - true_p[idx]
|
||||||
|
d_err = pred_d[idx] - true_d[idx]
|
||||||
|
rmse_p = float(np.sqrt(np.mean(p_err**2)))
|
||||||
|
rmse_d = float(np.sqrt(np.mean(d_err**2)))
|
||||||
|
mae_p = float(np.mean(np.abs(p_err)))
|
||||||
|
mae_d = float(np.mean(np.abs(d_err)))
|
||||||
|
rows.append(
|
||||||
|
{
|
||||||
|
"idx": idx,
|
||||||
|
"rmse_p": rmse_p,
|
||||||
|
"rmse_d": rmse_d,
|
||||||
|
"mae_p": mae_p,
|
||||||
|
"mae_d": mae_d,
|
||||||
|
"score": float(rmse_p + 2.0 * rmse_d),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
|
def write_csv(path: Path, rows: list[dict], fieldnames: list[str] | None = None) -> None:
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
if not rows:
|
||||||
|
path.write_text("", encoding="utf-8-sig")
|
||||||
|
return
|
||||||
|
names = fieldnames or list(rows[0].keys())
|
||||||
|
with path.open("w", newline="", encoding="utf-8-sig") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=names, extrasaction="ignore")
|
||||||
|
writer.writeheader()
|
||||||
|
writer.writerows(rows)
|
||||||
|
|
||||||
|
|
||||||
|
def iter_batches(total: int, batch_size: int) -> Iterable[tuple[int, int]]:
|
||||||
|
batch = max(1, int(batch_size))
|
||||||
|
for start in range(0, int(total), batch):
|
||||||
|
yield start, min(start + batch, int(total))
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_path: Path, device: torch.device) -> tuple[TimeConditionedSurrogate, dict]:
|
||||||
|
checkpoint = torch.load(model_path, map_location="cpu")
|
||||||
|
model = TimeConditionedSurrogate(
|
||||||
|
param_dim=int(checkpoint["param_dim"]),
|
||||||
|
schedule_dim=int(checkpoint["schedule_dim"]),
|
||||||
|
time_dim=int(checkpoint["time_dim"]),
|
||||||
|
hidden_dim=int(checkpoint["hidden_dim"]),
|
||||||
|
n_blocks=int(checkpoint["n_blocks"]),
|
||||||
|
dropout=float(checkpoint["dropout"]),
|
||||||
|
use_schedule=bool(checkpoint.get("use_schedule", True)),
|
||||||
|
)
|
||||||
|
model.load_state_dict(checkpoint["model_state_dict"])
|
||||||
|
model.eval()
|
||||||
|
model.to(device)
|
||||||
|
return model, checkpoint
|
||||||
|
|
||||||
|
|
||||||
|
def predict_scaled_points(
|
||||||
|
model: TimeConditionedSurrogate,
|
||||||
|
params_x: np.ndarray,
|
||||||
|
schedule_x: np.ndarray,
|
||||||
|
time_x: np.ndarray,
|
||||||
|
device: torch.device,
|
||||||
|
batch_size: int,
|
||||||
|
) -> np.ndarray:
|
||||||
|
n_samples, n_time, time_dim = time_x.shape
|
||||||
|
params_flat = np.repeat(params_x, n_time, axis=0)
|
||||||
|
schedule_flat = np.repeat(schedule_x, n_time, axis=0)
|
||||||
|
time_flat = time_x.reshape(n_samples * n_time, time_dim)
|
||||||
|
|
||||||
|
pred_flat = np.empty((n_samples * n_time, 2), dtype=np.float32)
|
||||||
|
use_schedule = bool(model.use_schedule)
|
||||||
|
with torch.no_grad():
|
||||||
|
for start, end in iter_batches(len(time_flat), batch_size):
|
||||||
|
params_t = torch.tensor(params_flat[start:end], dtype=torch.float32, device=device)
|
||||||
|
time_t = torch.tensor(time_flat[start:end], dtype=torch.float32, device=device)
|
||||||
|
if use_schedule:
|
||||||
|
schedule_t = torch.tensor(schedule_flat[start:end], dtype=torch.float32, device=device)
|
||||||
|
else:
|
||||||
|
schedule_t = None
|
||||||
|
pred_flat[start:end] = model(params_t, time_t, schedule_t).detach().cpu().numpy()
|
||||||
|
return pred_flat.reshape(n_samples, n_time, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def inverse_curve_part(values_scaled: np.ndarray, scaler_curve: object, part_slice: slice) -> np.ndarray:
|
||||||
|
mean = np.asarray(scaler_curve.mean_[part_slice], dtype=np.float32)
|
||||||
|
scale = np.asarray(scaler_curve.scale_[part_slice], dtype=np.float32)
|
||||||
|
return values_scaled.astype(np.float32) * scale.reshape(1, -1) + mean.reshape(1, -1)
|
||||||
|
|
||||||
|
|
||||||
|
def recover_raw_params(data: dict) -> dict[str, np.ndarray]:
|
||||||
|
meta = data.get("meta", {}) or {}
|
||||||
|
features = data["scaler_params"].inverse_transform(data["X_params_test"])
|
||||||
|
raw = inverse_transform_param_features(features, meta.get("param_feature_transform"))
|
||||||
|
names = list(meta.get("param_names") or ["k", "skin", "wellboreC", "phi", "h", "Cf"])
|
||||||
|
return {name: raw[:, idx].astype(np.float64) for idx, name in enumerate(names[: raw.shape[1]])}
|
||||||
|
|
||||||
|
|
||||||
|
def build_pso_mask(params: dict[str, np.ndarray], args: argparse.Namespace) -> np.ndarray:
|
||||||
|
return (
|
||||||
|
(params["k"] >= float(args.pso_k_min))
|
||||||
|
& (params["k"] <= float(args.pso_k_max))
|
||||||
|
& (params["skin"] >= float(args.pso_skin_min))
|
||||||
|
& (params["skin"] <= float(args.pso_skin_max))
|
||||||
|
& (params["wellboreC"] >= float(args.pso_wellboreC_min))
|
||||||
|
& (params["wellboreC"] <= float(args.pso_wellboreC_max))
|
||||||
|
& (params["phi"] >= float(args.pso_phi_min))
|
||||||
|
& (params["phi"] <= float(args.pso_phi_max))
|
||||||
|
& (params["h"] >= float(args.pso_h_min))
|
||||||
|
& (params["h"] <= float(args.pso_h_max))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_group(score: np.ndarray, rmse_p: np.ndarray, rmse_d: np.ndarray, mask: np.ndarray) -> dict:
|
||||||
|
m = np.asarray(mask, dtype=bool)
|
||||||
|
return {
|
||||||
|
"n": int(np.sum(m)),
|
||||||
|
"score": percentile_summary(score[m]),
|
||||||
|
"rmse_p": percentile_summary(rmse_p[m]),
|
||||||
|
"rmse_d": percentile_summary(rmse_d[m]),
|
||||||
|
"score_gt_1_ratio": float(np.mean(score[m] > 1.0)) if np.any(m) else None,
|
||||||
|
"score_gt_2_ratio": float(np.mean(score[m] > 2.0)) if np.any(m) else None,
|
||||||
|
"score_gt_5_ratio": float(np.mean(score[m] > 5.0)) if np.any(m) else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_domain_summary(sample_rows: list[dict], params: dict[str, np.ndarray], pso_mask: np.ndarray) -> dict:
|
||||||
|
score = np.asarray([r["score"] for r in sample_rows], dtype=np.float64)
|
||||||
|
rmse_p = np.asarray([r["rmse_p"] for r in sample_rows], dtype=np.float64)
|
||||||
|
rmse_d = np.asarray([r["rmse_d"] for r in sample_rows], dtype=np.float64)
|
||||||
|
skin = params["skin"]
|
||||||
|
wellboreC = params["wellboreC"]
|
||||||
|
|
||||||
|
order = np.argsort(-score)
|
||||||
|
top100 = order[: min(100, order.size)]
|
||||||
|
return {
|
||||||
|
"all": summarize_group(score, rmse_p, rmse_d, np.ones_like(pso_mask, dtype=bool)),
|
||||||
|
"pso_domain": summarize_group(score, rmse_p, rmse_d, pso_mask),
|
||||||
|
"outside_pso_domain": summarize_group(score, rmse_p, rmse_d, ~pso_mask),
|
||||||
|
"pso_skin_lt_minus_5": summarize_group(score, rmse_p, rmse_d, pso_mask & (skin < -5.0)),
|
||||||
|
"pso_skin_lt_minus_8": summarize_group(score, rmse_p, rmse_d, pso_mask & (skin < -8.0)),
|
||||||
|
"pso_skin_lt_minus_5_wellboreC_gt_0_1": summarize_group(
|
||||||
|
score,
|
||||||
|
rmse_p,
|
||||||
|
rmse_d,
|
||||||
|
pso_mask & (skin < -5.0) & (wellboreC > 0.1),
|
||||||
|
),
|
||||||
|
"top100": {
|
||||||
|
"outside_pso_domain": int(np.sum(~pso_mask[top100])),
|
||||||
|
"k_lt_0_001": int(np.sum(params["k"][top100] < 0.001)),
|
||||||
|
"k_gt_10": int(np.sum(params["k"][top100] > 10.0)),
|
||||||
|
"h_gt_50": int(np.sum(params["h"][top100] > 50.0)),
|
||||||
|
"pso_skin_lt_minus_5_wellboreC_gt_0_1": int(
|
||||||
|
np.sum((pso_mask & (skin < -5.0) & (wellboreC > 0.1))[top100])
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def summarize_params_for_indices(params: dict[str, np.ndarray], indices: np.ndarray) -> dict:
|
||||||
|
return {
|
||||||
|
name: percentile_summary(values[np.asarray(indices, dtype=int)])
|
||||||
|
for name, values in params.items()
|
||||||
|
if name in {"k", "skin", "wellboreC", "phi", "h", "Cf"}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_worst_case_summary(
|
||||||
|
sample_rows: list[dict],
|
||||||
|
params: dict[str, np.ndarray],
|
||||||
|
pso_mask: np.ndarray,
|
||||||
|
top_k: int,
|
||||||
|
) -> dict:
|
||||||
|
score = np.asarray([r["score"] for r in sample_rows], dtype=np.float64)
|
||||||
|
order_worst = np.argsort(-score)
|
||||||
|
order_best = np.argsort(score)
|
||||||
|
top = order_worst[: min(int(top_k), order_worst.size)]
|
||||||
|
worst100 = order_worst[: min(100, order_worst.size)]
|
||||||
|
best100 = order_best[: min(100, order_best.size)]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"top_k": int(top.size),
|
||||||
|
"metrics": {
|
||||||
|
"score": percentile_summary(score),
|
||||||
|
"n_score_gt_1": int(np.sum(score > 1.0)),
|
||||||
|
"n_score_gt_2": int(np.sum(score > 2.0)),
|
||||||
|
"n_score_gt_5": int(np.sum(score > 5.0)),
|
||||||
|
},
|
||||||
|
"pso_domain": {
|
||||||
|
"n_inside": int(np.sum(pso_mask)),
|
||||||
|
"n_outside": int(np.sum(~pso_mask)),
|
||||||
|
"top100_outside": int(np.sum(~pso_mask[worst100])),
|
||||||
|
"top100_k_lt_0_001": int(np.sum(params["k"][worst100] < 0.001)),
|
||||||
|
"top100_k_gt_10": int(np.sum(params["k"][worst100] > 10.0)),
|
||||||
|
"top100_h_gt_50": int(np.sum(params["h"][worst100] > 50.0)),
|
||||||
|
},
|
||||||
|
"params": {
|
||||||
|
"all": summarize_params_for_indices(params, np.arange(score.size)),
|
||||||
|
"worst_top_k": summarize_params_for_indices(params, top),
|
||||||
|
"worst100": summarize_params_for_indices(params, worst100),
|
||||||
|
"best100": summarize_params_for_indices(params, best100),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_worst_case_rows(
|
||||||
|
sample_rows: list[dict],
|
||||||
|
params: dict[str, np.ndarray],
|
||||||
|
data: dict,
|
||||||
|
true_p: np.ndarray,
|
||||||
|
true_d: np.ndarray,
|
||||||
|
pred_p: np.ndarray,
|
||||||
|
pred_d: np.ndarray,
|
||||||
|
pso_mask: np.ndarray,
|
||||||
|
top_k: int,
|
||||||
|
) -> tuple[list[dict], list[dict]]:
|
||||||
|
score = np.asarray([r["score"] for r in sample_rows], dtype=np.float64)
|
||||||
|
order = np.argsort(-score)[: min(int(top_k), len(sample_rows))]
|
||||||
|
family = data.get("family_name_test")
|
||||||
|
schedule_meta = data.get("schedule_meta_test")
|
||||||
|
schedule_meta_names = list((data.get("meta", {}) or {}).get("schedule_meta_names") or [])
|
||||||
|
|
||||||
|
case_rows: list[dict] = []
|
||||||
|
residual_rows: list[dict] = []
|
||||||
|
for rank, idx in enumerate(order, 1):
|
||||||
|
p_res = pred_p[idx] - true_p[idx]
|
||||||
|
d_res = pred_d[idx] - true_d[idx]
|
||||||
|
p_rmse = float(np.sqrt(np.mean(p_res**2)))
|
||||||
|
d_rmse = float(np.sqrt(np.mean(d_res**2)))
|
||||||
|
p_mean = float(np.mean(p_res))
|
||||||
|
d_mean = float(np.mean(d_res))
|
||||||
|
p_std = float(np.std(p_res))
|
||||||
|
d_std = float(np.std(d_res))
|
||||||
|
|
||||||
|
row = {
|
||||||
|
"rank": rank,
|
||||||
|
"idx": int(idx),
|
||||||
|
"score": float(score[idx]),
|
||||||
|
"rmse_p": float(sample_rows[idx]["rmse_p"]),
|
||||||
|
"rmse_d": float(sample_rows[idx]["rmse_d"]),
|
||||||
|
"mae_p": float(sample_rows[idx]["mae_p"]),
|
||||||
|
"mae_d": float(sample_rows[idx]["mae_d"]),
|
||||||
|
"in_pso_domain": int(bool(pso_mask[idx])),
|
||||||
|
"family": str(family[idx]) if family is not None else "",
|
||||||
|
}
|
||||||
|
for name in ["k", "skin", "wellboreC", "phi", "h", "Cf"]:
|
||||||
|
if name in params:
|
||||||
|
row[name] = float(params[name][idx])
|
||||||
|
if schedule_meta is not None:
|
||||||
|
for midx, name in enumerate(schedule_meta_names):
|
||||||
|
if midx < schedule_meta.shape[1]:
|
||||||
|
row[f"sched_{name}"] = float(schedule_meta[idx, midx])
|
||||||
|
case_rows.append(row)
|
||||||
|
|
||||||
|
residual_rows.append(
|
||||||
|
{
|
||||||
|
"rank": rank,
|
||||||
|
"idx": int(idx),
|
||||||
|
"score": float(score[idx]),
|
||||||
|
"p_res_mean": p_mean,
|
||||||
|
"p_res_std": p_std,
|
||||||
|
"p_res_rmse": p_rmse,
|
||||||
|
"p_shift_ratio": float(abs(p_mean) / max(p_rmse, 1.0e-12)),
|
||||||
|
"d_res_mean": d_mean,
|
||||||
|
"d_res_std": d_std,
|
||||||
|
"d_res_rmse": d_rmse,
|
||||||
|
"d_shift_ratio": float(abs(d_mean) / max(d_rmse, 1.0e-12)),
|
||||||
|
"p_res_first": float(p_res[0]),
|
||||||
|
"p_res_last": float(p_res[-1]),
|
||||||
|
"d_res_first": float(d_res[0]),
|
||||||
|
"d_res_last": float(d_res[-1]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return case_rows, residual_rows
|
||||||
|
|
||||||
|
|
||||||
|
def plot_sample(
|
||||||
|
output_path: Path,
|
||||||
|
idx: int,
|
||||||
|
t: np.ndarray,
|
||||||
|
true_p: np.ndarray,
|
||||||
|
pred_p: np.ndarray,
|
||||||
|
true_d: np.ndarray,
|
||||||
|
pred_d: np.ndarray,
|
||||||
|
title: str,
|
||||||
|
) -> None:
|
||||||
|
x = np.asarray(t, dtype=np.float64)
|
||||||
|
fig, axes = plt.subplots(2, 2, figsize=(13, 8))
|
||||||
|
fig.suptitle(title)
|
||||||
|
|
||||||
|
axes[0, 0].plot(x, true_p, label="True", linewidth=2)
|
||||||
|
axes[0, 0].plot(x, pred_p, label="Pred", linewidth=2)
|
||||||
|
axes[0, 0].set_title("Log Pressure")
|
||||||
|
axes[0, 0].set_xscale("log")
|
||||||
|
axes[0, 0].grid(True, alpha=0.3)
|
||||||
|
axes[0, 0].legend()
|
||||||
|
|
||||||
|
axes[0, 1].plot(x, pred_p - true_p, linewidth=1.5)
|
||||||
|
axes[0, 1].axhline(0.0, linestyle="--", linewidth=1)
|
||||||
|
axes[0, 1].set_title("Pressure Residual")
|
||||||
|
axes[0, 1].set_xscale("log")
|
||||||
|
axes[0, 1].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
axes[1, 0].plot(x, true_d, label="True", linewidth=2)
|
||||||
|
axes[1, 0].plot(x, pred_d, label="Pred", linewidth=2)
|
||||||
|
axes[1, 0].set_title("Log Derivative")
|
||||||
|
axes[1, 0].set_xscale("log")
|
||||||
|
axes[1, 0].grid(True, alpha=0.3)
|
||||||
|
axes[1, 0].legend()
|
||||||
|
|
||||||
|
axes[1, 1].plot(x, pred_d - true_d, linewidth=1.5)
|
||||||
|
axes[1, 1].axhline(0.0, linestyle="--", linewidth=1)
|
||||||
|
axes[1, 1].set_title("Derivative Residual")
|
||||||
|
axes[1, 1].set_xscale("log")
|
||||||
|
axes[1, 1].grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
for ax in axes.ravel():
|
||||||
|
ax.set_xlabel("Time")
|
||||||
|
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
plt.tight_layout(rect=[0, 0, 1, 0.95])
|
||||||
|
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
|
||||||
|
def write_plots(
|
||||||
|
output_dir: Path,
|
||||||
|
sample_rows: list[dict],
|
||||||
|
t_curve: np.ndarray,
|
||||||
|
true_p: np.ndarray,
|
||||||
|
true_d: np.ndarray,
|
||||||
|
pred_p: np.ndarray,
|
||||||
|
pred_d: np.ndarray,
|
||||||
|
args: argparse.Namespace,
|
||||||
|
) -> None:
|
||||||
|
plot_dir = output_dir / "plots"
|
||||||
|
score = np.asarray([r["score"] for r in sample_rows], dtype=np.float64)
|
||||||
|
random.seed(int(args.seed))
|
||||||
|
n_random = min(int(args.n_random_plots), len(sample_rows))
|
||||||
|
n_best = min(int(args.n_best_plots), len(sample_rows))
|
||||||
|
n_worst = min(int(args.n_worst_plots), len(sample_rows))
|
||||||
|
best = np.argsort(score)[:n_best].tolist()
|
||||||
|
worst = np.argsort(-score)[:n_worst].tolist()
|
||||||
|
random_idx = random.sample(range(len(sample_rows)), n_random)
|
||||||
|
|
||||||
|
for idx in random_idx:
|
||||||
|
plot_sample(
|
||||||
|
plot_dir / f"sample_{idx:04d}.png",
|
||||||
|
idx,
|
||||||
|
t_curve[idx],
|
||||||
|
true_p[idx],
|
||||||
|
pred_p[idx],
|
||||||
|
true_d[idx],
|
||||||
|
pred_d[idx],
|
||||||
|
f"Random sample {idx} | score={score[idx]:.4f}",
|
||||||
|
)
|
||||||
|
for idx in best:
|
||||||
|
plot_sample(
|
||||||
|
plot_dir / f"best_sample_{idx:04d}.png",
|
||||||
|
idx,
|
||||||
|
t_curve[idx],
|
||||||
|
true_p[idx],
|
||||||
|
pred_p[idx],
|
||||||
|
true_d[idx],
|
||||||
|
pred_d[idx],
|
||||||
|
f"Best sample {idx} | score={score[idx]:.4f}",
|
||||||
|
)
|
||||||
|
for idx in worst:
|
||||||
|
plot_sample(
|
||||||
|
plot_dir / f"worst_sample_{idx:04d}.png",
|
||||||
|
idx,
|
||||||
|
t_curve[idx],
|
||||||
|
true_p[idx],
|
||||||
|
pred_p[idx],
|
||||||
|
true_d[idx],
|
||||||
|
pred_d[idx],
|
||||||
|
f"Worst sample {idx} | score={score[idx]:.4f}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
tag = normalize_tag(args.tag)
|
||||||
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
||||||
|
model_path = Path(args.model) if args.model is not None else default_model_path(tag)
|
||||||
|
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag)
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
device_name = args.device or ("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
device = torch.device(device_name)
|
||||||
|
|
||||||
|
print("Loading processed dataset...")
|
||||||
|
data = joblib.load(processed_path)
|
||||||
|
required = ["X_params_test", "X_schedule_test", "X_time_test", "Y_curve_test"]
|
||||||
|
missing = [key for key in required if key not in data]
|
||||||
|
if missing:
|
||||||
|
raise KeyError(f"processed dataset is missing time-conditioned fields: {missing}")
|
||||||
|
|
||||||
|
print("Loading model...")
|
||||||
|
model, checkpoint = load_model(model_path, device)
|
||||||
|
curve_layout = checkpoint.get("curve_layout") or infer_curve_layout(data)
|
||||||
|
slices = get_part_slices(curve_layout)
|
||||||
|
|
||||||
|
x_params = np.asarray(data["X_params_test"], dtype=np.float32)
|
||||||
|
x_schedule = np.asarray(data["X_schedule_test"], dtype=np.float32)
|
||||||
|
x_time = np.asarray(data["X_time_test"], dtype=np.float32)
|
||||||
|
y_curve = np.asarray(data["Y_curve_test"], dtype=np.float32)
|
||||||
|
scaler_curve = data["scaler_curve"]
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"test={x_params.shape[0]}, n_time={x_time.shape[1]}, "
|
||||||
|
f"param_dim={x_params.shape[1]}, schedule_dim={x_schedule.shape[1]}, time_dim={x_time.shape[-1]}"
|
||||||
|
)
|
||||||
|
print(f"device={device}, batch_size={args.batch_size}")
|
||||||
|
|
||||||
|
pred_scaled = predict_scaled_points(
|
||||||
|
model=model,
|
||||||
|
params_x=x_params,
|
||||||
|
schedule_x=x_schedule,
|
||||||
|
time_x=x_time,
|
||||||
|
device=device,
|
||||||
|
batch_size=int(args.batch_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
p_slice = slices["log_pressure"]
|
||||||
|
d_slice = slices["log_derivative"]
|
||||||
|
true_p_scaled = y_curve[:, p_slice]
|
||||||
|
true_d_scaled = y_curve[:, d_slice]
|
||||||
|
pred_p_scaled = pred_scaled[:, :, 0]
|
||||||
|
pred_d_scaled = pred_scaled[:, :, 1]
|
||||||
|
|
||||||
|
true_p = inverse_curve_part(true_p_scaled, scaler_curve, p_slice)
|
||||||
|
true_d = inverse_curve_part(true_d_scaled, scaler_curve, d_slice)
|
||||||
|
pred_p = inverse_curve_part(pred_p_scaled, scaler_curve, p_slice)
|
||||||
|
pred_d = inverse_curve_part(pred_d_scaled, scaler_curve, d_slice)
|
||||||
|
|
||||||
|
summary = {
|
||||||
|
"processed_path": str(processed_path),
|
||||||
|
"model_path": str(model_path),
|
||||||
|
"device": str(device),
|
||||||
|
"checkpoint": {
|
||||||
|
"hidden_dim": int(checkpoint["hidden_dim"]),
|
||||||
|
"n_blocks": int(checkpoint["n_blocks"]),
|
||||||
|
"dropout": float(checkpoint["dropout"]),
|
||||||
|
"use_schedule": bool(checkpoint.get("use_schedule", True)),
|
||||||
|
},
|
||||||
|
"scaled_log_pressure": point_metrics(true_p_scaled, pred_p_scaled),
|
||||||
|
"scaled_log_derivative": point_metrics(true_d_scaled, pred_d_scaled),
|
||||||
|
"raw_log_pressure": point_metrics(true_p, pred_p),
|
||||||
|
"raw_log_derivative": point_metrics(true_d, pred_d),
|
||||||
|
}
|
||||||
|
|
||||||
|
rows = sample_metrics(true_p=true_p, pred_p=pred_p, true_d=true_d, pred_d=pred_d)
|
||||||
|
params = recover_raw_params(data)
|
||||||
|
pso_mask = build_pso_mask(params, args)
|
||||||
|
domain_summary = build_domain_summary(rows, params, pso_mask)
|
||||||
|
summary["pso_domain"] = {
|
||||||
|
"bounds": {
|
||||||
|
"k": [float(args.pso_k_min), float(args.pso_k_max)],
|
||||||
|
"skin": [float(args.pso_skin_min), float(args.pso_skin_max)],
|
||||||
|
"wellboreC": [float(args.pso_wellboreC_min), float(args.pso_wellboreC_max)],
|
||||||
|
"phi": [float(args.pso_phi_min), float(args.pso_phi_max)],
|
||||||
|
"h": [float(args.pso_h_min), float(args.pso_h_max)],
|
||||||
|
},
|
||||||
|
"metrics": domain_summary,
|
||||||
|
}
|
||||||
|
|
||||||
|
case_rows, residual_rows = build_worst_case_rows(
|
||||||
|
sample_rows=rows,
|
||||||
|
params=params,
|
||||||
|
data=data,
|
||||||
|
true_p=true_p,
|
||||||
|
true_d=true_d,
|
||||||
|
pred_p=pred_p,
|
||||||
|
pred_d=pred_d,
|
||||||
|
pso_mask=pso_mask,
|
||||||
|
top_k=int(args.top_k_analysis),
|
||||||
|
)
|
||||||
|
worst_case_summary = build_worst_case_summary(
|
||||||
|
sample_rows=rows,
|
||||||
|
params=params,
|
||||||
|
pso_mask=pso_mask,
|
||||||
|
top_k=int(args.top_k_analysis),
|
||||||
|
)
|
||||||
|
|
||||||
|
(output_dir / "summary_metrics.json").write_text(
|
||||||
|
json.dumps(summary, indent=2, ensure_ascii=False),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
write_csv(output_dir / "sample_metrics.csv", rows)
|
||||||
|
write_csv(output_dir / "worst_case_analysis.csv", case_rows)
|
||||||
|
write_csv(output_dir / "worst_residual_analysis.csv", residual_rows)
|
||||||
|
(output_dir / "worst_case_summary.json").write_text(
|
||||||
|
json.dumps(worst_case_summary, indent=2, ensure_ascii=False),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
residual_summary = {
|
||||||
|
"top_k": int(len(residual_rows)),
|
||||||
|
"top300_p_shift_ratio_median": float(np.median([r["p_shift_ratio"] for r in residual_rows])),
|
||||||
|
"top300_d_shift_ratio_median": float(np.median([r["d_shift_ratio"] for r in residual_rows])),
|
||||||
|
"top100_p_shift_ratio_median": float(np.median([r["p_shift_ratio"] for r in residual_rows[:100]])),
|
||||||
|
"top100_d_shift_ratio_median": float(np.median([r["d_shift_ratio"] for r in residual_rows[:100]])),
|
||||||
|
"top20": residual_rows[:20],
|
||||||
|
}
|
||||||
|
(output_dir / "worst_residual_summary.json").write_text(
|
||||||
|
json.dumps(residual_summary, indent=2, ensure_ascii=False),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
t_curve = np.asarray(data.get("T_curve_test"), dtype=np.float32)
|
||||||
|
if t_curve.ndim == 2 and t_curve.shape == true_p.shape:
|
||||||
|
write_plots(output_dir, rows, t_curve, true_p, true_d, pred_p, pred_d, args)
|
||||||
|
|
||||||
|
print("\nEvaluation complete.")
|
||||||
|
print(f"raw_log_pressure RMSE={summary['raw_log_pressure']['rmse']:.6f}, MAE={summary['raw_log_pressure']['mae']:.6f}")
|
||||||
|
print(f"raw_log_derivative RMSE={summary['raw_log_derivative']['rmse']:.6f}, MAE={summary['raw_log_derivative']['mae']:.6f}")
|
||||||
|
print(
|
||||||
|
"PSO-domain: "
|
||||||
|
f"n={domain_summary['pso_domain']['n']}, "
|
||||||
|
f"median={domain_summary['pso_domain']['score']['median']:.6f}, "
|
||||||
|
f"p95={domain_summary['pso_domain']['score']['p95']:.6f}, "
|
||||||
|
f"score>1={100.0 * domain_summary['pso_domain']['score_gt_1_ratio']:.3f}%"
|
||||||
|
)
|
||||||
|
print(f"Artifacts written to: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue