You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/validate_autofit_local_rank...

394 lines
17 KiB
Python

from __future__ import annotations
import argparse
import csv
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import joblib
import numpy as np
from scripts.generate_autofit_neighborhood_dataset import _resolve_section_indices, sample_schedule_by_mode
from scripts.validate_autofit_local_ranking import (
_corr_pearson,
_corr_spearman,
_rank_positions,
_sample_local_candidates,
infer_curve_layout,
load_model,
predict_surrogate_curve,
run_solver_and_extract_curve,
)
from src.common.config import Config
from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag
from src.data.params import Params, Schedule, generate_params_dataset
from src.data.runner_client import CppRunner
from src.evaluation.autofit_objective import dual_log_objective
def parse_span_fracs(text: str) -> list[float]:
values = []
for item in str(text).split(","):
item = item.strip()
if item:
values.append(float(item))
values = sorted({round(x, 10) for x in values if x > 0.0})
if not values:
raise ValueError("At least one positive span fraction is required")
return [float(x) for x in values]
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run local ranking validation on multiple synthetic target cases"
)
parser.add_argument("--config", type=str, default=None)
parser.add_argument(
"--stage",
choices=["fixed_case", "case_neighborhood", "family_random", "family_random_hard"],
default="family_random",
)
parser.add_argument("--processed", type=str, default=None)
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--tag", type=str, default="family_random_mixed_50k_biasfix_logparam")
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--n-targets", type=int, default=20)
parser.add_argument("--n-candidates", type=int, default=48)
parser.add_argument("--span-fracs", type=str, default="0.02,0.05,0.10")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--well-index", type=int, default=0)
parser.add_argument("--solver-timeout", type=int, default=120)
parser.add_argument("--target-max-attempts-factor", type=int, default=5)
return parser.parse_args()
def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]:
tag = normalize_tag(args.tag)
config_path = args.config
if config_path is None:
config_path = str(config_for_stage(args.stage) or Path("configs/data_gen_family_random.yaml"))
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = Path(args.model) if args.model is not None else model_checkpoint_for_tag(tag, use_schedule=True)
output_dir = (
Path(args.output_dir)
if args.output_dir is not None
else Path("results") / f"autofit_local_validation_batch_{tag}"
)
return Config(config_path), processed_path.resolve(), model_path.resolve(), output_dir.resolve()
def sample_target_case(cfg: Config, rng: np.random.RandomState, seed: int) -> Params:
params = generate_params_dataset(cfg, n_samples=1, method=cfg.raw["params"].get("sampling_method", "sobol"), random_seed=seed)[0]
timeQ, q, _sched_info = sample_schedule_by_mode(cfg, rng)
section_indices = _resolve_section_indices(cfg, timeQ, q, rng)
section_index = int(section_indices[int(rng.randint(0, len(section_indices)))])
params.schedule = Schedule(sectionIndex=section_index, timeQ=list(map(float, timeQ)), q=list(map(float, q)))
return params
def make_runner(cfg: Config, output_dir: Path, tag: str) -> CppRunner:
temp_dir = output_dir / "_runner_temp" / tag
temp_dir.mkdir(parents=True, exist_ok=True)
return CppRunner(cfg=cfg, auto_init=False, temp_dir=temp_dir)
def summarize_rows(rows: list[dict]) -> dict:
solver_obj = np.asarray([float(r["solver_objective"]) for r in rows], dtype=np.float64)
surrogate_obj = np.asarray([float(r["surrogate_objective"]) for r in rows], dtype=np.float64)
solver_rank = _rank_positions(solver_obj)
surrogate_rank = _rank_positions(surrogate_obj)
for i, row in enumerate(rows):
row["solver_rank"] = int(solver_rank[i])
row["surrogate_rank"] = int(surrogate_rank[i])
row["rank_gap"] = int(surrogate_rank[i] - solver_rank[i])
top5_solver = set(np.argsort(solver_obj)[: min(5, solver_obj.size)].tolist())
top5_sur = set(np.argsort(surrogate_obj)[: min(5, surrogate_obj.size)].tolist())
top10_solver = set(np.argsort(solver_obj)[: min(10, solver_obj.size)].tolist())
top10_sur = set(np.argsort(surrogate_obj)[: min(10, surrogate_obj.size)].tolist())
top20_solver = set(np.argsort(solver_obj)[: min(20, solver_obj.size)].tolist())
top20_sur = set(np.argsort(surrogate_obj)[: min(20, surrogate_obj.size)].tolist())
best_solver_idx = int(np.argmin(solver_obj))
best_surrogate_idx = int(np.argmin(surrogate_obj))
solver_best = float(solver_obj[best_solver_idx])
solver_at_surrogate_best = float(solver_obj[best_surrogate_idx])
return {
"n_valid": int(len(rows)),
"pearson_objective": _corr_pearson(solver_obj, surrogate_obj),
"spearman_objective": _corr_spearman(solver_obj, surrogate_obj),
"top5_overlap": int(len(top5_solver & top5_sur)),
"top10_overlap": int(len(top10_solver & top10_sur)),
"top20_overlap": int(len(top20_solver & top20_sur)),
"best_solver_candidate_surrogate_rank": int(surrogate_rank[best_solver_idx]),
"best_surrogate_candidate_solver_rank": int(solver_rank[best_surrogate_idx]),
"solver_best_objective": solver_best,
"solver_objective_at_surrogate_best": solver_at_surrogate_best,
"surrogate_top1_regret": float(solver_at_surrogate_best - solver_best),
"mean_objective_gap": float(np.mean(surrogate_obj - solver_obj)),
"mae_objective_gap": float(np.mean(np.abs(surrogate_obj - solver_obj))),
}
def summarize_screening(candidate_rows: list[dict], keep_fracs: list[float]) -> list[dict]:
grouped: dict[tuple[int, float], list[dict]] = {}
for row in candidate_rows:
key = (int(row["target_id"]), float(row["span_frac"]))
grouped.setdefault(key, []).append(row)
out: list[dict] = []
for keep_frac in keep_fracs:
recall_top10 = []
regret = []
missed_solver_best = []
kept_counts = []
for rows in grouped.values():
solver_obj = np.asarray([float(r["solver_objective"]) for r in rows], dtype=np.float64)
surrogate_obj = np.asarray([float(r["surrogate_objective"]) for r in rows], dtype=np.float64)
n = int(solver_obj.size)
if n == 0:
continue
keep_n = max(1, int(np.ceil(float(keep_frac) * n)))
keep_idx = set(np.argsort(surrogate_obj)[:keep_n].tolist())
solver_top10 = set(np.argsort(solver_obj)[: min(10, n)].tolist())
solver_best_idx = int(np.argmin(solver_obj))
kept_solver = solver_obj[list(keep_idx)]
recall_top10.append(float(len(keep_idx & solver_top10) / max(len(solver_top10), 1)))
regret.append(float(np.min(kept_solver) - np.min(solver_obj)))
missed_solver_best.append(float(solver_best_idx not in keep_idx))
kept_counts.append(float(keep_n))
if not recall_top10:
continue
recall_arr = np.asarray(recall_top10, dtype=np.float64)
regret_arr = np.asarray(regret, dtype=np.float64)
missed_arr = np.asarray(missed_solver_best, dtype=np.float64)
kept_arr = np.asarray(kept_counts, dtype=np.float64)
out.append(
{
"keep_frac": float(keep_frac),
"mean_kept_candidates": float(np.mean(kept_arr)),
"solver_top10_recall_mean": float(np.mean(recall_arr)),
"solver_top10_recall_median": float(np.median(recall_arr)),
"solver_top10_recall_p10": float(np.percentile(recall_arr, 10)),
"solver_best_missed_ratio": float(np.mean(missed_arr)),
"regret_mean": float(np.mean(regret_arr)),
"regret_median": float(np.median(regret_arr)),
"regret_p90": float(np.percentile(regret_arr, 90)),
}
)
return out
def main() -> None:
args = parse_args()
cfg, processed_path, model_path, output_dir = resolve_paths(args)
output_dir.mkdir(parents=True, exist_ok=True)
processed = joblib.load(processed_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
model, use_schedule, device = load_model(model_path)
span_fracs = parse_span_fracs(args.span_fracs)
rng = np.random.RandomState(int(args.seed))
all_candidate_rows: list[dict] = []
target_rows: list[dict] = []
failed_targets: list[dict] = []
max_target_attempts = max(int(args.n_targets), int(args.n_targets) * int(args.target_max_attempts_factor))
target_id = 0
for attempt in range(max_target_attempts):
if target_id >= int(args.n_targets):
break
target_params = sample_target_case(cfg, rng, seed=int(args.seed) + 100000 + attempt)
target_runner = make_runner(cfg, output_dir, f"target_{target_id:03d}")
try:
target_curve, _target_raw = run_solver_and_extract_curve(
runner=target_runner,
cfg=cfg,
params=target_params,
well_index=int(args.well_index),
timeout=int(args.solver_timeout),
)
except Exception as exc:
failed_targets.append({"attempt_id": attempt, "reason": str(exc)})
target_runner.close()
continue
finally:
try:
target_runner.close()
except Exception:
pass
print(
f"[target {target_id:03d}] "
f"k={target_params.k:.6g}, skin={target_params.skin:.6g}, "
f"C={target_params.wellboreC:.6g}, phi={target_params.phi:.6g}, h={target_params.h:.6g}, "
f"section={target_params.schedule.sectionIndex}, spans={span_fracs}"
)
for span_frac in span_fracs:
candidates = _sample_local_candidates(
cfg=cfg,
base=target_params,
n=int(args.n_candidates),
seed=int(args.seed) + target_id * 1000 + int(round(float(span_frac) * 100000)),
span_frac=float(span_frac),
)
rows: list[dict] = []
for cand_id, cand in enumerate(candidates):
runner = make_runner(cfg, output_dir, f"target_{target_id:03d}_span_{span_frac:g}_cand_{cand_id:03d}")
try:
solver_curve, _ = run_solver_and_extract_curve(
runner=runner,
cfg=cfg,
params=cand,
well_index=int(args.well_index),
timeout=int(args.solver_timeout),
)
pred_curve = predict_surrogate_curve(
processed=processed,
model=model,
device=device,
use_schedule=use_schedule,
params=cand,
schedule=cand.schedule,
cfg=cfg,
)
solver_obj = dual_log_objective(target_curve, solver_curve, curve_layout)
surrogate_obj = dual_log_objective(target_curve, pred_curve, curve_layout)
rows.append(
{
"target_id": target_id,
"span_frac": float(span_frac),
"candidate_id": cand_id,
"k": cand.k,
"skin": cand.skin,
"wellboreC": cand.wellboreC,
"phi": cand.phi,
"h": cand.h,
"Cf": cand.Cf,
"solver_objective": solver_obj["dual_log_objective"],
"solver_p_obj": solver_obj["log_pressure_objective"],
"solver_d_obj": solver_obj["log_derivative_objective"],
"surrogate_objective": surrogate_obj["dual_log_objective"],
"surrogate_p_obj": surrogate_obj["log_pressure_objective"],
"surrogate_d_obj": surrogate_obj["log_derivative_objective"],
}
)
except Exception as exc:
print(f"[warn] target={target_id} span={span_frac:g} cand={cand_id} skipped: {exc}")
finally:
runner.close()
if rows:
summary = summarize_rows(rows)
summary.update(
{
"target_id": target_id,
"span_frac": float(span_frac),
"target_k": target_params.k,
"target_skin": target_params.skin,
"target_wellboreC": target_params.wellboreC,
"target_phi": target_params.phi,
"target_h": target_params.h,
"target_Cf": target_params.Cf,
"target_sectionIndex": int(target_params.schedule.sectionIndex),
"target_timeQ_json": json.dumps(target_params.schedule.timeQ),
"target_q_json": json.dumps(target_params.schedule.q),
}
)
target_rows.append(summary)
all_candidate_rows.extend(rows)
print(
f" span={span_frac:g}: Spearman={summary['spearman_objective']:.4f}, "
f"top10={summary['top10_overlap']}, "
f"best_solver_sur_rank={summary['best_solver_candidate_surrogate_rank']}, "
f"regret={summary['surrogate_top1_regret']:.6f}"
)
else:
print(f" span={span_frac:g}: no valid candidates")
target_id += 1
if not target_rows:
raise RuntimeError("No valid target summaries were produced")
with open(output_dir / "target_summaries.csv", "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=list(target_rows[0].keys()))
writer.writeheader()
writer.writerows(target_rows)
if all_candidate_rows:
with open(output_dir / "candidate_objectives.csv", "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=list(all_candidate_rows[0].keys()))
writer.writeheader()
writer.writerows(all_candidate_rows)
screening_rows = summarize_screening(
all_candidate_rows,
keep_fracs=[0.30, 0.40, 0.50, 0.60, 0.70],
)
if screening_rows:
with open(output_dir / "screening_summary.csv", "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=list(screening_rows[0].keys()))
writer.writeheader()
writer.writerows(screening_rows)
with open(output_dir / "failed_targets.csv", "w", newline="", encoding="utf-8-sig") as f:
fieldnames = ["attempt_id", "reason"]
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(failed_targets)
spearman = np.asarray([float(r["spearman_objective"]) for r in target_rows], dtype=np.float64)
top10 = np.asarray([float(r["top10_overlap"]) for r in target_rows], dtype=np.float64)
best_solver_rank = np.asarray([float(r["best_solver_candidate_surrogate_rank"]) for r in target_rows], dtype=np.float64)
regret = np.asarray([float(r["surrogate_top1_regret"]) for r in target_rows], dtype=np.float64)
summary = {
"config_path": str(cfg.path),
"processed_path": str(processed_path),
"model_path": str(model_path),
"n_targets_requested": int(args.n_targets),
"n_target_span_rows": int(len(target_rows)),
"span_fracs": span_fracs,
"n_candidates": int(args.n_candidates),
"spearman_mean": float(np.mean(spearman)),
"spearman_median": float(np.median(spearman)),
"top10_overlap_mean": float(np.mean(top10)),
"best_solver_surrogate_rank_median": float(np.median(best_solver_rank)),
"best_solver_in_surrogate_top10_ratio": float(np.mean(best_solver_rank < 10)),
"surrogate_top1_regret_mean": float(np.mean(regret)),
"surrogate_top1_regret_median": float(np.median(regret)),
"failed_target_count": int(len(failed_targets)),
"screening": screening_rows,
}
with open(output_dir / "summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print("\nBatch local ranking validation complete.")
print(f"Output dir: {output_dir}")
print(
f"Spearman mean={summary['spearman_mean']:.4f}, "
f"best_solver_in_surrogate_top10_ratio={summary['best_solver_in_surrogate_top10_ratio']:.4f}, "
f"regret median={summary['surrogate_top1_regret_median']:.6f}"
)
if __name__ == "__main__":
main()