|
|
|
|
|
"""结合原始样本元数据分析不确定性评估结果。
|
|
|
|
|
|
|
|
|
|
|
|
在 `analyze_uq_results.py` 的全局统计之外,本脚本会把 UQ CSV 与预处理数据中的
|
|
|
|
|
|
参数、流量制度等元数据对齐,按制度族、参数区间等维度分组,帮助定位代理模型
|
|
|
|
|
|
在哪类数值试井样本上更容易给出高误差或高不确定性。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
import csv
|
|
|
|
|
|
import json
|
|
|
|
|
|
import sys
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
import joblib
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
|
|
sys.path.append(str(ROOT))
|
|
|
|
|
|
|
|
|
|
|
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
|
|
|
|
"""解析 UQ 指标 CSV、processed 数据和输出路径,用于合并样本元数据分析。"""
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Join UQ sample metrics with saved metadata")
|
|
|
|
|
|
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
|
|
|
|
|
|
parser.add_argument("--uq-csv", type=str, default=None, help="sample_uncertainty_metrics.csv path")
|
|
|
|
|
|
parser.add_argument("--tag", type=str, default="family_random_50k", help="Experiment tag")
|
|
|
|
|
|
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
|
|
|
|
|
|
parser.add_argument("--high-error-rmse", type=float, default=1.0)
|
|
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_uq_csv(tag: str) -> Path:
|
|
|
|
|
|
"""根据实验标签定位默认的不确定性指标 CSV。"""
|
|
|
|
|
|
return Path("results") / f"evaluation_{tag}_ensemble_uq" / "sample_uncertainty_metrics.csv"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def default_output_dir(tag: str) -> Path:
|
|
|
|
|
|
"""根据实验标签生成当前分析脚本默认的输出目录。"""
|
|
|
|
|
|
return Path("results") / f"evaluation_{tag}_ensemble_uq_metadata_analysis"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_csv(path: Path, rows: list[dict]) -> None:
|
|
|
|
|
|
"""把字典行写入 CSV;当没有行时仍写出表头,方便后续脚本读取。"""
|
|
|
|
|
|
if not rows:
|
|
|
|
|
|
return
|
|
|
|
|
|
with open(path, "w", encoding="utf-8-sig", newline="") as f:
|
|
|
|
|
|
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
|
|
|
|
|
|
writer.writeheader()
|
|
|
|
|
|
writer.writerows(rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def summarize_group(rows: list[dict], group_key: str) -> list[dict]:
|
|
|
|
|
|
"""对一个样本分组计算误差、不确定性和样本数量等汇总指标。"""
|
|
|
|
|
|
groups: dict[str, list[dict]] = {}
|
|
|
|
|
|
for row in rows:
|
|
|
|
|
|
groups.setdefault(str(row[group_key]), []).append(row)
|
|
|
|
|
|
|
|
|
|
|
|
out = []
|
|
|
|
|
|
for key, group_rows in sorted(groups.items(), key=lambda item: len(item[1]), reverse=True):
|
|
|
|
|
|
rmse = np.array([float(r["overall_rmse"]) for r in group_rows], dtype=np.float64)
|
|
|
|
|
|
unc = np.array([float(r["unc_mean_std"]) for r in group_rows], dtype=np.float64)
|
|
|
|
|
|
out.append(
|
|
|
|
|
|
{
|
|
|
|
|
|
group_key: key,
|
|
|
|
|
|
"n_samples": len(group_rows),
|
|
|
|
|
|
"rmse_mean": float(np.mean(rmse)),
|
|
|
|
|
|
"rmse_median": float(np.median(rmse)),
|
|
|
|
|
|
"unc_mean": float(np.mean(unc)),
|
|
|
|
|
|
"unc_median": float(np.median(unc)),
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
|
|
|
"""把 UQ 结果与参数/制度元数据拼接,按工况区域统计误差和不确定性。"""
|
|
|
|
|
|
args = parse_args()
|
|
|
|
|
|
tag = normalize_tag(args.tag)
|
|
|
|
|
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
|
|
|
|
|
uq_csv = Path(args.uq_csv) if args.uq_csv is not None else default_uq_csv(tag or "family_random_50k")
|
|
|
|
|
|
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag or "family_random_50k")
|
|
|
|
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
data = joblib.load(processed_path)
|
|
|
|
|
|
if "schedule_meta_test" not in data or "family_name_test" not in data:
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
"Processed dataset does not contain schedule metadata. Re-run preprocess on a metadata-rich raw HDF5 first."
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
schedule_meta_test = np.asarray(data["schedule_meta_test"], dtype=np.float32)
|
|
|
|
|
|
family_name_test = np.asarray(data["family_name_test"]).astype(str)
|
|
|
|
|
|
# source 字段只在混合数据集里存在;普通数据集没有时保持可选。
|
|
|
|
|
|
source_name_test = np.asarray(data["source_name_test"]).astype(str) if "source_name_test" in data else None
|
|
|
|
|
|
source_id_test = np.asarray(data["source_id_test"]) if "source_id_test" in data else None
|
|
|
|
|
|
meta_names = data["meta"].get("schedule_meta_names") or []
|
|
|
|
|
|
|
|
|
|
|
|
name_to_col = {name: i for i, name in enumerate(meta_names)}
|
|
|
|
|
|
|
|
|
|
|
|
with open(uq_csv, "r", encoding="utf-8-sig", newline="") as f:
|
|
|
|
|
|
uq_rows = list(csv.DictReader(f))
|
|
|
|
|
|
|
|
|
|
|
|
enriched_rows = []
|
|
|
|
|
|
for row in uq_rows:
|
|
|
|
|
|
idx = int(row["idx"])
|
|
|
|
|
|
meta_row = schedule_meta_test[idx]
|
|
|
|
|
|
enriched = dict(row)
|
|
|
|
|
|
# UQ CSV 的 idx 对应 processed 测试集下标,因此可直接取测试集元数据补列。
|
|
|
|
|
|
enriched["family_name"] = family_name_test[idx]
|
|
|
|
|
|
if source_name_test is not None:
|
|
|
|
|
|
enriched["source_name"] = source_name_test[idx]
|
|
|
|
|
|
if source_id_test is not None:
|
|
|
|
|
|
enriched["source_id"] = int(source_id_test[idx])
|
|
|
|
|
|
for name, col in name_to_col.items():
|
|
|
|
|
|
enriched[name] = float(meta_row[col])
|
|
|
|
|
|
enriched_rows.append(enriched)
|
|
|
|
|
|
|
|
|
|
|
|
save_csv(output_dir / "uq_samples_with_metadata.csv", enriched_rows)
|
|
|
|
|
|
|
|
|
|
|
|
family_summary = summarize_group(enriched_rows, "family_name")
|
|
|
|
|
|
save_csv(output_dir / "summary_by_family.csv", family_summary)
|
|
|
|
|
|
|
|
|
|
|
|
if source_name_test is not None:
|
|
|
|
|
|
source_summary = summarize_group(enriched_rows, "source_name")
|
|
|
|
|
|
save_csv(output_dir / "summary_by_source.csv", source_summary)
|
|
|
|
|
|
|
|
|
|
|
|
if "n_prod" in name_to_col:
|
|
|
|
|
|
for row in enriched_rows:
|
|
|
|
|
|
row["n_prod_group"] = int(round(float(row["n_prod"])))
|
|
|
|
|
|
# 生产段数常对应不同制度复杂度,单独分组有助于定位 UQ 失效区域。
|
|
|
|
|
|
n_prod_summary = summarize_group(enriched_rows, "n_prod_group")
|
|
|
|
|
|
save_csv(output_dir / "summary_by_n_prod.csv", n_prod_summary)
|
|
|
|
|
|
|
|
|
|
|
|
unc_values = np.array([float(row["unc_mean_std"]) for row in enriched_rows], dtype=np.float64)
|
|
|
|
|
|
low_unc_threshold = float(np.median(unc_values))
|
|
|
|
|
|
# 筛出“高误差但低不确定性”的样本,后续可作为困难样本补采样或诊断入口。
|
|
|
|
|
|
high_error_low_unc = [
|
|
|
|
|
|
row
|
|
|
|
|
|
for row in enriched_rows
|
|
|
|
|
|
if float(row["overall_rmse"]) >= float(args.high_error_rmse)
|
|
|
|
|
|
and float(row["unc_mean_std"]) <= low_unc_threshold
|
|
|
|
|
|
]
|
|
|
|
|
|
high_error_low_unc.sort(key=lambda row: float(row["overall_rmse"]), reverse=True)
|
|
|
|
|
|
save_csv(output_dir / "high_error_low_unc_with_metadata.csv", high_error_low_unc)
|
|
|
|
|
|
|
|
|
|
|
|
summary = {
|
|
|
|
|
|
"processed_path": str(processed_path),
|
|
|
|
|
|
"uq_csv": str(uq_csv),
|
|
|
|
|
|
"n_samples": len(enriched_rows),
|
|
|
|
|
|
"meta_names": meta_names,
|
|
|
|
|
|
"has_source_name": bool(source_name_test is not None),
|
|
|
|
|
|
"high_error_low_unc_count": len(high_error_low_unc),
|
|
|
|
|
|
}
|
|
|
|
|
|
with open(output_dir / "metadata_join_summary.json", "w", encoding="utf-8") as f:
|
|
|
|
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
print("UQ metadata join analysis complete.")
|
|
|
|
|
|
print(f"Output dir: {output_dir}")
|
|
|
|
|
|
print(f"High-error low-unc count={len(high_error_low_unc)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|