You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
391 lines
16 KiB
Python
391 lines
16 KiB
Python
|
3 weeks ago
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import argparse
|
||
|
|
import csv
|
||
|
|
import json
|
||
|
|
import math
|
||
|
|
import sys
|
||
|
|
from collections import defaultdict
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
ROOT = Path(__file__).resolve().parents[1]
|
||
|
|
sys.path.append(str(ROOT))
|
||
|
|
|
||
|
|
import joblib
|
||
|
|
import matplotlib.pyplot as plt
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
from scripts.compare_single_case import build_schedule_vector, load_model
|
||
|
|
from scripts.validate_autofit_local_ranking import infer_curve_layout, predict_surrogate_curve
|
||
|
|
from src.common.config import Config
|
||
|
|
from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag
|
||
|
|
from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features
|
||
|
|
from src.data.params import Params, Schedule
|
||
|
|
from src.evaluation.autofit_objective import dual_log_objective
|
||
|
|
|
||
|
|
|
||
|
|
PARAM_COLUMNS = ["k", "skin", "wellboreC", "phi", "h", "Cf"]
|
||
|
|
DEFAULT_KEEP_FRACS = [0.5, 0.6, 0.7]
|
||
|
|
|
||
|
|
|
||
|
|
def parse_args() -> argparse.Namespace:
|
||
|
|
parser = argparse.ArgumentParser(
|
||
|
|
description="Replay a C++ full-solver PSO trace with the forward surrogate as a screening model."
|
||
|
|
)
|
||
|
|
parser.add_argument("--trace-csv", type=str, default=None, help="Path to pso_baseline_trace_*.csv")
|
||
|
|
parser.add_argument("--trace-meta", type=str, default=None, help="Path to matching pso_baseline_trace_*.meta.json")
|
||
|
|
parser.add_argument("--tag", type=str, default="family_random_mixed_50k_logparam")
|
||
|
|
parser.add_argument("--stage", type=str, default="family_random")
|
||
|
|
parser.add_argument("--processed", type=str, default=None)
|
||
|
|
parser.add_argument("--model", type=str, default=None)
|
||
|
|
parser.add_argument("--config", type=str, default=None)
|
||
|
|
parser.add_argument("--output-dir", type=str, default=None)
|
||
|
|
parser.add_argument("--keep-fracs", type=str, default="0.5,0.6,0.7")
|
||
|
|
return parser.parse_args()
|
||
|
|
|
||
|
|
|
||
|
|
def latest_trace_pair(temp_dir: Path) -> tuple[Path, Path]:
|
||
|
|
csv_paths = sorted(temp_dir.glob("pso_baseline_trace_*.csv"), key=lambda p: p.stat().st_mtime, reverse=True)
|
||
|
|
for csv_path in csv_paths:
|
||
|
|
meta_path = csv_path.with_suffix(".meta.json")
|
||
|
|
if meta_path.exists():
|
||
|
|
return csv_path, meta_path
|
||
|
|
raise FileNotFoundError(f"No pso_baseline_trace_*.csv + .meta.json pair found in {temp_dir}")
|
||
|
|
|
||
|
|
|
||
|
|
def resolve_paths(args: argparse.Namespace) -> tuple[Path, Path, Path, Path, Path, Path]:
|
||
|
|
tag = normalize_tag(args.tag)
|
||
|
|
if args.trace_csv is None:
|
||
|
|
trace_csv, trace_meta = latest_trace_pair(ROOT / "data" / "temp")
|
||
|
|
else:
|
||
|
|
trace_csv = Path(args.trace_csv)
|
||
|
|
trace_meta = Path(args.trace_meta) if args.trace_meta is not None else trace_csv.with_suffix(".meta.json")
|
||
|
|
|
||
|
|
config_path = Path(args.config) if args.config is not None else config_for_stage(args.stage)
|
||
|
|
if config_path is None:
|
||
|
|
raise ValueError(f"Cannot resolve config for stage={args.stage!r}; pass --config explicitly")
|
||
|
|
|
||
|
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
||
|
|
model_path = Path(args.model) if args.model is not None else model_checkpoint_for_tag(tag, use_schedule=True)
|
||
|
|
|
||
|
|
run_id = trace_csv.stem.replace("pso_baseline_trace_", "")
|
||
|
|
output_dir = Path(args.output_dir) if args.output_dir is not None else Path("results") / f"pso_trace_replay_{run_id}"
|
||
|
|
|
||
|
|
return (
|
||
|
|
trace_csv.resolve(),
|
||
|
|
trace_meta.resolve(),
|
||
|
|
config_path.resolve(),
|
||
|
|
processed_path.resolve(),
|
||
|
|
model_path.resolve(),
|
||
|
|
output_dir.resolve(),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def read_trace_csv(path: Path) -> list[dict]:
|
||
|
|
with path.open("r", newline="", encoding="utf-8-sig") as f:
|
||
|
|
return list(csv.DictReader(f))
|
||
|
|
|
||
|
|
|
||
|
|
def to_float(value: str | None, default: float = float("nan")) -> float:
|
||
|
|
if value is None or value == "":
|
||
|
|
return default
|
||
|
|
try:
|
||
|
|
return float(value)
|
||
|
|
except ValueError:
|
||
|
|
return default
|
||
|
|
|
||
|
|
|
||
|
|
def is_valid_solver_row(row: dict) -> bool:
|
||
|
|
obj = to_float(row.get("solver_objective"))
|
||
|
|
return row.get("phase") == "particle_solver" and row.get("solver_success") == "1" and math.isfinite(obj) and obj < 1e9
|
||
|
|
|
||
|
|
|
||
|
|
def build_schedule_from_meta(meta: dict) -> Schedule:
|
||
|
|
target = meta["target"]
|
||
|
|
flow_points = target.get("flow_points") or []
|
||
|
|
time_q: list[float] = []
|
||
|
|
q: list[float] = []
|
||
|
|
for item in flow_points:
|
||
|
|
dt = float(item["x"])
|
||
|
|
rate = float(item["y"])
|
||
|
|
if dt <= 0 and not time_q:
|
||
|
|
continue
|
||
|
|
time_q.append(dt)
|
||
|
|
q.append(rate)
|
||
|
|
return Schedule(sectionIndex=int(target["section_index"]), timeQ=time_q, q=q)
|
||
|
|
|
||
|
|
|
||
|
|
def build_target_curve_from_meta(cfg: Config, processed: dict, meta: dict) -> np.ndarray:
|
||
|
|
target = meta["target"]["target_loglog"]
|
||
|
|
t = np.asarray(target["time"], dtype=np.float64)
|
||
|
|
p = np.asarray(target["pressure"], dtype=np.float64)
|
||
|
|
d = np.asarray(target["derivative"], dtype=np.float64)
|
||
|
|
cleaned = clean_curve_for_dataset(cfg, t, p, d)
|
||
|
|
if cleaned is None:
|
||
|
|
raise RuntimeError("Target loglog curve from meta cannot be cleaned for replay")
|
||
|
|
t_clean, p_clean, d_clean = cleaned
|
||
|
|
valid, reason = is_valid_curve(cfg, t_clean, p_clean, d_clean)
|
||
|
|
if not valid:
|
||
|
|
raise RuntimeError(f"Target loglog curve from meta is invalid: {reason}")
|
||
|
|
|
||
|
|
curve = resample_curve_to_features(cfg, t_clean, p_clean, d_clean)
|
||
|
|
curve_dim = int(processed["meta"]["curve_dim"])
|
||
|
|
if curve.size != curve_dim:
|
||
|
|
raise RuntimeError(f"Target curve dim mismatch after resample: {curve.size} != {curve_dim}")
|
||
|
|
return curve.astype(np.float32)
|
||
|
|
|
||
|
|
|
||
|
|
def params_from_row(row: dict) -> Params:
|
||
|
|
values = {name: to_float(row.get(name)) for name in PARAM_COLUMNS}
|
||
|
|
return Params(
|
||
|
|
k=values["k"],
|
||
|
|
skin=values["skin"],
|
||
|
|
wellboreC=values["wellboreC"],
|
||
|
|
phi=values["phi"],
|
||
|
|
h=values["h"],
|
||
|
|
Cf=values["Cf"],
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def rank_positions(values: np.ndarray) -> np.ndarray:
|
||
|
|
order = np.argsort(values)
|
||
|
|
ranks = np.empty_like(order)
|
||
|
|
ranks[order] = np.arange(values.size)
|
||
|
|
return ranks
|
||
|
|
|
||
|
|
|
||
|
|
def corr_pearson(a: np.ndarray, b: np.ndarray) -> float:
|
||
|
|
if a.size < 2:
|
||
|
|
return float("nan")
|
||
|
|
if np.std(a) <= 1e-12 or np.std(b) <= 1e-12:
|
||
|
|
return float("nan")
|
||
|
|
return float(np.corrcoef(a, b)[0, 1])
|
||
|
|
|
||
|
|
|
||
|
|
def corr_spearman(a: np.ndarray, b: np.ndarray) -> float:
|
||
|
|
if a.size < 2:
|
||
|
|
return float("nan")
|
||
|
|
return corr_pearson(rank_positions(a).astype(np.float64), rank_positions(b).astype(np.float64))
|
||
|
|
|
||
|
|
|
||
|
|
def summarize_generation(rows: list[dict], keep_fracs: list[float]) -> tuple[dict, list[dict]]:
|
||
|
|
valid = [r for r in rows if r["solver_success"] == "1" and math.isfinite(float(r["solver_objective"])) and float(r["solver_objective"]) < 1e9]
|
||
|
|
if len(valid) < 2:
|
||
|
|
return {}, []
|
||
|
|
|
||
|
|
solver_obj = np.asarray([float(r["solver_objective"]) for r in valid], dtype=np.float64)
|
||
|
|
surrogate_obj = np.asarray([float(r["surrogate_objective"]) for r in valid], dtype=np.float64)
|
||
|
|
solver_rank = rank_positions(solver_obj)
|
||
|
|
surrogate_rank = rank_positions(surrogate_obj)
|
||
|
|
|
||
|
|
for i, row in enumerate(valid):
|
||
|
|
row["solver_rank"] = int(solver_rank[i])
|
||
|
|
row["surrogate_rank"] = int(surrogate_rank[i])
|
||
|
|
row["rank_gap"] = int(surrogate_rank[i] - solver_rank[i])
|
||
|
|
|
||
|
|
best_solver_idx = int(np.argmin(solver_obj))
|
||
|
|
best_surrogate_idx = int(np.argmin(surrogate_obj))
|
||
|
|
top_n = min(10, len(valid))
|
||
|
|
top_solver = set(np.argsort(solver_obj)[:top_n].tolist())
|
||
|
|
top_surrogate = set(np.argsort(surrogate_obj)[:top_n].tolist())
|
||
|
|
|
||
|
|
summary = {
|
||
|
|
"generation": int(valid[0]["generation"]),
|
||
|
|
"n_particles": len(rows),
|
||
|
|
"n_valid_solver": len(valid),
|
||
|
|
"pearson_objective": corr_pearson(solver_obj, surrogate_obj),
|
||
|
|
"spearman_objective": corr_spearman(solver_obj, surrogate_obj),
|
||
|
|
"top10_overlap": int(len(top_solver & top_surrogate)),
|
||
|
|
"best_solver_particle_id": int(valid[best_solver_idx]["particle_id"]),
|
||
|
|
"best_solver_surrogate_rank": int(surrogate_rank[best_solver_idx]),
|
||
|
|
"best_surrogate_particle_id": int(valid[best_surrogate_idx]["particle_id"]),
|
||
|
|
"best_surrogate_solver_rank": int(solver_rank[best_surrogate_idx]),
|
||
|
|
"solver_best_objective": float(solver_obj[best_solver_idx]),
|
||
|
|
"solver_objective_at_surrogate_best": float(solver_obj[best_surrogate_idx]),
|
||
|
|
"surrogate_top1_regret": float(solver_obj[best_surrogate_idx] - solver_obj[best_solver_idx]),
|
||
|
|
}
|
||
|
|
|
||
|
|
screening_rows: list[dict] = []
|
||
|
|
all_surrogate = np.asarray([float(r["surrogate_objective"]) for r in rows], dtype=np.float64)
|
||
|
|
valid_by_particle = {int(r["particle_id"]): r for r in valid}
|
||
|
|
for keep_frac in keep_fracs:
|
||
|
|
keep_n = max(1, int(math.ceil(len(rows) * float(keep_frac))))
|
||
|
|
kept_all_positions = np.argsort(all_surrogate)[:keep_n]
|
||
|
|
kept_particle_ids = {int(rows[i]["particle_id"]) for i in kept_all_positions}
|
||
|
|
kept_valid_indices = [i for i, r in enumerate(valid) if int(r["particle_id"]) in kept_particle_ids]
|
||
|
|
kept_solver_obj = solver_obj[kept_valid_indices] if kept_valid_indices else np.asarray([], dtype=np.float64)
|
||
|
|
|
||
|
|
best_missed = int(int(valid[best_solver_idx]["particle_id"]) not in kept_particle_ids)
|
||
|
|
top_recall = float(len({int(valid[i]["particle_id"]) for i in top_solver} & kept_particle_ids) / max(top_n, 1))
|
||
|
|
if kept_solver_obj.size:
|
||
|
|
screening_regret = float(np.min(kept_solver_obj) - solver_obj[best_solver_idx])
|
||
|
|
else:
|
||
|
|
screening_regret = float("inf")
|
||
|
|
|
||
|
|
screening_rows.append(
|
||
|
|
{
|
||
|
|
"generation": summary["generation"],
|
||
|
|
"keep_frac": float(keep_frac),
|
||
|
|
"keep_n": int(keep_n),
|
||
|
|
"solver_top10_recall": top_recall,
|
||
|
|
"best_missed": best_missed,
|
||
|
|
"screening_regret": screening_regret,
|
||
|
|
"kept_valid_solver": int(len(kept_valid_indices)),
|
||
|
|
"kept_failed_or_invalid": int(sum(1 for i in kept_all_positions if int(rows[i]["particle_id"]) not in valid_by_particle)),
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
return summary, screening_rows
|
||
|
|
|
||
|
|
|
||
|
|
def write_csv(path: Path, rows: list[dict]) -> None:
|
||
|
|
if not rows:
|
||
|
|
return
|
||
|
|
fieldnames: list[str] = []
|
||
|
|
for row in rows:
|
||
|
|
for key in row.keys():
|
||
|
|
if key not in fieldnames:
|
||
|
|
fieldnames.append(key)
|
||
|
|
with path.open("w", newline="", encoding="utf-8-sig") as f:
|
||
|
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||
|
|
writer.writeheader()
|
||
|
|
writer.writerows(rows)
|
||
|
|
|
||
|
|
|
||
|
|
def plot_objective_scatter(rows: list[dict], path: Path) -> None:
|
||
|
|
valid = [r for r in rows if r.get("solver_success") == "1" and float(r.get("solver_objective", "inf")) < 1e9]
|
||
|
|
if len(valid) < 2:
|
||
|
|
return
|
||
|
|
solver = np.asarray([float(r["solver_objective"]) for r in valid], dtype=np.float64)
|
||
|
|
surrogate = np.asarray([float(r["surrogate_objective"]) for r in valid], dtype=np.float64)
|
||
|
|
gen = np.asarray([int(r["generation"]) for r in valid], dtype=np.int32)
|
||
|
|
|
||
|
|
plt.figure(figsize=(6, 5), dpi=150)
|
||
|
|
sc = plt.scatter(solver, surrogate, c=gen, cmap="viridis", alpha=0.8)
|
||
|
|
lo = min(float(np.min(solver)), float(np.min(surrogate)))
|
||
|
|
hi = max(float(np.max(solver)), float(np.max(surrogate)))
|
||
|
|
plt.plot([lo, hi], [lo, hi], "k--", linewidth=1)
|
||
|
|
plt.xlabel("C++ solver objective")
|
||
|
|
plt.ylabel("Surrogate objective to target")
|
||
|
|
plt.title("PSO trace replay objective ranking")
|
||
|
|
plt.colorbar(sc, label="generation")
|
||
|
|
plt.tight_layout()
|
||
|
|
plt.savefig(path)
|
||
|
|
plt.close()
|
||
|
|
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
args = parse_args()
|
||
|
|
trace_csv, trace_meta, config_path, processed_path, model_path, output_dir = resolve_paths(args)
|
||
|
|
keep_fracs = [float(x) for x in args.keep_fracs.split(",") if x.strip()]
|
||
|
|
|
||
|
|
cfg = Config(config_path)
|
||
|
|
processed = joblib.load(processed_path)
|
||
|
|
model, use_schedule, device = load_model(model_path)
|
||
|
|
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
|
||
|
|
|
||
|
|
meta = json.loads(trace_meta.read_text(encoding="utf-8"))
|
||
|
|
rows = read_trace_csv(trace_csv)
|
||
|
|
schedule = build_schedule_from_meta(meta)
|
||
|
|
target_curve = build_target_curve_from_meta(cfg, processed, meta)
|
||
|
|
|
||
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
replay_rows: list[dict] = []
|
||
|
|
for row in rows:
|
||
|
|
if row.get("phase") != "particle_solver":
|
||
|
|
continue
|
||
|
|
params = params_from_row(row)
|
||
|
|
params.schedule = schedule
|
||
|
|
pred_curve = predict_surrogate_curve(
|
||
|
|
processed=processed,
|
||
|
|
model=model,
|
||
|
|
device=device,
|
||
|
|
use_schedule=use_schedule,
|
||
|
|
params=params,
|
||
|
|
schedule=schedule,
|
||
|
|
cfg=cfg,
|
||
|
|
)
|
||
|
|
sur_obj = dual_log_objective(target_curve, pred_curve, curve_layout)
|
||
|
|
replay_row = dict(row)
|
||
|
|
replay_row["surrogate_objective"] = sur_obj["dual_log_objective"]
|
||
|
|
replay_row["surrogate_p_obj"] = sur_obj["log_pressure_objective"]
|
||
|
|
replay_row["surrogate_d_obj"] = sur_obj["log_derivative_objective"]
|
||
|
|
replay_rows.append(replay_row)
|
||
|
|
|
||
|
|
by_generation: dict[int, list[dict]] = defaultdict(list)
|
||
|
|
for row in replay_rows:
|
||
|
|
by_generation[int(row["generation"])].append(row)
|
||
|
|
|
||
|
|
generation_summaries: list[dict] = []
|
||
|
|
screening_rows: list[dict] = []
|
||
|
|
for generation in sorted(by_generation):
|
||
|
|
summary, gen_screening_rows = summarize_generation(by_generation[generation], keep_fracs)
|
||
|
|
if summary:
|
||
|
|
generation_summaries.append(summary)
|
||
|
|
screening_rows.extend(gen_screening_rows)
|
||
|
|
|
||
|
|
write_csv(output_dir / "candidate_objectives.csv", replay_rows)
|
||
|
|
write_csv(output_dir / "generation_summary.csv", generation_summaries)
|
||
|
|
write_csv(output_dir / "screening_summary.csv", screening_rows)
|
||
|
|
plot_objective_scatter(replay_rows, output_dir / "objective_scatter.png")
|
||
|
|
|
||
|
|
summary = {
|
||
|
|
"trace_csv": str(trace_csv),
|
||
|
|
"trace_meta": str(trace_meta),
|
||
|
|
"processed_path": str(processed_path),
|
||
|
|
"model_path": str(model_path),
|
||
|
|
"run_id": meta.get("run_id"),
|
||
|
|
"n_particle_rows": len(replay_rows),
|
||
|
|
"n_generations": len(generation_summaries),
|
||
|
|
"target_points_raw": len(meta["target"]["target_loglog"]["time"]),
|
||
|
|
"target_curve_dim_resampled": int(target_curve.size),
|
||
|
|
"use_schedule": bool(use_schedule),
|
||
|
|
"keep_fracs": keep_fracs,
|
||
|
|
"generation_summary_mean": {},
|
||
|
|
"screening_mean": {},
|
||
|
|
}
|
||
|
|
if generation_summaries:
|
||
|
|
for key in [
|
||
|
|
"pearson_objective",
|
||
|
|
"spearman_objective",
|
||
|
|
"top10_overlap",
|
||
|
|
"best_solver_surrogate_rank",
|
||
|
|
"best_surrogate_solver_rank",
|
||
|
|
"surrogate_top1_regret",
|
||
|
|
]:
|
||
|
|
values = np.asarray([float(r[key]) for r in generation_summaries], dtype=np.float64)
|
||
|
|
summary["generation_summary_mean"][key] = float(np.nanmean(values))
|
||
|
|
|
||
|
|
for keep_frac in keep_fracs:
|
||
|
|
rows_k = [r for r in screening_rows if abs(float(r["keep_frac"]) - keep_frac) < 1e-12]
|
||
|
|
if not rows_k:
|
||
|
|
continue
|
||
|
|
summary["screening_mean"][str(keep_frac)] = {
|
||
|
|
"solver_top10_recall": float(np.mean([float(r["solver_top10_recall"]) for r in rows_k])),
|
||
|
|
"best_missed_ratio": float(np.mean([float(r["best_missed"]) for r in rows_k])),
|
||
|
|
"screening_regret": float(np.mean([float(r["screening_regret"]) for r in rows_k])),
|
||
|
|
"kept_failed_or_invalid": float(np.mean([float(r["kept_failed_or_invalid"]) for r in rows_k])),
|
||
|
|
}
|
||
|
|
|
||
|
|
(output_dir / "summary.json").write_text(json.dumps(summary, indent=2, ensure_ascii=False), encoding="utf-8")
|
||
|
|
|
||
|
|
print(f"Trace replay output: {output_dir}")
|
||
|
|
print(f"particle_rows={summary['n_particle_rows']}, generations={summary['n_generations']}")
|
||
|
|
if summary["generation_summary_mean"]:
|
||
|
|
print(
|
||
|
|
"mean Spearman="
|
||
|
|
f"{summary['generation_summary_mean']['spearman_objective']:.4f}, "
|
||
|
|
f"mean best_solver_surrogate_rank={summary['generation_summary_mean']['best_solver_surrogate_rank']:.2f}"
|
||
|
|
)
|
||
|
|
for keep_frac, metrics in summary["screening_mean"].items():
|
||
|
|
print(
|
||
|
|
f"keep={keep_frac}: recall={metrics['solver_top10_recall']:.4f}, "
|
||
|
|
f"best_missed={metrics['best_missed_ratio']:.4f}, "
|
||
|
|
f"regret={metrics['screening_regret']:.6f}"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
main()
|