You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
75 lines
2.3 KiB
Python
75 lines
2.3 KiB
Python
"""批量生成正演代理模型的原始数值试井数据集。
|
|
|
|
脚本读取数据生成配置,调用并行数据集生成器批量采样地层/井筒参数和流量制度,
|
|
运行底层数值求解器并把有效曲线写入 HDF5。它是训练前最上游的数据生产入口。
|
|
"""
|
|
# pylint: disable=import-error,wrong-import-position,broad-exception-caught
|
|
|
|
from __future__ import annotations
|
|
import argparse
|
|
import multiprocessing as mp
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.append(str(ROOT))
|
|
|
|
from src.common.config import Config
|
|
from src.common.experiment_paths import config_for_stage
|
|
from src.data.dataset_generation import ParallelDatasetGenerator
|
|
|
|
|
|
def main() -> None:
|
|
"""按配置阶段启动并行数值试井样本生成,输出原始 HDF5 数据集路径。"""
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--config", default=None)
|
|
parser.add_argument(
|
|
"--stage",
|
|
choices=[
|
|
"fixed_case",
|
|
"case_neighborhood",
|
|
"family_random",
|
|
"family_random_hard",
|
|
"family_random_v2_q",
|
|
],
|
|
default=None,
|
|
)
|
|
parser.add_argument("--n-samples", type=int, default=None)
|
|
parser.add_argument("--n-workers", type=int, default=None)
|
|
parser.add_argument("--seed", type=int, default=None)
|
|
parser.add_argument("--method", type=str, default=None)
|
|
parser.add_argument(
|
|
"--dataset-tag",
|
|
type=str,
|
|
default=None,
|
|
help="Optional tag injected into output dataset filename",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
config_path = args.config
|
|
if config_path is None:
|
|
config_path = str(config_for_stage(args.stage) or Path("configs/data_gen.yaml"))
|
|
|
|
# stage 用来选择预设配置;命令行参数继续覆盖样本数、并行数和随机种子。
|
|
cfg = Config(config_path)
|
|
cfg.ensure_dirs()
|
|
path = ParallelDatasetGenerator(
|
|
cfg=cfg,
|
|
n_workers=args.n_workers,
|
|
).generate(
|
|
n_samples=args.n_samples,
|
|
method=args.method,
|
|
random_seed=args.seed,
|
|
dataset_tag=args.dataset_tag,
|
|
)
|
|
print(path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
mp.freeze_support()
|
|
try:
|
|
mp.set_start_method("spawn", force=True)
|
|
except RuntimeError:
|
|
pass
|
|
main()
|