|
|
|
|
|
"""
|
|
|
|
|
|
分析集成代理模型不确定性结果,
|
|
|
|
|
|
评估 uncertainty 是否能够识别高误差样本,
|
|
|
|
|
|
并生成 retention curve、风险样本统计和 summary 输出。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
import csv
|
|
|
|
|
|
import json
|
|
|
|
|
|
import sys
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
from typing import Any
|
|
|
|
|
|
|
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# 获取当前脚本所在目录的上一级项目根目录。
|
|
|
|
|
|
# 这里假设脚本位于项目的某个子目录中,因此 parents[1] 指向项目根目录。
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
|
|
|
|
|
|
|
|
# 将项目根目录加入 Python 搜索路径,方便后续导入项目内部模块。
|
|
|
|
|
|
# 当前脚本虽然没有直接导入内部模块,但保留该设置可以兼容项目结构。
|
|
|
|
|
|
sys.path.append(str(ROOT))
|
|
|
|
|
|
|
|
|
|
|
|
Row = dict[str, str]
|
|
|
|
|
|
OutputRow = dict[str, Any]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
|
class Metrics:
|
|
|
|
|
|
"""保存样本级误差和不确定性指标,减少 main 函数中的局部变量数量。"""
|
|
|
|
|
|
|
|
|
|
|
|
rmse: np.ndarray
|
|
|
|
|
|
mae: np.ndarray
|
|
|
|
|
|
r2: np.ndarray
|
|
|
|
|
|
unc: np.ndarray
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
|
class UncertaintyThresholds:
|
|
|
|
|
|
"""保存低/高不确定性阈值。"""
|
|
|
|
|
|
|
|
|
|
|
|
low: float
|
|
|
|
|
|
high: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
|
|
|
class RiskRows:
|
|
|
|
|
|
"""保存几类需要导出的风险样本。"""
|
|
|
|
|
|
|
|
|
|
|
|
top_uncertain: list[Row]
|
|
|
|
|
|
high_error_high_unc: list[Row]
|
|
|
|
|
|
high_error_low_unc: list[Row]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
|
|
|
|
"""解析不确定性评估结果 CSV、输出目录以及分位数分箱参数。"""
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
|
description="Analyze ensemble UQ outputs for fallback usefulness"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--input-csv",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
help="Path to sample_uncertainty_metrics.csv",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--output-dir",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
help="Optional output directory",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--tag",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default="family_random_50k",
|
|
|
|
|
|
help="Experiment tag",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--quantiles",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default="0,5,10,20,30,40,50",
|
|
|
|
|
|
help="Comma-separated uncertainty removal percentages",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--top-k",
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
default=100,
|
|
|
|
|
|
help="Top-K risky samples to export",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--high-error-rmse",
|
|
|
|
|
|
type=float,
|
|
|
|
|
|
default=1.0,
|
|
|
|
|
|
help="High-error RMSE threshold",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--low-unc-quantile",
|
|
|
|
|
|
type=float,
|
|
|
|
|
|
default=50.0,
|
|
|
|
|
|
help="Low uncertainty quantile",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--high-unc-quantile",
|
|
|
|
|
|
type=float,
|
|
|
|
|
|
default=90.0,
|
|
|
|
|
|
help="High uncertainty quantile",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_quantiles(text: str) -> list[float]:
|
|
|
|
|
|
"""解析逗号分隔的剔除比例,用于不确定性保留曲线分析。"""
|
|
|
|
|
|
values = []
|
|
|
|
|
|
|
|
|
|
|
|
for item in str(text).split(","):
|
|
|
|
|
|
item = item.strip()
|
|
|
|
|
|
if not item:
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
|
|
q = float(item)
|
|
|
|
|
|
if q < 0 or q >= 100:
|
|
|
|
|
|
raise ValueError(f"非法 quantile 百分比: {q}")
|
|
|
|
|
|
|
|
|
|
|
|
values.append(q)
|
|
|
|
|
|
|
|
|
|
|
|
if 0.0 not in values:
|
|
|
|
|
|
values = [0.0] + values
|
|
|
|
|
|
|
|
|
|
|
|
return sorted(set(values))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_input_csv(tag: str) -> Path:
|
|
|
|
|
|
"""根据实验标签定位集成评估生成的 sample_uncertainty_metrics.csv。"""
|
|
|
|
|
|
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_output_dir(tag: str) -> Path:
|
|
|
|
|
|
"""根据实验标签生成当前分析脚本默认的输出目录。"""
|
|
|
|
|
|
return Path("results") / f"evaluation_{tag}_ensemble_uq_analysis"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path]:
|
|
|
|
|
|
"""根据命令行参数解析输入 CSV 和输出目录。"""
|
|
|
|
|
|
tag = str(args.tag).strip()
|
|
|
|
|
|
input_csv = Path(args.input_csv) if args.input_csv else default_input_csv(tag)
|
|
|
|
|
|
output_dir = Path(args.output_dir) if args.output_dir else default_output_dir(tag)
|
|
|
|
|
|
return input_csv, output_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_rows(csv_path: Path) -> list[Row]:
|
|
|
|
|
|
"""读取 CSV 为字典行列表,保留每个样本的不确定性和误差字段。"""
|
|
|
|
|
|
with open(csv_path, "r", encoding="utf-8-sig", newline="") as file:
|
|
|
|
|
|
rows = list(csv.DictReader(file))
|
|
|
|
|
|
|
|
|
|
|
|
if not rows:
|
|
|
|
|
|
raise ValueError(f"CSV 没有数据: {csv_path}")
|
|
|
|
|
|
|
|
|
|
|
|
return rows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def as_float(rows: list[Row], key: str) -> np.ndarray:
|
|
|
|
|
|
"""将 CSV 字段转为 float 数组,用于后续统计分析。"""
|
|
|
|
|
|
return np.array([float(row[key]) for row in rows], dtype=np.float64)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_metrics(rows: list[Row]) -> Metrics:
|
|
|
|
|
|
"""从 CSV 行中提取样本级误差和不确定性指标。"""
|
|
|
|
|
|
return Metrics(
|
|
|
|
|
|
rmse=as_float(rows, "overall_rmse"),
|
|
|
|
|
|
mae=as_float(rows, "overall_mae"),
|
|
|
|
|
|
r2=as_float(rows, "overall_r2"),
|
|
|
|
|
|
unc=as_float(rows, "unc_mean_std"),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_mean(x: np.ndarray) -> float:
|
|
|
|
|
|
"""计算均值;没有有效数据时返回 NaN。"""
|
|
|
|
|
|
return float(np.mean(x)) if x.size else np.nan
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_median(x: np.ndarray) -> float:
|
|
|
|
|
|
"""计算中位数;没有有效数据时返回 NaN。"""
|
|
|
|
|
|
return float(np.median(x)) if x.size else np.nan
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_percentile(x: np.ndarray, q: float) -> float:
|
|
|
|
|
|
"""计算百分位数;没有有效数据时返回 NaN。"""
|
|
|
|
|
|
return float(np.percentile(x, q)) if x.size else np.nan
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_csv(path: Path, rows: list[OutputRow] | list[Row]) -> None:
|
|
|
|
|
|
"""把字典行写入 CSV。"""
|
|
|
|
|
|
if not rows:
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
with open(path, "w", encoding="utf-8-sig", newline="") as file:
|
|
|
|
|
|
writer = csv.DictWriter(file, fieldnames=list(rows[0].keys()))
|
|
|
|
|
|
writer.writeheader()
|
|
|
|
|
|
writer.writerows(rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_thresholds(
|
|
|
|
|
|
unc: np.ndarray,
|
|
|
|
|
|
low_unc_quantile: float,
|
|
|
|
|
|
high_unc_quantile: float,
|
|
|
|
|
|
) -> UncertaintyThresholds:
|
|
|
|
|
|
"""根据样本不确定性分布计算低/高不确定性阈值。"""
|
|
|
|
|
|
return UncertaintyThresholds(
|
|
|
|
|
|
low=float(np.percentile(unc, low_unc_quantile)),
|
|
|
|
|
|
high=float(np.percentile(unc, high_unc_quantile)),
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_retention_curve(metrics: Metrics, quantiles: list[float]) -> list[OutputRow]:
|
|
|
|
|
|
"""构造按不确定性由高到低剔除样本后的误差统计曲线。"""
|
|
|
|
|
|
curve_rows: list[OutputRow] = []
|
|
|
|
|
|
|
|
|
|
|
|
for removed_pct in quantiles:
|
|
|
|
|
|
threshold = float(np.percentile(metrics.unc, 100.0 - removed_pct))
|
|
|
|
|
|
keep_mask = metrics.unc <= threshold
|
|
|
|
|
|
|
|
|
|
|
|
kept_rmse = metrics.rmse[keep_mask]
|
|
|
|
|
|
kept_mae = metrics.mae[keep_mask]
|
|
|
|
|
|
kept_r2 = metrics.r2[keep_mask]
|
|
|
|
|
|
|
|
|
|
|
|
curve_rows.append(
|
|
|
|
|
|
{
|
|
|
|
|
|
"removed_pct": float(removed_pct),
|
|
|
|
|
|
"retained_pct": float(100.0 * np.mean(keep_mask)),
|
|
|
|
|
|
"n_retained": int(np.sum(keep_mask)),
|
|
|
|
|
|
"rmse_mean": safe_mean(kept_rmse),
|
|
|
|
|
|
"rmse_median": safe_median(kept_rmse),
|
|
|
|
|
|
"rmse_p90": safe_percentile(kept_rmse, 90),
|
|
|
|
|
|
"mae_mean": safe_mean(kept_mae),
|
|
|
|
|
|
"mae_median": safe_median(kept_mae),
|
|
|
|
|
|
"r2_mean": safe_mean(kept_r2),
|
|
|
|
|
|
"r2_median": safe_median(kept_r2),
|
|
|
|
|
|
"unc_threshold": threshold,
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return curve_rows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_top_uncertain_rows(rows: list[Row], unc: np.ndarray, top_k: int) -> list[Row]:
|
|
|
|
|
|
"""按不确定性从大到小导出 Top-K 样本。"""
|
|
|
|
|
|
top_uncertain_indices = np.argsort(-unc)[: min(top_k, len(rows))]
|
|
|
|
|
|
return [rows[int(index)] for index in top_uncertain_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def classify_risky_rows(
|
|
|
|
|
|
rows: list[Row],
|
|
|
|
|
|
thresholds: UncertaintyThresholds,
|
|
|
|
|
|
high_error_rmse: float,
|
|
|
|
|
|
) -> tuple[list[Row], list[Row]]:
|
|
|
|
|
|
"""识别高误差高不确定性样本,以及高误差低不确定性样本。"""
|
|
|
|
|
|
high_error_high_unc_rows = []
|
|
|
|
|
|
high_error_low_unc_rows = []
|
|
|
|
|
|
|
|
|
|
|
|
for row in rows:
|
|
|
|
|
|
row_rmse = float(row["overall_rmse"])
|
|
|
|
|
|
row_unc = float(row["unc_mean_std"])
|
|
|
|
|
|
|
|
|
|
|
|
if row_rmse >= high_error_rmse and row_unc >= thresholds.high:
|
|
|
|
|
|
high_error_high_unc_rows.append(row)
|
|
|
|
|
|
|
|
|
|
|
|
if row_rmse >= high_error_rmse and row_unc <= thresholds.low:
|
|
|
|
|
|
high_error_low_unc_rows.append(row)
|
|
|
|
|
|
|
|
|
|
|
|
high_error_high_unc_rows.sort(
|
|
|
|
|
|
key=lambda row: float(row["overall_rmse"]),
|
|
|
|
|
|
reverse=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
high_error_low_unc_rows.sort(
|
|
|
|
|
|
key=lambda row: float(row["overall_rmse"]),
|
|
|
|
|
|
reverse=True,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return high_error_high_unc_rows, high_error_low_unc_rows
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_risk_rows(
|
|
|
|
|
|
rows: list[Row],
|
|
|
|
|
|
metrics: Metrics,
|
|
|
|
|
|
thresholds: UncertaintyThresholds,
|
|
|
|
|
|
high_error_rmse: float,
|
|
|
|
|
|
top_k: int,
|
|
|
|
|
|
) -> RiskRows:
|
|
|
|
|
|
"""构造所有需要导出的风险样本集合。"""
|
|
|
|
|
|
high_error_high_unc_rows, high_error_low_unc_rows = classify_risky_rows(
|
|
|
|
|
|
rows=rows,
|
|
|
|
|
|
thresholds=thresholds,
|
|
|
|
|
|
high_error_rmse=high_error_rmse,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return RiskRows(
|
|
|
|
|
|
top_uncertain=select_top_uncertain_rows(rows, metrics.unc, top_k),
|
|
|
|
|
|
high_error_high_unc=high_error_high_unc_rows,
|
|
|
|
|
|
high_error_low_unc=high_error_low_unc_rows,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_summary(
|
|
|
|
|
|
input_csv: Path,
|
|
|
|
|
|
metrics: Metrics,
|
|
|
|
|
|
thresholds: UncertaintyThresholds,
|
|
|
|
|
|
risk_rows: RiskRows,
|
|
|
|
|
|
args: argparse.Namespace,
|
|
|
|
|
|
) -> dict[str, Any]:
|
|
|
|
|
|
"""汇总全局统计信息和关键阈值,方便后续写入 JSON。"""
|
|
|
|
|
|
return {
|
|
|
|
|
|
"input_csv": str(input_csv),
|
|
|
|
|
|
"n_samples": int(metrics.rmse.size),
|
|
|
|
|
|
"global": {
|
|
|
|
|
|
"rmse_mean": safe_mean(metrics.rmse),
|
|
|
|
|
|
"rmse_median": safe_median(metrics.rmse),
|
|
|
|
|
|
"rmse_p90": safe_percentile(metrics.rmse, 90),
|
|
|
|
|
|
"mae_mean": safe_mean(metrics.mae),
|
|
|
|
|
|
"r2_mean": safe_mean(metrics.r2),
|
|
|
|
|
|
"unc_mean": safe_mean(metrics.unc),
|
|
|
|
|
|
"unc_median": safe_median(metrics.unc),
|
|
|
|
|
|
"unc_p90": safe_percentile(metrics.unc, 90),
|
|
|
|
|
|
},
|
|
|
|
|
|
"thresholds": {
|
|
|
|
|
|
"high_error_rmse": float(args.high_error_rmse),
|
|
|
|
|
|
"low_unc_quantile": float(args.low_unc_quantile),
|
|
|
|
|
|
"low_unc_threshold": thresholds.low,
|
|
|
|
|
|
"high_unc_quantile": float(args.high_unc_quantile),
|
|
|
|
|
|
"high_unc_threshold": thresholds.high,
|
|
|
|
|
|
},
|
|
|
|
|
|
"counts": {
|
|
|
|
|
|
"top_uncertain_exported": len(risk_rows.top_uncertain),
|
|
|
|
|
|
"high_error_high_unc": len(risk_rows.high_error_high_unc),
|
|
|
|
|
|
"high_error_low_unc": len(risk_rows.high_error_low_unc),
|
|
|
|
|
|
},
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_retention_curve(curve_rows: list[OutputRow], output_path: Path) -> None:
|
|
|
|
|
|
"""绘制按不确定性由高到低剔除样本后的误差变化。"""
|
|
|
|
|
|
removed = np.array(
|
|
|
|
|
|
[float(row["removed_pct"]) for row in curve_rows],
|
|
|
|
|
|
dtype=np.float64,
|
|
|
|
|
|
)
|
|
|
|
|
|
retained = np.array(
|
|
|
|
|
|
[float(row["retained_pct"]) for row in curve_rows],
|
|
|
|
|
|
dtype=np.float64,
|
|
|
|
|
|
)
|
|
|
|
|
|
rmse = np.array([float(row["rmse_mean"]) for row in curve_rows], dtype=np.float64)
|
|
|
|
|
|
mae = np.array([float(row["mae_mean"]) for row in curve_rows], dtype=np.float64)
|
|
|
|
|
|
r2 = np.array([float(row["r2_mean"]) for row in curve_rows], dtype=np.float64)
|
|
|
|
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, 3, figsize=(15, 4.5))
|
|
|
|
|
|
|
|
|
|
|
|
axes[0].plot(retained, rmse, marker="o")
|
|
|
|
|
|
axes[0].set_title("Retained vs RMSE")
|
|
|
|
|
|
axes[0].set_xlabel("Retained Samples (%)")
|
|
|
|
|
|
axes[0].set_ylabel("RMSE mean")
|
|
|
|
|
|
axes[0].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
axes[1].plot(retained, mae, marker="o")
|
|
|
|
|
|
axes[1].set_title("Retained vs MAE")
|
|
|
|
|
|
axes[1].set_xlabel("Retained Samples (%)")
|
|
|
|
|
|
axes[1].set_ylabel("MAE mean")
|
|
|
|
|
|
axes[1].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
axes[2].plot(retained, r2, marker="o")
|
|
|
|
|
|
axes[2].set_title("Retained vs R2")
|
|
|
|
|
|
axes[2].set_xlabel("Retained Samples (%)")
|
|
|
|
|
|
axes[2].set_ylabel("R2 mean")
|
|
|
|
|
|
axes[2].grid(True, alpha=0.3)
|
|
|
|
|
|
|
|
|
|
|
|
removed_grid = ",".join(str(int(value)) for value in removed)
|
|
|
|
|
|
fig.suptitle(f"Uncertainty Filtering Curves | removed grid={removed_grid}%")
|
|
|
|
|
|
|
|
|
|
|
|
plt.tight_layout(rect=[0, 0, 1, 0.94])
|
|
|
|
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
|
|
|
|
plt.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_outputs(
|
|
|
|
|
|
output_dir: Path,
|
|
|
|
|
|
curve_rows: list[OutputRow],
|
|
|
|
|
|
risk_rows: RiskRows,
|
|
|
|
|
|
summary: dict[str, Any],
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
"""保存 JSON、CSV 和不确定性过滤曲线图。"""
|
|
|
|
|
|
with open(output_dir / "uq_fallback_summary.json", "w", encoding="utf-8") as file:
|
|
|
|
|
|
json.dump(summary, file, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
save_csv(output_dir / "uncertainty_filter_curve.csv", curve_rows)
|
|
|
|
|
|
save_csv(output_dir / "top_uncertain_samples.csv", risk_rows.top_uncertain)
|
|
|
|
|
|
save_csv(output_dir / "high_error_high_unc_samples.csv", risk_rows.high_error_high_unc)
|
|
|
|
|
|
save_csv(output_dir / "high_error_low_unc_samples.csv", risk_rows.high_error_low_unc)
|
|
|
|
|
|
|
|
|
|
|
|
plot_retention_curve(curve_rows, output_dir / "uncertainty_filter_curve.png")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_summary(output_dir: Path, summary: dict[str, Any]) -> None:
|
|
|
|
|
|
"""在控制台打印简要结果,便于命令行运行后快速判断分析是否完成。"""
|
|
|
|
|
|
print("UQ fallback analysis complete.")
|
|
|
|
|
|
print(f"Output dir: {output_dir}")
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"Global RMSE mean={summary['global']['rmse_mean']:.6f}, "
|
|
|
|
|
|
f"unc mean={summary['global']['unc_mean']:.6f}"
|
|
|
|
|
|
)
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"High-error & high-unc count={summary['counts']['high_error_high_unc']}, "
|
|
|
|
|
|
f"high-error & low-unc count={summary['counts']['high_error_low_unc']}"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
|
|
|
"""分析集成不确定性与真实误差的相关性,并生成分析产物。"""
|
|
|
|
|
|
args = parse_args()
|
|
|
|
|
|
input_csv, output_dir = resolve_paths(args)
|
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
rows = load_rows(input_csv)
|
|
|
|
|
|
metrics = extract_metrics(rows)
|
|
|
|
|
|
quantiles = parse_quantiles(args.quantiles)
|
|
|
|
|
|
thresholds = build_thresholds(
|
|
|
|
|
|
unc=metrics.unc,
|
|
|
|
|
|
low_unc_quantile=float(args.low_unc_quantile),
|
|
|
|
|
|
high_unc_quantile=float(args.high_unc_quantile),
|
|
|
|
|
|
)
|
|
|
|
|
|
curve_rows = build_retention_curve(metrics, quantiles)
|
|
|
|
|
|
risk_rows = build_risk_rows(
|
|
|
|
|
|
rows=rows,
|
|
|
|
|
|
metrics=metrics,
|
|
|
|
|
|
thresholds=thresholds,
|
|
|
|
|
|
high_error_rmse=float(args.high_error_rmse),
|
|
|
|
|
|
top_k=int(args.top_k),
|
|
|
|
|
|
)
|
|
|
|
|
|
summary = build_summary(
|
|
|
|
|
|
input_csv=input_csv,
|
|
|
|
|
|
metrics=metrics,
|
|
|
|
|
|
thresholds=thresholds,
|
|
|
|
|
|
risk_rows=risk_rows,
|
|
|
|
|
|
args=args,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
save_outputs(output_dir, curve_rows, risk_rows, summary)
|
|
|
|
|
|
print_summary(output_dir, summary)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|