You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/validate_autofit_local_rank...

416 lines
19 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""批量运行局部自动拟合排序验证。
相比 `validate_autofit_local_ranking.py` 的单案例验证,本脚本会按多个随机种子/目标样本
重复采样局部候选,汇总代理目标与真实目标之间的排序相关性、保留比例和失败案例,
用于给出更稳健的 PSO 预筛选可用性判断。
"""
from __future__ import annotations
import argparse
import csv
import json
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import joblib
import numpy as np
from scripts.generate_autofit_neighborhood_dataset import _resolve_section_indices, sample_schedule_by_mode
from scripts.validate_autofit_local_ranking import (
_corr_pearson,
_corr_spearman,
_rank_positions,
_sample_local_candidates,
infer_curve_layout,
load_model,
predict_surrogate_curve,
run_solver_and_extract_curve,
)
from src.common.config import Config
from src.common.experiment_paths import config_for_stage, model_checkpoint_for_tag, normalize_tag, processed_path_for_tag
from src.data.params import Params, Schedule, generate_params_dataset
from src.data.runner_client import CppRunner
from src.evaluation.autofit_objective import dual_log_objective
def parse_span_fracs(text: str) -> list[float]:
"""解析批量验证时使用的局部扰动尺度列表。"""
values = []
for item in str(text).split(","):
item = item.strip()
if item:
values.append(float(item))
values = sorted({round(x, 10) for x in values if x > 0.0})
if not values:
raise ValueError("At least one positive span fraction is required")
return [float(x) for x in values]
def parse_args() -> argparse.Namespace:
"""解析批量局部排序验证的目标数量、扰动尺度、筛选比例和输出目录。"""
parser = argparse.ArgumentParser(
description="Run local ranking validation on multiple synthetic target cases"
)
parser.add_argument("--config", type=str, default=None)
parser.add_argument(
"--stage",
choices=["fixed_case", "case_neighborhood", "family_random", "family_random_hard"],
default="family_random",
)
parser.add_argument("--processed", type=str, default=None)
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--tag", type=str, default="family_random_mixed_50k_biasfix_logparam")
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--n-targets", type=int, default=20)
parser.add_argument("--n-candidates", type=int, default=48)
parser.add_argument("--span-fracs", type=str, default="0.02,0.05,0.10")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--well-index", type=int, default=0)
parser.add_argument("--solver-timeout", type=int, default=120)
parser.add_argument("--target-max-attempts-factor", type=int, default=5)
return parser.parse_args()
def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]:
"""解析批量局部排序验证所需的配置、模型、预处理数据和输出目录。"""
tag = normalize_tag(args.tag)
config_path = args.config
if config_path is None:
config_path = str(config_for_stage(args.stage) or Path("configs/data_gen_family_random.yaml"))
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = Path(args.model) if args.model is not None else model_checkpoint_for_tag(tag, use_schedule=True)
output_dir = (
Path(args.output_dir)
if args.output_dir is not None
else Path("results") / f"autofit_local_validation_batch_{tag}"
)
return Config(config_path), processed_path.resolve(), model_path.resolve(), output_dir.resolve()
def sample_target_case(cfg: Config, rng: np.random.RandomState, seed: int) -> Params:
"""从配置或数据集中抽取一个目标案例,用于局部排序验证。"""
params = generate_params_dataset(cfg, n_samples=1, method=cfg.raw["params"].get("sampling_method", "sobol"), random_seed=seed)[0]
timeQ, q, _sched_info = sample_schedule_by_mode(cfg, rng)
section_indices = _resolve_section_indices(cfg, timeQ, q, rng)
section_index = int(section_indices[int(rng.randint(0, len(section_indices)))])
params.schedule = Schedule(sectionIndex=section_index, timeQ=list(map(float, timeQ)), q=list(map(float, q)))
return params
def make_runner(cfg: Config, output_dir: Path, tag: str) -> CppRunner:
"""创建 C++ 求解器客户端,并根据参数决定是否启用常驻服务。"""
temp_dir = output_dir / "_runner_temp" / tag
temp_dir.mkdir(parents=True, exist_ok=True)
return CppRunner(cfg=cfg, auto_init=False, temp_dir=temp_dir)
def summarize_rows(rows: list[dict]) -> dict:
"""对批量验证明细行计算均值、分位数和成功率。"""
solver_obj = np.asarray([float(r["solver_objective"]) for r in rows], dtype=np.float64)
surrogate_obj = np.asarray([float(r["surrogate_objective"]) for r in rows], dtype=np.float64)
solver_rank = _rank_positions(solver_obj)
surrogate_rank = _rank_positions(surrogate_obj)
for i, row in enumerate(rows):
row["solver_rank"] = int(solver_rank[i])
row["surrogate_rank"] = int(surrogate_rank[i])
row["rank_gap"] = int(surrogate_rank[i] - solver_rank[i])
top5_solver = set(np.argsort(solver_obj)[: min(5, solver_obj.size)].tolist())
top5_sur = set(np.argsort(surrogate_obj)[: min(5, surrogate_obj.size)].tolist())
top10_solver = set(np.argsort(solver_obj)[: min(10, solver_obj.size)].tolist())
top10_sur = set(np.argsort(surrogate_obj)[: min(10, surrogate_obj.size)].tolist())
top20_solver = set(np.argsort(solver_obj)[: min(20, solver_obj.size)].tolist())
top20_sur = set(np.argsort(surrogate_obj)[: min(20, surrogate_obj.size)].tolist())
best_solver_idx = int(np.argmin(solver_obj))
best_surrogate_idx = int(np.argmin(surrogate_obj))
solver_best = float(solver_obj[best_solver_idx])
solver_at_surrogate_best = float(solver_obj[best_surrogate_idx])
return {
"n_valid": int(len(rows)),
"pearson_objective": _corr_pearson(solver_obj, surrogate_obj),
"spearman_objective": _corr_spearman(solver_obj, surrogate_obj),
"top5_overlap": int(len(top5_solver & top5_sur)),
"top10_overlap": int(len(top10_solver & top10_sur)),
"top20_overlap": int(len(top20_solver & top20_sur)),
"best_solver_candidate_surrogate_rank": int(surrogate_rank[best_solver_idx]),
"best_surrogate_candidate_solver_rank": int(solver_rank[best_surrogate_idx]),
"solver_best_objective": solver_best,
"solver_objective_at_surrogate_best": solver_at_surrogate_best,
"surrogate_top1_regret": float(solver_at_surrogate_best - solver_best),
"mean_objective_gap": float(np.mean(surrogate_obj - solver_obj)),
"mae_objective_gap": float(np.mean(np.abs(surrogate_obj - solver_obj))),
}
def summarize_screening(candidate_rows: list[dict], keep_fracs: list[float]) -> list[dict]:
"""统计代理筛选后的候选集合是否保留真实最优或高质量候选。"""
grouped: dict[tuple[int, float], list[dict]] = {}
for row in candidate_rows:
key = (int(row["target_id"]), float(row["span_frac"]))
grouped.setdefault(key, []).append(row)
out: list[dict] = []
for keep_frac in keep_fracs:
recall_top10 = []
regret = []
missed_solver_best = []
kept_counts = []
for rows in grouped.values():
solver_obj = np.asarray([float(r["solver_objective"]) for r in rows], dtype=np.float64)
surrogate_obj = np.asarray([float(r["surrogate_objective"]) for r in rows], dtype=np.float64)
n = int(solver_obj.size)
if n == 0:
continue
keep_n = max(1, int(np.ceil(float(keep_frac) * n)))
keep_idx = set(np.argsort(surrogate_obj)[:keep_n].tolist())
solver_top10 = set(np.argsort(solver_obj)[: min(10, n)].tolist())
solver_best_idx = int(np.argmin(solver_obj))
# 这里模拟代理预筛选:只保留代理目标最好的 keep_n 个,再看真实好候选是否留下。
kept_solver = solver_obj[list(keep_idx)]
recall_top10.append(float(len(keep_idx & solver_top10) / max(len(solver_top10), 1)))
regret.append(float(np.min(kept_solver) - np.min(solver_obj)))
missed_solver_best.append(float(solver_best_idx not in keep_idx))
kept_counts.append(float(keep_n))
if not recall_top10:
continue
recall_arr = np.asarray(recall_top10, dtype=np.float64)
regret_arr = np.asarray(regret, dtype=np.float64)
missed_arr = np.asarray(missed_solver_best, dtype=np.float64)
kept_arr = np.asarray(kept_counts, dtype=np.float64)
out.append(
{
"keep_frac": float(keep_frac),
"mean_kept_candidates": float(np.mean(kept_arr)),
"solver_top10_recall_mean": float(np.mean(recall_arr)),
"solver_top10_recall_median": float(np.median(recall_arr)),
"solver_top10_recall_p10": float(np.percentile(recall_arr, 10)),
"solver_best_missed_ratio": float(np.mean(missed_arr)),
"regret_mean": float(np.mean(regret_arr)),
"regret_median": float(np.median(regret_arr)),
"regret_p90": float(np.percentile(regret_arr, 90)),
}
)
return out
def main() -> None:
"""批量生成自动拟合局部候选并统计代理筛选对真实最优解的覆盖能力。"""
args = parse_args()
cfg, processed_path, model_path, output_dir = resolve_paths(args)
output_dir.mkdir(parents=True, exist_ok=True)
processed = joblib.load(processed_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
model, use_schedule, device = load_model(model_path)
span_fracs = parse_span_fracs(args.span_fracs)
rng = np.random.RandomState(int(args.seed))
all_candidate_rows: list[dict] = []
target_rows: list[dict] = []
failed_targets: list[dict] = []
max_target_attempts = max(int(args.n_targets), int(args.n_targets) * int(args.target_max_attempts_factor))
target_id = 0
for attempt in range(max_target_attempts):
if target_id >= int(args.n_targets):
break
# target 可能因求解器失败被跳过,所以 attempt 数量通常大于目标数量。
target_params = sample_target_case(cfg, rng, seed=int(args.seed) + 100000 + attempt)
target_runner = make_runner(cfg, output_dir, f"target_{target_id:03d}")
try:
target_curve, _target_raw = run_solver_and_extract_curve(
runner=target_runner,
cfg=cfg,
params=target_params,
well_index=int(args.well_index),
timeout=int(args.solver_timeout),
)
except Exception as exc:
failed_targets.append({"attempt_id": attempt, "reason": str(exc)})
target_runner.close()
continue
finally:
try:
target_runner.close()
except Exception:
pass
print(
f"[target {target_id:03d}] "
f"k={target_params.k:.6g}, skin={target_params.skin:.6g}, "
f"C={target_params.wellboreC:.6g}, phi={target_params.phi:.6g}, h={target_params.h:.6g}, "
f"section={target_params.schedule.sectionIndex}, spans={span_fracs}"
)
for span_frac in span_fracs:
# 同一个目标参数下测试多个扰动半径,观察代理排序对局部范围大小的敏感性。
candidates = _sample_local_candidates(
cfg=cfg,
base=target_params,
n=int(args.n_candidates),
seed=int(args.seed) + target_id * 1000 + int(round(float(span_frac) * 100000)),
span_frac=float(span_frac),
)
rows: list[dict] = []
for cand_id, cand in enumerate(candidates):
runner = make_runner(cfg, output_dir, f"target_{target_id:03d}_span_{span_frac:g}_cand_{cand_id:03d}")
try:
# 候选真实曲线和代理曲线都与同一 target_curve 比较,保证目标函数可排序。
solver_curve, _ = run_solver_and_extract_curve(
runner=runner,
cfg=cfg,
params=cand,
well_index=int(args.well_index),
timeout=int(args.solver_timeout),
)
pred_curve = predict_surrogate_curve(
processed=processed,
model=model,
device=device,
use_schedule=use_schedule,
params=cand,
schedule=cand.schedule,
cfg=cfg,
)
solver_obj = dual_log_objective(target_curve, solver_curve, curve_layout)
surrogate_obj = dual_log_objective(target_curve, pred_curve, curve_layout)
rows.append(
{
"target_id": target_id,
"span_frac": float(span_frac),
"candidate_id": cand_id,
"k": cand.k,
"skin": cand.skin,
"wellboreC": cand.wellboreC,
"phi": cand.phi,
"h": cand.h,
"Cf": cand.Cf,
"solver_objective": solver_obj["dual_log_objective"],
"solver_p_obj": solver_obj["log_pressure_objective"],
"solver_d_obj": solver_obj["log_derivative_objective"],
"surrogate_objective": surrogate_obj["dual_log_objective"],
"surrogate_p_obj": surrogate_obj["log_pressure_objective"],
"surrogate_d_obj": surrogate_obj["log_derivative_objective"],
}
)
except Exception as exc:
print(f"[warn] target={target_id} span={span_frac:g} cand={cand_id} skipped: {exc}")
finally:
runner.close()
if rows:
summary = summarize_rows(rows)
# target_rows 是目标-扰动尺度级别汇总all_candidate_rows 保留候选明细用于筛选统计。
summary.update(
{
"target_id": target_id,
"span_frac": float(span_frac),
"target_k": target_params.k,
"target_skin": target_params.skin,
"target_wellboreC": target_params.wellboreC,
"target_phi": target_params.phi,
"target_h": target_params.h,
"target_Cf": target_params.Cf,
"target_sectionIndex": int(target_params.schedule.sectionIndex),
"target_timeQ_json": json.dumps(target_params.schedule.timeQ),
"target_q_json": json.dumps(target_params.schedule.q),
}
)
target_rows.append(summary)
all_candidate_rows.extend(rows)
print(
f" span={span_frac:g}: Spearman={summary['spearman_objective']:.4f}, "
f"top10={summary['top10_overlap']}, "
f"best_solver_sur_rank={summary['best_solver_candidate_surrogate_rank']}, "
f"regret={summary['surrogate_top1_regret']:.6f}"
)
else:
print(f" span={span_frac:g}: no valid candidates")
target_id += 1
if not target_rows:
raise RuntimeError("No valid target summaries were produced")
with open(output_dir / "target_summaries.csv", "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=list(target_rows[0].keys()))
writer.writeheader()
writer.writerows(target_rows)
if all_candidate_rows:
with open(output_dir / "candidate_objectives.csv", "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=list(all_candidate_rows[0].keys()))
writer.writeheader()
writer.writerows(all_candidate_rows)
# screening_summary 汇总“保留前 30%-70% 候选”时的真实 top10 召回和 regret。
screening_rows = summarize_screening(
all_candidate_rows,
keep_fracs=[0.30, 0.40, 0.50, 0.60, 0.70],
)
if screening_rows:
with open(output_dir / "screening_summary.csv", "w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=list(screening_rows[0].keys()))
writer.writeheader()
writer.writerows(screening_rows)
with open(output_dir / "failed_targets.csv", "w", newline="", encoding="utf-8-sig") as f:
fieldnames = ["attempt_id", "reason"]
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(failed_targets)
spearman = np.asarray([float(r["spearman_objective"]) for r in target_rows], dtype=np.float64)
top10 = np.asarray([float(r["top10_overlap"]) for r in target_rows], dtype=np.float64)
best_solver_rank = np.asarray([float(r["best_solver_candidate_surrogate_rank"]) for r in target_rows], dtype=np.float64)
regret = np.asarray([float(r["surrogate_top1_regret"]) for r in target_rows], dtype=np.float64)
# summary.json 给批量实验一个总览CSV 文件保存可追溯明细。
summary = {
"config_path": str(cfg.path),
"processed_path": str(processed_path),
"model_path": str(model_path),
"n_targets_requested": int(args.n_targets),
"n_target_span_rows": int(len(target_rows)),
"span_fracs": span_fracs,
"n_candidates": int(args.n_candidates),
"spearman_mean": float(np.mean(spearman)),
"spearman_median": float(np.median(spearman)),
"top10_overlap_mean": float(np.mean(top10)),
"best_solver_surrogate_rank_median": float(np.median(best_solver_rank)),
"best_solver_in_surrogate_top10_ratio": float(np.mean(best_solver_rank < 10)),
"surrogate_top1_regret_mean": float(np.mean(regret)),
"surrogate_top1_regret_median": float(np.median(regret)),
"failed_target_count": int(len(failed_targets)),
"screening": screening_rows,
}
with open(output_dir / "summary.json", "w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print("\nBatch local ranking validation complete.")
print(f"Output dir: {output_dir}")
print(
f"Spearman mean={summary['spearman_mean']:.4f}, "
f"best_solver_in_surrogate_top10_ratio={summary['best_solver_in_surrogate_top10_ratio']:.4f}, "
f"regret median={summary['surrogate_top1_regret_median']:.6f}"
)
if __name__ == "__main__":
main()