You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/replay_pso_trace_screening.py

391 lines
16 KiB
Python

from __future__ import annotations
import argparse
import csv
import json
import math
import sys
from collections import defaultdict
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import joblib
import matplotlib.pyplot as plt
import numpy as np
from scripts.compare_single_case import build_schedule_vector, load_model
from scripts.validate_autofit_local_ranking import infer_curve_layout, predict_surrogate_curve
from src.common.config import Config
from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag
from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features
from src.data.params import Params, Schedule
from src.evaluation.autofit_objective import dual_log_objective
PARAM_COLUMNS = ["k", "skin", "wellboreC", "phi", "h", "Cf"]
DEFAULT_KEEP_FRACS = [0.5, 0.6, 0.7]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Replay a C++ full-solver PSO trace with the forward surrogate as a screening model."
)
parser.add_argument("--trace-csv", type=str, default=None, help="Path to pso_baseline_trace_*.csv")
parser.add_argument("--trace-meta", type=str, default=None, help="Path to matching pso_baseline_trace_*.meta.json")
parser.add_argument("--tag", type=str, default="family_random_mixed_50k_logparam")
parser.add_argument("--stage", type=str, default="family_random")
parser.add_argument("--processed", type=str, default=None)
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--config", type=str, default=None)
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--keep-fracs", type=str, default="0.5,0.6,0.7")
return parser.parse_args()
def latest_trace_pair(temp_dir: Path) -> tuple[Path, Path]:
csv_paths = sorted(temp_dir.glob("pso_baseline_trace_*.csv"), key=lambda p: p.stat().st_mtime, reverse=True)
for csv_path in csv_paths:
meta_path = csv_path.with_suffix(".meta.json")
if meta_path.exists():
return csv_path, meta_path
raise FileNotFoundError(f"No pso_baseline_trace_*.csv + .meta.json pair found in {temp_dir}")
def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path, Path, Path, Path, Path]:
tag = normalize_tag(args.tag)
if args.trace_csv is None:
trace_csv, trace_meta = latest_trace_pair(ROOT / "data" / "temp")
else:
trace_csv = Path(args.trace_csv)
trace_meta = Path(args.trace_meta) if args.trace_meta is not None else trace_csv.with_suffix(".meta.json")
config_path = Path(args.config) if args.config is not None else config_for_stage(args.stage)
if config_path is None:
raise ValueError(f"Cannot resolve config for stage={args.stage!r}; pass --config explicitly")
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = Path(args.model) if args.model is not None else model_checkpoint_for_tag(tag, use_schedule=True)
run_id = trace_csv.stem.replace("pso_baseline_trace_", "")
output_dir = Path(args.output_dir) if args.output_dir is not None else Path("results") / f"pso_trace_replay_{run_id}"
return (
trace_csv.resolve(),
trace_meta.resolve(),
config_path.resolve(),
processed_path.resolve(),
model_path.resolve(),
output_dir.resolve(),
)
def read_trace_csv(path: Path) -> list[dict]:
with path.open("r", newline="", encoding="utf-8-sig") as f:
return list(csv.DictReader(f))
def to_float(value: str | None, default: float = float("nan")) -> float:
if value is None or value == "":
return default
try:
return float(value)
except ValueError:
return default
def is_valid_solver_row(row: dict) -> bool:
obj = to_float(row.get("solver_objective"))
return row.get("phase") == "particle_solver" and row.get("solver_success") == "1" and math.isfinite(obj) and obj < 1e9
def build_schedule_from_meta(meta: dict) -> Schedule:
target = meta["target"]
flow_points = target.get("flow_points") or []
time_q: list[float] = []
q: list[float] = []
for item in flow_points:
dt = float(item["x"])
rate = float(item["y"])
if dt <= 0 and not time_q:
continue
time_q.append(dt)
q.append(rate)
return Schedule(sectionIndex=int(target["section_index"]), timeQ=time_q, q=q)
def build_target_curve_from_meta(cfg: Config, processed: dict, meta: dict) -> np.ndarray:
target = meta["target"]["target_loglog"]
t = np.asarray(target["time"], dtype=np.float64)
p = np.asarray(target["pressure"], dtype=np.float64)
d = np.asarray(target["derivative"], dtype=np.float64)
cleaned = clean_curve_for_dataset(cfg, t, p, d)
if cleaned is None:
raise RuntimeError("Target loglog curve from meta cannot be cleaned for replay")
t_clean, p_clean, d_clean = cleaned
valid, reason = is_valid_curve(cfg, t_clean, p_clean, d_clean)
if not valid:
raise RuntimeError(f"Target loglog curve from meta is invalid: {reason}")
curve = resample_curve_to_features(cfg, t_clean, p_clean, d_clean)
curve_dim = int(processed["meta"]["curve_dim"])
if curve.size != curve_dim:
raise RuntimeError(f"Target curve dim mismatch after resample: {curve.size} != {curve_dim}")
return curve.astype(np.float32)
def params_from_row(row: dict) -> Params:
values = {name: to_float(row.get(name)) for name in PARAM_COLUMNS}
return Params(
k=values["k"],
skin=values["skin"],
wellboreC=values["wellboreC"],
phi=values["phi"],
h=values["h"],
Cf=values["Cf"],
)
def rank_positions(values: np.ndarray) -> np.ndarray:
order = np.argsort(values)
ranks = np.empty_like(order)
ranks[order] = np.arange(values.size)
return ranks
def corr_pearson(a: np.ndarray, b: np.ndarray) -> float:
if a.size < 2:
return float("nan")
if np.std(a) <= 1e-12 or np.std(b) <= 1e-12:
return float("nan")
return float(np.corrcoef(a, b)[0, 1])
def corr_spearman(a: np.ndarray, b: np.ndarray) -> float:
if a.size < 2:
return float("nan")
return corr_pearson(rank_positions(a).astype(np.float64), rank_positions(b).astype(np.float64))
def summarize_generation(rows: list[dict], keep_fracs: list[float]) -> tuple[dict, list[dict]]:
valid = [r for r in rows if r["solver_success"] == "1" and math.isfinite(float(r["solver_objective"])) and float(r["solver_objective"]) < 1e9]
if len(valid) < 2:
return {}, []
solver_obj = np.asarray([float(r["solver_objective"]) for r in valid], dtype=np.float64)
surrogate_obj = np.asarray([float(r["surrogate_objective"]) for r in valid], dtype=np.float64)
solver_rank = rank_positions(solver_obj)
surrogate_rank = rank_positions(surrogate_obj)
for i, row in enumerate(valid):
row["solver_rank"] = int(solver_rank[i])
row["surrogate_rank"] = int(surrogate_rank[i])
row["rank_gap"] = int(surrogate_rank[i] - solver_rank[i])
best_solver_idx = int(np.argmin(solver_obj))
best_surrogate_idx = int(np.argmin(surrogate_obj))
top_n = min(10, len(valid))
top_solver = set(np.argsort(solver_obj)[:top_n].tolist())
top_surrogate = set(np.argsort(surrogate_obj)[:top_n].tolist())
summary = {
"generation": int(valid[0]["generation"]),
"n_particles": len(rows),
"n_valid_solver": len(valid),
"pearson_objective": corr_pearson(solver_obj, surrogate_obj),
"spearman_objective": corr_spearman(solver_obj, surrogate_obj),
"top10_overlap": int(len(top_solver & top_surrogate)),
"best_solver_particle_id": int(valid[best_solver_idx]["particle_id"]),
"best_solver_surrogate_rank": int(surrogate_rank[best_solver_idx]),
"best_surrogate_particle_id": int(valid[best_surrogate_idx]["particle_id"]),
"best_surrogate_solver_rank": int(solver_rank[best_surrogate_idx]),
"solver_best_objective": float(solver_obj[best_solver_idx]),
"solver_objective_at_surrogate_best": float(solver_obj[best_surrogate_idx]),
"surrogate_top1_regret": float(solver_obj[best_surrogate_idx] - solver_obj[best_solver_idx]),
}
screening_rows: list[dict] = []
all_surrogate = np.asarray([float(r["surrogate_objective"]) for r in rows], dtype=np.float64)
valid_by_particle = {int(r["particle_id"]): r for r in valid}
for keep_frac in keep_fracs:
keep_n = max(1, int(math.ceil(len(rows) * float(keep_frac))))
kept_all_positions = np.argsort(all_surrogate)[:keep_n]
kept_particle_ids = {int(rows[i]["particle_id"]) for i in kept_all_positions}
kept_valid_indices = [i for i, r in enumerate(valid) if int(r["particle_id"]) in kept_particle_ids]
kept_solver_obj = solver_obj[kept_valid_indices] if kept_valid_indices else np.asarray([], dtype=np.float64)
best_missed = int(int(valid[best_solver_idx]["particle_id"]) not in kept_particle_ids)
top_recall = float(len({int(valid[i]["particle_id"]) for i in top_solver} & kept_particle_ids) / max(top_n, 1))
if kept_solver_obj.size:
screening_regret = float(np.min(kept_solver_obj) - solver_obj[best_solver_idx])
else:
screening_regret = float("inf")
screening_rows.append(
{
"generation": summary["generation"],
"keep_frac": float(keep_frac),
"keep_n": int(keep_n),
"solver_top10_recall": top_recall,
"best_missed": best_missed,
"screening_regret": screening_regret,
"kept_valid_solver": int(len(kept_valid_indices)),
"kept_failed_or_invalid": int(sum(1 for i in kept_all_positions if int(rows[i]["particle_id"]) not in valid_by_particle)),
}
)
return summary, screening_rows
def write_csv(path: Path, rows: list[dict]) -> None:
if not rows:
return
fieldnames: list[str] = []
for row in rows:
for key in row.keys():
if key not in fieldnames:
fieldnames.append(key)
with path.open("w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def plot_objective_scatter(rows: list[dict], path: Path) -> None:
valid = [r for r in rows if r.get("solver_success") == "1" and float(r.get("solver_objective", "inf")) < 1e9]
if len(valid) < 2:
return
solver = np.asarray([float(r["solver_objective"]) for r in valid], dtype=np.float64)
surrogate = np.asarray([float(r["surrogate_objective"]) for r in valid], dtype=np.float64)
gen = np.asarray([int(r["generation"]) for r in valid], dtype=np.int32)
plt.figure(figsize=(6, 5), dpi=150)
sc = plt.scatter(solver, surrogate, c=gen, cmap="viridis", alpha=0.8)
lo = min(float(np.min(solver)), float(np.min(surrogate)))
hi = max(float(np.max(solver)), float(np.max(surrogate)))
plt.plot([lo, hi], [lo, hi], "k--", linewidth=1)
plt.xlabel("C++ solver objective")
plt.ylabel("Surrogate objective to target")
plt.title("PSO trace replay objective ranking")
plt.colorbar(sc, label="generation")
plt.tight_layout()
plt.savefig(path)
plt.close()
def main() -> None:
args = parse_args()
trace_csv, trace_meta, config_path, processed_path, model_path, output_dir = resolve_paths(args)
keep_fracs = [float(x) for x in args.keep_fracs.split(",") if x.strip()]
cfg = Config(config_path)
processed = joblib.load(processed_path)
model, use_schedule, device = load_model(model_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
meta = json.loads(trace_meta.read_text(encoding="utf-8"))
rows = read_trace_csv(trace_csv)
schedule = build_schedule_from_meta(meta)
target_curve = build_target_curve_from_meta(cfg, processed, meta)
output_dir.mkdir(parents=True, exist_ok=True)
replay_rows: list[dict] = []
for row in rows:
if row.get("phase") != "particle_solver":
continue
params = params_from_row(row)
params.schedule = schedule
pred_curve = predict_surrogate_curve(
processed=processed,
model=model,
device=device,
use_schedule=use_schedule,
params=params,
schedule=schedule,
cfg=cfg,
)
sur_obj = dual_log_objective(target_curve, pred_curve, curve_layout)
replay_row = dict(row)
replay_row["surrogate_objective"] = sur_obj["dual_log_objective"]
replay_row["surrogate_p_obj"] = sur_obj["log_pressure_objective"]
replay_row["surrogate_d_obj"] = sur_obj["log_derivative_objective"]
replay_rows.append(replay_row)
by_generation: dict[int, list[dict]] = defaultdict(list)
for row in replay_rows:
by_generation[int(row["generation"])].append(row)
generation_summaries: list[dict] = []
screening_rows: list[dict] = []
for generation in sorted(by_generation):
summary, gen_screening_rows = summarize_generation(by_generation[generation], keep_fracs)
if summary:
generation_summaries.append(summary)
screening_rows.extend(gen_screening_rows)
write_csv(output_dir / "candidate_objectives.csv", replay_rows)
write_csv(output_dir / "generation_summary.csv", generation_summaries)
write_csv(output_dir / "screening_summary.csv", screening_rows)
plot_objective_scatter(replay_rows, output_dir / "objective_scatter.png")
summary = {
"trace_csv": str(trace_csv),
"trace_meta": str(trace_meta),
"processed_path": str(processed_path),
"model_path": str(model_path),
"run_id": meta.get("run_id"),
"n_particle_rows": len(replay_rows),
"n_generations": len(generation_summaries),
"target_points_raw": len(meta["target"]["target_loglog"]["time"]),
"target_curve_dim_resampled": int(target_curve.size),
"use_schedule": bool(use_schedule),
"keep_fracs": keep_fracs,
"generation_summary_mean": {},
"screening_mean": {},
}
if generation_summaries:
for key in [
"pearson_objective",
"spearman_objective",
"top10_overlap",
"best_solver_surrogate_rank",
"best_surrogate_solver_rank",
"surrogate_top1_regret",
]:
values = np.asarray([float(r[key]) for r in generation_summaries], dtype=np.float64)
summary["generation_summary_mean"][key] = float(np.nanmean(values))
for keep_frac in keep_fracs:
rows_k = [r for r in screening_rows if abs(float(r["keep_frac"]) - keep_frac) < 1e-12]
if not rows_k:
continue
summary["screening_mean"][str(keep_frac)] = {
"solver_top10_recall": float(np.mean([float(r["solver_top10_recall"]) for r in rows_k])),
"best_missed_ratio": float(np.mean([float(r["best_missed"]) for r in rows_k])),
"screening_regret": float(np.mean([float(r["screening_regret"]) for r in rows_k])),
"kept_failed_or_invalid": float(np.mean([float(r["kept_failed_or_invalid"]) for r in rows_k])),
}
(output_dir / "summary.json").write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8")
print(f"Trace replay output: {output_dir}")
print(f"particle_rows={summary['n_particle_rows']}, generations={summary['n_generations']}")
if summary["generation_summary_mean"]:
print(
"mean Spearman="
f"{summary['generation_summary_mean']['spearman_objective']:.4f}, "
f"mean best_solver_surrogate_rank={summary['generation_summary_mean']['best_solver_surrogate_rank']:.2f}"
)
for keep_frac, metrics in summary["screening_mean"].items():
print(
f"keep={keep_frac}: recall={metrics['solver_top10_recall']:.4f}, "
f"best_missed={metrics['best_missed_ratio']:.4f}, "
f"regret={metrics['screening_regret']:.6f}"
)
if __name__ == "__main__":
main()