"""评估正演代理模型集成并估计预测不确定性。 脚本按多个随机种子加载同结构模型,计算集成均值预测、成员间标准差和逐样本误差, 导出不确定性-误差相关性统计、散点图和高不确定性样本,用于判断集成方差是否可作为 自动拟合候选筛选或风险提示信号。 """ # pylint: disable=import-error,wrong-import-position # pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments # pylint: disable=too-many-statements from __future__ import annotations import argparse import csv import json import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) import joblib import matplotlib.pyplot as plt import numpy as np import torch from src.common.experiment_paths import normalize_tag, processed_path_for_tag from src.models.forward_surrogate import ForwardSurrogate def parse_seed_list(seed_text: str) -> list[int]: """解析逗号分隔的随机种子列表,用于训练或评估模型集成。""" seeds = [] for item in str(seed_text).split(","): item = item.strip() if not item: continue seeds.append(int(item)) if not seeds: raise ValueError("至少需要一个 seed") return seeds def default_model_root(tag: str | None, use_schedule: bool) -> Path: """根据实验标签推导集成模型成员所在的根目录。""" suffix = "" if use_schedule else "_no_schedule" if tag: return Path("models") / f"forward_surrogate_{tag}_ensemble{suffix}" return Path("models") / f"forward_surrogate_ensemble{suffix}" def default_output_dir(tag: str | None, use_schedule: bool) -> Path: """根据实验标签生成当前分析脚本默认的输出目录。""" suffix = "" if use_schedule else "_no_schedule" if tag: return Path("results") / f"evaluation_{tag}_ensemble_uq{suffix}" return Path("results") / f"evaluation_ensemble_uq{suffix}" def parse_args() -> argparse.Namespace: """解析深度集成评估所需的数据、成员模型清单、样本数量和绘图数量。""" parser = argparse.ArgumentParser(description="Evaluate deep-ensemble UQ for forward surrogate") parser.add_argument("--processed", type=str, default=None, help="Processed dataset path") parser.add_argument("--tag", type=str, default=None, help="Experiment tag") parser.add_argument( "--model-root", type=str, default=None, help="Root dir that contains seed_* members", ) parser.add_argument("--output-dir", type=str, default=None, help="Evaluation output dir") parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list") parser.add_argument("--no-schedule", action="store_true") parser.add_argument("--n-top-uncertain-plots", type=int, default=5) return parser.parse_args() def calc_metrics( y_true: np.ndarray, y_pred: np.ndarray, eps_range: float = 1e-3, eps_var: float = 1e-6, ) -> dict: """计算 RMSE、MAE、Bias、NRMSE、R2 等回归指标。""" err = y_pred - y_true mse = np.mean(err**2) rmse = float(np.sqrt(mse)) mae = float(np.mean(np.abs(err))) bias = float(np.mean(err)) value_range = float(np.max(y_true) - np.min(y_true)) ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2)) ss_res = float(np.sum(err**2)) valid_nrmse = value_range > eps_range valid_r2 = ss_tot > eps_var nrmse = float(rmse / value_range) if valid_nrmse else np.nan r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan return { "rmse": rmse, "mae": mae, "bias": bias, "abs_bias": float(abs(bias)), "nrmse": nrmse, "r2": r2, "valid_nrmse": bool(valid_nrmse), "valid_r2": bool(valid_r2), } def infer_curve_layout(meta: dict, curve_dim: int) -> dict: """从元数据读取曲线分段布局;旧数据没有布局时按压力/导数/斜率三等分回退。""" curve_layout = meta.get("curve_layout") if curve_layout is not None: return curve_layout n_time_points = curve_dim // 3 return { "n_time_points": int(n_time_points), "parts": [ {"name": "log_pressure", "start": 0, "end": n_time_points}, {"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points}, {"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points}, ], } def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]: """按照 curve_layout 将拼接曲线拆成 log_pressure、log_derivative 和 slope 三段。""" parts: dict[str, np.ndarray] = {} for part in layout["parts"]: start = int(part["start"]) end = int(part["end"]) parts[str(part["name"])] = curve[start:end] return parts def load_member(checkpoint_path: Path, device: torch.device) -> tuple[ForwardSurrogate, dict]: """加载集成中的一个模型成员、对应归一化器和模型配置。""" checkpoint = torch.load(checkpoint_path, map_location="cpu") model = ForwardSurrogate( param_dim=int(checkpoint["param_dim"]), schedule_dim=int(checkpoint["schedule_dim"]), curve_dim=int(checkpoint["curve_dim"]), hidden_dim=int(checkpoint["hidden_dim"]), dropout=float(checkpoint["dropout"]), use_schedule=bool(checkpoint.get("use_schedule", True)), ).to(device) model.load_state_dict(checkpoint["model_state_dict"]) model.eval() return model, checkpoint def safe_mean(x: np.ndarray) -> float: """计算忽略 NaN 后的均值;没有有效数据时返回 NaN。""" return float(np.mean(x)) if x.size else np.nan def safe_median(x: np.ndarray) -> float: """计算忽略 NaN 后的中位数;没有有效数据时返回 NaN。""" return float(np.median(x)) if x.size else np.nan def safe_percentile(x: np.ndarray, q: float) -> float: """计算忽略 NaN 后的百分位数;没有有效数据时返回 NaN。""" return float(np.percentile(x, q)) if x.size else np.nan def pearson_corr(x: np.ndarray, y: np.ndarray) -> float: """计算 Pearson 相关系数;样本数不足或方差为零时返回 NaN。""" if x.size == 0 or y.size == 0: return np.nan if np.allclose(np.std(x), 0.0) or np.allclose(np.std(y), 0.0): return np.nan return float(np.corrcoef(x, y)[0, 1]) def plot_uncertainty_scatter(sample_rows: list[dict], output_path: Path) -> None: """绘制预测不确定性与真实误差的散点图,检查二者是否相关。""" rmse = np.array([row["overall_rmse"] for row in sample_rows], dtype=np.float64) unc = np.array([row["unc_mean_std"] for row in sample_rows], dtype=np.float64) plt.figure(figsize=(7, 5)) plt.scatter(unc, rmse, s=10, alpha=0.35) plt.xlabel("Predictive Uncertainty (mean std)") plt.ylabel("Overall RMSE") plt.title(f"Uncertainty vs Error | Pearson={pearson_corr(unc, rmse):.4f}") plt.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() def plot_uncertain_sample( idx: int, curve_true: np.ndarray, curve_mean: np.ndarray, curve_std: np.ndarray, curve_layout: dict, output_dir: Path, unc_score: float, rmse: float, ) -> None: """绘制高不确定性样本的真实曲线、集成均值和成员分布。""" true_parts = split_curve_by_layout(curve_true, curve_layout) mean_parts = split_curve_by_layout(curve_mean, curve_layout) std_parts = split_curve_by_layout(curve_std, curve_layout) title_map = { "log_pressure": "Log Pressure", "log_derivative": "Log |Derivative|", "slope": "Slope of Log Pressure vs Log Time", } fig, axes = plt.subplots(3, 1, figsize=(12, 10)) fig.suptitle( f"High-Uncertainty Sample #{idx} | unc_mean_std={unc_score:.4f}, overall_rmse={rmse:.4f}" ) for ax, name in zip(axes, ["log_pressure", "log_derivative", "slope"]): y_true = true_parts[name] y_mean = mean_parts[name] y_std = std_parts[name] x = np.arange(len(y_true)) ax.plot(x, y_true, label="True", linewidth=2) ax.plot(x, y_mean, label="Ensemble mean", linewidth=2) ax.fill_between( x, y_mean - 2.0 * y_std, y_mean + 2.0 * y_std, alpha=0.2, label="mean ± 2 std", ) ax.set_title(title_map[name]) ax.grid(True, alpha=0.3) ax.legend() plt.tight_layout(rect=[0, 0, 1, 0.96]) plt.savefig(output_dir / f"top_uncertain_sample_{idx:04d}.png", dpi=150, bbox_inches="tight") plt.close() def main() -> None: """汇总多个代理模型成员的均值和方差,用于测试集误差与不确定性评估。""" args = parse_args() tag = normalize_tag(args.tag) use_schedule = not args.no_schedule seeds = parse_seed_list(args.seeds) processed_path = ( Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) ) model_root = ( Path(args.model_root) if args.model_root is not None else default_model_root(tag, use_schedule) ) output_dir = ( Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag, use_schedule) ) output_dir.mkdir(parents=True, exist_ok=True) # 集成评估必须使用同一份 processed 数据,保证各成员输入标准化口径一致。 data = joblib.load(processed_path) x_params_test = data["X_params_test"] x_schedule_test = data["X_schedule_test"] y_curve_test = data["Y_curve_test"] scaler_curve = data["scaler_curve"] curve_layout = infer_curve_layout(data["meta"], int(data["meta"]["curve_dim"])) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") members: list[tuple[int, ForwardSurrogate]] = [] member_paths = [] first_use_schedule = None for seed in seeds: ckpt_path = model_root / f"seed_{seed}" / "forward_surrogate_best.pt" model, checkpoint = load_member(ckpt_path, device) member_paths.append(str(ckpt_path)) members.append((seed, model)) cur_use_schedule = bool(checkpoint.get("use_schedule", True)) if first_use_schedule is None: first_use_schedule = cur_use_schedule elif first_use_schedule != cur_use_schedule: # 混用带/不带 schedule 的模型会导致输入含义不同,不能直接求均值和方差。 raise RuntimeError("Ensemble 成员的 use_schedule 设置不一致") all_true = [] all_mean = [] all_std = [] sample_rows = [] with torch.no_grad(): for idx in range(len(x_params_test)): params_t = torch.tensor( x_params_test[idx : idx + 1], dtype=torch.float32, device=device, ) schedule_t = torch.tensor( x_schedule_test[idx : idx + 1], dtype=torch.float32, device=device, ) member_preds = [] for _, model in members: # 每个成员独立预测后先反标准化;集成均值和标准差都在原始曲线尺度上计算。 if first_use_schedule: pred_scaled = model(params_t, schedule_t).cpu().numpy() else: pred_scaled = model(params_t, None).cpu().numpy() pred = scaler_curve.inverse_transform(pred_scaled)[0].astype(np.float32) member_preds.append(pred) member_preds = np.stack(member_preds, axis=0) curve_true = scaler_curve.inverse_transform( y_curve_test[idx : idx + 1] )[0].astype(np.float32) curve_mean = member_preds.mean(axis=0).astype(np.float32) curve_std = member_preds.std(axis=0, ddof=0).astype(np.float32) metrics = calc_metrics(curve_true, curve_mean) parts_std = split_curve_by_layout(curve_std, curve_layout) # std 作为经验不确定性指标,后续会和真实 RMSE 做相关性分析。 sample_rows.append( { "idx": idx, "overall_rmse": metrics["rmse"], "overall_mae": metrics["mae"], "overall_bias": metrics["bias"], "overall_r2": metrics["r2"], "unc_mean_std": float(np.mean(curve_std)), "unc_max_std": float(np.max(curve_std)), "unc_log_pressure_mean_std": float(np.mean(parts_std["log_pressure"])), "unc_log_derivative_mean_std": float(np.mean(parts_std["log_derivative"])), "unc_slope_mean_std": float(np.mean(parts_std["slope"])), } ) all_true.append(curve_true) all_mean.append(curve_mean) all_std.append(curve_std) all_true = np.stack(all_true, axis=0) all_mean = np.stack(all_mean, axis=0) all_std = np.stack(all_std, axis=0) overall_metrics = [calc_metrics(t, p) for t, p in zip(all_true, all_mean)] rmse = np.array([m["rmse"] for m in overall_metrics], dtype=np.float64) mae = np.array([m["mae"] for m in overall_metrics], dtype=np.float64) r2_valid = np.array([m["r2"] for m in overall_metrics if m["valid_r2"]], dtype=np.float64) unc = np.array([row["unc_mean_std"] for row in sample_rows], dtype=np.float64) # 聚合输出分为预测质量和不确定性质量两部分,方便单独比较 ensemble 是否有价值。 summary = { "ensemble": { "member_count": len(members), "member_paths": member_paths, "use_schedule": bool(first_use_schedule), "processed_path": str(processed_path), }, "prediction": { "rmse_mean": safe_mean(rmse), "rmse_median": safe_median(rmse), "rmse_p90": safe_percentile(rmse, 90), "mae_mean": safe_mean(mae), "mae_median": safe_median(mae), "r2_mean_valid": safe_mean(r2_valid), "r2_median_valid": safe_median(r2_valid), }, "uncertainty": { "unc_mean": safe_mean(unc), "unc_median": safe_median(unc), "unc_p90": safe_percentile(unc, 90), "unc_vs_rmse_pearson": pearson_corr(unc, rmse), }, } with open(output_dir / "ensemble_uq_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) with open( output_dir / "sample_uncertainty_metrics.csv", "w", newline="", encoding="utf-8-sig", ) as f: writer = csv.DictWriter(f, fieldnames=list(sample_rows[0].keys())) writer.writeheader() writer.writerows(sample_rows) plot_uncertainty_scatter(sample_rows, output_dir / "uncertainty_vs_error.png") top_k = min(args.n_top_uncertain_plots, len(sample_rows)) top_uncertain = sorted(sample_rows, key=lambda row: row["unc_mean_std"], reverse=True)[:top_k] # 只绘制最高不确定性的样本,重点检查模型成员分歧最大的区域。 for row in top_uncertain: idx = int(row["idx"]) plot_uncertain_sample( idx=idx, curve_true=all_true[idx], curve_mean=all_mean[idx], curve_std=all_std[idx], curve_layout=curve_layout, output_dir=output_dir, unc_score=float(row["unc_mean_std"]), rmse=float(row["overall_rmse"]), ) print("Ensemble UQ evaluation complete.") print(f"Output dir: {output_dir}") print(f"RMSE mean={summary['prediction']['rmse_mean']:.6f}") print(f"UQ-RMSE Pearson={summary['uncertainty']['unc_vs_rmse_pearson']:.6f}") if __name__ == "__main__": main()