|
|
|
|
|
"""将自动拟合邻域 HDF5 数据展开为普通逐样本数据集。
|
|
|
|
|
|
|
|
|
|
|
|
邻域数据通常按 anchor 和 neighbor 分组保存,便于排序训练;本脚本把这些分组样本
|
|
|
|
|
|
扁平化为常规 `params/schedule/curve` 结构,方便复用已有预处理、评估和合并流程。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
import json
|
|
|
|
|
|
import sys
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
|
|
sys.path.append(str(ROOT))
|
|
|
|
|
|
|
|
|
|
|
|
import h5py
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
|
|
|
|
"""解析自动拟合邻域 HDF5 的输入输出路径以及是否只导出邻域样本。"""
|
|
|
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
|
|
description="Flatten anchor-neighborhood autofit HDF5 into the standard sample-level HDF5 format"
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument("--input", type=str, required=True, help="Input autofit neighborhood .h5")
|
|
|
|
|
|
parser.add_argument("--output", type=str, required=True, help="Output flat sample-level .h5")
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--neighbors-only",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="Export only neighbor rows and skip anchor rows",
|
|
|
|
|
|
)
|
|
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _decode_attr_list(raw_names) -> list[str] | None:
|
|
|
|
|
|
"""解码 HDF5 属性中保存的 JSON 字符串或字节串列表。"""
|
|
|
|
|
|
if raw_names is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
return [
|
|
|
|
|
|
item.decode("utf-8") if isinstance(item, (bytes, np.bytes_)) else str(item)
|
|
|
|
|
|
for item in raw_names
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load_json_string_array(values: np.ndarray) -> list[list[float]]:
|
|
|
|
|
|
"""从 HDF5 字符串数据集中读取 JSON 字符串并解析为 Python 对象。"""
|
|
|
|
|
|
out: list[list[float]] = []
|
|
|
|
|
|
for item in np.asarray(values).astype(str).tolist():
|
|
|
|
|
|
parsed = json.loads(item)
|
|
|
|
|
|
out.append(list(map(float, parsed)))
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
|
|
|
"""把 anchor-neighbor 结构的自动拟合邻域数据展开为普通样本级 HDF5。"""
|
|
|
|
|
|
args = parse_args()
|
|
|
|
|
|
input_path = Path(args.input).resolve()
|
|
|
|
|
|
output_path = Path(args.output).resolve()
|
|
|
|
|
|
|
|
|
|
|
|
if not input_path.exists():
|
|
|
|
|
|
raise FileNotFoundError(f"Input file not found: {input_path}")
|
|
|
|
|
|
|
|
|
|
|
|
with h5py.File(input_path, "r") as src:
|
|
|
|
|
|
# anchor 保存目标样本本身;neighbor 保存围绕该目标扰动后的候选参数和曲线。
|
|
|
|
|
|
anchor_params = np.asarray(src["anchor_params"][:], dtype=np.float32)
|
|
|
|
|
|
anchor_schedule = np.asarray(src["anchor_schedule"][:], dtype=np.float32)
|
|
|
|
|
|
anchor_curve = np.asarray(src["anchor_curve"][:], dtype=np.float32)
|
|
|
|
|
|
anchor_schedule_meta = np.asarray(src["anchor_schedule_meta"][:], dtype=np.float32)
|
|
|
|
|
|
anchor_family_name = np.asarray(src["anchor_family_name"][:]).astype(str)
|
|
|
|
|
|
anchor_section_index = np.asarray(src["anchor_section_index"][:], dtype=np.int32)
|
|
|
|
|
|
anchor_timeQ_json = np.asarray(src["anchor_timeQ_json"][:]).astype(str)
|
|
|
|
|
|
anchor_q_json = np.asarray(src["anchor_q_json"][:]).astype(str)
|
|
|
|
|
|
|
|
|
|
|
|
neighbor_anchor_id = np.asarray(src["neighbor_anchor_id"][:], dtype=np.int32)
|
|
|
|
|
|
neighbor_params = np.asarray(src["neighbor_params"][:], dtype=np.float32)
|
|
|
|
|
|
neighbor_curve = np.asarray(src["neighbor_curve"][:], dtype=np.float32)
|
|
|
|
|
|
neighbor_objective = np.asarray(src["neighbor_objective"][:], dtype=np.float32)
|
|
|
|
|
|
neighbor_objective_p = np.asarray(src["neighbor_objective_p"][:], dtype=np.float32)
|
|
|
|
|
|
neighbor_objective_d = np.asarray(src["neighbor_objective_d"][:], dtype=np.float32)
|
|
|
|
|
|
neighbor_span_frac = (
|
|
|
|
|
|
np.asarray(src["neighbor_span_frac"][:], dtype=np.float32)
|
|
|
|
|
|
if "neighbor_span_frac" in src
|
|
|
|
|
|
else np.full((len(neighbor_anchor_id),), np.nan, dtype=np.float32)
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
schedule_meta_names = _decode_attr_list(src.attrs.get("schedule_meta_names"))
|
|
|
|
|
|
span_fracs = np.asarray(src.attrs.get("span_fracs", []), dtype=np.float32).reshape(-1).tolist()
|
|
|
|
|
|
|
|
|
|
|
|
n_anchors = int(anchor_params.shape[0])
|
|
|
|
|
|
n_neighbors = int(neighbor_params.shape[0])
|
|
|
|
|
|
|
|
|
|
|
|
if n_anchors == 0:
|
|
|
|
|
|
raise ValueError("Input neighborhood dataset contains no anchors")
|
|
|
|
|
|
if n_neighbors == 0:
|
|
|
|
|
|
raise ValueError("Input neighborhood dataset contains no neighbors")
|
|
|
|
|
|
|
|
|
|
|
|
if args.neighbors_only:
|
|
|
|
|
|
total_rows = n_neighbors
|
|
|
|
|
|
else:
|
|
|
|
|
|
total_rows = n_anchors + n_neighbors
|
|
|
|
|
|
|
|
|
|
|
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
with h5py.File(output_path, "w") as dst:
|
|
|
|
|
|
if schedule_meta_names is not None:
|
|
|
|
|
|
dst.attrs["schedule_meta_names"] = np.asarray(schedule_meta_names, dtype="S")
|
|
|
|
|
|
if span_fracs:
|
|
|
|
|
|
dst.attrs["span_fracs"] = np.asarray(span_fracs, dtype=np.float32)
|
|
|
|
|
|
|
|
|
|
|
|
param_dim = int(anchor_params.shape[1])
|
|
|
|
|
|
schedule_dim = int(anchor_schedule.shape[1])
|
|
|
|
|
|
curve_dim = int(anchor_curve.shape[1])
|
|
|
|
|
|
schedule_meta_dim = int(anchor_schedule_meta.shape[1])
|
|
|
|
|
|
|
|
|
|
|
|
dst.create_dataset("params", shape=(total_rows, param_dim), dtype=np.float32)
|
|
|
|
|
|
dst.create_dataset("schedule", shape=(total_rows, schedule_dim), dtype=np.float32)
|
|
|
|
|
|
dst.create_dataset("curve", shape=(total_rows, curve_dim), dtype=np.float32)
|
|
|
|
|
|
dst.create_dataset("group_id", shape=(total_rows,), dtype=np.int64)
|
|
|
|
|
|
dst.create_dataset("schedule_meta", shape=(total_rows, schedule_meta_dim), dtype=np.float32)
|
|
|
|
|
|
dst.create_dataset("family_name", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
dst.create_dataset("is_anchor", shape=(total_rows,), dtype=np.int8)
|
|
|
|
|
|
dst.create_dataset("neighbor_objective", shape=(total_rows,), dtype=np.float32)
|
|
|
|
|
|
dst.create_dataset("neighbor_objective_p", shape=(total_rows,), dtype=np.float32)
|
|
|
|
|
|
dst.create_dataset("neighbor_objective_d", shape=(total_rows,), dtype=np.float32)
|
|
|
|
|
|
dst.create_dataset("neighbor_span_frac", shape=(total_rows,), dtype=np.float32)
|
|
|
|
|
|
dst.create_dataset("section_index", shape=(total_rows,), dtype=np.int32)
|
|
|
|
|
|
dst.create_dataset("timeQ_json", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
|
|
|
|
|
|
dst.create_dataset("q_json", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8"))
|
|
|
|
|
|
|
|
|
|
|
|
write_pos = 0
|
|
|
|
|
|
|
|
|
|
|
|
if not args.neighbors_only:
|
|
|
|
|
|
# anchor 以独立样本写入,group_id 等于自身编号,目标函数置 0 表示与自己完全匹配。
|
|
|
|
|
|
anchor_end = write_pos + n_anchors
|
|
|
|
|
|
dst["params"][write_pos:anchor_end] = anchor_params
|
|
|
|
|
|
dst["schedule"][write_pos:anchor_end] = anchor_schedule
|
|
|
|
|
|
dst["curve"][write_pos:anchor_end] = anchor_curve
|
|
|
|
|
|
dst["group_id"][write_pos:anchor_end] = np.arange(n_anchors, dtype=np.int32)
|
|
|
|
|
|
dst["schedule_meta"][write_pos:anchor_end] = anchor_schedule_meta
|
|
|
|
|
|
dst["family_name"][write_pos:anchor_end] = anchor_family_name.tolist()
|
|
|
|
|
|
dst["is_anchor"][write_pos:anchor_end] = 1
|
|
|
|
|
|
dst["neighbor_objective"][write_pos:anchor_end] = 0.0
|
|
|
|
|
|
dst["neighbor_objective_p"][write_pos:anchor_end] = 0.0
|
|
|
|
|
|
dst["neighbor_objective_d"][write_pos:anchor_end] = 0.0
|
|
|
|
|
|
dst["neighbor_span_frac"][write_pos:anchor_end] = 0.0
|
|
|
|
|
|
dst["section_index"][write_pos:anchor_end] = anchor_section_index
|
|
|
|
|
|
dst["timeQ_json"][write_pos:anchor_end] = anchor_timeQ_json.tolist()
|
|
|
|
|
|
dst["q_json"][write_pos:anchor_end] = anchor_q_json.tolist()
|
|
|
|
|
|
write_pos = anchor_end
|
|
|
|
|
|
|
|
|
|
|
|
neighbor_end = write_pos + n_neighbors
|
|
|
|
|
|
# neighbor 继承对应 anchor 的制度、分组和元数据,只替换参数、曲线和目标函数。
|
|
|
|
|
|
dst["params"][write_pos:neighbor_end] = neighbor_params
|
|
|
|
|
|
dst["schedule"][write_pos:neighbor_end] = anchor_schedule[neighbor_anchor_id]
|
|
|
|
|
|
dst["curve"][write_pos:neighbor_end] = neighbor_curve
|
|
|
|
|
|
dst["group_id"][write_pos:neighbor_end] = neighbor_anchor_id
|
|
|
|
|
|
dst["schedule_meta"][write_pos:neighbor_end] = anchor_schedule_meta[neighbor_anchor_id]
|
|
|
|
|
|
dst["family_name"][write_pos:neighbor_end] = anchor_family_name[neighbor_anchor_id].tolist()
|
|
|
|
|
|
dst["is_anchor"][write_pos:neighbor_end] = 0
|
|
|
|
|
|
dst["neighbor_objective"][write_pos:neighbor_end] = neighbor_objective
|
|
|
|
|
|
dst["neighbor_objective_p"][write_pos:neighbor_end] = neighbor_objective_p
|
|
|
|
|
|
dst["neighbor_objective_d"][write_pos:neighbor_end] = neighbor_objective_d
|
|
|
|
|
|
dst["neighbor_span_frac"][write_pos:neighbor_end] = neighbor_span_frac
|
|
|
|
|
|
dst["section_index"][write_pos:neighbor_end] = anchor_section_index[neighbor_anchor_id]
|
|
|
|
|
|
dst["timeQ_json"][write_pos:neighbor_end] = anchor_timeQ_json[neighbor_anchor_id].tolist()
|
|
|
|
|
|
dst["q_json"][write_pos:neighbor_end] = anchor_q_json[neighbor_anchor_id].tolist()
|
|
|
|
|
|
|
|
|
|
|
|
dst.attrs["n_samples"] = int(total_rows)
|
|
|
|
|
|
dst.attrs["source_neighborhood_h5"] = str(input_path)
|
|
|
|
|
|
|
|
|
|
|
|
print("Autofit neighborhood flatten complete.")
|
|
|
|
|
|
print(f"Input: {input_path}")
|
|
|
|
|
|
print(f"Output: {output_path}")
|
|
|
|
|
|
print(f"anchors={n_anchors}, neighbors={n_neighbors}, total_exported={total_rows}")
|
|
|
|
|
|
print(f"neighbors_only={bool(args.neighbors_only)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|