You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/evaluate_forward.py

585 lines
23 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""评估固定长度曲线正演代理模型。
脚本加载预处理数据和 `ForwardSurrogate` checkpoint,批量预测验证/测试样本曲线,
按整体曲线与压力、导数、斜率分段统计 RMSE、MAE、NRMSE、R2 等指标,并保存
随机、最佳、最差样本图,作为正演代理模型离线验收入口。
"""
# pylint: disable=import-error,wrong-import-position
# pylint: disable=too-many-locals,too-many-arguments,too-many-positional-arguments,too-many-statements
# pylint: disable=line-too-long
from __future__ import annotations
import argparse
import csv
import json
import random
import sys
from pathlib import Path
import joblib
import matplotlib.pyplot as plt
import numpy as np
import torch
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import (
evaluation_dir_for_tag,
model_checkpoint_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.models.forward_surrogate import ForwardSurrogate
DEFAULT_RANDOM_SEED = 42
def parse_args() -> argparse.Namespace:
"""解析固定长度正演代理评估所需的数据、checkpoint、样本数和绘图开关。"""
parser = argparse.ArgumentParser(description="Evaluate forward surrogate model")
parser.add_argument(
"--processed",
type=str,
default=None,
help="Processed dataset path",
)
parser.add_argument(
"--model",
type=str,
default=None,
help="Model checkpoint path",
)
parser.add_argument(
"--fit-processed",
type=str,
default=None,
help=(
"Processed dataset used to fit scalers for the evaluated model; "
"required for cross-dataset evaluation"
),
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Optional evaluation output directory",
)
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
parser.add_argument(
"--no-schedule",
action="store_true",
help="When --model is omitted, infer the no-schedule checkpoint path",
)
parser.add_argument("--seed", type=int, default=DEFAULT_RANDOM_SEED)
parser.add_argument("--n-random-plots", type=int, default=5)
parser.add_argument("--n-best-plots", type=int, default=5)
parser.add_argument("--n-worst-plots", type=int, default=5)
return parser.parse_args()
def calc_metrics(
y_true: np.ndarray,
y_pred: np.ndarray,
eps_range: float = 1e-3,
eps_var: float = 1e-6,
) -> dict:
"""计算 RMSE、MAE、Bias、NRMSE、R2 等回归指标。"""
err = y_pred - y_true
mse = np.mean(err**2)
rmse = float(np.sqrt(mse))
mae = float(np.mean(np.abs(err)))
bias = float(np.mean(err))
value_range = float(np.max(y_true) - np.min(y_true))
ss_tot = float(np.sum((y_true - np.mean(y_true)) ** 2))
ss_res = float(np.sum(err**2))
valid_nrmse = value_range > eps_range
valid_r2 = ss_tot > eps_var
nrmse = float(rmse / value_range) if valid_nrmse else np.nan
r2 = float(1.0 - ss_res / ss_tot) if valid_r2 else np.nan
return {
"rmse": rmse,
"mae": mae,
"bias": bias,
"abs_bias": float(abs(bias)),
"nrmse": nrmse,
"r2": r2,
"range_true": value_range,
"ss_tot": ss_tot,
"valid_nrmse": bool(valid_nrmse),
"valid_r2": bool(valid_r2),
}
def split_curve_by_layout(curve: np.ndarray, layout: dict) -> dict[str, np.ndarray]:
"""按照 curve_layout 将拼接曲线拆成 log_pressure、log_derivative 和 slope 三段。"""
parts: dict[str, np.ndarray] = {}
for part in layout["parts"]:
start = int(part["start"])
end = int(part["end"])
parts[str(part["name"])] = curve[start:end]
return parts
def safe_percentile(x: np.ndarray, q: float) -> float:
"""计算忽略 NaN 后的百分位数;没有有效数据时返回 NaN。"""
if x.size == 0:
return np.nan
return float(np.percentile(x, q))
def safe_mean(x: np.ndarray) -> float:
"""计算忽略 NaN 后的均值;没有有效数据时返回 NaN。"""
if x.size == 0:
return np.nan
return float(np.mean(x))
def safe_median(x: np.ndarray) -> float:
"""计算忽略 NaN 后的中位数;没有有效数据时返回 NaN。"""
if x.size == 0:
return np.nan
return float(np.median(x))
def summarize_metric_dicts(metric_dicts: list[dict], prefix: str) -> dict:
"""把多个样本的指标字典合并为均值、中位数和分位数统计。"""
rmse = np.array([m["rmse"] for m in metric_dicts], dtype=np.float64)
mae = np.array([m["mae"] for m in metric_dicts], dtype=np.float64)
bias = np.array([m["bias"] for m in metric_dicts], dtype=np.float64)
abs_bias = np.array([m["abs_bias"] for m in metric_dicts], dtype=np.float64)
ranges = np.array([m["range_true"] for m in metric_dicts], dtype=np.float64)
ss_tot = np.array([m["ss_tot"] for m in metric_dicts], dtype=np.float64)
nrmse_valid = np.array([m["nrmse"] for m in metric_dicts if m["valid_nrmse"]], dtype=np.float64)
r2_valid = np.array([m["r2"] for m in metric_dicts if m["valid_r2"]], dtype=np.float64)
valid_nrmse_ratio = len(nrmse_valid) / max(len(metric_dicts), 1)
valid_r2_ratio = len(r2_valid) / max(len(metric_dicts), 1)
low_range_count = int(np.sum(ranges <= 1e-3))
low_var_count = int(np.sum(ss_tot <= 1e-6))
print(f"\n[{prefix}]")
print(
f" RMSE mean={safe_mean(rmse):.6f}, median={safe_median(rmse):.6f}, "
f"p90={safe_percentile(rmse, 90):.6f}"
)
print(
f" MAE mean={safe_mean(mae):.6f}, median={safe_median(mae):.6f}, "
f"p90={safe_percentile(mae, 90):.6f}"
)
print(
f" Bias mean={safe_mean(bias):.6f}, median={safe_median(bias):.6f}, "
f"|mean|={safe_mean(abs_bias):.6f}"
)
print(
f" TrueRange mean={safe_mean(ranges):.6f}, median={safe_median(ranges):.6f}, "
f"p10={safe_percentile(ranges, 10):.6f}"
)
print(
f" NRMSE(valid only) mean={safe_mean(nrmse_valid):.6f}, "
f"median={safe_median(nrmse_valid):.6f}, p90={safe_percentile(nrmse_valid, 90):.6f}, "
f"valid_ratio={valid_nrmse_ratio:.4f}, low_range_count={low_range_count}"
)
print(
f" R2(valid only) mean={safe_mean(r2_valid):.6f}, "
f"median={safe_median(r2_valid):.6f}, p10={safe_percentile(r2_valid, 10):.6f}, "
f"valid_ratio={valid_r2_ratio:.4f}, low_var_count={low_var_count}"
)
return {
"rmse_mean": safe_mean(rmse),
"rmse_median": safe_median(rmse),
"rmse_p90": safe_percentile(rmse, 90),
"mae_mean": safe_mean(mae),
"mae_median": safe_median(mae),
"mae_p90": safe_percentile(mae, 90),
"bias_mean": safe_mean(bias),
"bias_median": safe_median(bias),
"abs_bias_mean": safe_mean(abs_bias),
"nrmse_mean_valid": safe_mean(nrmse_valid),
"nrmse_median_valid": safe_median(nrmse_valid),
"nrmse_p90_valid": safe_percentile(nrmse_valid, 90),
"nrmse_valid_ratio": valid_nrmse_ratio,
"r2_mean_valid": safe_mean(r2_valid),
"r2_median_valid": safe_median(r2_valid),
"r2_p10_valid": safe_percentile(r2_valid, 10),
"r2_valid_ratio": valid_r2_ratio,
"low_range_count": low_range_count,
"low_var_count": low_var_count,
}
def build_composite_score(overall_m: dict, part_ms: dict) -> float:
"""把整体误差和分段误差组合成单个分数,用于挑选最差样本。"""
return float(
1.0 * overall_m["rmse"]
+ 0.5 * overall_m["mae"]
+ 0.8 * part_ms["log_pressure"]["abs_bias"]
+ 0.8 * part_ms["log_derivative"]["rmse"]
+ 0.2 * part_ms["slope"]["rmse"]
)
def plot_sample(
idx: int,
curve_true: np.ndarray,
curve_pred: np.ndarray,
curve_layout: dict,
output_dir: Path,
title_prefix: str,
) -> None:
"""绘制单个样本的真实曲线、预测曲线和误差曲线,并保存为图片。"""
true_parts = split_curve_by_layout(curve_true, curve_layout)
pred_parts = split_curve_by_layout(curve_pred, curve_layout)
overall = calc_metrics(curve_true, curve_pred)
nrmse_text = "nan" if np.isnan(overall["nrmse"]) else f"{overall['nrmse']:.4f}"
r2_text = "nan" if np.isnan(overall["r2"]) else f"{overall['r2']:.4f}"
fig, axes = plt.subplots(3, 2, figsize=(14, 12))
fig.suptitle(
f"{title_prefix} | Sample #{idx} | "
f"Overall RMSE={overall['rmse']:.4f}, MAE={overall['mae']:.4f}, "
f"Bias={overall['bias']:.4f}, "
f"NRMSE={nrmse_text}, "
f"R2={r2_text}"
)
plot_order = ["log_pressure", "log_derivative", "slope"]
title_map = {
"log_pressure": "Log Pressure",
"log_derivative": "Log |Derivative|",
"slope": "Slope of Log Pressure vs Log Time",
}
for row, name in enumerate(plot_order):
y_true = true_parts[name]
y_pred = pred_parts[name]
err = y_pred - y_true
x = np.arange(len(y_true))
m = calc_metrics(y_true, y_pred)
nrmse_text = "nan" if np.isnan(m["nrmse"]) else f"{m['nrmse']:.4f}"
r2_text = "nan" if np.isnan(m["r2"]) else f"{m['r2']:.4f}"
ax_l = axes[row, 0]
ax_l.plot(x, y_true, label="True", linewidth=2, alpha=0.85)
ax_l.plot(x, y_pred, label="Predicted", linewidth=2, alpha=0.85)
ax_l.set_title(
f"{title_map[name]} | RMSE={m['rmse']:.4f}, MAE={m['mae']:.4f}, "
f"Bias={m['bias']:.4f}, NRMSE={nrmse_text}, R2={r2_text}"
)
ax_l.set_xlabel("Resampled Time Index")
ax_l.set_ylabel("Value")
ax_l.grid(True, alpha=0.3)
ax_l.legend()
ax_r = axes[row, 1]
ax_r.plot(x, err, linewidth=1.5, alpha=0.85)
ax_r.axhline(0.0, linestyle="--", linewidth=1)
ax_r.set_title(f"{title_map[name]} Error")
ax_r.set_xlabel("Resampled Time Index")
ax_r.set_ylabel("Pred - True")
ax_r.grid(True, alpha=0.3)
plt.tight_layout(rect=[0, 0, 1, 0.97])
save_path = output_dir / f"{title_prefix.lower()}_sample_{idx:04d}.png"
plt.savefig(save_path, dpi=150, bbox_inches="tight")
plt.close()
print(f"Plot saved: {save_path}")
def save_sample_metrics_csv(sample_scores: list[dict], path: Path) -> None:
"""保存逐样本评估指标,便于后续排序和人工排查。"""
if not sample_scores:
return
fieldnames = list(sample_scores[0].keys())
with open(path, "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(sample_scores)
print(f"Sample metrics saved: {path}")
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
"""从元数据读取曲线分段布局;旧数据没有布局时按压力/导数/斜率三等分回退。"""
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
],
}
def choose_output_dir(arg_output_dir: str | None, tag: str | None, use_schedule: bool) -> Path:
"""优先使用命令行输出目录,否则按实验标签生成评估目录。"""
if arg_output_dir is not None:
return Path(arg_output_dir)
return evaluation_dir_for_tag(tag, use_schedule)
def model_predict(
model: ForwardSurrogate,
params_x: torch.Tensor,
schedule_x: torch.Tensor,
use_schedule: bool,
) -> torch.Tensor:
"""按批调用正演代理模型,返回反标准化后的曲线预测。"""
if use_schedule:
return model(params_x, schedule_x)
return model(params_x, None)
def prepare_eval_arrays(eval_data: dict, fit_data: dict | None) -> tuple[np.ndarray, np.ndarray, np.ndarray, object]:
"""根据评估划分取出参数、流量制度和真实曲线,并完成必要的数组整理。"""
x_params_test = np.asarray(eval_data["X_params_test"], dtype=np.float32)
x_schedule_test = np.asarray(eval_data["X_schedule_test"], dtype=np.float32)
y_curve_test = np.asarray(eval_data["Y_curve_test"], dtype=np.float32)
eval_scaler_params = eval_data["scaler_params"]
eval_scaler_schedule = eval_data["scaler_schedule"]
eval_scaler_curve = eval_data["scaler_curve"]
if fit_data is None:
return x_params_test, x_schedule_test, y_curve_test, eval_scaler_curve
fit_scaler_params = fit_data["scaler_params"]
fit_scaler_schedule = fit_data["scaler_schedule"]
fit_scaler_curve = fit_data["scaler_curve"]
eval_param_transform = (eval_data.get("meta", {}) or {}).get("param_feature_transform")
fit_param_transform = (fit_data.get("meta", {}) or {}).get("param_feature_transform")
if eval_param_transform != fit_param_transform:
raise ValueError(
"Cross-dataset evaluation requires matching param_feature_transform metadata. "
"Re-preprocess both datasets with the same transform setting."
)
# 先还原评估集使用的参数特征尺度,再映射到当前模型训练时保存的 scaler 尺度。
x_params_raw = eval_scaler_params.inverse_transform(x_params_test)
x_schedule_raw = eval_scaler_schedule.inverse_transform(x_schedule_test)
x_params_fit = fit_scaler_params.transform(x_params_raw).astype(np.float32)
x_schedule_fit = fit_scaler_schedule.transform(x_schedule_raw).astype(np.float32)
y_true_raw = eval_scaler_curve.inverse_transform(y_curve_test).astype(np.float32)
return x_params_fit, x_schedule_fit, y_true_raw, fit_scaler_curve
def main() -> None:
"""在测试集上评估正演代理模型并输出整体指标、分段指标和样本图。"""
args = parse_args()
tag = normalize_tag(args.tag)
# 评估默认使用同一实验标签下的 processed 数据和最佳 checkpoint
# 也允许手动传入 fit_processed用另一套 scaler 还原预测值。
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = (
Path(args.model)
if args.model is not None
else model_checkpoint_for_tag(tag, use_schedule=not args.no_schedule)
)
print("Loading processed dataset...")
data = joblib.load(processed_path)
fit_processed_path = Path(args.fit_processed) if args.fit_processed is not None else None
fit_data = joblib.load(fit_processed_path) if fit_processed_path is not None else None
# prepare_eval_arrays 会处理“评估集”和“拟合 scaler 的训练集”不一致的情况。
x_params_test, x_schedule_test, y_curve_test, pred_scaler_curve = prepare_eval_arrays(data, fit_data)
meta = data["meta"]
param_dim = int(meta["param_dim"])
schedule_dim = int(meta["schedule_dim"])
curve_dim = int(meta["curve_dim"])
curve_layout = infer_curve_layout(meta, curve_dim)
print(f"Test size: {len(x_params_test)}")
print(f"param_dim={param_dim}, schedule_dim={schedule_dim}, curve_dim={curve_dim}")
print(f"curve_layout={curve_layout}")
print("Loading trained model...")
checkpoint = torch.load(model_path, map_location="cpu")
use_schedule = bool(checkpoint.get("use_schedule", True))
model = ForwardSurrogate(
param_dim=int(checkpoint["param_dim"]),
schedule_dim=int(checkpoint["schedule_dim"]),
curve_dim=int(checkpoint["curve_dim"]),
hidden_dim=int(checkpoint["hidden_dim"]),
dropout=float(checkpoint["dropout"]),
use_schedule=use_schedule,
)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
output_dir = choose_output_dir(args.output_dir, tag, use_schedule)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Using device: {device}")
print(f"use_schedule={use_schedule}")
print(f"output_dir={output_dir}")
if fit_processed_path is not None:
print(f"fit_processed={fit_processed_path}")
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
all_true = []
all_pred = []
print("\nRunning inference on full test split...")
with torch.no_grad():
for idx in range(len(x_params_test)):
# 单样本推理便于逐条反标准化和保存样本级指标;测试集通常不大。
params_x = torch.tensor(x_params_test[idx : idx + 1], dtype=torch.float32).to(device)
schedule_x = torch.tensor(x_schedule_test[idx : idx + 1], dtype=torch.float32).to(device)
curve_pred_scaled = model_predict(model, params_x, schedule_x, use_schedule).cpu().numpy()
if fit_processed_path is None:
# 常规评估:真实曲线仍是标准化空间,需要用当前 scaler 还原。
curve_true = pred_scaler_curve.inverse_transform(y_curve_test[idx : idx + 1])[0]
else:
# 跨数据集评估prepare_eval_arrays 已经把真实曲线整理到原始尺度。
curve_true = y_curve_test[idx]
curve_pred = pred_scaler_curve.inverse_transform(curve_pred_scaled)[0]
all_true.append(curve_true.astype(np.float32))
all_pred.append(curve_pred.astype(np.float32))
all_true = np.stack(all_true, axis=0)
all_pred = np.stack(all_pred, axis=0)
print("Inference complete.")
overall_metric_list: list[dict] = []
part_metric_lists = {
"log_pressure": [],
"log_derivative": [],
"slope": [],
}
sample_scores: list[dict] = []
for idx, (curve_true, curve_pred) in enumerate(zip(all_true, all_pred)):
# 同时记录整体指标和三段指标,便于判断误差来自压力、导数还是辅助 slope。
overall_m = calc_metrics(curve_true, curve_pred)
overall_metric_list.append(overall_m)
true_parts = split_curve_by_layout(curve_true, curve_layout)
pred_parts = split_curve_by_layout(curve_pred, curve_layout)
part_ms: dict[str, dict] = {}
for name in ["log_pressure", "log_derivative", "slope"]:
part_m = calc_metrics(true_parts[name], pred_parts[name])
part_metric_lists[name].append(part_m)
part_ms[name] = part_m
sample_scores.append(
{
"idx": idx,
"composite_score": build_composite_score(overall_m, part_ms),
"overall_rmse": overall_m["rmse"],
"overall_mae": overall_m["mae"],
"overall_bias": overall_m["bias"],
"overall_abs_bias": overall_m["abs_bias"],
"overall_nrmse": overall_m["nrmse"],
"overall_r2": overall_m["r2"],
"log_pressure_rmse": part_ms["log_pressure"]["rmse"],
"log_pressure_mae": part_ms["log_pressure"]["mae"],
"log_pressure_bias": part_ms["log_pressure"]["bias"],
"log_pressure_abs_bias": part_ms["log_pressure"]["abs_bias"],
"log_pressure_nrmse": part_ms["log_pressure"]["nrmse"],
"log_pressure_r2": part_ms["log_pressure"]["r2"],
"log_derivative_rmse": part_ms["log_derivative"]["rmse"],
"log_derivative_mae": part_ms["log_derivative"]["mae"],
"log_derivative_bias": part_ms["log_derivative"]["bias"],
"log_derivative_abs_bias": part_ms["log_derivative"]["abs_bias"],
"log_derivative_nrmse": part_ms["log_derivative"]["nrmse"],
"log_derivative_r2": part_ms["log_derivative"]["r2"],
"slope_rmse": part_ms["slope"]["rmse"],
"slope_mae": part_ms["slope"]["mae"],
"slope_bias": part_ms["slope"]["bias"],
"slope_abs_bias": part_ms["slope"]["abs_bias"],
"slope_nrmse": part_ms["slope"]["nrmse"],
"slope_r2": part_ms["slope"]["r2"],
"overall_valid_nrmse": overall_m["valid_nrmse"],
"overall_valid_r2": overall_m["valid_r2"],
"log_pressure_valid_nrmse": part_ms["log_pressure"]["valid_nrmse"],
"log_pressure_valid_r2": part_ms["log_pressure"]["valid_r2"],
"log_derivative_valid_nrmse": part_ms["log_derivative"]["valid_nrmse"],
"log_derivative_valid_r2": part_ms["log_derivative"]["valid_r2"],
"slope_valid_nrmse": part_ms["slope"]["valid_nrmse"],
"slope_valid_r2": part_ms["slope"]["valid_r2"],
}
)
print("\n" + "=" * 60)
print("Full test split summary")
# summary_metrics.json 保存聚合指标sample_metrics.csv 保存逐样本指标供后续筛查。
summary = {
"overall": summarize_metric_dicts(overall_metric_list, "overall"),
"log_pressure": summarize_metric_dicts(part_metric_lists["log_pressure"], "log_pressure"),
"log_derivative": summarize_metric_dicts(part_metric_lists["log_derivative"], "log_derivative"),
"slope": summarize_metric_dicts(part_metric_lists["slope"], "slope"),
"checkpoint": {
"model_path": str(model_path),
"processed_path": str(processed_path),
"fit_processed_path": str(fit_processed_path) if fit_processed_path is not None else str(processed_path),
"use_schedule": use_schedule,
},
}
with open(output_dir / "summary_metrics.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print(f"\nSummary saved: {output_dir / 'summary_metrics.json'}")
save_sample_metrics_csv(sample_scores, output_dir / "sample_metrics.csv")
score_composite = np.array([row["composite_score"] for row in sample_scores], dtype=np.float64)
n_random = min(args.n_random_plots, len(sample_scores))
n_best = min(args.n_best_plots, len(sample_scores))
n_worst = min(args.n_worst_plots, len(sample_scores))
best_indices = np.argsort(score_composite)[:n_best].tolist()
worst_indices = np.argsort(score_composite)[-n_worst:].tolist()
random_indices = random.sample(range(len(sample_scores)), n_random)
# 三类图各有用途best 看上限worst 定位失败模式random 检查整体观感。
print("\nBest sample indices:", best_indices)
print("Worst sample indices:", worst_indices)
print("Random sample indices:", random_indices)
for idx in random_indices:
plot_sample(idx, all_true[idx], all_pred[idx], curve_layout, output_dir, "random")
for idx in best_indices:
plot_sample(idx, all_true[idx], all_pred[idx], curve_layout, output_dir, "best")
for idx in worst_indices:
plot_sample(idx, all_true[idx], all_pred[idx], curve_layout, output_dir, "worst")
print("\nArtifacts written to:", output_dir)
print("1. summary_metrics.json")
print("2. sample_metrics.csv")
print("3. best/random/worst sample plots")
if __name__ == "__main__":
main()