from __future__ import annotations import os import struct from dataclasses import dataclass, asdict from typing import Dict, List, Optional import numpy as np from src.common.config import Config @dataclass class Schedule: sectionIndex: int timeQ: List[float] q: List[float] def validate(self) -> bool: if self.timeQ is None or self.q is None: return False if len(self.timeQ) < 2 or len(self.q) != len(self.timeQ): return False if int(self.sectionIndex) < 1: return False t = np.asarray(self.timeQ, dtype=np.float64) q = np.asarray(self.q, dtype=np.float64) if not (np.isfinite(t).all() and np.isfinite(q).all()): return False if np.any(t <= 0.0): return False if np.any(q < 0.0): return False return True def clipped(self, max_points: int) -> "Schedule": return Schedule( sectionIndex=max(1, int(self.sectionIndex)), timeQ=list(map(float, self.timeQ[:max_points])), q=list(map(float, self.q[:max_points])), ) @dataclass class Params: k: float skin: float wellboreC: float phi: float h: float Cf: float schedule: Optional[Schedule] = None def to_dict(self) -> Dict[str, float]: return asdict(self) def to_bin_bytes(self, cfg: Config, include_schedule: Optional[bool] = None) -> bytes: if include_schedule is None: include_schedule = bool(cfg.get("schedule", "write_schedule_to_params_bin", default=False)) magic = ord("P") | (ord("R") << 8) | (ord("M") << 16) | (ord("1") << 24) version = 1 b = struct.pack( " None: dirpath = os.path.dirname(path) if dirpath: os.makedirs(dirpath, exist_ok=True) with open(path, "wb") as f: f.write(p.to_bin_bytes(cfg=cfg, include_schedule=include_schedule)) def _clip(cfg: Config, name: str, v: float) -> float: lo, hi = cfg.raw["params"]["ranges"][name] return float(np.clip(float(v), lo, hi)) def _safe_log10(x: float, eps: float = 1e-30) -> float: return float(np.log10(max(float(x), eps))) def _sample_from_components( cfg: Config, name: str, u: float, default_log_scale: bool, components: list[dict], ) -> float: comps = [] probs = [] for comp in components: if "range" not in comp: continue lo, hi = comp["range"] lo = float(lo) hi = float(hi) if hi <= lo: continue comps.append( { "lo": lo, "hi": hi, "scale": str(comp.get("scale", "log" if default_log_scale else "linear")).lower(), } ) probs.append(float(comp.get("prob", 1.0))) if not comps: lo, hi = cfg.raw["params"]["ranges"][name] if default_log_scale: lo = max(lo, 1e-30) return _clip(cfg, name, 10 ** (_safe_log10(lo) + u * (_safe_log10(hi) - _safe_log10(lo)))) return _clip(cfg, name, lo + u * (hi - lo)) probs_arr = np.asarray(probs, dtype=np.float64) probs_arr = np.maximum(probs_arr, 0.0) if float(np.sum(probs_arr)) <= 0: probs_arr[:] = 1.0 probs_arr = probs_arr / float(np.sum(probs_arr)) cdf = np.cumsum(probs_arr) comp_idx = int(np.searchsorted(cdf, float(u), side="right")) comp_idx = min(max(comp_idx, 0), len(comps) - 1) cdf_lo = 0.0 if comp_idx == 0 else float(cdf[comp_idx - 1]) cdf_hi = float(cdf[comp_idx]) local_u = 0.0 if cdf_hi <= cdf_lo else (float(u) - cdf_lo) / max(cdf_hi - cdf_lo, 1e-12) local_u = float(np.clip(local_u, 0.0, 1.0)) lo = float(comps[comp_idx]["lo"]) hi = float(comps[comp_idx]["hi"]) scale = str(comps[comp_idx]["scale"]) if scale == "log": lo = max(lo, 1e-30) v = 10 ** (_safe_log10(lo) + local_u * (_safe_log10(hi) - _safe_log10(lo))) else: v = lo + local_u * (hi - lo) return _clip(cfg, name, v) def _sample_param_value( cfg: Config, name: str, u: float, default_log_scale: bool, ) -> float: targeted = cfg.raw["params"].get("targeted_sampling", {}) or {} if bool(targeted.get("enabled", False)): strategies = targeted.get("strategies", {}) or {} strategy = strategies.get(name, {}) or {} components = strategy.get("components", []) or [] if components: return _sample_from_components(cfg, name, u, default_log_scale, components) lo, hi = cfg.raw["params"]["ranges"][name] if default_log_scale: lo = max(lo, 1e-30) v = 10 ** (_safe_log10(lo) + u * (_safe_log10(hi) - _safe_log10(lo))) else: v = lo + u * (hi - lo) return _clip(cfg, name, v) def _qmc_unit(n: int, d: int, method: str, seed: int | None) -> np.ndarray: method = (method or "sobol").lower() rng = np.random.RandomState(seed) try: from scipy.stats import qmc if method == "sobol": sampler = qmc.Sobol(d=d, scramble=True, seed=seed) m = int(np.ceil(np.log2(max(n, 1)))) U = sampler.random_base2(m=m) return U[:n] if method == "lhs": sampler = qmc.LatinHypercube(d=d, seed=seed) return sampler.random(n=n) if method == "uniform": return rng.rand(n, d) except Exception: return rng.rand(n, d) raise ValueError(f"Unknown sampling method: {method}") def _build_full_param_dict(cfg: Config, sampled_vals: Dict[str, float]) -> Dict[str, float]: out: Dict[str, float] = {} fixed_cfg = cfg.raw["params"].get("fixed_params", {}) or {} for name in cfg.raw["params"]["all_physical_param_names"]: fixed = fixed_cfg.get(name, {}) or {} if bool(fixed.get("enabled", False)): out[name] = _clip(cfg, name, float(fixed["value"])) else: out[name] = _clip(cfg, name, sampled_vals[name]) return out def generate_params_dataset(cfg: Config, n_samples: int, method: str | None = None, random_seed: int | None = None) -> list[Params]: method = (method or cfg.raw["params"].get("sampling_method", "sobol")).lower() active_names = list(cfg.raw["params"]["active_param_names"]) log_params = set(cfg.raw["params"]["log_params"]) U = _qmc_unit(n_samples, len(active_names), method, random_seed) out: list[Params] = [] seen = set() for row in U: sampled_vals: Dict[str, float] = {} for i, name in enumerate(active_names): u = float(row[i]) sampled_vals[name] = _sample_param_value( cfg=cfg, name=name, u=u, default_log_scale=(name in log_params), ) p = Params(**_build_full_param_dict(cfg, sampled_vals)) key = tuple(round(float(getattr(p, name)), 10) for name in active_names) if key not in seen: seen.add(key) out.append(p) if len(out) >= n_samples: break return out