You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/analyze_uq_results.py

456 lines
14 KiB
Python

"""
分析集成代理模型不确定性结果
评估 uncertainty 是否能够识别高误差样本
并生成 retention curve风险样本统计和 summary 输出
"""
from __future__ import annotations
import argparse
import csv
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import matplotlib.pyplot as plt
import numpy as np
# 获取当前脚本所在目录的上一级项目根目录。
# 这里假设脚本位于项目的某个子目录中,因此 parents[1] 指向项目根目录。
ROOT = Path(__file__).resolve().parents[1]
# 将项目根目录加入 Python 搜索路径,方便后续导入项目内部模块。
# 当前脚本虽然没有直接导入内部模块,但保留该设置可以兼容项目结构。
sys.path.append(str(ROOT))
Row = dict[str, str]
OutputRow = dict[str, Any]
@dataclass(frozen=True)
class Metrics:
"""保存样本级误差和不确定性指标,减少 main 函数中的局部变量数量。"""
rmse: np.ndarray
mae: np.ndarray
r2: np.ndarray
unc: np.ndarray
@dataclass(frozen=True)
class UncertaintyThresholds:
"""保存低/高不确定性阈值。"""
low: float
high: float
@dataclass(frozen=True)
class RiskRows:
"""保存几类需要导出的风险样本。"""
top_uncertain: list[Row]
high_error_high_unc: list[Row]
high_error_low_unc: list[Row]
def parse_args() -> argparse.Namespace:
"""解析不确定性评估结果 CSV、输出目录以及分位数分箱参数。"""
parser = argparse.ArgumentParser(
description="Analyze ensemble UQ outputs for fallback usefulness"
)
parser.add_argument(
"--input-csv",
type=str,
default=None,
help="Path to sample_uncertainty_metrics.csv",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Optional output directory",
)
parser.add_argument(
"--tag",
type=str,
default="family_random_50k",
help="Experiment tag",
)
parser.add_argument(
"--quantiles",
type=str,
default="0,5,10,20,30,40,50",
help="Comma-separated uncertainty removal percentages",
)
parser.add_argument(
"--top-k",
type=int,
default=100,
help="Top-K risky samples to export",
)
parser.add_argument(
"--high-error-rmse",
type=float,
default=1.0,
help="High-error RMSE threshold",
)
parser.add_argument(
"--low-unc-quantile",
type=float,
default=50.0,
help="Low uncertainty quantile",
)
parser.add_argument(
"--high-unc-quantile",
type=float,
default=90.0,
help="High uncertainty quantile",
)
return parser.parse_args()
def parse_quantiles(text: str) -> list[float]:
"""解析逗号分隔的剔除比例,用于不确定性保留曲线分析。"""
values = []
for item in str(text).split(","):
item = item.strip()
if not item:
continue
q = float(item)
if q < 0 or q >= 100:
raise ValueError(f"非法 quantile 百分比: {q}")
values.append(q)
if 0.0 not in values:
values = [0.0] + values
return sorted(set(values))
def default_input_csv(tag: str) -> Path:
"""根据实验标签定位集成评估生成的 sample_uncertainty_metrics.csv。"""
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
def default_output_dir(tag: str) -> Path:
"""根据实验标签生成当前分析脚本默认的输出目录。"""
return Path("results") / f"evaluation_{tag}_ensemble_uq_analysis"
def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path]:
"""根据命令行参数解析输入 CSV 和输出目录。"""
tag = str(args.tag).strip()
input_csv = Path(args.input_csv) if args.input_csv else default_input_csv(tag)
output_dir = Path(args.output_dir) if args.output_dir else default_output_dir(tag)
return input_csv, output_dir
def load_rows(csv_path: Path) -> list[Row]:
"""读取 CSV 为字典行列表,保留每个样本的不确定性和误差字段。"""
with open(csv_path, "r", encoding="utf-8-sig", newline="") as file:
rows = list(csv.DictReader(file))
if not rows:
raise ValueError(f"CSV 没有数据: {csv_path}")
return rows
def as_float(rows: list[Row], key: str) -> np.ndarray:
"""将 CSV 字段转为 float 数组,用于后续统计分析。"""
return np.array([float(row[key]) for row in rows], dtype=np.float64)
def extract_metrics(rows: list[Row]) -> Metrics:
"""从 CSV 行中提取样本级误差和不确定性指标。"""
return Metrics(
rmse=as_float(rows, "overall_rmse"),
mae=as_float(rows, "overall_mae"),
r2=as_float(rows, "overall_r2"),
unc=as_float(rows, "unc_mean_std"),
)
def safe_mean(x: np.ndarray) -> float:
"""计算均值;没有有效数据时返回 NaN。"""
return float(np.mean(x)) if x.size else np.nan
def safe_median(x: np.ndarray) -> float:
"""计算中位数;没有有效数据时返回 NaN。"""
return float(np.median(x)) if x.size else np.nan
def safe_percentile(x: np.ndarray, q: float) -> float:
"""计算百分位数;没有有效数据时返回 NaN。"""
return float(np.percentile(x, q)) if x.size else np.nan
def save_csv(path: Path, rows: list[OutputRow] | list[Row]) -> None:
"""把字典行写入 CSV。"""
if not rows:
return
with open(path, "w", encoding="utf-8-sig", newline="") as file:
writer = csv.DictWriter(file, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
def build_thresholds(
unc: np.ndarray,
low_unc_quantile: float,
high_unc_quantile: float,
) -> UncertaintyThresholds:
"""根据样本不确定性分布计算低/高不确定性阈值。"""
return UncertaintyThresholds(
low=float(np.percentile(unc, low_unc_quantile)),
high=float(np.percentile(unc, high_unc_quantile)),
)
def build_retention_curve(metrics: Metrics, quantiles: list[float]) -> list[OutputRow]:
"""构造按不确定性由高到低剔除样本后的误差统计曲线。"""
curve_rows: list[OutputRow] = []
for removed_pct in quantiles:
threshold = float(np.percentile(metrics.unc, 100.0 - removed_pct))
keep_mask = metrics.unc <= threshold
kept_rmse = metrics.rmse[keep_mask]
kept_mae = metrics.mae[keep_mask]
kept_r2 = metrics.r2[keep_mask]
curve_rows.append(
{
"removed_pct": float(removed_pct),
"retained_pct": float(100.0 * np.mean(keep_mask)),
"n_retained": int(np.sum(keep_mask)),
"rmse_mean": safe_mean(kept_rmse),
"rmse_median": safe_median(kept_rmse),
"rmse_p90": safe_percentile(kept_rmse, 90),
"mae_mean": safe_mean(kept_mae),
"mae_median": safe_median(kept_mae),
"r2_mean": safe_mean(kept_r2),
"r2_median": safe_median(kept_r2),
"unc_threshold": threshold,
}
)
return curve_rows
def select_top_uncertain_rows(rows: list[Row], unc: np.ndarray, top_k: int) -> list[Row]:
"""按不确定性从大到小导出 Top-K 样本。"""
top_uncertain_indices = np.argsort(-unc)[: min(top_k, len(rows))]
return [rows[int(index)] for index in top_uncertain_indices]
def classify_risky_rows(
rows: list[Row],
thresholds: UncertaintyThresholds,
high_error_rmse: float,
) -> tuple[list[Row], list[Row]]:
"""识别高误差高不确定性样本,以及高误差低不确定性样本。"""
high_error_high_unc_rows = []
high_error_low_unc_rows = []
for row in rows:
row_rmse = float(row["overall_rmse"])
row_unc = float(row["unc_mean_std"])
if row_rmse >= high_error_rmse and row_unc >= thresholds.high:
high_error_high_unc_rows.append(row)
if row_rmse >= high_error_rmse and row_unc <= thresholds.low:
high_error_low_unc_rows.append(row)
high_error_high_unc_rows.sort(
key=lambda row: float(row["overall_rmse"]),
reverse=True,
)
high_error_low_unc_rows.sort(
key=lambda row: float(row["overall_rmse"]),
reverse=True,
)
return high_error_high_unc_rows, high_error_low_unc_rows
def build_risk_rows(
rows: list[Row],
metrics: Metrics,
thresholds: UncertaintyThresholds,
high_error_rmse: float,
top_k: int,
) -> RiskRows:
"""构造所有需要导出的风险样本集合。"""
high_error_high_unc_rows, high_error_low_unc_rows = classify_risky_rows(
rows=rows,
thresholds=thresholds,
high_error_rmse=high_error_rmse,
)
return RiskRows(
top_uncertain=select_top_uncertain_rows(rows, metrics.unc, top_k),
high_error_high_unc=high_error_high_unc_rows,
high_error_low_unc=high_error_low_unc_rows,
)
def build_summary(
input_csv: Path,
metrics: Metrics,
thresholds: UncertaintyThresholds,
risk_rows: RiskRows,
args: argparse.Namespace,
) -> dict[str, Any]:
"""汇总全局统计信息和关键阈值,方便后续写入 JSON。"""
return {
"input_csv": str(input_csv),
"n_samples": int(metrics.rmse.size),
"global": {
"rmse_mean": safe_mean(metrics.rmse),
"rmse_median": safe_median(metrics.rmse),
"rmse_p90": safe_percentile(metrics.rmse, 90),
"mae_mean": safe_mean(metrics.mae),
"r2_mean": safe_mean(metrics.r2),
"unc_mean": safe_mean(metrics.unc),
"unc_median": safe_median(metrics.unc),
"unc_p90": safe_percentile(metrics.unc, 90),
},
"thresholds": {
"high_error_rmse": float(args.high_error_rmse),
"low_unc_quantile": float(args.low_unc_quantile),
"low_unc_threshold": thresholds.low,
"high_unc_quantile": float(args.high_unc_quantile),
"high_unc_threshold": thresholds.high,
},
"counts": {
"top_uncertain_exported": len(risk_rows.top_uncertain),
"high_error_high_unc": len(risk_rows.high_error_high_unc),
"high_error_low_unc": len(risk_rows.high_error_low_unc),
},
}
def plot_retention_curve(curve_rows: list[OutputRow], output_path: Path) -> None:
"""绘制按不确定性由高到低剔除样本后的误差变化。"""
removed = np.array(
[float(row["removed_pct"]) for row in curve_rows],
dtype=np.float64,
)
retained = np.array(
[float(row["retained_pct"]) for row in curve_rows],
dtype=np.float64,
)
rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64)
mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64)
r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64)
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
axes[0].plot(retained, rmse, marker="o")
axes[0].set_title("Retained vs RMSE")
axes[0].set_xlabel("Retained Samples (%)")
axes[0].set_ylabel("RMSE mean")
axes[0].grid(True, alpha=0.3)
axes[1].plot(retained, mae, marker="o")
axes[1].set_title("Retained vs MAE")
axes[1].set_xlabel("Retained Samples (%)")
axes[1].set_ylabel("MAE mean")
axes[1].grid(True, alpha=0.3)
axes[2].plot(retained, r2, marker="o")
axes[2].set_title("Retained vs R2")
axes[2].set_xlabel("Retained Samples (%)")
axes[2].set_ylabel("R2 mean")
axes[2].grid(True, alpha=0.3)
removed_grid = ",".join(str(int(value)) for value in removed)
fig.suptitle(f"Uncertainty Filtering Curves | removed grid={removed_grid}%")
plt.tight_layout(rect=[0, 0, 1, 0.94])
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def save_outputs(
output_dir: Path,
curve_rows: list[OutputRow],
risk_rows: RiskRows,
summary: dict[str, Any],
) -> None:
"""保存 JSON、CSV 和不确定性过滤曲线图。"""
with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as file:
json.dump(summary, file, ensure_ascii=False, indent=2)
save_csv(output_dir / "uncertainty_filter_curve.csv", curve_rows)
save_csv(output_dir / "top_uncertain_samples.csv", risk_rows.top_uncertain)
save_csv(output_dir / "high_error_high_unc_samples.csv", risk_rows.high_error_high_unc)
save_csv(output_dir / "high_error_low_unc_samples.csv", risk_rows.high_error_low_unc)
plot_retention_curve(curve_rows, output_dir / "uncertainty_filter_curve.png")
def print_summary(output_dir: Path, summary: dict[str, Any]) -> None:
"""在控制台打印简要结果,便于命令行运行后快速判断分析是否完成。"""
print("UQ fallback analysis complete.")
print(f"Output dir: {output_dir}")
print(
f"Global RMSE mean={summary['global']['rmse_mean']:.6f}, "
f"unc mean={summary['global']['unc_mean']:.6f}"
)
print(
f"High-error & high-unc count={summary['counts']['high_error_high_unc']}, "
f"high-error & low-unc count={summary['counts']['high_error_low_unc']}"
)
def main() -> None:
"""分析集成不确定性与真实误差的相关性,并生成分析产物。"""
args = parse_args()
input_csv, output_dir = resolve_paths(args)
output_dir.mkdir(parents=True, exist_ok=True)
rows = load_rows(input_csv)
metrics = extract_metrics(rows)
quantiles = parse_quantiles(args.quantiles)
thresholds = build_thresholds(
unc=metrics.unc,
low_unc_quantile=float(args.low_unc_quantile),
high_unc_quantile=float(args.high_unc_quantile),
)
curve_rows = build_retention_curve(metrics, quantiles)
risk_rows = build_risk_rows(
rows=rows,
metrics=metrics,
thresholds=thresholds,
high_error_rmse=float(args.high_error_rmse),
top_k=int(args.top_k),
)
summary = build_summary(
input_csv=input_csv,
metrics=metrics,
thresholds=thresholds,
risk_rows=risk_rows,
args=args,
)
save_outputs(output_dir, curve_rows, risk_rows, summary)
print_summary(output_dir, summary)
if __name__ == "__main__":
main()