"""训练固定长度曲线正演代理模型。 脚本把命令行超参数整理为 `TrainConfig`,加载预处理后的参数、流量制度和曲线数据, 训练 `ForwardSurrogate` 来直接预测整条重采样压力/导数/斜率曲线,是当前数值试井 代理模型的主训练入口。 """ # pylint: disable=import-error,wrong-import-position from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT)) from src.common.experiment_paths import model_dir_for_tag, normalize_tag, processed_path_for_tag from src.training.train_forward import ( LossConfig, LossWeights, ModelConfig, OptimConfig, SampleReweightConfig, TrainConfig, TrainRuntime, train_forward, ) def main() -> None: """读取训练超参数并训练固定长度曲线正演代理模型。""" parser = argparse.ArgumentParser(description="Train forward surrogate model") parser.add_argument( "--processed", type=str, default=None, help="Processed dataset path", ) parser.add_argument( "--output-dir", type=str, default=None, help="Optional model output directory", ) parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming") parser.add_argument("--seed", type=int, default=42) parser.add_argument("--batch-size", type=int, default=256) parser.add_argument("--epochs", type=int, default=220) parser.add_argument("--lr", type=float, default=0.001) parser.add_argument("--weight-decay", type=float, default=0.0005) parser.add_argument("--hidden-dim", type=int, default=256) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--w-pressure", type=float, default=1.0) parser.add_argument("--w-derivative", type=float, default=2.0) parser.add_argument("--w-slope", type=float, default=0.0) parser.add_argument("--w-bias-pressure", type=float, default=0.15) parser.add_argument("--w-bias-derivative", type=float, default=0.05) parser.add_argument("--w-derivative-shape", type=float, default=0.10) parser.add_argument("--w-autofit-pressure", type=float, default=0.0) parser.add_argument("--w-autofit-derivative", type=float, default=0.0) parser.add_argument("--huber-beta", type=float, default=0.05) parser.add_argument("--use-sample-reweight", action="store_true", default=True) parser.add_argument("--no-sample-reweight", action="store_false", dest="use_sample_reweight") parser.add_argument("--sample-reweight-alpha", type=float, default=0.4) parser.add_argument("--sample-weight-min", type=float, default=1.0) parser.add_argument("--sample-weight-max", type=float, default=2.5) parser.add_argument( "--no-schedule", action="store_true", help="Disable the schedule branch and train a parameter-only forward surrogate", ) args = parser.parse_args() tag = normalize_tag(args.tag) use_schedule = not args.no_schedule # processed/model 路径与 tag 绑定,保证预处理、训练和评估脚本默认指向同一实验。 processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag) output_dir = Path(args.output_dir) if args.output_dir is not None else model_dir_for_tag(tag, use_schedule) cfg = TrainConfig( processed_path=processed_path, output_dir=output_dir, runtime=TrainRuntime(seed=args.seed), optim=OptimConfig( batch_size=args.batch_size, epochs=args.epochs, lr=args.lr, weight_decay=args.weight_decay, ), model=ModelConfig( hidden_dim=args.hidden_dim, dropout=args.dropout, use_schedule=use_schedule, ), loss=LossConfig( weights=LossWeights( pressure=args.w_pressure, derivative=args.w_derivative, slope=args.w_slope, bias_pressure=args.w_bias_pressure, bias_derivative=args.w_bias_derivative, derivative_shape=args.w_derivative_shape, autofit_pressure=args.w_autofit_pressure, autofit_derivative=args.w_autofit_derivative, ), huber_beta=args.huber_beta, ), sample_reweight=SampleReweightConfig( enabled=args.use_sample_reweight, alpha=args.sample_reweight_alpha, weight_min=args.sample_weight_min, weight_max=args.sample_weight_max, ), ) train_forward(cfg) if __name__ == "__main__": main()