You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/src/data/params.py

294 lines
11 KiB
Python

# -*- coding: utf-8 -*-
"""物理参数、流量制度及二进制输入编码。
本模块定义数值试井一次正演所需的核心输入地层/井筒物理参数 Params以及分段
流量制度 Schedule除了数据结构本身还负责按照 C++ 求解器约定的字段顺序写入
params.bin因此这里的二进制布局必须与求解器端严格一致
参数采样函数支持 SobolLHS 和普通均匀采样并可通过 targeted_sampling 在困难
参数区间增加样本密度用于提升代理模型在自动拟合敏感区域的表现
"""
from __future__ import annotations
import os
import struct
from dataclasses import dataclass, asdict
from typing import Dict, List, Optional
import numpy as np
from src.common.config import Config
@dataclass
class Schedule:
"""一次数值试井模拟使用的分段流量制度。
Attributes:
sectionIndex: 当前样本对应的观测/拟合分段编号采用从 1 开始的约定
当同一条流量制度按多个分段展开训练样本时该字段用于告诉求解器和模型
当前监督曲线对应哪个阶段
timeQ: 每个流动段或关井段的持续时间长度必须与 q 一致单位由 C++ 求解器
侧配置决定
q: 每个分段对应的流量约定非负接近 0 的末段通常表示关井段
"""
sectionIndex: int
timeQ: List[float]
q: List[float]
def validate(self) -> bool:
"""检查流量制度的分段时长、流量数量和 sectionIndex 是否合法。"""
if self.timeQ is None or self.q is None:
return False
if len(self.timeQ) < 2 or len(self.q) != len(self.timeQ):
return False
if int(self.sectionIndex) < 1:
return False
t = np.asarray(self.timeQ, dtype=np.float64)
q = np.asarray(self.q, dtype=np.float64)
if not (np.isfinite(t).all() and np.isfinite(q).all()):
return False
if np.any(t <= 0.0):
return False
if np.any(q < 0.0):
return False
return True
def clipped(self, max_points: int) -> "Schedule":
"""把流量制度裁剪到求解器允许的时长和流量范围内。"""
return Schedule(
sectionIndex=max(1, int(self.sectionIndex)),
timeQ=list(map(float, self.timeQ[:max_points])),
q=list(map(float, self.q[:max_points])),
)
@dataclass
class Params:
"""描述单个油藏/井筒物理参数样本,并提供写入 C++ 求解器二进制输入的能力。"""
k: float
skin: float
wellboreC: float
phi: float
h: float
Cf: float
schedule: Optional[Schedule] = None
def to_dict(self) -> Dict[str, float]:
"""把 Params 对象转换为普通字典,便于写入 JSON、CSV 或日志。"""
return asdict(self)
def to_bin_bytes(self, cfg: Config, include_schedule: Optional[bool] = None) -> bytes:
"""按 C++ 求解器约定把参数和可选流量制度编码为 params.bin 内容。
二进制头部固定为 PRM1 magic + version随后写入 6 double 物理参数
kskinwellboreCphihCf include_schedule=True则继续写入
sectionIndex分段数量 nQtimeQ 数组和 q 数组
这里的字段顺序字节序和数据类型必须与 C++ runner 完全一致否则求解器会
读取到错误参数表现为模拟失败或曲线异常
"""
if include_schedule is None:
include_schedule = bool(cfg.get("schedule", "write_schedule_to_params_bin", default=False))
magic = ord("P") | (ord("R") << 8) | (ord("M") << 16) | (ord("1") << 24)
version = 1
b = struct.pack(
"<II6d",
int(magic), int(version),
float(self.k), float(self.skin), float(self.wellboreC),
float(self.phi), float(self.h), float(self.Cf),
)
if include_schedule and self.schedule is not None:
sch = self.schedule.clipped(int(cfg.get("schedule", "max_points", default=512)))
if not sch.validate():
raise ValueError("Invalid schedule extension")
nQ = len(sch.timeQ)
b += struct.pack("<II", int(sch.sectionIndex) & 0xFFFFFFFF, nQ & 0xFFFFFFFF)
b += struct.pack("<" + "d" * nQ, *map(float, sch.timeQ))
b += struct.pack("<" + "d" * nQ, *map(float, sch.q))
return b
def write_params_bin(path: str, p: Params, cfg: Config, include_schedule: Optional[bool] = None) -> None:
"""把 Params 的二进制表示写入求解器输入文件。"""
dirpath = os.path.dirname(path)
if dirpath:
os.makedirs(dirpath, exist_ok=True)
with open(path, "wb") as f:
f.write(p.to_bin_bytes(cfg=cfg, include_schedule=include_schedule))
def _clip(cfg: Config, name: str, v: float) -> float:
"""将标量裁剪到给定上下界。"""
lo, hi = cfg.raw["params"]["ranges"][name]
return float(np.clip(float(v), lo, hi))
def _safe_log10(x: float, eps: float = 1e-30) -> float:
"""对正数取 log10并对非正数使用下限保护避免数值错误。"""
return float(np.log10(max(float(x), eps)))
def _sample_from_components(
cfg: Config,
name: str,
u: float,
default_log_scale: bool,
components: list[dict],
) -> float:
"""从 targeted_sampling 的多个区间组件中抽取一个参数值。
全局 Sobol/LHS 样本 u 先用于选择混合组件再被线性映射为该组件内部的 local_u
这样可以在保持低差异采样顺序稳定的同时把更多样本放到配置指定的困难区间
每个组件可独立指定线性尺度或 log 尺度最终结果仍会被裁剪到全局参数范围内
"""
comps = []
probs = []
for comp in components:
if "range" not in comp:
continue
lo, hi = comp["range"]
lo = float(lo)
hi = float(hi)
if hi <= lo:
continue
comps.append(
{
"lo": lo,
"hi": hi,
"scale": str(comp.get("scale", "log" if default_log_scale else "linear")).lower(),
}
)
probs.append(float(comp.get("prob", 1.0)))
if not comps:
# 未配置 targeted components 时退回到参数全局范围采样。
lo, hi = cfg.raw["params"]["ranges"][name]
if default_log_scale:
lo = max(lo, 1e-30)
return _clip(cfg, name, 10 ** (_safe_log10(lo) + u * (_safe_log10(hi) - _safe_log10(lo))))
return _clip(cfg, name, lo + u * (hi - lo))
probs_arr = np.asarray(probs, dtype=np.float64)
probs_arr = np.maximum(probs_arr, 0.0)
if float(np.sum(probs_arr)) <= 0:
probs_arr[:] = 1.0
probs_arr = probs_arr / float(np.sum(probs_arr))
cdf = np.cumsum(probs_arr)
comp_idx = int(np.searchsorted(cdf, float(u), side="right"))
comp_idx = min(max(comp_idx, 0), len(comps) - 1)
cdf_lo = 0.0 if comp_idx == 0 else float(cdf[comp_idx - 1])
cdf_hi = float(cdf[comp_idx])
local_u = 0.0 if cdf_hi <= cdf_lo else (float(u) - cdf_lo) / max(cdf_hi - cdf_lo, 1e-12)
local_u = float(np.clip(local_u, 0.0, 1.0))
# 同一个全局 u 先决定组件,再映射到该组件内部的局部 [0, 1] 坐标。
lo = float(comps[comp_idx]["lo"])
hi = float(comps[comp_idx]["hi"])
scale = str(comps[comp_idx]["scale"])
if scale == "log":
lo = max(lo, 1e-30)
v = 10 ** (_safe_log10(lo) + local_u * (_safe_log10(hi) - _safe_log10(lo)))
else:
v = lo + local_u * (hi - lo)
return _clip(cfg, name, v)
def _sample_param_value(
cfg: Config,
name: str,
u: float,
default_log_scale: bool,
) -> float:
"""根据单个参数的配置生成一个采样值。"""
targeted = cfg.raw["params"].get("targeted_sampling", {}) or {}
if bool(targeted.get("enabled", False)):
strategies = targeted.get("strategies", {}) or {}
strategy = strategies.get(name, {}) or {}
components = strategy.get("components", []) or []
if components:
# targeted_sampling 用于在困难参数区间增加样本密度。
return _sample_from_components(cfg, name, u, default_log_scale, components)
lo, hi = cfg.raw["params"]["ranges"][name]
if default_log_scale:
lo = max(lo, 1e-30)
v = 10 ** (_safe_log10(lo) + u * (_safe_log10(hi) - _safe_log10(lo)))
else:
v = lo + u * (hi - lo)
return _clip(cfg, name, v)
def _qmc_unit(n: int, d: int, method: str, seed: int | None) -> np.ndarray:
"""生成准蒙特卡洛或伪随机的单位区间样本矩阵。"""
method = (method or "sobol").lower()
rng = np.random.RandomState(seed)
try:
from scipy.stats import qmc
if method == "sobol":
sampler = qmc.Sobol(d=d, scramble=True, seed=seed)
m = int(np.ceil(np.log2(max(n, 1))))
U = sampler.random_base2(m=m)
return U[:n]
if method == "lhs":
sampler = qmc.LatinHypercube(d=d, seed=seed)
return sampler.random(n=n)
if method == "uniform":
return rng.rand(n, d)
except Exception:
return rng.rand(n, d)
raise ValueError(f"Unknown sampling method: {method}")
def _build_full_param_dict(cfg: Config, sampled_vals: Dict[str, float]) -> Dict[str, float]:
"""把采样参数与默认参数合并成完整物理参数字典。"""
out: Dict[str, float] = {}
fixed_cfg = cfg.raw["params"].get("fixed_params", {}) or {}
for name in cfg.raw["params"]["all_physical_param_names"]:
fixed = fixed_cfg.get(name, {}) or {}
if bool(fixed.get("enabled", False)):
out[name] = _clip(cfg, name, float(fixed["value"]))
else:
out[name] = _clip(cfg, name, sampled_vals[name])
return out
def generate_params_dataset(cfg: Config, n_samples: int, method: str | None = None, random_seed: int | None = None) -> list[Params]:
"""批量生成物理参数样本。
采样只直接作用于 active_param_names未激活的参数由 fixed_params 或全局范围补齐
从而保证写给 C++ 求解器的 Params 始终包含完整 6 个字段函数会对 active 参数
组合做简单去重避免 Sobol 截断或固定参数导致重复样本进入数据集
"""
method = (method or cfg.raw["params"].get("sampling_method", "sobol")).lower()
active_names = list(cfg.raw["params"]["active_param_names"])
log_params = set(cfg.raw["params"]["log_params"])
U = _qmc_unit(n_samples, len(active_names), method, random_seed)
out: list[Params] = []
seen = set()
for row in U:
sampled_vals: Dict[str, float] = {}
for i, name in enumerate(active_names):
u = float(row[i])
sampled_vals[name] = _sample_param_value(
cfg=cfg,
name=name,
u=u,
default_log_scale=(name in log_params),
)
# active_names 之外的参数由 fixed_params 或配置范围补齐,保证求解器字段完整。
p = Params(**_build_full_param_dict(cfg, sampled_vals))
key = tuple(round(float(getattr(p, name)), 10) for name in active_names)
if key not in seen:
seen.add(key)
out.append(p)
if len(out) >= n_samples:
break
return out