You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
233 lines
8.5 KiB
Python
233 lines
8.5 KiB
Python
|
3 weeks ago
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import csv
|
||
|
|
import json
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
ROOT = Path(__file__).resolve().parents[1]
|
||
|
|
sys.path.append(str(ROOT))
|
||
|
|
|
||
|
|
|
||
|
|
def parse_args() -> argparse.Namespace:
|
||
|
|
parser = argparse.ArgumentParser(description="Analyze ensemble UQ outputs for fallback usefulness")
|
||
|
|
parser.add_argument("--input-csv", type=str, default=None, help="Path to sample_uncertainty_metrics.csv")
|
||
|
|
parser.add_argument("--output-dir", type=str, default=None, help="Optional output directory")
|
||
|
|
parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag")
|
||
|
|
parser.add_argument(
|
||
|
|
"--quantiles",
|
||
|
|
type=str,
|
||
|
|
default="0,5,10,20,30,40,50",
|
||
|
|
help="Comma-separated uncertainty removal percentages",
|
||
|
|
)
|
||
|
|
parser.add_argument("--top-k", type=int, default=100, help="Top-K risky samples to export")
|
||
|
|
parser.add_argument("--high-error-rmse", type=float, default=1.0, help="High-error RMSE threshold")
|
||
|
|
parser.add_argument("--low-unc-quantile", type=float, default=50.0, help="Low uncertainty quantile")
|
||
|
|
parser.add_argument("--high-unc-quantile", type=float, default=90.0, help="High uncertainty quantile")
|
||
|
|
return parser.parse_args()
|
||
|
|
|
||
|
|
|
||
|
|
def parse_quantiles(text: str) -> list[float]:
|
||
|
|
values = []
|
||
|
|
for item in str(text).split(","):
|
||
|
|
item = item.strip()
|
||
|
|
if not item:
|
||
|
|
continue
|
||
|
|
q = float(item)
|
||
|
|
if q < 0 or q >= 100:
|
||
|
|
raise ValueError(f"非法 quantile 百分比: {q}")
|
||
|
|
values.append(q)
|
||
|
|
if 0.0 not in values:
|
||
|
|
values = [0.0] + values
|
||
|
|
return sorted(set(values))
|
||
|
|
|
||
|
|
|
||
|
|
def default_input_csv(tag: str) -> Path:
|
||
|
|
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
|
||
|
|
|
||
|
|
|
||
|
|
def default_output_dir(tag: str) -> Path:
|
||
|
|
return Path("results") / f"evaluation_{tag}_ensemble_uq_analysis"
|
||
|
|
|
||
|
|
|
||
|
|
def load_rows(csv_path: Path) -> list[dict]:
|
||
|
|
with open(csv_path, "r", encoding="utf-8-sig", newline="") as f:
|
||
|
|
rows = list(csv.DictReader(f))
|
||
|
|
if not rows:
|
||
|
|
raise ValueError(f"CSV 没有数据: {csv_path}")
|
||
|
|
return rows
|
||
|
|
|
||
|
|
|
||
|
|
def as_float(rows: list[dict], key: str) -> np.ndarray:
|
||
|
|
return np.array([float(row[key]) for row in rows], dtype=np.float64)
|
||
|
|
|
||
|
|
|
||
|
|
def safe_mean(x: np.ndarray) -> float:
|
||
|
|
return float(np.mean(x)) if x.size else np.nan
|
||
|
|
|
||
|
|
|
||
|
|
def safe_median(x: np.ndarray) -> float:
|
||
|
|
return float(np.median(x)) if x.size else np.nan
|
||
|
|
|
||
|
|
|
||
|
|
def safe_percentile(x: np.ndarray, q: float) -> float:
|
||
|
|
return float(np.percentile(x, q)) if x.size else np.nan
|
||
|
|
|
||
|
|
|
||
|
|
def save_csv(path: Path, rows: list[dict]) -> None:
|
||
|
|
if not rows:
|
||
|
|
return
|
||
|
|
with open(path, "w", encoding="utf-8-sig", newline="") as f:
|
||
|
|
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
|
||
|
|
writer.writeheader()
|
||
|
|
writer.writerows(rows)
|
||
|
|
|
||
|
|
|
||
|
|
def plot_retention_curve(curve_rows: list[dict], output_path: Path) -> None:
|
||
|
|
removed = np.array([float(row["removed_pct"]) for row in curve_rows], dtype=np.float64)
|
||
|
|
retained = np.array([float(row["retained_pct"]) for row in curve_rows], dtype=np.float64)
|
||
|
|
rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64)
|
||
|
|
mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64)
|
||
|
|
r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64)
|
||
|
|
|
||
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
|
||
|
|
|
||
|
|
axes[0].plot(retained, rmse, marker="o")
|
||
|
|
axes[0].set_title("Retained vs RMSE")
|
||
|
|
axes[0].set_xlabel("Retained Samples (%)")
|
||
|
|
axes[0].set_ylabel("RMSE mean")
|
||
|
|
axes[0].grid(True, alpha=0.3)
|
||
|
|
|
||
|
|
axes[1].plot(retained, mae, marker="o")
|
||
|
|
axes[1].set_title("Retained vs MAE")
|
||
|
|
axes[1].set_xlabel("Retained Samples (%)")
|
||
|
|
axes[1].set_ylabel("MAE mean")
|
||
|
|
axes[1].grid(True, alpha=0.3)
|
||
|
|
|
||
|
|
axes[2].plot(retained, r2, marker="o")
|
||
|
|
axes[2].set_title("Retained vs R2")
|
||
|
|
axes[2].set_xlabel("Retained Samples (%)")
|
||
|
|
axes[2].set_ylabel("R2 mean")
|
||
|
|
axes[2].grid(True, alpha=0.3)
|
||
|
|
|
||
|
|
fig.suptitle(
|
||
|
|
f"Uncertainty Filtering Curves | removed grid={','.join(str(int(x)) for x in removed)}%"
|
||
|
|
)
|
||
|
|
plt.tight_layout(rect=[0, 0, 1, 0.94])
|
||
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
||
|
|
plt.close()
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
args = parse_args()
|
||
|
|
tag = str(args.tag).strip()
|
||
|
|
input_csv = Path(args.input_csv) if args.input_csv is not None else default_input_csv(tag)
|
||
|
|
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag)
|
||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
rows = load_rows(input_csv)
|
||
|
|
quantiles = parse_quantiles(args.quantiles)
|
||
|
|
|
||
|
|
rmse = as_float(rows, "overall_rmse")
|
||
|
|
mae = as_float(rows, "overall_mae")
|
||
|
|
r2 = as_float(rows, "overall_r2")
|
||
|
|
unc = as_float(rows, "unc_mean_std")
|
||
|
|
|
||
|
|
total_n = len(rows)
|
||
|
|
low_unc_thr = float(np.percentile(unc, float(args.low_unc_quantile)))
|
||
|
|
high_unc_thr = float(np.percentile(unc, float(args.high_unc_quantile)))
|
||
|
|
|
||
|
|
curve_rows: list[dict] = []
|
||
|
|
for removed_pct in quantiles:
|
||
|
|
thr = float(np.percentile(unc, 100.0 - removed_pct))
|
||
|
|
keep_mask = unc <= thr
|
||
|
|
kept_rmse = rmse[keep_mask]
|
||
|
|
kept_mae = mae[keep_mask]
|
||
|
|
kept_r2 = r2[keep_mask]
|
||
|
|
curve_rows.append(
|
||
|
|
{
|
||
|
|
"removed_pct": float(removed_pct),
|
||
|
|
"retained_pct": float(100.0 * np.mean(keep_mask)),
|
||
|
|
"n_retained": int(np.sum(keep_mask)),
|
||
|
|
"rmse_mean": safe_mean(kept_rmse),
|
||
|
|
"rmse_median": safe_median(kept_rmse),
|
||
|
|
"rmse_p90": safe_percentile(kept_rmse, 90),
|
||
|
|
"mae_mean": safe_mean(kept_mae),
|
||
|
|
"mae_median": safe_median(kept_mae),
|
||
|
|
"r2_mean": safe_mean(kept_r2),
|
||
|
|
"r2_median": safe_median(kept_r2),
|
||
|
|
"unc_threshold": thr,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
top_unc_idx = np.argsort(-unc)[: min(args.top_k, total_n)]
|
||
|
|
top_uncertain_rows = [rows[int(i)] for i in top_unc_idx]
|
||
|
|
|
||
|
|
high_error_high_unc_rows = []
|
||
|
|
high_error_low_unc_rows = []
|
||
|
|
for row in rows:
|
||
|
|
row_rmse = float(row["overall_rmse"])
|
||
|
|
row_unc = float(row["unc_mean_std"])
|
||
|
|
if row_rmse >= float(args.high_error_rmse) and row_unc >= high_unc_thr:
|
||
|
|
high_error_high_unc_rows.append(row)
|
||
|
|
if row_rmse >= float(args.high_error_rmse) and row_unc <= low_unc_thr:
|
||
|
|
high_error_low_unc_rows.append(row)
|
||
|
|
|
||
|
|
high_error_high_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
|
||
|
|
high_error_low_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
|
||
|
|
|
||
|
|
summary = {
|
||
|
|
"input_csv": str(input_csv),
|
||
|
|
"n_samples": total_n,
|
||
|
|
"global": {
|
||
|
|
"rmse_mean": safe_mean(rmse),
|
||
|
|
"rmse_median": safe_median(rmse),
|
||
|
|
"rmse_p90": safe_percentile(rmse, 90),
|
||
|
|
"mae_mean": safe_mean(mae),
|
||
|
|
"r2_mean": safe_mean(r2),
|
||
|
|
"unc_mean": safe_mean(unc),
|
||
|
|
"unc_median": safe_median(unc),
|
||
|
|
"unc_p90": safe_percentile(unc, 90),
|
||
|
|
},
|
||
|
|
"thresholds": {
|
||
|
|
"high_error_rmse": float(args.high_error_rmse),
|
||
|
|
"low_unc_quantile": float(args.low_unc_quantile),
|
||
|
|
"low_unc_threshold": low_unc_thr,
|
||
|
|
"high_unc_quantile": float(args.high_unc_quantile),
|
||
|
|
"high_unc_threshold": high_unc_thr,
|
||
|
|
},
|
||
|
|
"counts": {
|
||
|
|
"top_uncertain_exported": len(top_uncertain_rows),
|
||
|
|
"high_error_high_unc": len(high_error_high_unc_rows),
|
||
|
|
"high_error_low_unc": len(high_error_low_unc_rows),
|
||
|
|
},
|
||
|
|
}
|
||
|
|
|
||
|
|
with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as f:
|
||
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
||
|
|
|
||
|
|
save_csv(output_dir / "uncertainty_filter_curve.csv", curve_rows)
|
||
|
|
save_csv(output_dir / "top_uncertain_samples.csv", top_uncertain_rows)
|
||
|
|
save_csv(output_dir / "high_error_high_unc_samples.csv", high_error_high_unc_rows)
|
||
|
|
save_csv(output_dir / "high_error_low_unc_samples.csv", high_error_low_unc_rows)
|
||
|
|
plot_retention_curve(curve_rows, output_dir / "uncertainty_filter_curve.png")
|
||
|
|
|
||
|
|
print("UQ fallback analysis complete.")
|
||
|
|
print(f"Output dir: {output_dir}")
|
||
|
|
print(
|
||
|
|
f"Global RMSE mean={summary['global']['rmse_mean']:.6f}, "
|
||
|
|
f"unc mean={summary['global']['unc_mean']:.6f}"
|
||
|
|
)
|
||
|
|
print(
|
||
|
|
f"High-error & high-unc count={summary['counts']['high_error_high_unc']}, "
|
||
|
|
f"high-error & low-unc count={summary['counts']['high_error_low_unc']}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|