"""结合原始样本元数据分析不确定性评估结果。 在 `analyze_uq_results.py` 的全局统计之外,本脚本会把 UQ CSV 与预处理数据中的 参数、流量制度等元数据对齐,按制度族、参数区间等维度分组,帮助定位代理模型 在哪类数值试井样本上更容易给出高误差或高不确定性。 """ from __future__ import annotations import argparse import csv import json import sys from pathlib import Path import joblib import numpy as np ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import normalize_tag, processed_path_for_tag def parse_args() -> argparse.Namespace: """解析 UQ 指标 CSV、processed 数据和输出路径,用于合并样本元数据分析。""" parser = argparse.ArgumentParser(description="Join UQ sample metrics with saved metadata") parser.add_argument("--processed", type=str, default=None, help="Processed dataset path") parser.add_argument("--uq-csv", type=str, default=None, help="sample_uncertainty_metrics.csv path") parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag") parser.add_argument("--output-dir", type=str, default=None, help="Output directory") parser.add_argument("--high-error-rmse", type=float, default=1.0) return parser.parse_args() def default_uq_csv(tag: str) -> Path: """根据实验标签定位默认的不确定性指标 CSV。""" return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv" def default_output_dir(tag: str) -> Path: """根据实验标签生成当前分析脚本默认的输出目录。""" return Path("results") / f"evaluation_{tag}_ensemble_uq_metadata_analysis" def save_csv(path: Path, rows: list[dict]) -> None: """把字典行写入 CSV;当没有行时仍写出表头,方便后续脚本读取。""" if not rows: return with open(path, "w", encoding="utf-8-sig", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) def summarize_group(rows: list[dict], group_key: str) -> list[dict]: """对一个样本分组计算误差、不确定性和样本数量等汇总指标。""" groups: dict[str, list[dict]] = {} for row in rows: groups.setdefault(str(row[group_key]), []).append(row) out = [] for key, group_rows in sorted(groups.items(), key=lambda item: len(item[1]), reverse=True): rmse = np.array([float(r["overall_rmse"]) for r in group_rows], dtype=np.float64) unc = np.array([float(r["unc_mean_std"]) for r in group_rows], dtype=np.float64) out.append( { group_key: key, "n_samples": len(group_rows), "rmse_mean": float(np.mean(rmse)), "rmse_median": float(np.median(rmse)), "unc_mean": float(np.mean(unc)), "unc_median": float(np.median(unc)), } ) return out def main() -> None: """把 UQ 结果与参数/制度元数据拼接,按工况区域统计误差和不确定性。""" args = parse_args() tag = normalize_tag(args.tag) processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) uq_csv = Path(args.uq_csv) if args.uq_csv is not None else default_uq_csv(tag or "family_random_50k") output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag or "family_random_50k") output_dir.mkdir(parents=True, exist_ok=True) data = joblib.load(processed_path) if "schedule_meta_test" not in data or "family_name_test" not in data: raise RuntimeError( "Processed dataset does not contain schedule metadata. Re-run preprocess on a metadata-rich raw HDF5 first." ) schedule_meta_test = np.asarray(data["schedule_meta_test"], dtype=np.float32) family_name_test = np.asarray(data["family_name_test"]).astype(str) # source 字段只在混合数据集里存在;普通数据集没有时保持可选。 source_name_test = np.asarray(data["source_name_test"]).astype(str) if "source_name_test" in data else None source_id_test = np.asarray(data["source_id_test"]) if "source_id_test" in data else None meta_names = data["meta"].get("schedule_meta_names") or [] name_to_col = {name: i for i, name in enumerate(meta_names)} with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f: uq_rows = list(csv.DictReader(f)) enriched_rows = [] for row in uq_rows: idx = int(row["idx"]) meta_row = schedule_meta_test[idx] enriched = dict(row) # UQ CSV 的 idx 对应 processed 测试集下标,因此可直接取测试集元数据补列。 enriched["family_name"] = family_name_test[idx] if source_name_test is not None: enriched["source_name"] = source_name_test[idx] if source_id_test is not None: enriched["source_id"] = int(source_id_test[idx]) for name, col in name_to_col.items(): enriched[name] = float(meta_row[col]) enriched_rows.append(enriched) save_csv(output_dir / "uq_samples_with_metadata.csv", enriched_rows) family_summary = summarize_group(enriched_rows, "family_name") save_csv(output_dir / "summary_by_family.csv", family_summary) if source_name_test is not None: source_summary = summarize_group(enriched_rows, "source_name") save_csv(output_dir / "summary_by_source.csv", source_summary) if "n_prod" in name_to_col: for row in enriched_rows: row["n_prod_group"] = int(round(float(row["n_prod"]))) # 生产段数常对应不同制度复杂度,单独分组有助于定位 UQ 失效区域。 n_prod_summary = summarize_group(enriched_rows, "n_prod_group") save_csv(output_dir / "summary_by_n_prod.csv", n_prod_summary) unc_values = np.array([float(row["unc_mean_std"]) for row in enriched_rows], dtype=np.float64) low_unc_threshold = float(np.median(unc_values)) # 筛出“高误差但低不确定性”的样本,后续可作为困难样本补采样或诊断入口。 high_error_low_unc = [ row for row in enriched_rows if float(row["overall_rmse"]) >= float(args.high_error_rmse) and float(row["unc_mean_std"]) <= low_unc_threshold ] high_error_low_unc.sort(key=lambda row: float(row["overall_rmse"]), reverse=True) save_csv(output_dir / "high_error_low_unc_with_metadata.csv", high_error_low_unc) summary = { "processed_path": str(processed_path), "uq_csv": str(uq_csv), "n_samples": len(enriched_rows), "meta_names": meta_names, "has_source_name": bool(source_name_test is not None), "high_error_low_unc_count": len(high_error_low_unc), } with open(output_dir / "metadata_join_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) print("UQ metadata join analysis complete.") print(f"Output dir: {output_dir}") print(f"High-error low-unc count={len(high_error_low_unc)}") if __name__ == "__main__": main()