You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/finetune_forward_local_rank...

439 lines
18 KiB
Python

from __future__ import annotations
import argparse
import json
import random
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
import h5py
import joblib
import numpy as np
import torch
import torch.nn.functional as F
from src.common.experiment_paths import model_checkpoint_for_tag, model_dir_for_tag, normalize_tag, processed_path_for_tag
from src.data.param_features import param_feature_transform_from_meta, transform_param_features
from src.models.forward_surrogate import ForwardSurrogate
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Fine-tune a forward surrogate with local pairwise autofit ranking constraints"
)
parser.add_argument("--neighborhood", type=str, required=True, help="Anchor-neighborhood HDF5 path")
parser.add_argument("--base-tag", type=str, default="family_random_mixed_50k_biasfix")
parser.add_argument("--base-processed", type=str, default=None)
parser.add_argument("--base-model", type=str, default=None)
parser.add_argument("--output-tag", type=str, required=True)
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--epochs", type=int, default=80)
parser.add_argument("--lr", type=float, default=1.0e-5)
parser.add_argument("--weight-decay", type=float, default=1.0e-5)
parser.add_argument("--w-rank", type=float, default=1.0)
parser.add_argument("--w-forward", type=float, default=0.25)
parser.add_argument("--w-anchor-forward", type=float, default=0.05)
parser.add_argument("--w-bias", type=float, default=0.10)
parser.add_argument("--pair-delta-min", type=float, default=0.02)
parser.add_argument("--pair-margin-scale", type=float, default=0.30)
parser.add_argument("--pair-margin-min", type=float, default=0.01)
parser.add_argument("--pair-margin-max", type=float, default=0.20)
parser.add_argument("--huber-beta", type=float, default=0.05)
parser.add_argument("--patience", type=int, default=20)
return parser.parse_args()
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
],
}
def get_part_slices(curve_layout: dict) -> dict[str, slice]:
out: dict[str, slice] = {}
for part in curve_layout["parts"]:
out[str(part["name"])] = slice(int(part["start"]), int(part["end"]))
return out
def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, dict]:
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model = ForwardSurrogate(
param_dim=int(checkpoint["param_dim"]),
schedule_dim=int(checkpoint["schedule_dim"]),
curve_dim=int(checkpoint["curve_dim"]),
hidden_dim=int(checkpoint["hidden_dim"]),
dropout=float(checkpoint["dropout"]),
use_schedule=bool(checkpoint.get("use_schedule", True)),
)
model.load_state_dict(checkpoint["model_state_dict"])
return model, checkpoint
def resolve_default_processed_path(base_tag: str | None) -> Path:
direct = processed_path_for_tag(base_tag)
if direct.exists() or base_tag is None:
return direct
for suffix in ("_biasfix", "_autofit"):
if base_tag.endswith(suffix):
fallback = processed_path_for_tag(base_tag[: -len(suffix)])
if fallback.exists():
return fallback
return direct
def load_neighborhood_groups(
neighborhood_path: Path,
processed: dict,
) -> list[dict[str, np.ndarray]]:
scaler_params = processed["scaler_params"]
scaler_schedule = processed["scaler_schedule"]
scaler_curve = processed["scaler_curve"]
param_transform = param_feature_transform_from_meta(processed.get("meta", {}))
with h5py.File(neighborhood_path, "r") as f:
anchor_params = np.asarray(f["anchor_params"][:], dtype=np.float32)
anchor_schedule = np.asarray(f["anchor_schedule"][:], dtype=np.float32)
anchor_curve = np.asarray(f["anchor_curve"][:], dtype=np.float32)
neighbor_anchor_id = np.asarray(f["neighbor_anchor_id"][:], dtype=np.int64)
neighbor_params = np.asarray(f["neighbor_params"][:], dtype=np.float32)
neighbor_curve = np.asarray(f["neighbor_curve"][:], dtype=np.float32)
neighbor_objective = np.asarray(f["neighbor_objective"][:], dtype=np.float32)
groups: list[dict[str, np.ndarray]] = []
for anchor_id in range(anchor_params.shape[0]):
idx = np.where(neighbor_anchor_id == anchor_id)[0]
if idx.size < 2:
continue
anchor_params_scaled = scaler_params.transform(
transform_param_features(anchor_params[anchor_id : anchor_id + 1], param_transform)
).astype(np.float32)
anchor_schedule_scaled = scaler_schedule.transform(anchor_schedule[anchor_id : anchor_id + 1]).astype(np.float32)
anchor_curve_scaled = scaler_curve.transform(anchor_curve[anchor_id : anchor_id + 1]).astype(np.float32)
groups.append(
{
"anchor_id": np.asarray([anchor_id], dtype=np.int64),
"anchor_params_x": anchor_params_scaled,
"anchor_schedule_x": anchor_schedule_scaled,
"anchor_curve_x": anchor_curve_scaled,
"anchor_curve_raw": anchor_curve[anchor_id : anchor_id + 1].astype(np.float32),
"neighbor_params_x": scaler_params.transform(
transform_param_features(neighbor_params[idx], param_transform)
).astype(np.float32),
"neighbor_schedule_x": scaler_schedule.transform(anchor_schedule[neighbor_anchor_id[idx]]).astype(np.float32),
"neighbor_curve_x": scaler_curve.transform(neighbor_curve[idx]).astype(np.float32),
"neighbor_objective": neighbor_objective[idx].astype(np.float32),
}
)
if not groups:
raise ValueError(f"No usable anchor groups found in {neighborhood_path}")
return groups
def to_tensor(group: dict[str, np.ndarray], device: torch.device) -> dict[str, torch.Tensor]:
return {
key: torch.tensor(value, dtype=torch.float32, device=device)
for key, value in group.items()
if key != "anchor_id"
}
def restore_raw(x_scaled: torch.Tensor, mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
return x_scaled * scale.unsqueeze(0) + mean.unsqueeze(0)
def objective_1d(target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0)
weight = 1.0 / (1.0 + weight_factor)
scale = torch.maximum(torch.maximum(torch.abs(target), torch.abs(pred)), torch.full_like(target, 1e-12))
relative_error = torch.abs(target - pred) / scale
absolute_error = torch.abs(target - pred)
point_error = 0.7 * relative_error + 0.3 * absolute_error
weighted_mse = (weight * (point_error**2)).sum(dim=1) / torch.clamp(weight.sum(dim=1), min=1e-12)
return torch.sqrt(weighted_mse)
def dual_log_objective_torch(
anchor_curve_raw: torch.Tensor,
pred_curve_raw: torch.Tensor,
slices: dict[str, slice],
) -> torch.Tensor:
target = anchor_curve_raw.expand(pred_curve_raw.shape[0], -1)
p_obj = objective_1d(target[:, slices["log_pressure"]], pred_curve_raw[:, slices["log_pressure"]])
d_obj = objective_1d(target[:, slices["log_derivative"]], pred_curve_raw[:, slices["log_derivative"]])
return 0.5 * p_obj + 0.5 * d_obj
def pairwise_ranking_loss(
pred_objective: torch.Tensor,
solver_objective: torch.Tensor,
delta_min: float,
margin_scale: float,
margin_min: float,
margin_max: float,
) -> torch.Tensor:
solver_delta = solver_objective.unsqueeze(0) - solver_objective.unsqueeze(1)
pred_delta = pred_objective.unsqueeze(0) - pred_objective.unsqueeze(1)
mask = solver_delta > float(delta_min)
if not bool(torch.any(mask)):
return pred_objective.sum() * 0.0
margin = torch.clamp(float(margin_scale) * solver_delta, min=float(margin_min), max=float(margin_max))
return F.softplus(margin[mask] - pred_delta[mask]).mean()
def smooth_l1_loss(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor:
return F.smooth_l1_loss(pred, target, beta=float(beta), reduction="mean")
def run_epoch(
model: ForwardSurrogate,
groups: list[dict[str, np.ndarray]],
optimizer: torch.optim.Optimizer | None,
device: torch.device,
slices: dict[str, slice],
curve_mean: torch.Tensor,
curve_scale: torch.Tensor,
args: argparse.Namespace,
) -> dict[str, float]:
training = optimizer is not None
model.train(training)
order = list(range(len(groups)))
if training:
random.shuffle(order)
totals = {
"loss": 0.0,
"rank": 0.0,
"forward": 0.0,
"anchor_forward": 0.0,
"bias": 0.0,
}
for pos in order:
g = to_tensor(groups[pos], device=device)
if training:
optimizer.zero_grad()
pred_neighbor_scaled = model(g["neighbor_params_x"], g["neighbor_schedule_x"])
pred_neighbor_raw = restore_raw(pred_neighbor_scaled, curve_mean, curve_scale)
pred_obj = dual_log_objective_torch(g["anchor_curve_raw"], pred_neighbor_raw, slices)
rank_loss = pairwise_ranking_loss(
pred_objective=pred_obj,
solver_objective=g["neighbor_objective"],
delta_min=float(args.pair_delta_min),
margin_scale=float(args.pair_margin_scale),
margin_min=float(args.pair_margin_min),
margin_max=float(args.pair_margin_max),
)
forward_loss = smooth_l1_loss(pred_neighbor_scaled, g["neighbor_curve_x"], beta=float(args.huber_beta))
pred_anchor_scaled = model(g["anchor_params_x"], g["anchor_schedule_x"])
anchor_forward_loss = smooth_l1_loss(pred_anchor_scaled, g["anchor_curve_x"], beta=float(args.huber_beta))
pred_p_mean = pred_neighbor_scaled[:, slices["log_pressure"]].mean(dim=1)
true_p_mean = g["neighbor_curve_x"][:, slices["log_pressure"]].mean(dim=1)
pred_d_mean = pred_neighbor_scaled[:, slices["log_derivative"]].mean(dim=1)
true_d_mean = g["neighbor_curve_x"][:, slices["log_derivative"]].mean(dim=1)
bias_loss = F.l1_loss(pred_p_mean, true_p_mean) + F.l1_loss(pred_d_mean, true_d_mean)
loss = (
float(args.w_rank) * rank_loss
+ float(args.w_forward) * forward_loss
+ float(args.w_anchor_forward) * anchor_forward_loss
+ float(args.w_bias) * bias_loss
)
if training:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
totals["loss"] += float(loss.detach().cpu())
totals["rank"] += float(rank_loss.detach().cpu())
totals["forward"] += float(forward_loss.detach().cpu())
totals["anchor_forward"] += float(anchor_forward_loss.detach().cpu())
totals["bias"] += float(bias_loss.detach().cpu())
denom = max(len(order), 1)
return {key: value / denom for key, value in totals.items()}
def main() -> None:
args = parse_args()
set_seed(int(args.seed))
base_tag = normalize_tag(args.base_tag)
output_tag = normalize_tag(args.output_tag)
if output_tag is None:
raise ValueError("--output-tag is required")
processed_path = Path(args.base_processed) if args.base_processed is not None else resolve_default_processed_path(base_tag)
model_path = Path(args.base_model) if args.base_model is not None else model_checkpoint_for_tag(base_tag, use_schedule=True)
output_dir = Path(args.output_dir) if args.output_dir is not None else model_dir_for_tag(output_tag, use_schedule=True)
output_dir.mkdir(parents=True, exist_ok=True)
processed = joblib.load(processed_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
slices = get_part_slices(curve_layout)
groups = load_neighborhood_groups(Path(args.neighborhood), processed)
rng = np.random.RandomState(int(args.seed))
perm = rng.permutation(len(groups))
n_val = max(1, int(round(0.20 * len(groups))))
val_ids = set(int(x) for x in perm[:n_val])
train_groups = [g for i, g in enumerate(groups) if i not in val_ids]
val_groups = [g for i, g in enumerate(groups) if i in val_ids]
model, checkpoint = load_model(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
scaler_curve = processed["scaler_curve"]
curve_mean = torch.tensor(np.asarray(scaler_curve.mean_, dtype=np.float32), dtype=torch.float32, device=device)
curve_scale = torch.tensor(np.asarray(scaler_curve.scale_, dtype=np.float32), dtype=torch.float32, device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=float(args.lr), weight_decay=float(args.weight_decay))
best_val = float("inf")
best_path = output_dir / "forward_surrogate_best.pt"
history: list[dict] = []
bad_epochs = 0
print("Local ranking fine-tune config:")
print(f" processed={processed_path}")
print(f" base_model={model_path}")
print(f" neighborhood={args.neighborhood}")
print(f" output_dir={output_dir}")
print(f" device={device}, groups train={len(train_groups)}, val={len(val_groups)}")
print(
f" weights rank={args.w_rank}, forward={args.w_forward}, "
f"anchor_forward={args.w_anchor_forward}, bias={args.w_bias}"
)
for epoch in range(1, int(args.epochs) + 1):
train_metrics = run_epoch(
model=model,
groups=train_groups,
optimizer=optimizer,
device=device,
slices=slices,
curve_mean=curve_mean,
curve_scale=curve_scale,
args=args,
)
with torch.no_grad():
val_metrics = run_epoch(
model=model,
groups=val_groups,
optimizer=None,
device=device,
slices=slices,
curve_mean=curve_mean,
curve_scale=curve_scale,
args=args,
)
row = {"epoch": epoch}
row.update({f"train_{k}": float(v) for k, v in train_metrics.items()})
row.update({f"val_{k}": float(v) for k, v in val_metrics.items()})
history.append(row)
print(
f"[Epoch {epoch:03d}] "
f"train={train_metrics['loss']:.6f} (rank={train_metrics['rank']:.6f}, fwd={train_metrics['forward']:.6f}) "
f"val={val_metrics['loss']:.6f} (rank={val_metrics['rank']:.6f}, fwd={val_metrics['forward']:.6f})"
)
if val_metrics["loss"] < best_val - 1e-6:
best_val = val_metrics["loss"]
bad_epochs = 0
torch.save(
{
"model_state_dict": model.state_dict(),
"param_dim": int(checkpoint["param_dim"]),
"schedule_dim": int(checkpoint["schedule_dim"]),
"curve_dim": int(checkpoint["curve_dim"]),
"hidden_dim": int(checkpoint["hidden_dim"]),
"dropout": float(checkpoint["dropout"]),
"use_schedule": bool(checkpoint.get("use_schedule", True)),
"seed": int(args.seed),
"curve_layout": curve_layout,
"base_model_path": str(model_path),
"base_processed_path": str(processed_path),
"neighborhood_path": str(Path(args.neighborhood)),
"fine_tune": {
"type": "local_pairwise_ranking",
"best_val_loss": float(best_val),
"weights": {
"rank": float(args.w_rank),
"forward": float(args.w_forward),
"anchor_forward": float(args.w_anchor_forward),
"bias": float(args.w_bias),
},
"pair_delta_min": float(args.pair_delta_min),
"pair_margin_scale": float(args.pair_margin_scale),
"pair_margin_min": float(args.pair_margin_min),
"pair_margin_max": float(args.pair_margin_max),
},
},
best_path,
)
print(f" -> best model saved to: {best_path}")
else:
bad_epochs += 1
if bad_epochs >= int(args.patience):
print(f"Early stopping at epoch {epoch}; best_val={best_val:.6f}")
break
with open(output_dir / "history.json", "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
with open(output_dir / "metrics.json", "w", encoding="utf-8") as f:
json.dump(
{
"best_val_loss": float(best_val),
"base_model_path": str(model_path),
"base_processed_path": str(processed_path),
"neighborhood_path": str(Path(args.neighborhood)),
"n_train_groups": int(len(train_groups)),
"n_val_groups": int(len(val_groups)),
"history_last": history[-1] if history else {},
},
f,
ensure_ascii=False,
indent=2,
)
print("\nLocal ranking fine-tune complete.")
print(f"Best checkpoint: {best_path}")
print(f"Metrics: {output_dir / 'metrics.json'}")
if __name__ == "__main__":
main()