"""合并常规数据集与困难样本数据集。 脚本校验两个 HDF5 的核心字段、属性和可选元数据是否兼容,按用户指定比例或数量抽样, 再写出一个结构一致的混合集。该流程用于把全局随机样本与自动拟合邻域样本混合, 提升代理模型在实际反演困难区域的覆盖度。 """ # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,no-member from __future__ import annotations import argparse import json import math import sys from pathlib import Path from types import SimpleNamespace from typing import Any import h5py import numpy as np ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) REQUIRED_DATASETS = ("params", "schedule", "curve") OPTIONAL_DATASETS = ( "group_id", "schedule_meta", "family_name", "is_anchor", "neighbor_objective", "neighbor_objective_p", "neighbor_objective_d", "neighbor_span_frac", "section_index", "timeQ_json", "q_json", ) COPY_ATTRS = ("param_names", "schedule_meta_names") def parse_args() -> argparse.Namespace: """解析两份 HDF5 数据集的抽样数量、混合比例、标签和输出路径。""" parser = argparse.ArgumentParser(description="Merge normal and hard raw HDF5 datasets into one mixed dataset") parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset") parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset") parser.add_argument("--output", type=str, default=None, help="Optional output .h5 path") parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Tag used for default output naming") parser.add_argument("--total-samples", type=int, default=50000, help="Total rows to export") parser.add_argument("--hard-ratio", type=float, default=0.30, help="Hard-sample fraction when explicit counts are omitted") parser.add_argument("--normal-count", type=int, default=None, help="Explicit normal sample count") parser.add_argument("--hard-count", type=int, default=None, help="Explicit hard sample count") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--normal-label", type=str, default="normal", help="Value written to source_name for normal rows") parser.add_argument("--hard-label", type=str, default="hard", help="Value written to source_name for hard rows") parser.add_argument("--batch-size", type=int, default=4096, help="Output write batch size") return parser.parse_args() def normalize_tag(tag: str | None) -> str: """规范化实验标签,避免空字符串、大小写或空白导致路径不一致。""" if tag is None: return "merged" cleaned = str(tag).strip().replace(" ", "_").replace("-", "_") return cleaned or "merged" def default_output_path(tag: str) -> Path: """根据合并标签生成默认 HDF5 输出文件路径。""" return Path("data") / "samples" / f"dataset_{tag}.h5" def _decode_attr_list(raw_names: Any) -> list[str] | None: """解码 HDF5 属性中保存的 JSON 字符串或字节串列表。""" if raw_names is None: return None return [ item.decode("utf-8") if isinstance(item, (bytes, np.bytes_)) else str(item) for item in raw_names ] def _read_string_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray: """从 HDF5 字符串数据集中读取指定行,并统一转成 Python 字符串。""" if indices.size == 0: return np.asarray([], dtype=object) indices = indices.astype(np.int64) order = np.argsort(indices, kind="stable") inverse = np.empty_like(order) inverse[order] = np.arange(order.size) values_sorted = np.asarray(ds[indices[order]]).astype(str) return values_sorted[inverse] def _read_numeric_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray: """从 HDF5 数值数据集中读取指定行。""" if indices.size == 0: shape_tail = ds.shape[1:] return np.empty((0, *shape_tail), dtype=ds.dtype) if shape_tail else np.empty((0,), dtype=ds.dtype) indices = indices.astype(np.int64) order = np.argsort(indices, kind="stable") inverse = np.empty_like(order) inverse[order] = np.arange(order.size) values_sorted = np.asarray(ds[indices[order]]) return values_sorted[inverse] def _assert_attr_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> list[str] | None: """检查多个输入文件的关键属性是否一致,避免合并不同 schema。""" left = _decode_attr_list(lhs.attrs.get(name)) right = _decode_attr_list(rhs.attrs.get(name)) if left is None and right is None: return None if left is None or right is None: raise ValueError(f"Attribute mismatch for {name}: one input has it and the other does not") if left != right: raise ValueError(f"Attribute mismatch for {name}: {left} != {right}") return left def _assert_dataset_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> tuple[tuple[int, ...], np.dtype]: """检查多个输入文件同名数据集的维度和 dtype 是否可合并。""" if name not in lhs or name not in rhs: raise ValueError(f"Dataset mismatch for {name}: both inputs must contain it") lhs_ds = lhs[name] rhs_ds = rhs[name] if lhs_ds.shape[1:] != rhs_ds.shape[1:]: raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}") if lhs_ds.dtype != rhs_ds.dtype: raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}") return lhs_ds.shape[1:], lhs_ds.dtype def _is_string_dtype(dtype: np.dtype) -> bool: """判断 HDF5 数据集是否使用字符串 dtype。""" return h5py.check_string_dtype(dtype) is not None or dtype.kind in ("O", "S", "U") def _optional_dataset_schema(name: str, lhs: h5py.File, rhs: h5py.File) -> dict[str, Any] | None: """读取可选数据集的 dtype 和尾部形状,用于在输出文件中创建同 schema 数据集。""" has_l = name in lhs has_r = name in rhs if not has_l and not has_r: return None ref = lhs[name] if has_l else rhs[name] shape_tail = ref.shape[1:] dtype = ref.dtype is_string = _is_string_dtype(dtype) if has_l and has_r: lhs_ds = lhs[name] rhs_ds = rhs[name] if lhs_ds.shape[1:] != rhs_ds.shape[1:]: raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}") if _is_string_dtype(lhs_ds.dtype) != _is_string_dtype(rhs_ds.dtype): raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}") if not is_string and lhs_ds.dtype != rhs_ds.dtype: raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}") return { "shape_tail": tuple(shape_tail), "dtype": dtype, "is_string": bool(is_string), "normal_has": bool(has_l), "hard_has": bool(has_r), } def _default_numeric_value(name: str) -> float | int: """为缺失的可选数值数据集生成合适的填充值。""" if name in {"group_id", "section_index", "is_anchor"}: return -1 return np.nan def resolve_counts(args: argparse.Namespace, n_normal_avail: int, n_hard_avail: int) -> tuple[int, int]: """根据每个输入文件样本量和用户给定比例/数量计算抽样数量。""" if args.normal_count is not None or args.hard_count is not None: if args.normal_count is None or args.hard_count is None: raise ValueError("Both --normal-count and --hard-count are required when using explicit counts") normal_count = int(args.normal_count) hard_count = int(args.hard_count) else: total = int(args.total_samples) if total <= 0: raise ValueError("--total-samples must be positive") hard_ratio = float(args.hard_ratio) if hard_ratio < 0.0 or hard_ratio > 1.0: raise ValueError("--hard-ratio must be in [0, 1]") hard_count = int(round(total * hard_ratio)) normal_count = total - hard_count if normal_count <= 0 or hard_count <= 0: raise ValueError("Both normal_count and hard_count must be positive") if normal_count > n_normal_avail: raise ValueError(f"Requested normal_count={normal_count}, but only {n_normal_avail} rows are available") if hard_count > n_hard_avail: raise ValueError(f"Requested hard_count={hard_count}, but only {n_hard_avail} rows are available") return normal_count, hard_count def create_output_file( output_path: Path, total_rows: int, required_schemas: dict[str, tuple[tuple[int, ...], np.dtype]], optional_schemas: dict[str, dict[str, Any] | None], copied_attrs: dict[str, list[str] | None], merge_meta: dict[str, Any], ) -> h5py.File: """创建邻域 HDF5 文件,并初始化锚点、候选、曲线和元数据数据集。""" output_path.parent.mkdir(parents=True, exist_ok=True) f = h5py.File(output_path, "w") # 必需数据集必须在两个来源中完全兼容,合并输出直接按最终行数一次性创建。 for name, (shape_tail, dtype) in required_schemas.items(): f.create_dataset(name, shape=(total_rows, *shape_tail), dtype=dtype) # 可选数据集允许某个来源缺失;缺失行会写入默认值或空字符串。 for name, schema in optional_schemas.items(): if schema is None: continue shape_tail = tuple(schema["shape_tail"]) dtype = schema["dtype"] if bool(schema["is_string"]): f.create_dataset(name, shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) else: f.create_dataset( name, shape=(total_rows, *shape_tail), dtype=dtype, fillvalue=_default_numeric_value(name), ) f.create_dataset("source_id", shape=(total_rows,), dtype=np.int8) f.create_dataset("source_name", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) f.create_dataset("source_row", shape=(total_rows,), dtype=np.int64) # source_* 字段记录每一行来自哪个输入文件,方便后续误差归因和抽样审计。 for attr_name, values in copied_attrs.items(): if values is not None: f.attrs[attr_name] = np.asarray(values, dtype="S") f.attrs["n_samples"] = int(total_rows) f.attrs["source_name_vocab"] = np.asarray( [str(merge_meta["normal_label"]), str(merge_meta["hard_label"])], dtype="S", ) f.attrs["merge_meta_json"] = json.dumps(merge_meta, ensure_ascii=False) return f def build_sample_plan( rng: np.random.RandomState, normal_count: int, hard_count: int, n_normal_avail: int, n_hard_avail: int, ) -> np.ndarray: """为每个输入文件生成需要复制的行索引和来源标签。""" normal_idx = rng.choice(n_normal_avail, size=normal_count, replace=False) hard_idx = rng.choice(n_hard_avail, size=hard_count, replace=False) plan = np.empty((normal_count + hard_count, 2), dtype=np.int64) plan[:normal_count, 0] = 0 plan[:normal_count, 1] = normal_idx plan[normal_count:, 0] = 1 plan[normal_count:, 1] = hard_idx rng.shuffle(plan) return plan def summarize_family_counts(source_file: h5py.File, chosen_rows: np.ndarray) -> dict[str, int]: """统计合并后各流量制度族的样本数量。""" if "family_name" not in source_file: return {} family_name = _read_string_rows(source_file["family_name"], chosen_rows) unique, counts = np.unique(family_name, return_counts=True) return {str(name): int(count) for name, count in zip(unique.tolist(), counts.tolist())} def write_batch( out_file: h5py.File, batch_plan: np.ndarray, normal_file: h5py.File, hard_file: h5py.File, normal_label: str, hard_label: str, start: int, optional_names: tuple[str, ...], optional_schemas: dict[str, dict[str, Any] | None], ) -> None: """把一批选中的源样本复制到合并输出文件。""" batch_size = int(batch_plan.shape[0]) source_col = batch_plan[:, 0] row_col = batch_plan[:, 1] row_positions = np.arange(batch_size, dtype=np.int64) sources = { 0: {"file": normal_file, "label": normal_label}, 1: {"file": hard_file, "label": hard_label}, } for dataset_name in REQUIRED_DATASETS: out_ds = out_file[dataset_name] shape_tail = out_ds.shape[1:] batch_arr = np.empty((batch_size, *shape_tail), dtype=out_ds.dtype) # 同一个 batch 中 normal/hard 行交错排列,先按来源分组读取,再放回原位置。 for source_id in (0, 1): mask = source_col == source_id if not np.any(mask): continue positions = row_positions[mask] rows = row_col[mask] batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows) out_ds[start : start + batch_size] = batch_arr for dataset_name in optional_names: schema = optional_schemas[dataset_name] if schema is None: continue if bool(schema["is_string"]): # 字符串可选字段没有数值 fillvalue,缺失来源统一填空字符串。 batch_arr = np.full((batch_size,), "", dtype=object) for source_id in (0, 1): mask = source_col == source_id if not np.any(mask): continue has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"]) if not has_dataset: continue positions = row_positions[mask] rows = row_col[mask] batch_arr[positions] = _read_string_rows(sources[source_id]["file"][dataset_name], rows) out_file[dataset_name][start : start + batch_size] = batch_arr.tolist() continue out_ds = out_file[dataset_name] shape_tail = out_ds.shape[1:] batch_arr = np.full( (batch_size, *shape_tail), _default_numeric_value(dataset_name), dtype=out_ds.dtype, ) # 数值可选字段先填默认值,只覆盖实际存在于来源文件中的部分。 for source_id in (0, 1): mask = source_col == source_id if not np.any(mask): continue has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"]) if not has_dataset: continue positions = row_positions[mask] rows = row_col[mask] batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows) out_ds[start : start + batch_size] = batch_arr out_file["source_id"][start : start + batch_size] = source_col.astype(np.int8) out_file["source_name"][start : start + batch_size] = [ sources[int(source_id)]["label"] for source_id in source_col ] out_file["source_row"][start : start + batch_size] = row_col.astype(np.int64) def merge_datasets( normal_input: str | Path, hard_input: str | Path, output: str | Path | None = None, tag: str | None = "family_random_mixed_50k", total_samples: int = 50000, hard_ratio: float = 0.30, normal_count: int | None = None, hard_count: int | None = None, seed: int = 42, normal_label: str = "normal", hard_label: str = "hard", batch_size: int = 4096, ) -> dict[str, Any]: """执行 HDF5 数据集合并,复制主数据集、可选数据集和元数据,并写出摘要。""" tag = normalize_tag(tag) output_path = Path(output) if output is not None else default_output_path(tag) rng = np.random.RandomState(int(seed)) normal_input = Path(normal_input) hard_input = Path(hard_input) if not normal_input.exists(): raise FileNotFoundError(f"normal_input not found: {normal_input}") if not hard_input.exists(): raise FileNotFoundError(f"hard_input not found: {hard_input}") merge_meta: dict[str, Any] count_args = SimpleNamespace( normal_count=normal_count, hard_count=hard_count, total_samples=total_samples, hard_ratio=hard_ratio, ) with h5py.File(normal_input, "r") as normal_file, h5py.File(hard_input, "r") as hard_file: # 合并前先锁定 schema,避免后面边写边发现字段不兼容。 required_schemas = { name: _assert_dataset_compatible(name, normal_file, hard_file) for name in REQUIRED_DATASETS } optional_schemas = { name: _optional_dataset_schema(name, normal_file, hard_file) for name in OPTIONAL_DATASETS } copied_attrs = { name: _assert_attr_compatible(name, normal_file, hard_file) for name in COPY_ATTRS } n_normal_avail = int(normal_file["params"].shape[0]) n_hard_avail = int(hard_file["params"].shape[0]) normal_count_resolved, hard_count_resolved = resolve_counts( count_args, n_normal_avail=n_normal_avail, n_hard_avail=n_hard_avail, ) # plan 的第 0 列是来源 id,第 1 列是来源文件中的行号;shuffle 后输出数据混排。 plan = build_sample_plan( rng=rng, normal_count=normal_count_resolved, hard_count=hard_count_resolved, n_normal_avail=n_normal_avail, n_hard_avail=n_hard_avail, ) normal_rows = plan[plan[:, 0] == 0, 1] hard_rows = plan[plan[:, 0] == 1, 1] merge_meta = { "tag": tag, "seed": int(seed), "normal_input": str(normal_input), "hard_input": str(hard_input), "output_path": str(output_path), "normal_label": str(normal_label), "hard_label": str(hard_label), "normal_count": int(normal_count_resolved), "hard_count": int(hard_count_resolved), "total_samples": int(normal_count_resolved + hard_count_resolved), "hard_ratio_actual": float(hard_count_resolved / max(normal_count_resolved + hard_count_resolved, 1)), "normal_family_counts": summarize_family_counts(normal_file, normal_rows), "hard_family_counts": summarize_family_counts(hard_file, hard_rows), } optional_names = tuple(name for name, schema in optional_schemas.items() if schema is not None) with create_output_file( output_path=output_path, total_rows=int(plan.shape[0]), required_schemas=required_schemas, optional_schemas=optional_schemas, copied_attrs=copied_attrs, merge_meta=merge_meta, ) as out_file: batch_size = max(1, int(batch_size)) n_batches = int(math.ceil(plan.shape[0] / batch_size)) for batch_idx in range(n_batches): # 分批复制可以避免一次性把大 HDF5 数据集全部读进内存。 start = batch_idx * batch_size end = min(start + batch_size, int(plan.shape[0])) write_batch( out_file=out_file, batch_plan=plan[start:end], normal_file=normal_file, hard_file=hard_file, normal_label=str(normal_label), hard_label=str(hard_label), start=start, optional_names=optional_names, optional_schemas=optional_schemas, ) summary_path = output_path.with_suffix(".merge_summary.json") with open(summary_path, "w", encoding="utf-8") as f: json.dump(merge_meta, f, ensure_ascii=False, indent=2) result = dict(merge_meta) result["summary_path"] = str(summary_path) return result def main() -> None: """按指定比例抽取 normal/hard 样本并合并为一份可追踪来源的 HDF5。""" args = parse_args() merge_meta = merge_datasets( normal_input=args.normal_input, hard_input=args.hard_input, output=args.output, tag=args.tag, total_samples=args.total_samples, hard_ratio=args.hard_ratio, normal_count=args.normal_count, hard_count=args.hard_count, seed=args.seed, normal_label=args.normal_label, hard_label=args.hard_label, batch_size=args.batch_size, ) print(f"Merged dataset written to: {merge_meta['output_path']}") print( f"normal_count={merge_meta['normal_count']}, " f"hard_count={merge_meta['hard_count']}, " f"total={merge_meta['total_samples']}, " f"hard_ratio_actual={merge_meta['hard_ratio_actual']:.4f}" ) print(f"Merge summary written to: {merge_meta['summary_path']}") if __name__ == "__main__": main()