You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
73 lines
2.8 KiB
Python
73 lines
2.8 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.append(str(ROOT))
|
|
|
|
from scripts.merge_datasets import merge_datasets
|
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag, sample_path_for_tag
|
|
from src.data.preprocess import preprocess_dataset
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Build a mixed raw+processed dataset from normal and hard HDF5 pools")
|
|
parser.add_argument("--normal-input", type=str, required=True, help="Path to the normal/main .h5 dataset")
|
|
parser.add_argument("--hard-input", type=str, required=True, help="Path to the hard-targeted .h5 dataset")
|
|
parser.add_argument("--tag", type=str, default="family_random_mixed_50k", help="Experiment tag")
|
|
parser.add_argument("--output-h5", type=str, default=None, help="Optional merged raw .h5 path")
|
|
parser.add_argument("--output-processed", type=str, default=None, help="Optional processed .pkl path")
|
|
parser.add_argument("--total-samples", type=int, default=50000)
|
|
parser.add_argument("--hard-ratio", type=float, default=0.30)
|
|
parser.add_argument("--normal-count", type=int, default=None)
|
|
parser.add_argument("--hard-count", type=int, default=None)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--test-size", type=float, default=0.15)
|
|
parser.add_argument("--val-size", type=float, default=0.15)
|
|
parser.add_argument("--normal-label", type=str, default="normal")
|
|
parser.add_argument("--hard-label", type=str, default="hard")
|
|
parser.add_argument("--batch-size", type=int, default=4096)
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
tag = normalize_tag(args.tag)
|
|
output_h5 = Path(args.output_h5) if args.output_h5 is not None else sample_path_for_tag(tag)
|
|
output_processed = (
|
|
Path(args.output_processed) if args.output_processed is not None else processed_path_for_tag(tag)
|
|
)
|
|
|
|
merge_meta = merge_datasets(
|
|
normal_input=args.normal_input,
|
|
hard_input=args.hard_input,
|
|
output=output_h5,
|
|
tag=tag,
|
|
total_samples=args.total_samples,
|
|
hard_ratio=args.hard_ratio,
|
|
normal_count=args.normal_count,
|
|
hard_count=args.hard_count,
|
|
seed=args.seed,
|
|
normal_label=args.normal_label,
|
|
hard_label=args.hard_label,
|
|
batch_size=args.batch_size,
|
|
)
|
|
|
|
preprocess_dataset(
|
|
input_path=output_h5,
|
|
output_path=output_processed,
|
|
test_size=args.test_size,
|
|
val_size=args.val_size,
|
|
random_seed=args.seed,
|
|
)
|
|
|
|
print(f"Merged raw dataset: {merge_meta['output_path']}")
|
|
print(f"Merge summary: {merge_meta['summary_path']}")
|
|
print(f"Processed dataset: {output_processed}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|