You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/evaluate_forward_ensemble.py

378 lines
15 KiB
Python

"""评估正演代理模型集成并估计预测不确定性。
脚本按多个随机种子加载同结构模型,计算集成均值预测、成员间标准差和逐样本误差,
导出不确定性-误差相关性统计、散点图和高不确定性样本,用于判断集成方差是否可作为
自动拟合候选筛选或风险提示信号。
"""
from __future__ import annotations
import argparse
import csv
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import joblib
import matplotlib.pyplot as plt
import numpy as np
import torch
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
from src.models.forward_surrogate import ForwardSurrogate
def parse_seed_list(seed_text: str) -> list[int]:
"""解析逗号分隔的随机种子列表,用于训练或评估模型集成。"""
seeds = []
for item in str(seed_text).split(","):
item = item.strip()
if not item:
continue
seeds.append(int(item))
if not seeds:
raise ValueError("至少需要一个 seed")
return seeds
def default_model_root(tag: str | None, use_schedule: bool) -> Path:
"""根据实验标签推导集成模型成员所在的根目录。"""
suffix = "" if use_schedule else "_no_schedule"
if tag:
return Path("models") / f"forward_surrogate_{tag}_ensemble{suffix}"
return Path("models") / f"forward_surrogate_ensemble{suffix}"
def default_output_dir(tag: str | None, use_schedule: bool) -> Path:
"""根据实验标签生成当前分析脚本默认的输出目录。"""
suffix = "" if use_schedule else "_no_schedule"
if tag:
return Path("results") / f"evaluation_{tag}_ensemble_uq{suffix}"
return Path("results") / f"evaluation_ensemble_uq{suffix}"
def parse_args() -> argparse.Namespace:
"""解析深度集成评估所需的数据、成员模型清单、样本数量和绘图数量。"""
parser = argparse.ArgumentParser(description="Evaluate deep-ensemble UQ for forward surrogate")
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
parser.add_argument("--tag", type=str, default=None, help="Experiment tag")
parser.add_argument("--model-root", type=str, default=None, help="Root dir that contains seed_* members")
parser.add_argument("--output-dir", type=str, default=None, help="Evaluation output dir")
parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list")
parser.add_argument("--no-schedule", action="store_true")
parser.add_argument("--n-top-uncertain-plots", type=int, default=5)
return parser.parse_args()
def calc_metrics(
y_true: np.ndarray,
y_pred: np.ndarray,
eps_range: float = 1e-3,
eps_var: float = 1e-6,
) -> dict:
"""计算 RMSE、MAE、Bias、NRMSE、R2 等回归指标。"""
err = y_pred - y_true
mse = np.mean(err**2)
rmse = float(np.sqrt(mse))
mae = float(np.mean(np.abs(err)))
bias = float(np.mean(err))
value_range = float(np.max(y_true) - np.min(y_true))
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
ss_res = float(np.sum(err**2))
valid_nrmse = value_range > eps_range
valid_r2 = ss_tot > eps_var
nrmse = float(rmse / value_range) if valid_nrmse else np.nan
r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan
return {
"rmse": rmse,
"mae": mae,
"bias": bias,
"abs_bias": float(abs(bias)),
"nrmse": nrmse,
"r2": r2,
"valid_nrmse": bool(valid_nrmse),
"valid_r2": bool(valid_r2),
}
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
"""从元数据读取曲线分段布局;旧数据没有布局时按压力/导数/斜率三等分回退。"""
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
],
}
def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]:
"""按照 curve_layout 将拼接曲线拆成 log_pressure、log_derivative 和 slope 三段。"""
parts: dict[str, np.ndarray] = {}
for part in layout["parts"]:
start = int(part["start"])
end = int(part["end"])
parts[str(part["name"])] = curve[start:end]
return parts
def load_member(checkpoint_path: Path, device: torch.device) -> tuple[ForwardSurrogate, dict]:
"""加载集成中的一个模型成员、对应归一化器和模型配置。"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model = ForwardSurrogate(
param_dim=int(checkpoint["param_dim"]),
schedule_dim=int(checkpoint["schedule_dim"]),
curve_dim=int(checkpoint["curve_dim"]),
hidden_dim=int(checkpoint["hidden_dim"]),
dropout=float(checkpoint["dropout"]),
use_schedule=bool(checkpoint.get("use_schedule", True)),
).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
return model, checkpoint
def safe_mean(x: np.ndarray) -> float:
"""计算忽略 NaN 后的均值;没有有效数据时返回 NaN。"""
return float(np.mean(x)) if x.size else np.nan
def safe_median(x: np.ndarray) -> float:
"""计算忽略 NaN 后的中位数;没有有效数据时返回 NaN。"""
return float(np.median(x)) if x.size else np.nan
def safe_percentile(x: np.ndarray, q: float) -> float:
"""计算忽略 NaN 后的百分位数;没有有效数据时返回 NaN。"""
return float(np.percentile(x, q)) if x.size else np.nan
def pearson_corr(x: np.ndarray, y: np.ndarray) -> float:
"""计算 Pearson 相关系数;样本数不足或方差为零时返回 NaN。"""
if x.size == 0 or y.size == 0:
return np.nan
if np.allclose(np.std(x), 0.0) or np.allclose(np.std(y), 0.0):
return np.nan
return float(np.corrcoef(x, y)[0, 1])
def plot_uncertainty_scatter(sample_rows: list[dict], output_path: Path) -> None:
"""绘制预测不确定性与真实误差的散点图,检查二者是否相关。"""
rmse = np.array([row["overall_rmse"] for row in sample_rows], dtype=np.float64)
unc = np.array([row["unc_mean_std"] for row in sample_rows], dtype=np.float64)
plt.figure(figsize=(7, 5))
plt.scatter(unc, rmse, s=10, alpha=0.35)
plt.xlabel("Predictive Uncertainty (mean std)")
plt.ylabel("Overall RMSE")
plt.title(f"Uncertainty vs Error | Pearson={pearson_corr(unc, rmse):.4f}")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches="tight")
plt.close()
def plot_uncertain_sample(
idx: int,
curve_true: np.ndarray,
curve_mean: np.ndarray,
curve_std: np.ndarray,
curve_layout: dict,
output_dir: Path,
unc_score: float,
rmse: float,
) -> None:
"""绘制高不确定性样本的真实曲线、集成均值和成员分布。"""
true_parts = split_curve_by_layout(curve_true, curve_layout)
mean_parts = split_curve_by_layout(curve_mean, curve_layout)
std_parts = split_curve_by_layout(curve_std, curve_layout)
title_map = {
"log_pressure": "Log Pressure",
"log_derivative": "Log |Derivative|",
"slope": "Slope of Log Pressure vs Log Time",
}
fig, axes = plt.subplots(3, 1, figsize=(12, 10))
fig.suptitle(
f"High-Uncertainty Sample #{idx} | unc_mean_std={unc_score:.4f}, overall_rmse={rmse:.4f}"
)
for ax, name in zip(axes, ["log_pressure", "log_derivative", "slope"]):
y_true = true_parts[name]
y_mean = mean_parts[name]
y_std = std_parts[name]
x = np.arange(len(y_true))
ax.plot(x, y_true, label="True", linewidth=2)
ax.plot(x, y_mean, label="Ensemble mean", linewidth=2)
ax.fill_between(x, y_mean - 2.0 * y_std, y_mean + 2.0 * y_std, alpha=0.2, label="mean ± 2 std")
ax.set_title(title_map[name])
ax.grid(True, alpha=0.3)
ax.legend()
plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.savefig(output_dir / f"top_uncertain_sample_{idx:04d}.png", dpi=150, bbox_inches="tight")
plt.close()
def main() -> None:
"""汇总多个代理模型成员的均值和方差,用于测试集误差与不确定性评估。"""
args = parse_args()
tag = normalize_tag(args.tag)
use_schedule = not args.no_schedule
seeds = parse_seed_list(args.seeds)
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_root = Path(args.model_root) if args.model_root is not None else default_model_root(tag, use_schedule)
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag, use_schedule)
output_dir.mkdir(parents=True, exist_ok=True)
# 集成评估必须使用同一份 processed 数据,保证各成员输入标准化口径一致。
data = joblib.load(processed_path)
x_params_test = data["X_params_test"]
x_schedule_test = data["X_schedule_test"]
y_curve_test = data["Y_curve_test"]
scaler_curve = data["scaler_curve"]
curve_layout = infer_curve_layout(data["meta"], int(data["meta"]["curve_dim"]))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
members: list[tuple[int, ForwardSurrogate]] = []
member_paths = []
first_use_schedule = None
for seed in seeds:
ckpt_path = model_root / f"seed_{seed}" / "forward_surrogate_best.pt"
model, checkpoint = load_member(ckpt_path, device)
member_paths.append(str(ckpt_path))
members.append((seed, model))
cur_use_schedule = bool(checkpoint.get("use_schedule", True))
if first_use_schedule is None:
first_use_schedule = cur_use_schedule
elif first_use_schedule != cur_use_schedule:
# 混用带/不带 schedule 的模型会导致输入含义不同,不能直接求均值和方差。
raise RuntimeError("Ensemble 成员的 use_schedule 设置不一致")
all_true = []
all_mean = []
all_std = []
sample_rows = []
with torch.no_grad():
for idx in range(len(x_params_test)):
params_t = torch.tensor(x_params_test[idx : idx + 1], dtype=torch.float32, device=device)
schedule_t = torch.tensor(x_schedule_test[idx : idx + 1], dtype=torch.float32, device=device)
member_preds = []
for _, model in members:
# 每个成员独立预测后先反标准化;集成均值和标准差都在原始曲线尺度上计算。
if first_use_schedule:
pred_scaled = model(params_t, schedule_t).cpu().numpy()
else:
pred_scaled = model(params_t, None).cpu().numpy()
pred = scaler_curve.inverse_transform(pred_scaled)[0].astype(np.float32)
member_preds.append(pred)
member_preds = np.stack(member_preds, axis=0)
curve_true = scaler_curve.inverse_transform(y_curve_test[idx : idx + 1])[0].astype(np.float32)
curve_mean = member_preds.mean(axis=0).astype(np.float32)
curve_std = member_preds.std(axis=0, ddof=0).astype(np.float32)
metrics = calc_metrics(curve_true, curve_mean)
parts_std = split_curve_by_layout(curve_std, curve_layout)
# std 作为经验不确定性指标,后续会和真实 RMSE 做相关性分析。
sample_rows.append(
{
"idx": idx,
"overall_rmse": metrics["rmse"],
"overall_mae": metrics["mae"],
"overall_bias": metrics["bias"],
"overall_r2": metrics["r2"],
"unc_mean_std": float(np.mean(curve_std)),
"unc_max_std": float(np.max(curve_std)),
"unc_log_pressure_mean_std": float(np.mean(parts_std["log_pressure"])),
"unc_log_derivative_mean_std": float(np.mean(parts_std["log_derivative"])),
"unc_slope_mean_std": float(np.mean(parts_std["slope"])),
}
)
all_true.append(curve_true)
all_mean.append(curve_mean)
all_std.append(curve_std)
all_true = np.stack(all_true, axis=0)
all_mean = np.stack(all_mean, axis=0)
all_std = np.stack(all_std, axis=0)
overall_metrics = [calc_metrics(t, p) for t, p in zip(all_true, all_mean)]
rmse = np.array([m["rmse"] for m in overall_metrics], dtype=np.float64)
mae = np.array([m["mae"] for m in overall_metrics], dtype=np.float64)
r2_valid = np.array([m["r2"] for m in overall_metrics if m["valid_r2"]], dtype=np.float64)
unc = np.array([row["unc_mean_std"] for row in sample_rows], dtype=np.float64)
# 聚合输出分为预测质量和不确定性质量两部分,方便单独比较 ensemble 是否有价值。
summary = {
"ensemble": {
"member_count": len(members),
"member_paths": member_paths,
"use_schedule": bool(first_use_schedule),
"processed_path": str(processed_path),
},
"prediction": {
"rmse_mean": safe_mean(rmse),
"rmse_median": safe_median(rmse),
"rmse_p90": safe_percentile(rmse, 90),
"mae_mean": safe_mean(mae),
"mae_median": safe_median(mae),
"r2_mean_valid": safe_mean(r2_valid),
"r2_median_valid": safe_median(r2_valid),
},
"uncertainty": {
"unc_mean": safe_mean(unc),
"unc_median": safe_median(unc),
"unc_p90": safe_percentile(unc, 90),
"unc_vs_rmse_pearson": pearson_corr(unc, rmse),
},
}
with open(output_dir / "ensemble_uq_summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
with open(output_dir / "sample_uncertainty_metrics.csv", "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=list(sample_rows[0].keys()))
writer.writeheader()
writer.writerows(sample_rows)
plot_uncertainty_scatter(sample_rows, output_dir / "uncertainty_vs_error.png")
top_k = min(args.n_top_uncertain_plots, len(sample_rows))
top_uncertain = sorted(sample_rows, key=lambda row: row["unc_mean_std"], reverse=True)[:top_k]
# 只绘制最高不确定性的样本,重点检查模型成员分歧最大的区域。
for row in top_uncertain:
idx = int(row["idx"])
plot_uncertain_sample(
idx=idx,
curve_true=all_true[idx],
curve_mean=all_mean[idx],
curve_std=all_std[idx],
curve_layout=curve_layout,
output_dir=output_dir,
unc_score=float(row["unc_mean_std"]),
rmse=float(row["overall_rmse"]),
)
print("Ensemble UQ evaluation complete.")
print(f"Output dir: {output_dir}")
print(f"RMSE mean={summary['prediction']['rmse_mean']:.6f}")
print(f"UQ-RMSE Pearson={summary['uncertainty']['unc_vs_rmse_pearson']:.6f}")
if __name__ == "__main__":
main()