You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/q_sweep_local_ranking.py

564 lines
22 KiB
Python

"""扫描流量制度变化对局部排序能力的影响。
脚本固定目标地层参数构造多组不同生产/关井制度分别运行数值求解器和代理模型
比较候选排序相关系数和保留比例下的最优值用于判断代理模型是否对流量制度变化
保持稳定的自动拟合筛选能力
"""
# pylint: disable=import-error,wrong-import-position,wrong-import-order,too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements,broad-exception-caught
from __future__ import annotations
import argparse
import csv
import json
import math
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import joblib
import numpy as np
from scripts.validate_autofit_local_ranking import (
_corr_pearson,
_corr_spearman,
_rank_positions,
_sample_local_candidates,
infer_curve_layout,
load_model,
predict_surrogate_curve,
run_solver_and_extract_curve,
)
from src.common.config import Config
from src.common.experiment_paths import (
config_for_stage,
model_checkpoint_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.data.params import Params, Schedule
from src.data.runner_client import CppRunner
from src.evaluation.autofit_objective import dual_log_objective
THETA_STAR = {
"k": 0.008584308759474064,
"skin": 1.2096901115405767,
"wellboreC": 0.03024150790038185,
"phi": 0.2736409306526184,
"h": 21.499309979069142,
"Cf": 4.315e-4,
}
KEEP_FRACS = [0.50, 0.60, 0.70]
def parse_args() -> argparse.Namespace:
"""解析多制度 q 扫描实验的目标样本、候选数量、代理筛选比例和输出路径。"""
parser = argparse.ArgumentParser(
description="Generate Q schedules around theta* and run first-layer local ranking validation."
)
parser.add_argument("--config", type=str, default=None)
parser.add_argument(
"--stage",
choices=[
"fixed_case",
"case_neighborhood",
"family_random",
"family_random_hard",
"family_random_v2_q",
],
default="family_random",
)
parser.add_argument("--processed", type=str, default=None)
parser.add_argument("--model", type=str, default=None)
parser.add_argument("--tag", type=str, default="family_random_mixed_50k_logparam")
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--n-candidates", type=int, default=48)
parser.add_argument("--span-frac", type=float, default=0.05)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--well-index", type=int, default=0)
parser.add_argument("--solver-timeout", type=int, default=120)
parser.add_argument("--max-candidate-attempts-factor", type=int, default=4)
parser.add_argument("--max-q-cases", type=int, default=None, help="Optional cap for quick smoke runs or batched execution")
parser.add_argument("--case-id-contains", type=str, default=None, help="Only run Q cases whose case_id contains this text")
parser.add_argument("--k", type=float, default=THETA_STAR["k"])
parser.add_argument("--skin", type=float, default=THETA_STAR["skin"])
parser.add_argument("--wellboreC", type=float, default=THETA_STAR["wellboreC"])
parser.add_argument("--phi", type=float, default=THETA_STAR["phi"])
parser.add_argument("--h", type=float, default=THETA_STAR["h"])
parser.add_argument("--Cf", type=float, default=THETA_STAR["Cf"])
return parser.parse_args()
def resolve_paths(args: argparse.Namespace) -> tuple[Config, Path, Path, Path]:
"""解析流量扰动排序实验需要的配置、模型、预处理数据和输出目录。"""
tag = normalize_tag(args.tag)
config_path = Path(args.config) if args.config is not None else config_for_stage(args.stage)
if config_path is None:
raise ValueError(f"Cannot resolve config for stage={args.stage!r}; pass --config explicitly")
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
model_path = Path(args.model) if args.model is not None else model_checkpoint_for_tag(tag, use_schedule=True)
output_dir = (
Path(args.output_dir)
if args.output_dir is not None
else Path("results") / f"q_sweep_local_ranking_{tag}"
)
return Config(config_path), processed_path.resolve(), model_path.resolve(), output_dir.resolve()
def make_theta_params(args: argparse.Namespace, schedule: Schedule) -> Params:
"""根据基准参数和扰动变量构造一组候选物理参数。"""
return Params(
k=float(args.k),
skin=float(args.skin),
wellboreC=float(args.wellboreC),
phi=float(args.phi),
h=float(args.h),
Cf=float(args.Cf),
schedule=schedule,
)
def normalize_prod_durations(base_prod_dt: list[float], total: float) -> list[float]:
"""归一化生产段时长,使总生产时间保持不变。"""
arr = np.asarray(base_prod_dt, dtype=np.float64)
arr = np.maximum(arr, 0.05)
arr *= float(total) / max(float(np.sum(arr)), 1e-12)
arr[-1] += float(total) - float(np.sum(arr))
return [float(max(x, 0.05)) for x in arr]
def schedule_features(time_q: list[float], q: list[float]) -> dict:
"""从流量制度中提取趋势、波动、阶跃比例等排序分析特征。"""
dt = np.asarray(time_q, dtype=np.float64)
qq = np.asarray(q, dtype=np.float64)
q_thr = 1e-6
has_shutin = bool(len(qq) > 0 and qq[-1] <= q_thr)
n_prod = int(len(qq) - 1 if has_shutin else len(qq))
prod_dt = dt[:n_prod]
prod_q = qq[:n_prod]
step_ratios = []
for i in range(1, len(prod_q)):
lo = max(min(float(prod_q[i - 1]), float(prod_q[i])), 1e-12)
hi = max(float(prod_q[i - 1]), float(prod_q[i]))
step_ratios.append(hi / lo)
q_mean = float(np.mean(prod_q)) if n_prod else 0.0
q_std = float(np.std(prod_q)) if n_prod else 0.0
return {
"n_sections": int(len(time_q)),
"n_prod_sections": n_prod,
"prod_total_time": float(np.sum(prod_dt)) if n_prod else 0.0,
"shutin_time": float(dt[-1]) if has_shutin else 0.0,
"q_min": float(np.min(prod_q)) if n_prod else 0.0,
"q_max": float(np.max(prod_q)) if n_prod else 0.0,
"q_mean": q_mean,
"q_std": q_std,
"q_cv": float(q_std / max(q_mean, 1e-12)) if n_prod else 0.0,
"max_step_ratio": float(max(step_ratios)) if step_ratios else 1.0,
}
def add_case(
cases: list[dict],
seen: set[tuple],
case_id: str,
family: str,
time_q: list[float],
q: list[float],
axis: str,
) -> None:
"""向案例列表追加一个流量制度候选及其特征。"""
key = tuple(round(float(x), 8) for x in (time_q + q))
if key in seen:
return
seen.add(key)
schedule = Schedule(sectionIndex=len(time_q), timeQ=list(map(float, time_q)), q=list(map(float, q)))
features = schedule_features(time_q, q)
cases.append(
{
"case_id": case_id,
"axis": axis,
"family": family,
"schedule": schedule,
"timeQ": list(map(float, time_q)),
"q": list(map(float, q)),
**features,
}
)
def generate_q_cases() -> list[dict]:
"""围绕基准流量制度生成多种流量扰动案例,用于局部排序验证。"""
cases: list[dict] = []
seen: set[tuple] = set()
# 基准制度接近 B3 交替阶跃案例;后续 case 每次只重点改变一个流量因素。
base_prod_dt = [12.0, 12.0, 24.0, 36.0]
base_q = [170.0, 170.0, 210.0, 250.0]
base_shutin = 72.0
add_case(
cases,
seen,
"Q000_baseline_b3_alt_like",
"mild_step",
base_prod_dt + [base_shutin],
base_q + [0.0],
"baseline",
)
# 分别扫描生产总时长、关井时长、流量倍率、阶跃强度和生产段数量。
for total in [24.0, 48.0, 72.0, 96.0, 160.0, 220.0, 260.0]:
time_q = normalize_prod_durations(base_prod_dt, total) + [base_shutin]
add_case(cases, seen, f"prod_total_{int(total)}", "mild_step", time_q, base_q + [0.0], "prod_total_time")
for shutin in [12.0, 24.0, 48.0, 72.0, 96.0, 140.0, 180.0]:
add_case(cases, seen, f"shutin_{int(shutin)}", "mild_step", base_prod_dt + [shutin], base_q + [0.0], "shutin_time")
for scale in [0.5, 0.75, 1.0, 1.25, 1.5, 2.0, 3.0]:
q_prod = [float(x * scale) for x in base_q]
add_case(
cases,
seen,
f"prod_total_{int(total)}",
"mild_step",
time_q,
base_q + [0.0],
"prod_total_time",
)
for ratio in [1.2, 1.6, 2.0, 2.2, 2.8, 3.5]:
q0 = 180.0
q_prod = [q0, q0 * ratio, q0, q0 * ratio]
add_case(
cases,
seen,
f"step_ratio_{str(ratio).replace('.', 'p')}",
"sharp_step",
base_prod_dt + [base_shutin],
q_prod + [0.0],
"step_ratio",
)
for n_prod in [2, 3, 4, 5, 6]:
total = 84.0
time_prod = [total / n_prod] * n_prod
if n_prod == 2:
q_prod = [170.0, 250.0]
elif n_prod == 3:
q_prod = [170.0, 210.0, 250.0]
elif n_prod == 4:
q_prod = base_q
elif n_prod == 5:
q_prod = [160.0, 180.0, 205.0, 230.0, 250.0]
else:
q_prod = [150.0, 170.0, 190.0, 210.0, 230.0, 250.0]
add_case(cases, seen, f"n_prod_{n_prod}", "increasing", time_prod + [base_shutin], q_prod + [0.0], "n_prod_sections")
family_specs = [
("flat", [200.0, 200.0, 200.0, 200.0]),
("increasing", [150.0, 180.0, 220.0, 260.0]),
("decreasing", [260.0, 220.0, 180.0, 150.0]),
("mild_step", [160.0, 195.0, 225.0, 240.0]),
("sharp_step", [100.0, 250.0, 120.0, 300.0]),
]
for family, q_prod in family_specs:
add_case(cases, seen, f"family_{family}", family, base_prod_dt + [base_shutin], q_prod + [0.0], "family")
return cases
def make_runner(cfg: Config, output_dir: Path, tag: str) -> CppRunner:
"""创建 C++ 求解器客户端,并根据参数决定是否启用常驻服务。"""
temp_dir = output_dir / "_runner_temp" / tag
temp_dir.mkdir(parents=True, exist_ok=True)
return CppRunner(cfg=cfg, auto_init=False, temp_dir=temp_dir)
def summarize_case(rows: list[dict], keep_fracs: list[float]) -> dict:
"""汇总单个候选案例的代理目标、真实目标和排序信息。"""
solver_obj = np.asarray([float(r["solver_objective"]) for r in rows], dtype=np.float64)
surrogate_obj = np.asarray([float(r["surrogate_objective"]) for r in rows], dtype=np.float64)
solver_rank = _rank_positions(solver_obj)
surrogate_rank = _rank_positions(surrogate_obj)
for i, row in enumerate(rows):
row["solver_rank"] = int(solver_rank[i])
row["surrogate_rank"] = int(surrogate_rank[i])
row["rank_gap"] = int(surrogate_rank[i] - solver_rank[i])
n = int(len(rows))
best_solver_idx = int(np.argmin(solver_obj))
best_surrogate_idx = int(np.argmin(surrogate_obj))
top10_n = min(10, n)
solver_top10 = set(np.argsort(solver_obj)[:top10_n].tolist())
screening: dict[str, float | int | bool] = {}
for keep_frac in keep_fracs:
keep_n = max(1, int(math.ceil(float(keep_frac) * n)))
keep_idx = set(np.argsort(surrogate_obj)[:keep_n].tolist())
kept_solver = solver_obj[list(keep_idx)]
label = f"keep{int(round(keep_frac * 100))}"
# 用代理目标筛掉一部分候选后,检查真实最优和真实 top10 是否仍被保留下来。
screening[f"{label}_n"] = int(keep_n)
screening[f"best_in_{label}"] = bool(best_solver_idx in keep_idx)
screening[f"{label}_top10_recall"] = float(len(solver_top10 & keep_idx) / max(top10_n, 1))
screening[f"{label}_regret"] = float(np.min(kept_solver) - np.min(solver_obj))
best_rank = int(surrogate_rank[best_solver_idx])
keep50_n = int(screening["keep50_n"])
keep60_n = int(screening["keep60_n"])
keep70_n = int(screening["keep70_n"])
if best_rank < keep50_n and float(screening["keep50_regret"]) <= 1e-8:
risk_label = "safe_keep50"
elif best_rank < keep60_n:
risk_label = "borderline_keep60"
elif best_rank < keep70_n:
risk_label = "borderline_keep70"
else:
risk_label = "unsafe"
return {
"n_valid": n,
"pearson_objective": _corr_pearson(solver_obj, surrogate_obj),
"spearman_objective": _corr_spearman(solver_obj, surrogate_obj),
"best_solver_candidate_id": best_solver_idx,
"best_solver_surrogate_rank": best_rank,
"best_surrogate_candidate_id": best_surrogate_idx,
"best_surrogate_solver_rank": int(solver_rank[best_surrogate_idx]),
"solver_best_objective": float(solver_obj[best_solver_idx]),
"solver_objective_at_surrogate_best": float(solver_obj[best_surrogate_idx]),
"surrogate_top1_regret": float(solver_obj[best_surrogate_idx] - solver_obj[best_solver_idx]),
"risk_label": risk_label,
**screening,
}
def write_csv(path: Path, rows: list[dict]) -> None:
"""按字段名写出 CSV 明细或汇总结果。"""
if not rows:
return
fieldnames: list[str] = []
for row in rows:
for key in row.keys():
if key not in fieldnames:
fieldnames.append(key)
with path.open("w", newline="", encoding="utf-8-sig") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
writer.writeheader()
writer.writerows(rows)
def main() -> None:
"""比较不同流量制度下代理模型筛选候选参数的局部排序可靠性。"""
args = parse_args()
cfg, processed_path, model_path, output_dir = resolve_paths(args)
output_dir.mkdir(parents=True, exist_ok=True)
if not processed_path.exists():
raise FileNotFoundError(f"Processed dataset not found: {processed_path}")
if not model_path.exists():
raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
processed = joblib.load(processed_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
model, use_schedule, device = load_model(model_path)
q_cases = generate_q_cases()
# 每个 q_case 固定一条目标流量制度;后续只扰动物理参数,验证局部排序是否稳健。
if args.case_id_contains:
needle = str(args.case_id_contains)
q_cases = [case for case in q_cases if needle in str(case["case_id"])]
if args.max_q_cases is not None:
q_cases = q_cases[: max(0, int(args.max_q_cases))]
case_rows: list[dict] = []
candidate_rows: list[dict] = []
failed_case_rows: list[dict] = []
failed_candidate_rows: list[dict] = []
for case_idx, q_case in enumerate(q_cases):
schedule: Schedule = q_case["schedule"]
if not schedule.validate():
failed_case_rows.append(
{
**{k: v for k, v in q_case.items() if k != "schedule"},
"reason": "invalid schedule",
}
)
continue
target_params = make_theta_params(args, schedule)
print(
f"[Q {case_idx:03d}] {q_case['case_id']} axis={q_case['axis']} "
f"n_prod={q_case['n_prod_sections']} prodT={q_case['prod_total_time']:.3g} "
f"shut={q_case['shutin_time']:.3g} q=[{q_case['q_min']:.3g},{q_case['q_max']:.3g}]"
)
target_runner = make_runner(cfg, output_dir, f"target_{case_idx:03d}")
try:
# 先用真实求解器生成该流量制度下的目标曲线,候选曲线都和它比较。
target_curve, _ = run_solver_and_extract_curve(
runner=target_runner,
cfg=cfg,
params=target_params,
well_index=int(args.well_index),
timeout=int(args.solver_timeout),
)
except Exception as exc:
failed_case_rows.append(
{
**{k: v for k, v in q_case.items() if k != "schedule"},
"reason": str(exc),
}
)
print(f" [fail] target solver failed: {exc}")
continue
finally:
target_runner.close()
candidates = _sample_local_candidates(
cfg=cfg,
base=target_params,
n=max(int(args.n_candidates), int(args.n_candidates) * int(args.max_candidate_attempts_factor)),
seed=int(args.seed) + case_idx * 1009,
span_frac=float(args.span_frac),
)
rows: list[dict] = []
for attempt_id, cand in enumerate(candidates):
if len(rows) >= int(args.n_candidates):
break
runner = make_runner(cfg, output_dir, f"case_{case_idx:03d}_cand_{attempt_id:04d}")
try:
# 同一候选同时计算 solver 目标和 surrogate 目标,后面只比较排序,不混用数值来源。
solver_curve, _ = run_solver_and_extract_curve(
runner=runner,
cfg=cfg,
params=cand,
well_index=int(args.well_index),
timeout=int(args.solver_timeout),
)
pred_curve = predict_surrogate_curve(
processed=processed,
model=model,
device=device,
use_schedule=use_schedule,
params=cand,
schedule=schedule,
cfg=cfg,
)
solver_obj = dual_log_objective(target_curve, solver_curve, curve_layout)
surrogate_obj = dual_log_objective(target_curve, pred_curve, curve_layout)
row = {
"case_id": q_case["case_id"],
"case_index": case_idx,
"candidate_id": len(rows),
"attempt_id": attempt_id,
"k": cand.k,
"skin": cand.skin,
"wellboreC": cand.wellboreC,
"phi": cand.phi,
"h": cand.h,
"Cf": cand.Cf,
"solver_objective": solver_obj["dual_log_objective"],
"solver_p_obj": solver_obj["log_pressure_objective"],
"solver_d_obj": solver_obj["log_derivative_objective"],
"surrogate_objective": surrogate_obj["dual_log_objective"],
"surrogate_p_obj": surrogate_obj["log_pressure_objective"],
"surrogate_d_obj": surrogate_obj["log_derivative_objective"],
}
rows.append(row)
except Exception as exc:
failed_candidate_rows.append(
{
"case_id": q_case["case_id"],
"case_index": case_idx,
"attempt_id": attempt_id,
"reason": str(exc),
"k": cand.k,
"skin": cand.skin,
"wellboreC": cand.wellboreC,
"phi": cand.phi,
"h": cand.h,
"Cf": cand.Cf,
}
)
finally:
runner.close()
if len(rows) < 2:
failed_case_rows.append(
{
**{k: v for k, v in q_case.items() if k != "schedule"},
"reason": "fewer than two valid candidates",
}
)
print(f" [fail] valid candidates={len(rows)}")
continue
case_summary = summarize_case(rows, KEEP_FRACS)
# case_rows 是每个流量制度的排序结论candidate_rows 保留候选级别明细。
case_row = {
**{k: v for k, v in q_case.items() if k != "schedule"},
"timeQ_json": json.dumps(q_case["timeQ"]),
"q_json": json.dumps(q_case["q"]),
**case_summary,
}
case_rows.append(case_row)
candidate_rows.extend(rows)
print(
f" risk={case_summary['risk_label']} spearman={case_summary['spearman_objective']:.4f} "
f"best_sur_rank={case_summary['best_solver_surrogate_rank']} "
f"keep50={case_summary['best_in_keep50']} regret={case_summary['surrogate_top1_regret']:.6g}"
)
write_csv(output_dir / "q_case_summaries.csv", case_rows)
write_csv(output_dir / "candidate_objectives.csv", candidate_rows)
write_csv(output_dir / "failed_q_cases.csv", failed_case_rows)
write_csv(output_dir / "failed_candidates.csv", failed_candidate_rows)
risk_counts: dict[str, int] = {}
for row in case_rows:
risk = str(row["risk_label"])
risk_counts[risk] = risk_counts.get(risk, 0) + 1
summary = {
"config_path": str(cfg.path),
"processed_path": str(processed_path),
"model_path": str(model_path),
"theta_star": {
"k": float(args.k),
"skin": float(args.skin),
"wellboreC": float(args.wellboreC),
"phi": float(args.phi),
"h": float(args.h),
"Cf": float(args.Cf),
},
"n_q_cases_generated": int(len(q_cases)),
"n_q_cases_valid": int(len(case_rows)),
"n_q_cases_failed": int(len(failed_case_rows)),
"n_candidates_requested_per_case": int(args.n_candidates),
"span_frac": float(args.span_frac),
"keep_fracs": KEEP_FRACS,
"risk_counts": risk_counts,
}
with (output_dir / "summary.json").open("w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
print("\nQ sweep local ranking complete.")
print(f"Output dir: {output_dir}")
print(f"Valid Q cases={len(case_rows)}/{len(q_cases)}, failed={len(failed_case_rows)}")
print(f"Risk counts={risk_counts}")
if __name__ == "__main__":
main()