You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/flatten_autofit_neighborhoo...

216 lines
9.4 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""将自动拟合邻域 HDF5 数据展开为普通逐样本数据集。
邻域数据通常按 anchor 和 neighbor 分组保存,便于排序训练;本脚本把这些分组样本
扁平化为常规 `params/schedule/curve` 结构,方便复用已有预处理、评估和合并流程。
"""
# pylint: disable=too-many-locals,too-many-statements
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
import h5py
import numpy as np
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
def parse_args() -> argparse.Namespace:
"""解析自动拟合邻域 HDF5 的输入输出路径以及是否只导出邻域样本。"""
parser = argparse.ArgumentParser(
description=(
"Flatten anchor-neighborhood autofit HDF5 into the standard "
"sample-level HDF5 format"
)
)
parser.add_argument("--input", type=str, required=True, help="Input autofit neighborhood .h5")
parser.add_argument("--output", type=str, required=True, help="Output flat sample-level .h5")
parser.add_argument(
"--neighbors-only",
action="store_true",
help="Export only neighbor rows and skip anchor rows",
)
return parser.parse_args()
def _decode_attr_list(raw_names) -> list[str] | None:
"""解码 HDF5 属性中保存的 JSON 字符串或字节串列表。"""
if raw_names is None:
return None
return [
item.decode("utf-8") if isinstance(item, (bytes, np.bytes_)) else str(item)
for item in raw_names
]
def _load_json_string_array(values: np.ndarray) -> list[list[float]]:
"""从 HDF5 字符串数据集中读取 JSON 字符串并解析为 Python 对象。"""
out: list[list[float]] = []
for item in np.asarray(values).astype(str).tolist():
parsed = json.loads(item)
out.append(list(map(float, parsed)))
return out
def main() -> None:
"""把 anchor-neighbor 结构的自动拟合邻域数据展开为普通样本级 HDF5。"""
args = parse_args()
input_path = Path(args.input).resolve()
output_path = Path(args.output).resolve()
if not input_path.exists():
raise FileNotFoundError(f"Input file not found: {input_path}")
with h5py.File(input_path, "r") as src:
# anchor 保存目标样本本身neighbor 保存围绕该目标扰动后的候选参数和曲线。
anchor_params = np.asarray(src["anchor_params"][:], dtype=np.float32)
anchor_schedule = np.asarray(src["anchor_schedule"][:], dtype=np.float32)
anchor_curve = np.asarray(src["anchor_curve"][:], dtype=np.float32)
anchor_schedule_meta = np.asarray(src["anchor_schedule_meta"][:], dtype=np.float32)
anchor_family_name = np.asarray(src["anchor_family_name"][:]).astype(str)
anchor_section_index = np.asarray(src["anchor_section_index"][:], dtype=np.int32)
anchor_time_q_json = np.asarray(src["anchor_time_q_json"][:]).astype(str)
anchor_q_json = np.asarray(src["anchor_q_json"][:]).astype(str)
neighbor_anchor_id = np.asarray(src["neighbor_anchor_id"][:], dtype=np.int32)
neighbor_params = np.asarray(src["neighbor_params"][:], dtype=np.float32)
neighbor_curve = np.asarray(src["neighbor_curve"][:], dtype=np.float32)
neighbor_objective = np.asarray(src["neighbor_objective"][:], dtype=np.float32)
neighbor_objective_p = np.asarray(src["neighbor_objective_p"][:], dtype=np.float32)
neighbor_objective_d = np.asarray(src["neighbor_objective_d"][:], dtype=np.float32)
neighbor_span_frac = (
np.asarray(src["neighbor_span_frac"][:], dtype=np.float32)
if "neighbor_span_frac" in src
else np.full((len(neighbor_anchor_id),), np.nan, dtype=np.float32)
)
schedule_meta_names = _decode_attr_list(src.attrs.get("schedule_meta_names"))
span_fracs = (
np.asarray(
src.attrs.get("span_fracs", []),
dtype=np.float32,
)
.reshape(-1)
.tolist()
)
n_anchors = int(anchor_params.shape[0])
n_neighbors = int(neighbor_params.shape[0])
if n_anchors == 0:
raise ValueError("Input neighborhood dataset contains no anchors")
if n_neighbors == 0:
raise ValueError("Input neighborhood dataset contains no neighbors")
if args.neighbors_only:
total_rows = n_neighbors
else:
total_rows = n_anchors + n_neighbors
output_path.parent.mkdir(parents=True, exist_ok=True)
with h5py.File(output_path, "w") as dst:
if schedule_meta_names is not None:
dst.attrs["schedule_meta_names"] = np.asarray(schedule_meta_names, dtype="S")
if span_fracs:
dst.attrs["span_fracs"] = np.asarray(span_fracs, dtype=np.float32)
param_dim = int(anchor_params.shape[1])
schedule_dim = int(anchor_schedule.shape[1])
curve_dim = int(anchor_curve.shape[1])
schedule_meta_dim = int(anchor_schedule_meta.shape[1])
dst.create_dataset("params", shape=(total_rows, param_dim), dtype=np.float32)
dst.create_dataset("schedule", shape=(total_rows, schedule_dim), dtype=np.float32)
dst.create_dataset("curve", shape=(total_rows, curve_dim), dtype=np.float32)
dst.create_dataset("group_id", shape=(total_rows,), dtype=np.int64)
dst.create_dataset("schedule_meta", shape=(total_rows, schedule_meta_dim), dtype=np.float32)
dst.create_dataset(
"family_name",
shape=(total_rows,),
dtype=h5py.string_dtype(encoding="utf-8"),
)
dst.create_dataset("is_anchor", shape=(total_rows,), dtype=np.int8)
dst.create_dataset("neighbor_objective", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("neighbor_objective_p", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("neighbor_objective_d", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("neighbor_span_frac", shape=(total_rows,), dtype=np.float32)
dst.create_dataset("section_index", shape=(total_rows,), dtype=np.int32)
dst.create_dataset(
"timeQ_json",
shape=(total_rows,),
dtype=h5py.string_dtype(encoding="utf-8"),
)
dst.create_dataset(
"q_json",
shape=(total_rows,),
dtype=h5py.string_dtype(encoding="utf-8"),
)
write_pos = 0
if not args.neighbors_only:
# anchor 以独立样本写入group_id 等于自身编号,目标函数置 0 表示与自己完全匹配。
anchor_end = write_pos + n_anchors
dst["params"][write_pos:anchor_end] = anchor_params
dst["schedule"][write_pos:anchor_end] = anchor_schedule
dst["curve"][write_pos:anchor_end] = anchor_curve
dst["group_id"][write_pos:anchor_end] = np.arange(n_anchors, dtype=np.int32)
dst["schedule_meta"][write_pos:anchor_end] = anchor_schedule_meta
dst["family_name"][write_pos:anchor_end] = anchor_family_name.tolist()
dst["is_anchor"][write_pos:anchor_end] = 1
dst["neighbor_objective"][write_pos:anchor_end] = 0.0
dst["neighbor_objective_p"][write_pos:anchor_end] = 0.0
dst["neighbor_objective_d"][write_pos:anchor_end] = 0.0
dst["neighbor_span_frac"][write_pos:anchor_end] = 0.0
dst["section_index"][write_pos:anchor_end] = anchor_section_index
dst["timeQ_json"][write_pos:anchor_end] = anchor_time_q_json.tolist()
dst["q_json"][write_pos:anchor_end] = anchor_q_json.tolist()
write_pos = anchor_end
neighbor_end = write_pos + n_neighbors
# neighbor 继承对应 anchor 的制度、分组和元数据,只替换参数、曲线和目标函数。
dst["params"][write_pos:neighbor_end] = neighbor_params
dst["schedule"][write_pos:neighbor_end] = anchor_schedule[neighbor_anchor_id]
dst["curve"][write_pos:neighbor_end] = neighbor_curve
dst["group_id"][write_pos:neighbor_end] = neighbor_anchor_id
dst["schedule_meta"][write_pos:neighbor_end] = anchor_schedule_meta[neighbor_anchor_id]
dst["family_name"][write_pos:neighbor_end] = anchor_family_name[
neighbor_anchor_id
].tolist()
dst["is_anchor"][write_pos:neighbor_end] = 0
dst["neighbor_objective"][write_pos:neighbor_end] = neighbor_objective
dst["neighbor_objective_p"][write_pos:neighbor_end] = neighbor_objective_p
dst["neighbor_objective_d"][write_pos:neighbor_end] = neighbor_objective_d
dst["neighbor_span_frac"][write_pos:neighbor_end] = neighbor_span_frac
dst["section_index"][write_pos:neighbor_end] = anchor_section_index[
neighbor_anchor_id
]
dst["timeQ_json"][write_pos:neighbor_end] = anchor_time_q_json[
neighbor_anchor_id
].tolist()
dst["q_json"][write_pos:neighbor_end] = anchor_q_json[
neighbor_anchor_id
].tolist()
dst.attrs["n_samples"] = int(total_rows)
dst.attrs["source_neighborhood_h5"] = str(input_path)
print("Autofit neighborhood flatten complete.")
print(f"Input: {input_path}")
print(f"Output: {output_path}")
print(
f"anchors={n_anchors}, neighbors={n_neighbors}, "
f"total_exported={total_rows}"
)
print(f"neighbors_only={bool(args.neighbors_only)}")
if __name__ == "__main__":
main()