from __future__ import annotations import argparse import csv import json import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) import joblib import matplotlib.pyplot as plt import numpy as np import torch from src.common.experiment_paths import normalize_tag, processed_path_for_tag from src.models.forward_surrogate import ForwardSurrogate def parse_seed_list(seed_text: str) -> list[int]: seeds = [] for item in str(seed_text).split(","): item = item.strip() if not item: continue seeds.append(int(item)) if not seeds: raise ValueError("至少需要一个 seed") return seeds def default_model_root(tag: str | None, use_schedule: bool) -> Path: suffix = "" if use_schedule else "_no_schedule" if tag: return Path("models") / f"forward_surrogate_{tag}_ensemble{suffix}" return Path("models") / f"forward_surrogate_ensemble{suffix}" def default_output_dir(tag: str | None, use_schedule: bool) -> Path: suffix = "" if use_schedule else "_no_schedule" if tag: return Path("results") / f"evaluation_{tag}_ensemble_uq{suffix}" return Path("results") / f"evaluation_ensemble_uq{suffix}" def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Evaluate deep-ensemble UQ for forward surrogate") parser.add_argument("--processed", type=str, default=None, help="Processed dataset path") parser.add_argument("--tag", type=str, default=None, help="Experiment tag") parser.add_argument("--model-root", type=str, default=None, help="Root dir that contains seed_* members") parser.add_argument("--output-dir", type=str, default=None, help="Evaluation output dir") parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list") parser.add_argument("--no-schedule", action="store_true") parser.add_argument("--n-top-uncertain-plots", type=int, default=5) return parser.parse_args() def calc_metrics( y_true: np.ndarray, y_pred: np.ndarray, eps_range: float = 1e-3, eps_var: float = 1e-6, ) -> dict: err = y_pred - y_true mse = np.mean(err**2) rmse = float(np.sqrt(mse)) mae = float(np.mean(np.abs(err))) bias = float(np.mean(err)) value_range = float(np.max(y_true) - np.min(y_true)) ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2)) ss_res = float(np.sum(err**2)) valid_nrmse = value_range > eps_range valid_r2 = ss_tot > eps_var nrmse = float(rmse / value_range) if valid_nrmse else np.nan r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan return { "rmse": rmse, "mae": mae, "bias": bias, "abs_bias": float(abs(bias)), "nrmse": nrmse, "r2": r2, "valid_nrmse": bool(valid_nrmse), "valid_r2": bool(valid_r2), } def infer_curve_layout(meta: dict, curve_dim: int) -> dict: curve_layout = meta.get("curve_layout") if curve_layout is not None: return curve_layout n_time_points = curve_dim // 3 return { "n_time_points": int(n_time_points), "parts": [ {"name": "log_pressure", "start": 0, "end": n_time_points}, {"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points}, {"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points}, ], } def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]: parts: dict[str, np.ndarray] = {} for part in layout["parts"]: start = int(part["start"]) end = int(part["end"]) parts[str(part["name"])] = curve[start:end] return parts def load_member(checkpoint_path: Path, device: torch.device) -> tuple[ForwardSurrogate, dict]: checkpoint = torch.load(checkpoint_path, map_location="cpu") model = ForwardSurrogate( param_dim=int(checkpoint["param_dim"]), schedule_dim=int(checkpoint["schedule_dim"]), curve_dim=int(checkpoint["curve_dim"]), hidden_dim=int(checkpoint["hidden_dim"]), dropout=float(checkpoint["dropout"]), use_schedule=bool(checkpoint.get("use_schedule", True)), ).to(device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, checkpoint def safe_mean(x: np.ndarray) -> float: return float(np.mean(x)) if x.size else np.nan def safe_median(x: np.ndarray) -> float: return float(np.median(x)) if x.size else np.nan def safe_percentile(x: np.ndarray, q: float) -> float: return float(np.percentile(x, q)) if x.size else np.nan def pearson_corr(x: np.ndarray, y: np.ndarray) -> float: if x.size == 0 or y.size == 0: return np.nan if np.allclose(np.std(x), 0.0) or np.allclose(np.std(y), 0.0): return np.nan return float(np.corrcoef(x, y)[0, 1]) def plot_uncertainty_scatter(sample_rows: list[dict], output_path: Path) -> None: rmse = np.array([row["overall_rmse"] for row in sample_rows], dtype=np.float64) unc = np.array([row["unc_mean_std"] for row in sample_rows], dtype=np.float64) plt.figure(figsize=(7, 5)) plt.scatter(unc, rmse, s=10, alpha=0.35) plt.xlabel("Predictive Uncertainty (mean std)") plt.ylabel("Overall RMSE") plt.title(f"Uncertainty vs Error | Pearson={pearson_corr(unc, rmse):.4f}") plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() def plot_uncertain_sample( idx: int, curve_true: np.ndarray, curve_mean: np.ndarray, curve_std: np.ndarray, curve_layout: dict, output_dir: Path, unc_score: float, rmse: float, ) -> None: true_parts = split_curve_by_layout(curve_true, curve_layout) mean_parts = split_curve_by_layout(curve_mean, curve_layout) std_parts = split_curve_by_layout(curve_std, curve_layout) title_map = { "log_pressure": "Log Pressure", "log_derivative": "Log |Derivative|", "slope": "Slope of Log Pressure vs Log Time", } fig, axes = plt.subplots(3, 1, figsize=(12, 10)) fig.suptitle( f"High-Uncertainty Sample #{idx} | unc_mean_std={unc_score:.4f}, overall_rmse={rmse:.4f}" ) for ax, name in zip(axes, ["log_pressure", "log_derivative", "slope"]): y_true = true_parts[name] y_mean = mean_parts[name] y_std = std_parts[name] x = np.arange(len(y_true)) ax.plot(x, y_true, label="True", linewidth=2) ax.plot(x, y_mean, label="Ensemble mean", linewidth=2) ax.fill_between(x, y_mean - 2.0 * y_std, y_mean + 2.0 * y_std, alpha=0.2, label="mean ± 2 std") ax.set_title(title_map[name]) ax.grid(True, alpha=0.3) ax.legend() plt.tight_layout(rect=[0, 0, 1, 0.96]) plt.savefig(output_dir / f"top_uncertain_sample_{idx:04d}.png", dpi=150, bbox_inches="tight") plt.close() def main() -> None: args = parse_args() tag = normalize_tag(args.tag) use_schedule = not args.no_schedule seeds = parse_seed_list(args.seeds) processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) model_root = Path(args.model_root) if args.model_root is not None else default_model_root(tag, use_schedule) output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag, use_schedule) output_dir.mkdir(parents=True, exist_ok=True) data = joblib.load(processed_path) x_params_test = data["X_params_test"] x_schedule_test = data["X_schedule_test"] y_curve_test = data["Y_curve_test"] scaler_curve = data["scaler_curve"] curve_layout = infer_curve_layout(data["meta"], int(data["meta"]["curve_dim"])) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") members: list[tuple[int, ForwardSurrogate]] = [] member_paths = [] first_use_schedule = None for seed in seeds: ckpt_path = model_root / f"seed_{seed}" / "forward_surrogate_best.pt" model, checkpoint = load_member(ckpt_path, device) member_paths.append(str(ckpt_path)) members.append((seed, model)) cur_use_schedule = bool(checkpoint.get("use_schedule", True)) if first_use_schedule is None: first_use_schedule = cur_use_schedule elif first_use_schedule != cur_use_schedule: raise RuntimeError("Ensemble 成员的 use_schedule 设置不一致") all_true = [] all_mean = [] all_std = [] sample_rows = [] with torch.no_grad(): for idx in range(len(x_params_test)): params_t = torch.tensor(x_params_test[idx : idx + 1], dtype=torch.float32, device=device) schedule_t = torch.tensor(x_schedule_test[idx : idx + 1], dtype=torch.float32, device=device) member_preds = [] for _, model in members: if first_use_schedule: pred_scaled = model(params_t, schedule_t).cpu().numpy() else: pred_scaled = model(params_t, None).cpu().numpy() pred = scaler_curve.inverse_transform(pred_scaled)[0].astype(np.float32) member_preds.append(pred) member_preds = np.stack(member_preds, axis=0) curve_true = scaler_curve.inverse_transform(y_curve_test[idx : idx + 1])[0].astype(np.float32) curve_mean = member_preds.mean(axis=0).astype(np.float32) curve_std = member_preds.std(axis=0, ddof=0).astype(np.float32) metrics = calc_metrics(curve_true, curve_mean) parts_std = split_curve_by_layout(curve_std, curve_layout) sample_rows.append( { "idx": idx, "overall_rmse": metrics["rmse"], "overall_mae": metrics["mae"], "overall_bias": metrics["bias"], "overall_r2": metrics["r2"], "unc_mean_std": float(np.mean(curve_std)), "unc_max_std": float(np.max(curve_std)), "unc_log_pressure_mean_std": float(np.mean(parts_std["log_pressure"])), "unc_log_derivative_mean_std": float(np.mean(parts_std["log_derivative"])), "unc_slope_mean_std": float(np.mean(parts_std["slope"])), } ) all_true.append(curve_true) all_mean.append(curve_mean) all_std.append(curve_std) all_true = np.stack(all_true, axis=0) all_mean = np.stack(all_mean, axis=0) all_std = np.stack(all_std, axis=0) overall_metrics = [calc_metrics(t, p) for t, p in zip(all_true, all_mean)] rmse = np.array([m["rmse"] for m in overall_metrics], dtype=np.float64) mae = np.array([m["mae"] for m in overall_metrics], dtype=np.float64) r2_valid = np.array([m["r2"] for m in overall_metrics if m["valid_r2"]], dtype=np.float64) unc = np.array([row["unc_mean_std"] for row in sample_rows], dtype=np.float64) summary = { "ensemble": { "member_count": len(members), "member_paths": member_paths, "use_schedule": bool(first_use_schedule), "processed_path": str(processed_path), }, "prediction": { "rmse_mean": safe_mean(rmse), "rmse_median": safe_median(rmse), "rmse_p90": safe_percentile(rmse, 90), "mae_mean": safe_mean(mae), "mae_median": safe_median(mae), "r2_mean_valid": safe_mean(r2_valid), "r2_median_valid": safe_median(r2_valid), }, "uncertainty": { "unc_mean": safe_mean(unc), "unc_median": safe_median(unc), "unc_p90": safe_percentile(unc, 90), "unc_vs_rmse_pearson": pearson_corr(unc, rmse), }, } with open(output_dir / "ensemble_uq_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) with open(output_dir / "sample_uncertainty_metrics.csv", "w", newline="", encoding="utf-8-sig") as f: writer = csv.DictWriter(f, fieldnames=list(sample_rows[0].keys())) writer.writeheader() writer.writerows(sample_rows) plot_uncertainty_scatter(sample_rows, output_dir / "uncertainty_vs_error.png") top_k = min(args.n_top_uncertain_plots, len(sample_rows)) top_uncertain = sorted(sample_rows, key=lambda row: row["unc_mean_std"], reverse=True)[:top_k] for row in top_uncertain: idx = int(row["idx"]) plot_uncertain_sample( idx=idx, curve_true=all_true[idx], curve_mean=all_mean[idx], curve_std=all_std[idx], curve_layout=curve_layout, output_dir=output_dir, unc_score=float(row["unc_mean_std"]), rmse=float(row["overall_rmse"]), ) print("Ensemble UQ evaluation complete.") print(f"Output dir: {output_dir}") print(f"RMSE mean={summary['prediction']['rmse_mean']:.6f}") print(f"UQ-RMSE Pearson={summary['uncertainty']['unc_vs_rmse_pearson']:.6f}") if __name__ == "__main__": main()