You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/analyze_uq_results.py

385 lines
17 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""分析集成代理模型的不确定性评估结果。
本脚本读取 `evaluate_forward_ensemble.py` 导出的逐样本误差与不确定性 CSV
统计“剔除高不确定性样本后剩余样本误差是否下降”、高误差样本是否被不确定性捕获,
并输出汇总 CSV、风险样本清单以及保留率曲线图。
"""
from __future__ import annotations
# 标准库导入:
# argparse 用于解析命令行参数;
# csv/json 用于读写分析结果;
# sys/Path 用于处理项目路径和文件路径。
import argparse
import csv
import json
import sys
from pathlib import Path
# 第三方库导入:
# matplotlib 用于绘制不确定性筛选曲线;
# numpy 用于数组计算、排序、百分位数等统计操作。
import matplotlib.pyplot as plt
import numpy as np
# 获取当前脚本所在目录的上一级项目根目录。
# 这里假设脚本位于项目的某个子目录中,因此 parents[1] 指向项目根目录。
ROOT = Path(__file__).resolve().parents[1]
# 将项目根目录加入 Python 搜索路径,方便后续导入项目内部模块。
# 当前脚本虽然没有直接导入内部模块,但保留该设置可以兼容项目结构。
sys.path.append(str(ROOT))
def parse_args() -> argparse.Namespace:
"""解析不确定性评估结果 CSV、输出目录以及分位数分箱参数。"""
# 创建命令行参数解析器。
# 本脚本主要用于分析集成代理模型输出的不确定性是否能够辅助识别高误差样本。
parser = argparse.ArgumentParser(description="Analyze ensemble UQ outputs for fallback usefulness")
# 输入 CSV通常来自集成不确定性评估阶段生成的 sample_uncertainty_metrics.csv。
# 如果不指定,则会根据 tag 自动拼接默认路径。
parser.add_argument("--input-csv", type=str, default=None, help="Path to sample_uncertainty_metrics.csv")
# 输出目录:用于保存 summary、筛选曲线 CSV、风险样本 CSV 和图像。
# 如果不指定,则会根据 tag 自动生成默认输出目录。
parser.add_argument("--output-dir", type=str, default=None, help="Optional output directory")
# 实验标签:用于区分不同实验、不同数据集或不同代理模型配置。
parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag")
# 不确定性剔除比例。
# 例如 0,5,10 表示分别剔除不确定性最高的 0%、5%、10% 样本后统计保留样本误差。
parser.add_argument(
"--quantiles",
type=str,
default="0,5,10,20,30,40,50",
help="Comma-separated uncertainty removal percentages",
)
# 导出不确定性最高的 Top-K 样本,方便后续查看这些样本是否确实更容易预测失败。
parser.add_argument("--top-k", type=int, default=100, help="Top-K risky samples to export")
# 高误差样本阈值。
# overall_rmse 大于等于该值的样本会被视为高误差样本。
parser.add_argument("--high-error-rmse", type=float, default=1.0, help="High-error RMSE threshold")
# 低不确定性分位点。
# 例如 50 表示不确定性低于中位数的样本被视为低不确定性样本。
parser.add_argument("--low-unc-quantile", type=float, default=50.0, help="Low uncertainty quantile")
# 高不确定性分位点。
# 例如 90 表示不确定性高于第 90 百分位的样本被视为高不确定性样本。
parser.add_argument("--high-unc-quantile", type=float, default=90.0, help="High uncertainty quantile")
return parser.parse_args()
def parse_quantiles(text: str) -> list[float]:
"""解析逗号分隔的剔除比例,用于不确定性保留曲线分析。"""
values = []
# 将形如 "0,5,10,20" 的字符串逐项拆分并转成 float。
for item in str(text).split(","):
item = item.strip()
# 允许用户在字符串中多写逗号或空格,空项直接跳过。
if not item:
continue
q = float(item)
# 剔除比例必须位于 [0, 100)。
# 不能等于 100因为全部剔除后没有样本可用于统计误差。
if q < 0 or q >= 100:
raise ValueError(f"非法 quantile 百分比: {q}")
values.append(q)
# 强制包含 0%,用于记录不剔除任何样本时的全局基线结果。
if 0.0 not in values:
values = [0.0] + values
# 去重并升序排列,保证后续曲线横轴顺序稳定。
return sorted(set(values))
def default_input_csv(tag: str) -> Path:
"""根据实验标签定位集成评估生成的 sample_uncertainty_metrics.csv。"""
# 默认输入路径约定:
# results/evaluation_{tag}_ensemble_uq/sample_uncertainty_metrics.csv
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
def default_output_dir(tag: str) -> Path:
"""根据实验标签生成当前分析脚本默认的输出目录。"""
# 默认输出路径约定:
# results/evaluation_{tag}_ensemble_uq_analysis
return Path("results") / f"evaluation_{tag}_ensemble_uq_analysis"
def load_rows(csv_path: Path) -> list[dict]:
"""读取 CSV 为字典行列表,保留每个样本的不确定性和误差字段。"""
# 使用 utf-8-sig 是为了兼容 Excel 或部分 Windows 工具生成的带 BOM 的 CSV。
with open(csv_path, "r", encoding="utf-8-sig", newline="") as f:
rows = list(csv.DictReader(f))
# 如果 CSV 只有表头或为空,则无法进行统计分析,直接抛出错误。
if not rows:
raise ValueError(f"CSV 没有数据: {csv_path}")
return rows
def as_float(rows: list[dict], key: str) -> np.ndarray:
"""将 CSV 字段转为 float 数组,用于后续统计分析。"""
# 这里要求 CSV 中必须存在对应字段,并且字段值可以转为 float。
# 当前脚本依赖的关键字段包括:
# overall_rmse、overall_mae、overall_r2、unc_mean_std。
return np.array([float(row[key]) for row in rows], dtype=np.float64)
def safe_mean(x: np.ndarray) -> float:
"""计算均值;没有有效数据时返回 NaN。"""
# 当某个筛选条件过严时,可能出现空数组,此时返回 NaN 避免程序崩溃。
return float(np.mean(x)) if x.size else np.nan
def safe_median(x: np.ndarray) -> float:
"""计算中位数;没有有效数据时返回 NaN。"""
return float(np.median(x)) if x.size else np.nan
def safe_percentile(x: np.ndarray, q: float) -> float:
"""计算百分位数;没有有效数据时返回 NaN。"""
return float(np.percentile(x, q)) if x.size else np.nan
def save_csv(path: Path, rows: list[dict]) -> None:
"""把字典行写入 CSV。"""
# 若 rows 为空,则不写文件。
# 例如某次实验中不存在“高误差且低不确定性”的样本,就会走到这里。
if not rows:
return
# 输出使用 utf-8-sig方便在 Excel 中直接打开时中文不乱码。
with open(path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
def plot_retention_curve(curve_rows: list[dict], output_path: Path) -> None:
"""绘制按不确定性由高到低剔除样本后的误差变化,用来判断不确定性能否筛掉坏样本。"""
# 从曲线统计结果中取出横轴和纵轴数据。
removed = np.array([float(row["removed_pct"]) for row in curve_rows], dtype=np.float64)
retained = np.array([float(row["retained_pct"]) for row in curve_rows], dtype=np.float64)
rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64)
mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64)
r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64)
# 创建 1 行 3 列子图,分别观察 RMSE、MAE、R2 随保留样本比例的变化。
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
# 子图 1保留样本比例 vs RMSE。
# 如果不确定性有效剔除高不确定性样本后RMSE 理论上应下降。
axes[0].plot(retained, rmse, marker="o")
axes[0].set_title("Retained vs RMSE")
axes[0].set_xlabel("Retained Samples (%)")
axes[0].set_ylabel("RMSE mean")
axes[0].grid(True, alpha=0.3)
# 子图 2保留样本比例 vs MAE。
# MAE 对极端值不如 RMSE 敏感,可作为误差变化的补充观察指标。
axes[1].plot(retained, mae, marker="o")
axes[1].set_title("Retained vs MAE")
axes[1].set_xlabel("Retained Samples (%)")
axes[1].set_ylabel("MAE mean")
axes[1].grid(True, alpha=0.3)
# 子图 3保留样本比例 vs R2。
# 如果保留样本更容易预测R2 通常应升高。
axes[2].plot(retained, r2, marker="o")
axes[2].set_title("Retained vs R2")
axes[2].set_xlabel("Retained Samples (%)")
axes[2].set_ylabel("R2 mean")
axes[2].grid(True, alpha=0.3)
# 总标题中记录剔除比例网格,便于之后回看图像时知道对应的实验设置。
fig.suptitle(
f"Uncertainty Filtering Curves | removed grid={','.join(str(int(x)) for x in removed)}%"
)
# tight_layout 用于减少子图之间的重叠rect 为总标题预留空间。
plt.tight_layout(rect=[0, 0, 1, 0.94])
# 保存图像dpi=150 对实验报告和日常查看基本够用。
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def main() -> None:
"""分析集成不确定性与真实误差的相关性,并生成分箱统计和散点图。"""
# 1. 解析命令行参数。
args = parse_args()
# 去除 tag 两端空格,避免路径拼接时出现隐藏错误。
tag = str(args.tag).strip()
# 2. 确定输入 CSV 和输出目录。
# 若用户通过命令行指定路径,则优先使用用户路径;否则使用默认路径。
input_csv = Path(args.input_csv) if args.input_csv is not None else default_input_csv(tag)
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag)
# 确保输出目录存在parents=True 允许自动创建多级目录。
output_dir.mkdir(parents=True, exist_ok=True)
# 3. 读取样本级不确定性评估结果。
rows = load_rows(input_csv)
# 解析用户指定的剔除比例列表。
quantiles = parse_quantiles(args.quantiles)
# 4. 提取关键指标。
# overall_rmse / overall_mae / overall_r2 是每个样本的预测误差表现;
# unc_mean_std 是集成模型输出的平均标准差,可理解为该样本的预测不确定性。
rmse = as_float(rows, "overall_rmse")
mae = as_float(rows, "overall_mae")
r2 = as_float(rows, "overall_r2")
unc = as_float(rows, "unc_mean_std")
total_n = len(rows)
# 低/高不确定性阈值来自样本分布本身,避免手动指定绝对阈值依赖数据尺度。
# 例如不同压力归一化方式、不同数据集或不同代理模型会改变 unc_mean_std 的绝对大小。
low_unc_thr = float(np.percentile(unc, float(args.low_unc_quantile)))
high_unc_thr = float(np.percentile(unc, float(args.high_unc_quantile)))
# 5. 构造“不确定性过滤曲线”。
# 核心思路:
# 按不确定性从高到低剔除一部分样本;
# 对剩余样本重新计算 RMSE/MAE/R2
# 如果误差明显改善,说明不确定性指标可以作为 fallback 或人工复核触发条件。
curve_rows: list[dict] = []
for removed_pct in quantiles:
# 计算当前剔除比例对应的不确定性阈值。
# 例如 removed_pct=10 时,阈值为第 90 百分位;
# keep_mask 会保留不确定性小于等于该阈值的样本,即剔除最高 10% 不确定性样本。
thr = float(np.percentile(unc, 100.0 - removed_pct))
keep_mask = unc <= thr
# 根据掩码取出保留下来的样本误差。
kept_rmse = rmse[keep_mask]
kept_mae = mae[keep_mask]
kept_r2 = r2[keep_mask]
# 保存当前剔除比例下的统计结果。
curve_rows.append(
{
"removed_pct": float(removed_pct),
"retained_pct": float(100.0 * np.mean(keep_mask)),
"n_retained": int(np.sum(keep_mask)),
"rmse_mean": safe_mean(kept_rmse),
"rmse_median": safe_median(kept_rmse),
"rmse_p90": safe_percentile(kept_rmse, 90),
"mae_mean": safe_mean(kept_mae),
"mae_median": safe_median(kept_mae),
"r2_mean": safe_mean(kept_r2),
"r2_median": safe_median(kept_r2),
"unc_threshold": thr,
}
)
# 6. 导出不确定性最高的 Top-K 样本。
# np.argsort(-unc) 表示按 unc 从大到小排序。
top_unc_idx = np.argsort(-unc)[: min(args.top_k, total_n)]
top_uncertain_rows = [rows[int(i)] for i in top_unc_idx]
# 7. 识别两类高误差样本:
# high_error_high_unc模型误差大并且不确定性也高。
# 这类样本是“可被不确定性发现的风险样本”。
# high_error_low_unc模型误差大但不确定性低。
# 这类样本最危险,说明模型错得很自信,是后续改进代理模型时最值得重点排查的对象。
high_error_high_unc_rows = []
high_error_low_unc_rows = []
for row in rows:
row_rmse = float(row["overall_rmse"])
row_unc = float(row["unc_mean_std"])
# 高误差 + 高不确定性fallback 机制比较容易捕捉。
if row_rmse >= float(args.high_error_rmse) and row_unc >= high_unc_thr:
high_error_high_unc_rows.append(row)
# 高误差 + 低不确定性:说明不确定性没有给出预警,是更严重的问题。
if row_rmse >= float(args.high_error_rmse) and row_unc <= low_unc_thr:
high_error_low_unc_rows.append(row)
# 按 RMSE 从高到低排序,方便优先查看最严重的失败样本。
high_error_high_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
high_error_low_unc_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
# 8. 汇总全局统计信息和关键阈值,写入 JSON。
# JSON 更适合被后续自动化流程、实验记录系统或报告脚本读取。
summary = {
"input_csv": str(input_csv),
"n_samples": total_n,
"global": {
"rmse_mean": safe_mean(rmse),
"rmse_median": safe_median(rmse),
"rmse_p90": safe_percentile(rmse, 90),
"mae_mean": safe_mean(mae),
"r2_mean": safe_mean(r2),
"unc_mean": safe_mean(unc),
"unc_median": safe_median(unc),
"unc_p90": safe_percentile(unc, 90),
},
"thresholds": {
"high_error_rmse": float(args.high_error_rmse),
"low_unc_quantile": float(args.low_unc_quantile),
"low_unc_threshold": low_unc_thr,
"high_unc_quantile": float(args.high_unc_quantile),
"high_unc_threshold": high_unc_thr,
},
"counts": {
"top_uncertain_exported": len(top_uncertain_rows),
"high_error_high_unc": len(high_error_high_unc_rows),
"high_error_low_unc": len(high_error_low_unc_rows),
},
}
# ensure_ascii=False 可以保证 JSON 文件中中文正常显示。
with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
# 9. 保存所有分析产物。
save_csv(output_dir / "uncertainty_filter_curve.csv", curve_rows)
save_csv(output_dir / "top_uncertain_samples.csv", top_uncertain_rows)
save_csv(output_dir / "high_error_high_unc_samples.csv", high_error_high_unc_rows)
save_csv(output_dir / "high_error_low_unc_samples.csv", high_error_low_unc_rows)
# 绘制并保存不确定性过滤曲线。
plot_retention_curve(curve_rows, output_dir / "uncertainty_filter_curve.png")
# 10. 在控制台打印简要结果,便于命令行运行后快速判断分析是否完成。
print("UQ fallback analysis complete.")
print(f"Output dir: {output_dir}")
print(
f"Global RMSE mean={summary['global']['rmse_mean']:.6f}, "
f"unc mean={summary['global']['unc_mean']:.6f}"
)
print(
f"High-error & high-unc count={summary['counts']['high_error_high_unc']}, "
f"high-error & low-unc count={summary['counts']['high_error_low_unc']}"
)
# Python 脚本入口。
# 当该文件被直接运行时执行 main()
# 当它被其他脚本 import 时,不会自动执行分析流程。
if __name__ == "__main__":
main()