You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/preprocess_dataset.py

76 lines
2.1 KiB
Python

"""预处理原始 HDF5 数据集以供代理模型训练。
脚本根据实验 tag 或显式路径读取原始曲线数据调用统一预处理函数完成参数变换
曲线清洗/标准化训练验证测试划分和 scaler 保存输出后续训练脚本直接消费的
processed 数据文件
"""
# pylint: disable=import-error,wrong-import-position
from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
from src.data.preprocess import preprocess_dataset
def main() -> None:
"""把原始 HDF5 曲线数据切分、标准化并保存为训练用 pkl。"""
parser = argparse.ArgumentParser(
description="Preprocess HDF5 dataset for forward surrogate"
)
parser.add_argument(
"--input",
type=str,
required=True,
help="Path to the generated .h5 dataset",
)
parser.add_argument(
"--output",
type=str,
default=None,
help="Optional output .pkl path",
)
parser.add_argument(
"--tag",
type=str,
default=None,
help="Experiment tag for auto naming",
)
parser.add_argument("--test-size", type=float, default=0.15)
parser.add_argument("--val-size", type=float, default=0.15)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--no-param-feature-transform",
action="store_true",
help="Keep raw physical parameters before StandardScaler; default uses log/asinh features",
)
args = parser.parse_args()
tag = normalize_tag(args.tag)
output_path = (
Path(args.output)
if args.output is not None
else processed_path_for_tag(tag)
)
preprocess_dataset(
input_path=Path(args.input),
output_path=output_path,
test_size=args.test_size,
val_size=args.val_size,
random_seed=args.seed,
use_param_feature_transform=not args.no_param_feature_transform,
)
if __name__ == "__main__":
main()