from __future__ import annotations import argparse import json import random import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) import h5py import joblib import numpy as np import torch import torch.nn.functional as F from src.common.experiment_paths import model_checkpoint_for_tag, model_dir_for_tag, normalize_tag, processed_path_for_tag from src.data.param_features import param_feature_transform_from_meta, transform_param_features from src.models.forward_surrogate import ForwardSurrogate def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Fine-tune a forward surrogate with local pairwise autofit ranking constraints" ) parser.add_argument("--neighborhood", type=str, required=True, help="Anchor-neighborhood HDF5 path") parser.add_argument("--base-tag", type=str, default="family_random_mixed_50k_biasfix") parser.add_argument("--base-processed", type=str, default=None) parser.add_argument("--base-model", type=str, default=None) parser.add_argument("--output-tag", type=str, required=True) parser.add_argument("--output-dir", type=str, default=None) parser.add_argument("--seed", type=int, default=42) parser.add_argument("--epochs", type=int, default=80) parser.add_argument("--lr", type=float, default=1.0e-5) parser.add_argument("--weight-decay", type=float, default=1.0e-5) parser.add_argument("--w-rank", type=float, default=1.0) parser.add_argument("--w-forward", type=float, default=0.25) parser.add_argument("--w-anchor-forward", type=float, default=0.05) parser.add_argument("--w-bias", type=float, default=0.10) parser.add_argument("--pair-delta-min", type=float, default=0.02) parser.add_argument("--pair-margin-scale", type=float, default=0.30) parser.add_argument("--pair-margin-min", type=float, default=0.01) parser.add_argument("--pair-margin-max", type=float, default=0.20) parser.add_argument("--huber-beta", type=float, default=0.05) parser.add_argument("--patience", type=int, default=20) return parser.parse_args() def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def infer_curve_layout(meta: dict, curve_dim: int) -> dict: curve_layout = meta.get("curve_layout") if curve_layout is not None: return curve_layout n_time_points = curve_dim // 3 return { "n_time_points": int(n_time_points), "parts": [ {"name": "log_pressure", "start": 0, "end": n_time_points}, {"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points}, {"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points}, ], } def get_part_slices(curve_layout: dict) -> dict[str, slice]: out: dict[str, slice] = {} for part in curve_layout["parts"]: out[str(part["name"])] = slice(int(part["start"]), int(part["end"])) return out def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, dict]: checkpoint = torch.load(checkpoint_path, map_location="cpu") model = ForwardSurrogate( param_dim=int(checkpoint["param_dim"]), schedule_dim=int(checkpoint["schedule_dim"]), curve_dim=int(checkpoint["curve_dim"]), hidden_dim=int(checkpoint["hidden_dim"]), dropout=float(checkpoint["dropout"]), use_schedule=bool(checkpoint.get("use_schedule", True)), ) model.load_state_dict(checkpoint["model_state_dict"]) return model, checkpoint def resolve_default_processed_path(base_tag: str | None) -> Path: direct = processed_path_for_tag(base_tag) if direct.exists() or base_tag is None: return direct for suffix in ("_biasfix", "_autofit"): if base_tag.endswith(suffix): fallback = processed_path_for_tag(base_tag[: -len(suffix)]) if fallback.exists(): return fallback return direct def load_neighborhood_groups( neighborhood_path: Path, processed: dict, ) -> list[dict[str, np.ndarray]]: scaler_params = processed["scaler_params"] scaler_schedule = processed["scaler_schedule"] scaler_curve = processed["scaler_curve"] param_transform = param_feature_transform_from_meta(processed.get("meta", {})) with h5py.File(neighborhood_path, "r") as f: anchor_params = np.asarray(f["anchor_params"][:], dtype=np.float32) anchor_schedule = np.asarray(f["anchor_schedule"][:], dtype=np.float32) anchor_curve = np.asarray(f["anchor_curve"][:], dtype=np.float32) neighbor_anchor_id = np.asarray(f["neighbor_anchor_id"][:], dtype=np.int64) neighbor_params = np.asarray(f["neighbor_params"][:], dtype=np.float32) neighbor_curve = np.asarray(f["neighbor_curve"][:], dtype=np.float32) neighbor_objective = np.asarray(f["neighbor_objective"][:], dtype=np.float32) groups: list[dict[str, np.ndarray]] = [] for anchor_id in range(anchor_params.shape[0]): idx = np.where(neighbor_anchor_id == anchor_id)[0] if idx.size < 2: continue anchor_params_scaled = scaler_params.transform( transform_param_features(anchor_params[anchor_id : anchor_id + 1], param_transform) ).astype(np.float32) anchor_schedule_scaled = scaler_schedule.transform(anchor_schedule[anchor_id : anchor_id + 1]).astype(np.float32) anchor_curve_scaled = scaler_curve.transform(anchor_curve[anchor_id : anchor_id + 1]).astype(np.float32) groups.append( { "anchor_id": np.asarray([anchor_id], dtype=np.int64), "anchor_params_x": anchor_params_scaled, "anchor_schedule_x": anchor_schedule_scaled, "anchor_curve_x": anchor_curve_scaled, "anchor_curve_raw": anchor_curve[anchor_id : anchor_id + 1].astype(np.float32), "neighbor_params_x": scaler_params.transform( transform_param_features(neighbor_params[idx], param_transform) ).astype(np.float32), "neighbor_schedule_x": scaler_schedule.transform(anchor_schedule[neighbor_anchor_id[idx]]).astype(np.float32), "neighbor_curve_x": scaler_curve.transform(neighbor_curve[idx]).astype(np.float32), "neighbor_objective": neighbor_objective[idx].astype(np.float32), } ) if not groups: raise ValueError(f"No usable anchor groups found in {neighborhood_path}") return groups def to_tensor(group: dict[str, np.ndarray], device: torch.device) -> dict[str, torch.Tensor]: return { key: torch.tensor(value, dtype=torch.float32, device=device) for key, value in group.items() if key != "anchor_id" } def restore_raw(x_scaled: torch.Tensor, mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: return x_scaled * scale.unsqueeze(0) + mean.unsqueeze(0) def objective_1d(target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor: weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0) weight = 1.0 / (1.0 + weight_factor) scale = torch.maximum(torch.maximum(torch.abs(target), torch.abs(pred)), torch.full_like(target, 1e-12)) relative_error = torch.abs(target - pred) / scale absolute_error = torch.abs(target - pred) point_error = 0.7 * relative_error + 0.3 * absolute_error weighted_mse = (weight * (point_error**2)).sum(dim=1) / torch.clamp(weight.sum(dim=1), min=1e-12) return torch.sqrt(weighted_mse) def dual_log_objective_torch( anchor_curve_raw: torch.Tensor, pred_curve_raw: torch.Tensor, slices: dict[str, slice], ) -> torch.Tensor: target = anchor_curve_raw.expand(pred_curve_raw.shape[0], -1) p_obj = objective_1d(target[:, slices["log_pressure"]], pred_curve_raw[:, slices["log_pressure"]]) d_obj = objective_1d(target[:, slices["log_derivative"]], pred_curve_raw[:, slices["log_derivative"]]) return 0.5 * p_obj + 0.5 * d_obj def pairwise_ranking_loss( pred_objective: torch.Tensor, solver_objective: torch.Tensor, delta_min: float, margin_scale: float, margin_min: float, margin_max: float, ) -> torch.Tensor: solver_delta = solver_objective.unsqueeze(0) - solver_objective.unsqueeze(1) pred_delta = pred_objective.unsqueeze(0) - pred_objective.unsqueeze(1) mask = solver_delta > float(delta_min) if not bool(torch.any(mask)): return pred_objective.sum() * 0.0 margin = torch.clamp(float(margin_scale) * solver_delta, min=float(margin_min), max=float(margin_max)) return F.softplus(margin[mask] - pred_delta[mask]).mean() def smooth_l1_loss(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor: return F.smooth_l1_loss(pred, target, beta=float(beta), reduction="mean") def run_epoch( model: ForwardSurrogate, groups: list[dict[str, np.ndarray]], optimizer: torch.optim.Optimizer | None, device: torch.device, slices: dict[str, slice], curve_mean: torch.Tensor, curve_scale: torch.Tensor, args: argparse.Namespace, ) -> dict[str, float]: training = optimizer is not None model.train(training) order = list(range(len(groups))) if training: random.shuffle(order) totals = { "loss": 0.0, "rank": 0.0, "forward": 0.0, "anchor_forward": 0.0, "bias": 0.0, } for pos in order: g = to_tensor(groups[pos], device=device) if training: optimizer.zero_grad() pred_neighbor_scaled = model(g["neighbor_params_x"], g["neighbor_schedule_x"]) pred_neighbor_raw = restore_raw(pred_neighbor_scaled, curve_mean, curve_scale) pred_obj = dual_log_objective_torch(g["anchor_curve_raw"], pred_neighbor_raw, slices) rank_loss = pairwise_ranking_loss( pred_objective=pred_obj, solver_objective=g["neighbor_objective"], delta_min=float(args.pair_delta_min), margin_scale=float(args.pair_margin_scale), margin_min=float(args.pair_margin_min), margin_max=float(args.pair_margin_max), ) forward_loss = smooth_l1_loss(pred_neighbor_scaled, g["neighbor_curve_x"], beta=float(args.huber_beta)) pred_anchor_scaled = model(g["anchor_params_x"], g["anchor_schedule_x"]) anchor_forward_loss = smooth_l1_loss(pred_anchor_scaled, g["anchor_curve_x"], beta=float(args.huber_beta)) pred_p_mean = pred_neighbor_scaled[:, slices["log_pressure"]].mean(dim=1) true_p_mean = g["neighbor_curve_x"][:, slices["log_pressure"]].mean(dim=1) pred_d_mean = pred_neighbor_scaled[:, slices["log_derivative"]].mean(dim=1) true_d_mean = g["neighbor_curve_x"][:, slices["log_derivative"]].mean(dim=1) bias_loss = F.l1_loss(pred_p_mean, true_p_mean) + F.l1_loss(pred_d_mean, true_d_mean) loss = ( float(args.w_rank) * rank_loss + float(args.w_forward) * forward_loss + float(args.w_anchor_forward) * anchor_forward_loss + float(args.w_bias) * bias_loss ) if training: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() totals["loss"] += float(loss.detach().cpu()) totals["rank"] += float(rank_loss.detach().cpu()) totals["forward"] += float(forward_loss.detach().cpu()) totals["anchor_forward"] += float(anchor_forward_loss.detach().cpu()) totals["bias"] += float(bias_loss.detach().cpu()) denom = max(len(order), 1) return {key: value / denom for key, value in totals.items()} def main() -> None: args = parse_args() set_seed(int(args.seed)) base_tag = normalize_tag(args.base_tag) output_tag = normalize_tag(args.output_tag) if output_tag is None: raise ValueError("--output-tag is required") processed_path = Path(args.base_processed) if args.base_processed is not None else resolve_default_processed_path(base_tag) model_path = Path(args.base_model) if args.base_model is not None else model_checkpoint_for_tag(base_tag, use_schedule=True) output_dir = Path(args.output_dir) if args.output_dir is not None else model_dir_for_tag(output_tag, use_schedule=True) output_dir.mkdir(parents=True, exist_ok=True) processed = joblib.load(processed_path) curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"])) slices = get_part_slices(curve_layout) groups = load_neighborhood_groups(Path(args.neighborhood), processed) rng = np.random.RandomState(int(args.seed)) perm = rng.permutation(len(groups)) n_val = max(1, int(round(0.20 * len(groups)))) val_ids = set(int(x) for x in perm[:n_val]) train_groups = [g for i, g in enumerate(groups) if i not in val_ids] val_groups = [g for i, g in enumerate(groups) if i in val_ids] model, checkpoint = load_model(model_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = model.to(device) scaler_curve = processed["scaler_curve"] curve_mean = torch.tensor(np.asarray(scaler_curve.mean_, dtype=np.float32), dtype=torch.float32, device=device) curve_scale = torch.tensor(np.asarray(scaler_curve.scale_, dtype=np.float32), dtype=torch.float32, device=device) optimizer = torch.optim.Adam(model.parameters(), lr=float(args.lr), weight_decay=float(args.weight_decay)) best_val = float("inf") best_path = output_dir / "forward_surrogate_best.pt" history: list[dict] = [] bad_epochs = 0 print("Local ranking fine-tune config:") print(f" processed={processed_path}") print(f" base_model={model_path}") print(f" neighborhood={args.neighborhood}") print(f" output_dir={output_dir}") print(f" device={device}, groups train={len(train_groups)}, val={len(val_groups)}") print( f" weights rank={args.w_rank}, forward={args.w_forward}, " f"anchor_forward={args.w_anchor_forward}, bias={args.w_bias}" ) for epoch in range(1, int(args.epochs) + 1): train_metrics = run_epoch( model=model, groups=train_groups, optimizer=optimizer, device=device, slices=slices, curve_mean=curve_mean, curve_scale=curve_scale, args=args, ) with torch.no_grad(): val_metrics = run_epoch( model=model, groups=val_groups, optimizer=None, device=device, slices=slices, curve_mean=curve_mean, curve_scale=curve_scale, args=args, ) row = {"epoch": epoch} row.update({f"train_{k}": float(v) for k, v in train_metrics.items()}) row.update({f"val_{k}": float(v) for k, v in val_metrics.items()}) history.append(row) print( f"[Epoch {epoch:03d}] " f"train={train_metrics['loss']:.6f} (rank={train_metrics['rank']:.6f}, fwd={train_metrics['forward']:.6f}) " f"val={val_metrics['loss']:.6f} (rank={val_metrics['rank']:.6f}, fwd={val_metrics['forward']:.6f})" ) if val_metrics["loss"] < best_val - 1e-6: best_val = val_metrics["loss"] bad_epochs = 0 torch.save( { "model_state_dict": model.state_dict(), "param_dim": int(checkpoint["param_dim"]), "schedule_dim": int(checkpoint["schedule_dim"]), "curve_dim": int(checkpoint["curve_dim"]), "hidden_dim": int(checkpoint["hidden_dim"]), "dropout": float(checkpoint["dropout"]), "use_schedule": bool(checkpoint.get("use_schedule", True)), "seed": int(args.seed), "curve_layout": curve_layout, "base_model_path": str(model_path), "base_processed_path": str(processed_path), "neighborhood_path": str(Path(args.neighborhood)), "fine_tune": { "type": "local_pairwise_ranking", "best_val_loss": float(best_val), "weights": { "rank": float(args.w_rank), "forward": float(args.w_forward), "anchor_forward": float(args.w_anchor_forward), "bias": float(args.w_bias), }, "pair_delta_min": float(args.pair_delta_min), "pair_margin_scale": float(args.pair_margin_scale), "pair_margin_min": float(args.pair_margin_min), "pair_margin_max": float(args.pair_margin_max), }, }, best_path, ) print(f" -> best model saved to: {best_path}") else: bad_epochs += 1 if bad_epochs >= int(args.patience): print(f"Early stopping at epoch {epoch}; best_val={best_val:.6f}") break with open(output_dir / "history.json", "w", encoding="utf-8") as f: json.dump(history, f, ensure_ascii=False, indent=2) with open(output_dir / "metrics.json", "w", encoding="utf-8") as f: json.dump( { "best_val_loss": float(best_val), "base_model_path": str(model_path), "base_processed_path": str(processed_path), "neighborhood_path": str(Path(args.neighborhood)), "n_train_groups": int(len(train_groups)), "n_val_groups": int(len(val_groups)), "history_last": history[-1] if history else {}, }, f, ensure_ascii=False, indent=2, ) print("\nLocal ranking fine-tune complete.") print(f"Best checkpoint: {best_path}") print(f"Metrics: {output_dir / 'metrics.json'}") if __name__ == "__main__": main()