from __future__ import annotations import argparse import csv import json import random import sys from pathlib import Path import joblib import matplotlib.pyplot as plt import numpy as np import torch ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import ( evaluation_dir_for_tag, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag, ) from src.models.forward_surrogate import ForwardSurrogate DEFAULT_RANDOM_SEED = 42 def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Evaluate forward surrogate model") parser.add_argument( "--processed", type=str, default=None, help="Processed dataset path", ) parser.add_argument( "--model", type=str, default=None, help="Model checkpoint path", ) parser.add_argument( "--fit-processed", type=str, default=None, help="Processed dataset used to fit scalers for the evaluated model; required for cross-dataset evaluation", ) parser.add_argument( "--output-dir", type=str, default=None, help="Optional evaluation output directory", ) parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming") parser.add_argument( "--no-schedule", action="store_true", help="When --model is omitted, infer the no-schedule checkpoint path", ) parser.add_argument("--seed", type=int, default=DEFAULT_RANDOM_SEED) parser.add_argument("--n-random-plots", type=int, default=5) parser.add_argument("--n-best-plots", type=int, default=5) parser.add_argument("--n-worst-plots", type=int, default=5) return parser.parse_args() def calc_metrics( y_true: np.ndarray, y_pred: np.ndarray, eps_range: float = 1e-3, eps_var: float = 1e-6, ) -> dict: err = y_pred - y_true mse = np.mean(err**2) rmse = float(np.sqrt(mse)) mae = float(np.mean(np.abs(err))) bias = float(np.mean(err)) value_range = float(np.max(y_true) - np.min(y_true)) ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2)) ss_res = float(np.sum(err**2)) valid_nrmse = value_range > eps_range valid_r2 = ss_tot > eps_var nrmse = float(rmse / value_range) if valid_nrmse else np.nan r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan return { "rmse": rmse, "mae": mae, "bias": bias, "abs_bias": float(abs(bias)), "nrmse": nrmse, "r2": r2, "range_true": value_range, "ss_tot": ss_tot, "valid_nrmse": bool(valid_nrmse), "valid_r2": bool(valid_r2), } def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]: parts: dict[str, np.ndarray] = {} for part in layout["parts"]: start = int(part["start"]) end = int(part["end"]) parts[str(part["name"])] = curve[start:end] return parts def safe_percentile(x: np.ndarray, q: float) -> float: if x.size == 0: return np.nan return float(np.percentile(x, q)) def safe_mean(x: np.ndarray) -> float: if x.size == 0: return np.nan return float(np.mean(x)) def safe_median(x: np.ndarray) -> float: if x.size == 0: return np.nan return float(np.median(x)) def summarize_metric_dicts(metric_dicts: list[dict], prefix: str) -> dict: rmse = np.array([m["rmse"] for m in metric_dicts], dtype=np.float64) mae = np.array([m["mae"] for m in metric_dicts], dtype=np.float64) bias = np.array([m["bias"] for m in metric_dicts], dtype=np.float64) abs_bias = np.array([m["abs_bias"] for m in metric_dicts], dtype=np.float64) ranges = np.array([m["range_true"] for m in metric_dicts], dtype=np.float64) ss_tot = np.array([m["ss_tot"] for m in metric_dicts], dtype=np.float64) nrmse_valid = np.array([m["nrmse"] for m in metric_dicts if m["valid_nrmse"]], dtype=np.float64) r2_valid = np.array([m["r2"] for m in metric_dicts if m["valid_r2"]], dtype=np.float64) valid_nrmse_ratio = len(nrmse_valid) / max(len(metric_dicts), 1) valid_r2_ratio = len(r2_valid) / max(len(metric_dicts), 1) low_range_count = int(np.sum(ranges <= 1e-3)) low_var_count = int(np.sum(ss_tot <= 1e-6)) print(f"\n[{prefix}]") print( f" RMSE mean={safe_mean(rmse):.6f}, median={safe_median(rmse):.6f}, " f"p90={safe_percentile(rmse, 90):.6f}" ) print( f" MAE mean={safe_mean(mae):.6f}, median={safe_median(mae):.6f}, " f"p90={safe_percentile(mae, 90):.6f}" ) print( f" Bias mean={safe_mean(bias):.6f}, median={safe_median(bias):.6f}, " f"|mean|={safe_mean(abs_bias):.6f}" ) print( f" TrueRange mean={safe_mean(ranges):.6f}, median={safe_median(ranges):.6f}, " f"p10={safe_percentile(ranges, 10):.6f}" ) print( f" NRMSE(valid only) mean={safe_mean(nrmse_valid):.6f}, " f"median={safe_median(nrmse_valid):.6f}, p90={safe_percentile(nrmse_valid, 90):.6f}, " f"valid_ratio={valid_nrmse_ratio:.4f}, low_range_count={low_range_count}" ) print( f" R2(valid only) mean={safe_mean(r2_valid):.6f}, " f"median={safe_median(r2_valid):.6f}, p10={safe_percentile(r2_valid, 10):.6f}, " f"valid_ratio={valid_r2_ratio:.4f}, low_var_count={low_var_count}" ) return { "rmse_mean": safe_mean(rmse), "rmse_median": safe_median(rmse), "rmse_p90": safe_percentile(rmse, 90), "mae_mean": safe_mean(mae), "mae_median": safe_median(mae), "mae_p90": safe_percentile(mae, 90), "bias_mean": safe_mean(bias), "bias_median": safe_median(bias), "abs_bias_mean": safe_mean(abs_bias), "nrmse_mean_valid": safe_mean(nrmse_valid), "nrmse_median_valid": safe_median(nrmse_valid), "nrmse_p90_valid": safe_percentile(nrmse_valid, 90), "nrmse_valid_ratio": valid_nrmse_ratio, "r2_mean_valid": safe_mean(r2_valid), "r2_median_valid": safe_median(r2_valid), "r2_p10_valid": safe_percentile(r2_valid, 10), "r2_valid_ratio": valid_r2_ratio, "low_range_count": low_range_count, "low_var_count": low_var_count, } def build_composite_score(overall_m: dict, part_ms: dict) -> float: return float( 1.0 * overall_m["rmse"] + 0.5 * overall_m["mae"] + 0.8 * part_ms["log_pressure"]["abs_bias"] + 0.8 * part_ms["log_derivative"]["rmse"] + 0.2 * part_ms["slope"]["rmse"] ) def plot_sample( idx: int, curve_true: np.ndarray, curve_pred: np.ndarray, curve_layout: dict, output_dir: Path, title_prefix: str, ) -> None: true_parts = split_curve_by_layout(curve_true, curve_layout) pred_parts = split_curve_by_layout(curve_pred, curve_layout) overall = calc_metrics(curve_true, curve_pred) fig, axes = plt.subplots(3, 2, figsize=(14, 12)) fig.suptitle( f"{title_prefix} | Sample #{idx} | " f"Overall RMSE={overall['rmse']:.4f}, MAE={overall['mae']:.4f}, " f"Bias={overall['bias']:.4f}, " f"NRMSE={'nan' if np.isnan(overall['nrmse']) else f'{overall['nrmse']:.4f}'}, " f"R2={'nan' if np.isnan(overall['r2']) else f'{overall['r2']:.4f}'}" ) plot_order = ["log_pressure", "log_derivative", "slope"] title_map = { "log_pressure": "Log Pressure", "log_derivative": "Log |Derivative|", "slope": "Slope of Log Pressure vs Log Time", } for row, name in enumerate(plot_order): y_true = true_parts[name] y_pred = pred_parts[name] err = y_pred - y_true x = np.arange(len(y_true)) m = calc_metrics(y_true, y_pred) nrmse_text = "nan" if np.isnan(m["nrmse"]) else f"{m['nrmse']:.4f}" r2_text = "nan" if np.isnan(m["r2"]) else f"{m['r2']:.4f}" ax_l = axes[row, 0] ax_l.plot(x, y_true, label="True", linewidth=2, alpha=0.85) ax_l.plot(x, y_pred, label="Predicted", linewidth=2, alpha=0.85) ax_l.set_title( f"{title_map[name]} | RMSE={m['rmse']:.4f}, MAE={m['mae']:.4f}, " f"Bias={m['bias']:.4f}, NRMSE={nrmse_text}, R2={r2_text}" ) ax_l.set_xlabel("Resampled Time Index") ax_l.set_ylabel("Value") ax_l.grid(True, alpha=0.3) ax_l.legend() ax_r = axes[row, 1] ax_r.plot(x, err, linewidth=1.5, alpha=0.85) ax_r.axhline(0.0, linestyle="--", linewidth=1) ax_r.set_title(f"{title_map[name]} Error") ax_r.set_xlabel("Resampled Time Index") ax_r.set_ylabel("Pred - True") ax_r.grid(True, alpha=0.3) plt.tight_layout(rect=[0, 0, 1, 0.97]) save_path = output_dir / f"{title_prefix.lower()}_sample_{idx:04d}.png" plt.savefig(save_path, dpi=150, bbox_inches="tight") plt.close() print(f"Plot saved: {save_path}") def save_sample_metrics_csv(sample_scores: list[dict], path: Path) -> None: if not sample_scores: return fieldnames = list(sample_scores[0].keys()) with open(path, "w", newline="", encoding="utf-8-sig") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(sample_scores) print(f"Sample metrics saved: {path}") def infer_curve_layout(meta: dict, curve_dim: int) -> dict: curve_layout = meta.get("curve_layout") if curve_layout is not None: return curve_layout n_time_points = curve_dim // 3 return { "n_time_points": int(n_time_points), "parts": [ {"name": "log_pressure", "start": 0, "end": n_time_points}, {"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points}, {"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points}, ], } def choose_output_dir(arg_output_dir: str | None, tag: str | None, use_schedule: bool) -> Path: if arg_output_dir is not None: return Path(arg_output_dir) return evaluation_dir_for_tag(tag, use_schedule) def model_predict( model: ForwardSurrogate, params_x: torch.Tensor, schedule_x: torch.Tensor, use_schedule: bool, ) -> torch.Tensor: if use_schedule: return model(params_x, schedule_x) return model(params_x, None) def prepare_eval_arrays(eval_data: dict, fit_data: dict | None) -> tuple[np.ndarray, np.ndarray, np.ndarray, object]: x_params_test = np.asarray(eval_data["X_params_test"], dtype=np.float32) x_schedule_test = np.asarray(eval_data["X_schedule_test"], dtype=np.float32) y_curve_test = np.asarray(eval_data["Y_curve_test"], dtype=np.float32) eval_scaler_params = eval_data["scaler_params"] eval_scaler_schedule = eval_data["scaler_schedule"] eval_scaler_curve = eval_data["scaler_curve"] if fit_data is None: return x_params_test, x_schedule_test, y_curve_test, eval_scaler_curve fit_scaler_params = fit_data["scaler_params"] fit_scaler_schedule = fit_data["scaler_schedule"] fit_scaler_curve = fit_data["scaler_curve"] eval_param_transform = (eval_data.get("meta", {}) or {}).get("param_feature_transform") fit_param_transform = (fit_data.get("meta", {}) or {}).get("param_feature_transform") if eval_param_transform != fit_param_transform: raise ValueError( "Cross-dataset evaluation requires matching param_feature_transform metadata. " "Re-preprocess both datasets with the same transform setting." ) # Recover parameter features in the benchmark dataset's fitted scale, # then map them into the feature scale expected by the trained model. x_params_raw = eval_scaler_params.inverse_transform(x_params_test) x_schedule_raw = eval_scaler_schedule.inverse_transform(x_schedule_test) x_params_fit = fit_scaler_params.transform(x_params_raw).astype(np.float32) x_schedule_fit = fit_scaler_schedule.transform(x_schedule_raw).astype(np.float32) y_true_raw = eval_scaler_curve.inverse_transform(y_curve_test).astype(np.float32) return x_params_fit, x_schedule_fit, y_true_raw, fit_scaler_curve def main() -> None: args = parse_args() tag = normalize_tag(args.tag) processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) model_path = ( Path(args.model) if args.model is not None else model_checkpoint_for_tag(tag, use_schedule=not args.no_schedule) ) print("Loading processed dataset...") data = joblib.load(processed_path) fit_processed_path = Path(args.fit_processed) if args.fit_processed is not None else None fit_data = joblib.load(fit_processed_path) if fit_processed_path is not None else None x_params_test, x_schedule_test, y_curve_test, pred_scaler_curve = prepare_eval_arrays(data, fit_data) meta = data["meta"] param_dim = int(meta["param_dim"]) schedule_dim = int(meta["schedule_dim"]) curve_dim = int(meta["curve_dim"]) curve_layout = infer_curve_layout(meta, curve_dim) print(f"Test size: {len(x_params_test)}") print(f"param_dim={param_dim}, schedule_dim={schedule_dim}, curve_dim={curve_dim}") print(f"curve_layout={curve_layout}") print("Loading trained model...") checkpoint = torch.load(model_path, map_location="cpu") use_schedule = bool(checkpoint.get("use_schedule", True)) model = ForwardSurrogate( param_dim=int(checkpoint["param_dim"]), schedule_dim=int(checkpoint["schedule_dim"]), curve_dim=int(checkpoint["curve_dim"]), hidden_dim=int(checkpoint["hidden_dim"]), dropout=float(checkpoint["dropout"]), use_schedule=use_schedule, ) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) output_dir = choose_output_dir(args.output_dir, tag, use_schedule) output_dir.mkdir(parents=True, exist_ok=True) print(f"Using device: {device}") print(f"use_schedule={use_schedule}") print(f"output_dir={output_dir}") if fit_processed_path is not None: print(f"fit_processed={fit_processed_path}") random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) all_true = [] all_pred = [] print("\nRunning inference on full test split...") with torch.no_grad(): for idx in range(len(x_params_test)): params_x = torch.tensor(x_params_test[idx : idx + 1], dtype=torch.float32).to(device) schedule_x = torch.tensor(x_schedule_test[idx : idx + 1], dtype=torch.float32).to(device) curve_pred_scaled = model_predict(model, params_x, schedule_x, use_schedule).cpu().numpy() if fit_processed_path is None: curve_true = pred_scaler_curve.inverse_transform(y_curve_test[idx : idx + 1])[0] else: curve_true = y_curve_test[idx] curve_pred = pred_scaler_curve.inverse_transform(curve_pred_scaled)[0] all_true.append(curve_true.astype(np.float32)) all_pred.append(curve_pred.astype(np.float32)) all_true = np.stack(all_true, axis=0) all_pred = np.stack(all_pred, axis=0) print("Inference complete.") overall_metric_list: list[dict] = [] part_metric_lists = { "log_pressure": [], "log_derivative": [], "slope": [], } sample_scores: list[dict] = [] for idx in range(len(all_true)): curve_true = all_true[idx] curve_pred = all_pred[idx] overall_m = calc_metrics(curve_true, curve_pred) overall_metric_list.append(overall_m) true_parts = split_curve_by_layout(curve_true, curve_layout) pred_parts = split_curve_by_layout(curve_pred, curve_layout) part_ms: dict[str, dict] = {} for name in ["log_pressure", "log_derivative", "slope"]: part_m = calc_metrics(true_parts[name], pred_parts[name]) part_metric_lists[name].append(part_m) part_ms[name] = part_m sample_scores.append( { "idx": idx, "composite_score": build_composite_score(overall_m, part_ms), "overall_rmse": overall_m["rmse"], "overall_mae": overall_m["mae"], "overall_bias": overall_m["bias"], "overall_abs_bias": overall_m["abs_bias"], "overall_nrmse": overall_m["nrmse"], "overall_r2": overall_m["r2"], "log_pressure_rmse": part_ms["log_pressure"]["rmse"], "log_pressure_mae": part_ms["log_pressure"]["mae"], "log_pressure_bias": part_ms["log_pressure"]["bias"], "log_pressure_abs_bias": part_ms["log_pressure"]["abs_bias"], "log_pressure_nrmse": part_ms["log_pressure"]["nrmse"], "log_pressure_r2": part_ms["log_pressure"]["r2"], "log_derivative_rmse": part_ms["log_derivative"]["rmse"], "log_derivative_mae": part_ms["log_derivative"]["mae"], "log_derivative_bias": part_ms["log_derivative"]["bias"], "log_derivative_abs_bias": part_ms["log_derivative"]["abs_bias"], "log_derivative_nrmse": part_ms["log_derivative"]["nrmse"], "log_derivative_r2": part_ms["log_derivative"]["r2"], "slope_rmse": part_ms["slope"]["rmse"], "slope_mae": part_ms["slope"]["mae"], "slope_bias": part_ms["slope"]["bias"], "slope_abs_bias": part_ms["slope"]["abs_bias"], "slope_nrmse": part_ms["slope"]["nrmse"], "slope_r2": part_ms["slope"]["r2"], "overall_valid_nrmse": overall_m["valid_nrmse"], "overall_valid_r2": overall_m["valid_r2"], "log_pressure_valid_nrmse": part_ms["log_pressure"]["valid_nrmse"], "log_pressure_valid_r2": part_ms["log_pressure"]["valid_r2"], "log_derivative_valid_nrmse": part_ms["log_derivative"]["valid_nrmse"], "log_derivative_valid_r2": part_ms["log_derivative"]["valid_r2"], "slope_valid_nrmse": part_ms["slope"]["valid_nrmse"], "slope_valid_r2": part_ms["slope"]["valid_r2"], } ) print("\n" + "=" * 60) print("Full test split summary") summary = { "overall": summarize_metric_dicts(overall_metric_list, "overall"), "log_pressure": summarize_metric_dicts(part_metric_lists["log_pressure"], "log_pressure"), "log_derivative": summarize_metric_dicts(part_metric_lists["log_derivative"], "log_derivative"), "slope": summarize_metric_dicts(part_metric_lists["slope"], "slope"), "checkpoint": { "model_path": str(model_path), "processed_path": str(processed_path), "fit_processed_path": str(fit_processed_path) if fit_processed_path is not None else str(processed_path), "use_schedule": use_schedule, }, } with open(output_dir / "summary_metrics.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) print(f"\nSummary saved: {output_dir / 'summary_metrics.json'}") save_sample_metrics_csv(sample_scores, output_dir / "sample_metrics.csv") score_composite = np.array([row["composite_score"] for row in sample_scores], dtype=np.float64) n_random = min(args.n_random_plots, len(sample_scores)) n_best = min(args.n_best_plots, len(sample_scores)) n_worst = min(args.n_worst_plots, len(sample_scores)) best_indices = np.argsort(score_composite)[:n_best].tolist() worst_indices = np.argsort(score_composite)[-n_worst:].tolist() random_indices = random.sample(range(len(sample_scores)), n_random) print("\nBest sample indices:", best_indices) print("Worst sample indices:", worst_indices) print("Random sample indices:", random_indices) for idx in random_indices: plot_sample(idx, all_true[idx], all_pred[idx], curve_layout, output_dir, "random") for idx in best_indices: plot_sample(idx, all_true[idx], all_pred[idx], curve_layout, output_dir, "best") for idx in worst_indices: plot_sample(idx, all_true[idx], all_pred[idx], curve_layout, output_dir, "worst") print("\nArtifacts written to:", output_dir) print("1. summary_metrics.json") print("2. sample_metrics.csv") print("3. best/random/worst sample plots") if __name__ == "__main__": main()