|
|
|
|
|
"""训练固定长度曲线正演代理模型。
|
|
|
|
|
|
|
|
|
|
|
|
脚本把命令行超参数整理为 `TrainConfig`,加载预处理后的参数、流量制度和曲线数据,
|
|
|
|
|
|
训练 `ForwardSurrogate` 来直接预测整条重采样压力/导数/斜率曲线,是当前数值试井
|
|
|
|
|
|
代理模型的主训练入口。
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
import sys
|
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
|
|
|
|
sys.path.append(str(ROOT))
|
|
|
|
|
|
|
|
|
|
|
|
from src.common.experiment_paths import model_dir_for_tag, normalize_tag, processed_path_for_tag
|
|
|
|
|
|
from src.training.train_forward import TrainConfig, train_forward
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
|
|
|
"""读取训练超参数并训练固定长度曲线正演代理模型。"""
|
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train forward surrogate model")
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--processed",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
help="Processed dataset path",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--output-dir",
|
|
|
|
|
|
type=str,
|
|
|
|
|
|
default=None,
|
|
|
|
|
|
help="Optional model output directory",
|
|
|
|
|
|
)
|
|
|
|
|
|
parser.add_argument("--tag", type=str, default=None, help="Experiment tag for auto naming")
|
|
|
|
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
|
|
|
|
parser.add_argument("--batch-size", type=int, default=256)
|
|
|
|
|
|
parser.add_argument("--epochs", type=int, default=220)
|
|
|
|
|
|
parser.add_argument("--lr", type=float, default=0.001)
|
|
|
|
|
|
parser.add_argument("--weight-decay", type=float, default=0.0005)
|
|
|
|
|
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
|
|
|
|
|
parser.add_argument("--dropout", type=float, default=0.1)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--w-pressure", type=float, default=1.0)
|
|
|
|
|
|
parser.add_argument("--w-derivative", type=float, default=2.0)
|
|
|
|
|
|
parser.add_argument("--w-slope", type=float, default=0.0)
|
|
|
|
|
|
parser.add_argument("--w-bias-pressure", type=float, default=0.15)
|
|
|
|
|
|
parser.add_argument("--w-bias-derivative", type=float, default=0.05)
|
|
|
|
|
|
parser.add_argument("--w-derivative-shape", type=float, default=0.10)
|
|
|
|
|
|
parser.add_argument("--w-autofit-pressure", type=float, default=0.0)
|
|
|
|
|
|
parser.add_argument("--w-autofit-derivative", type=float, default=0.0)
|
|
|
|
|
|
parser.add_argument("--huber-beta", type=float, default=0.05)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--use-sample-reweight", action="store_true", default=True)
|
|
|
|
|
|
parser.add_argument("--sample-reweight-alpha", type=float, default=0.4)
|
|
|
|
|
|
parser.add_argument("--sample-weight-min", type=float, default=1.0)
|
|
|
|
|
|
parser.add_argument("--sample-weight-max", type=float, default=2.5)
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument(
|
|
|
|
|
|
"--no-schedule",
|
|
|
|
|
|
action="store_true",
|
|
|
|
|
|
help="Disable the schedule branch and train a parameter-only forward surrogate",
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
|
|
tag = normalize_tag(args.tag)
|
|
|
|
|
|
use_schedule = not args.no_schedule
|
|
|
|
|
|
# processed/model 路径与 tag 绑定,保证预处理、训练和评估脚本默认指向同一实验。
|
|
|
|
|
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
|
|
|
|
|
output_dir = Path(args.output_dir) if args.output_dir is not None else model_dir_for_tag(tag, use_schedule)
|
|
|
|
|
|
|
|
|
|
|
|
cfg = TrainConfig(
|
|
|
|
|
|
processed_path=processed_path,
|
|
|
|
|
|
output_dir=output_dir,
|
|
|
|
|
|
seed=args.seed,
|
|
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
|
|
epochs=args.epochs,
|
|
|
|
|
|
lr=args.lr,
|
|
|
|
|
|
weight_decay=args.weight_decay,
|
|
|
|
|
|
hidden_dim=args.hidden_dim,
|
|
|
|
|
|
dropout=args.dropout,
|
|
|
|
|
|
w_pressure=args.w_pressure,
|
|
|
|
|
|
w_derivative=args.w_derivative,
|
|
|
|
|
|
w_slope=args.w_slope,
|
|
|
|
|
|
w_bias_pressure=args.w_bias_pressure,
|
|
|
|
|
|
w_bias_derivative=args.w_bias_derivative,
|
|
|
|
|
|
w_derivative_shape=args.w_derivative_shape,
|
|
|
|
|
|
w_autofit_pressure=args.w_autofit_pressure,
|
|
|
|
|
|
w_autofit_derivative=args.w_autofit_derivative,
|
|
|
|
|
|
huber_beta=args.huber_beta,
|
|
|
|
|
|
use_sample_reweight=args.use_sample_reweight,
|
|
|
|
|
|
sample_reweight_alpha=args.sample_reweight_alpha,
|
|
|
|
|
|
sample_weight_min=args.sample_weight_min,
|
|
|
|
|
|
sample_weight_max=args.sample_weight_max,
|
|
|
|
|
|
use_schedule=use_schedule,
|
|
|
|
|
|
)
|
|
|
|
|
|
train_forward(cfg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
main()
|