You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/merge_datasets.py

485 lines
18 KiB
Python

from __future__ import annotations
import argparse
import json
import math
import sys
from pathlib import Path
from typing import Any
import h5py
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
REQUIRED_DATASETS = ("params", "schedule", "curve")
OPTIONAL_DATASETS = (
"group_id",
"schedule_meta",
"family_name",
"is_anchor",
"neighbor_objective",
"neighbor_objective_p",
"neighbor_objective_d",
"neighbor_span_frac",
"section_index",
"timeQ_json",
"q_json",
)
COPY_ATTRS = ("param_names", "schedule_meta_names")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Merge normal and hard raw HDF5 datasets into one mixed dataset")
parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset")
parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset")
parser.add_argument("--output", type=str, default=None, help="Optional output .h5 path")
parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Tag used for default output naming")
parser.add_argument("--total-samples", type=int, default=50000, help="Total rows to export")
parser.add_argument("--hard-ratio", type=float, default=0.30, help="Hard-sample fraction when explicit counts are omitted")
parser.add_argument("--normal-count", type=int, default=None, help="Explicit normal sample count")
parser.add_argument("--hard-count", type=int, default=None, help="Explicit hard sample count")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--normal-label", type=str, default="normal", help="Value written to source_name for normal rows")
parser.add_argument("--hard-label", type=str, default="hard", help="Value written to source_name for hard rows")
parser.add_argument("--batch-size", type=int, default=4096, help="Output write batch size")
return parser.parse_args()
def normalize_tag(tag: str | None) -> str:
if tag is None:
return "merged"
cleaned = str(tag).strip().replace(" ", "_").replace("-", "_")
return cleaned or "merged"
def default_output_path(tag: str) -> Path:
return Path("data") / "samples" / f"dataset_{tag}.h5"
def _decode_attr_list(raw_names: Any) -> list[str] | None:
if raw_names is None:
return None
return [
item.decode("utf-8") if isinstance(item, (bytes, np.bytes_)) else str(item)
for item in raw_names
]
def _read_string_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray:
if indices.size == 0:
return np.asarray([], dtype=object)
indices = indices.astype(np.int64)
order = np.argsort(indices, kind="stable")
inverse = np.empty_like(order)
inverse[order] = np.arange(order.size)
values_sorted = np.asarray(ds[indices[order]]).astype(str)
return values_sorted[inverse]
def _read_numeric_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray:
if indices.size == 0:
shape_tail = ds.shape[1:]
return np.empty((0, *shape_tail), dtype=ds.dtype) if shape_tail else np.empty((0,), dtype=ds.dtype)
indices = indices.astype(np.int64)
order = np.argsort(indices, kind="stable")
inverse = np.empty_like(order)
inverse[order] = np.arange(order.size)
values_sorted = np.asarray(ds[indices[order]])
return values_sorted[inverse]
def _assert_attr_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> list[str] | None:
left = _decode_attr_list(lhs.attrs.get(name))
right = _decode_attr_list(rhs.attrs.get(name))
if left is None and right is None:
return None
if left is None or right is None:
raise ValueError(f"Attribute mismatch for {name}: one input has it and the other does not")
if left != right:
raise ValueError(f"Attribute mismatch for {name}: {left} != {right}")
return left
def _assert_dataset_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> tuple[tuple[int, ...], np.dtype]:
if name not in lhs or name not in rhs:
raise ValueError(f"Dataset mismatch for {name}: both inputs must contain it")
lhs_ds = lhs[name]
rhs_ds = rhs[name]
if lhs_ds.shape[1:] != rhs_ds.shape[1:]:
raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}")
if lhs_ds.dtype != rhs_ds.dtype:
raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}")
return lhs_ds.shape[1:], lhs_ds.dtype
def _is_string_dtype(dtype: np.dtype) -> bool:
return h5py.check_string_dtype(dtype) is not None or dtype.kind in ("O", "S", "U")
def _optional_dataset_schema(name: str, lhs: h5py.File, rhs: h5py.File) -> dict[str, Any] | None:
has_l = name in lhs
has_r = name in rhs
if not has_l and not has_r:
return None
ref = lhs[name] if has_l else rhs[name]
shape_tail = ref.shape[1:]
dtype = ref.dtype
is_string = _is_string_dtype(dtype)
if has_l and has_r:
lhs_ds = lhs[name]
rhs_ds = rhs[name]
if lhs_ds.shape[1:] != rhs_ds.shape[1:]:
raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}")
if _is_string_dtype(lhs_ds.dtype) != _is_string_dtype(rhs_ds.dtype):
raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}")
if not is_string and lhs_ds.dtype != rhs_ds.dtype:
raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}")
return {
"shape_tail": tuple(shape_tail),
"dtype": dtype,
"is_string": bool(is_string),
"normal_has": bool(has_l),
"hard_has": bool(has_r),
}
def _default_numeric_value(name: str) -> float | int:
if name in {"group_id", "section_index", "is_anchor"}:
return -1
return np.nan
def resolve_counts(args: argparse.Namespace, n_normal_avail: int, n_hard_avail: int) -> tuple[int, int]:
if args.normal_count is not None or args.hard_count is not None:
if args.normal_count is None or args.hard_count is None:
raise ValueError("Both --normal-count and --hard-count are required when using explicit counts")
normal_count = int(args.normal_count)
hard_count = int(args.hard_count)
else:
total = int(args.total_samples)
if total <= 0:
raise ValueError("--total-samples must be positive")
hard_ratio = float(args.hard_ratio)
if hard_ratio < 0.0 or hard_ratio > 1.0:
raise ValueError("--hard-ratio must be in [0, 1]")
hard_count = int(round(total * hard_ratio))
normal_count = total - hard_count
if normal_count <= 0 or hard_count <= 0:
raise ValueError("Both normal_count and hard_count must be positive")
if normal_count > n_normal_avail:
raise ValueError(f"Requested normal_count={normal_count}, but only {n_normal_avail} rows are available")
if hard_count > n_hard_avail:
raise ValueError(f"Requested hard_count={hard_count}, but only {n_hard_avail} rows are available")
return normal_count, hard_count
def create_output_file(
output_path: Path,
total_rows: int,
required_schemas: dict[str, tuple[tuple[int, ...], np.dtype]],
optional_schemas: dict[str, dict[str, Any] | None],
copied_attrs: dict[str, list[str] | None],
merge_meta: dict[str, Any],
) -> h5py.File:
output_path.parent.mkdir(parents=True, exist_ok=True)
f = h5py.File(output_path, "w")
for name, (shape_tail, dtype) in required_schemas.items():
f.create_dataset(name, shape=(total_rows, *shape_tail), dtype=dtype)
for name, schema in optional_schemas.items():
if schema is None:
continue
shape_tail = tuple(schema["shape_tail"])
dtype = schema["dtype"]
if bool(schema["is_string"]):
f.create_dataset(name, shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
else:
f.create_dataset(
name,
shape=(total_rows, *shape_tail),
dtype=dtype,
fillvalue=_default_numeric_value(name),
)
f.create_dataset("source_id", shape=(total_rows,), dtype=np.int8)
f.create_dataset("source_name", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
f.create_dataset("source_row", shape=(total_rows,), dtype=np.int64)
for attr_name, values in copied_attrs.items():
if values is not None:
f.attrs[attr_name] = np.asarray(values, dtype="S")
f.attrs["n_samples"] = int(total_rows)
f.attrs["source_name_vocab"] = np.asarray(
[str(merge_meta["normal_label"]), str(merge_meta["hard_label"])],
dtype="S",
)
f.attrs["merge_meta_json"] = json.dumps(merge_meta, ensure_ascii=False)
return f
def build_sample_plan(
rng: np.random.RandomState,
normal_count: int,
hard_count: int,
n_normal_avail: int,
n_hard_avail: int,
) -> np.ndarray:
normal_idx = rng.choice(n_normal_avail, size=normal_count, replace=False)
hard_idx = rng.choice(n_hard_avail, size=hard_count, replace=False)
plan = np.empty((normal_count + hard_count, 2), dtype=np.int64)
plan[:normal_count, 0] = 0
plan[:normal_count, 1] = normal_idx
plan[normal_count:, 0] = 1
plan[normal_count:, 1] = hard_idx
rng.shuffle(plan)
return plan
def summarize_family_counts(source_file: h5py.File, chosen_rows: np.ndarray) -> dict[str, int]:
if "family_name" not in source_file:
return {}
family_name = _read_string_rows(source_file["family_name"], chosen_rows)
unique, counts = np.unique(family_name, return_counts=True)
return {str(name): int(count) for name, count in zip(unique.tolist(), counts.tolist())}
def write_batch(
out_file: h5py.File,
batch_plan: np.ndarray,
normal_file: h5py.File,
hard_file: h5py.File,
normal_label: str,
hard_label: str,
start: int,
optional_names: tuple[str, ...],
optional_schemas: dict[str, dict[str, Any] | None],
) -> None:
batch_size = int(batch_plan.shape[0])
source_col = batch_plan[:, 0]
row_col = batch_plan[:, 1]
row_positions = np.arange(batch_size, dtype=np.int64)
sources = {
0: {"file": normal_file, "label": normal_label},
1: {"file": hard_file, "label": hard_label},
}
for dataset_name in REQUIRED_DATASETS:
out_ds = out_file[dataset_name]
shape_tail = out_ds.shape[1:]
batch_arr = np.empty((batch_size, *shape_tail), dtype=out_ds.dtype)
for source_id in (0, 1):
mask = source_col == source_id
if not np.any(mask):
continue
positions = row_positions[mask]
rows = row_col[mask]
batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows)
out_ds[start : start + batch_size] = batch_arr
for dataset_name in optional_names:
schema = optional_schemas[dataset_name]
if schema is None:
continue
if bool(schema["is_string"]):
batch_arr = np.full((batch_size,), "", dtype=object)
for source_id in (0, 1):
mask = source_col == source_id
if not np.any(mask):
continue
has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"])
if not has_dataset:
continue
positions = row_positions[mask]
rows = row_col[mask]
batch_arr[positions] = _read_string_rows(sources[source_id]["file"][dataset_name], rows)
out_file[dataset_name][start : start + batch_size] = batch_arr.tolist()
continue
out_ds = out_file[dataset_name]
shape_tail = out_ds.shape[1:]
batch_arr = np.full(
(batch_size, *shape_tail),
_default_numeric_value(dataset_name),
dtype=out_ds.dtype,
)
for source_id in (0, 1):
mask = source_col == source_id
if not np.any(mask):
continue
has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"])
if not has_dataset:
continue
positions = row_positions[mask]
rows = row_col[mask]
batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows)
out_ds[start : start + batch_size] = batch_arr
out_file["source_id"][start : start + batch_size] = source_col.astype(np.int8)
out_file["source_name"][start : start + batch_size] = [
sources[int(source_id)]["label"] for source_id in source_col
]
out_file["source_row"][start : start + batch_size] = row_col.astype(np.int64)
def merge_datasets(
normal_input: str | Path,
hard_input: str | Path,
output: str | Path | None = None,
tag: str | None = "family_random_mixed_50k",
total_samples: int = 50000,
hard_ratio: float = 0.30,
normal_count: int | None = None,
hard_count: int | None = None,
seed: int = 42,
normal_label: str = "normal",
hard_label: str = "hard",
batch_size: int = 4096,
) -> dict[str, Any]:
tag = normalize_tag(tag)
output_path = Path(output) if output is not None else default_output_path(tag)
rng = np.random.RandomState(int(seed))
normal_input = Path(normal_input)
hard_input = Path(hard_input)
if not normal_input.exists():
raise FileNotFoundError(f"normal_input not found: {normal_input}")
if not hard_input.exists():
raise FileNotFoundError(f"hard_input not found: {hard_input}")
merge_meta: dict[str, Any]
class _Args:
pass
count_args = _Args()
count_args.normal_count = normal_count
count_args.hard_count = hard_count
count_args.total_samples = total_samples
count_args.hard_ratio = hard_ratio
with h5py.File(normal_input, "r") as normal_file, h5py.File(hard_input, "r") as hard_file:
required_schemas = {
name: _assert_dataset_compatible(name, normal_file, hard_file)
for name in REQUIRED_DATASETS
}
optional_schemas = {
name: _optional_dataset_schema(name, normal_file, hard_file)
for name in OPTIONAL_DATASETS
}
copied_attrs = {
name: _assert_attr_compatible(name, normal_file, hard_file)
for name in COPY_ATTRS
}
n_normal_avail = int(normal_file["params"].shape[0])
n_hard_avail = int(hard_file["params"].shape[0])
normal_count_resolved, hard_count_resolved = resolve_counts(
count_args,
n_normal_avail=n_normal_avail,
n_hard_avail=n_hard_avail,
)
plan = build_sample_plan(
rng=rng,
normal_count=normal_count_resolved,
hard_count=hard_count_resolved,
n_normal_avail=n_normal_avail,
n_hard_avail=n_hard_avail,
)
normal_rows = plan[plan[:, 0] == 0, 1]
hard_rows = plan[plan[:, 0] == 1, 1]
merge_meta = {
"tag": tag,
"seed": int(seed),
"normal_input": str(normal_input),
"hard_input": str(hard_input),
"output_path": str(output_path),
"normal_label": str(normal_label),
"hard_label": str(hard_label),
"normal_count": int(normal_count_resolved),
"hard_count": int(hard_count_resolved),
"total_samples": int(normal_count_resolved + hard_count_resolved),
"hard_ratio_actual": float(hard_count_resolved / max(normal_count_resolved + hard_count_resolved, 1)),
"normal_family_counts": summarize_family_counts(normal_file, normal_rows),
"hard_family_counts": summarize_family_counts(hard_file, hard_rows),
}
optional_names = tuple(name for name, schema in optional_schemas.items() if schema is not None)
with create_output_file(
output_path=output_path,
total_rows=int(plan.shape[0]),
required_schemas=required_schemas,
optional_schemas=optional_schemas,
copied_attrs=copied_attrs,
merge_meta=merge_meta,
) as out_file:
batch_size = max(1, int(batch_size))
n_batches = int(math.ceil(plan.shape[0] / batch_size))
for batch_idx in range(n_batches):
start = batch_idx * batch_size
end = min(start + batch_size, int(plan.shape[0]))
write_batch(
out_file=out_file,
batch_plan=plan[start:end],
normal_file=normal_file,
hard_file=hard_file,
normal_label=str(normal_label),
hard_label=str(hard_label),
start=start,
optional_names=optional_names,
optional_schemas=optional_schemas,
)
summary_path = output_path.with_suffix(".merge_summary.json")
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(merge_meta, f, ensure_ascii=False, indent=2)
result = dict(merge_meta)
result["summary_path"] = str(summary_path)
return result
def main() -> None:
args = parse_args()
merge_meta = merge_datasets(
normal_input=args.normal_input,
hard_input=args.hard_input,
output=args.output,
tag=args.tag,
total_samples=args.total_samples,
hard_ratio=args.hard_ratio,
normal_count=args.normal_count,
hard_count=args.hard_count,
seed=args.seed,
normal_label=args.normal_label,
hard_label=args.hard_label,
batch_size=args.batch_size,
)
print(f"Merged dataset written to: {merge_meta['output_path']}")
print(
f"normal_count={merge_meta['normal_count']}, "
f"hard_count={merge_meta['hard_count']}, "
f"total={merge_meta['total_samples']}, "
f"hard_ratio_actual={merge_meta['hard_ratio_actual']:.4f}"
)
print(f"Merge summary written to: {merge_meta['summary_path']}")
if __name__ == "__main__":
main()