|
|
"""构建“普通样本 + 困难样本”的混合训练数据集。
|
|
|
|
|
|
脚本先调用合并逻辑把常规数据集与局部自动拟合邻域数据集合成为一个 HDF5,
|
|
|
再复用统一预处理流程生成模型训练所需的标准化数据文件。适合在正演代理模型
|
|
|
需要兼顾全局覆盖和 PSO/自动拟合困难区域时使用。
|
|
|
"""
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import argparse
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
|
sys.path.append(str(ROOT))
|
|
|
|
|
|
from scripts.merge_datasets import merge_datasets
|
|
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag, sample_path_for_tag
|
|
|
from src.data.preprocess import preprocess_dataset
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
|
"""解析 normal/hard 两类 HDF5 的混合比例、输出路径和预处理切分参数。"""
|
|
|
parser = argparse.ArgumentParser(description="Build a mixed raw+processed dataset from normal and hard HDF5 pools")
|
|
|
parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset")
|
|
|
parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset")
|
|
|
parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Experiment tag")
|
|
|
parser.add_argument("--output-h5", type=str, default=None, help="Optional merged raw .h5 path")
|
|
|
parser.add_argument("--output-processed", type=str, default=None, help="Optional processed .pkl path")
|
|
|
parser.add_argument("--total-samples", type=int, default=50000)
|
|
|
parser.add_argument("--hard-ratio", type=float, default=0.30)
|
|
|
parser.add_argument("--normal-count", type=int, default=None)
|
|
|
parser.add_argument("--hard-count", type=int, default=None)
|
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
|
parser.add_argument("--test-size", type=float, default=0.15)
|
|
|
parser.add_argument("--val-size", type=float, default=0.15)
|
|
|
parser.add_argument("--normal-label", type=str, default="normal")
|
|
|
parser.add_argument("--hard-label", type=str, default="hard")
|
|
|
parser.add_argument("--batch-size", type=int, default=4096)
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
"""合并普通样本与困难样本,并立即生成对应的 processed 训练数据。"""
|
|
|
args = parse_args()
|
|
|
tag = normalize_tag(args.tag)
|
|
|
output_h5 = Path(args.output_h5) if args.output_h5 is not None else sample_path_for_tag(tag)
|
|
|
output_processed = (
|
|
|
Path(args.output_processed) if args.output_processed is not None else processed_path_for_tag(tag)
|
|
|
)
|
|
|
|
|
|
# 先在原始 HDF5 层面按比例抽样合并,保留 source_label 便于之后追踪样本来源。
|
|
|
merge_meta = merge_datasets(
|
|
|
normal_input=args.normal_input,
|
|
|
hard_input=args.hard_input,
|
|
|
output=output_h5,
|
|
|
tag=tag,
|
|
|
total_samples=args.total_samples,
|
|
|
hard_ratio=args.hard_ratio,
|
|
|
normal_count=args.normal_count,
|
|
|
hard_count=args.hard_count,
|
|
|
seed=args.seed,
|
|
|
normal_label=args.normal_label,
|
|
|
hard_label=args.hard_label,
|
|
|
batch_size=args.batch_size,
|
|
|
)
|
|
|
|
|
|
# 合并后的原始数据立即进入同一套预处理流程,保证训练集格式与普通数据一致。
|
|
|
preprocess_dataset(
|
|
|
input_path=output_h5,
|
|
|
output_path=output_processed,
|
|
|
test_size=args.test_size,
|
|
|
val_size=args.val_size,
|
|
|
random_seed=args.seed,
|
|
|
)
|
|
|
|
|
|
print(f"Merged raw dataset: {merge_meta['output_path']}")
|
|
|
print(f"Merge summary: {merge_meta['summary_path']}")
|
|
|
print(f"Processed dataset: {output_processed}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|