from __future__ import annotations import argparse import csv import json import sys from pathlib import Path import joblib import numpy as np ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import normalize_tag, processed_path_for_tag def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Join UQ sample metrics with saved metadata") parser.add_argument("--processed", type=str, default=None, help="Processed dataset path") parser.add_argument("--uq-csv", type=str, default=None, help="sample_uncertainty_metrics.csv path") parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag") parser.add_argument("--output-dir", type=str, default=None, help="Output directory") parser.add_argument("--high-error-rmse", type=float, default=1.0) return parser.parse_args() def default_uq_csv(tag: str) -> Path: return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv" def default_output_dir(tag: str) -> Path: return Path("results") / f"evaluation_{tag}_ensemble_uq_metadata_analysis" def save_csv(path: Path, rows: list[dict]) -> None: if not rows: return with open(path, "w", encoding="utf-8-sig", newline="") as f: writer = csv.DictWriter(f, fieldnames=list(rows[0].keys())) writer.writeheader() writer.writerows(rows) def summarize_group(rows: list[dict], group_key: str) -> list[dict]: groups: dict[str, list[dict]] = {} for row in rows: groups.setdefault(str(row[group_key]), []).append(row) out = [] for key, group_rows in sorted(groups.items(), key=lambda item: len(item[1]), reverse=True): rmse = np.array([float(r["overall_rmse"]) for r in group_rows], dtype=np.float64) unc = np.array([float(r["unc_mean_std"]) for r in group_rows], dtype=np.float64) out.append( { group_key: key, "n_samples": len(group_rows), "rmse_mean": float(np.mean(rmse)), "rmse_median": float(np.median(rmse)), "unc_mean": float(np.mean(unc)), "unc_median": float(np.median(unc)), } ) return out def main() -> None: args = parse_args() tag = normalize_tag(args.tag) processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) uq_csv = Path(args.uq_csv) if args.uq_csv is not None else default_uq_csv(tag or "family_random_50k") output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag or "family_random_50k") output_dir.mkdir(parents=True, exist_ok=True) data = joblib.load(processed_path) if "schedule_meta_test" not in data or "family_name_test" not in data: raise RuntimeError( "Processed dataset does not contain schedule metadata. Re-run preprocess on a metadata-rich raw HDF5 first." ) schedule_meta_test = np.asarray(data["schedule_meta_test"], dtype=np.float32) family_name_test = np.asarray(data["family_name_test"]).astype(str) source_name_test = np.asarray(data["source_name_test"]).astype(str) if "source_name_test" in data else None source_id_test = np.asarray(data["source_id_test"]) if "source_id_test" in data else None meta_names = data["meta"].get("schedule_meta_names") or [] name_to_col = {name: i for i, name in enumerate(meta_names)} with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f: uq_rows = list(csv.DictReader(f)) enriched_rows = [] for row in uq_rows: idx = int(row["idx"]) meta_row = schedule_meta_test[idx] enriched = dict(row) enriched["family_name"] = family_name_test[idx] if source_name_test is not None: enriched["source_name"] = source_name_test[idx] if source_id_test is not None: enriched["source_id"] = int(source_id_test[idx]) for name, col in name_to_col.items(): enriched[name] = float(meta_row[col]) enriched_rows.append(enriched) save_csv(output_dir / "uq_samples_with_metadata.csv", enriched_rows) family_summary = summarize_group(enriched_rows, "family_name") save_csv(output_dir / "summary_by_family.csv", family_summary) if source_name_test is not None: source_summary = summarize_group(enriched_rows, "source_name") save_csv(output_dir / "summary_by_source.csv", source_summary) if "n_prod" in name_to_col: for row in enriched_rows: row["n_prod_group"] = int(round(float(row["n_prod"]))) n_prod_summary = summarize_group(enriched_rows, "n_prod_group") save_csv(output_dir / "summary_by_n_prod.csv", n_prod_summary) unc_values = np.array([float(row["unc_mean_std"]) for row in enriched_rows], dtype=np.float64) low_unc_threshold = float(np.median(unc_values)) high_error_low_unc = [ row for row in enriched_rows if float(row["overall_rmse"]) >= float(args.high_error_rmse) and float(row["unc_mean_std"]) <= low_unc_threshold ] high_error_low_unc.sort(key=lambda row: float(row["overall_rmse"]), reverse=True) save_csv(output_dir / "high_error_low_unc_with_metadata.csv", high_error_low_unc) summary = { "processed_path": str(processed_path), "uq_csv": str(uq_csv), "n_samples": len(enriched_rows), "meta_names": meta_names, "has_source_name": bool(source_name_test is not None), "high_error_low_unc_count": len(high_error_low_unc), } with open(output_dir / "metadata_join_summary.json", "w", encoding="utf-8") as f: json.dump(summary, f, ensure_ascii=False, indent=2) print("UQ metadata join analysis complete.") print(f"Output dir: {output_dir}") print(f"High-error low-unc count={len(high_error_low_unc)}") if __name__ == "__main__": main()