from __future__ import annotations import argparse import csv import json import sys from pathlib import Path import matplotlib.pyplot as plt import numpy as np ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Analyze ensemble UQ outputs for fallback usefulness") parser.add_argument("--input-csv", type=str, default=None, help="Path to sample_uncertainty_metrics.csv") parser.add_argument("--output-dir", type=str, default=None, help="Optional output directory") parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag") parser.add_argument( "--quantiles", type=str, default="0,5,10,20,30,40,50", help="Comma-separated uncertainty removal percentages", ) parser.add_argument("--top-k", type=int, default=100, help="Top-K risky samples to export") parser.add_argument("--high-error-rmse", type=float, default=1.0, help="High-error RMSE threshold") parser.add_argument("--low-unc-quantile", type=float, default=50.0, help="Low uncertainty quantile") parser.add_argument("--high-unc-quantile", type=float, default=90.0, help="High uncertainty quantile") return parser.parse_args() def parse_quantiles(text: str) -> list[float]: values = [] for item in str(text).split(","): item = item.strip() if not item: continue q = float(item) if q < 0 or q >= 100: raise ValueError(f"非法 quantile 百分比: {q}") values.append(q) if 0.0 not in values: values = [0.0] + values return sorted(set(values)) def default_input_csv(tag: str) -> Path: return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv" def default_output_dir(tag: str) -> Path: return Path("results") / f"evaluation_{tag}_ensemble_uq_analysis" def load_rows(csv_path: Path) -> list[dict]: with open(csv_path, "r", encoding="utf-8-sig", newline="") as f: rows = list(csv.DictReader(f)) if not rows: raise ValueError(f"CSV 没有数据: {csv_path}") return rows def as_float(rows: list[dict], key: str) -> np.ndarray: return np.array([float(row[key]) for row in rows], dtype=np.float64) def safe_mean(x: np.ndarray) -> float: return float(np.mean(x)) if x.size else np.nan def safe_median(x: np.ndarray) -> float: return float(np.median(x)) if x.size else np.nan def safe_percentile(x: np.ndarray, q: float) -> float: return float(np.percentile(x, q)) if x.size else np.nan def save_csv(path: Path, rows: list[dict]) -> None: if not rows: return with open(path, "w", encoding="utf-8-sig", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) def plot_retention_curve(curve_rows: list[dict], output_path: Path) -> None: removed = np.array([float(row["removed_pct"]) for row in curve_rows], dtype=np.float64) retained = np.array([float(row["retained_pct"]) for row in curve_rows], dtype=np.float64) rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64) mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64) r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64) fig, axes = plt.subplots(1, 3, figsize=(15, 4.5)) axes[0].plot(retained, rmse, marker="o") axes[0].set_title("Retained vs RMSE") axes[0].set_xlabel("Retained Samples (%)") axes[0].set_ylabel("RMSE mean") axes[0].grid(True, alpha=0.3) axes[1].plot(retained, mae, marker="o") axes[1].set_title("Retained vs MAE") axes[1].set_xlabel("Retained Samples (%)") axes[1].set_ylabel("MAE mean") axes[1].grid(True, alpha=0.3) axes[2].plot(retained, r2, marker="o") axes[2].set_title("Retained vs R2") axes[2].set_xlabel("Retained Samples (%)") axes[2].set_ylabel("R2 mean") axes[2].grid(True, alpha=0.3) fig.suptitle( f"Uncertainty Filtering Curves | removed grid={','.join(str(int(x)) for x in removed)}%" ) plt.tight_layout(rect=[0, 0, 1, 0.94]) plt.savefig(output_path, dpi=150, bbox_inches="tight") plt.close() def main() -> None: args = parse_args() tag = str(args.tag).strip() input_csv = Path(args.input_csv) if args.input_csv is not None else default_input_csv(tag) output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag) output_dir.mkdir(parents=True, exist_ok=True) rows = load_rows(input_csv) quantiles = parse_quantiles(args.quantiles) rmse = as_float(rows, "overall_rmse") mae = as_float(rows, "overall_mae") r2 = as_float(rows, "overall_r2") unc = as_float(rows, "unc_mean_std") total_n = len(rows) low_unc_thr = float(np.percentile(unc, float(args.low_unc_quantile))) high_unc_thr = float(np.percentile(unc, float(args.high_unc_quantile))) curve_rows: list[dict] = [] for removed_pct in quantiles: thr = float(np.percentile(unc, 100.0 - removed_pct)) keep_mask = unc <= thr kept_rmse = rmse[keep_mask] kept_mae = mae[keep_mask] kept_r2 = r2[keep_mask] curve_rows.append( { "removed_pct": float(removed_pct), "retained_pct": float(100.0 * np.mean(keep_mask)), "n_retained": int(np.sum(keep_mask)), "rmse_mean": safe_mean(kept_rmse), "rmse_median": safe_median(kept_rmse), "rmse_p90": safe_percentile(kept_rmse, 90), "mae_mean": safe_mean(kept_mae), "mae_median": safe_median(kept_mae), "r2_mean": safe_mean(kept_r2), "r2_median": safe_median(kept_r2), "unc_threshold": thr, } ) top_unc_idx = np.argsort(-unc)[: min(args.top_k, total_n)] top_uncertain_rows = [rows[int(i)] for i in top_unc_idx] high_error_high_unc_rows = [] high_error_low_unc_rows = [] for row in rows: row_rmse = float(row["overall_rmse"]) row_unc = float(row["unc_mean_std"]) if row_rmse >= float(args.high_error_rmse) and row_unc >= high_unc_thr: high_error_high_unc_rows.append(row) if row_rmse >= float(args.high_error_rmse) and row_unc <= low_unc_thr: high_error_low_unc_rows.append(row) high_error_high_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True) high_error_low_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True) summary = { "input_csv": str(input_csv), "n_samples": total_n, "global": { "rmse_mean": safe_mean(rmse), "rmse_median": safe_median(rmse), "rmse_p90": safe_percentile(rmse, 90), "mae_mean": safe_mean(mae), "r2_mean": safe_mean(r2), "unc_mean": safe_mean(unc), "unc_median": safe_median(unc), "unc_p90": safe_percentile(unc, 90), }, "thresholds": { "high_error_rmse": float(args.high_error_rmse), "low_unc_quantile": float(args.low_unc_quantile), "low_unc_threshold": low_unc_thr, "high_unc_quantile": float(args.high_unc_quantile), "high_unc_threshold": high_unc_thr, }, "counts": { "top_uncertain_exported": len(top_uncertain_rows), "high_error_high_unc": len(high_error_high_unc_rows), "high_error_low_unc": len(high_error_low_unc_rows), }, } with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) save_csv(output_dir / "uncertainty_filter_curve.csv", curve_rows) save_csv(output_dir / "top_uncertain_samples.csv", top_uncertain_rows) save_csv(output_dir / "high_error_high_unc_samples.csv", high_error_high_unc_rows) save_csv(output_dir / "high_error_low_unc_samples.csv", high_error_low_unc_rows) plot_retention_curve(curve_rows, output_dir / "uncertainty_filter_curve.png") print("UQ fallback analysis complete.") print(f"Output dir: {output_dir}") print( f"Global RMSE mean={summary['global']['rmse_mean']:.6f}, " f"unc mean={summary['global']['unc_mean']:.6f}" ) print( f"High-error & high-unc count={summary['counts']['high_error_high_unc']}, " f"high-error & low-unc count={summary['counts']['high_error_low_unc']}" ) if __name__ == "__main__": main()