You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/analyze_uq_with_metadata.py

372 lines
11 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""
将集成学习不确定性量化样本指标与处理后的数据集元数据进行关联。
该脚本为每个 UQ 指标行补充调度元数据、类别标签以及可选的源标签,然后输出分组汇总结果和困难样本。
"""
from __future__ import annotations
import argparse
import csv
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import joblib
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
def load_experiment_path_helpers():
"""Load project path helpers after adding project root to sys.path."""
# pylint: disable=import-error,import-outside-toplevel
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
return normalize_tag, processed_path_for_tag
normalize_tag_func, processed_path_func = load_experiment_path_helpers()
@dataclass(frozen=True)
class AnalysisPaths:
"""Resolved input and output paths used by the metadata analysis."""
tag: str
processed_path: Path
uq_csv: Path
output_dir: Path
@dataclass(frozen=True)
class ProcessedMetadata:
"""Metadata arrays loaded from the processed dataset."""
schedule_meta_test: np.ndarray
family_name_test: np.ndarray
source_name_test: np.ndarray | None
source_id_test: np.ndarray | None
meta_names: list[str]
name_to_col: dict[str, int]
def parse_args() -> argparse.Namespace:
"""解析 UQ 指标 CSV、processed 数据和输出路径,用于合并样本元数据分析。"""
parser = argparse.ArgumentParser(
description="Join UQ sample metrics with saved metadata"
)
parser.add_argument(
"--processed",
type=str,
default=None,
help="Processed dataset path",
)
parser.add_argument(
"--uq-csv",
type=str,
default=None,
help="sample_uncertainty_metrics.csv path",
)
parser.add_argument(
"--tag",
type=str,
default="family_random_50k",
help="Experiment tag",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Output directory",
)
parser.add_argument("--high-error-rmse", type=float, default=1.0)
return parser.parse_args()
def default_uq_csv(tag: str) -> Path:
"""根据实验标签定位默认的不确定性指标 CSV。"""
return (
Path("results")
/ f"evaluation_{tag}_ensemble_uq"
/ "sample_uncertainty_metrics.csv"
)
def default_output_dir(tag: str) -> Path:
"""根据实验标签生成当前分析脚本默认的输出目录。"""
return Path("results") / f"evaluation_{tag}_ensemble_uq_metadata_analysis"
def resolve_paths(args: argparse.Namespace) -> AnalysisPaths:
"""根据命令行参数和实验标签解析输入输出路径。"""
tag = normalize_tag_func(args.tag)
fallback_tag = tag or "family_random_50k"
processed_path = (
Path(args.processed)
if args.processed is not None
else processed_path_func(tag)
)
uq_csv = (
Path(args.uq_csv)
if args.uq_csv is not None
else default_uq_csv(fallback_tag)
)
output_dir = (
Path(args.output_dir)
if args.output_dir is not None
else default_output_dir(fallback_tag)
)
return AnalysisPaths(
tag=tag,
processed_path=processed_path,
uq_csv=uq_csv,
output_dir=output_dir,
)
def save_csv(path: Path, rows: list[dict[str, Any]]) -> None:
"""把字典行写入 CSV当没有行时不写文件。"""
if not rows:
return
with open(path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
def load_uq_rows(uq_csv: Path) -> list[dict[str, str]]:
"""读取 sample_uncertainty_metrics.csv。"""
with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f:
rows = list(csv.DictReader(f))
if not rows:
raise ValueError(f"UQ CSV 没有数据: {uq_csv}")
return rows
def load_processed_metadata(processed_path: Path) -> ProcessedMetadata:
"""读取 processed 数据中的测试集元数据。"""
data = joblib.load(processed_path)
required_keys = {"schedule_meta_test", "family_name_test"}
if not required_keys.issubset(data):
raise RuntimeError(
"Processed dataset does not contain schedule metadata. "
"Re-run preprocess on a metadata-rich raw HDF5 first."
)
schedule_meta_test = np.asarray(data["schedule_meta_test"], dtype=np.float32)
family_name_test = np.asarray(data["family_name_test"]).astype(str)
source_name_test = None
source_id_test = None
if "source_name_test" in data:
source_name_test = np.asarray(data["source_name_test"]).astype(str)
if "source_id_test" in data:
source_id_test = np.asarray(data["source_id_test"])
meta_names = data["meta"].get("schedule_meta_names") or []
name_to_col = {name: i for i, name in enumerate(meta_names)}
return ProcessedMetadata(
schedule_meta_test=schedule_meta_test,
family_name_test=family_name_test,
source_name_test=source_name_test,
source_id_test=source_id_test,
meta_names=meta_names,
name_to_col=name_to_col,
)
def enrich_uq_rows(
uq_rows: list[dict[str, str]],
metadata: ProcessedMetadata,
) -> list[dict[str, Any]]:
"""把 UQ 指标与 family/source/schedule metadata 拼接到同一行。"""
enriched_rows: list[dict[str, Any]] = []
for row in uq_rows:
idx = int(row["idx"])
meta_row = metadata.schedule_meta_test[idx]
enriched: dict[str, Any] = dict(row)
enriched["family_name"] = metadata.family_name_test[idx]
if metadata.source_name_test is not None:
enriched["source_name"] = metadata.source_name_test[idx]
if metadata.source_id_test is not None:
enriched["source_id"] = int(metadata.source_id_test[idx])
for name, col in metadata.name_to_col.items():
enriched[name] = float(meta_row[col])
enriched_rows.append(enriched)
return enriched_rows
def summarize_group(
rows: list[dict[str, Any]],
group_key: str,
) -> list[dict[str, Any]]:
"""对一个样本分组计算误差、不确定性和样本数量等汇总指标。"""
groups: dict[str, list[dict[str, Any]]] = {}
for row in rows:
groups.setdefault(str(row[group_key]), []).append(row)
summaries = []
sorted_groups = sorted(
groups.items(),
key=lambda item: len(item[1]),
reverse=True,
)
for key, group_rows in sorted_groups:
rmse = np.array(
[float(row["overall_rmse"]) for row in group_rows],
dtype=np.float64,
)
unc = np.array(
[float(row["unc_mean_std"]) for row in group_rows],
dtype=np.float64,
)
summaries.append(
{
group_key: key,
"n_samples": len(group_rows),
"rmse_mean": float(np.mean(rmse)),
"rmse_median": float(np.median(rmse)),
"unc_mean": float(np.mean(unc)),
"unc_median": float(np.median(unc)),
}
)
return summaries
def save_group_summaries(
output_dir: Path,
enriched_rows: list[dict[str, Any]],
metadata: ProcessedMetadata,
) -> None:
"""按 family/source/n_prod 等维度保存分组统计。"""
save_csv(
output_dir / "summary_by_family.csv",
summarize_group(enriched_rows, "family_name"),
)
if metadata.source_name_test is not None:
save_csv(
output_dir / "summary_by_source.csv",
summarize_group(enriched_rows, "source_name"),
)
if "n_prod" in metadata.name_to_col:
for row in enriched_rows:
row["n_prod_group"] = int(round(float(row["n_prod"])))
save_csv(
output_dir / "summary_by_n_prod.csv",
summarize_group(enriched_rows, "n_prod_group"),
)
def find_high_error_low_unc(
enriched_rows: list[dict[str, Any]],
high_error_rmse: float,
) -> list[dict[str, Any]]:
"""筛出高误差但低不确定性的危险样本。"""
unc_values = np.array(
[float(row["unc_mean_std"]) for row in enriched_rows],
dtype=np.float64,
)
low_unc_threshold = float(np.median(unc_values))
risky_rows = [
row
for row in enriched_rows
if float(row["overall_rmse"]) >= high_error_rmse
and float(row["unc_mean_std"]) <= low_unc_threshold
]
risky_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
return risky_rows
def write_json_summary(
output_dir: Path,
paths: AnalysisPaths,
metadata: ProcessedMetadata,
enriched_rows: list[dict[str, Any]],
high_error_low_unc: list[dict[str, Any]],
) -> None:
"""保存 metadata join 的整体摘要信息。"""
summary = {
"tag": paths.tag,
"processed_path": str(paths.processed_path),
"uq_csv": str(paths.uq_csv),
"n_samples": len(enriched_rows),
"meta_names": metadata.meta_names,
"has_source_name": bool(metadata.source_name_test is not None),
"high_error_low_unc_count": len(high_error_low_unc),
}
with open(output_dir / "metadata_join_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
def run_analysis(args: argparse.Namespace) -> tuple[Path, int]:
"""执行完整 UQ metadata join 分析流程。"""
paths = resolve_paths(args)
paths.output_dir.mkdir(parents=True, exist_ok=True)
metadata = load_processed_metadata(paths.processed_path)
uq_rows = load_uq_rows(paths.uq_csv)
enriched_rows = enrich_uq_rows(uq_rows, metadata)
save_csv(paths.output_dir / "uq_samples_with_metadata.csv", enriched_rows)
save_group_summaries(paths.output_dir, enriched_rows, metadata)
high_error_low_unc = find_high_error_low_unc(
enriched_rows=enriched_rows,
high_error_rmse=float(args.high_error_rmse),
)
save_csv(
paths.output_dir / "high_error_low_unc_with_metadata.csv",
high_error_low_unc,
)
write_json_summary(
output_dir=paths.output_dir,
paths=paths,
metadata=metadata,
enriched_rows=enriched_rows,
high_error_low_unc=high_error_low_unc,
)
return paths.output_dir, len(high_error_low_unc)
def main() -> None:
"""命令行入口:拼接 UQ 结果和元数据,并输出汇总文件。"""
output_dir, risky_count = run_analysis(parse_args())
print("UQ metadata join analysis complete.")
print(f"Output dir: {output_dir}")
print(f"High-error low-unc count={risky_count}")
if __name__ == "__main__":
main()