|
|
"""合并常规数据集与困难样本数据集。
|
|
|
|
|
|
脚本校验两个 HDF5 的核心字段、属性和可选元数据是否兼容,按用户指定比例或数量抽样,
|
|
|
再写出一个结构一致的混合集。该流程用于把全局随机样本与自动拟合邻域样本混合,
|
|
|
提升代理模型在实际反演困难区域的覆盖度。
|
|
|
"""
|
|
|
|
|
|
# pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,no-member
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
import argparse
|
|
|
import json
|
|
|
import math
|
|
|
import sys
|
|
|
from pathlib import Path
|
|
|
from types import SimpleNamespace
|
|
|
from typing import Any
|
|
|
|
|
|
import h5py
|
|
|
import numpy as np
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
|
sys.path.append(str(ROOT))
|
|
|
|
|
|
|
|
|
REQUIRED_DATASETS = ("params", "schedule", "curve")
|
|
|
OPTIONAL_DATASETS = (
|
|
|
"group_id",
|
|
|
"schedule_meta",
|
|
|
"family_name",
|
|
|
"is_anchor",
|
|
|
"neighbor_objective",
|
|
|
"neighbor_objective_p",
|
|
|
"neighbor_objective_d",
|
|
|
"neighbor_span_frac",
|
|
|
"section_index",
|
|
|
"timeQ_json",
|
|
|
"q_json",
|
|
|
)
|
|
|
COPY_ATTRS = ("param_names", "schedule_meta_names")
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
|
"""解析两份 HDF5 数据集的抽样数量、混合比例、标签和输出路径。"""
|
|
|
parser = argparse.ArgumentParser(description="Merge normal and hard raw HDF5 datasets into one mixed dataset")
|
|
|
parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset")
|
|
|
parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset")
|
|
|
parser.add_argument("--output", type=str, default=None, help="Optional output .h5 path")
|
|
|
parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Tag used for default output naming")
|
|
|
parser.add_argument("--total-samples", type=int, default=50000, help="Total rows to export")
|
|
|
parser.add_argument("--hard-ratio", type=float, default=0.30, help="Hard-sample fraction when explicit counts are omitted")
|
|
|
parser.add_argument("--normal-count", type=int, default=None, help="Explicit normal sample count")
|
|
|
parser.add_argument("--hard-count", type=int, default=None, help="Explicit hard sample count")
|
|
|
parser.add_argument("--seed", type=int, default=42, help="Random seed")
|
|
|
parser.add_argument("--normal-label", type=str, default="normal", help="Value written to source_name for normal rows")
|
|
|
parser.add_argument("--hard-label", type=str, default="hard", help="Value written to source_name for hard rows")
|
|
|
parser.add_argument("--batch-size", type=int, default=4096, help="Output write batch size")
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
def normalize_tag(tag: str | None) -> str:
|
|
|
"""规范化实验标签,避免空字符串、大小写或空白导致路径不一致。"""
|
|
|
if tag is None:
|
|
|
return "merged"
|
|
|
cleaned = str(tag).strip().replace(" ", "_").replace("-", "_")
|
|
|
return cleaned or "merged"
|
|
|
|
|
|
|
|
|
def default_output_path(tag: str) -> Path:
|
|
|
"""根据合并标签生成默认 HDF5 输出文件路径。"""
|
|
|
return Path("data") / "samples" / f"dataset_{tag}.h5"
|
|
|
|
|
|
|
|
|
def _decode_attr_list(raw_names: Any) -> list[str] | None:
|
|
|
"""解码 HDF5 属性中保存的 JSON 字符串或字节串列表。"""
|
|
|
if raw_names is None:
|
|
|
return None
|
|
|
return [
|
|
|
item.decode("utf-8") if isinstance(item, (bytes, np.bytes_)) else str(item)
|
|
|
for item in raw_names
|
|
|
]
|
|
|
|
|
|
|
|
|
def _read_string_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray:
|
|
|
"""从 HDF5 字符串数据集中读取指定行,并统一转成 Python 字符串。"""
|
|
|
if indices.size == 0:
|
|
|
return np.asarray([], dtype=object)
|
|
|
indices = indices.astype(np.int64)
|
|
|
order = np.argsort(indices, kind="stable")
|
|
|
inverse = np.empty_like(order)
|
|
|
inverse[order] = np.arange(order.size)
|
|
|
values_sorted = np.asarray(ds[indices[order]]).astype(str)
|
|
|
return values_sorted[inverse]
|
|
|
|
|
|
|
|
|
def _read_numeric_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray:
|
|
|
"""从 HDF5 数值数据集中读取指定行。"""
|
|
|
if indices.size == 0:
|
|
|
shape_tail = ds.shape[1:]
|
|
|
return np.empty((0, *shape_tail), dtype=ds.dtype) if shape_tail else np.empty((0,), dtype=ds.dtype)
|
|
|
indices = indices.astype(np.int64)
|
|
|
order = np.argsort(indices, kind="stable")
|
|
|
inverse = np.empty_like(order)
|
|
|
inverse[order] = np.arange(order.size)
|
|
|
values_sorted = np.asarray(ds[indices[order]])
|
|
|
return values_sorted[inverse]
|
|
|
|
|
|
|
|
|
def _assert_attr_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> list[str] | None:
|
|
|
"""检查多个输入文件的关键属性是否一致,避免合并不同 schema。"""
|
|
|
left = _decode_attr_list(lhs.attrs.get(name))
|
|
|
right = _decode_attr_list(rhs.attrs.get(name))
|
|
|
if left is None and right is None:
|
|
|
return None
|
|
|
if left is None or right is None:
|
|
|
raise ValueError(f"Attribute mismatch for {name}: one input has it and the other does not")
|
|
|
if left != right:
|
|
|
raise ValueError(f"Attribute mismatch for {name}: {left} != {right}")
|
|
|
return left
|
|
|
|
|
|
|
|
|
def _assert_dataset_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> tuple[tuple[int, ...], np.dtype]:
|
|
|
"""检查多个输入文件同名数据集的维度和 dtype 是否可合并。"""
|
|
|
if name not in lhs or name not in rhs:
|
|
|
raise ValueError(f"Dataset mismatch for {name}: both inputs must contain it")
|
|
|
lhs_ds = lhs[name]
|
|
|
rhs_ds = rhs[name]
|
|
|
if lhs_ds.shape[1:] != rhs_ds.shape[1:]:
|
|
|
raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}")
|
|
|
if lhs_ds.dtype != rhs_ds.dtype:
|
|
|
raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}")
|
|
|
return lhs_ds.shape[1:], lhs_ds.dtype
|
|
|
|
|
|
|
|
|
def _is_string_dtype(dtype: np.dtype) -> bool:
|
|
|
"""判断 HDF5 数据集是否使用字符串 dtype。"""
|
|
|
return h5py.check_string_dtype(dtype) is not None or dtype.kind in ("O", "S", "U")
|
|
|
|
|
|
|
|
|
def _optional_dataset_schema(name: str, lhs: h5py.File, rhs: h5py.File) -> dict[str, Any] | None:
|
|
|
"""读取可选数据集的 dtype 和尾部形状,用于在输出文件中创建同 schema 数据集。"""
|
|
|
has_l = name in lhs
|
|
|
has_r = name in rhs
|
|
|
if not has_l and not has_r:
|
|
|
return None
|
|
|
|
|
|
ref = lhs[name] if has_l else rhs[name]
|
|
|
shape_tail = ref.shape[1:]
|
|
|
dtype = ref.dtype
|
|
|
is_string = _is_string_dtype(dtype)
|
|
|
|
|
|
if has_l and has_r:
|
|
|
lhs_ds = lhs[name]
|
|
|
rhs_ds = rhs[name]
|
|
|
if lhs_ds.shape[1:] != rhs_ds.shape[1:]:
|
|
|
raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}")
|
|
|
if _is_string_dtype(lhs_ds.dtype) != _is_string_dtype(rhs_ds.dtype):
|
|
|
raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}")
|
|
|
if not is_string and lhs_ds.dtype != rhs_ds.dtype:
|
|
|
raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}")
|
|
|
|
|
|
return {
|
|
|
"shape_tail": tuple(shape_tail),
|
|
|
"dtype": dtype,
|
|
|
"is_string": bool(is_string),
|
|
|
"normal_has": bool(has_l),
|
|
|
"hard_has": bool(has_r),
|
|
|
}
|
|
|
|
|
|
|
|
|
def _default_numeric_value(name: str) -> float | int:
|
|
|
"""为缺失的可选数值数据集生成合适的填充值。"""
|
|
|
if name in {"group_id", "section_index", "is_anchor"}:
|
|
|
return -1
|
|
|
return np.nan
|
|
|
|
|
|
|
|
|
def resolve_counts(args: argparse.Namespace, n_normal_avail: int, n_hard_avail: int) -> tuple[int, int]:
|
|
|
"""根据每个输入文件样本量和用户给定比例/数量计算抽样数量。"""
|
|
|
if args.normal_count is not None or args.hard_count is not None:
|
|
|
if args.normal_count is None or args.hard_count is None:
|
|
|
raise ValueError("Both --normal-count and --hard-count are required when using explicit counts")
|
|
|
normal_count = int(args.normal_count)
|
|
|
hard_count = int(args.hard_count)
|
|
|
else:
|
|
|
total = int(args.total_samples)
|
|
|
if total <= 0:
|
|
|
raise ValueError("--total-samples must be positive")
|
|
|
hard_ratio = float(args.hard_ratio)
|
|
|
if hard_ratio < 0.0 or hard_ratio > 1.0:
|
|
|
raise ValueError("--hard-ratio must be in [0, 1]")
|
|
|
hard_count = int(round(total * hard_ratio))
|
|
|
normal_count = total - hard_count
|
|
|
|
|
|
if normal_count <= 0 or hard_count <= 0:
|
|
|
raise ValueError("Both normal_count and hard_count must be positive")
|
|
|
if normal_count > n_normal_avail:
|
|
|
raise ValueError(f"Requested normal_count={normal_count}, but only {n_normal_avail} rows are available")
|
|
|
if hard_count > n_hard_avail:
|
|
|
raise ValueError(f"Requested hard_count={hard_count}, but only {n_hard_avail} rows are available")
|
|
|
return normal_count, hard_count
|
|
|
|
|
|
|
|
|
def create_output_file(
|
|
|
output_path: Path,
|
|
|
total_rows: int,
|
|
|
required_schemas: dict[str, tuple[tuple[int, ...], np.dtype]],
|
|
|
optional_schemas: dict[str, dict[str, Any] | None],
|
|
|
copied_attrs: dict[str, list[str] | None],
|
|
|
merge_meta: dict[str, Any],
|
|
|
) -> h5py.File:
|
|
|
"""创建邻域 HDF5 文件,并初始化锚点、候选、曲线和元数据数据集。"""
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
f = h5py.File(output_path, "w")
|
|
|
|
|
|
# 必需数据集必须在两个来源中完全兼容,合并输出直接按最终行数一次性创建。
|
|
|
for name, (shape_tail, dtype) in required_schemas.items():
|
|
|
f.create_dataset(name, shape=(total_rows, *shape_tail), dtype=dtype)
|
|
|
|
|
|
# 可选数据集允许某个来源缺失;缺失行会写入默认值或空字符串。
|
|
|
for name, schema in optional_schemas.items():
|
|
|
if schema is None:
|
|
|
continue
|
|
|
shape_tail = tuple(schema["shape_tail"])
|
|
|
dtype = schema["dtype"]
|
|
|
if bool(schema["is_string"]):
|
|
|
f.create_dataset(name, shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
|
|
|
else:
|
|
|
f.create_dataset(
|
|
|
name,
|
|
|
shape=(total_rows, *shape_tail),
|
|
|
dtype=dtype,
|
|
|
fillvalue=_default_numeric_value(name),
|
|
|
)
|
|
|
|
|
|
f.create_dataset("source_id", shape=(total_rows,), dtype=np.int8)
|
|
|
f.create_dataset("source_name", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
|
|
|
f.create_dataset("source_row", shape=(total_rows,), dtype=np.int64)
|
|
|
|
|
|
# source_* 字段记录每一行来自哪个输入文件,方便后续误差归因和抽样审计。
|
|
|
for attr_name, values in copied_attrs.items():
|
|
|
if values is not None:
|
|
|
f.attrs[attr_name] = np.asarray(values, dtype="S")
|
|
|
|
|
|
f.attrs["n_samples"] = int(total_rows)
|
|
|
f.attrs["source_name_vocab"] = np.asarray(
|
|
|
[str(merge_meta["normal_label"]), str(merge_meta["hard_label"])],
|
|
|
dtype="S",
|
|
|
)
|
|
|
f.attrs["merge_meta_json"] = json.dumps(merge_meta, ensure_ascii=False)
|
|
|
return f
|
|
|
|
|
|
|
|
|
def build_sample_plan(
|
|
|
rng: np.random.RandomState,
|
|
|
normal_count: int,
|
|
|
hard_count: int,
|
|
|
n_normal_avail: int,
|
|
|
n_hard_avail: int,
|
|
|
) -> np.ndarray:
|
|
|
"""为每个输入文件生成需要复制的行索引和来源标签。"""
|
|
|
normal_idx = rng.choice(n_normal_avail, size=normal_count, replace=False)
|
|
|
hard_idx = rng.choice(n_hard_avail, size=hard_count, replace=False)
|
|
|
|
|
|
plan = np.empty((normal_count + hard_count, 2), dtype=np.int64)
|
|
|
plan[:normal_count, 0] = 0
|
|
|
plan[:normal_count, 1] = normal_idx
|
|
|
plan[normal_count:, 0] = 1
|
|
|
plan[normal_count:, 1] = hard_idx
|
|
|
rng.shuffle(plan)
|
|
|
return plan
|
|
|
|
|
|
|
|
|
def summarize_family_counts(source_file: h5py.File, chosen_rows: np.ndarray) -> dict[str, int]:
|
|
|
"""统计合并后各流量制度族的样本数量。"""
|
|
|
if "family_name" not in source_file:
|
|
|
return {}
|
|
|
family_name = _read_string_rows(source_file["family_name"], chosen_rows)
|
|
|
unique, counts = np.unique(family_name, return_counts=True)
|
|
|
return {str(name): int(count) for name, count in zip(unique.tolist(), counts.tolist())}
|
|
|
|
|
|
|
|
|
def write_batch(
|
|
|
out_file: h5py.File,
|
|
|
batch_plan: np.ndarray,
|
|
|
normal_file: h5py.File,
|
|
|
hard_file: h5py.File,
|
|
|
normal_label: str,
|
|
|
hard_label: str,
|
|
|
start: int,
|
|
|
optional_names: tuple[str, ...],
|
|
|
optional_schemas: dict[str, dict[str, Any] | None],
|
|
|
) -> None:
|
|
|
"""把一批选中的源样本复制到合并输出文件。"""
|
|
|
batch_size = int(batch_plan.shape[0])
|
|
|
source_col = batch_plan[:, 0]
|
|
|
row_col = batch_plan[:, 1]
|
|
|
row_positions = np.arange(batch_size, dtype=np.int64)
|
|
|
|
|
|
sources = {
|
|
|
0: {"file": normal_file, "label": normal_label},
|
|
|
1: {"file": hard_file, "label": hard_label},
|
|
|
}
|
|
|
|
|
|
for dataset_name in REQUIRED_DATASETS:
|
|
|
out_ds = out_file[dataset_name]
|
|
|
shape_tail = out_ds.shape[1:]
|
|
|
batch_arr = np.empty((batch_size, *shape_tail), dtype=out_ds.dtype)
|
|
|
# 同一个 batch 中 normal/hard 行交错排列,先按来源分组读取,再放回原位置。
|
|
|
for source_id in (0, 1):
|
|
|
mask = source_col == source_id
|
|
|
if not np.any(mask):
|
|
|
continue
|
|
|
positions = row_positions[mask]
|
|
|
rows = row_col[mask]
|
|
|
batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows)
|
|
|
out_ds[start : start + batch_size] = batch_arr
|
|
|
|
|
|
for dataset_name in optional_names:
|
|
|
schema = optional_schemas[dataset_name]
|
|
|
if schema is None:
|
|
|
continue
|
|
|
|
|
|
if bool(schema["is_string"]):
|
|
|
# 字符串可选字段没有数值 fillvalue,缺失来源统一填空字符串。
|
|
|
batch_arr = np.full((batch_size,), "", dtype=object)
|
|
|
for source_id in (0, 1):
|
|
|
mask = source_col == source_id
|
|
|
if not np.any(mask):
|
|
|
continue
|
|
|
has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"])
|
|
|
if not has_dataset:
|
|
|
continue
|
|
|
positions = row_positions[mask]
|
|
|
rows = row_col[mask]
|
|
|
batch_arr[positions] = _read_string_rows(sources[source_id]["file"][dataset_name], rows)
|
|
|
out_file[dataset_name][start : start + batch_size] = batch_arr.tolist()
|
|
|
continue
|
|
|
|
|
|
out_ds = out_file[dataset_name]
|
|
|
shape_tail = out_ds.shape[1:]
|
|
|
batch_arr = np.full(
|
|
|
(batch_size, *shape_tail),
|
|
|
_default_numeric_value(dataset_name),
|
|
|
dtype=out_ds.dtype,
|
|
|
)
|
|
|
# 数值可选字段先填默认值,只覆盖实际存在于来源文件中的部分。
|
|
|
for source_id in (0, 1):
|
|
|
mask = source_col == source_id
|
|
|
if not np.any(mask):
|
|
|
continue
|
|
|
has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"])
|
|
|
if not has_dataset:
|
|
|
continue
|
|
|
positions = row_positions[mask]
|
|
|
rows = row_col[mask]
|
|
|
batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows)
|
|
|
out_ds[start : start + batch_size] = batch_arr
|
|
|
|
|
|
out_file["source_id"][start : start + batch_size] = source_col.astype(np.int8)
|
|
|
out_file["source_name"][start : start + batch_size] = [
|
|
|
sources[int(source_id)]["label"] for source_id in source_col
|
|
|
]
|
|
|
out_file["source_row"][start : start + batch_size] = row_col.astype(np.int64)
|
|
|
|
|
|
|
|
|
def merge_datasets(
|
|
|
normal_input: str | Path,
|
|
|
hard_input: str | Path,
|
|
|
output: str | Path | None = None,
|
|
|
tag: str | None = "family_random_mixed_50k",
|
|
|
total_samples: int = 50000,
|
|
|
hard_ratio: float = 0.30,
|
|
|
normal_count: int | None = None,
|
|
|
hard_count: int | None = None,
|
|
|
seed: int = 42,
|
|
|
normal_label: str = "normal",
|
|
|
hard_label: str = "hard",
|
|
|
batch_size: int = 4096,
|
|
|
) -> dict[str, Any]:
|
|
|
"""执行 HDF5 数据集合并,复制主数据集、可选数据集和元数据,并写出摘要。"""
|
|
|
tag = normalize_tag(tag)
|
|
|
output_path = Path(output) if output is not None else default_output_path(tag)
|
|
|
rng = np.random.RandomState(int(seed))
|
|
|
|
|
|
normal_input = Path(normal_input)
|
|
|
hard_input = Path(hard_input)
|
|
|
if not normal_input.exists():
|
|
|
raise FileNotFoundError(f"normal_input not found: {normal_input}")
|
|
|
if not hard_input.exists():
|
|
|
raise FileNotFoundError(f"hard_input not found: {hard_input}")
|
|
|
|
|
|
merge_meta: dict[str, Any]
|
|
|
|
|
|
count_args = SimpleNamespace(
|
|
|
normal_count=normal_count,
|
|
|
hard_count=hard_count,
|
|
|
total_samples=total_samples,
|
|
|
hard_ratio=hard_ratio,
|
|
|
)
|
|
|
|
|
|
with h5py.File(normal_input, "r") as normal_file, h5py.File(hard_input, "r") as hard_file:
|
|
|
# 合并前先锁定 schema,避免后面边写边发现字段不兼容。
|
|
|
required_schemas = {
|
|
|
name: _assert_dataset_compatible(name, normal_file, hard_file)
|
|
|
for name in REQUIRED_DATASETS
|
|
|
}
|
|
|
optional_schemas = {
|
|
|
name: _optional_dataset_schema(name, normal_file, hard_file)
|
|
|
for name in OPTIONAL_DATASETS
|
|
|
}
|
|
|
copied_attrs = {
|
|
|
name: _assert_attr_compatible(name, normal_file, hard_file)
|
|
|
for name in COPY_ATTRS
|
|
|
}
|
|
|
|
|
|
n_normal_avail = int(normal_file["params"].shape[0])
|
|
|
n_hard_avail = int(hard_file["params"].shape[0])
|
|
|
normal_count_resolved, hard_count_resolved = resolve_counts(
|
|
|
count_args,
|
|
|
n_normal_avail=n_normal_avail,
|
|
|
n_hard_avail=n_hard_avail,
|
|
|
)
|
|
|
|
|
|
# plan 的第 0 列是来源 id,第 1 列是来源文件中的行号;shuffle 后输出数据混排。
|
|
|
plan = build_sample_plan(
|
|
|
rng=rng,
|
|
|
normal_count=normal_count_resolved,
|
|
|
hard_count=hard_count_resolved,
|
|
|
n_normal_avail=n_normal_avail,
|
|
|
n_hard_avail=n_hard_avail,
|
|
|
)
|
|
|
|
|
|
normal_rows = plan[plan[:, 0] == 0, 1]
|
|
|
hard_rows = plan[plan[:, 0] == 1, 1]
|
|
|
merge_meta = {
|
|
|
"tag": tag,
|
|
|
"seed": int(seed),
|
|
|
"normal_input": str(normal_input),
|
|
|
"hard_input": str(hard_input),
|
|
|
"output_path": str(output_path),
|
|
|
"normal_label": str(normal_label),
|
|
|
"hard_label": str(hard_label),
|
|
|
"normal_count": int(normal_count_resolved),
|
|
|
"hard_count": int(hard_count_resolved),
|
|
|
"total_samples": int(normal_count_resolved + hard_count_resolved),
|
|
|
"hard_ratio_actual": float(hard_count_resolved / max(normal_count_resolved + hard_count_resolved, 1)),
|
|
|
"normal_family_counts": summarize_family_counts(normal_file, normal_rows),
|
|
|
"hard_family_counts": summarize_family_counts(hard_file, hard_rows),
|
|
|
}
|
|
|
|
|
|
optional_names = tuple(name for name, schema in optional_schemas.items() if schema is not None)
|
|
|
|
|
|
with create_output_file(
|
|
|
output_path=output_path,
|
|
|
total_rows=int(plan.shape[0]),
|
|
|
required_schemas=required_schemas,
|
|
|
optional_schemas=optional_schemas,
|
|
|
copied_attrs=copied_attrs,
|
|
|
merge_meta=merge_meta,
|
|
|
) as out_file:
|
|
|
batch_size = max(1, int(batch_size))
|
|
|
n_batches = int(math.ceil(plan.shape[0] / batch_size))
|
|
|
for batch_idx in range(n_batches):
|
|
|
# 分批复制可以避免一次性把大 HDF5 数据集全部读进内存。
|
|
|
start = batch_idx * batch_size
|
|
|
end = min(start + batch_size, int(plan.shape[0]))
|
|
|
write_batch(
|
|
|
out_file=out_file,
|
|
|
batch_plan=plan[start:end],
|
|
|
normal_file=normal_file,
|
|
|
hard_file=hard_file,
|
|
|
normal_label=str(normal_label),
|
|
|
hard_label=str(hard_label),
|
|
|
start=start,
|
|
|
optional_names=optional_names,
|
|
|
optional_schemas=optional_schemas,
|
|
|
)
|
|
|
|
|
|
summary_path = output_path.with_suffix(".merge_summary.json")
|
|
|
with open(summary_path, "w", encoding="utf-8") as f:
|
|
|
json.dump(merge_meta, f, ensure_ascii=False, indent=2)
|
|
|
|
|
|
result = dict(merge_meta)
|
|
|
result["summary_path"] = str(summary_path)
|
|
|
return result
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
"""按指定比例抽取 normal/hard 样本并合并为一份可追踪来源的 HDF5。"""
|
|
|
args = parse_args()
|
|
|
merge_meta = merge_datasets(
|
|
|
normal_input=args.normal_input,
|
|
|
hard_input=args.hard_input,
|
|
|
output=args.output,
|
|
|
tag=args.tag,
|
|
|
total_samples=args.total_samples,
|
|
|
hard_ratio=args.hard_ratio,
|
|
|
normal_count=args.normal_count,
|
|
|
hard_count=args.hard_count,
|
|
|
seed=args.seed,
|
|
|
normal_label=args.normal_label,
|
|
|
hard_label=args.hard_label,
|
|
|
batch_size=args.batch_size,
|
|
|
)
|
|
|
|
|
|
print(f"Merged dataset written to: {merge_meta['output_path']}")
|
|
|
print(
|
|
|
f"normal_count={merge_meta['normal_count']}, "
|
|
|
f"hard_count={merge_meta['hard_count']}, "
|
|
|
f"total={merge_meta['total_samples']}, "
|
|
|
f"hard_ratio_actual={merge_meta['hard_ratio_actual']:.4f}"
|
|
|
)
|
|
|
print(f"Merge summary written to: {merge_meta['summary_path']}")
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main()
|