You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
124 lines
5.1 KiB
Python
124 lines
5.1 KiB
Python
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.append(str(ROOT))
|
|
|
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
|
|
from src.training.train_forward import TrainConfig, train_forward
|
|
|
|
|
|
def parse_seed_list(seed_text: str) -> list[int]:
|
|
seeds = []
|
|
for item in str(seed_text).split(","):
|
|
item = item.strip()
|
|
if not item:
|
|
continue
|
|
seeds.append(int(item))
|
|
if not seeds:
|
|
raise ValueError("至少需要一个 seed")
|
|
return seeds
|
|
|
|
|
|
def default_output_root(tag: str | None, use_schedule: bool) -> Path:
|
|
suffix = "" if use_schedule else "_no_schedule"
|
|
if tag:
|
|
return Path("models") / f"forward_surrogate_{tag}_ensemble{suffix}"
|
|
return Path("models") / f"forward_surrogate_ensemble{suffix}"
|
|
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(description="Train a deep-ensemble forward surrogate for UQ")
|
|
parser.add_argument("--processed", type=str, default=None, help="Processed dataset path")
|
|
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
|
|
parser.add_argument("--output-root", type=str, default=None, help="Optional ensemble root directory")
|
|
parser.add_argument("--seeds", type=str, default="41,42,43", help="Comma-separated seed list")
|
|
parser.add_argument("--batch-size", type=int, default=256)
|
|
parser.add_argument("--epochs", type=int, default=220)
|
|
parser.add_argument("--lr", type=float, default=0.001)
|
|
parser.add_argument("--weight-decay", type=float, default=0.0005)
|
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
|
parser.add_argument("--dropout", type=float, default=0.1)
|
|
parser.add_argument("--w-pressure", type=float, default=1.0)
|
|
parser.add_argument("--w-derivative", type=float, default=2.0)
|
|
parser.add_argument("--w-slope", type=float, default=0.0)
|
|
parser.add_argument("--w-bias-pressure", type=float, default=0.15)
|
|
parser.add_argument("--w-bias-derivative", type=float, default=0.05)
|
|
parser.add_argument("--w-derivative-shape", type=float, default=0.10)
|
|
parser.add_argument("--w-autofit-pressure", type=float, default=0.0)
|
|
parser.add_argument("--w-autofit-derivative", type=float, default=0.0)
|
|
parser.add_argument("--huber-beta", type=float, default=0.05)
|
|
parser.add_argument("--use-sample-reweight", action="store_true", default=True)
|
|
parser.add_argument("--sample-reweight-alpha", type=float, default=0.4)
|
|
parser.add_argument("--sample-weight-min", type=float, default=1.0)
|
|
parser.add_argument("--sample-weight-max", type=float, default=2.5)
|
|
parser.add_argument("--no-schedule", action="store_true")
|
|
args = parser.parse_args()
|
|
|
|
tag = normalize_tag(args.tag)
|
|
use_schedule = not args.no_schedule
|
|
seeds = parse_seed_list(args.seeds)
|
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
|
output_root = Path(args.output_root) if args.output_root is not None else default_output_root(tag, use_schedule)
|
|
output_root.mkdir(parents=True, exist_ok=True)
|
|
|
|
manifest = {
|
|
"tag": tag,
|
|
"processed_path": str(processed_path),
|
|
"use_schedule": use_schedule,
|
|
"seeds": seeds,
|
|
"members": [],
|
|
}
|
|
|
|
for seed in seeds:
|
|
member_dir = output_root / f"seed_{seed}"
|
|
print(f"\n=== Training ensemble member seed={seed} -> {member_dir} ===")
|
|
cfg = TrainConfig(
|
|
processed_path=processed_path,
|
|
output_dir=member_dir,
|
|
seed=seed,
|
|
batch_size=args.batch_size,
|
|
epochs=args.epochs,
|
|
lr=args.lr,
|
|
weight_decay=args.weight_decay,
|
|
hidden_dim=args.hidden_dim,
|
|
dropout=args.dropout,
|
|
w_pressure=args.w_pressure,
|
|
w_derivative=args.w_derivative,
|
|
w_slope=args.w_slope,
|
|
w_bias_pressure=args.w_bias_pressure,
|
|
w_bias_derivative=args.w_bias_derivative,
|
|
w_derivative_shape=args.w_derivative_shape,
|
|
w_autofit_pressure=args.w_autofit_pressure,
|
|
w_autofit_derivative=args.w_autofit_derivative,
|
|
huber_beta=args.huber_beta,
|
|
use_sample_reweight=args.use_sample_reweight,
|
|
sample_reweight_alpha=args.sample_reweight_alpha,
|
|
sample_weight_min=args.sample_weight_min,
|
|
sample_weight_max=args.sample_weight_max,
|
|
use_schedule=use_schedule,
|
|
)
|
|
train_forward(cfg)
|
|
manifest["members"].append(
|
|
{
|
|
"seed": seed,
|
|
"dir": str(member_dir),
|
|
"checkpoint": str(member_dir / "forward_surrogate_best.pt"),
|
|
"metrics": str(member_dir / "metrics.json"),
|
|
}
|
|
)
|
|
|
|
with open(output_root / "ensemble_manifest.json", "w", encoding="utf-8") as f:
|
|
json.dump(manifest, f, ensure_ascii=False, indent=2)
|
|
|
|
print("\nEnsemble training complete.")
|
|
print(f"Manifest saved: {output_root / 'ensemble_manifest.json'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|