You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/replay_pso_trace_screening.py

420 lines
18 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""回放 PSO 轨迹并评估代理模型筛选效果。
脚本读取已有 PSO trace 中的候选参数和真实求解器目标值,重新用正演代理模型打分,
比较两套目标函数的相关性与 Top-K/保留比例效果,帮助判断代理是否能减少真实求解器
调用量而不明显损失最优候选。
"""
from __future__ import annotations
import argparse
import csv
import json
import math
import sys
from collections import defaultdict
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import joblib
import matplotlib.pyplot as plt
import numpy as np
from scripts.compare_single_case import build_schedule_vector, load_model
from scripts.validate_autofit_local_ranking import infer_curve_layout, predict_surrogate_curve
from src.common.config import Config
from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag
from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features
from src.data.params import Params, Schedule
from src.evaluation.autofit_objective import dual_log_objective
PARAM_COLUMNS = ["k", "skin", "wellboreC", "phi", "h", "Cf"]
DEFAULT_KEEP_FRACS = [0.5, 0.6, 0.7]
def parse_args() -> argparse.Namespace:
"""解析 PSO trace、代理评分表和 keep_ratio用于回放筛选策略。"""
parser = argparse.ArgumentParser(
description="Replay a C++ full-solver PSO trace with the forward surrogate as a screening model."
)
parser.add_argument("--trace-csv", type=str, default=None, help="Path to pso_baseline_trace_*.csv")
parser.add_argument("--trace-meta", type=str, default=None, help="Path to matching pso_baseline_trace_*.meta.json")
parser.add_argument("--tag", type=str, default="family_random_mixed_50k_logparam")
parser.add_argument("--stage", type=str, default="family_random")
parser.add_argument("--processed", type=str, default=None)
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--config", type=str, default=None)
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--keep-fracs", type=str, default="0.5,0.6,0.7")
return parser.parse_args()
def latest_trace_pair(temp_dir: Path) -> tuple[Path, Path]:
"""在 PSO 结果目录中找到最新的轨迹文件和候选文件组合。"""
csv_paths = sorted(temp_dir.glob("pso_baseline_trace_*.csv"), key=lambda p: p.stat().st_mtime, reverse=True)
for csv_path in csv_paths:
meta_path = csv_path.with_suffix(".meta.json")
if meta_path.exists():
return csv_path, meta_path
raise FileNotFoundError(f"No pso_baseline_trace_*.csv + .meta.json pair found in {temp_dir}")
def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path, Path, Path, Path, Path]:
"""解析 PSO 回放所需的轨迹文件、候选文件、配置文件和输出目录。"""
tag = normalize_tag(args.tag)
if args.trace_csv is None:
trace_csv, trace_meta = latest_trace_pair(ROOT / "data" / "temp")
else:
trace_csv = Path(args.trace_csv)
trace_meta = Path(args.trace_meta) if args.trace_meta is not None else trace_csv.with_suffix(".meta.json")
config_path = Path(args.config) if args.config is not None else config_for_stage(args.stage)
if config_path is None:
raise ValueError(f"Cannot resolve config for stage={args.stage!r}; pass --config explicitly")
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = Path(args.model) if args.model is not None else model_checkpoint_for_tag(tag, use_schedule=True)
run_id = trace_csv.stem.replace("pso_baseline_trace_", "")
output_dir = Path(args.output_dir) if args.output_dir is not None else Path("results") / f"pso_trace_replay_{run_id}"
return (
trace_csv.resolve(),
trace_meta.resolve(),
config_path.resolve(),
processed_path.resolve(),
model_path.resolve(),
output_dir.resolve(),
)
def read_trace_csv(path: Path) -> list[dict]:
"""读取 PSO 轨迹 CSV并保留每代候选的代理分数和真实求解器分数。"""
with path.open("r", newline="", encoding="utf-8-sig") as f:
return list(csv.DictReader(f))
def to_float(value: str | None, default: float = float("nan")) -> float:
"""将表格字段安全转为 float无法解析时返回 NaN。"""
if value is None or value == "":
return default
try:
return float(value)
except ValueError:
return default
def is_valid_solver_row(row: dict) -> bool:
"""判断轨迹行是否包含可用的真实求解器目标值。"""
obj = to_float(row.get("solver_objective"))
return row.get("phase") == "particle_solver" and row.get("solver_success") == "1" and math.isfinite(obj) and obj < 1e9
def build_schedule_from_meta(meta: dict) -> Schedule:
"""从 CSV 元数据字段还原 Schedule 对象。"""
target = meta["target"]
flow_points = target.get("flow_points") or []
time_q: list[float] = []
q: list[float] = []
for item in flow_points:
dt = float(item["x"])
rate = float(item["y"])
if dt <= 0 and not time_q:
continue
time_q.append(dt)
q.append(rate)
return Schedule(sectionIndex=int(target["section_index"]), timeQ=time_q, q=q)
def build_target_curve_from_meta(cfg: Config, processed: dict, meta: dict) -> np.ndarray:
"""从 CSV 元数据字段还原目标双对数曲线向量。"""
target = meta["target"]["target_loglog"]
t = np.asarray(target["time"], dtype=np.float64)
p = np.asarray(target["pressure"], dtype=np.float64)
d = np.asarray(target["derivative"], dtype=np.float64)
cleaned = clean_curve_for_dataset(cfg, t, p, d)
if cleaned is None:
raise RuntimeError("Target loglog curve from meta cannot be cleaned for replay")
t_clean, p_clean, d_clean = cleaned
valid, reason = is_valid_curve(cfg, t_clean, p_clean, d_clean)
if not valid:
raise RuntimeError(f"Target loglog curve from meta is invalid: {reason}")
curve = resample_curve_to_features(cfg, t_clean, p_clean, d_clean)
curve_dim = int(processed["meta"]["curve_dim"])
if curve.size != curve_dim:
raise RuntimeError(f"Target curve dim mismatch after resample: {curve.size} != {curve_dim}")
return curve.astype(np.float32)
def params_from_row(row: dict) -> Params:
"""从 CSV 行读取物理参数字段并构造 Params 对象。"""
values = {name: to_float(row.get(name)) for name in PARAM_COLUMNS}
return Params(
k=values["k"],
skin=values["skin"],
wellboreC=values["wellboreC"],
phi=values["phi"],
h=values["h"],
Cf=values["Cf"],
)
def rank_positions(values: np.ndarray) -> np.ndarray:
"""把目标函数值转换成排名位置,用于比较代理排序和真实排序。"""
order = np.argsort(values)
ranks = np.empty_like(order)
ranks[order] = np.arange(values.size)
return ranks
def corr_pearson(a: np.ndarray, b: np.ndarray) -> float:
"""计算 Pearson 相关系数,用于评估代理分数与真实分数的线性相关。"""
if a.size < 2:
return float("nan")
if np.std(a) <= 1e-12 or np.std(b) <= 1e-12:
return float("nan")
return float(np.corrcoef(a, b)[0, 1])
def corr_spearman(a: np.ndarray, b: np.ndarray) -> float:
"""计算 Spearman 秩相关,用于评估候选排序是否一致。"""
if a.size < 2:
return float("nan")
return corr_pearson(rank_positions(a).astype(np.float64), rank_positions(b).astype(np.float64))
def summarize_generation(rows: list[dict], keep_fracs: list[float]) -> tuple[dict, list[dict]]:
"""按 PSO 代数汇总代理筛选保留真实优质候选的效果。"""
valid = [r for r in rows if r["solver_success"] == "1" and math.isfinite(float(r["solver_objective"])) and float(r["solver_objective"]) < 1e9]
if len(valid) < 2:
return {}, []
solver_obj = np.asarray([float(r["solver_objective"]) for r in valid], dtype=np.float64)
surrogate_obj = np.asarray([float(r["surrogate_objective"]) for r in valid], dtype=np.float64)
solver_rank = rank_positions(solver_obj)
surrogate_rank = rank_positions(surrogate_obj)
# 给每个粒子同时标注真实排名和代理排名,便于后续定位“代理看好但真实较差”的粒子。
for i, row in enumerate(valid):
row["solver_rank"] = int(solver_rank[i])
row["surrogate_rank"] = int(surrogate_rank[i])
row["rank_gap"] = int(surrogate_rank[i] - solver_rank[i])
best_solver_idx = int(np.argmin(solver_obj))
best_surrogate_idx = int(np.argmin(surrogate_obj))
top_n = min(10, len(valid))
top_solver = set(np.argsort(solver_obj)[:top_n].tolist())
top_surrogate = set(np.argsort(surrogate_obj)[:top_n].tolist())
summary = {
"generation": int(valid[0]["generation"]),
"n_particles": len(rows),
"n_valid_solver": len(valid),
"pearson_objective": corr_pearson(solver_obj, surrogate_obj),
"spearman_objective": corr_spearman(solver_obj, surrogate_obj),
"top10_overlap": int(len(top_solver & top_surrogate)),
"best_solver_particle_id": int(valid[best_solver_idx]["particle_id"]),
"best_solver_surrogate_rank": int(surrogate_rank[best_solver_idx]),
"best_surrogate_particle_id": int(valid[best_surrogate_idx]["particle_id"]),
"best_surrogate_solver_rank": int(solver_rank[best_surrogate_idx]),
"solver_best_objective": float(solver_obj[best_solver_idx]),
"solver_objective_at_surrogate_best": float(solver_obj[best_surrogate_idx]),
"surrogate_top1_regret": float(solver_obj[best_surrogate_idx] - solver_obj[best_solver_idx]),
}
screening_rows: list[dict] = []
all_surrogate = np.asarray([float(r["surrogate_objective"]) for r in rows], dtype=np.float64)
valid_by_particle = {int(r["particle_id"]): r for r in valid}
for keep_frac in keep_fracs:
keep_n = max(1, int(math.ceil(len(rows) * float(keep_frac))))
kept_all_positions = np.argsort(all_surrogate)[:keep_n]
kept_particle_ids = {int(rows[i]["particle_id"]) for i in kept_all_positions}
kept_valid_indices = [i for i, r in enumerate(valid) if int(r["particle_id"]) in kept_particle_ids]
kept_solver_obj = solver_obj[kept_valid_indices] if kept_valid_indices else np.asarray([], dtype=np.float64)
# 模拟“先用代理筛选,再让真实求解器评估保留下来的粒子”的效果。
best_missed = int(int(valid[best_solver_idx]["particle_id"]) not in kept_particle_ids)
top_recall = float(len({int(valid[i]["particle_id"]) for i in top_solver} & kept_particle_ids) / max(top_n, 1))
if kept_solver_obj.size:
screening_regret = float(np.min(kept_solver_obj) - solver_obj[best_solver_idx])
else:
screening_regret = float("inf")
screening_rows.append(
{
"generation": summary["generation"],
"keep_frac": float(keep_frac),
"keep_n": int(keep_n),
"solver_top10_recall": top_recall,
"best_missed": best_missed,
"screening_regret": screening_regret,
"kept_valid_solver": int(len(kept_valid_indices)),
"kept_failed_or_invalid": int(sum(1 for i in kept_all_positions if int(rows[i]["particle_id"]) not in valid_by_particle)),
}
)
return summary, screening_rows
def write_csv(path: Path, rows: list[dict]) -> None:
"""按字段名写出 CSV 明细或汇总结果。"""
if not rows:
return
fieldnames: list[str] = []
for row in rows:
for key in row.keys():
if key not in fieldnames:
fieldnames.append(key)
with path.open("w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def plot_objective_scatter(rows: list[dict], path: Path) -> None:
"""绘制代理目标与真实目标的散点图,直观看排序筛选偏差。"""
valid = [r for r in rows if r.get("solver_success") == "1" and float(r.get("solver_objective", "inf")) < 1e9]
if len(valid) < 2:
return
solver = np.asarray([float(r["solver_objective"]) for r in valid], dtype=np.float64)
surrogate = np.asarray([float(r["surrogate_objective"]) for r in valid], dtype=np.float64)
gen = np.asarray([int(r["generation"]) for r in valid], dtype=np.int32)
plt.figure(figsize=(6, 5), dpi=150)
sc = plt.scatter(solver, surrogate, c=gen, cmap="viridis", alpha=0.8)
lo = min(float(np.min(solver)), float(np.min(surrogate)))
hi = max(float(np.max(solver)), float(np.max(surrogate)))
plt.plot([lo, hi], [lo, hi], "k--", linewidth=1)
plt.xlabel("C++ solver objective")
plt.ylabel("Surrogate objective to target")
plt.title("PSO trace replay objective ranking")
plt.colorbar(sc, label="generation")
plt.tight_layout()
plt.savefig(path)
plt.close()
def main() -> None:
"""在既有 PSO 轨迹上离线模拟代理筛选,评估真实最优粒子是否被保留。"""
args = parse_args()
trace_csv, trace_meta, config_path, processed_path, model_path, output_dir = resolve_paths(args)
keep_fracs = [float(x) for x in args.keep_fracs.split(",") if x.strip()]
cfg = Config(config_path)
processed = joblib.load(processed_path)
model, use_schedule, device = load_model(model_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
meta = json.loads(trace_meta.read_text(encoding="utf-8"))
rows = read_trace_csv(trace_csv)
schedule = build_schedule_from_meta(meta)
target_curve = build_target_curve_from_meta(cfg, processed, meta)
output_dir.mkdir(parents=True, exist_ok=True)
replay_rows: list[dict] = []
for row in rows:
if row.get("phase") != "particle_solver":
continue
# 回放时不重跑真实求解器,只用保存的粒子参数重新计算代理目标,和 trace 中真实目标对比。
params = params_from_row(row)
params.schedule = schedule
pred_curve = predict_surrogate_curve(
processed=processed,
model=model,
device=device,
use_schedule=use_schedule,
params=params,
schedule=schedule,
cfg=cfg,
)
sur_obj = dual_log_objective(target_curve, pred_curve, curve_layout)
replay_row = dict(row)
replay_row["surrogate_objective"] = sur_obj["dual_log_objective"]
replay_row["surrogate_p_obj"] = sur_obj["log_pressure_objective"]
replay_row["surrogate_d_obj"] = sur_obj["log_derivative_objective"]
replay_rows.append(replay_row)
by_generation: dict[int, list[dict]] = defaultdict(list)
for row in replay_rows:
by_generation[int(row["generation"])].append(row)
# 按 PSO 代分组统计,因为早期和后期粒子分布差异很大,不能只看全局平均。
generation_summaries: list[dict] = []
screening_rows: list[dict] = []
for generation in sorted(by_generation):
summary, gen_screening_rows = summarize_generation(by_generation[generation], keep_fracs)
if summary:
generation_summaries.append(summary)
screening_rows.extend(gen_screening_rows)
write_csv(output_dir / "candidate_objectives.csv", replay_rows)
write_csv(output_dir / "generation_summary.csv", generation_summaries)
write_csv(output_dir / "screening_summary.csv", screening_rows)
plot_objective_scatter(replay_rows, output_dir / "objective_scatter.png")
summary = {
"trace_csv": str(trace_csv),
"trace_meta": str(trace_meta),
"processed_path": str(processed_path),
"model_path": str(model_path),
"run_id": meta.get("run_id"),
"n_particle_rows": len(replay_rows),
"n_generations": len(generation_summaries),
"target_points_raw": len(meta["target"]["target_loglog"]["time"]),
"target_curve_dim_resampled": int(target_curve.size),
"use_schedule": bool(use_schedule),
"keep_fracs": keep_fracs,
"generation_summary_mean": {},
"screening_mean": {},
}
if generation_summaries:
# generation_summary_mean 给出跨代平均排序质量,用来和不同代理模型版本横向比较。
for key in [
"pearson_objective",
"spearman_objective",
"top10_overlap",
"best_solver_surrogate_rank",
"best_surrogate_solver_rank",
"surrogate_top1_regret",
]:
values = np.asarray([float(r[key]) for r in generation_summaries], dtype=np.float64)
summary["generation_summary_mean"][key] = float(np.nanmean(values))
for keep_frac in keep_fracs:
rows_k = [r for r in screening_rows if abs(float(r["keep_frac"]) - keep_frac) < 1e-12]
if not rows_k:
continue
# screening_mean 直接回答:保留某个比例时,真实好粒子是否被代理筛掉。
summary["screening_mean"][str(keep_frac)] = {
"solver_top10_recall": float(np.mean([float(r["solver_top10_recall"]) for r in rows_k])),
"best_missed_ratio": float(np.mean([float(r["best_missed"]) for r in rows_k])),
"screening_regret": float(np.mean([float(r["screening_regret"]) for r in rows_k])),
"kept_failed_or_invalid": float(np.mean([float(r["kept_failed_or_invalid"]) for r in rows_k])),
}
(output_dir / "summary.json").write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"Trace replay output: {output_dir}")
print(f"particle_rows={summary['n_particle_rows']}, generations={summary['n_generations']}")
if summary["generation_summary_mean"]:
print(
"mean Spearman="
f"{summary['generation_summary_mean']['spearman_objective']:.4f}, "
f"mean best_solver_surrogate_rank={summary['generation_summary_mean']['best_solver_surrogate_rank']:.2f}"
)
for keep_frac, metrics in summary["screening_mean"].items():
print(
f"keep={keep_frac}: recall={metrics['solver_top10_recall']:.4f}, "
f"best_missed={metrics['best_missed_ratio']:.4f}, "
f"regret={metrics['screening_regret']:.6f}"
)
if __name__ == "__main__":
main()