from __future__ import annotations import argparse import json import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) import h5py import numpy as np def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Flatten anchor-neighborhood autofit HDF5 into the standard sample-level HDF5 format" ) parser.add_argument("--input", type=str, required=True, help="Input autofit neighborhood .h5") parser.add_argument("--output", type=str, required=True, help="Output flat sample-level .h5") parser.add_argument( "--neighbors-only", action="store_true", help="Export only neighbor rows and skip anchor rows", ) return parser.parse_args() def _decode_attr_list(raw_names) -> list[str] | None: if raw_names is None: return None return [ item.decode("utf-8") if isinstance(item, (bytes, np.bytes_)) else str(item) for item in raw_names ] def _load_json_string_array(values: np.ndarray) -> list[list[float]]: out: list[list[float]] = [] for item in np.asarray(values).astype(str).tolist(): parsed = json.loads(item) out.append(list(map(float, parsed))) return out def main() -> None: args = parse_args() input_path = Path(args.input).resolve() output_path = Path(args.output).resolve() if not input_path.exists(): raise FileNotFoundError(f"Input file not found: {input_path}") with h5py.File(input_path, "r") as src: anchor_params = np.asarray(src["anchor_params"][:], dtype=np.float32) anchor_schedule = np.asarray(src["anchor_schedule"][:], dtype=np.float32) anchor_curve = np.asarray(src["anchor_curve"][:], dtype=np.float32) anchor_schedule_meta = np.asarray(src["anchor_schedule_meta"][:], dtype=np.float32) anchor_family_name = np.asarray(src["anchor_family_name"][:]).astype(str) anchor_section_index = np.asarray(src["anchor_section_index"][:], dtype=np.int32) anchor_timeQ_json = np.asarray(src["anchor_timeQ_json"][:]).astype(str) anchor_q_json = np.asarray(src["anchor_q_json"][:]).astype(str) neighbor_anchor_id = np.asarray(src["neighbor_anchor_id"][:], dtype=np.int32) neighbor_params = np.asarray(src["neighbor_params"][:], dtype=np.float32) neighbor_curve = np.asarray(src["neighbor_curve"][:], dtype=np.float32) neighbor_objective = np.asarray(src["neighbor_objective"][:], dtype=np.float32) neighbor_objective_p = np.asarray(src["neighbor_objective_p"][:], dtype=np.float32) neighbor_objective_d = np.asarray(src["neighbor_objective_d"][:], dtype=np.float32) neighbor_span_frac = ( np.asarray(src["neighbor_span_frac"][:], dtype=np.float32) if "neighbor_span_frac" in src else np.full((len(neighbor_anchor_id),), np.nan, dtype=np.float32) ) schedule_meta_names = _decode_attr_list(src.attrs.get("schedule_meta_names")) span_fracs = np.asarray(src.attrs.get("span_fracs", []), dtype=np.float32).reshape(-1).tolist() n_anchors = int(anchor_params.shape[0]) n_neighbors = int(neighbor_params.shape[0]) if n_anchors == 0: raise ValueError("Input neighborhood dataset contains no anchors") if n_neighbors == 0: raise ValueError("Input neighborhood dataset contains no neighbors") if args.neighbors_only: total_rows = n_neighbors else: total_rows = n_anchors + n_neighbors output_path.parent.mkdir(parents=True, exist_ok=True) with h5py.File(output_path, "w") as dst: if schedule_meta_names is not None: dst.attrs["schedule_meta_names"] = np.asarray(schedule_meta_names, dtype="S") if span_fracs: dst.attrs["span_fracs"] = np.asarray(span_fracs, dtype=np.float32) param_dim = int(anchor_params.shape[1]) schedule_dim = int(anchor_schedule.shape[1]) curve_dim = int(anchor_curve.shape[1]) schedule_meta_dim = int(anchor_schedule_meta.shape[1]) dst.create_dataset("params", shape=(total_rows, param_dim), dtype=np.float32) dst.create_dataset("schedule", shape=(total_rows, schedule_dim), dtype=np.float32) dst.create_dataset("curve", shape=(total_rows, curve_dim), dtype=np.float32) dst.create_dataset("group_id", shape=(total_rows,), dtype=np.int64) dst.create_dataset("schedule_meta", shape=(total_rows, schedule_meta_dim), dtype=np.float32) dst.create_dataset("family_name", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) dst.create_dataset("is_anchor", shape=(total_rows,), dtype=np.int8) dst.create_dataset("neighbor_objective", shape=(total_rows,), dtype=np.float32) dst.create_dataset("neighbor_objective_p", shape=(total_rows,), dtype=np.float32) dst.create_dataset("neighbor_objective_d", shape=(total_rows,), dtype=np.float32) dst.create_dataset("neighbor_span_frac", shape=(total_rows,), dtype=np.float32) dst.create_dataset("section_index", shape=(total_rows,), dtype=np.int32) dst.create_dataset("timeQ_json", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) dst.create_dataset("q_json", shape=(total_rows,), dtype=h5py.string_dtype(encoding="utf-8")) write_pos = 0 if not args.neighbors_only: anchor_end = write_pos + n_anchors dst["params"][write_pos:anchor_end] = anchor_params dst["schedule"][write_pos:anchor_end] = anchor_schedule dst["curve"][write_pos:anchor_end] = anchor_curve dst["group_id"][write_pos:anchor_end] = np.arange(n_anchors, dtype=np.int32) dst["schedule_meta"][write_pos:anchor_end] = anchor_schedule_meta dst["family_name"][write_pos:anchor_end] = anchor_family_name.tolist() dst["is_anchor"][write_pos:anchor_end] = 1 dst["neighbor_objective"][write_pos:anchor_end] = 0.0 dst["neighbor_objective_p"][write_pos:anchor_end] = 0.0 dst["neighbor_objective_d"][write_pos:anchor_end] = 0.0 dst["neighbor_span_frac"][write_pos:anchor_end] = 0.0 dst["section_index"][write_pos:anchor_end] = anchor_section_index dst["timeQ_json"][write_pos:anchor_end] = anchor_timeQ_json.tolist() dst["q_json"][write_pos:anchor_end] = anchor_q_json.tolist() write_pos = anchor_end neighbor_end = write_pos + n_neighbors dst["params"][write_pos:neighbor_end] = neighbor_params dst["schedule"][write_pos:neighbor_end] = anchor_schedule[neighbor_anchor_id] dst["curve"][write_pos:neighbor_end] = neighbor_curve dst["group_id"][write_pos:neighbor_end] = neighbor_anchor_id dst["schedule_meta"][write_pos:neighbor_end] = anchor_schedule_meta[neighbor_anchor_id] dst["family_name"][write_pos:neighbor_end] = anchor_family_name[neighbor_anchor_id].tolist() dst["is_anchor"][write_pos:neighbor_end] = 0 dst["neighbor_objective"][write_pos:neighbor_end] = neighbor_objective dst["neighbor_objective_p"][write_pos:neighbor_end] = neighbor_objective_p dst["neighbor_objective_d"][write_pos:neighbor_end] = neighbor_objective_d dst["neighbor_span_frac"][write_pos:neighbor_end] = neighbor_span_frac dst["section_index"][write_pos:neighbor_end] = anchor_section_index[neighbor_anchor_id] dst["timeQ_json"][write_pos:neighbor_end] = anchor_timeQ_json[neighbor_anchor_id].tolist() dst["q_json"][write_pos:neighbor_end] = anchor_q_json[neighbor_anchor_id].tolist() dst.attrs["n_samples"] = int(total_rows) dst.attrs["source_neighborhood_h5"] = str(input_path) print("Autofit neighborhood flatten complete.") print(f"Input: {input_path}") print(f"Output: {output_path}") print(f"anchors={n_anchors}, neighbors={n_neighbors}, total_exported={total_rows}") print(f"neighbors_only={bool(args.neighbors_only)}") if __name__ == "__main__": main()