You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/nmWTAI-ML/scripts/train_forward.py

126 lines
4.6 KiB
Python

"""训练固定长度曲线正演代理模型。
脚本把命令行超参数整理为 `TrainConfig`加载预处理后的参数流量制度和曲线数据
训练 `ForwardSurrogate` 来直接预测整条重采样压力/导数/斜率曲线是当前数值试井
代理模型的主训练入口
"""
# pylint: disable=import-error,wrong-import-position
from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT))
from src.common.experiment_paths import model_dir_for_tag, normalize_tag, processed_path_for_tag
from src.training.train_forward import (
LossConfig,
LossWeights,
ModelConfig,
OptimConfig,
SampleReweightConfig,
TrainConfig,
TrainRuntime,
train_forward,
)
def main() -> None:
"""读取训练超参数并训练固定长度曲线正演代理模型。"""
parser = argparse.ArgumentParser(description="Train forward surrogate model")
parser.add_argument(
"--processed",
type=str,
default=None,
help="Processed dataset path",
)
parser.add_argument(
"--output-dir",
type=str,
default=None,
help="Optional model output directory",
)
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch-size", type=int, default=256)
parser.add_argument("--epochs", type=int, default=220)
parser.add_argument("--lr", type=float, default=0.001)
parser.add_argument("--weight-decay", type=float, default=0.0005)
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--w-pressure", type=float, default=1.0)
parser.add_argument("--w-derivative", type=float, default=2.0)
parser.add_argument("--w-slope", type=float, default=0.0)
parser.add_argument("--w-bias-pressure", type=float, default=0.15)
parser.add_argument("--w-bias-derivative", type=float, default=0.05)
parser.add_argument("--w-derivative-shape", type=float, default=0.10)
parser.add_argument("--w-autofit-pressure", type=float, default=0.0)
parser.add_argument("--w-autofit-derivative", type=float, default=0.0)
parser.add_argument("--huber-beta", type=float, default=0.05)
parser.add_argument("--use-sample-reweight", action="store_true", default=True)
parser.add_argument("--no-sample-reweight", action="store_false", dest="use_sample_reweight")
parser.add_argument("--sample-reweight-alpha", type=float, default=0.4)
parser.add_argument("--sample-weight-min", type=float, default=1.0)
parser.add_argument("--sample-weight-max", type=float, default=2.5)
parser.add_argument(
"--no-schedule",
action="store_true",
help="Disable the schedule branch and train a parameter-only forward surrogate",
)
args = parser.parse_args()
tag = normalize_tag(args.tag)
use_schedule = not args.no_schedule
# processed/model 路径与 tag 绑定,保证预处理、训练和评估脚本默认指向同一实验。
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
output_dir = Path(args.output_dir) if args.output_dir is not None else model_dir_for_tag(tag, use_schedule)
cfg = TrainConfig(
processed_path=processed_path,
output_dir=output_dir,
runtime=TrainRuntime(seed=args.seed),
optim=OptimConfig(
batch_size=args.batch_size,
epochs=args.epochs,
lr=args.lr,
weight_decay=args.weight_decay,
),
model=ModelConfig(
hidden_dim=args.hidden_dim,
dropout=args.dropout,
use_schedule=use_schedule,
),
loss=LossConfig(
weights=LossWeights(
pressure=args.w_pressure,
derivative=args.w_derivative,
slope=args.w_slope,
bias_pressure=args.w_bias_pressure,
bias_derivative=args.w_bias_derivative,
derivative_shape=args.w_derivative_shape,
autofit_pressure=args.w_autofit_pressure,
autofit_derivative=args.w_autofit_derivative,
),
huber_beta=args.huber_beta,
),
sample_reweight=SampleReweightConfig(
enabled=args.use_sample_reweight,
alpha=args.sample_reweight_alpha,
weight_min=args.sample_weight_min,
weight_max=args.sample_weight_max,
),
)
train_forward(cfg)
if __name__ == "__main__":
main()