You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
439 lines
18 KiB
Python
439 lines
18 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import random
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.append(str(ROOT))
|
|
|
|
import h5py
|
|
import joblib
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from src.common.experiment_paths import model_checkpoint_for_tag, model_dir_for_tag, normalize_tag, processed_path_for_tag
|
|
from src.data.param_features import param_feature_transform_from_meta, transform_param_features
|
|
from src.models.forward_surrogate import ForwardSurrogate
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(
|
|
description="Fine-tune a forward surrogate with local pairwise autofit ranking constraints"
|
|
)
|
|
parser.add_argument("--neighborhood", type=str, required=True, help="Anchor-neighborhood HDF5 path")
|
|
parser.add_argument("--base-tag", type=str, default="family_random_mixed_50k_biasfix")
|
|
parser.add_argument("--base-processed", type=str, default=None)
|
|
parser.add_argument("--base-model", type=str, default=None)
|
|
parser.add_argument("--output-tag", type=str, required=True)
|
|
parser.add_argument("--output-dir", type=str, default=None)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--epochs", type=int, default=80)
|
|
parser.add_argument("--lr", type=float, default=1.0e-5)
|
|
parser.add_argument("--weight-decay", type=float, default=1.0e-5)
|
|
parser.add_argument("--w-rank", type=float, default=1.0)
|
|
parser.add_argument("--w-forward", type=float, default=0.25)
|
|
parser.add_argument("--w-anchor-forward", type=float, default=0.05)
|
|
parser.add_argument("--w-bias", type=float, default=0.10)
|
|
parser.add_argument("--pair-delta-min", type=float, default=0.02)
|
|
parser.add_argument("--pair-margin-scale", type=float, default=0.30)
|
|
parser.add_argument("--pair-margin-min", type=float, default=0.01)
|
|
parser.add_argument("--pair-margin-max", type=float, default=0.20)
|
|
parser.add_argument("--huber-beta", type=float, default=0.05)
|
|
parser.add_argument("--patience", type=int, default=20)
|
|
return parser.parse_args()
|
|
|
|
|
|
def set_seed(seed: int) -> None:
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
if torch.cuda.is_available():
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
|
|
curve_layout = meta.get("curve_layout")
|
|
if curve_layout is not None:
|
|
return curve_layout
|
|
n_time_points = curve_dim // 3
|
|
return {
|
|
"n_time_points": int(n_time_points),
|
|
"parts": [
|
|
{"name": "log_pressure", "start": 0, "end": n_time_points},
|
|
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
|
|
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
|
|
],
|
|
}
|
|
|
|
|
|
def get_part_slices(curve_layout: dict) -> dict[str, slice]:
|
|
out: dict[str, slice] = {}
|
|
for part in curve_layout["parts"]:
|
|
out[str(part["name"])] = slice(int(part["start"]), int(part["end"]))
|
|
return out
|
|
|
|
|
|
def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, dict]:
|
|
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
|
model = ForwardSurrogate(
|
|
param_dim=int(checkpoint["param_dim"]),
|
|
schedule_dim=int(checkpoint["schedule_dim"]),
|
|
curve_dim=int(checkpoint["curve_dim"]),
|
|
hidden_dim=int(checkpoint["hidden_dim"]),
|
|
dropout=float(checkpoint["dropout"]),
|
|
use_schedule=bool(checkpoint.get("use_schedule", True)),
|
|
)
|
|
model.load_state_dict(checkpoint["model_state_dict"])
|
|
return model, checkpoint
|
|
|
|
|
|
def resolve_default_processed_path(base_tag: str | None) -> Path:
|
|
direct = processed_path_for_tag(base_tag)
|
|
if direct.exists() or base_tag is None:
|
|
return direct
|
|
for suffix in ("_biasfix", "_autofit"):
|
|
if base_tag.endswith(suffix):
|
|
fallback = processed_path_for_tag(base_tag[: -len(suffix)])
|
|
if fallback.exists():
|
|
return fallback
|
|
return direct
|
|
|
|
|
|
def load_neighborhood_groups(
|
|
neighborhood_path: Path,
|
|
processed: dict,
|
|
) -> list[dict[str, np.ndarray]]:
|
|
scaler_params = processed["scaler_params"]
|
|
scaler_schedule = processed["scaler_schedule"]
|
|
scaler_curve = processed["scaler_curve"]
|
|
param_transform = param_feature_transform_from_meta(processed.get("meta", {}))
|
|
|
|
with h5py.File(neighborhood_path, "r") as f:
|
|
anchor_params = np.asarray(f["anchor_params"][:], dtype=np.float32)
|
|
anchor_schedule = np.asarray(f["anchor_schedule"][:], dtype=np.float32)
|
|
anchor_curve = np.asarray(f["anchor_curve"][:], dtype=np.float32)
|
|
neighbor_anchor_id = np.asarray(f["neighbor_anchor_id"][:], dtype=np.int64)
|
|
neighbor_params = np.asarray(f["neighbor_params"][:], dtype=np.float32)
|
|
neighbor_curve = np.asarray(f["neighbor_curve"][:], dtype=np.float32)
|
|
neighbor_objective = np.asarray(f["neighbor_objective"][:], dtype=np.float32)
|
|
|
|
groups: list[dict[str, np.ndarray]] = []
|
|
for anchor_id in range(anchor_params.shape[0]):
|
|
idx = np.where(neighbor_anchor_id == anchor_id)[0]
|
|
if idx.size < 2:
|
|
continue
|
|
|
|
anchor_params_scaled = scaler_params.transform(
|
|
transform_param_features(anchor_params[anchor_id : anchor_id + 1], param_transform)
|
|
).astype(np.float32)
|
|
anchor_schedule_scaled = scaler_schedule.transform(anchor_schedule[anchor_id : anchor_id + 1]).astype(np.float32)
|
|
anchor_curve_scaled = scaler_curve.transform(anchor_curve[anchor_id : anchor_id + 1]).astype(np.float32)
|
|
|
|
groups.append(
|
|
{
|
|
"anchor_id": np.asarray([anchor_id], dtype=np.int64),
|
|
"anchor_params_x": anchor_params_scaled,
|
|
"anchor_schedule_x": anchor_schedule_scaled,
|
|
"anchor_curve_x": anchor_curve_scaled,
|
|
"anchor_curve_raw": anchor_curve[anchor_id : anchor_id + 1].astype(np.float32),
|
|
"neighbor_params_x": scaler_params.transform(
|
|
transform_param_features(neighbor_params[idx], param_transform)
|
|
).astype(np.float32),
|
|
"neighbor_schedule_x": scaler_schedule.transform(anchor_schedule[neighbor_anchor_id[idx]]).astype(np.float32),
|
|
"neighbor_curve_x": scaler_curve.transform(neighbor_curve[idx]).astype(np.float32),
|
|
"neighbor_objective": neighbor_objective[idx].astype(np.float32),
|
|
}
|
|
)
|
|
|
|
if not groups:
|
|
raise ValueError(f"No usable anchor groups found in {neighborhood_path}")
|
|
return groups
|
|
|
|
|
|
def to_tensor(group: dict[str, np.ndarray], device: torch.device) -> dict[str, torch.Tensor]:
|
|
return {
|
|
key: torch.tensor(value, dtype=torch.float32, device=device)
|
|
for key, value in group.items()
|
|
if key != "anchor_id"
|
|
}
|
|
|
|
|
|
def restore_raw(x_scaled: torch.Tensor, mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
|
return x_scaled * scale.unsqueeze(0) + mean.unsqueeze(0)
|
|
|
|
|
|
def objective_1d(target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
|
|
weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0)
|
|
weight = 1.0 / (1.0 + weight_factor)
|
|
scale = torch.maximum(torch.maximum(torch.abs(target), torch.abs(pred)), torch.full_like(target, 1e-12))
|
|
relative_error = torch.abs(target - pred) / scale
|
|
absolute_error = torch.abs(target - pred)
|
|
point_error = 0.7 * relative_error + 0.3 * absolute_error
|
|
weighted_mse = (weight * (point_error**2)).sum(dim=1) / torch.clamp(weight.sum(dim=1), min=1e-12)
|
|
return torch.sqrt(weighted_mse)
|
|
|
|
|
|
def dual_log_objective_torch(
|
|
anchor_curve_raw: torch.Tensor,
|
|
pred_curve_raw: torch.Tensor,
|
|
slices: dict[str, slice],
|
|
) -> torch.Tensor:
|
|
target = anchor_curve_raw.expand(pred_curve_raw.shape[0], -1)
|
|
p_obj = objective_1d(target[:, slices["log_pressure"]], pred_curve_raw[:, slices["log_pressure"]])
|
|
d_obj = objective_1d(target[:, slices["log_derivative"]], pred_curve_raw[:, slices["log_derivative"]])
|
|
return 0.5 * p_obj + 0.5 * d_obj
|
|
|
|
|
|
def pairwise_ranking_loss(
|
|
pred_objective: torch.Tensor,
|
|
solver_objective: torch.Tensor,
|
|
delta_min: float,
|
|
margin_scale: float,
|
|
margin_min: float,
|
|
margin_max: float,
|
|
) -> torch.Tensor:
|
|
solver_delta = solver_objective.unsqueeze(0) - solver_objective.unsqueeze(1)
|
|
pred_delta = pred_objective.unsqueeze(0) - pred_objective.unsqueeze(1)
|
|
mask = solver_delta > float(delta_min)
|
|
if not bool(torch.any(mask)):
|
|
return pred_objective.sum() * 0.0
|
|
|
|
margin = torch.clamp(float(margin_scale) * solver_delta, min=float(margin_min), max=float(margin_max))
|
|
return F.softplus(margin[mask] - pred_delta[mask]).mean()
|
|
|
|
|
|
def smooth_l1_loss(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor:
|
|
return F.smooth_l1_loss(pred, target, beta=float(beta), reduction="mean")
|
|
|
|
|
|
def run_epoch(
|
|
model: ForwardSurrogate,
|
|
groups: list[dict[str, np.ndarray]],
|
|
optimizer: torch.optim.Optimizer | None,
|
|
device: torch.device,
|
|
slices: dict[str, slice],
|
|
curve_mean: torch.Tensor,
|
|
curve_scale: torch.Tensor,
|
|
args: argparse.Namespace,
|
|
) -> dict[str, float]:
|
|
training = optimizer is not None
|
|
model.train(training)
|
|
order = list(range(len(groups)))
|
|
if training:
|
|
random.shuffle(order)
|
|
|
|
totals = {
|
|
"loss": 0.0,
|
|
"rank": 0.0,
|
|
"forward": 0.0,
|
|
"anchor_forward": 0.0,
|
|
"bias": 0.0,
|
|
}
|
|
|
|
for pos in order:
|
|
g = to_tensor(groups[pos], device=device)
|
|
if training:
|
|
optimizer.zero_grad()
|
|
|
|
pred_neighbor_scaled = model(g["neighbor_params_x"], g["neighbor_schedule_x"])
|
|
pred_neighbor_raw = restore_raw(pred_neighbor_scaled, curve_mean, curve_scale)
|
|
pred_obj = dual_log_objective_torch(g["anchor_curve_raw"], pred_neighbor_raw, slices)
|
|
rank_loss = pairwise_ranking_loss(
|
|
pred_objective=pred_obj,
|
|
solver_objective=g["neighbor_objective"],
|
|
delta_min=float(args.pair_delta_min),
|
|
margin_scale=float(args.pair_margin_scale),
|
|
margin_min=float(args.pair_margin_min),
|
|
margin_max=float(args.pair_margin_max),
|
|
)
|
|
|
|
forward_loss = smooth_l1_loss(pred_neighbor_scaled, g["neighbor_curve_x"], beta=float(args.huber_beta))
|
|
|
|
pred_anchor_scaled = model(g["anchor_params_x"], g["anchor_schedule_x"])
|
|
anchor_forward_loss = smooth_l1_loss(pred_anchor_scaled, g["anchor_curve_x"], beta=float(args.huber_beta))
|
|
|
|
pred_p_mean = pred_neighbor_scaled[:, slices["log_pressure"]].mean(dim=1)
|
|
true_p_mean = g["neighbor_curve_x"][:, slices["log_pressure"]].mean(dim=1)
|
|
pred_d_mean = pred_neighbor_scaled[:, slices["log_derivative"]].mean(dim=1)
|
|
true_d_mean = g["neighbor_curve_x"][:, slices["log_derivative"]].mean(dim=1)
|
|
bias_loss = F.l1_loss(pred_p_mean, true_p_mean) + F.l1_loss(pred_d_mean, true_d_mean)
|
|
|
|
loss = (
|
|
float(args.w_rank) * rank_loss
|
|
+ float(args.w_forward) * forward_loss
|
|
+ float(args.w_anchor_forward) * anchor_forward_loss
|
|
+ float(args.w_bias) * bias_loss
|
|
)
|
|
|
|
if training:
|
|
loss.backward()
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
|
optimizer.step()
|
|
|
|
totals["loss"] += float(loss.detach().cpu())
|
|
totals["rank"] += float(rank_loss.detach().cpu())
|
|
totals["forward"] += float(forward_loss.detach().cpu())
|
|
totals["anchor_forward"] += float(anchor_forward_loss.detach().cpu())
|
|
totals["bias"] += float(bias_loss.detach().cpu())
|
|
|
|
denom = max(len(order), 1)
|
|
return {key: value / denom for key, value in totals.items()}
|
|
|
|
|
|
def main() -> None:
|
|
args = parse_args()
|
|
set_seed(int(args.seed))
|
|
|
|
base_tag = normalize_tag(args.base_tag)
|
|
output_tag = normalize_tag(args.output_tag)
|
|
if output_tag is None:
|
|
raise ValueError("--output-tag is required")
|
|
|
|
processed_path = Path(args.base_processed) if args.base_processed is not None else resolve_default_processed_path(base_tag)
|
|
model_path = Path(args.base_model) if args.base_model is not None else model_checkpoint_for_tag(base_tag, use_schedule=True)
|
|
output_dir = Path(args.output_dir) if args.output_dir is not None else model_dir_for_tag(output_tag, use_schedule=True)
|
|
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
processed = joblib.load(processed_path)
|
|
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
|
|
slices = get_part_slices(curve_layout)
|
|
|
|
groups = load_neighborhood_groups(Path(args.neighborhood), processed)
|
|
rng = np.random.RandomState(int(args.seed))
|
|
perm = rng.permutation(len(groups))
|
|
n_val = max(1, int(round(0.20 * len(groups))))
|
|
val_ids = set(int(x) for x in perm[:n_val])
|
|
train_groups = [g for i, g in enumerate(groups) if i not in val_ids]
|
|
val_groups = [g for i, g in enumerate(groups) if i in val_ids]
|
|
|
|
model, checkpoint = load_model(model_path)
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
model = model.to(device)
|
|
|
|
scaler_curve = processed["scaler_curve"]
|
|
curve_mean = torch.tensor(np.asarray(scaler_curve.mean_, dtype=np.float32), dtype=torch.float32, device=device)
|
|
curve_scale = torch.tensor(np.asarray(scaler_curve.scale_, dtype=np.float32), dtype=torch.float32, device=device)
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=float(args.lr), weight_decay=float(args.weight_decay))
|
|
|
|
best_val = float("inf")
|
|
best_path = output_dir / "forward_surrogate_best.pt"
|
|
history: list[dict] = []
|
|
bad_epochs = 0
|
|
|
|
print("Local ranking fine-tune config:")
|
|
print(f" processed={processed_path}")
|
|
print(f" base_model={model_path}")
|
|
print(f" neighborhood={args.neighborhood}")
|
|
print(f" output_dir={output_dir}")
|
|
print(f" device={device}, groups train={len(train_groups)}, val={len(val_groups)}")
|
|
print(
|
|
f" weights rank={args.w_rank}, forward={args.w_forward}, "
|
|
f"anchor_forward={args.w_anchor_forward}, bias={args.w_bias}"
|
|
)
|
|
|
|
for epoch in range(1, int(args.epochs) + 1):
|
|
train_metrics = run_epoch(
|
|
model=model,
|
|
groups=train_groups,
|
|
optimizer=optimizer,
|
|
device=device,
|
|
slices=slices,
|
|
curve_mean=curve_mean,
|
|
curve_scale=curve_scale,
|
|
args=args,
|
|
)
|
|
with torch.no_grad():
|
|
val_metrics = run_epoch(
|
|
model=model,
|
|
groups=val_groups,
|
|
optimizer=None,
|
|
device=device,
|
|
slices=slices,
|
|
curve_mean=curve_mean,
|
|
curve_scale=curve_scale,
|
|
args=args,
|
|
)
|
|
|
|
row = {"epoch": epoch}
|
|
row.update({f"train_{k}": float(v) for k, v in train_metrics.items()})
|
|
row.update({f"val_{k}": float(v) for k, v in val_metrics.items()})
|
|
history.append(row)
|
|
|
|
print(
|
|
f"[Epoch {epoch:03d}] "
|
|
f"train={train_metrics['loss']:.6f} (rank={train_metrics['rank']:.6f}, fwd={train_metrics['forward']:.6f}) "
|
|
f"val={val_metrics['loss']:.6f} (rank={val_metrics['rank']:.6f}, fwd={val_metrics['forward']:.6f})"
|
|
)
|
|
|
|
if val_metrics["loss"] < best_val - 1e-6:
|
|
best_val = val_metrics["loss"]
|
|
bad_epochs = 0
|
|
torch.save(
|
|
{
|
|
"model_state_dict": model.state_dict(),
|
|
"param_dim": int(checkpoint["param_dim"]),
|
|
"schedule_dim": int(checkpoint["schedule_dim"]),
|
|
"curve_dim": int(checkpoint["curve_dim"]),
|
|
"hidden_dim": int(checkpoint["hidden_dim"]),
|
|
"dropout": float(checkpoint["dropout"]),
|
|
"use_schedule": bool(checkpoint.get("use_schedule", True)),
|
|
"seed": int(args.seed),
|
|
"curve_layout": curve_layout,
|
|
"base_model_path": str(model_path),
|
|
"base_processed_path": str(processed_path),
|
|
"neighborhood_path": str(Path(args.neighborhood)),
|
|
"fine_tune": {
|
|
"type": "local_pairwise_ranking",
|
|
"best_val_loss": float(best_val),
|
|
"weights": {
|
|
"rank": float(args.w_rank),
|
|
"forward": float(args.w_forward),
|
|
"anchor_forward": float(args.w_anchor_forward),
|
|
"bias": float(args.w_bias),
|
|
},
|
|
"pair_delta_min": float(args.pair_delta_min),
|
|
"pair_margin_scale": float(args.pair_margin_scale),
|
|
"pair_margin_min": float(args.pair_margin_min),
|
|
"pair_margin_max": float(args.pair_margin_max),
|
|
},
|
|
},
|
|
best_path,
|
|
)
|
|
print(f" -> best model saved to: {best_path}")
|
|
else:
|
|
bad_epochs += 1
|
|
if bad_epochs >= int(args.patience):
|
|
print(f"Early stopping at epoch {epoch}; best_val={best_val:.6f}")
|
|
break
|
|
|
|
with open(output_dir / "history.json", "w", encoding="utf-8") as f:
|
|
json.dump(history, f, ensure_ascii=False, indent=2)
|
|
with open(output_dir / "metrics.json", "w", encoding="utf-8") as f:
|
|
json.dump(
|
|
{
|
|
"best_val_loss": float(best_val),
|
|
"base_model_path": str(model_path),
|
|
"base_processed_path": str(processed_path),
|
|
"neighborhood_path": str(Path(args.neighborhood)),
|
|
"n_train_groups": int(len(train_groups)),
|
|
"n_val_groups": int(len(val_groups)),
|
|
"history_last": history[-1] if history else {},
|
|
},
|
|
f,
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
)
|
|
|
|
print("\nLocal ranking fine-tune complete.")
|
|
print(f"Best checkpoint: {best_path}")
|
|
print(f"Metrics: {output_dir / 'metrics.json'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|