You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/flatten_autofit_neighborhoo...

216 lines
9.4 KiB
Python

"""将自动拟合邻域 HDF5 数据展开为普通逐样本数据集。
邻域数据通常按 anchor neighbor 分组保存便于排序训练本脚本把这些分组样本
扁平化为常规 `params/schedule/curve` 结构方便复用已有预处理评估和合并流程
"""
# pylint: disable=too-many-locals,too-many-statements
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import h5py
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
def parse_args() -> argparse.Namespace:
"""解析自动拟合邻域 HDF5 的输入输出路径以及是否只导出邻域样本。"""
parser = argparse.ArgumentParser(
description=(
"Flatten anchor-neighborhood autofit HDF5 into the standard "
"sample-level HDF5 format"
)
)
parser.add_argument("--input", type=str, required=True, help="Input autofit neighborhood .h5")
parser.add_argument("--output", type=str, required=True, help="Output flat sample-level .h5")
parser.add_argument(
"--neighbors-only",
action="store_true",
help="Export only neighbor rows and skip anchor rows",
)
return parser.parse_args()
def _decode_attr_list(raw_names) -> list[str] | None:
"""解码 HDF5 属性中保存的 JSON 字符串或字节串列表。"""
if raw_names is None:
return None
return [
item.decode("utf-8") if isinstance(item, (bytes, np.bytes_)) else str(item)
for item in raw_names
]
def _load_json_string_array(values: np.ndarray) -> list[list[float]]:
"""从 HDF5 字符串数据集中读取 JSON 字符串并解析为 Python 对象。"""
out: list[list[float]] = []
for item in np.asarray(values).astype(str).tolist():
parsed = json.loads(item)
out.append(list(map(float, parsed)))
return out
def main() -> None:
"""把 anchor-neighbor 结构的自动拟合邻域数据展开为普通样本级 HDF5。"""
args = parse_args()
input_path = Path(args.input).resolve()
output_path = Path(args.output).resolve()
if not input_path.exists():
raise FileNotFoundError(f"Input file not found: {input_path}")
with h5py.File(input_path, "r") as src:
# anchor 保存目标样本本身neighbor 保存围绕该目标扰动后的候选参数和曲线。
anchor_params = np.asarray(src["anchor_params"][:], dtype=np.float32)
anchor_schedule = np.asarray(src["anchor_schedule"][:], dtype=np.float32)
anchor_curve = np.asarray(src["anchor_curve"][:], dtype=np.float32)
anchor_schedule_meta = np.asarray(src["anchor_schedule_meta"][:], dtype=np.float32)
anchor_family_name = np.asarray(src["anchor_family_name"][:]).astype(str)
anchor_section_index = np.asarray(src["anchor_section_index"][:], dtype=np.int32)
anchor_time_q_json = np.asarray(src["anchor_time_q_json"][:]).astype(str)
anchor_q_json = np.asarray(src["anchor_q_json"][:]).astype(str)
neighbor_anchor_id = np.asarray(src["neighbor_anchor_id"][:], dtype=np.int32)
neighbor_params = np.asarray(src["neighbor_params"][:], dtype=np.float32)
neighbor_curve = np.asarray(src["neighbor_curve"][:], dtype=np.float32)
neighbor_objective = np.asarray(src["neighbor_objective"][:], dtype=np.float32)
neighbor_objective_p = np.asarray(src["neighbor_objective_p"][:], dtype=np.float32)
neighbor_objective_d = np.asarray(src["neighbor_objective_d"][:], dtype=np.float32)
neighbor_span_frac = (
np.asarray(src["neighbor_span_frac"][:], dtype=np.float32)
if "neighbor_span_frac" in src
else np.full((len(neighbor_anchor_id),), np.nan, dtype=np.float32)
)
schedule_meta_names = _decode_attr_list(src.attrs.get("schedule_meta_names"))
span_fracs = (
np.asarray(
src.attrs.get("span_fracs", []),
dtype=np.float32,
)
.reshape(-1)
.tolist()
)
n_anchors = int(anchor_params.shape[0])
n_neighbors = int(neighbor_params.shape[0])
if n_anchors == 0:
raise ValueError("Input neighborhood dataset contains no anchors")
if n_neighbors == 0:
raise ValueError("Input neighborhood dataset contains no neighbors")
if args.neighbors_only:
total_rows = n_neighbors
else:
total_rows = n_anchors + n_neighbors
output_path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(output_path, "w") as dst:
if schedule_meta_names is not None:
dst.attrs["schedule_meta_names"] = np.asarray(schedule_meta_names, dtype="S")
if span_fracs:
dst.attrs["span_fracs"] = np.asarray(span_fracs, dtype=np.float32)
param_dim = int(anchor_params.shape[1])
schedule_dim = int(anchor_schedule.shape[1])
curve_dim = int(anchor_curve.shape[1])
schedule_meta_dim = int(anchor_schedule_meta.shape[1])
dst.create_dataset("params", shape=(total_rows, param_dim), dtype=np.float32)
dst.create_dataset("schedule", shape=(total_rows, schedule_dim), dtype=np.float32)
dst.create_dataset("curve", shape=(total_rows, curve_dim), dtype=np.float32)
dst.create_dataset("group_id", shape=(total_rows,), dtype=np.int64)
dst.create_dataset("schedule_meta", shape=(total_rows, schedule_meta_dim), dtype=np.float32)
dst.create_dataset(
"family_name",
shape=(total_rows,),
dtype=h5py.string_dtype(encoding="utf-8"),
)
dst.create_dataset("is_anchor", shape=(total_rows,), dtype=np.int8)
dst.create_dataset("neighbor_objective", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("neighbor_objective_p", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("neighbor_objective_d", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("neighbor_span_frac", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("section_index", shape=(total_rows,), dtype=np.int32)
dst.create_dataset(
"timeQ_json",
shape=(total_rows,),
dtype=h5py.string_dtype(encoding="utf-8"),
)
dst.create_dataset(
"q_json",
shape=(total_rows,),
dtype=h5py.string_dtype(encoding="utf-8"),
)
write_pos = 0
if not args.neighbors_only:
# anchor 以独立样本写入group_id 等于自身编号,目标函数置 0 表示与自己完全匹配。
anchor_end = write_pos + n_anchors
dst["params"][write_pos:anchor_end] = anchor_params
dst["schedule"][write_pos:anchor_end] = anchor_schedule
dst["curve"][write_pos:anchor_end] = anchor_curve
dst["group_id"][write_pos:anchor_end] = np.arange(n_anchors, dtype=np.int32)
dst["schedule_meta"][write_pos:anchor_end] = anchor_schedule_meta
dst["family_name"][write_pos:anchor_end] = anchor_family_name.tolist()
dst["is_anchor"][write_pos:anchor_end] = 1
dst["neighbor_objective"][write_pos:anchor_end] = 0.0
dst["neighbor_objective_p"][write_pos:anchor_end] = 0.0
dst["neighbor_objective_d"][write_pos:anchor_end] = 0.0
dst["neighbor_span_frac"][write_pos:anchor_end] = 0.0
dst["section_index"][write_pos:anchor_end] = anchor_section_index
dst["timeQ_json"][write_pos:anchor_end] = anchor_time_q_json.tolist()
dst["q_json"][write_pos:anchor_end] = anchor_q_json.tolist()
write_pos = anchor_end
neighbor_end = write_pos + n_neighbors
# neighbor 继承对应 anchor 的制度、分组和元数据,只替换参数、曲线和目标函数。
dst["params"][write_pos:neighbor_end] = neighbor_params
dst["schedule"][write_pos:neighbor_end] = anchor_schedule[neighbor_anchor_id]
dst["curve"][write_pos:neighbor_end] = neighbor_curve
dst["group_id"][write_pos:neighbor_end] = neighbor_anchor_id
dst["schedule_meta"][write_pos:neighbor_end] = anchor_schedule_meta[neighbor_anchor_id]
dst["family_name"][write_pos:neighbor_end] = anchor_family_name[
neighbor_anchor_id
].tolist()
dst["is_anchor"][write_pos:neighbor_end] = 0
dst["neighbor_objective"][write_pos:neighbor_end] = neighbor_objective
dst["neighbor_objective_p"][write_pos:neighbor_end] = neighbor_objective_p
dst["neighbor_objective_d"][write_pos:neighbor_end] = neighbor_objective_d
dst["neighbor_span_frac"][write_pos:neighbor_end] = neighbor_span_frac
dst["section_index"][write_pos:neighbor_end] = anchor_section_index[
neighbor_anchor_id
]
dst["timeQ_json"][write_pos:neighbor_end] = anchor_time_q_json[
neighbor_anchor_id
].tolist()
dst["q_json"][write_pos:neighbor_end] = anchor_q_json[
neighbor_anchor_id
].tolist()
dst.attrs["n_samples"] = int(total_rows)
dst.attrs["source_neighborhood_h5"] = str(input_path)
print("Autofit neighborhood flatten complete.")
print(f"Input: {input_path}")
print(f"Output: {output_path}")
print(
f"anchors={n_anchors}, neighbors={n_neighbors}, "
f"total_exported={total_rows}"
)
print(f"neighbors_only={bool(args.neighbors_only)}")
if __name__ == "__main__":
main()