You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import joblib
|
|
import numpy as np
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.append(str(ROOT))
|
|
|
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Join UQ sample metrics with saved metadata")
|
|
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
|
|
parser.add_argument("--uq-csv", type=str, default=None, help="sample_uncertainty_metrics.csv path")
|
|
parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag")
|
|
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
|
|
parser.add_argument("--high-error-rmse", type=float, default=1.0)
|
|
return parser.parse_args()
|
|
|
|
|
|
def default_uq_csv(tag: str) -> Path:
|
|
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
|
|
|
|
|
|
def default_output_dir(tag: str) -> Path:
|
|
return Path("results") / f"evaluation_{tag}_ensemble_uq_metadata_analysis"
|
|
|
|
|
|
def save_csv(path: Path, rows: list[dict]) -> None:
|
|
if not rows:
|
|
return
|
|
with open(path, "w", encoding="utf-8-sig", newline="") as f:
|
|
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
|
|
writer.writeheader()
|
|
writer.writerows(rows)
|
|
|
|
|
|
def summarize_group(rows: list[dict], group_key: str) -> list[dict]:
|
|
groups: dict[str, list[dict]] = {}
|
|
for row in rows:
|
|
groups.setdefault(str(row[group_key]), []).append(row)
|
|
|
|
out = []
|
|
for key, group_rows in sorted(groups.items(), key=lambda item: len(item[1]), reverse=True):
|
|
rmse = np.array([float(r["overall_rmse"]) for r in group_rows], dtype=np.float64)
|
|
unc = np.array([float(r["unc_mean_std"]) for r in group_rows], dtype=np.float64)
|
|
out.append(
|
|
{
|
|
group_key: key,
|
|
"n_samples": len(group_rows),
|
|
"rmse_mean": float(np.mean(rmse)),
|
|
"rmse_median": float(np.median(rmse)),
|
|
"unc_mean": float(np.mean(unc)),
|
|
"unc_median": float(np.median(unc)),
|
|
}
|
|
)
|
|
return out
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
tag = normalize_tag(args.tag)
|
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
|
uq_csv = Path(args.uq_csv) if args.uq_csv is not None else default_uq_csv(tag or "family_random_50k")
|
|
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag or "family_random_50k")
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
data = joblib.load(processed_path)
|
|
if "schedule_meta_test" not in data or "family_name_test" not in data:
|
|
raise RuntimeError(
|
|
"Processed dataset does not contain schedule metadata. Re-run preprocess on a metadata-rich raw HDF5 first."
|
|
)
|
|
|
|
schedule_meta_test = np.asarray(data["schedule_meta_test"], dtype=np.float32)
|
|
family_name_test = np.asarray(data["family_name_test"]).astype(str)
|
|
source_name_test = np.asarray(data["source_name_test"]).astype(str) if "source_name_test" in data else None
|
|
source_id_test = np.asarray(data["source_id_test"]) if "source_id_test" in data else None
|
|
meta_names = data["meta"].get("schedule_meta_names") or []
|
|
|
|
name_to_col = {name: i for i, name in enumerate(meta_names)}
|
|
|
|
with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f:
|
|
uq_rows = list(csv.DictReader(f))
|
|
|
|
enriched_rows = []
|
|
for row in uq_rows:
|
|
idx = int(row["idx"])
|
|
meta_row = schedule_meta_test[idx]
|
|
enriched = dict(row)
|
|
enriched["family_name"] = family_name_test[idx]
|
|
if source_name_test is not None:
|
|
enriched["source_name"] = source_name_test[idx]
|
|
if source_id_test is not None:
|
|
enriched["source_id"] = int(source_id_test[idx])
|
|
for name, col in name_to_col.items():
|
|
enriched[name] = float(meta_row[col])
|
|
enriched_rows.append(enriched)
|
|
|
|
save_csv(output_dir / "uq_samples_with_metadata.csv", enriched_rows)
|
|
|
|
family_summary = summarize_group(enriched_rows, "family_name")
|
|
save_csv(output_dir / "summary_by_family.csv", family_summary)
|
|
|
|
if source_name_test is not None:
|
|
source_summary = summarize_group(enriched_rows, "source_name")
|
|
save_csv(output_dir / "summary_by_source.csv", source_summary)
|
|
|
|
if "n_prod" in name_to_col:
|
|
for row in enriched_rows:
|
|
row["n_prod_group"] = int(round(float(row["n_prod"])))
|
|
n_prod_summary = summarize_group(enriched_rows, "n_prod_group")
|
|
save_csv(output_dir / "summary_by_n_prod.csv", n_prod_summary)
|
|
|
|
unc_values = np.array([float(row["unc_mean_std"]) for row in enriched_rows], dtype=np.float64)
|
|
low_unc_threshold = float(np.median(unc_values))
|
|
high_error_low_unc = [
|
|
row
|
|
for row in enriched_rows
|
|
if float(row["overall_rmse"]) >= float(args.high_error_rmse)
|
|
and float(row["unc_mean_std"]) <= low_unc_threshold
|
|
]
|
|
high_error_low_unc.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
|
|
save_csv(output_dir / "high_error_low_unc_with_metadata.csv", high_error_low_unc)
|
|
|
|
summary = {
|
|
"processed_path": str(processed_path),
|
|
"uq_csv": str(uq_csv),
|
|
"n_samples": len(enriched_rows),
|
|
"meta_names": meta_names,
|
|
"has_source_name": bool(source_name_test is not None),
|
|
"high_error_low_unc_count": len(high_error_low_unc),
|
|
}
|
|
with open(output_dir / "metadata_join_summary.json", "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
print("UQ metadata join analysis complete.")
|
|
print(f"Output dir: {output_dir}")
|
|
print(f"High-error low-unc count={len(high_error_low_unc)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|