""" 将集成学习不确定性量化样本指标与处理后的数据集元数据进行关联。 该脚本为每个 UQ 指标行补充调度元数据、类别标签以及可选的源标签,然后输出分组汇总结果和困难样本。 """ from __future__ import annotations import argparse import csv import json import sys from dataclasses import dataclass from pathlib import Path from typing import Any import joblib import numpy as np ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) def load_experiment_path_helpers(): """Load project path helpers after adding project root to sys.path.""" # pylint: disable=import-error,import-outside-toplevel from src.common.experiment_paths import normalize_tag, processed_path_for_tag return normalize_tag, processed_path_for_tag normalize_tag_func, processed_path_func = load_experiment_path_helpers() @dataclass(frozen=True) class AnalysisPaths: """Resolved input and output paths used by the metadata analysis.""" tag: str processed_path: Path uq_csv: Path output_dir: Path @dataclass(frozen=True) class ProcessedMetadata: """Metadata arrays loaded from the processed dataset.""" schedule_meta_test: np.ndarray family_name_test: np.ndarray source_name_test: np.ndarray | None source_id_test: np.ndarray | None meta_names: list[str] name_to_col: dict[str, int] def parse_args() -> argparse.Namespace: """解析 UQ 指标 CSV、processed 数据和输出路径,用于合并样本元数据分析。""" parser = argparse.ArgumentParser( description="Join UQ sample metrics with saved metadata" ) parser.add_argument( "--processed", type=str, default=None, help="Processed dataset path", ) parser.add_argument( "--uq-csv", type=str, default=None, help="sample_uncertainty_metrics.csv path", ) parser.add_argument( "--tag", type=str, default="family_random_50k", help="Experiment tag", ) parser.add_argument( "--output-dir", type=str, default=None, help="Output directory", ) parser.add_argument("--high-error-rmse", type=float, default=1.0) return parser.parse_args() def default_uq_csv(tag: str) -> Path: """根据实验标签定位默认的不确定性指标 CSV。""" return ( Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv" ) def default_output_dir(tag: str) -> Path: """根据实验标签生成当前分析脚本默认的输出目录。""" return Path("results") / f"evaluation_{tag}_ensemble_uq_metadata_analysis" def resolve_paths(args: argparse.Namespace) -> AnalysisPaths: """根据命令行参数和实验标签解析输入输出路径。""" tag = normalize_tag_func(args.tag) fallback_tag = tag or "family_random_50k" processed_path = ( Path(args.processed) if args.processed is not None else processed_path_func(tag) ) uq_csv = ( Path(args.uq_csv) if args.uq_csv is not None else default_uq_csv(fallback_tag) ) output_dir = ( Path(args.output_dir) if args.output_dir is not None else default_output_dir(fallback_tag) ) return AnalysisPaths( tag=tag, processed_path=processed_path, uq_csv=uq_csv, output_dir=output_dir, ) def save_csv(path: Path, rows: list[dict[str, Any]]) -> None: """把字典行写入 CSV;当没有行时不写文件。""" if not rows: return with open(path, "w", encoding="utf-8-sig", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) def load_uq_rows(uq_csv: Path) -> list[dict[str, str]]: """读取 sample_uncertainty_metrics.csv。""" with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f: rows = list(csv.DictReader(f)) if not rows: raise ValueError(f"UQ CSV 没有数据: {uq_csv}") return rows def load_processed_metadata(processed_path: Path) -> ProcessedMetadata: """读取 processed 数据中的测试集元数据。""" data = joblib.load(processed_path) required_keys = {"schedule_meta_test", "family_name_test"} if not required_keys.issubset(data): raise RuntimeError( "Processed dataset does not contain schedule metadata. " "Re-run preprocess on a metadata-rich raw HDF5 first." ) schedule_meta_test = np.asarray(data["schedule_meta_test"], dtype=np.float32) family_name_test = np.asarray(data["family_name_test"]).astype(str) source_name_test = None source_id_test = None if "source_name_test" in data: source_name_test = np.asarray(data["source_name_test"]).astype(str) if "source_id_test" in data: source_id_test = np.asarray(data["source_id_test"]) meta_names = data["meta"].get("schedule_meta_names") or [] name_to_col = {name: i for i, name in enumerate(meta_names)} return ProcessedMetadata( schedule_meta_test=schedule_meta_test, family_name_test=family_name_test, source_name_test=source_name_test, source_id_test=source_id_test, meta_names=meta_names, name_to_col=name_to_col, ) def enrich_uq_rows( uq_rows: list[dict[str, str]], metadata: ProcessedMetadata, ) -> list[dict[str, Any]]: """把 UQ 指标与 family/source/schedule metadata 拼接到同一行。""" enriched_rows: list[dict[str, Any]] = [] for row in uq_rows: idx = int(row["idx"]) meta_row = metadata.schedule_meta_test[idx] enriched: dict[str, Any] = dict(row) enriched["family_name"] = metadata.family_name_test[idx] if metadata.source_name_test is not None: enriched["source_name"] = metadata.source_name_test[idx] if metadata.source_id_test is not None: enriched["source_id"] = int(metadata.source_id_test[idx]) for name, col in metadata.name_to_col.items(): enriched[name] = float(meta_row[col]) enriched_rows.append(enriched) return enriched_rows def summarize_group( rows: list[dict[str, Any]], group_key: str, ) -> list[dict[str, Any]]: """对一个样本分组计算误差、不确定性和样本数量等汇总指标。""" groups: dict[str, list[dict[str, Any]]] = {} for row in rows: groups.setdefault(str(row[group_key]), []).append(row) summaries = [] sorted_groups = sorted( groups.items(), key=lambda item: len(item[1]), reverse=True, ) for key, group_rows in sorted_groups: rmse = np.array( [float(row["overall_rmse"]) for row in group_rows], dtype=np.float64, ) unc = np.array( [float(row["unc_mean_std"]) for row in group_rows], dtype=np.float64, ) summaries.append( { group_key: key, "n_samples": len(group_rows), "rmse_mean": float(np.mean(rmse)), "rmse_median": float(np.median(rmse)), "unc_mean": float(np.mean(unc)), "unc_median": float(np.median(unc)), } ) return summaries def save_group_summaries( output_dir: Path, enriched_rows: list[dict[str, Any]], metadata: ProcessedMetadata, ) -> None: """按 family/source/n_prod 等维度保存分组统计。""" save_csv( output_dir / "summary_by_family.csv", summarize_group(enriched_rows, "family_name"), ) if metadata.source_name_test is not None: save_csv( output_dir / "summary_by_source.csv", summarize_group(enriched_rows, "source_name"), ) if "n_prod" in metadata.name_to_col: for row in enriched_rows: row["n_prod_group"] = int(round(float(row["n_prod"]))) save_csv( output_dir / "summary_by_n_prod.csv", summarize_group(enriched_rows, "n_prod_group"), ) def find_high_error_low_unc( enriched_rows: list[dict[str, Any]], high_error_rmse: float, ) -> list[dict[str, Any]]: """筛出高误差但低不确定性的危险样本。""" unc_values = np.array( [float(row["unc_mean_std"]) for row in enriched_rows], dtype=np.float64, ) low_unc_threshold = float(np.median(unc_values)) risky_rows = [ row for row in enriched_rows if float(row["overall_rmse"]) >= high_error_rmse and float(row["unc_mean_std"]) <= low_unc_threshold ] risky_rows.sort(key=lambda row: float(row["overall_rmse"]), reverse=True) return risky_rows def write_json_summary( output_dir: Path, paths: AnalysisPaths, metadata: ProcessedMetadata, enriched_rows: list[dict[str, Any]], high_error_low_unc: list[dict[str, Any]], ) -> None: """保存 metadata join 的整体摘要信息。""" summary = { "tag": paths.tag, "processed_path": str(paths.processed_path), "uq_csv": str(paths.uq_csv), "n_samples": len(enriched_rows), "meta_names": metadata.meta_names, "has_source_name": bool(metadata.source_name_test is not None), "high_error_low_unc_count": len(high_error_low_unc), } with open(output_dir / "metadata_join_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) def run_analysis(args: argparse.Namespace) -> tuple[Path, int]: """执行完整 UQ metadata join 分析流程。""" paths = resolve_paths(args) paths.output_dir.mkdir(parents=True, exist_ok=True) metadata = load_processed_metadata(paths.processed_path) uq_rows = load_uq_rows(paths.uq_csv) enriched_rows = enrich_uq_rows(uq_rows, metadata) save_csv(paths.output_dir / "uq_samples_with_metadata.csv", enriched_rows) save_group_summaries(paths.output_dir, enriched_rows, metadata) high_error_low_unc = find_high_error_low_unc( enriched_rows=enriched_rows, high_error_rmse=float(args.high_error_rmse), ) save_csv( paths.output_dir / "high_error_low_unc_with_metadata.csv", high_error_low_unc, ) write_json_summary( output_dir=paths.output_dir, paths=paths, metadata=metadata, enriched_rows=enriched_rows, high_error_low_unc=high_error_low_unc, ) return paths.output_dir, len(high_error_low_unc) def main() -> None: """命令行入口:拼接 UQ 结果和元数据,并输出汇总文件。""" output_dir, risky_count = run_analysis(parse_args()) print("UQ metadata join analysis complete.") print(f"Output dir: {output_dir}") print(f"High-error low-unc count={risky_count}") if __name__ == "__main__": main()