You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/merge_datasets.py

520 lines
21 KiB
Python

"""合并常规数据集与困难样本数据集。
脚本校验两个 HDF5 的核心字段属性和可选元数据是否兼容按用户指定比例或数量抽样
再写出一个结构一致的混合集该流程用于把全局随机样本与自动拟合邻域样本混合
提升代理模型在实际反演困难区域的覆盖度
"""
# pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,no-member
from __future__ import annotations
import argparse
import json
import math
import sys
from pathlib import Path
from types import SimpleNamespace
from typing import Any
import h5py
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
REQUIRED_DATASETS = ("params", "schedule", "curve")
OPTIONAL_DATASETS = (
"group_id",
"schedule_meta",
"family_name",
"is_anchor",
"neighbor_objective",
"neighbor_objective_p",
"neighbor_objective_d",
"neighbor_span_frac",
"section_index",
"timeQ_json",
"q_json",
)
COPY_ATTRS = ("param_names", "schedule_meta_names")
def parse_args() -> argparse.Namespace:
"""解析两份 HDF5 数据集的抽样数量、混合比例、标签和输出路径。"""
parser = argparse.ArgumentParser(description="Merge normal and hard raw HDF5 datasets into one mixed dataset")
parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset")
parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset")
parser.add_argument("--output", type=str, default=None, help="Optional output .h5 path")
parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Tag used for default output naming")
parser.add_argument("--total-samples", type=int, default=50000, help="Total rows to export")
parser.add_argument("--hard-ratio", type=float, default=0.30, help="Hard-sample fraction when explicit counts are omitted")
parser.add_argument("--normal-count", type=int, default=None, help="Explicit normal sample count")
parser.add_argument("--hard-count", type=int, default=None, help="Explicit hard sample count")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--normal-label", type=str, default="normal", help="Value written to source_name for normal rows")
parser.add_argument("--hard-label", type=str, default="hard", help="Value written to source_name for hard rows")
parser.add_argument("--batch-size", type=int, default=4096, help="Output write batch size")
return parser.parse_args()
def normalize_tag(tag: str | None) -> str:
"""规范化实验标签,避免空字符串、大小写或空白导致路径不一致。"""
if tag is None:
return "merged"
cleaned = str(tag).strip().replace(" ", "_").replace("-", "_")
return cleaned or "merged"
def default_output_path(tag: str) -> Path:
"""根据合并标签生成默认 HDF5 输出文件路径。"""
return Path("data") / "samples" / f"dataset_{tag}.h5"
def _decode_attr_list(raw_names: Any) -> list[str] | None:
"""解码 HDF5 属性中保存的 JSON 字符串或字节串列表。"""
if raw_names is None:
return None
return [
item.decode("utf-8") if isinstance(item, (bytes, np.bytes_)) else str(item)
for item in raw_names
]
def _read_string_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray:
"""从 HDF5 字符串数据集中读取指定行,并统一转成 Python 字符串。"""
if indices.size == 0:
return np.asarray([], dtype=object)
indices = indices.astype(np.int64)
order = np.argsort(indices, kind="stable")
inverse = np.empty_like(order)
inverse[order] = np.arange(order.size)
values_sorted = np.asarray(ds[indices[order]]).astype(str)
return values_sorted[inverse]
def _read_numeric_rows(ds: h5py.Dataset, indices: np.ndarray) -> np.ndarray:
"""从 HDF5 数值数据集中读取指定行。"""
if indices.size == 0:
shape_tail = ds.shape[1:]
return np.empty((0, *shape_tail), dtype=ds.dtype) if shape_tail else np.empty((0,), dtype=ds.dtype)
indices = indices.astype(np.int64)
order = np.argsort(indices, kind="stable")
inverse = np.empty_like(order)
inverse[order] = np.arange(order.size)
values_sorted = np.asarray(ds[indices[order]])
return values_sorted[inverse]
def _assert_attr_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> list[str] | None:
"""检查多个输入文件的关键属性是否一致,避免合并不同 schema。"""
left = _decode_attr_list(lhs.attrs.get(name))
right = _decode_attr_list(rhs.attrs.get(name))
if left is None and right is None:
return None
if left is None or right is None:
raise ValueError(f"Attribute mismatch for {name}: one input has it and the other does not")
if left != right:
raise ValueError(f"Attribute mismatch for {name}: {left} != {right}")
return left
def _assert_dataset_compatible(name: str, lhs: h5py.File, rhs: h5py.File) -> tuple[tuple[int, ...], np.dtype]:
"""检查多个输入文件同名数据集的维度和 dtype 是否可合并。"""
if name not in lhs or name not in rhs:
raise ValueError(f"Dataset mismatch for {name}: both inputs must contain it")
lhs_ds = lhs[name]
rhs_ds = rhs[name]
if lhs_ds.shape[1:] != rhs_ds.shape[1:]:
raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}")
if lhs_ds.dtype != rhs_ds.dtype:
raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}")
return lhs_ds.shape[1:], lhs_ds.dtype
def _is_string_dtype(dtype: np.dtype) -> bool:
"""判断 HDF5 数据集是否使用字符串 dtype。"""
return h5py.check_string_dtype(dtype) is not None or dtype.kind in ("O", "S", "U")
def _optional_dataset_schema(name: str, lhs: h5py.File, rhs: h5py.File) -> dict[str, Any] | None:
"""读取可选数据集的 dtype 和尾部形状,用于在输出文件中创建同 schema 数据集。"""
has_l = name in lhs
has_r = name in rhs
if not has_l and not has_r:
return None
ref = lhs[name] if has_l else rhs[name]
shape_tail = ref.shape[1:]
dtype = ref.dtype
is_string = _is_string_dtype(dtype)
if has_l and has_r:
lhs_ds = lhs[name]
rhs_ds = rhs[name]
if lhs_ds.shape[1:] != rhs_ds.shape[1:]:
raise ValueError(f"Dataset shape mismatch for {name}: {lhs_ds.shape} vs {rhs_ds.shape}")
if _is_string_dtype(lhs_ds.dtype) != _is_string_dtype(rhs_ds.dtype):
raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}")
if not is_string and lhs_ds.dtype != rhs_ds.dtype:
raise ValueError(f"Dataset dtype mismatch for {name}: {lhs_ds.dtype} vs {rhs_ds.dtype}")
return {
"shape_tail": tuple(shape_tail),
"dtype": dtype,
"is_string": bool(is_string),
"normal_has": bool(has_l),
"hard_has": bool(has_r),
}
def _default_numeric_value(name: str) -> float | int:
"""为缺失的可选数值数据集生成合适的填充值。"""
if name in {"group_id", "section_index", "is_anchor"}:
return -1
return np.nan
def resolve_counts(args: argparse.Namespace, n_normal_avail: int, n_hard_avail: int) -> tuple[int, int]:
"""根据每个输入文件样本量和用户给定比例/数量计算抽样数量。"""
if args.normal_count is not None or args.hard_count is not None:
if args.normal_count is None or args.hard_count is None:
raise ValueError("Both --normal-count and --hard-count are required when using explicit counts")
normal_count = int(args.normal_count)
hard_count = int(args.hard_count)
else:
total = int(args.total_samples)
if total <= 0:
raise ValueError("--total-samples must be positive")
hard_ratio = float(args.hard_ratio)
if hard_ratio < 0.0 or hard_ratio > 1.0:
raise ValueError("--hard-ratio must be in [0, 1]")
hard_count = int(round(total * hard_ratio))
normal_count = total - hard_count
if normal_count <= 0 or hard_count <= 0:
raise ValueError("Both normal_count and hard_count must be positive")
if normal_count > n_normal_avail:
raise ValueError(f"Requested normal_count={normal_count}, but only {n_normal_avail} rows are available")
if hard_count > n_hard_avail:
raise ValueError(f"Requested hard_count={hard_count}, but only {n_hard_avail} rows are available")
return normal_count, hard_count
def create_output_file(
output_path: Path,
total_rows: int,
required_schemas: dict[str, tuple[tuple[int, ...], np.dtype]],
optional_schemas: dict[str, dict[str, Any] | None],
copied_attrs: dict[str, list[str] | None],
merge_meta: dict[str, Any],
) -> h5py.File:
"""创建邻域 HDF5 文件,并初始化锚点、候选、曲线和元数据数据集。"""
output_path.parent.mkdir(parents=True, exist_ok=True)
f = h5py.File(output_path, "w")
# 必需数据集必须在两个来源中完全兼容,合并输出直接按最终行数一次性创建。
for name, (shape_tail, dtype) in required_schemas.items():
f.create_dataset(name, shape=(total_rows, *shape_tail), dtype=dtype)
# 可选数据集允许某个来源缺失;缺失行会写入默认值或空字符串。
for name, schema in optional_schemas.items():
if schema is None:
continue
shape_tail = tuple(schema["shape_tail"])
dtype = schema["dtype"]
if bool(schema["is_string"]):
f.create_dataset(name, shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
else:
f.create_dataset(
name,
shape=(total_rows, *shape_tail),
dtype=dtype,
fillvalue=_default_numeric_value(name),
)
f.create_dataset("source_id", shape=(total_rows,), dtype=np.int8)
f.create_dataset("source_name", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
f.create_dataset("source_row", shape=(total_rows,), dtype=np.int64)
# source_* 字段记录每一行来自哪个输入文件,方便后续误差归因和抽样审计。
for attr_name, values in copied_attrs.items():
if values is not None:
f.attrs[attr_name] = np.asarray(values, dtype="S")
f.attrs["n_samples"] = int(total_rows)
f.attrs["source_name_vocab"] = np.asarray(
[str(merge_meta["normal_label"]), str(merge_meta["hard_label"])],
dtype="S",
)
f.attrs["merge_meta_json"] = json.dumps(merge_meta, ensure_ascii=False)
return f
def build_sample_plan(
rng: np.random.RandomState,
normal_count: int,
hard_count: int,
n_normal_avail: int,
n_hard_avail: int,
) -> np.ndarray:
"""为每个输入文件生成需要复制的行索引和来源标签。"""
normal_idx = rng.choice(n_normal_avail, size=normal_count, replace=False)
hard_idx = rng.choice(n_hard_avail, size=hard_count, replace=False)
plan = np.empty((normal_count + hard_count, 2), dtype=np.int64)
plan[:normal_count, 0] = 0
plan[:normal_count, 1] = normal_idx
plan[normal_count:, 0] = 1
plan[normal_count:, 1] = hard_idx
rng.shuffle(plan)
return plan
def summarize_family_counts(source_file: h5py.File, chosen_rows: np.ndarray) -> dict[str, int]:
"""统计合并后各流量制度族的样本数量。"""
if "family_name" not in source_file:
return {}
family_name = _read_string_rows(source_file["family_name"], chosen_rows)
unique, counts = np.unique(family_name, return_counts=True)
return {str(name): int(count) for name, count in zip(unique.tolist(), counts.tolist())}
def write_batch(
out_file: h5py.File,
batch_plan: np.ndarray,
normal_file: h5py.File,
hard_file: h5py.File,
normal_label: str,
hard_label: str,
start: int,
optional_names: tuple[str, ...],
optional_schemas: dict[str, dict[str, Any] | None],
) -> None:
"""把一批选中的源样本复制到合并输出文件。"""
batch_size = int(batch_plan.shape[0])
source_col = batch_plan[:, 0]
row_col = batch_plan[:, 1]
row_positions = np.arange(batch_size, dtype=np.int64)
sources = {
0: {"file": normal_file, "label": normal_label},
1: {"file": hard_file, "label": hard_label},
}
for dataset_name in REQUIRED_DATASETS:
out_ds = out_file[dataset_name]
shape_tail = out_ds.shape[1:]
batch_arr = np.empty((batch_size, *shape_tail), dtype=out_ds.dtype)
# 同一个 batch 中 normal/hard 行交错排列,先按来源分组读取,再放回原位置。
for source_id in (0, 1):
mask = source_col == source_id
if not np.any(mask):
continue
positions = row_positions[mask]
rows = row_col[mask]
batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows)
out_ds[start : start + batch_size] = batch_arr
for dataset_name in optional_names:
schema = optional_schemas[dataset_name]
if schema is None:
continue
if bool(schema["is_string"]):
# 字符串可选字段没有数值 fillvalue缺失来源统一填空字符串。
batch_arr = np.full((batch_size,), "", dtype=object)
for source_id in (0, 1):
mask = source_col == source_id
if not np.any(mask):
continue
has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"])
if not has_dataset:
continue
positions = row_positions[mask]
rows = row_col[mask]
batch_arr[positions] = _read_string_rows(sources[source_id]["file"][dataset_name], rows)
out_file[dataset_name][start : start + batch_size] = batch_arr.tolist()
continue
out_ds = out_file[dataset_name]
shape_tail = out_ds.shape[1:]
batch_arr = np.full(
(batch_size, *shape_tail),
_default_numeric_value(dataset_name),
dtype=out_ds.dtype,
)
# 数值可选字段先填默认值,只覆盖实际存在于来源文件中的部分。
for source_id in (0, 1):
mask = source_col == source_id
if not np.any(mask):
continue
has_dataset = bool(schema["normal_has"] if source_id == 0 else schema["hard_has"])
if not has_dataset:
continue
positions = row_positions[mask]
rows = row_col[mask]
batch_arr[positions] = _read_numeric_rows(sources[source_id]["file"][dataset_name], rows)
out_ds[start : start + batch_size] = batch_arr
out_file["source_id"][start : start + batch_size] = source_col.astype(np.int8)
out_file["source_name"][start : start + batch_size] = [
sources[int(source_id)]["label"] for source_id in source_col
]
out_file["source_row"][start : start + batch_size] = row_col.astype(np.int64)
def merge_datasets(
normal_input: str | Path,
hard_input: str | Path,
output: str | Path | None = None,
tag: str | None = "family_random_mixed_50k",
total_samples: int = 50000,
hard_ratio: float = 0.30,
normal_count: int | None = None,
hard_count: int | None = None,
seed: int = 42,
normal_label: str = "normal",
hard_label: str = "hard",
batch_size: int = 4096,
) -> dict[str, Any]:
"""执行 HDF5 数据集合并,复制主数据集、可选数据集和元数据,并写出摘要。"""
tag = normalize_tag(tag)
output_path = Path(output) if output is not None else default_output_path(tag)
rng = np.random.RandomState(int(seed))
normal_input = Path(normal_input)
hard_input = Path(hard_input)
if not normal_input.exists():
raise FileNotFoundError(f"normal_input not found: {normal_input}")
if not hard_input.exists():
raise FileNotFoundError(f"hard_input not found: {hard_input}")
merge_meta: dict[str, Any]
count_args = SimpleNamespace(
normal_count=normal_count,
hard_count=hard_count,
total_samples=total_samples,
hard_ratio=hard_ratio,
)
with h5py.File(normal_input, "r") as normal_file, h5py.File(hard_input, "r") as hard_file:
# 合并前先锁定 schema避免后面边写边发现字段不兼容。
required_schemas = {
name: _assert_dataset_compatible(name, normal_file, hard_file)
for name in REQUIRED_DATASETS
}
optional_schemas = {
name: _optional_dataset_schema(name, normal_file, hard_file)
for name in OPTIONAL_DATASETS
}
copied_attrs = {
name: _assert_attr_compatible(name, normal_file, hard_file)
for name in COPY_ATTRS
}
n_normal_avail = int(normal_file["params"].shape[0])
n_hard_avail = int(hard_file["params"].shape[0])
normal_count_resolved, hard_count_resolved = resolve_counts(
count_args,
n_normal_avail=n_normal_avail,
n_hard_avail=n_hard_avail,
)
# plan 的第 0 列是来源 id第 1 列是来源文件中的行号shuffle 后输出数据混排。
plan = build_sample_plan(
rng=rng,
normal_count=normal_count_resolved,
hard_count=hard_count_resolved,
n_normal_avail=n_normal_avail,
n_hard_avail=n_hard_avail,
)
normal_rows = plan[plan[:, 0] == 0, 1]
hard_rows = plan[plan[:, 0] == 1, 1]
merge_meta = {
"tag": tag,
"seed": int(seed),
"normal_input": str(normal_input),
"hard_input": str(hard_input),
"output_path": str(output_path),
"normal_label": str(normal_label),
"hard_label": str(hard_label),
"normal_count": int(normal_count_resolved),
"hard_count": int(hard_count_resolved),
"total_samples": int(normal_count_resolved + hard_count_resolved),
"hard_ratio_actual": float(hard_count_resolved / max(normal_count_resolved + hard_count_resolved, 1)),
"normal_family_counts": summarize_family_counts(normal_file, normal_rows),
"hard_family_counts": summarize_family_counts(hard_file, hard_rows),
}
optional_names = tuple(name for name, schema in optional_schemas.items() if schema is not None)
with create_output_file(
output_path=output_path,
total_rows=int(plan.shape[0]),
required_schemas=required_schemas,
optional_schemas=optional_schemas,
copied_attrs=copied_attrs,
merge_meta=merge_meta,
) as out_file:
batch_size = max(1, int(batch_size))
n_batches = int(math.ceil(plan.shape[0] / batch_size))
for batch_idx in range(n_batches):
# 分批复制可以避免一次性把大 HDF5 数据集全部读进内存。
start = batch_idx * batch_size
end = min(start + batch_size, int(plan.shape[0]))
write_batch(
out_file=out_file,
batch_plan=plan[start:end],
normal_file=normal_file,
hard_file=hard_file,
normal_label=str(normal_label),
hard_label=str(hard_label),
start=start,
optional_names=optional_names,
optional_schemas=optional_schemas,
)
summary_path = output_path.with_suffix(".merge_summary.json")
with open(summary_path, "w", encoding="utf-8") as f:
json.dump(merge_meta, f, ensure_ascii=False, indent=2)
result = dict(merge_meta)
result["summary_path"] = str(summary_path)
return result
def main() -> None:
"""按指定比例抽取 normal/hard 样本并合并为一份可追踪来源的 HDF5。"""
args = parse_args()
merge_meta = merge_datasets(
normal_input=args.normal_input,
hard_input=args.hard_input,
output=args.output,
tag=args.tag,
total_samples=args.total_samples,
hard_ratio=args.hard_ratio,
normal_count=args.normal_count,
hard_count=args.hard_count,
seed=args.seed,
normal_label=args.normal_label,
hard_label=args.hard_label,
batch_size=args.batch_size,
)
print(f"Merged dataset written to: {merge_meta['output_path']}")
print(
f"normal_count={merge_meta['normal_count']}, "
f"hard_count={merge_meta['hard_count']}, "
f"total={merge_meta['total_samples']}, "
f"hard_ratio_actual={merge_meta['hard_ratio_actual']:.4f}"
)
print(f"Merge summary written to: {merge_meta['summary_path']}")
if __name__ == "__main__":
main()