"""训练固定长度曲线正演代理模型。 脚本把命令行超参数整理为 `TrainConfig`,加载预处理后的参数、流量制度和曲线数据, 训练 `ForwardSurrogate` 来直接预测整条重采样压力/导数/斜率曲线,是当前数值试井 代理模型的主训练入口。 """ from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import model_dir_for_tag, normalize_tag, processed_path_for_tag from src.training.train_forward import TrainConfig, train_forward def main() -> None: """读取训练超参数并训练固定长度曲线正演代理模型。""" parser = argparse.ArgumentParser(description="Train forward surrogate model") parser.add_argument( "--processed", type=str, default=None, help="Processed dataset path", ) parser.add_argument( "--output-dir", type=str, default=None, help="Optional model output directory", ) parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--epochs", type=int, default=220) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--weight-decay", type=float, default=0.0005) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--w-pressure", type=float, default=1.0) parser.add_argument("--w-derivative", type=float, default=2.0) parser.add_argument("--w-slope", type=float, default=0.0) parser.add_argument("--w-bias-pressure", type=float, default=0.15) parser.add_argument("--w-bias-derivative", type=float, default=0.05) parser.add_argument("--w-derivative-shape", type=float, default=0.10) parser.add_argument("--w-autofit-pressure", type=float, default=0.0) parser.add_argument("--w-autofit-derivative", type=float, default=0.0) parser.add_argument("--huber-beta", type=float, default=0.05) parser.add_argument("--use-sample-reweight", action="store_true", default=True) parser.add_argument("--sample-reweight-alpha", type=float, default=0.4) parser.add_argument("--sample-weight-min", type=float, default=1.0) parser.add_argument("--sample-weight-max", type=float, default=2.5) parser.add_argument( "--no-schedule", action="store_true", help="Disable the schedule branch and train a parameter-only forward surrogate", ) args = parser.parse_args() tag = normalize_tag(args.tag) use_schedule = not args.no_schedule # processed/model 路径与 tag 绑定,保证预处理、训练和评估脚本默认指向同一实验。 processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) output_dir = Path(args.output_dir) if args.output_dir is not None else model_dir_for_tag(tag, use_schedule) cfg = TrainConfig( processed_path=processed_path, output_dir=output_dir, seed=args.seed, batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay, hidden_dim=args.hidden_dim, dropout=args.dropout, w_pressure=args.w_pressure, w_derivative=args.w_derivative, w_slope=args.w_slope, w_bias_pressure=args.w_bias_pressure, w_bias_derivative=args.w_bias_derivative, w_derivative_shape=args.w_derivative_shape, w_autofit_pressure=args.w_autofit_pressure, w_autofit_derivative=args.w_autofit_derivative, huber_beta=args.huber_beta, use_sample_reweight=args.use_sample_reweight, sample_reweight_alpha=args.sample_reweight_alpha, sample_weight_min=args.sample_weight_min, sample_weight_max=args.sample_weight_max, use_schedule=use_schedule, ) train_forward(cfg) if __name__ == "__main__": main()