|
|
"""生成面向自动拟合/PSO 局部排序的数据集。
|
|
|
|
|
|
脚本先采样 anchor 参数与流量制度,再围绕每个 anchor 在变换参数空间中构造多尺度邻域,
|
|
|
调用 C++ 数值求解器得到真实曲线和自动拟合目标函数,并按目标值分层保留候选样本。
|
|
|
输出的 HDF5 同时包含曲线、参数、制度编码和排序训练所需的邻域元数据。
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
import sys
|
|
|
from collections import Counter
|
|
|
from pathlib import Path
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
|
sys.path.append(str(ROOT))
|
|
|
|
|
|
import h5py
|
|
|
import numpy as np
|
|
|
|
|
|
from src.common.config import Config
|
|
|
from src.common.experiment_paths import config_for_stage, normalize_tag
|
|
|
from src.data.curve_processing import clean_curve_for_dataset, is_valid_curve, resample_curve_to_features
|
|
|
from src.data.params import Params, Schedule, generate_params_dataset
|
|
|
from src.data.runner_client import CppRunner, read_result_bin
|
|
|
from src.data.schedule_encoding import encode_schedule_to_timegrid
|
|
|
from src.evaluation.autofit_objective import dual_log_objective
|
|
|
|
|
|
|
|
|
SCHEDULE_META_NAMES = [
|
|
|
"family_id",
|
|
|
"section_index",
|
|
|
"n_sections",
|
|
|
"n_prod",
|
|
|
"prod_total_time",
|
|
|
"shutin_dt",
|
|
|
"q_first_prod",
|
|
|
"q_last_prod",
|
|
|
"q_mean_prod",
|
|
|
"q_std_prod",
|
|
|
"q_min_prod",
|
|
|
"q_max_prod",
|
|
|
]
|
|
|
|
|
|
|
|
|
def _pick_mixture(rng: np.random.RandomState, items: list[dict]) -> dict:
|
|
|
"""按配置中的概率权重从多个采样组件里抽取一个组件。"""
|
|
|
probs = np.asarray([float(it.get("prob", 0.0)) for it in items], dtype=np.float64)
|
|
|
s = float(np.sum(probs))
|
|
|
if s <= 0:
|
|
|
return items[int(rng.randint(0, len(items)))]
|
|
|
probs = probs / s
|
|
|
u = float(rng.rand())
|
|
|
c = 0.0
|
|
|
for it, p in zip(items, probs):
|
|
|
c += float(p)
|
|
|
if u <= c:
|
|
|
return it
|
|
|
return items[-1]
|
|
|
|
|
|
|
|
|
def _normalize_durations_to_total(dt: np.ndarray, total: float, min_dt: float) -> np.ndarray:
|
|
|
"""在满足最小时长约束的前提下,将各段时长缩放到指定总时长。"""
|
|
|
dt = np.maximum(dt.astype(np.float64), float(min_dt))
|
|
|
s = float(np.sum(dt))
|
|
|
total = max(float(total), float(min_dt) * float(len(dt)))
|
|
|
if s <= 0:
|
|
|
return np.full_like(dt, total / len(dt))
|
|
|
dt = dt * (total / s)
|
|
|
dt = np.maximum(dt, float(min_dt))
|
|
|
diff = total - float(np.sum(dt))
|
|
|
dt[-1] += diff
|
|
|
return np.maximum(dt, float(min_dt))
|
|
|
|
|
|
|
|
|
def _family_id_map(cfg: Config) -> dict[str, int]:
|
|
|
"""把流量制度族名称映射为整数 id,便于写入模型特征和元数据。"""
|
|
|
mode = str(cfg.raw["schedule"]["generation_mode"]).lower()
|
|
|
if mode == "family_random":
|
|
|
families = cfg.raw["schedule"]["family_random"]["families"]
|
|
|
mapping = {str(item.get("name", f"family_{i}")).lower(): i for i, item in enumerate(families)}
|
|
|
mapping.setdefault("unknown", -1)
|
|
|
return mapping
|
|
|
return {"fixed_case": 0, "case_neighborhood": 1, "unknown": -1}
|
|
|
|
|
|
|
|
|
def build_schedule_metadata(
|
|
|
cfg: Config,
|
|
|
timeQ: list[float],
|
|
|
q: list[float],
|
|
|
family_name: str,
|
|
|
section_index: int,
|
|
|
) -> tuple[np.ndarray, str]:
|
|
|
"""从分段时长和流量中提取流量制度形态特征,例如生产段数、流量趋势和递减强度。"""
|
|
|
dt = np.asarray(timeQ, dtype=np.float64).reshape(-1)
|
|
|
qq = np.asarray(q, dtype=np.float64).reshape(-1)
|
|
|
n_sections = int(len(dt))
|
|
|
q_thr = float((cfg.raw["schedule"].get("canonicalize_for_model", {}) or {}).get("q_thr", 1e-6))
|
|
|
|
|
|
has_shutin = bool(n_sections > 0 and qq[-1] <= q_thr)
|
|
|
n_prod = int(max(n_sections - 1, 0) if has_shutin else n_sections)
|
|
|
prod_dt = dt[:n_prod] if n_prod > 0 else np.zeros((0,), dtype=np.float64)
|
|
|
prod_q = qq[:n_prod] if n_prod > 0 else np.zeros((0,), dtype=np.float64)
|
|
|
shutin_dt = float(dt[-1]) if has_shutin else 0.0
|
|
|
|
|
|
family_name = str(family_name).lower()
|
|
|
family_id = int(_family_id_map(cfg).get(family_name, -1))
|
|
|
|
|
|
meta_vec = np.asarray(
|
|
|
[
|
|
|
float(family_id),
|
|
|
float(section_index),
|
|
|
float(n_sections),
|
|
|
float(n_prod),
|
|
|
float(np.sum(prod_dt)) if n_prod > 0 else 0.0,
|
|
|
float(shutin_dt),
|
|
|
float(prod_q[0]) if n_prod > 0 else 0.0,
|
|
|
float(prod_q[-1]) if n_prod > 0 else 0.0,
|
|
|
float(np.mean(prod_q)) if n_prod > 0 else 0.0,
|
|
|
float(np.std(prod_q)) if n_prod > 0 else 0.0,
|
|
|
float(np.min(prod_q)) if n_prod > 0 else 0.0,
|
|
|
float(np.max(prod_q)) if n_prod > 0 else 0.0,
|
|
|
],
|
|
|
dtype=np.float32,
|
|
|
)
|
|
|
return meta_vec, family_name
|
|
|
|
|
|
|
|
|
def _sample_schedule_fixed_case(cfg: Config) -> tuple[list[float], list[float], dict]:
|
|
|
"""返回配置中固定的流量制度,用于基准案例或可复现实验。"""
|
|
|
sc = cfg.raw["schedule"]["case_schedule"]
|
|
|
return list(map(float, sc["timeQ"])), list(map(float, sc["q"])), {"family_name": "fixed_case"}
|
|
|
|
|
|
|
|
|
def _sample_schedule_case_neighborhood(cfg: Config, rng: np.random.RandomState) -> tuple[list[float], list[float], dict]:
|
|
|
"""围绕固定流量制度做局部扰动,生成相近但不完全相同的制度。"""
|
|
|
base = cfg.raw["schedule"]["case_schedule"]
|
|
|
ncfg = cfg.raw["schedule"]["case_neighborhood"]
|
|
|
|
|
|
base_timeQ = np.asarray(base["timeQ"], dtype=np.float64)
|
|
|
base_q = np.asarray(base["q"], dtype=np.float64)
|
|
|
out_t = base_timeQ.copy()
|
|
|
out_q = base_q.copy()
|
|
|
prod_n = len(base_timeQ) - 1
|
|
|
|
|
|
noise_t = rng.uniform(-float(ncfg["dt_jitter_rel"]), float(ncfg["dt_jitter_rel"]), size=prod_n)
|
|
|
out_t[:prod_n] = np.maximum(base_timeQ[:prod_n] * (1.0 + noise_t), float(ncfg["min_dt"]))
|
|
|
tail_noise = rng.uniform(-float(ncfg["shutin_dt_jitter_rel"]), float(ncfg["shutin_dt_jitter_rel"]))
|
|
|
out_t[-1] = max(base_timeQ[-1] * (1.0 + tail_noise), float(ncfg["min_dt"]))
|
|
|
|
|
|
noise_q = rng.uniform(-float(ncfg["q_jitter_rel"]), float(ncfg["q_jitter_rel"]), size=prod_n)
|
|
|
out_q[:prod_n] = np.clip(base_q[:prod_n] * (1.0 + noise_q), float(ncfg["q_min"]), float(ncfg["q_max"]))
|
|
|
|
|
|
if bool(ncfg["keep_monotonic_prod"]) and prod_n > 1:
|
|
|
out_q[:prod_n] = np.maximum.accumulate(out_q[:prod_n])
|
|
|
if bool(ncfg["keep_last_q_zero"]):
|
|
|
out_q[-1] = 0.0
|
|
|
|
|
|
return out_t.tolist(), out_q.tolist(), {"family_name": "case_neighborhood"}
|
|
|
|
|
|
|
|
|
def _sample_schedule_family_random(
|
|
|
cfg: Config,
|
|
|
rng: np.random.RandomState,
|
|
|
family_override: str | None = None,
|
|
|
) -> tuple[list[float], list[float], dict]:
|
|
|
"""按 family_random 配置随机生成不同类别的生产/关井制度。"""
|
|
|
fcfg = cfg.raw["schedule"]["family_random"]
|
|
|
if family_override is None:
|
|
|
fam = _pick_mixture(rng, fcfg["families"])
|
|
|
fam_name = str(fam.get("name", "inc_tail_shutin")).lower()
|
|
|
else:
|
|
|
fam_name = str(family_override).lower()
|
|
|
|
|
|
n_lo, n_hi = fcfg["n_prod_sections_range"]
|
|
|
n_prod = int(rng.randint(int(n_lo), int(n_hi) + 1))
|
|
|
|
|
|
prod_total_lo, prod_total_hi = fcfg["prod_total_time_range"]
|
|
|
prod_total = float(prod_total_lo + rng.rand() * (prod_total_hi - prod_total_lo))
|
|
|
|
|
|
mu = float(fcfg["duration_lognormal_mu"])
|
|
|
sigma = float(fcfg["duration_lognormal_sigma"])
|
|
|
dt_prod = rng.lognormal(mean=mu, sigma=sigma, size=n_prod).astype(np.float64)
|
|
|
dt_prod = _normalize_durations_to_total(dt_prod, total=prod_total, min_dt=0.05)
|
|
|
|
|
|
q_lo, q_hi = fcfg["q_range"]
|
|
|
q0 = float(q_lo + rng.rand() * (q_hi - q_lo))
|
|
|
max_rel_step = float(fcfg["max_rel_step"])
|
|
|
mult_noise_sigma = float(fcfg["mult_noise_sigma"])
|
|
|
step_jump_lo, step_jump_hi = fcfg["step_jump_rel_range"]
|
|
|
|
|
|
q_prod = np.zeros((n_prod,), dtype=np.float64)
|
|
|
q_prod[0] = q0
|
|
|
|
|
|
if fam_name == "inc_tail_shutin":
|
|
|
for i in range(1, n_prod):
|
|
|
q_prod[i] = q_prod[i - 1] * rng.uniform(1.02, max_rel_step)
|
|
|
q_prod = np.maximum.accumulate(q_prod)
|
|
|
elif fam_name == "dec_tail_shutin":
|
|
|
for i in range(1, n_prod):
|
|
|
q_prod[i] = max(q_prod[i - 1] * rng.uniform(1.0 / max_rel_step, 0.98), q_lo)
|
|
|
q_prod = np.minimum.accumulate(q_prod)
|
|
|
elif fam_name == "mild_step_tail_shutin":
|
|
|
q_prod[:] = q0
|
|
|
jump_idx = int(rng.randint(1, max(2, n_prod)))
|
|
|
q_prod[jump_idx:] *= float(rng.uniform(step_jump_lo, step_jump_hi))
|
|
|
else:
|
|
|
for i in range(1, n_prod):
|
|
|
rel = rng.uniform(1.0 / max_rel_step, max_rel_step)
|
|
|
rel = 1.0 + 0.15 * (rel - 1.0)
|
|
|
q_prod[i] = q_prod[i - 1] * rel
|
|
|
|
|
|
if mult_noise_sigma > 0:
|
|
|
q_prod *= np.exp(rng.normal(loc=0.0, scale=mult_noise_sigma, size=n_prod))
|
|
|
q_prod = np.clip(q_prod, q_lo, q_hi)
|
|
|
|
|
|
shut_lo, shut_hi = fcfg["shutin_dt_range"]
|
|
|
dt_shut = float(shut_lo + rng.rand() * (shut_hi - shut_lo))
|
|
|
return dt_prod.tolist() + [dt_shut], q_prod.tolist() + [0.0], {"family_name": fam_name}
|
|
|
|
|
|
|
|
|
def sample_schedule_by_mode(
|
|
|
cfg: Config,
|
|
|
rng: np.random.RandomState,
|
|
|
family_override: str | None = None,
|
|
|
) -> tuple[list[float], list[float], dict]:
|
|
|
"""根据配置的 generation_mode 调用对应的流量制度采样函数。"""
|
|
|
mode = str(cfg.raw["schedule"]["generation_mode"]).lower()
|
|
|
if mode == "fixed_case":
|
|
|
return _sample_schedule_fixed_case(cfg)
|
|
|
if mode == "case_neighborhood":
|
|
|
return _sample_schedule_case_neighborhood(cfg, rng)
|
|
|
if mode == "family_random":
|
|
|
return _sample_schedule_family_random(cfg, rng, family_override=family_override)
|
|
|
raise ValueError(f"Unknown schedule generation_mode: {mode}")
|
|
|
|
|
|
|
|
|
def _resolve_section_indices(cfg: Config, timeQ, q, rng: np.random.RandomState) -> list[int]:
|
|
|
"""解析允许作为 sectionIndex 的分段范围,并裁剪到当前流量制度长度内。"""
|
|
|
policy = cfg.raw["schedule"]["section_policy"]
|
|
|
mode = str(policy["mode"]).lower()
|
|
|
n = int(len(timeQ))
|
|
|
if mode == "fixed_last":
|
|
|
return [n]
|
|
|
if mode == "fixed_value":
|
|
|
return [int(np.clip(int(policy["fixed_value"]), 1, n))]
|
|
|
if mode == "all_sections":
|
|
|
return list(range(1, n + 1))
|
|
|
if mode == "uniform_one":
|
|
|
return [int(rng.randint(1, n + 1))]
|
|
|
raise ValueError(f"Unknown section_policy.mode: {mode}")
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
|
"""解析自动拟合邻域数据生成所需的锚点数量、扰动尺度和输出路径。"""
|
|
|
parser = argparse.ArgumentParser(description="Generate anchor-neighborhood autofit dataset for generalized local ranking")
|
|
|
parser.add_argument("--config", type=str, default=None)
|
|
|
parser.add_argument(
|
|
|
"--stage",
|
|
|
choices=["fixed_case", "case_neighborhood", "family_random", "family_random_hard"],
|
|
|
default="family_random",
|
|
|
)
|
|
|
parser.add_argument("--output", type=str, default=None, help="Output HDF5 path")
|
|
|
parser.add_argument("--tag", type=str, default="family_random_autofit_neighborhood")
|
|
|
parser.add_argument("--n-anchors", type=int, default=64)
|
|
|
parser.add_argument("--neighbors-per-anchor", type=int, default=24)
|
|
|
parser.add_argument("--max-attempts-factor", type=int, default=4)
|
|
|
parser.add_argument("--anchor-max-attempts-factor", type=int, default=5)
|
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
|
parser.add_argument("--span-frac", type=float, default=0.08)
|
|
|
parser.add_argument(
|
|
|
"--span-fracs",
|
|
|
type=str,
|
|
|
default=None,
|
|
|
help="Optional comma-separated span list, e.g. 0.02,0.05,0.10; overrides --span-frac",
|
|
|
)
|
|
|
parser.add_argument("--max-perturbed-dims", type=int, default=3)
|
|
|
parser.add_argument("--balance-families", action="store_true", default=True)
|
|
|
parser.add_argument("--objective-bins", type=int, default=3)
|
|
|
parser.add_argument("--solver-timeout", type=int, default=120)
|
|
|
parser.add_argument("--well-index", type=int, default=0)
|
|
|
parser.add_argument(
|
|
|
"--use-runner-server",
|
|
|
action="store_true",
|
|
|
help="Use runner --server mode; faster but less robust for long neighborhood generation",
|
|
|
)
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
def resolve_span_fracs(args: argparse.Namespace) -> list[float]:
|
|
|
"""解析邻域扰动尺度列表,并保证至少包含一个有效尺度。"""
|
|
|
if args.span_fracs is None:
|
|
|
values = [float(args.span_frac)]
|
|
|
else:
|
|
|
values = []
|
|
|
for item in str(args.span_fracs).split(","):
|
|
|
item = item.strip()
|
|
|
if not item:
|
|
|
continue
|
|
|
values.append(float(item))
|
|
|
|
|
|
cleaned = sorted({round(float(x), 10) for x in values if float(x) > 0.0})
|
|
|
if not cleaned:
|
|
|
raise ValueError("At least one positive span fraction is required")
|
|
|
return [float(x) for x in cleaned]
|
|
|
|
|
|
|
|
|
def resolve_config(args: argparse.Namespace) -> Config:
|
|
|
"""读取配置文件并处理脚本参数对配置的覆盖。"""
|
|
|
config_path = args.config
|
|
|
if config_path is None:
|
|
|
config_path = str(config_for_stage(args.stage) or Path("configs/data_gen_family_random.yaml"))
|
|
|
return Config(config_path)
|
|
|
|
|
|
|
|
|
def resolve_output_path(cfg: Config, args: argparse.Namespace) -> Path:
|
|
|
"""确定邻域数据集的 HDF5 输出路径。"""
|
|
|
if args.output is not None:
|
|
|
return Path(args.output).resolve()
|
|
|
tag = normalize_tag(args.tag) or "autofit_neighborhood"
|
|
|
return (cfg.paths.samples_dir / f"{tag}.h5").resolve()
|
|
|
|
|
|
|
|
|
def build_schedule_vector(cfg: Config, schedule: Schedule) -> np.ndarray:
|
|
|
"""把 Schedule 编码成正演代理模型可直接接收的流量制度特征向量。"""
|
|
|
enc = encode_schedule_to_timegrid(
|
|
|
cfg,
|
|
|
sectionIndex=int(schedule.sectionIndex),
|
|
|
timeQ=schedule.timeQ,
|
|
|
q=schedule.q,
|
|
|
n_sections=len(schedule.timeQ),
|
|
|
)
|
|
|
return np.concatenate(
|
|
|
[
|
|
|
np.asarray(enc.x_sched, dtype=np.float32).reshape(-1),
|
|
|
np.asarray(enc.x_sec, dtype=np.float32).reshape(-1),
|
|
|
],
|
|
|
axis=0,
|
|
|
).astype(np.float32)
|
|
|
|
|
|
|
|
|
def run_solver_and_extract_curve(
|
|
|
runner: CppRunner,
|
|
|
cfg: Config,
|
|
|
params: Params,
|
|
|
well_index: int,
|
|
|
timeout: int,
|
|
|
) -> tuple[np.ndarray, dict]:
|
|
|
"""调用 C++ 求解器运行一次正演,并把双对数输出重采样为模型曲线向量。"""
|
|
|
ok = runner.run_simulation(params, timeout=timeout, override_schedule=params.schedule, include_schedule=True)
|
|
|
result = read_result_bin(runner.result_bin) if runner.result_bin.exists() else None
|
|
|
if not ok and result is None:
|
|
|
raise RuntimeError("solver_failed_no_result")
|
|
|
if result is None or not result["loglog"]:
|
|
|
raise RuntimeError("solver_no_loglog")
|
|
|
if well_index < 0 or well_index >= len(result["loglog"]):
|
|
|
raise RuntimeError(f"well_index_out_of_range_{well_index}")
|
|
|
|
|
|
loglog = result["loglog"][well_index]
|
|
|
t = np.asarray(loglog["t"], dtype=np.float64)
|
|
|
p = np.asarray(loglog["p"], dtype=np.float64)
|
|
|
d = np.asarray(loglog["deriv"], dtype=np.float64)
|
|
|
|
|
|
cleaned = clean_curve_for_dataset(cfg, t, p, d)
|
|
|
if cleaned is None:
|
|
|
raise RuntimeError("curve_clean_failed")
|
|
|
t_clean, p_clean, d_clean = cleaned
|
|
|
|
|
|
valid, reason = is_valid_curve(cfg, t_clean, p_clean, d_clean)
|
|
|
if not valid:
|
|
|
raise RuntimeError(f"curve_invalid_{reason}")
|
|
|
|
|
|
curve_feat = resample_curve_to_features(cfg, t_clean, p_clean, d_clean)
|
|
|
raw = {
|
|
|
"t": t_clean.tolist(),
|
|
|
"p": p_clean.tolist(),
|
|
|
"d": d_clean.tolist(),
|
|
|
"n_steps": int(result["nSteps"]),
|
|
|
"n_wells": int(result["nWells"]),
|
|
|
}
|
|
|
return curve_feat, raw
|
|
|
|
|
|
|
|
|
def params_to_array(params: Params) -> np.ndarray:
|
|
|
"""按固定参数顺序把 Params 对象转换成数值数组。"""
|
|
|
return np.asarray(
|
|
|
[params.k, params.skin, params.wellboreC, params.phi, params.h, params.Cf],
|
|
|
dtype=np.float32,
|
|
|
)
|
|
|
|
|
|
|
|
|
def sample_anchor_params_and_schedule(
|
|
|
cfg: Config,
|
|
|
rng: np.random.RandomState,
|
|
|
family_override: str | None = None,
|
|
|
) -> tuple[Params, np.ndarray, str]:
|
|
|
"""为一个锚点同时采样物理参数和流量制度,作为后续邻域搜索中心。"""
|
|
|
params = generate_params_dataset(cfg, n_samples=1, method="sobol", random_seed=int(rng.randint(0, 2**31 - 1)))[0]
|
|
|
timeQ, q, sched_info = sample_schedule_by_mode(cfg, rng, family_override=family_override)
|
|
|
section_indices = _resolve_section_indices(cfg, timeQ, q, rng)
|
|
|
sec = int(section_indices[int(rng.randint(0, len(section_indices)))])
|
|
|
schedule = Schedule(sectionIndex=sec, timeQ=list(map(float, timeQ)), q=list(map(float, q)))
|
|
|
params.schedule = schedule
|
|
|
|
|
|
meta_vec, family_name = build_schedule_metadata(
|
|
|
cfg=cfg,
|
|
|
timeQ=schedule.timeQ,
|
|
|
q=schedule.q,
|
|
|
family_name=str(sched_info.get("family_name", "unknown")),
|
|
|
section_index=int(schedule.sectionIndex),
|
|
|
)
|
|
|
return params, meta_vec.astype(np.float32), str(family_name)
|
|
|
|
|
|
|
|
|
def build_family_plan(cfg: Config, n_anchors: int, balance_families: bool) -> list[str | None]:
|
|
|
"""根据目标样本数和流量制度族配置,计算每个族需要生成多少锚点。"""
|
|
|
mode = str(cfg.raw["schedule"]["generation_mode"]).lower()
|
|
|
if mode != "family_random":
|
|
|
return [None] * int(n_anchors)
|
|
|
|
|
|
families = [str(item.get("name", "")).lower() for item in cfg.raw["schedule"]["family_random"]["families"]]
|
|
|
families = [x for x in families if x]
|
|
|
if not balance_families or not families:
|
|
|
return [None] * int(n_anchors)
|
|
|
|
|
|
base = int(n_anchors) // len(families)
|
|
|
rem = int(n_anchors) % len(families)
|
|
|
plan: list[str | None] = []
|
|
|
for i, fam in enumerate(families):
|
|
|
count = base + (1 if i < rem else 0)
|
|
|
plan.extend([fam] * count)
|
|
|
return plan
|
|
|
|
|
|
|
|
|
def _fixed_param_value(cfg: Config, name: str) -> float | None:
|
|
|
"""读取固定参数值;优先使用命令行覆盖,其次使用配置默认值。"""
|
|
|
fixed_cfg = ((cfg.raw["params"].get("fixed_params", {}) or {}).get(name, {}) or {})
|
|
|
if bool(fixed_cfg.get("enabled", False)):
|
|
|
return float(fixed_cfg["value"])
|
|
|
return None
|
|
|
|
|
|
|
|
|
def search_param_names(cfg: Config) -> list[str]:
|
|
|
"""确定局部邻域中需要扰动搜索的物理参数名称。"""
|
|
|
active_names = list(cfg.raw["params"].get("active_param_names", cfg.raw["params"]["all_physical_param_names"]))
|
|
|
names = [name for name in active_names if _fixed_param_value(cfg, name) is None]
|
|
|
if not names:
|
|
|
raise ValueError("No searchable parameters: active_param_names is empty after removing fixed_params")
|
|
|
return names
|
|
|
|
|
|
|
|
|
def sample_neighbor_params(
|
|
|
cfg: Config,
|
|
|
base: Params,
|
|
|
rng: np.random.RandomState,
|
|
|
span_frac: float,
|
|
|
max_perturbed_dims: int,
|
|
|
) -> Params:
|
|
|
"""围绕锚点参数按扰动尺度生成邻域候选参数,并记录扰动元数据。"""
|
|
|
all_names = list(cfg.raw["params"]["all_physical_param_names"])
|
|
|
names = search_param_names(cfg)
|
|
|
log_params = set(cfg.raw["params"]["log_params"])
|
|
|
base_dict = {
|
|
|
"k": float(base.k),
|
|
|
"skin": float(base.skin),
|
|
|
"wellboreC": float(base.wellboreC),
|
|
|
"phi": float(base.phi),
|
|
|
"h": float(base.h),
|
|
|
"Cf": float(base.Cf),
|
|
|
}
|
|
|
|
|
|
n_change = int(rng.randint(1, min(max_perturbed_dims, len(names)) + 1))
|
|
|
chosen = set(rng.choice(names, size=n_change, replace=False).tolist())
|
|
|
|
|
|
span_scale = {
|
|
|
"k": 1.0,
|
|
|
"skin": 1.0,
|
|
|
"wellboreC": 0.75,
|
|
|
"phi": 0.60,
|
|
|
"h": 0.75,
|
|
|
"Cf": 0.40,
|
|
|
}
|
|
|
cand = dict(base_dict)
|
|
|
for name in all_names:
|
|
|
fixed_value = _fixed_param_value(cfg, name)
|
|
|
if fixed_value is not None:
|
|
|
cand[name] = fixed_value
|
|
|
continue
|
|
|
|
|
|
lo, hi = map(float, cfg.raw["params"]["ranges"][name])
|
|
|
base_val = float(base_dict[name])
|
|
|
if name not in chosen:
|
|
|
cand[name] = base_val
|
|
|
continue
|
|
|
|
|
|
width = float(span_frac) * float(span_scale.get(name, 1.0))
|
|
|
if name in log_params:
|
|
|
# 跨数量级参数在 log10 空间扰动,保证“相对变化幅度”更均匀。
|
|
|
lo_t = np.log10(max(lo, 1e-30))
|
|
|
hi_t = np.log10(max(hi, 1e-30))
|
|
|
margin_t = 0.01 * (hi_t - lo_t)
|
|
|
inner_lo_t = lo_t + margin_t
|
|
|
inner_hi_t = hi_t - margin_t
|
|
|
if inner_hi_t <= inner_lo_t:
|
|
|
inner_lo_t, inner_hi_t = lo_t, hi_t
|
|
|
base_t = np.clip(np.log10(max(base_val, 1e-30)), inner_lo_t, inner_hi_t)
|
|
|
span = width * (hi_t - lo_t)
|
|
|
v_t = np.clip(rng.uniform(base_t - span, base_t + span), inner_lo_t, inner_hi_t)
|
|
|
cand[name] = float(10 ** v_t)
|
|
|
else:
|
|
|
# 线性参数直接在物理量空间扰动,并留出 1% 边界余量避免贴边样本过多。
|
|
|
margin = 0.01 * (hi - lo)
|
|
|
inner_lo = lo + margin
|
|
|
inner_hi = hi - margin
|
|
|
if inner_hi <= inner_lo:
|
|
|
inner_lo, inner_hi = lo, hi
|
|
|
span = width * (inner_hi - inner_lo)
|
|
|
cand[name] = float(np.clip(rng.uniform(base_val - span, base_val + span), inner_lo, inner_hi))
|
|
|
|
|
|
return Params(
|
|
|
k=cand["k"],
|
|
|
skin=cand["skin"],
|
|
|
wellboreC=cand["wellboreC"],
|
|
|
phi=cand["phi"],
|
|
|
h=cand["h"],
|
|
|
Cf=cand["Cf"],
|
|
|
schedule=base.schedule,
|
|
|
)
|
|
|
|
|
|
|
|
|
def select_objective_stratified_indices(objectives: np.ndarray, k: int, objective_bins: int) -> list[int]:
|
|
|
"""按真实目标函数分层抽样邻域候选,避免只保留单一难度样本。"""
|
|
|
objectives = np.asarray(objectives, dtype=np.float64).reshape(-1)
|
|
|
if objectives.size <= k:
|
|
|
return list(range(int(objectives.size)))
|
|
|
|
|
|
rows_sorted = np.argsort(objectives, kind="stable")
|
|
|
bins = max(int(objective_bins), 1)
|
|
|
chunks = np.array_split(rows_sorted, bins)
|
|
|
|
|
|
# 从好到差的目标函数区间轮流取样,保留“容易/中等/困难”候选的排序信息。
|
|
|
selected: list[int] = []
|
|
|
ptrs = [0 for _ in chunks]
|
|
|
while len(selected) < k:
|
|
|
progressed = False
|
|
|
for bin_idx, chunk in enumerate(chunks):
|
|
|
if len(selected) >= k:
|
|
|
break
|
|
|
if ptrs[bin_idx] >= len(chunk):
|
|
|
continue
|
|
|
row_idx = int(chunk[ptrs[bin_idx]])
|
|
|
selected.append(row_idx)
|
|
|
ptrs[bin_idx] += 1
|
|
|
progressed = True
|
|
|
if not progressed:
|
|
|
break
|
|
|
return selected[:k]
|
|
|
|
|
|
|
|
|
def select_multiscale_rows(
|
|
|
rows: list[tuple[np.ndarray, np.ndarray, dict[str, float], float]],
|
|
|
k: int,
|
|
|
objective_bins: int,
|
|
|
span_fracs: list[float],
|
|
|
) -> list[tuple[np.ndarray, np.ndarray, dict[str, float], float]]:
|
|
|
"""从不同扰动尺度的候选中挑选代表性行,形成多尺度邻域数据。"""
|
|
|
if len(rows) <= k or len(span_fracs) <= 1:
|
|
|
indices = select_objective_stratified_indices(
|
|
|
np.asarray([float(x[2]["dual_log_objective"]) for x in rows], dtype=np.float64),
|
|
|
k=k,
|
|
|
objective_bins=objective_bins,
|
|
|
)
|
|
|
return [rows[i] for i in indices]
|
|
|
|
|
|
selected_indices: list[int] = []
|
|
|
remaining_indices: list[int] = []
|
|
|
|
|
|
# 先给每个扰动尺度分配基础名额,避免最终数据只来自某一个 span_frac。
|
|
|
base = k // len(span_fracs)
|
|
|
rem = k % len(span_fracs)
|
|
|
|
|
|
for span_idx, span_frac in enumerate(span_fracs):
|
|
|
target = base + (1 if span_idx < rem else 0)
|
|
|
local_indices = [i for i, row in enumerate(rows) if abs(float(row[3]) - float(span_frac)) <= 1e-10]
|
|
|
if not local_indices:
|
|
|
continue
|
|
|
|
|
|
local_objectives = np.asarray(
|
|
|
[float(rows[i][2]["dual_log_objective"]) for i in local_indices],
|
|
|
dtype=np.float64,
|
|
|
)
|
|
|
chosen_local_pos = select_objective_stratified_indices(
|
|
|
local_objectives,
|
|
|
k=min(target, len(local_indices)),
|
|
|
objective_bins=objective_bins,
|
|
|
)
|
|
|
chosen_set = {int(pos) for pos in chosen_local_pos}
|
|
|
selected_indices.extend(local_indices[pos] for pos in chosen_local_pos)
|
|
|
remaining_indices.extend(local_indices[pos] for pos in range(len(local_indices)) if pos not in chosen_set)
|
|
|
|
|
|
if len(selected_indices) < k and remaining_indices:
|
|
|
remaining_objectives = np.asarray(
|
|
|
[float(rows[i][2]["dual_log_objective"]) for i in remaining_indices],
|
|
|
dtype=np.float64,
|
|
|
)
|
|
|
extra_pos = select_objective_stratified_indices(
|
|
|
remaining_objectives,
|
|
|
k=min(k - len(selected_indices), len(remaining_indices)),
|
|
|
objective_bins=objective_bins,
|
|
|
)
|
|
|
selected_indices.extend(remaining_indices[pos] for pos in extra_pos)
|
|
|
|
|
|
selected_indices = selected_indices[:k]
|
|
|
return [rows[i] for i in selected_indices]
|
|
|
|
|
|
|
|
|
def build_span_targets(total_keep: int, span_fracs: list[float]) -> dict[float, int]:
|
|
|
"""构造各扰动尺度对应的目标候选数量。"""
|
|
|
if not span_fracs:
|
|
|
return {}
|
|
|
base = int(total_keep) // len(span_fracs)
|
|
|
rem = int(total_keep) % len(span_fracs)
|
|
|
targets: dict[float, int] = {}
|
|
|
for idx, span_frac in enumerate(span_fracs):
|
|
|
targets[float(span_frac)] = base + (1 if idx < rem else 0)
|
|
|
return targets
|
|
|
|
|
|
|
|
|
def create_output_file(
|
|
|
output_path: Path,
|
|
|
param_dim: int,
|
|
|
schedule_dim: int,
|
|
|
curve_dim: int,
|
|
|
span_fracs: list[float],
|
|
|
search_names: list[str],
|
|
|
) -> h5py.File:
|
|
|
"""创建邻域 HDF5 文件,并初始化锚点、候选、曲线和元数据数据集。"""
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
f = h5py.File(output_path, "w")
|
|
|
f.attrs["param_names"] = np.asarray(["k", "skin", "wellboreC", "phi", "h", "Cf"], dtype="S")
|
|
|
f.attrs["schedule_meta_names"] = np.asarray(SCHEDULE_META_NAMES, dtype="S")
|
|
|
f.attrs["span_fracs"] = np.asarray(span_fracs, dtype=np.float32)
|
|
|
f.attrs["search_param_names"] = np.asarray(search_names, dtype="S")
|
|
|
|
|
|
n_anchors = 0
|
|
|
n_neighbors = 0
|
|
|
# anchor_* 保存中心样本;neighbor_* 保存围绕该中心扰动后的候选及其真实目标函数。
|
|
|
f.create_dataset("anchor_params", shape=(0, param_dim), maxshape=(None, param_dim), dtype=np.float32, chunks=(64, param_dim))
|
|
|
f.create_dataset("anchor_schedule", shape=(0, schedule_dim), maxshape=(None, schedule_dim), dtype=np.float32, chunks=(64, schedule_dim))
|
|
|
f.create_dataset("anchor_curve", shape=(0, curve_dim), maxshape=(None, curve_dim), dtype=np.float32, chunks=(64, curve_dim))
|
|
|
f.create_dataset("anchor_schedule_meta", shape=(0, len(SCHEDULE_META_NAMES)), maxshape=(None, len(SCHEDULE_META_NAMES)), dtype=np.float32, chunks=(64, len(SCHEDULE_META_NAMES)))
|
|
|
f.create_dataset("anchor_family_name", shape=(0,), maxshape=(None,), dtype=h5py.string_dtype(encoding="utf-8"), chunks=(64,))
|
|
|
f.create_dataset("anchor_section_index", shape=(0,), maxshape=(None,), dtype=np.int32, chunks=(64,))
|
|
|
f.create_dataset("anchor_timeQ_json", shape=(0,), maxshape=(None,), dtype=h5py.string_dtype(encoding="utf-8"), chunks=(64,))
|
|
|
f.create_dataset("anchor_q_json", shape=(0,), maxshape=(None,), dtype=h5py.string_dtype(encoding="utf-8"), chunks=(64,))
|
|
|
|
|
|
f.create_dataset("neighbor_anchor_id", shape=(0,), maxshape=(None,), dtype=np.int32, chunks=(256,))
|
|
|
f.create_dataset("neighbor_params", shape=(0, param_dim), maxshape=(None, param_dim), dtype=np.float32, chunks=(256, param_dim))
|
|
|
f.create_dataset("neighbor_curve", shape=(0, curve_dim), maxshape=(None, curve_dim), dtype=np.float32, chunks=(256, curve_dim))
|
|
|
f.create_dataset("neighbor_objective", shape=(0,), maxshape=(None,), dtype=np.float32, chunks=(256,))
|
|
|
f.create_dataset("neighbor_objective_p", shape=(0,), maxshape=(None,), dtype=np.float32, chunks=(256,))
|
|
|
f.create_dataset("neighbor_objective_d", shape=(0,), maxshape=(None,), dtype=np.float32, chunks=(256,))
|
|
|
f.create_dataset("neighbor_span_frac", shape=(0,), maxshape=(None,), dtype=np.float32, chunks=(256,))
|
|
|
f.attrs["n_anchors"] = n_anchors
|
|
|
f.attrs["n_neighbors"] = n_neighbors
|
|
|
return f
|
|
|
|
|
|
|
|
|
def append_anchor(
|
|
|
f: h5py.File,
|
|
|
anchor_params: np.ndarray,
|
|
|
schedule_vec: np.ndarray,
|
|
|
curve: np.ndarray,
|
|
|
schedule_meta: np.ndarray,
|
|
|
family_name: str,
|
|
|
section_index: int,
|
|
|
timeQ: list[float],
|
|
|
q: list[float],
|
|
|
) -> int:
|
|
|
"""向 HDF5 写入锚点样本及其参数、流量制度和曲线。"""
|
|
|
idx = int(f.attrs["n_anchors"])
|
|
|
end = idx + 1
|
|
|
for name, value in [
|
|
|
("anchor_params", anchor_params.reshape(1, -1)),
|
|
|
("anchor_schedule", schedule_vec.reshape(1, -1)),
|
|
|
("anchor_curve", curve.reshape(1, -1)),
|
|
|
("anchor_schedule_meta", schedule_meta.reshape(1, -1)),
|
|
|
]:
|
|
|
ds = f[name]
|
|
|
ds.resize((end, ds.shape[1]))
|
|
|
ds[idx:end] = value
|
|
|
for name, value in [
|
|
|
("anchor_family_name", family_name),
|
|
|
("anchor_section_index", int(section_index)),
|
|
|
("anchor_timeQ_json", json.dumps(list(map(float, timeQ)))),
|
|
|
("anchor_q_json", json.dumps(list(map(float, q)))),
|
|
|
]:
|
|
|
ds = f[name]
|
|
|
ds.resize((end,))
|
|
|
ds[idx] = value
|
|
|
f.attrs["n_anchors"] = end
|
|
|
return idx
|
|
|
|
|
|
|
|
|
def append_neighbors(
|
|
|
f: h5py.File,
|
|
|
anchor_id: int,
|
|
|
params_list: list[np.ndarray],
|
|
|
curve_list: list[np.ndarray],
|
|
|
obj_list: list[dict[str, float]],
|
|
|
span_frac_list: list[float],
|
|
|
) -> None:
|
|
|
"""向 HDF5 批量写入邻域候选样本及其目标函数和扰动信息。"""
|
|
|
start = int(f.attrs["n_neighbors"])
|
|
|
B = len(params_list)
|
|
|
end = start + B
|
|
|
|
|
|
def resize_1d(name: str):
|
|
|
"""把一维可扩容数据集扩展到新的样本数。"""
|
|
|
ds = f[name]
|
|
|
ds.resize((end,))
|
|
|
return ds
|
|
|
|
|
|
def resize_2d(name: str):
|
|
|
"""把二维可扩容数据集扩展到新的样本数。"""
|
|
|
ds = f[name]
|
|
|
ds.resize((end, ds.shape[1]))
|
|
|
return ds
|
|
|
|
|
|
resize_1d("neighbor_anchor_id")[start:end] = int(anchor_id)
|
|
|
resize_2d("neighbor_params")[start:end] = np.stack(params_list, axis=0).astype(np.float32)
|
|
|
resize_2d("neighbor_curve")[start:end] = np.stack(curve_list, axis=0).astype(np.float32)
|
|
|
resize_1d("neighbor_objective")[start:end] = np.asarray([x["dual_log_objective"] for x in obj_list], dtype=np.float32)
|
|
|
resize_1d("neighbor_objective_p")[start:end] = np.asarray([x["log_pressure_objective"] for x in obj_list], dtype=np.float32)
|
|
|
resize_1d("neighbor_objective_d")[start:end] = np.asarray([x["log_derivative_objective"] for x in obj_list], dtype=np.float32)
|
|
|
resize_1d("neighbor_span_frac")[start:end] = np.asarray(span_frac_list, dtype=np.float32)
|
|
|
f.attrs["n_neighbors"] = end
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
"""围绕目标样本生成参数邻域,并用数值求解器计算候选曲线和目标函数。"""
|
|
|
args = parse_args()
|
|
|
cfg = resolve_config(args)
|
|
|
output_path = resolve_output_path(cfg, args)
|
|
|
rng = np.random.RandomState(int(args.seed))
|
|
|
span_fracs = resolve_span_fracs(args)
|
|
|
|
|
|
schedule_dim = int(np.prod(cfg.schedule_grid_shape)) + int(cfg.sec_feat_dim)
|
|
|
curve_dim = int(cfg.curve_dim)
|
|
|
param_dim = len(cfg.raw["params"]["all_physical_param_names"])
|
|
|
local_search_names = search_param_names(cfg)
|
|
|
|
|
|
f = create_output_file(
|
|
|
output_path,
|
|
|
param_dim=param_dim,
|
|
|
schedule_dim=schedule_dim,
|
|
|
curve_dim=curve_dim,
|
|
|
span_fracs=span_fracs,
|
|
|
search_names=local_search_names,
|
|
|
)
|
|
|
summary = {
|
|
|
"config_path": str(cfg.path),
|
|
|
"output_path": str(output_path),
|
|
|
"n_anchor_requested": int(args.n_anchors),
|
|
|
"neighbors_per_anchor_requested": int(args.neighbors_per_anchor),
|
|
|
"span_fracs": span_fracs,
|
|
|
"seed": int(args.seed),
|
|
|
"well_index": int(args.well_index),
|
|
|
"solver_timeout": int(args.solver_timeout),
|
|
|
"use_runner_server": bool(args.use_runner_server),
|
|
|
"max_perturbed_dims": int(args.max_perturbed_dims),
|
|
|
"search_param_names": local_search_names,
|
|
|
"objective_bins": int(args.objective_bins),
|
|
|
"anchor_fail_reasons": Counter(),
|
|
|
"neighbor_fail_reasons": Counter(),
|
|
|
"family_counter_attempted": Counter(),
|
|
|
"family_counter_kept": Counter(),
|
|
|
"neighbor_span_counter_valid": Counter(),
|
|
|
"neighbor_span_counter_kept": Counter(),
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
anchor_attempts = 0
|
|
|
max_anchor_attempts = max(int(args.n_anchors), int(args.n_anchors) * int(args.anchor_max_attempts_factor))
|
|
|
family_plan = build_family_plan(cfg, int(args.n_anchors), balance_families=bool(args.balance_families))
|
|
|
while int(f.attrs["n_anchors"]) < int(args.n_anchors) and anchor_attempts < max_anchor_attempts:
|
|
|
anchor_idx = anchor_attempts
|
|
|
anchor_attempts += 1
|
|
|
target_family = family_plan[int(f.attrs["n_anchors"])] if family_plan else None
|
|
|
anchor_params, schedule_meta, family_name = sample_anchor_params_and_schedule(cfg, rng, family_override=target_family)
|
|
|
summary["family_counter_attempted"][family_name] += 1
|
|
|
|
|
|
anchor_runner = CppRunner(
|
|
|
cfg=cfg,
|
|
|
auto_init=False,
|
|
|
temp_dir=(output_path.parent / f"{output_path.stem}_temp" / f"anchor_{anchor_idx:04d}").resolve(),
|
|
|
use_server=bool(args.use_runner_server),
|
|
|
)
|
|
|
try:
|
|
|
anchor_curve, _ = run_solver_and_extract_curve(
|
|
|
runner=anchor_runner,
|
|
|
cfg=cfg,
|
|
|
params=anchor_params,
|
|
|
well_index=int(args.well_index),
|
|
|
timeout=int(args.solver_timeout),
|
|
|
)
|
|
|
except Exception as exc:
|
|
|
summary["anchor_fail_reasons"][str(exc)] += 1
|
|
|
print(f"[warn] anchor {anchor_idx} skipped: {exc}")
|
|
|
anchor_runner.close()
|
|
|
continue
|
|
|
|
|
|
schedule_vec = build_schedule_vector(cfg, anchor_params.schedule)
|
|
|
anchor_id = append_anchor(
|
|
|
f=f,
|
|
|
anchor_params=params_to_array(anchor_params),
|
|
|
schedule_vec=schedule_vec,
|
|
|
curve=anchor_curve,
|
|
|
schedule_meta=schedule_meta,
|
|
|
family_name=family_name,
|
|
|
section_index=int(anchor_params.schedule.sectionIndex),
|
|
|
timeQ=anchor_params.schedule.timeQ,
|
|
|
q=anchor_params.schedule.q,
|
|
|
)
|
|
|
summary["family_counter_kept"][family_name] += 1
|
|
|
print(
|
|
|
f"[anchor {anchor_id:03d}] start family={family_name} "
|
|
|
f"target_neighbors={args.neighbors_per_anchor} spans={[float(x) for x in span_fracs]} "
|
|
|
f"use_server={bool(args.use_runner_server)}"
|
|
|
)
|
|
|
|
|
|
valid_neighbor_rows: list[tuple[np.ndarray, np.ndarray, dict[str, float], float]] = []
|
|
|
valid_span_counter: Counter[str] = Counter()
|
|
|
max_attempts = max(
|
|
|
int(args.neighbors_per_anchor),
|
|
|
int(args.neighbors_per_anchor) * int(args.max_attempts_factor),
|
|
|
) * len(span_fracs)
|
|
|
span_targets = build_span_targets(int(args.neighbors_per_anchor), span_fracs)
|
|
|
span_stop_targets = {
|
|
|
float(span_frac): max(int(target), int(target) * max(int(args.objective_bins), 1))
|
|
|
for span_frac, target in span_targets.items()
|
|
|
}
|
|
|
|
|
|
for attempt in range(max_attempts):
|
|
|
span_frac = float(span_fracs[attempt % len(span_fracs)])
|
|
|
cand = sample_neighbor_params(
|
|
|
cfg,
|
|
|
anchor_params,
|
|
|
rng,
|
|
|
span_frac=span_frac,
|
|
|
max_perturbed_dims=int(args.max_perturbed_dims),
|
|
|
)
|
|
|
try:
|
|
|
cand_curve, _ = run_solver_and_extract_curve(
|
|
|
runner=anchor_runner,
|
|
|
cfg=cfg,
|
|
|
params=cand,
|
|
|
well_index=int(args.well_index),
|
|
|
timeout=int(args.solver_timeout),
|
|
|
)
|
|
|
obj = dual_log_objective(anchor_curve, cand_curve, {"parts": [
|
|
|
{"name": "log_pressure", "start": 0, "end": curve_dim // 3},
|
|
|
{"name": "log_derivative", "start": curve_dim // 3, "end": 2 * curve_dim // 3},
|
|
|
{"name": "slope", "start": 2 * curve_dim // 3, "end": curve_dim},
|
|
|
]})
|
|
|
valid_neighbor_rows.append((params_to_array(cand), cand_curve.astype(np.float32), obj, span_frac))
|
|
|
summary["neighbor_span_counter_valid"][f"{span_frac:g}"] += 1
|
|
|
valid_span_counter[f"{span_frac:g}"] += 1
|
|
|
except Exception as exc:
|
|
|
summary["neighbor_fail_reasons"][str(exc)] += 1
|
|
|
|
|
|
enough_per_span = all(
|
|
|
valid_span_counter.get(f"{float(span_frac):g}", 0) >= int(span_stop_targets[float(span_frac)])
|
|
|
for span_frac in span_fracs
|
|
|
)
|
|
|
if enough_per_span:
|
|
|
print(
|
|
|
f"[anchor {anchor_id:03d}] early-stop attempts={attempt + 1}/{max_attempts} "
|
|
|
f"valid={len(valid_neighbor_rows)} per_span={dict(sorted(valid_span_counter.items()))}"
|
|
|
)
|
|
|
break
|
|
|
|
|
|
if (attempt + 1) % max(12, len(span_fracs) * 4) == 0:
|
|
|
print(
|
|
|
f"[anchor {anchor_id:03d}] progress attempts={attempt + 1}/{max_attempts} "
|
|
|
f"valid={len(valid_neighbor_rows)} per_span={dict(sorted(valid_span_counter.items()))}"
|
|
|
)
|
|
|
|
|
|
selected_rows = select_multiscale_rows(
|
|
|
valid_neighbor_rows,
|
|
|
k=int(args.neighbors_per_anchor),
|
|
|
objective_bins=int(args.objective_bins),
|
|
|
span_fracs=span_fracs,
|
|
|
)
|
|
|
|
|
|
neighbor_params_rows = [x[0] for x in selected_rows]
|
|
|
neighbor_curve_rows = [x[1] for x in selected_rows]
|
|
|
neighbor_obj_rows = [x[2] for x in selected_rows]
|
|
|
neighbor_span_rows = [float(x[3]) for x in selected_rows]
|
|
|
for span_frac in neighbor_span_rows:
|
|
|
summary["neighbor_span_counter_kept"][f"{span_frac:g}"] += 1
|
|
|
|
|
|
if neighbor_params_rows:
|
|
|
append_neighbors(
|
|
|
f=f,
|
|
|
anchor_id=anchor_id,
|
|
|
params_list=neighbor_params_rows,
|
|
|
curve_list=neighbor_curve_rows,
|
|
|
obj_list=neighbor_obj_rows,
|
|
|
span_frac_list=neighbor_span_rows,
|
|
|
)
|
|
|
f.flush()
|
|
|
print(
|
|
|
f"[anchor {anchor_id:03d}] family={family_name} "
|
|
|
f"neighbors={len(neighbor_params_rows)}/{args.neighbors_per_anchor} "
|
|
|
f"spans={dict(sorted(Counter(f'{x:g}' for x in neighbor_span_rows).items()))}"
|
|
|
)
|
|
|
else:
|
|
|
print(f"[warn] anchor {anchor_id:03d} produced no valid neighbors")
|
|
|
|
|
|
anchor_runner.close()
|
|
|
|
|
|
summary["n_anchor_kept"] = int(f.attrs["n_anchors"])
|
|
|
summary["n_neighbor_kept"] = int(f.attrs["n_neighbors"])
|
|
|
summary["n_anchor_attempts"] = int(anchor_attempts)
|
|
|
summary["anchor_fail_reasons"] = dict(summary["anchor_fail_reasons"])
|
|
|
summary["neighbor_fail_reasons"] = dict(summary["neighbor_fail_reasons"])
|
|
|
summary["family_counter_attempted"] = dict(summary["family_counter_attempted"])
|
|
|
summary["family_counter_kept"] = dict(summary["family_counter_kept"])
|
|
|
summary["neighbor_span_counter_valid"] = dict(summary["neighbor_span_counter_valid"])
|
|
|
summary["neighbor_span_counter_kept"] = dict(summary["neighbor_span_counter_kept"])
|
|
|
finally:
|
|
|
f.close()
|
|
|
|
|
|
summary_path = output_path.with_suffix(".summary.json")
|
|
|
with open(summary_path, "w", encoding="utf-8") as fp:
|
|
|
json.dump(summary, fp, ensure_ascii=False, indent=2)
|
|
|
|
|
|
print("\nAutofit neighborhood dataset generation complete.")
|
|
|
print(f"Output HDF5: {output_path}")
|
|
|
print(f"Summary JSON: {summary_path}")
|
|
|
print(f"Anchors kept={summary['n_anchor_kept']}, neighbors kept={summary['n_neighbor_kept']}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|