You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/analyze_uq_with_metadata.py

167 lines
7.2 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""结合原始样本元数据分析不确定性评估结果。
在 `analyze_uq_results.py` 的全局统计之外,本脚本会把 UQ CSV 与预处理数据中的
参数、流量制度等元数据对齐,按制度族、参数区间等维度分组,帮助定位代理模型
在哪类数值试井样本上更容易给出高误差或高不确定性。
"""
from __future__ import annotations
import argparse
import csv
import json
import sys
from pathlib import Path
import joblib
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
def parse_args() -> argparse.Namespace:
"""解析 UQ 指标 CSV、processed 数据和输出路径,用于合并样本元数据分析。"""
parser = argparse.ArgumentParser(description="Join UQ sample metrics with saved metadata")
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
parser.add_argument("--uq-csv", type=str, default=None, help="sample_uncertainty_metrics.csv path")
parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag")
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
parser.add_argument("--high-error-rmse", type=float, default=1.0)
return parser.parse_args()
def default_uq_csv(tag: str) -> Path:
"""根据实验标签定位默认的不确定性指标 CSV。"""
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
def default_output_dir(tag: str) -> Path:
"""根据实验标签生成当前分析脚本默认的输出目录。"""
return Path("results") / f"evaluation_{tag}_ensemble_uq_metadata_analysis"
def save_csv(path: Path, rows: list[dict]) -> None:
"""把字典行写入 CSV当没有行时仍写出表头方便后续脚本读取。"""
if not rows:
return
with open(path, "w", encoding="utf-8-sig", newline="") as f:
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
writer.writeheader()
writer.writerows(rows)
def summarize_group(rows: list[dict], group_key: str) -> list[dict]:
"""对一个样本分组计算误差、不确定性和样本数量等汇总指标。"""
groups: dict[str, list[dict]] = {}
for row in rows:
groups.setdefault(str(row[group_key]), []).append(row)
out = []
for key, group_rows in sorted(groups.items(), key=lambda item: len(item[1]), reverse=True):
rmse = np.array([float(r["overall_rmse"]) for r in group_rows], dtype=np.float64)
unc = np.array([float(r["unc_mean_std"]) for r in group_rows], dtype=np.float64)
out.append(
{
group_key: key,
"n_samples": len(group_rows),
"rmse_mean": float(np.mean(rmse)),
"rmse_median": float(np.median(rmse)),
"unc_mean": float(np.mean(unc)),
"unc_median": float(np.median(unc)),
}
)
return out
def main() -> None:
"""把 UQ 结果与参数/制度元数据拼接,按工况区域统计误差和不确定性。"""
args = parse_args()
tag = normalize_tag(args.tag)
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
uq_csv = Path(args.uq_csv) if args.uq_csv is not None else default_uq_csv(tag or "family_random_50k")
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag or "family_random_50k")
output_dir.mkdir(parents=True, exist_ok=True)
data = joblib.load(processed_path)
if "schedule_meta_test" not in data or "family_name_test" not in data:
raise RuntimeError(
"Processed dataset does not contain schedule metadata. Re-run preprocess on a metadata-rich raw HDF5 first."
)
schedule_meta_test = np.asarray(data["schedule_meta_test"], dtype=np.float32)
family_name_test = np.asarray(data["family_name_test"]).astype(str)
# source 字段只在混合数据集里存在;普通数据集没有时保持可选。
source_name_test = np.asarray(data["source_name_test"]).astype(str) if "source_name_test" in data else None
source_id_test = np.asarray(data["source_id_test"]) if "source_id_test" in data else None
meta_names = data["meta"].get("schedule_meta_names") or []
name_to_col = {name: i for i, name in enumerate(meta_names)}
with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f:
uq_rows = list(csv.DictReader(f))
enriched_rows = []
for row in uq_rows:
idx = int(row["idx"])
meta_row = schedule_meta_test[idx]
enriched = dict(row)
# UQ CSV 的 idx 对应 processed 测试集下标,因此可直接取测试集元数据补列。
enriched["family_name"] = family_name_test[idx]
if source_name_test is not None:
enriched["source_name"] = source_name_test[idx]
if source_id_test is not None:
enriched["source_id"] = int(source_id_test[idx])
for name, col in name_to_col.items():
enriched[name] = float(meta_row[col])
enriched_rows.append(enriched)
save_csv(output_dir / "uq_samples_with_metadata.csv", enriched_rows)
family_summary = summarize_group(enriched_rows, "family_name")
save_csv(output_dir / "summary_by_family.csv", family_summary)
if source_name_test is not None:
source_summary = summarize_group(enriched_rows, "source_name")
save_csv(output_dir / "summary_by_source.csv", source_summary)
if "n_prod" in name_to_col:
for row in enriched_rows:
row["n_prod_group"] = int(round(float(row["n_prod"])))
# 生产段数常对应不同制度复杂度,单独分组有助于定位 UQ 失效区域。
n_prod_summary = summarize_group(enriched_rows, "n_prod_group")
save_csv(output_dir / "summary_by_n_prod.csv", n_prod_summary)
unc_values = np.array([float(row["unc_mean_std"]) for row in enriched_rows], dtype=np.float64)
low_unc_threshold = float(np.median(unc_values))
# 筛出“高误差但低不确定性”的样本,后续可作为困难样本补采样或诊断入口。
high_error_low_unc = [
row
for row in enriched_rows
if float(row["overall_rmse"]) >= float(args.high_error_rmse)
and float(row["unc_mean_std"]) <= low_unc_threshold
]
high_error_low_unc.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
save_csv(output_dir / "high_error_low_unc_with_metadata.csv", high_error_low_unc)
summary = {
"processed_path": str(processed_path),
"uq_csv": str(uq_csv),
"n_samples": len(enriched_rows),
"meta_names": meta_names,
"has_source_name": bool(source_name_test is not None),
"high_error_low_unc_count": len(high_error_low_unc),
}
with open(output_dir / "metadata_join_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print("UQ metadata join analysis complete.")
print(f"Output dir: {output_dir}")
print(f"High-error low-unc count={len(high_error_low_unc)}")
if __name__ == "__main__":
main()