from __future__ import annotations import argparse import json import math import sys from pathlib import Path from typing import Any import h5py import numpy as np ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) REQUIRED_DATASETS = ("params", "schedule", "curve") OPTIONAL_DATASETS = ( "group_id", "schedule_meta", "family_name", "is_anchor", "neighbor_objective", "neighbor_objective_p", "neighbor_objective_d", "neighbor_span_frac", "section_index", "timeQ_json", "q_json", ) COPY_ATTRS = ("param_names", "schedule_meta_names") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Merge normal and hard raw HDF5 datasets into one mixed dataset") parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset") parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset") parser.add_argument("--output", type=str, default=None, help="Optional output .h5 path") parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Tag used for default output naming") parser.add_argument("--total-samples", type=int, default=50000, help="Total rows to export") parser.add_argument("--hard-ratio", type=float, default=0.30, help="Hard-sample fraction when explicit counts are omitted") parser.add_argument("--normal-count", type=int, default=None, help="Explicit normal sample count") parser.add_argument("--hard-count", type=int, default=None, help="Explicit hard sample count") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--normal-label", type=str, default="normal", help="Value written to source_name for normal rows") parser.add_argument("--hard-label", type=str, default="hard", help="Value written to source_name for hard rows") parser.add_argument("--batch-size", type=int, default=4096, help="Output write batch size") return parser.parse_args() def normalize_tag(tag: str | None) -> str: if tag is None: return "merged" cleaned = str(tag).strip().replace(" ", "_").replace("-", "_") return cleaned or "merged" def default_output_path(tag: str) -> Path: return Path("data") / "samples" / f"dataset_{tag}.h5" def _decode_attr_list(raw_names: Any) -> list[str] | None: if raw_names is None: return None return [ item.decode("utf-8") if isinstance(item, (bytes, np.bytes_)) else str(item) for item in raw_names ] def _read_string_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray: if indices.size == 0: return np.asarray([], dtype=object) indices = indices.astype(np.int64) order = np.argsort(indices, kind="stable") inverse = np.empty_like(order) inverse[order] = np.arange(order.size) values_sorted = np.asarray(ds[indices[order]]).astype(str) return values_sorted[inverse] def _read_numeric_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray: if indices.size == 0: shape_tail = ds.shape[1:] return np.empty((0, *shape_tail), dtype=ds.dtype) if shape_tail else np.empty((0,), dtype=ds.dtype) indices = indices.astype(np.int64) order = np.argsort(indices, kind="stable") inverse = np.empty_like(order) inverse[order] = np.arange(order.size) values_sorted = np.asarray(ds[indices[order]]) return values_sorted[inverse] def _assert_attr_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> list[str] | None: left = _decode_attr_list(lhs.attrs.get(name)) right = _decode_attr_list(rhs.attrs.get(name)) if left is None and right is None: return None if left is None or right is None: raise ValueError(f"Attribute mismatch for {name}: one input has it and the other does not") if left != right: raise ValueError(f"Attribute mismatch for {name}: {left} != {right}") return left def _assert_dataset_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> tuple[tuple[int, ...], np.dtype]: if name not in lhs or name not in rhs: raise ValueError(f"Dataset mismatch for {name}: both inputs must contain it") lhs_ds = lhs[name] rhs_ds = rhs[name] if lhs_ds.shape[1:] != rhs_ds.shape[1:]: raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}") if lhs_ds.dtype != rhs_ds.dtype: raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}") return lhs_ds.shape[1:], lhs_ds.dtype def _is_string_dtype(dtype: np.dtype) -> bool: return h5py.check_string_dtype(dtype) is not None or dtype.kind in ("O", "S", "U") def _optional_dataset_schema(name: str, lhs: h5py.File, rhs: h5py.File) -> dict[str, Any] | None: has_l = name in lhs has_r = name in rhs if not has_l and not has_r: return None ref = lhs[name] if has_l else rhs[name] shape_tail = ref.shape[1:] dtype = ref.dtype is_string = _is_string_dtype(dtype) if has_l and has_r: lhs_ds = lhs[name] rhs_ds = rhs[name] if lhs_ds.shape[1:] != rhs_ds.shape[1:]: raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}") if _is_string_dtype(lhs_ds.dtype) != _is_string_dtype(rhs_ds.dtype): raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}") if not is_string and lhs_ds.dtype != rhs_ds.dtype: raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}") return { "shape_tail": tuple(shape_tail), "dtype": dtype, "is_string": bool(is_string), "normal_has": bool(has_l), "hard_has": bool(has_r), } def _default_numeric_value(name: str) -> float | int: if name in {"group_id", "section_index", "is_anchor"}: return -1 return np.nan def resolve_counts(args: argparse.Namespace, n_normal_avail: int, n_hard_avail: int) -> tuple[int, int]: if args.normal_count is not None or args.hard_count is not None: if args.normal_count is None or args.hard_count is None: raise ValueError("Both --normal-count and --hard-count are required when using explicit counts") normal_count = int(args.normal_count) hard_count = int(args.hard_count) else: total = int(args.total_samples) if total <= 0: raise ValueError("--total-samples must be positive") hard_ratio = float(args.hard_ratio) if hard_ratio < 0.0 or hard_ratio > 1.0: raise ValueError("--hard-ratio must be in [0, 1]") hard_count = int(round(total * hard_ratio)) normal_count = total - hard_count if normal_count <= 0 or hard_count <= 0: raise ValueError("Both normal_count and hard_count must be positive") if normal_count > n_normal_avail: raise ValueError(f"Requested normal_count={normal_count}, but only {n_normal_avail} rows are available") if hard_count > n_hard_avail: raise ValueError(f"Requested hard_count={hard_count}, but only {n_hard_avail} rows are available") return normal_count, hard_count def create_output_file( output_path: Path, total_rows: int, required_schemas: dict[str, tuple[tuple[int, ...], np.dtype]], optional_schemas: dict[str, dict[str, Any] | None], copied_attrs: dict[str, list[str] | None], merge_meta: dict[str, Any], ) -> h5py.File: output_path.parent.mkdir(parents=True, exist_ok=True) f = h5py.File(output_path, "w") for name, (shape_tail, dtype) in required_schemas.items(): f.create_dataset(name, shape=(total_rows, *shape_tail), dtype=dtype) for name, schema in optional_schemas.items(): if schema is None: continue shape_tail = tuple(schema["shape_tail"]) dtype = schema["dtype"] if bool(schema["is_string"]): f.create_dataset(name, shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) else: f.create_dataset( name, shape=(total_rows, *shape_tail), dtype=dtype, fillvalue=_default_numeric_value(name), ) f.create_dataset("source_id", shape=(total_rows,), dtype=np.int8) f.create_dataset("source_name", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) f.create_dataset("source_row", shape=(total_rows,), dtype=np.int64) for attr_name, values in copied_attrs.items(): if values is not None: f.attrs[attr_name] = np.asarray(values, dtype="S") f.attrs["n_samples"] = int(total_rows) f.attrs["source_name_vocab"] = np.asarray( [str(merge_meta["normal_label"]), str(merge_meta["hard_label"])], dtype="S", ) f.attrs["merge_meta_json"] = json.dumps(merge_meta, ensure_ascii=False) return f def build_sample_plan( rng: np.random.RandomState, normal_count: int, hard_count: int, n_normal_avail: int, n_hard_avail: int, ) -> np.ndarray: normal_idx = rng.choice(n_normal_avail, size=normal_count, replace=False) hard_idx = rng.choice(n_hard_avail, size=hard_count, replace=False) plan = np.empty((normal_count + hard_count, 2), dtype=np.int64) plan[:normal_count, 0] = 0 plan[:normal_count, 1] = normal_idx plan[normal_count:, 0] = 1 plan[normal_count:, 1] = hard_idx rng.shuffle(plan) return plan def summarize_family_counts(source_file: h5py.File, chosen_rows: np.ndarray) -> dict[str, int]: if "family_name" not in source_file: return {} family_name = _read_string_rows(source_file["family_name"], chosen_rows) unique, counts = np.unique(family_name, return_counts=True) return {str(name): int(count) for name, count in zip(unique.tolist(), counts.tolist())} def write_batch( out_file: h5py.File, batch_plan: np.ndarray, normal_file: h5py.File, hard_file: h5py.File, normal_label: str, hard_label: str, start: int, optional_names: tuple[str, ...], optional_schemas: dict[str, dict[str, Any] | None], ) -> None: batch_size = int(batch_plan.shape[0]) source_col = batch_plan[:, 0] row_col = batch_plan[:, 1] row_positions = np.arange(batch_size, dtype=np.int64) sources = { 0: {"file": normal_file, "label": normal_label}, 1: {"file": hard_file, "label": hard_label}, } for dataset_name in REQUIRED_DATASETS: out_ds = out_file[dataset_name] shape_tail = out_ds.shape[1:] batch_arr = np.empty((batch_size, *shape_tail), dtype=out_ds.dtype) for source_id in (0, 1): mask = source_col == source_id if not np.any(mask): continue positions = row_positions[mask] rows = row_col[mask] batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows) out_ds[start : start + batch_size] = batch_arr for dataset_name in optional_names: schema = optional_schemas[dataset_name] if schema is None: continue if bool(schema["is_string"]): batch_arr = np.full((batch_size,), "", dtype=object) for source_id in (0, 1): mask = source_col == source_id if not np.any(mask): continue has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"]) if not has_dataset: continue positions = row_positions[mask] rows = row_col[mask] batch_arr[positions] = _read_string_rows(sources[source_id]["file"][dataset_name], rows) out_file[dataset_name][start : start + batch_size] = batch_arr.tolist() continue out_ds = out_file[dataset_name] shape_tail = out_ds.shape[1:] batch_arr = np.full( (batch_size, *shape_tail), _default_numeric_value(dataset_name), dtype=out_ds.dtype, ) for source_id in (0, 1): mask = source_col == source_id if not np.any(mask): continue has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"]) if not has_dataset: continue positions = row_positions[mask] rows = row_col[mask] batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows) out_ds[start : start + batch_size] = batch_arr out_file["source_id"][start : start + batch_size] = source_col.astype(np.int8) out_file["source_name"][start : start + batch_size] = [ sources[int(source_id)]["label"] for source_id in source_col ] out_file["source_row"][start : start + batch_size] = row_col.astype(np.int64) def merge_datasets( normal_input: str | Path, hard_input: str | Path, output: str | Path | None = None, tag: str | None = "family_random_mixed_50k", total_samples: int = 50000, hard_ratio: float = 0.30, normal_count: int | None = None, hard_count: int | None = None, seed: int = 42, normal_label: str = "normal", hard_label: str = "hard", batch_size: int = 4096, ) -> dict[str, Any]: tag = normalize_tag(tag) output_path = Path(output) if output is not None else default_output_path(tag) rng = np.random.RandomState(int(seed)) normal_input = Path(normal_input) hard_input = Path(hard_input) if not normal_input.exists(): raise FileNotFoundError(f"normal_input not found: {normal_input}") if not hard_input.exists(): raise FileNotFoundError(f"hard_input not found: {hard_input}") merge_meta: dict[str, Any] class _Args: pass count_args = _Args() count_args.normal_count = normal_count count_args.hard_count = hard_count count_args.total_samples = total_samples count_args.hard_ratio = hard_ratio with h5py.File(normal_input, "r") as normal_file, h5py.File(hard_input, "r") as hard_file: required_schemas = { name: _assert_dataset_compatible(name, normal_file, hard_file) for name in REQUIRED_DATASETS } optional_schemas = { name: _optional_dataset_schema(name, normal_file, hard_file) for name in OPTIONAL_DATASETS } copied_attrs = { name: _assert_attr_compatible(name, normal_file, hard_file) for name in COPY_ATTRS } n_normal_avail = int(normal_file["params"].shape[0]) n_hard_avail = int(hard_file["params"].shape[0]) normal_count_resolved, hard_count_resolved = resolve_counts( count_args, n_normal_avail=n_normal_avail, n_hard_avail=n_hard_avail, ) plan = build_sample_plan( rng=rng, normal_count=normal_count_resolved, hard_count=hard_count_resolved, n_normal_avail=n_normal_avail, n_hard_avail=n_hard_avail, ) normal_rows = plan[plan[:, 0] == 0, 1] hard_rows = plan[plan[:, 0] == 1, 1] merge_meta = { "tag": tag, "seed": int(seed), "normal_input": str(normal_input), "hard_input": str(hard_input), "output_path": str(output_path), "normal_label": str(normal_label), "hard_label": str(hard_label), "normal_count": int(normal_count_resolved), "hard_count": int(hard_count_resolved), "total_samples": int(normal_count_resolved + hard_count_resolved), "hard_ratio_actual": float(hard_count_resolved / max(normal_count_resolved + hard_count_resolved, 1)), "normal_family_counts": summarize_family_counts(normal_file, normal_rows), "hard_family_counts": summarize_family_counts(hard_file, hard_rows), } optional_names = tuple(name for name, schema in optional_schemas.items() if schema is not None) with create_output_file( output_path=output_path, total_rows=int(plan.shape[0]), required_schemas=required_schemas, optional_schemas=optional_schemas, copied_attrs=copied_attrs, merge_meta=merge_meta, ) as out_file: batch_size = max(1, int(batch_size)) n_batches = int(math.ceil(plan.shape[0] / batch_size)) for batch_idx in range(n_batches): start = batch_idx * batch_size end = min(start + batch_size, int(plan.shape[0])) write_batch( out_file=out_file, batch_plan=plan[start:end], normal_file=normal_file, hard_file=hard_file, normal_label=str(normal_label), hard_label=str(hard_label), start=start, optional_names=optional_names, optional_schemas=optional_schemas, ) summary_path = output_path.with_suffix(".merge_summary.json") with open(summary_path, "w", encoding="utf-8") as f: json.dump(merge_meta, f, ensure_ascii=False, indent=2) result = dict(merge_meta) result["summary_path"] = str(summary_path) return result def main() -> None: args = parse_args() merge_meta = merge_datasets( normal_input=args.normal_input, hard_input=args.hard_input, output=args.output, tag=args.tag, total_samples=args.total_samples, hard_ratio=args.hard_ratio, normal_count=args.normal_count, hard_count=args.hard_count, seed=args.seed, normal_label=args.normal_label, hard_label=args.hard_label, batch_size=args.batch_size, ) print(f"Merged dataset written to: {merge_meta['output_path']}") print( f"normal_count={merge_meta['normal_count']}, " f"hard_count={merge_meta['hard_count']}, " f"total={merge_meta['total_samples']}, " f"hard_ratio_actual={merge_meta['hard_ratio_actual']:.4f}" ) print(f"Merge summary written to: {merge_meta['summary_path']}") if __name__ == "__main__": main()