You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/analyze_uq_results.py

233 lines
8.5 KiB
Python

from __future__ import annotations
import argparse
import csv
import json
import sys
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Analyze ensemble UQ outputs for fallback usefulness")
parser.add_argument("--input-csv", type=str, default=None, help="Path to sample_uncertainty_metrics.csv")
parser.add_argument("--output-dir", type=str, default=None, help="Optional output directory")
parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag")
parser.add_argument(
"--quantiles",
type=str,
default="0,5,10,20,30,40,50",
help="Comma-separated uncertainty removal percentages",
)
parser.add_argument("--top-k", type=int, default=100, help="Top-K risky samples to export")
parser.add_argument("--high-error-rmse", type=float, default=1.0, help="High-error RMSE threshold")
parser.add_argument("--low-unc-quantile", type=float, default=50.0, help="Low uncertainty quantile")
parser.add_argument("--high-unc-quantile", type=float, default=90.0, help="High uncertainty quantile")
return parser.parse_args()
def parse_quantiles(text: str) -> list[float]:
values = []
for item in str(text).split(","):
item = item.strip()
if not item:
continue
q = float(item)
if q < 0 or q >= 100:
raise ValueError(f"非法 quantile 百分比: {q}")
values.append(q)
if 0.0 not in values:
values = [0.0] + values
return sorted(set(values))
def default_input_csv(tag: str) -> Path:
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
def default_output_dir(tag: str) -> Path:
return Path("results") / f"evaluation_{tag}_ensemble_uq_analysis"
def load_rows(csv_path: Path) -> list[dict]:
with open(csv_path, "r", encoding="utf-8-sig", newline="") as f:
rows = list(csv.DictReader(f))
if not rows:
raise ValueError(f"CSV 没有数据: {csv_path}")
return rows
def as_float(rows: list[dict], key: str) -> np.ndarray:
return np.array([float(row[key]) for row in rows], dtype=np.float64)
def safe_mean(x: np.ndarray) -> float:
return float(np.mean(x)) if x.size else np.nan
def safe_median(x: np.ndarray) -> float:
return float(np.median(x)) if x.size else np.nan
def safe_percentile(x: np.ndarray, q: float) -> float:
return float(np.percentile(x, q)) if x.size else np.nan
def save_csv(path: Path, rows: list[dict]) -> None:
if not rows:
return
with open(path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
def plot_retention_curve(curve_rows: list[dict], output_path: Path) -> None:
removed = np.array([float(row["removed_pct"]) for row in curve_rows], dtype=np.float64)
retained = np.array([float(row["retained_pct"]) for row in curve_rows], dtype=np.float64)
rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64)
mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64)
r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64)
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
axes[0].plot(retained, rmse, marker="o")
axes[0].set_title("Retained vs RMSE")
axes[0].set_xlabel("Retained Samples (%)")
axes[0].set_ylabel("RMSE mean")
axes[0].grid(True, alpha=0.3)
axes[1].plot(retained, mae, marker="o")
axes[1].set_title("Retained vs MAE")
axes[1].set_xlabel("Retained Samples (%)")
axes[1].set_ylabel("MAE mean")
axes[1].grid(True, alpha=0.3)
axes[2].plot(retained, r2, marker="o")
axes[2].set_title("Retained vs R2")
axes[2].set_xlabel("Retained Samples (%)")
axes[2].set_ylabel("R2 mean")
axes[2].grid(True, alpha=0.3)
fig.suptitle(
f"Uncertainty Filtering Curves | removed grid={','.join(str(int(x)) for x in removed)}%"
)
plt.tight_layout(rect=[0, 0, 1, 0.94])
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def main() -> None:
args = parse_args()
tag = str(args.tag).strip()
input_csv = Path(args.input_csv) if args.input_csv is not None else default_input_csv(tag)
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag)
output_dir.mkdir(parents=True, exist_ok=True)
rows = load_rows(input_csv)
quantiles = parse_quantiles(args.quantiles)
rmse = as_float(rows, "overall_rmse")
mae = as_float(rows, "overall_mae")
r2 = as_float(rows, "overall_r2")
unc = as_float(rows, "unc_mean_std")
total_n = len(rows)
low_unc_thr = float(np.percentile(unc, float(args.low_unc_quantile)))
high_unc_thr = float(np.percentile(unc, float(args.high_unc_quantile)))
curve_rows: list[dict] = []
for removed_pct in quantiles:
thr = float(np.percentile(unc, 100.0 - removed_pct))
keep_mask = unc <= thr
kept_rmse = rmse[keep_mask]
kept_mae = mae[keep_mask]
kept_r2 = r2[keep_mask]
curve_rows.append(
{
"removed_pct": float(removed_pct),
"retained_pct": float(100.0 * np.mean(keep_mask)),
"n_retained": int(np.sum(keep_mask)),
"rmse_mean": safe_mean(kept_rmse),
"rmse_median": safe_median(kept_rmse),
"rmse_p90": safe_percentile(kept_rmse, 90),
"mae_mean": safe_mean(kept_mae),
"mae_median": safe_median(kept_mae),
"r2_mean": safe_mean(kept_r2),
"r2_median": safe_median(kept_r2),
"unc_threshold": thr,
}
)
top_unc_idx = np.argsort(-unc)[: min(args.top_k, total_n)]
top_uncertain_rows = [rows[int(i)] for i in top_unc_idx]
high_error_high_unc_rows = []
high_error_low_unc_rows = []
for row in rows:
row_rmse = float(row["overall_rmse"])
row_unc = float(row["unc_mean_std"])
if row_rmse >= float(args.high_error_rmse) and row_unc >= high_unc_thr:
high_error_high_unc_rows.append(row)
if row_rmse >= float(args.high_error_rmse) and row_unc <= low_unc_thr:
high_error_low_unc_rows.append(row)
high_error_high_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
high_error_low_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
summary = {
"input_csv": str(input_csv),
"n_samples": total_n,
"global": {
"rmse_mean": safe_mean(rmse),
"rmse_median": safe_median(rmse),
"rmse_p90": safe_percentile(rmse, 90),
"mae_mean": safe_mean(mae),
"r2_mean": safe_mean(r2),
"unc_mean": safe_mean(unc),
"unc_median": safe_median(unc),
"unc_p90": safe_percentile(unc, 90),
},
"thresholds": {
"high_error_rmse": float(args.high_error_rmse),
"low_unc_quantile": float(args.low_unc_quantile),
"low_unc_threshold": low_unc_thr,
"high_unc_quantile": float(args.high_unc_quantile),
"high_unc_threshold": high_unc_thr,
},
"counts": {
"top_uncertain_exported": len(top_uncertain_rows),
"high_error_high_unc": len(high_error_high_unc_rows),
"high_error_low_unc": len(high_error_low_unc_rows),
},
}
with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
save_csv(output_dir / "uncertainty_filter_curve.csv", curve_rows)
save_csv(output_dir / "top_uncertain_samples.csv", top_uncertain_rows)
save_csv(output_dir / "high_error_high_unc_samples.csv", high_error_high_unc_rows)
save_csv(output_dir / "high_error_low_unc_samples.csv", high_error_low_unc_rows)
plot_retention_curve(curve_rows, output_dir / "uncertainty_filter_curve.png")
print("UQ fallback analysis complete.")
print(f"Output dir: {output_dir}")
print(
f"Global RMSE mean={summary['global']['rmse_mean']:.6f}, "
f"unc mean={summary['global']['unc_mean']:.6f}"
)
print(
f"High-error & high-unc count={summary['counts']['high_error_high_unc']}, "
f"high-error & low-unc count={summary['counts']['high_error_low_unc']}"
)
if __name__ == "__main__":
main()