You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/evaluate_forward.py

585 lines
23 KiB
Python

"""评估固定长度曲线正演代理模型。
脚本加载预处理数据和 `ForwardSurrogate` checkpoint,批量预测验证/测试样本曲线
按整体曲线与压力导数斜率分段统计 RMSEMAENRMSER2 等指标并保存
随机最佳最差样本图作为正演代理模型离线验收入口
"""
# pylint: disable=import-error,wrong-import-position
# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments,too-many-statements
# pylint: disable=line-too-long
from __future__ import annotations
import argparse
import csv
import json
import random
import sys
from pathlib import Path
import joblib
import matplotlib.pyplot as plt
import numpy as np
import torch
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import (
evaluation_dir_for_tag,
model_checkpoint_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.models.forward_surrogate import ForwardSurrogate
DEFAULT_RANDOM_SEED = 42
def parse_args() -> argparse.Namespace:
"""解析固定长度正演代理评估所需的数据、checkpoint、样本数和绘图开关。"""
parser = argparse.ArgumentParser(description="Evaluate forward surrogate model")
parser.add_argument(
"--processed",
type=str,
default=None,
help="Processed dataset path",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="Model checkpoint path",
)
parser.add_argument(
"--fit-processed",
type=str,
default=None,
help=(
"Processed dataset used to fit scalers for the evaluated model; "
"required for cross-dataset evaluation"
),
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Optional evaluation output directory",
)
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
parser.add_argument(
"--no-schedule",
action="store_true",
help="When --model is omitted, infer the no-schedule checkpoint path",
)
parser.add_argument("--seed", type=int, default=DEFAULT_RANDOM_SEED)
parser.add_argument("--n-random-plots", type=int, default=5)
parser.add_argument("--n-best-plots", type=int, default=5)
parser.add_argument("--n-worst-plots", type=int, default=5)
return parser.parse_args()
def calc_metrics(
y_true: np.ndarray,
y_pred: np.ndarray,
eps_range: float = 1e-3,
eps_var: float = 1e-6,
) -> dict:
"""计算 RMSE、MAE、Bias、NRMSE、R2 等回归指标。"""
err = y_pred - y_true
mse = np.mean(err**2)
rmse = float(np.sqrt(mse))
mae = float(np.mean(np.abs(err)))
bias = float(np.mean(err))
value_range = float(np.max(y_true) - np.min(y_true))
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
ss_res = float(np.sum(err**2))
valid_nrmse = value_range > eps_range
valid_r2 = ss_tot > eps_var
nrmse = float(rmse / value_range) if valid_nrmse else np.nan
r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan
return {
"rmse": rmse,
"mae": mae,
"bias": bias,
"abs_bias": float(abs(bias)),
"nrmse": nrmse,
"r2": r2,
"range_true": value_range,
"ss_tot": ss_tot,
"valid_nrmse": bool(valid_nrmse),
"valid_r2": bool(valid_r2),
}
def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]:
"""按照 curve_layout 将拼接曲线拆成 log_pressure、log_derivative 和 slope 三段。"""
parts: dict[str, np.ndarray] = {}
for part in layout["parts"]:
start = int(part["start"])
end = int(part["end"])
parts[str(part["name"])] = curve[start:end]
return parts
def safe_percentile(x: np.ndarray, q: float) -> float:
"""计算忽略 NaN 后的百分位数;没有有效数据时返回 NaN。"""
if x.size == 0:
return np.nan
return float(np.percentile(x, q))
def safe_mean(x: np.ndarray) -> float:
"""计算忽略 NaN 后的均值;没有有效数据时返回 NaN。"""
if x.size == 0:
return np.nan
return float(np.mean(x))
def safe_median(x: np.ndarray) -> float:
"""计算忽略 NaN 后的中位数;没有有效数据时返回 NaN。"""
if x.size == 0:
return np.nan
return float(np.median(x))
def summarize_metric_dicts(metric_dicts: list[dict], prefix: str) -> dict:
"""把多个样本的指标字典合并为均值、中位数和分位数统计。"""
rmse = np.array([m["rmse"] for m in metric_dicts], dtype=np.float64)
mae = np.array([m["mae"] for m in metric_dicts], dtype=np.float64)
bias = np.array([m["bias"] for m in metric_dicts], dtype=np.float64)
abs_bias = np.array([m["abs_bias"] for m in metric_dicts], dtype=np.float64)
ranges = np.array([m["range_true"] for m in metric_dicts], dtype=np.float64)
ss_tot = np.array([m["ss_tot"] for m in metric_dicts], dtype=np.float64)
nrmse_valid = np.array([m["nrmse"] for m in metric_dicts if m["valid_nrmse"]], dtype=np.float64)
r2_valid = np.array([m["r2"] for m in metric_dicts if m["valid_r2"]], dtype=np.float64)
valid_nrmse_ratio = len(nrmse_valid) / max(len(metric_dicts), 1)
valid_r2_ratio = len(r2_valid) / max(len(metric_dicts), 1)
low_range_count = int(np.sum(ranges <= 1e-3))
low_var_count = int(np.sum(ss_tot <= 1e-6))
print(f"\n[{prefix}]")
print(
f" RMSE mean={safe_mean(rmse):.6f}, median={safe_median(rmse):.6f}, "
f"p90={safe_percentile(rmse, 90):.6f}"
)
print(
f" MAE mean={safe_mean(mae):.6f}, median={safe_median(mae):.6f}, "
f"p90={safe_percentile(mae, 90):.6f}"
)
print(
f" Bias mean={safe_mean(bias):.6f}, median={safe_median(bias):.6f}, "
f"|mean|={safe_mean(abs_bias):.6f}"
)
print(
f" TrueRange mean={safe_mean(ranges):.6f}, median={safe_median(ranges):.6f}, "
f"p10={safe_percentile(ranges, 10):.6f}"
)
print(
f" NRMSE(valid only) mean={safe_mean(nrmse_valid):.6f}, "
f"median={safe_median(nrmse_valid):.6f}, p90={safe_percentile(nrmse_valid, 90):.6f}, "
f"valid_ratio={valid_nrmse_ratio:.4f}, low_range_count={low_range_count}"
)
print(
f" R2(valid only) mean={safe_mean(r2_valid):.6f}, "
f"median={safe_median(r2_valid):.6f}, p10={safe_percentile(r2_valid, 10):.6f}, "
f"valid_ratio={valid_r2_ratio:.4f}, low_var_count={low_var_count}"
)
return {
"rmse_mean": safe_mean(rmse),
"rmse_median": safe_median(rmse),
"rmse_p90": safe_percentile(rmse, 90),
"mae_mean": safe_mean(mae),
"mae_median": safe_median(mae),
"mae_p90": safe_percentile(mae, 90),
"bias_mean": safe_mean(bias),
"bias_median": safe_median(bias),
"abs_bias_mean": safe_mean(abs_bias),
"nrmse_mean_valid": safe_mean(nrmse_valid),
"nrmse_median_valid": safe_median(nrmse_valid),
"nrmse_p90_valid": safe_percentile(nrmse_valid, 90),
"nrmse_valid_ratio": valid_nrmse_ratio,
"r2_mean_valid": safe_mean(r2_valid),
"r2_median_valid": safe_median(r2_valid),
"r2_p10_valid": safe_percentile(r2_valid, 10),
"r2_valid_ratio": valid_r2_ratio,
"low_range_count": low_range_count,
"low_var_count": low_var_count,
}
def build_composite_score(overall_m: dict, part_ms: dict) -> float:
"""把整体误差和分段误差组合成单个分数,用于挑选最差样本。"""
return float(
1.0 * overall_m["rmse"]
+ 0.5 * overall_m["mae"]
+ 0.8 * part_ms["log_pressure"]["abs_bias"]
+ 0.8 * part_ms["log_derivative"]["rmse"]
+ 0.2 * part_ms["slope"]["rmse"]
)
def plot_sample(
idx: int,
curve_true: np.ndarray,
curve_pred: np.ndarray,
curve_layout: dict,
output_dir: Path,
title_prefix: str,
) -> None:
"""绘制单个样本的真实曲线、预测曲线和误差曲线,并保存为图片。"""
true_parts = split_curve_by_layout(curve_true, curve_layout)
pred_parts = split_curve_by_layout(curve_pred, curve_layout)
overall = calc_metrics(curve_true, curve_pred)
nrmse_text = "nan" if np.isnan(overall["nrmse"]) else f"{overall['nrmse']:.4f}"
r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.4f}"
fig, axes = plt.subplots(3, 2, figsize=(14, 12))
fig.suptitle(
f"{title_prefix} | Sample #{idx} | "
f"Overall RMSE={overall['rmse']:.4f}, MAE={overall['mae']:.4f}, "
f"Bias={overall['bias']:.4f}, "
f"NRMSE={nrmse_text}, "
f"R2={r2_text}"
)
plot_order = ["log_pressure", "log_derivative", "slope"]
title_map = {
"log_pressure": "Log Pressure",
"log_derivative": "Log |Derivative|",
"slope": "Slope of Log Pressure vs Log Time",
}
for row, name in enumerate(plot_order):
y_true = true_parts[name]
y_pred = pred_parts[name]
err = y_pred - y_true
x = np.arange(len(y_true))
m = calc_metrics(y_true, y_pred)
nrmse_text = "nan" if np.isnan(m["nrmse"]) else f"{m['nrmse']:.4f}"
r2_text = "nan" if np.isnan(m["r2"]) else f"{m['r2']:.4f}"
ax_l = axes[row, 0]
ax_l.plot(x, y_true, label="True", linewidth=2, alpha=0.85)
ax_l.plot(x, y_pred, label="Predicted", linewidth=2, alpha=0.85)
ax_l.set_title(
f"{title_map[name]} | RMSE={m['rmse']:.4f}, MAE={m['mae']:.4f}, "
f"Bias={m['bias']:.4f}, NRMSE={nrmse_text}, R2={r2_text}"
)
ax_l.set_xlabel("Resampled Time Index")
ax_l.set_ylabel("Value")
ax_l.grid(True, alpha=0.3)
ax_l.legend()
ax_r = axes[row, 1]
ax_r.plot(x, err, linewidth=1.5, alpha=0.85)
ax_r.axhline(0.0, linestyle="--", linewidth=1)
ax_r.set_title(f"{title_map[name]} Error")
ax_r.set_xlabel("Resampled Time Index")
ax_r.set_ylabel("Pred - True")
ax_r.grid(True, alpha=0.3)
plt.tight_layout(rect=[0, 0, 1, 0.97])
save_path = output_dir / f"{title_prefix.lower()}_sample_{idx:04d}.png"
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Plot saved: {save_path}")
def save_sample_metrics_csv(sample_scores: list[dict], path: Path) -> None:
"""保存逐样本评估指标,便于后续排序和人工排查。"""
if not sample_scores:
return
fieldnames = list(sample_scores[0].keys())
with open(path, "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(sample_scores)
print(f"Sample metrics saved: {path}")
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
"""从元数据读取曲线分段布局;旧数据没有布局时按压力/导数/斜率三等分回退。"""
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
],
}
def choose_output_dir(arg_output_dir: str | None, tag: str | None, use_schedule: bool) -> Path:
"""优先使用命令行输出目录,否则按实验标签生成评估目录。"""
if arg_output_dir is not None:
return Path(arg_output_dir)
return evaluation_dir_for_tag(tag, use_schedule)
def model_predict(
model: ForwardSurrogate,
params_x: torch.Tensor,
schedule_x: torch.Tensor,
use_schedule: bool,
) -> torch.Tensor:
"""按批调用正演代理模型,返回反标准化后的曲线预测。"""
if use_schedule:
return model(params_x, schedule_x)
return model(params_x, None)
def prepare_eval_arrays(eval_data: dict, fit_data: dict | None) -> tuple[np.ndarray, np.ndarray, np.ndarray, object]:
"""根据评估划分取出参数、流量制度和真实曲线,并完成必要的数组整理。"""
x_params_test = np.asarray(eval_data["X_params_test"], dtype=np.float32)
x_schedule_test = np.asarray(eval_data["X_schedule_test"], dtype=np.float32)
y_curve_test = np.asarray(eval_data["Y_curve_test"], dtype=np.float32)
eval_scaler_params = eval_data["scaler_params"]
eval_scaler_schedule = eval_data["scaler_schedule"]
eval_scaler_curve = eval_data["scaler_curve"]
if fit_data is None:
return x_params_test, x_schedule_test, y_curve_test, eval_scaler_curve
fit_scaler_params = fit_data["scaler_params"]
fit_scaler_schedule = fit_data["scaler_schedule"]
fit_scaler_curve = fit_data["scaler_curve"]
eval_param_transform = (eval_data.get("meta", {}) or {}).get("param_feature_transform")
fit_param_transform = (fit_data.get("meta", {}) or {}).get("param_feature_transform")
if eval_param_transform != fit_param_transform:
raise ValueError(
"Cross-dataset evaluation requires matching param_feature_transform metadata. "
"Re-preprocess both datasets with the same transform setting."
)
# 先还原评估集使用的参数特征尺度,再映射到当前模型训练时保存的 scaler 尺度。
x_params_raw = eval_scaler_params.inverse_transform(x_params_test)
x_schedule_raw = eval_scaler_schedule.inverse_transform(x_schedule_test)
x_params_fit = fit_scaler_params.transform(x_params_raw).astype(np.float32)
x_schedule_fit = fit_scaler_schedule.transform(x_schedule_raw).astype(np.float32)
y_true_raw = eval_scaler_curve.inverse_transform(y_curve_test).astype(np.float32)
return x_params_fit, x_schedule_fit, y_true_raw, fit_scaler_curve
def main() -> None:
"""在测试集上评估正演代理模型并输出整体指标、分段指标和样本图。"""
args = parse_args()
tag = normalize_tag(args.tag)
# 评估默认使用同一实验标签下的 processed 数据和最佳 checkpoint
# 也允许手动传入 fit_processed用另一套 scaler 还原预测值。
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = (
Path(args.model)
if args.model is not None
else model_checkpoint_for_tag(tag, use_schedule=not args.no_schedule)
)
print("Loading processed dataset...")
data = joblib.load(processed_path)
fit_processed_path = Path(args.fit_processed) if args.fit_processed is not None else None
fit_data = joblib.load(fit_processed_path) if fit_processed_path is not None else None
# prepare_eval_arrays 会处理“评估集”和“拟合 scaler 的训练集”不一致的情况。
x_params_test, x_schedule_test, y_curve_test, pred_scaler_curve = prepare_eval_arrays(data, fit_data)
meta = data["meta"]
param_dim = int(meta["param_dim"])
schedule_dim = int(meta["schedule_dim"])
curve_dim = int(meta["curve_dim"])
curve_layout = infer_curve_layout(meta, curve_dim)
print(f"Test size: {len(x_params_test)}")
print(f"param_dim={param_dim}, schedule_dim={schedule_dim}, curve_dim={curve_dim}")
print(f"curve_layout={curve_layout}")
print("Loading trained model...")
checkpoint = torch.load(model_path, map_location="cpu")
use_schedule = bool(checkpoint.get("use_schedule", True))
model = ForwardSurrogate(
param_dim=int(checkpoint["param_dim"]),
schedule_dim=int(checkpoint["schedule_dim"]),
curve_dim=int(checkpoint["curve_dim"]),
hidden_dim=int(checkpoint["hidden_dim"]),
dropout=float(checkpoint["dropout"]),
use_schedule=use_schedule,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
output_dir = choose_output_dir(args.output_dir, tag, use_schedule)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Using device: {device}")
print(f"use_schedule={use_schedule}")
print(f"output_dir={output_dir}")
if fit_processed_path is not None:
print(f"fit_processed={fit_processed_path}")
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
all_true = []
all_pred = []
print("\nRunning inference on full test split...")
with torch.no_grad():
for idx in range(len(x_params_test)):
# 单样本推理便于逐条反标准化和保存样本级指标;测试集通常不大。
params_x = torch.tensor(x_params_test[idx : idx + 1], dtype=torch.float32).to(device)
schedule_x = torch.tensor(x_schedule_test[idx : idx + 1], dtype=torch.float32).to(device)
curve_pred_scaled = model_predict(model, params_x, schedule_x, use_schedule).cpu().numpy()
if fit_processed_path is None:
# 常规评估:真实曲线仍是标准化空间,需要用当前 scaler 还原。
curve_true = pred_scaler_curve.inverse_transform(y_curve_test[idx : idx + 1])[0]
else:
# 跨数据集评估prepare_eval_arrays 已经把真实曲线整理到原始尺度。
curve_true = y_curve_test[idx]
curve_pred = pred_scaler_curve.inverse_transform(curve_pred_scaled)[0]
all_true.append(curve_true.astype(np.float32))
all_pred.append(curve_pred.astype(np.float32))
all_true = np.stack(all_true, axis=0)
all_pred = np.stack(all_pred, axis=0)
print("Inference complete.")
overall_metric_list: list[dict] = []
part_metric_lists = {
"log_pressure": [],
"log_derivative": [],
"slope": [],
}
sample_scores: list[dict] = []
for idx, (curve_true, curve_pred) in enumerate(zip(all_true, all_pred)):
# 同时记录整体指标和三段指标,便于判断误差来自压力、导数还是辅助 slope。
overall_m = calc_metrics(curve_true, curve_pred)
overall_metric_list.append(overall_m)
true_parts = split_curve_by_layout(curve_true, curve_layout)
pred_parts = split_curve_by_layout(curve_pred, curve_layout)
part_ms: dict[str, dict] = {}
for name in ["log_pressure", "log_derivative", "slope"]:
part_m = calc_metrics(true_parts[name], pred_parts[name])
part_metric_lists[name].append(part_m)
part_ms[name] = part_m
sample_scores.append(
{
"idx": idx,
"composite_score": build_composite_score(overall_m, part_ms),
"overall_rmse": overall_m["rmse"],
"overall_mae": overall_m["mae"],
"overall_bias": overall_m["bias"],
"overall_abs_bias": overall_m["abs_bias"],
"overall_nrmse": overall_m["nrmse"],
"overall_r2": overall_m["r2"],
"log_pressure_rmse": part_ms["log_pressure"]["rmse"],
"log_pressure_mae": part_ms["log_pressure"]["mae"],
"log_pressure_bias": part_ms["log_pressure"]["bias"],
"log_pressure_abs_bias": part_ms["log_pressure"]["abs_bias"],
"log_pressure_nrmse": part_ms["log_pressure"]["nrmse"],
"log_pressure_r2": part_ms["log_pressure"]["r2"],
"log_derivative_rmse": part_ms["log_derivative"]["rmse"],
"log_derivative_mae": part_ms["log_derivative"]["mae"],
"log_derivative_bias": part_ms["log_derivative"]["bias"],
"log_derivative_abs_bias": part_ms["log_derivative"]["abs_bias"],
"log_derivative_nrmse": part_ms["log_derivative"]["nrmse"],
"log_derivative_r2": part_ms["log_derivative"]["r2"],
"slope_rmse": part_ms["slope"]["rmse"],
"slope_mae": part_ms["slope"]["mae"],
"slope_bias": part_ms["slope"]["bias"],
"slope_abs_bias": part_ms["slope"]["abs_bias"],
"slope_nrmse": part_ms["slope"]["nrmse"],
"slope_r2": part_ms["slope"]["r2"],
"overall_valid_nrmse": overall_m["valid_nrmse"],
"overall_valid_r2": overall_m["valid_r2"],
"log_pressure_valid_nrmse": part_ms["log_pressure"]["valid_nrmse"],
"log_pressure_valid_r2": part_ms["log_pressure"]["valid_r2"],
"log_derivative_valid_nrmse": part_ms["log_derivative"]["valid_nrmse"],
"log_derivative_valid_r2": part_ms["log_derivative"]["valid_r2"],
"slope_valid_nrmse": part_ms["slope"]["valid_nrmse"],
"slope_valid_r2": part_ms["slope"]["valid_r2"],
}
)
print("\n" + "=" * 60)
print("Full test split summary")
# summary_metrics.json 保存聚合指标sample_metrics.csv 保存逐样本指标供后续筛查。
summary = {
"overall": summarize_metric_dicts(overall_metric_list, "overall"),
"log_pressure": summarize_metric_dicts(part_metric_lists["log_pressure"], "log_pressure"),
"log_derivative": summarize_metric_dicts(part_metric_lists["log_derivative"], "log_derivative"),
"slope": summarize_metric_dicts(part_metric_lists["slope"], "slope"),
"checkpoint": {
"model_path": str(model_path),
"processed_path": str(processed_path),
"fit_processed_path": str(fit_processed_path) if fit_processed_path is not None else str(processed_path),
"use_schedule": use_schedule,
},
}
with open(output_dir / "summary_metrics.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"\nSummary saved: {output_dir / 'summary_metrics.json'}")
save_sample_metrics_csv(sample_scores, output_dir / "sample_metrics.csv")
score_composite = np.array([row["composite_score"] for row in sample_scores], dtype=np.float64)
n_random = min(args.n_random_plots, len(sample_scores))
n_best = min(args.n_best_plots, len(sample_scores))
n_worst = min(args.n_worst_plots, len(sample_scores))
best_indices = np.argsort(score_composite)[:n_best].tolist()
worst_indices = np.argsort(score_composite)[-n_worst:].tolist()
random_indices = random.sample(range(len(sample_scores)), n_random)
# 三类图各有用途best 看上限worst 定位失败模式random 检查整体观感。
print("\nBest sample indices:", best_indices)
print("Worst sample indices:", worst_indices)
print("Random sample indices:", random_indices)
for idx in random_indices:
plot_sample(idx, all_true[idx], all_pred[idx], curve_layout, output_dir, "random")
for idx in best_indices:
plot_sample(idx, all_true[idx], all_pred[idx], curve_layout, output_dir, "best")
for idx in worst_indices:
plot_sample(idx, all_true[idx], all_pred[idx], curve_layout, output_dir, "worst")
print("\nArtifacts written to:", output_dir)
print("1. summary_metrics.json")
print("2. sample_metrics.csv")
print("3. best/random/worst sample plots")
if __name__ == "__main__":
main()