You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/generate_dataset.py

57 lines
2.1 KiB
Python

"""批量生成正演代理模型的原始数值试井数据集。
脚本读取数据生成配置调用并行数据集生成器批量采样地层/井筒参数和流量制度
运行底层数值求解器并把有效曲线写入 HDF5它是训练前最上游的数据生产入口
"""
from __future__ import annotations
import argparse
import multiprocessing as mp
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.config import Config
from src.common.experiment_paths import config_for_stage
from src.data.dataset_generation import ParallelDatasetGenerator
def main():
"""按配置阶段启动并行数值试井样本生成,输出原始 HDF5 数据集路径。"""
parser = argparse.ArgumentParser()
parser.add_argument("--config", default=None)
parser.add_argument(
"--stage",
choices=["fixed_case", "case_neighborhood", "family_random", "family_random_hard", "family_random_v2_q"],
default=None,
)
parser.add_argument("--n-samples", type=int, default=None)
parser.add_argument("--n-workers", type=int, default=None)
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--method", type=str, default=None)
parser.add_argument("--dataset-tag", type=str, default=None, help="Optional tag injected into output dataset filename")
args = parser.parse_args()
config_path = args.config
if config_path is None:
config_path = str(config_for_stage(args.stage) or Path("configs/data_gen.yaml"))
# stage 用来选择预设配置;命令行参数继续覆盖样本数、并行数和随机种子。
cfg = Config(config_path)
cfg.ensure_dirs()
path = ParallelDatasetGenerator(cfg=cfg, n_workers=args.n_workers).generate(
n_samples=args.n_samples, method=args.method, random_seed=args.seed, dataset_tag=args.dataset_tag
)
print(path)
if __name__ == "__main__":
mp.freeze_support()
try:
mp.set_start_method("spawn", force=True)
except Exception:
pass
main()