from __future__ import annotations import argparse import csv import json import math import sys from collections import defaultdict from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) import joblib import matplotlib.pyplot as plt import numpy as np from scripts.compare_single_case import build_schedule_vector, load_model from scripts.validate_autofit_local_ranking import infer_curve_layout, predict_surrogate_curve from src.common.config import Config from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features from src.data.params import Params, Schedule from src.evaluation.autofit_objective import dual_log_objective PARAM_COLUMNS = ["k", "skin", "wellboreC", "phi", "h", "Cf"] DEFAULT_KEEP_FRACS = [0.5, 0.6, 0.7] def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Replay a C++ full-solver PSO trace with the forward surrogate as a screening model." ) parser.add_argument("--trace-csv", type=str, default=None, help="Path to pso_baseline_trace_*.csv") parser.add_argument("--trace-meta", type=str, default=None, help="Path to matching pso_baseline_trace_*.meta.json") parser.add_argument("--tag", type=str, default="family_random_mixed_50k_logparam") parser.add_argument("--stage", type=str, default="family_random") parser.add_argument("--processed", type=str, default=None) parser.add_argument("--model", type=str, default=None) parser.add_argument("--config", type=str, default=None) parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--keep-fracs", type=str, default="0.5,0.6,0.7") return parser.parse_args() def latest_trace_pair(temp_dir: Path) -> tuple[Path, Path]: csv_paths = sorted(temp_dir.glob("pso_baseline_trace_*.csv"), key=lambda p: p.stat().st_mtime, reverse=True) for csv_path in csv_paths: meta_path = csv_path.with_suffix(".meta.json") if meta_path.exists(): return csv_path, meta_path raise FileNotFoundError(f"No pso_baseline_trace_*.csv + .meta.json pair found in {temp_dir}") def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path, Path, Path, Path, Path]: tag = normalize_tag(args.tag) if args.trace_csv is None: trace_csv, trace_meta = latest_trace_pair(ROOT / "data" / "temp") else: trace_csv = Path(args.trace_csv) trace_meta = Path(args.trace_meta) if args.trace_meta is not None else trace_csv.with_suffix(".meta.json") config_path = Path(args.config) if args.config is not None else config_for_stage(args.stage) if config_path is None: raise ValueError(f"Cannot resolve config for stage={args.stage!r}; pass --config explicitly") processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) model_path = Path(args.model) if args.model is not None else model_checkpoint_for_tag(tag, use_schedule=True) run_id = trace_csv.stem.replace("pso_baseline_trace_", "") output_dir = Path(args.output_dir) if args.output_dir is not None else Path("results") / f"pso_trace_replay_{run_id}" return ( trace_csv.resolve(), trace_meta.resolve(), config_path.resolve(), processed_path.resolve(), model_path.resolve(), output_dir.resolve(), ) def read_trace_csv(path: Path) -> list[dict]: with path.open("r", newline="", encoding="utf-8-sig") as f: return list(csv.DictReader(f)) def to_float(value: str | None, default: float = float("nan")) -> float: if value is None or value == "": return default try: return float(value) except ValueError: return default def is_valid_solver_row(row: dict) -> bool: obj = to_float(row.get("solver_objective")) return row.get("phase") == "particle_solver" and row.get("solver_success") == "1" and math.isfinite(obj) and obj < 1e9 def build_schedule_from_meta(meta: dict) -> Schedule: target = meta["target"] flow_points = target.get("flow_points") or [] time_q: list[float] = [] q: list[float] = [] for item in flow_points: dt = float(item["x"]) rate = float(item["y"]) if dt <= 0 and not time_q: continue time_q.append(dt) q.append(rate) return Schedule(sectionIndex=int(target["section_index"]), timeQ=time_q, q=q) def build_target_curve_from_meta(cfg: Config, processed: dict, meta: dict) -> np.ndarray: target = meta["target"]["target_loglog"] t = np.asarray(target["time"], dtype=np.float64) p = np.asarray(target["pressure"], dtype=np.float64) d = np.asarray(target["derivative"], dtype=np.float64) cleaned = clean_curve_for_dataset(cfg, t, p, d) if cleaned is None: raise RuntimeError("Target loglog curve from meta cannot be cleaned for replay") t_clean, p_clean, d_clean = cleaned valid, reason = is_valid_curve(cfg, t_clean, p_clean, d_clean) if not valid: raise RuntimeError(f"Target loglog curve from meta is invalid: {reason}") curve = resample_curve_to_features(cfg, t_clean, p_clean, d_clean) curve_dim = int(processed["meta"]["curve_dim"]) if curve.size != curve_dim: raise RuntimeError(f"Target curve dim mismatch after resample: {curve.size} != {curve_dim}") return curve.astype(np.float32) def params_from_row(row: dict) -> Params: values = {name: to_float(row.get(name)) for name in PARAM_COLUMNS} return Params( k=values["k"], skin=values["skin"], wellboreC=values["wellboreC"], phi=values["phi"], h=values["h"], Cf=values["Cf"], ) def rank_positions(values: np.ndarray) -> np.ndarray: order = np.argsort(values) ranks = np.empty_like(order) ranks[order] = np.arange(values.size) return ranks def corr_pearson(a: np.ndarray, b: np.ndarray) -> float: if a.size < 2: return float("nan") if np.std(a) <= 1e-12 or np.std(b) <= 1e-12: return float("nan") return float(np.corrcoef(a, b)[0, 1]) def corr_spearman(a: np.ndarray, b: np.ndarray) -> float: if a.size < 2: return float("nan") return corr_pearson(rank_positions(a).astype(np.float64), rank_positions(b).astype(np.float64)) def summarize_generation(rows: list[dict], keep_fracs: list[float]) -> tuple[dict, list[dict]]: valid = [r for r in rows if r["solver_success"] == "1" and math.isfinite(float(r["solver_objective"])) and float(r["solver_objective"]) < 1e9] if len(valid) < 2: return {}, [] solver_obj = np.asarray([float(r["solver_objective"]) for r in valid], dtype=np.float64) surrogate_obj = np.asarray([float(r["surrogate_objective"]) for r in valid], dtype=np.float64) solver_rank = rank_positions(solver_obj) surrogate_rank = rank_positions(surrogate_obj) for i, row in enumerate(valid): row["solver_rank"] = int(solver_rank[i]) row["surrogate_rank"] = int(surrogate_rank[i]) row["rank_gap"] = int(surrogate_rank[i] - solver_rank[i]) best_solver_idx = int(np.argmin(solver_obj)) best_surrogate_idx = int(np.argmin(surrogate_obj)) top_n = min(10, len(valid)) top_solver = set(np.argsort(solver_obj)[:top_n].tolist()) top_surrogate = set(np.argsort(surrogate_obj)[:top_n].tolist()) summary = { "generation": int(valid[0]["generation"]), "n_particles": len(rows), "n_valid_solver": len(valid), "pearson_objective": corr_pearson(solver_obj, surrogate_obj), "spearman_objective": corr_spearman(solver_obj, surrogate_obj), "top10_overlap": int(len(top_solver & top_surrogate)), "best_solver_particle_id": int(valid[best_solver_idx]["particle_id"]), "best_solver_surrogate_rank": int(surrogate_rank[best_solver_idx]), "best_surrogate_particle_id": int(valid[best_surrogate_idx]["particle_id"]), "best_surrogate_solver_rank": int(solver_rank[best_surrogate_idx]), "solver_best_objective": float(solver_obj[best_solver_idx]), "solver_objective_at_surrogate_best": float(solver_obj[best_surrogate_idx]), "surrogate_top1_regret": float(solver_obj[best_surrogate_idx] - solver_obj[best_solver_idx]), } screening_rows: list[dict] = [] all_surrogate = np.asarray([float(r["surrogate_objective"]) for r in rows], dtype=np.float64) valid_by_particle = {int(r["particle_id"]): r for r in valid} for keep_frac in keep_fracs: keep_n = max(1, int(math.ceil(len(rows) * float(keep_frac)))) kept_all_positions = np.argsort(all_surrogate)[:keep_n] kept_particle_ids = {int(rows[i]["particle_id"]) for i in kept_all_positions} kept_valid_indices = [i for i, r in enumerate(valid) if int(r["particle_id"]) in kept_particle_ids] kept_solver_obj = solver_obj[kept_valid_indices] if kept_valid_indices else np.asarray([], dtype=np.float64) best_missed = int(int(valid[best_solver_idx]["particle_id"]) not in kept_particle_ids) top_recall = float(len({int(valid[i]["particle_id"]) for i in top_solver} & kept_particle_ids) / max(top_n, 1)) if kept_solver_obj.size: screening_regret = float(np.min(kept_solver_obj) - solver_obj[best_solver_idx]) else: screening_regret = float("inf") screening_rows.append( { "generation": summary["generation"], "keep_frac": float(keep_frac), "keep_n": int(keep_n), "solver_top10_recall": top_recall, "best_missed": best_missed, "screening_regret": screening_regret, "kept_valid_solver": int(len(kept_valid_indices)), "kept_failed_or_invalid": int(sum(1 for i in kept_all_positions if int(rows[i]["particle_id"]) not in valid_by_particle)), } ) return summary, screening_rows def write_csv(path: Path, rows: list[dict]) -> None: if not rows: return fieldnames: list[str] = [] for row in rows: for key in row.keys(): if key not in fieldnames: fieldnames.append(key) with path.open("w", newline="", encoding="utf-8-sig") as f: writer = csv.DictWriter(f, fieldnames=fieldnames) writer.writeheader() writer.writerows(rows) def plot_objective_scatter(rows: list[dict], path: Path) -> None: valid = [r for r in rows if r.get("solver_success") == "1" and float(r.get("solver_objective", "inf")) < 1e9] if len(valid) < 2: return solver = np.asarray([float(r["solver_objective"]) for r in valid], dtype=np.float64) surrogate = np.asarray([float(r["surrogate_objective"]) for r in valid], dtype=np.float64) gen = np.asarray([int(r["generation"]) for r in valid], dtype=np.int32) plt.figure(figsize=(6, 5), dpi=150) sc = plt.scatter(solver, surrogate, c=gen, cmap="viridis", alpha=0.8) lo = min(float(np.min(solver)), float(np.min(surrogate))) hi = max(float(np.max(solver)), float(np.max(surrogate))) plt.plot([lo, hi], [lo, hi], "k--", linewidth=1) plt.xlabel("C++ solver objective") plt.ylabel("Surrogate objective to target") plt.title("PSO trace replay objective ranking") plt.colorbar(sc, label="generation") plt.tight_layout() plt.savefig(path) plt.close() def main() -> None: args = parse_args() trace_csv, trace_meta, config_path, processed_path, model_path, output_dir = resolve_paths(args) keep_fracs = [float(x) for x in args.keep_fracs.split(",") if x.strip()] cfg = Config(config_path) processed = joblib.load(processed_path) model, use_schedule, device = load_model(model_path) curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"])) meta = json.loads(trace_meta.read_text(encoding="utf-8")) rows = read_trace_csv(trace_csv) schedule = build_schedule_from_meta(meta) target_curve = build_target_curve_from_meta(cfg, processed, meta) output_dir.mkdir(parents=True, exist_ok=True) replay_rows: list[dict] = [] for row in rows: if row.get("phase") != "particle_solver": continue params = params_from_row(row) params.schedule = schedule pred_curve = predict_surrogate_curve( processed=processed, model=model, device=device, use_schedule=use_schedule, params=params, schedule=schedule, cfg=cfg, ) sur_obj = dual_log_objective(target_curve, pred_curve, curve_layout) replay_row = dict(row) replay_row["surrogate_objective"] = sur_obj["dual_log_objective"] replay_row["surrogate_p_obj"] = sur_obj["log_pressure_objective"] replay_row["surrogate_d_obj"] = sur_obj["log_derivative_objective"] replay_rows.append(replay_row) by_generation: dict[int, list[dict]] = defaultdict(list) for row in replay_rows: by_generation[int(row["generation"])].append(row) generation_summaries: list[dict] = [] screening_rows: list[dict] = [] for generation in sorted(by_generation): summary, gen_screening_rows = summarize_generation(by_generation[generation], keep_fracs) if summary: generation_summaries.append(summary) screening_rows.extend(gen_screening_rows) write_csv(output_dir / "candidate_objectives.csv", replay_rows) write_csv(output_dir / "generation_summary.csv", generation_summaries) write_csv(output_dir / "screening_summary.csv", screening_rows) plot_objective_scatter(replay_rows, output_dir / "objective_scatter.png") summary = { "trace_csv": str(trace_csv), "trace_meta": str(trace_meta), "processed_path": str(processed_path), "model_path": str(model_path), "run_id": meta.get("run_id"), "n_particle_rows": len(replay_rows), "n_generations": len(generation_summaries), "target_points_raw": len(meta["target"]["target_loglog"]["time"]), "target_curve_dim_resampled": int(target_curve.size), "use_schedule": bool(use_schedule), "keep_fracs": keep_fracs, "generation_summary_mean": {}, "screening_mean": {}, } if generation_summaries: for key in [ "pearson_objective", "spearman_objective", "top10_overlap", "best_solver_surrogate_rank", "best_surrogate_solver_rank", "surrogate_top1_regret", ]: values = np.asarray([float(r[key]) for r in generation_summaries], dtype=np.float64) summary["generation_summary_mean"][key] = float(np.nanmean(values)) for keep_frac in keep_fracs: rows_k = [r for r in screening_rows if abs(float(r["keep_frac"]) - keep_frac) < 1e-12] if not rows_k: continue summary["screening_mean"][str(keep_frac)] = { "solver_top10_recall": float(np.mean([float(r["solver_top10_recall"]) for r in rows_k])), "best_missed_ratio": float(np.mean([float(r["best_missed"]) for r in rows_k])), "screening_regret": float(np.mean([float(r["screening_regret"]) for r in rows_k])), "kept_failed_or_invalid": float(np.mean([float(r["kept_failed_or_invalid"]) for r in rows_k])), } (output_dir / "summary.json").write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8") print(f"Trace replay output: {output_dir}") print(f"particle_rows={summary['n_particle_rows']}, generations={summary['n_generations']}") if summary["generation_summary_mean"]: print( "mean Spearman=" f"{summary['generation_summary_mean']['spearman_objective']:.4f}, " f"mean best_solver_surrogate_rank={summary['generation_summary_mean']['best_solver_surrogate_rank']:.2f}" ) for keep_frac, metrics in summary["screening_mean"].items(): print( f"keep={keep_frac}: recall={metrics['solver_top10_recall']:.4f}, " f"best_missed={metrics['best_missed_ratio']:.4f}, " f"regret={metrics['screening_regret']:.6f}" ) if __name__ == "__main__": main()