from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from scripts.merge_datasets import merge_datasets from src.common.experiment_paths import normalize_tag, processed_path_for_tag, sample_path_for_tag from src.data.preprocess import preprocess_dataset def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Build a mixed raw+processed dataset from normal and hard HDF5 pools") parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset") parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset") parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Experiment tag") parser.add_argument("--output-h5", type=str, default=None, help="Optional merged raw .h5 path") parser.add_argument("--output-processed", type=str, default=None, help="Optional processed .pkl path") parser.add_argument("--total-samples", type=int, default=50000) parser.add_argument("--hard-ratio", type=float, default=0.30) parser.add_argument("--normal-count", type=int, default=None) parser.add_argument("--hard-count", type=int, default=None) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--test-size", type=float, default=0.15) parser.add_argument("--val-size", type=float, default=0.15) parser.add_argument("--normal-label", type=str, default="normal") parser.add_argument("--hard-label", type=str, default="hard") parser.add_argument("--batch-size", type=int, default=4096) return parser.parse_args() def main() -> None: args = parse_args() tag = normalize_tag(args.tag) output_h5 = Path(args.output_h5) if args.output_h5 is not None else sample_path_for_tag(tag) output_processed = ( Path(args.output_processed) if args.output_processed is not None else processed_path_for_tag(tag) ) merge_meta = merge_datasets( normal_input=args.normal_input, hard_input=args.hard_input, output=output_h5, tag=tag, total_samples=args.total_samples, hard_ratio=args.hard_ratio, normal_count=args.normal_count, hard_count=args.hard_count, seed=args.seed, normal_label=args.normal_label, hard_label=args.hard_label, batch_size=args.batch_size, ) preprocess_dataset( input_path=output_h5, output_path=output_processed, test_size=args.test_size, val_size=args.val_size, random_seed=args.seed, ) print(f"Merged raw dataset: {merge_meta['output_path']}") print(f"Merge summary: {merge_meta['summary_path']}") print(f"Processed dataset: {output_processed}") if __name__ == "__main__": main()