"""批量生成正演代理模型的原始数值试井数据集。 脚本读取数据生成配置,调用并行数据集生成器批量采样地层/井筒参数和流量制度, 运行底层数值求解器并把有效曲线写入 HDF5。它是训练前最上游的数据生产入口。 """ # pylint: disable=import-error,wrong-import-position,broad-exception-caught from __future__ import annotations import argparse import multiprocessing as mp import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.config import Config from src.common.experiment_paths import config_for_stage from src.data.dataset_generation import ParallelDatasetGenerator def main() -> None: """按配置阶段启动并行数值试井样本生成,输出原始 HDF5 数据集路径。""" parser = argparse.ArgumentParser() parser.add_argument("--config", default=None) parser.add_argument( "--stage", choices=[ "fixed_case", "case_neighborhood", "family_random", "family_random_hard", "family_random_v2_q", ], default=None, ) parser.add_argument("--n-samples", type=int, default=None) parser.add_argument("--n-workers", type=int, default=None) parser.add_argument("--seed", type=int, default=None) parser.add_argument("--method", type=str, default=None) parser.add_argument( "--dataset-tag", type=str, default=None, help="Optional tag injected into output dataset filename", ) args = parser.parse_args() config_path = args.config if config_path is None: config_path = str(config_for_stage(args.stage) or Path("configs/data_gen.yaml")) # stage 用来选择预设配置;命令行参数继续覆盖样本数、并行数和随机种子。 cfg = Config(config_path) cfg.ensure_dirs() path = ParallelDatasetGenerator( cfg=cfg, n_workers=args.n_workers, ).generate( n_samples=args.n_samples, method=args.method, random_seed=args.seed, dataset_tag=args.dataset_tag, ) print(path) if __name__ == "__main__": mp.freeze_support() try: mp.set_start_method("spawn", force=True) except RuntimeError: pass main()