You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/build_mixed_dataset.py

200 lines
6.2 KiB
Python

"""构建“普通样本 + 困难样本”的混合训练数据集。
脚本先调用合并逻辑把常规数据集与局部自动拟合邻域数据集合成为一个 HDF5
再复用统一预处理流程生成模型训练所需的标准化数据文件
该脚本适合在正演代理模型需要同时兼顾全局覆盖样本和 PSO/自动拟合困难区域
样本时使用
"""
from __future__ import annotations
import argparse
import importlib
import sys
from pathlib import Path
from typing import Any, Callable
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT))
def parse_args() -> argparse.Namespace:
"""解析 normal/hard 两类 HDF5 的混合比例、输出路径和预处理切分参数。"""
parser = argparse.ArgumentParser(
description="Build a mixed raw+processed dataset from normal and hard HDF5 pools"
)
parser.add_argument(
"--normal-input",
type=str,
required=True,
help="Path to the normal/main .h5 dataset",
)
parser.add_argument(
"--hard-input",
type=str,
required=True,
help="Path to the hard-targeted .h5 dataset",
)
parser.add_argument(
"--tag",
type=str,
default="family_random_mixed_50k",
help="Experiment tag",
)
parser.add_argument(
"--output-h5",
type=str,
default=None,
help="Optional merged raw .h5 path",
)
parser.add_argument(
"--output-processed",
type=str,
default=None,
help="Optional processed .pkl path",
)
parser.add_argument("--total-samples", type=int, default=50000)
parser.add_argument("--hard-ratio", type=float, default=0.30)
parser.add_argument("--normal-count", type=int, default=None)
parser.add_argument("--hard-count", type=int, default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--test-size", type=float, default=0.15)
parser.add_argument("--val-size", type=float, default=0.15)
parser.add_argument("--normal-label", type=str, default="normal")
parser.add_argument("--hard-label", type=str, default="hard")
parser.add_argument("--batch-size", type=int, default=4096)
return parser.parse_args()
def load_project_functions() -> tuple[
Callable[..., dict[str, Any]],
Callable[[str], str],
Callable[[str], Path],
Callable[[str], Path],
Callable[..., None],
]:
"""延迟导入项目内部函数,避免 Pylint 对动态项目路径产生误报。"""
merge_module = importlib.import_module("scripts.merge_datasets")
paths_module = importlib.import_module("src.common.experiment_paths")
preprocess_module = importlib.import_module("src.data.preprocess")
return (
merge_module.merge_datasets,
paths_module.normalize_tag,
paths_module.processed_path_for_tag,
paths_module.sample_path_for_tag,
preprocess_module.preprocess_dataset,
)
def resolve_output_paths(
args: argparse.Namespace,
tag: str,
processed_path_for_tag: Callable[[str], Path],
sample_path_for_tag: Callable[[str], Path],
) -> tuple[Path, Path]:
"""根据命令行参数和实验标签确定 merged HDF5 与 processed 输出路径。"""
output_h5 = (
Path(args.output_h5)
if args.output_h5 is not None
else sample_path_for_tag(tag)
)
output_processed = (
Path(args.output_processed)
if args.output_processed is not None
else processed_path_for_tag(tag)
)
return output_h5, output_processed
def merge_raw_datasets(
args: argparse.Namespace,
tag: str,
output_h5: Path,
merge_datasets: Callable[..., dict[str, Any]],
) -> dict[str, Any]:
"""按设定比例合并普通样本与困难样本,并返回合并过程元数据。"""
return merge_datasets(
normal_input=args.normal_input,
hard_input=args.hard_input,
output=output_h5,
tag=tag,
total_samples=args.total_samples,
hard_ratio=args.hard_ratio,
normal_count=args.normal_count,
hard_count=args.hard_count,
seed=args.seed,
normal_label=args.normal_label,
hard_label=args.hard_label,
batch_size=args.batch_size,
)
def build_processed_dataset(
args: argparse.Namespace,
output_h5: Path,
output_processed: Path,
preprocess_dataset: Callable[..., None],
) -> None:
"""对合并后的 HDF5 数据执行统一预处理,生成模型训练用 pkl 文件。"""
preprocess_dataset(
input_path=output_h5,
output_path=output_processed,
test_size=args.test_size,
val_size=args.val_size,
random_seed=args.seed,
)
def print_outputs(merge_meta: dict[str, Any], output_processed: Path) -> None:
"""打印混合原始数据和 processed 数据的输出位置。"""
print(f"Merged raw dataset: {merge_meta['output_path']}")
print(f"Merge summary: {merge_meta['summary_path']}")
print(f"Processed dataset: {output_processed}")
def main() -> None:
"""合并普通样本与困难样本,并立即生成对应的 processed 训练数据。"""
args = parse_args()
(
merge_datasets,
normalize_tag,
processed_path_for_tag,
sample_path_for_tag,
preprocess_dataset,
) = load_project_functions()
tag = normalize_tag(args.tag)
output_h5, output_processed = resolve_output_paths(
args=args,
tag=tag,
processed_path_for_tag=processed_path_for_tag,
sample_path_for_tag=sample_path_for_tag,
)
# 先在原始 HDF5 层面按比例抽样合并,保留 source_label 便于之后追踪样本来源。
merge_meta = merge_raw_datasets(
args=args,
tag=tag,
output_h5=output_h5,
merge_datasets=merge_datasets,
)
# 合并后的原始数据立即进入同一套预处理流程,保证训练集格式与普通数据一致。
build_processed_dataset(
args=args,
output_h5=output_h5,
output_processed=output_processed,
preprocess_dataset=preprocess_dataset,
)
print_outputs(merge_meta=merge_meta, output_processed=output_processed)
if __name__ == "__main__":
main()