You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
350 lines
13 KiB
Python
350 lines
13 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import csv
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.append(str(ROOT))
|
|
|
|
import joblib
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
import torch
|
|
|
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
|
|
from src.models.forward_surrogate import ForwardSurrogate
|
|
|
|
|
|
def parse_seed_list(seed_text: str) -> list[int]:
|
|
seeds = []
|
|
for item in str(seed_text).split(","):
|
|
item = item.strip()
|
|
if not item:
|
|
continue
|
|
seeds.append(int(item))
|
|
if not seeds:
|
|
raise ValueError("至少需要一个 seed")
|
|
return seeds
|
|
|
|
|
|
def default_model_root(tag: str | None, use_schedule: bool) -> Path:
|
|
suffix = "" if use_schedule else "_no_schedule"
|
|
if tag:
|
|
return Path("models") / f"forward_surrogate_{tag}_ensemble{suffix}"
|
|
return Path("models") / f"forward_surrogate_ensemble{suffix}"
|
|
|
|
|
|
def default_output_dir(tag: str | None, use_schedule: bool) -> Path:
|
|
suffix = "" if use_schedule else "_no_schedule"
|
|
if tag:
|
|
return Path("results") / f"evaluation_{tag}_ensemble_uq{suffix}"
|
|
return Path("results") / f"evaluation_ensemble_uq{suffix}"
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Evaluate deep-ensemble UQ for forward surrogate")
|
|
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
|
|
parser.add_argument("--tag", type=str, default=None, help="Experiment tag")
|
|
parser.add_argument("--model-root", type=str, default=None, help="Root dir that contains seed_* members")
|
|
parser.add_argument("--output-dir", type=str, default=None, help="Evaluation output dir")
|
|
parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list")
|
|
parser.add_argument("--no-schedule", action="store_true")
|
|
parser.add_argument("--n-top-uncertain-plots", type=int, default=5)
|
|
return parser.parse_args()
|
|
|
|
|
|
def calc_metrics(
|
|
y_true: np.ndarray,
|
|
y_pred: np.ndarray,
|
|
eps_range: float = 1e-3,
|
|
eps_var: float = 1e-6,
|
|
) -> dict:
|
|
err = y_pred - y_true
|
|
mse = np.mean(err**2)
|
|
rmse = float(np.sqrt(mse))
|
|
mae = float(np.mean(np.abs(err)))
|
|
bias = float(np.mean(err))
|
|
value_range = float(np.max(y_true) - np.min(y_true))
|
|
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
|
|
ss_res = float(np.sum(err**2))
|
|
valid_nrmse = value_range > eps_range
|
|
valid_r2 = ss_tot > eps_var
|
|
nrmse = float(rmse / value_range) if valid_nrmse else np.nan
|
|
r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan
|
|
return {
|
|
"rmse": rmse,
|
|
"mae": mae,
|
|
"bias": bias,
|
|
"abs_bias": float(abs(bias)),
|
|
"nrmse": nrmse,
|
|
"r2": r2,
|
|
"valid_nrmse": bool(valid_nrmse),
|
|
"valid_r2": bool(valid_r2),
|
|
}
|
|
|
|
|
|
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
|
|
curve_layout = meta.get("curve_layout")
|
|
if curve_layout is not None:
|
|
return curve_layout
|
|
n_time_points = curve_dim // 3
|
|
return {
|
|
"n_time_points": int(n_time_points),
|
|
"parts": [
|
|
{"name": "log_pressure", "start": 0, "end": n_time_points},
|
|
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
|
|
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
|
|
],
|
|
}
|
|
|
|
|
|
def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]:
|
|
parts: dict[str, np.ndarray] = {}
|
|
for part in layout["parts"]:
|
|
start = int(part["start"])
|
|
end = int(part["end"])
|
|
parts[str(part["name"])] = curve[start:end]
|
|
return parts
|
|
|
|
|
|
def load_member(checkpoint_path: Path, device: torch.device) -> tuple[ForwardSurrogate, dict]:
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
model = ForwardSurrogate(
|
|
param_dim=int(checkpoint["param_dim"]),
|
|
schedule_dim=int(checkpoint["schedule_dim"]),
|
|
curve_dim=int(checkpoint["curve_dim"]),
|
|
hidden_dim=int(checkpoint["hidden_dim"]),
|
|
dropout=float(checkpoint["dropout"]),
|
|
use_schedule=bool(checkpoint.get("use_schedule", True)),
|
|
).to(device)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
model.eval()
|
|
return model, checkpoint
|
|
|
|
|
|
def safe_mean(x: np.ndarray) -> float:
|
|
return float(np.mean(x)) if x.size else np.nan
|
|
|
|
|
|
def safe_median(x: np.ndarray) -> float:
|
|
return float(np.median(x)) if x.size else np.nan
|
|
|
|
|
|
def safe_percentile(x: np.ndarray, q: float) -> float:
|
|
return float(np.percentile(x, q)) if x.size else np.nan
|
|
|
|
|
|
def pearson_corr(x: np.ndarray, y: np.ndarray) -> float:
|
|
if x.size == 0 or y.size == 0:
|
|
return np.nan
|
|
if np.allclose(np.std(x), 0.0) or np.allclose(np.std(y), 0.0):
|
|
return np.nan
|
|
return float(np.corrcoef(x, y)[0, 1])
|
|
|
|
|
|
def plot_uncertainty_scatter(sample_rows: list[dict], output_path: Path) -> None:
|
|
rmse = np.array([row["overall_rmse"] for row in sample_rows], dtype=np.float64)
|
|
unc = np.array([row["unc_mean_std"] for row in sample_rows], dtype=np.float64)
|
|
plt.figure(figsize=(7, 5))
|
|
plt.scatter(unc, rmse, s=10, alpha=0.35)
|
|
plt.xlabel("Predictive Uncertainty (mean std)")
|
|
plt.ylabel("Overall RMSE")
|
|
plt.title(f"Uncertainty vs Error | Pearson={pearson_corr(unc, rmse):.4f}")
|
|
plt.grid(True, alpha=0.3)
|
|
plt.tight_layout()
|
|
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
|
|
def plot_uncertain_sample(
|
|
idx: int,
|
|
curve_true: np.ndarray,
|
|
curve_mean: np.ndarray,
|
|
curve_std: np.ndarray,
|
|
curve_layout: dict,
|
|
output_dir: Path,
|
|
unc_score: float,
|
|
rmse: float,
|
|
) -> None:
|
|
true_parts = split_curve_by_layout(curve_true, curve_layout)
|
|
mean_parts = split_curve_by_layout(curve_mean, curve_layout)
|
|
std_parts = split_curve_by_layout(curve_std, curve_layout)
|
|
|
|
title_map = {
|
|
"log_pressure": "Log Pressure",
|
|
"log_derivative": "Log |Derivative|",
|
|
"slope": "Slope of Log Pressure vs Log Time",
|
|
}
|
|
|
|
fig, axes = plt.subplots(3, 1, figsize=(12, 10))
|
|
fig.suptitle(
|
|
f"High-Uncertainty Sample #{idx} | unc_mean_std={unc_score:.4f}, overall_rmse={rmse:.4f}"
|
|
)
|
|
|
|
for ax, name in zip(axes, ["log_pressure", "log_derivative", "slope"]):
|
|
y_true = true_parts[name]
|
|
y_mean = mean_parts[name]
|
|
y_std = std_parts[name]
|
|
x = np.arange(len(y_true))
|
|
ax.plot(x, y_true, label="True", linewidth=2)
|
|
ax.plot(x, y_mean, label="Ensemble mean", linewidth=2)
|
|
ax.fill_between(x, y_mean - 2.0 * y_std, y_mean + 2.0 * y_std, alpha=0.2, label="mean ± 2 std")
|
|
ax.set_title(title_map[name])
|
|
ax.grid(True, alpha=0.3)
|
|
ax.legend()
|
|
|
|
plt.tight_layout(rect=[0, 0, 1, 0.96])
|
|
plt.savefig(output_dir / f"top_uncertain_sample_{idx:04d}.png", dpi=150, bbox_inches="tight")
|
|
plt.close()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
tag = normalize_tag(args.tag)
|
|
use_schedule = not args.no_schedule
|
|
seeds = parse_seed_list(args.seeds)
|
|
|
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
|
model_root = Path(args.model_root) if args.model_root is not None else default_model_root(tag, use_schedule)
|
|
output_dir = Path(args.output_dir) if args.output_dir is not None else default_output_dir(tag, use_schedule)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
data = joblib.load(processed_path)
|
|
x_params_test = data["X_params_test"]
|
|
x_schedule_test = data["X_schedule_test"]
|
|
y_curve_test = data["Y_curve_test"]
|
|
scaler_curve = data["scaler_curve"]
|
|
curve_layout = infer_curve_layout(data["meta"], int(data["meta"]["curve_dim"]))
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
members: list[tuple[int, ForwardSurrogate]] = []
|
|
member_paths = []
|
|
first_use_schedule = None
|
|
|
|
for seed in seeds:
|
|
ckpt_path = model_root / f"seed_{seed}" / "forward_surrogate_best.pt"
|
|
model, checkpoint = load_member(ckpt_path, device)
|
|
member_paths.append(str(ckpt_path))
|
|
members.append((seed, model))
|
|
cur_use_schedule = bool(checkpoint.get("use_schedule", True))
|
|
if first_use_schedule is None:
|
|
first_use_schedule = cur_use_schedule
|
|
elif first_use_schedule != cur_use_schedule:
|
|
raise RuntimeError("Ensemble 成员的 use_schedule 设置不一致")
|
|
|
|
all_true = []
|
|
all_mean = []
|
|
all_std = []
|
|
sample_rows = []
|
|
|
|
with torch.no_grad():
|
|
for idx in range(len(x_params_test)):
|
|
params_t = torch.tensor(x_params_test[idx : idx + 1], dtype=torch.float32, device=device)
|
|
schedule_t = torch.tensor(x_schedule_test[idx : idx + 1], dtype=torch.float32, device=device)
|
|
member_preds = []
|
|
for _, model in members:
|
|
if first_use_schedule:
|
|
pred_scaled = model(params_t, schedule_t).cpu().numpy()
|
|
else:
|
|
pred_scaled = model(params_t, None).cpu().numpy()
|
|
pred = scaler_curve.inverse_transform(pred_scaled)[0].astype(np.float32)
|
|
member_preds.append(pred)
|
|
|
|
member_preds = np.stack(member_preds, axis=0)
|
|
curve_true = scaler_curve.inverse_transform(y_curve_test[idx : idx + 1])[0].astype(np.float32)
|
|
curve_mean = member_preds.mean(axis=0).astype(np.float32)
|
|
curve_std = member_preds.std(axis=0, ddof=0).astype(np.float32)
|
|
|
|
metrics = calc_metrics(curve_true, curve_mean)
|
|
parts_std = split_curve_by_layout(curve_std, curve_layout)
|
|
|
|
sample_rows.append(
|
|
{
|
|
"idx": idx,
|
|
"overall_rmse": metrics["rmse"],
|
|
"overall_mae": metrics["mae"],
|
|
"overall_bias": metrics["bias"],
|
|
"overall_r2": metrics["r2"],
|
|
"unc_mean_std": float(np.mean(curve_std)),
|
|
"unc_max_std": float(np.max(curve_std)),
|
|
"unc_log_pressure_mean_std": float(np.mean(parts_std["log_pressure"])),
|
|
"unc_log_derivative_mean_std": float(np.mean(parts_std["log_derivative"])),
|
|
"unc_slope_mean_std": float(np.mean(parts_std["slope"])),
|
|
}
|
|
)
|
|
|
|
all_true.append(curve_true)
|
|
all_mean.append(curve_mean)
|
|
all_std.append(curve_std)
|
|
|
|
all_true = np.stack(all_true, axis=0)
|
|
all_mean = np.stack(all_mean, axis=0)
|
|
all_std = np.stack(all_std, axis=0)
|
|
|
|
overall_metrics = [calc_metrics(t, p) for t, p in zip(all_true, all_mean)]
|
|
rmse = np.array([m["rmse"] for m in overall_metrics], dtype=np.float64)
|
|
mae = np.array([m["mae"] for m in overall_metrics], dtype=np.float64)
|
|
r2_valid = np.array([m["r2"] for m in overall_metrics if m["valid_r2"]], dtype=np.float64)
|
|
unc = np.array([row["unc_mean_std"] for row in sample_rows], dtype=np.float64)
|
|
|
|
summary = {
|
|
"ensemble": {
|
|
"member_count": len(members),
|
|
"member_paths": member_paths,
|
|
"use_schedule": bool(first_use_schedule),
|
|
"processed_path": str(processed_path),
|
|
},
|
|
"prediction": {
|
|
"rmse_mean": safe_mean(rmse),
|
|
"rmse_median": safe_median(rmse),
|
|
"rmse_p90": safe_percentile(rmse, 90),
|
|
"mae_mean": safe_mean(mae),
|
|
"mae_median": safe_median(mae),
|
|
"r2_mean_valid": safe_mean(r2_valid),
|
|
"r2_median_valid": safe_median(r2_valid),
|
|
},
|
|
"uncertainty": {
|
|
"unc_mean": safe_mean(unc),
|
|
"unc_median": safe_median(unc),
|
|
"unc_p90": safe_percentile(unc, 90),
|
|
"unc_vs_rmse_pearson": pearson_corr(unc, rmse),
|
|
},
|
|
}
|
|
|
|
with open(output_dir / "ensemble_uq_summary.json", "w", encoding="utf-8") as f:
|
|
json.dump(summary, f, ensure_ascii=False, indent=2)
|
|
|
|
with open(output_dir / "sample_uncertainty_metrics.csv", "w", newline="", encoding="utf-8-sig") as f:
|
|
writer = csv.DictWriter(f, fieldnames=list(sample_rows[0].keys()))
|
|
writer.writeheader()
|
|
writer.writerows(sample_rows)
|
|
|
|
plot_uncertainty_scatter(sample_rows, output_dir / "uncertainty_vs_error.png")
|
|
|
|
top_k = min(args.n_top_uncertain_plots, len(sample_rows))
|
|
top_uncertain = sorted(sample_rows, key=lambda row: row["unc_mean_std"], reverse=True)[:top_k]
|
|
for row in top_uncertain:
|
|
idx = int(row["idx"])
|
|
plot_uncertain_sample(
|
|
idx=idx,
|
|
curve_true=all_true[idx],
|
|
curve_mean=all_mean[idx],
|
|
curve_std=all_std[idx],
|
|
curve_layout=curve_layout,
|
|
output_dir=output_dir,
|
|
unc_score=float(row["unc_mean_std"]),
|
|
rmse=float(row["overall_rmse"]),
|
|
)
|
|
|
|
print("Ensemble UQ evaluation complete.")
|
|
print(f"Output dir: {output_dir}")
|
|
print(f"RMSE mean={summary['prediction']['rmse_mean']:.6f}")
|
|
print(f"UQ-RMSE Pearson={summary['uncertainty']['unc_vs_rmse_pearson']:.6f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|