You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/build_mixed_dataset.py

84 lines
3.6 KiB
Python

"""构建“普通样本 + 困难样本”的混合训练数据集。
脚本先调用合并逻辑把常规数据集与局部自动拟合邻域数据集合成为一个 HDF5
再复用统一预处理流程生成模型训练所需的标准化数据文件适合在正演代理模型
需要兼顾全局覆盖和 PSO/自动拟合困难区域时使用
"""
from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from scripts.merge_datasets import merge_datasets
from src.common.experiment_paths import normalize_tag, processed_path_for_tag, sample_path_for_tag
from src.data.preprocess import preprocess_dataset
def parse_args() -> argparse.Namespace:
"""解析 normal/hard 两类 HDF5 的混合比例、输出路径和预处理切分参数。"""
parser = argparse.ArgumentParser(description="Build a mixed raw+processed dataset from normal and hard HDF5 pools")
parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset")
parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset")
parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Experiment tag")
parser.add_argument("--output-h5", type=str, default=None, help="Optional merged raw .h5 path")
parser.add_argument("--output-processed", type=str, default=None, help="Optional processed .pkl path")
parser.add_argument("--total-samples", type=int, default=50000)
parser.add_argument("--hard-ratio", type=float, default=0.30)
parser.add_argument("--normal-count", type=int, default=None)
parser.add_argument("--hard-count", type=int, default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--test-size", type=float, default=0.15)
parser.add_argument("--val-size", type=float, default=0.15)
parser.add_argument("--normal-label", type=str, default="normal")
parser.add_argument("--hard-label", type=str, default="hard")
parser.add_argument("--batch-size", type=int, default=4096)
return parser.parse_args()
def main() -> None:
"""合并普通样本与困难样本,并立即生成对应的 processed 训练数据。"""
args = parse_args()
tag = normalize_tag(args.tag)
output_h5 = Path(args.output_h5) if args.output_h5 is not None else sample_path_for_tag(tag)
output_processed = (
Path(args.output_processed) if args.output_processed is not None else processed_path_for_tag(tag)
)
# 先在原始 HDF5 层面按比例抽样合并,保留 source_label 便于之后追踪样本来源。
merge_meta = merge_datasets(
normal_input=args.normal_input,
hard_input=args.hard_input,
output=output_h5,
tag=tag,
total_samples=args.total_samples,
hard_ratio=args.hard_ratio,
normal_count=args.normal_count,
hard_count=args.hard_count,
seed=args.seed,
normal_label=args.normal_label,
hard_label=args.hard_label,
batch_size=args.batch_size,
)
# 合并后的原始数据立即进入同一套预处理流程,保证训练集格式与普通数据一致。
preprocess_dataset(
input_path=output_h5,
output_path=output_processed,
test_size=args.test_size,
val_size=args.val_size,
random_seed=args.seed,
)
print(f"Merged raw dataset: {merge_meta['output_path']}")
print(f"Merge summary: {merge_meta['summary_path']}")
print(f"Processed dataset: {output_processed}")
if __name__ == "__main__":
main()