You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/finetune_forward_local_rank...

487 lines
20 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

"""用局部邻域排序目标微调正演代理模型。
脚本读取自动拟合邻域样本组,以同一 anchor 附近候选的真实目标函数排序为监督信号,
在原有曲线拟合损失之外加入 pairwise ranking loss使代理模型不仅能拟合曲线
也更能保留自动拟合/PSO 场景下“哪个候选更好”的局部顺序。
"""
from __future__ import annotations
import argparse
import json
import random
import sys
from pathlib import Path
import h5py
import joblib
import numpy as np
import torch
import torch.nn.functional as F
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import (
model_checkpoint_for_tag,
model_dir_for_tag,
normalize_tag,
processed_path_for_tag,
)
from src.data.param_features import (
param_feature_transform_from_meta,
transform_param_features,
)
from src.models.forward_surrogate import ForwardSurrogate
def parse_args() -> argparse.Namespace:
"""解析局部排序微调所需的邻域数据、基础模型和 ranking loss 权重。"""
parser = argparse.ArgumentParser(
description="Fine-tune a forward surrogate with local pairwise autofit ranking constraints"
)
parser.add_argument(
"--neighborhood",
type=str,
required=True,
help="Anchor-neighborhood HDF5 path",
)
parser.add_argument("--base-tag", type=str, default="family_random_mixed_50k_biasfix")
parser.add_argument("--base-processed", type=str, default=None)
parser.add_argument("--base-model", type=str, default=None)
parser.add_argument("--output-tag", type=str, required=True)
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--epochs", type=int, default=80)
parser.add_argument("--lr", type=float, default=1.0e-5)
parser.add_argument("--weight-decay", type=float, default=1.0e-5)
parser.add_argument("--w-rank", type=float, default=1.0)
parser.add_argument("--w-forward", type=float, default=0.25)
parser.add_argument("--w-anchor-forward", type=float, default=0.05)
parser.add_argument("--w-bias", type=float, default=0.10)
parser.add_argument("--pair-delta-min", type=float, default=0.02)
parser.add_argument("--pair-margin-scale", type=float, default=0.30)
parser.add_argument("--pair-margin-min", type=float, default=0.01)
parser.add_argument("--pair-margin-max", type=float, default=0.20)
parser.add_argument("--huber-beta", type=float, default=0.05)
parser.add_argument("--patience", type=int, default=20)
return parser.parse_args()
def set_seed(seed: int) -> None:
"""设置 Python、NumPy 和 PyTorch 随机种子,使微调结果可复现。"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def infer_curve_layout(meta: dict, curve_dim: int) -> dict:
"""从元数据读取曲线分段布局;旧数据没有布局时按压力/导数/斜率三等分回退。"""
curve_layout = meta.get("curve_layout")
if curve_layout is not None:
return curve_layout
n_time_points = curve_dim // 3
return {
"n_time_points": int(n_time_points),
"parts": [
{"name": "log_pressure", "start": 0, "end": n_time_points},
{"name": "log_derivative", "start": n_time_points, "end": 2 * n_time_points},
{"name": "slope", "start": 2 * n_time_points, "end": 3 * n_time_points},
],
}
def get_part_slices(curve_layout: dict) -> dict[str, slice]:
"""把 curve_layout 中的 start/end 信息转换成各曲线分段的 slice。"""
out: dict[str, slice] = {}
for part in curve_layout["parts"]:
out[str(part["name"])] = slice(int(part["start"]), int(part["end"]))
return out
def load_model(checkpoint_path: Path) -> tuple[ForwardSurrogate, dict]:
"""加载模型检查点,按保存的维度和超参数重建网络并切换到评估模式。"""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
model = ForwardSurrogate(
param_dim=int(checkpoint["param_dim"]),
schedule_dim=int(checkpoint["schedule_dim"]),
curve_dim=int(checkpoint["curve_dim"]),
hidden_dim=int(checkpoint["hidden_dim"]),
dropout=float(checkpoint["dropout"]),
use_schedule=bool(checkpoint.get("use_schedule", True)),
)
model.load_state_dict(checkpoint["model_state_dict"])
return model, checkpoint
def resolve_default_processed_path(base_tag: str | None) -> Path:
"""在未显式传入预处理数据时,根据模型或标签推导默认 processed 路径。"""
direct = processed_path_for_tag(base_tag)
if direct.exists() or base_tag is None:
return direct
for suffix in ("_biasfix", "_autofit"):
if base_tag.endswith(suffix):
fallback = processed_path_for_tag(base_tag[: -len(suffix)])
if fallback.exists():
return fallback
return direct
def load_neighborhood_groups(
neighborhood_path: Path,
processed: dict,
) -> list[dict[str, np.ndarray]]:
"""读取自动拟合邻域数据,并按 anchor_id 聚合候选,供局部排序训练使用。"""
scaler_params = processed["scaler_params"]
scaler_schedule = processed["scaler_schedule"]
scaler_curve = processed["scaler_curve"]
param_transform = param_feature_transform_from_meta(processed.get("meta", {}))
with h5py.File(neighborhood_path, "r") as f:
anchor_params = np.asarray(f["anchor_params"][:], dtype=np.float32)
anchor_schedule = np.asarray(f["anchor_schedule"][:], dtype=np.float32)
anchor_curve = np.asarray(f["anchor_curve"][:], dtype=np.float32)
neighbor_anchor_id = np.asarray(f["neighbor_anchor_id"][:], dtype=np.int64)
neighbor_params = np.asarray(f["neighbor_params"][:], dtype=np.float32)
neighbor_curve = np.asarray(f["neighbor_curve"][:], dtype=np.float32)
neighbor_objective = np.asarray(f["neighbor_objective"][:], dtype=np.float32)
groups: list[dict[str, np.ndarray]] = []
for anchor_id in range(anchor_params.shape[0]):
idx = np.where(neighbor_anchor_id == anchor_id)[0]
if idx.size < 2:
continue
anchor_params_scaled = scaler_params.transform(
transform_param_features(anchor_params[anchor_id : anchor_id + 1], param_transform)
).astype(np.float32)
anchor_schedule_scaled = scaler_schedule.transform(
anchor_schedule[anchor_id : anchor_id + 1]
).astype(np.float32)
anchor_curve_scaled = scaler_curve.transform(
anchor_curve[anchor_id : anchor_id + 1]
).astype(np.float32)
groups.append(
{
"anchor_id": np.asarray([anchor_id], dtype=np.int64),
"anchor_params_x": anchor_params_scaled,
"anchor_schedule_x": anchor_schedule_scaled,
"anchor_curve_x": anchor_curve_scaled,
"anchor_curve_raw": anchor_curve[anchor_id : anchor_id + 1].astype(np.float32),
"neighbor_params_x": scaler_params.transform(
transform_param_features(neighbor_params[idx], param_transform)
).astype(np.float32),
"neighbor_schedule_x": scaler_schedule.transform(
anchor_schedule[neighbor_anchor_id[idx]]
).astype(np.float32),
"neighbor_curve_x": scaler_curve.transform(neighbor_curve[idx]).astype(np.float32),
"neighbor_objective": neighbor_objective[idx].astype(np.float32),
}
)
if not groups:
raise ValueError(f"No usable anchor groups found in {neighborhood_path}")
return groups
def to_tensor(group: dict[str, np.ndarray], device: torch.device) -> dict[str, torch.Tensor]:
"""把 numpy 数组转换为指定设备上的 float32 张量。"""
return {
key: torch.tensor(value, dtype=torch.float32, device=device)
for key, value in group.items()
if key != "anchor_id"
}
def restore_raw(x_scaled: torch.Tensor, mean: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""用均值和尺度撤销标准化,恢复曲线原始数值。"""
return x_scaled * scale.unsqueeze(0) + mean.unsqueeze(0)
def objective_1d(target: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
"""计算单条曲线的一维自动拟合目标,混合相对误差和绝对误差。"""
weight_factor = torch.clamp(torch.abs(target) * 0.01, max=100.0)
weight = 1.0 / (1.0 + weight_factor)
scale = torch.maximum(torch.maximum(torch.abs(target), torch.abs(pred)), torch.full_like(target, 1e-12))
relative_error = torch.abs(target - pred) / scale
absolute_error = torch.abs(target - pred)
point_error = 0.7 * relative_error + 0.3 * absolute_error
weighted_mse = (weight * (point_error**2)).sum(dim=1) / torch.clamp(weight.sum(dim=1), min=1e-12)
return torch.sqrt(weighted_mse)
def dual_log_objective_torch(
anchor_curve_raw: torch.Tensor,
pred_curve_raw: torch.Tensor,
slices: dict[str, slice],
) -> torch.Tensor:
"""用 torch 计算压力和导数两段联合的自动拟合目标。"""
target = anchor_curve_raw.expand(pred_curve_raw.shape[0], -1)
p_obj = objective_1d(target[:, slices["log_pressure"]], pred_curve_raw[:, slices["log_pressure"]])
d_obj = objective_1d(target[:, slices["log_derivative"]], pred_curve_raw[:, slices["log_derivative"]])
return 0.5 * p_obj + 0.5 * d_obj
def pairwise_ranking_loss(
pred_objective: torch.Tensor,
solver_objective: torch.Tensor,
delta_min: float,
margin_scale: float,
margin_min: float,
margin_max: float,
) -> torch.Tensor:
"""计算成对排序损失,使低真实目标的候选在代理模型中也排得更靠前。"""
solver_delta = solver_objective.unsqueeze(0) - solver_objective.unsqueeze(1)
pred_delta = pred_objective.unsqueeze(0) - pred_objective.unsqueeze(1)
mask = solver_delta > float(delta_min)
if not bool(torch.any(mask)):
return pred_objective.sum() * 0.0
margin = torch.clamp(float(margin_scale) * solver_delta, min=float(margin_min), max=float(margin_max))
return F.softplus(margin[mask] - pred_delta[mask]).mean()
def smooth_l1_loss(pred: torch.Tensor, target: torch.Tensor, beta: float) -> torch.Tensor:
"""计算 Smooth L1 回归损失,用于微调时稳定拟合曲线值。"""
return F.smooth_l1_loss(pred, target, beta=float(beta), reduction="mean")
def run_epoch(
model: ForwardSurrogate,
groups: list[dict[str, np.ndarray]],
optimizer: torch.optim.Optimizer | None,
device: torch.device,
slices: dict[str, slice],
curve_mean: torch.Tensor,
curve_scale: torch.Tensor,
args: argparse.Namespace,
) -> dict[str, float]:
"""执行一个局部排序微调 epoch并累计回归损失和排序损失。"""
training = optimizer is not None
model.train(training)
order = list(range(len(groups)))
if training:
# 训练时打乱 anchor 组顺序;验证时保持固定顺序,便于复现实验。
random.shuffle(order)
totals = {
"loss": 0.0,
"rank": 0.0,
"forward": 0.0,
"anchor_forward": 0.0,
"bias": 0.0,
}
for pos in order:
g = to_tensor(groups[pos], device=device)
if training:
optimizer.zero_grad()
# 邻域候选的预测曲线先还原到原始尺度,再计算自动拟合目标用于排序损失。
pred_neighbor_scaled = model(g["neighbor_params_x"], g["neighbor_schedule_x"])
pred_neighbor_raw = restore_raw(pred_neighbor_scaled, curve_mean, curve_scale)
pred_obj = dual_log_objective_torch(g["anchor_curve_raw"], pred_neighbor_raw, slices)
rank_loss = pairwise_ranking_loss(
pred_objective=pred_obj,
solver_objective=g["neighbor_objective"],
delta_min=float(args.pair_delta_min),
margin_scale=float(args.pair_margin_scale),
margin_min=float(args.pair_margin_min),
margin_max=float(args.pair_margin_max),
)
forward_loss = smooth_l1_loss(pred_neighbor_scaled, g["neighbor_curve_x"], beta=float(args.huber_beta))
# anchor 本身也保留前向拟合约束,防止微调只学排序而破坏原正演精度。
pred_anchor_scaled = model(g["anchor_params_x"], g["anchor_schedule_x"])
anchor_forward_loss = smooth_l1_loss(pred_anchor_scaled, g["anchor_curve_x"], beta=float(args.huber_beta))
pred_p_mean = pred_neighbor_scaled[:, slices["log_pressure"]].mean(dim=1)
true_p_mean = g["neighbor_curve_x"][:, slices["log_pressure"]].mean(dim=1)
pred_d_mean = pred_neighbor_scaled[:, slices["log_derivative"]].mean(dim=1)
true_d_mean = g["neighbor_curve_x"][:, slices["log_derivative"]].mean(dim=1)
bias_loss = F.l1_loss(pred_p_mean, true_p_mean) + F.l1_loss(pred_d_mean, true_d_mean)
loss = (
float(args.w_rank) * rank_loss
+ float(args.w_forward) * forward_loss
+ float(args.w_anchor_forward) * anchor_forward_loss
+ float(args.w_bias) * bias_loss
)
if training:
loss.backward()
# 排序损失是成对组合,梯度可能较大;裁剪能减少微调阶段震荡。
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
totals["loss"] += float(loss.detach().cpu())
totals["rank"] += float(rank_loss.detach().cpu())
totals["forward"] += float(forward_loss.detach().cpu())
totals["anchor_forward"] += float(anchor_forward_loss.detach().cpu())
totals["bias"] += float(bias_loss.detach().cpu())
denom = max(len(order), 1)
return {key: value / denom for key, value in totals.items()}
def main() -> None:
"""在自动拟合邻域样本上微调正演代理模型,使其更符合局部目标函数排序。"""
args = parse_args()
set_seed(int(args.seed))
base_tag = normalize_tag(args.base_tag)
output_tag = normalize_tag(args.output_tag)
if output_tag is None:
raise ValueError("--output-tag is required")
processed_path = Path(args.base_processed) if args.base_processed is not None else resolve_default_processed_path(base_tag)
model_path = Path(args.base_model) if args.base_model is not None else model_checkpoint_for_tag(base_tag, use_schedule=True)
output_dir = Path(args.output_dir) if args.output_dir is not None else model_dir_for_tag(output_tag, use_schedule=True)
output_dir.mkdir(parents=True, exist_ok=True)
processed = joblib.load(processed_path)
curve_layout = infer_curve_layout(processed["meta"], int(processed["meta"]["curve_dim"]))
slices = get_part_slices(curve_layout)
groups = load_neighborhood_groups(Path(args.neighborhood), processed)
rng = np.random.RandomState(int(args.seed))
perm = rng.permutation(len(groups))
n_val = max(1, int(round(0.20 * len(groups))))
val_ids = set(int(x) for x in perm[:n_val])
# 按 anchor group 划分训练/验证,避免同一锚点的邻域候选同时出现在两边造成泄漏。
train_groups = [g for i, g in enumerate(groups) if i not in val_ids]
val_groups = [g for i, g in enumerate(groups) if i in val_ids]
model, checkpoint = load_model(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
scaler_curve = processed["scaler_curve"]
curve_mean = torch.tensor(np.asarray(scaler_curve.mean_, dtype=np.float32), dtype=torch.float32, device=device)
curve_scale = torch.tensor(np.asarray(scaler_curve.scale_, dtype=np.float32), dtype=torch.float32, device=device)
# 微调从已有正演模型继续训练,只调整局部排序相关目标。
optimizer = torch.optim.Adam(model.parameters(), lr=float(args.lr), weight_decay=float(args.weight_decay))
best_val = float("inf")
best_path = output_dir / "forward_surrogate_best.pt"
history: list[dict] = []
bad_epochs = 0
print("Local ranking fine-tune config:")
print(f" processed={processed_path}")
print(f" base_model={model_path}")
print(f" neighborhood={args.neighborhood}")
print(f" output_dir={output_dir}")
print(f" device={device}, groups train={len(train_groups)}, val={len(val_groups)}")
print(
f" weights rank={args.w_rank}, forward={args.w_forward}, "
f"anchor_forward={args.w_anchor_forward}, bias={args.w_bias}"
)
for epoch in range(1, int(args.epochs) + 1):
train_metrics = run_epoch(
model=model,
groups=train_groups,
optimizer=optimizer,
device=device,
slices=slices,
curve_mean=curve_mean,
curve_scale=curve_scale,
args=args,
)
with torch.no_grad():
val_metrics = run_epoch(
model=model,
groups=val_groups,
optimizer=None,
device=device,
slices=slices,
curve_mean=curve_mean,
curve_scale=curve_scale,
args=args,
)
row = {"epoch": epoch}
row.update({f"train_{k}": float(v) for k, v in train_metrics.items()})
row.update({f"val_{k}": float(v) for k, v in val_metrics.items()})
history.append(row)
print(
f"[Epoch {epoch:03d}] "
f"train={train_metrics['loss']:.6f} (rank={train_metrics['rank']:.6f}, fwd={train_metrics['forward']:.6f}) "
f"val={val_metrics['loss']:.6f} (rank={val_metrics['rank']:.6f}, fwd={val_metrics['forward']:.6f})"
)
if val_metrics["loss"] < best_val - 1e-6:
best_val = val_metrics["loss"]
bad_epochs = 0
# 保存时保留 base_model 和 neighborhood 路径,方便追踪微调来源。
torch.save(
{
"model_state_dict": model.state_dict(),
"param_dim": int(checkpoint["param_dim"]),
"schedule_dim": int(checkpoint["schedule_dim"]),
"curve_dim": int(checkpoint["curve_dim"]),
"hidden_dim": int(checkpoint["hidden_dim"]),
"dropout": float(checkpoint["dropout"]),
"use_schedule": bool(checkpoint.get("use_schedule", True)),
"seed": int(args.seed),
"curve_layout": curve_layout,
"base_model_path": str(model_path),
"base_processed_path": str(processed_path),
"neighborhood_path": str(Path(args.neighborhood)),
"fine_tune": {
"type": "local_pairwise_ranking",
"best_val_loss": float(best_val),
"weights": {
"rank": float(args.w_rank),
"forward": float(args.w_forward),
"anchor_forward": float(args.w_anchor_forward),
"bias": float(args.w_bias),
},
"pair_delta_min": float(args.pair_delta_min),
"pair_margin_scale": float(args.pair_margin_scale),
"pair_margin_min": float(args.pair_margin_min),
"pair_margin_max": float(args.pair_margin_max),
},
},
best_path,
)
print(f" -> best model saved to: {best_path}")
else:
bad_epochs += 1
if bad_epochs >= int(args.patience):
print(f"Early stopping at epoch {epoch}; best_val={best_val:.6f}")
break
with open(output_dir / "history.json", "w", encoding="utf-8") as f:
json.dump(history, f, ensure_ascii=False, indent=2)
with open(output_dir / "metrics.json", "w", encoding="utf-8") as f:
json.dump(
{
"best_val_loss": float(best_val),
"base_model_path": str(model_path),
"base_processed_path": str(processed_path),
"neighborhood_path": str(Path(args.neighborhood)),
"n_train_groups": int(len(train_groups)),
"n_val_groups": int(len(val_groups)),
"history_last": history[-1] if history else {},
},
f,
ensure_ascii=False,
indent=2,
)
print("\nLocal ranking fine-tune complete.")
print(f"Best checkpoint: {best_path}")
print(f"Metrics: {output_dir / 'metrics.json'}")
if __name__ == "__main__":
main()