You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
4.1 KiB
Python
105 lines
4.1 KiB
Python
"""训练时间条件正演代理模型。
|
|
|
|
脚本把预处理数据展开为按时间点监督的训练任务,并将命令行参数封装为
|
|
`TimeConditionedTrainConfig`。模型学习在给定参数、制度和时间特征时预测压力/导数,
|
|
适合处理可变时间采样或逐点推理场景。
|
|
"""
|
|
|
|
# pylint: disable=import-error,wrong-import-position
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
sys.path.append(str(ROOT))
|
|
|
|
from src.common.experiment_paths import normalize_tag, processed_path_for_tag
|
|
from src.training.train_time_conditioned import (
|
|
RiskWeightConfig,
|
|
TimeConditionedTrainConfig,
|
|
TimeLossConfig,
|
|
TimeModelConfig,
|
|
TimeOptimConfig,
|
|
TimeRuntimeConfig,
|
|
train_time_conditioned,
|
|
)
|
|
|
|
|
|
def main() -> None:
|
|
"""读取训练参数并训练按时间点展开的时间条件代理模型。"""
|
|
parser = argparse.ArgumentParser(description="Train a time-conditioned point-wise forward surrogate")
|
|
parser.add_argument("--processed", type=str, default=None)
|
|
parser.add_argument("--tag", type=str, default=None)
|
|
parser.add_argument("--output-dir", type=str, default=None)
|
|
parser.add_argument("--seed", type=int, default=42)
|
|
parser.add_argument("--batch-size", type=int, default=4096)
|
|
parser.add_argument("--epochs", type=int, default=120)
|
|
parser.add_argument("--lr", type=float, default=1.0e-3)
|
|
parser.add_argument("--weight-decay", type=float, default=1.0e-4)
|
|
parser.add_argument("--hidden-dim", type=int, default=256)
|
|
parser.add_argument("--n-blocks", type=int, default=4)
|
|
parser.add_argument("--dropout", type=float, default=0.05)
|
|
parser.add_argument("--w-pressure", type=float, default=1.0)
|
|
parser.add_argument("--w-derivative", type=float, default=2.0)
|
|
parser.add_argument("--huber-beta", type=float, default=0.05)
|
|
parser.add_argument("--no-schedule", action="store_true")
|
|
parser.add_argument(
|
|
"--sample-weight-mode",
|
|
choices=["none", "risk_region"],
|
|
default="none",
|
|
help="Optional risk-region sample weighting; default keeps the original unweighted training behavior",
|
|
)
|
|
parser.add_argument("--risk-weight", type=float, default=2.5)
|
|
parser.add_argument("--skin-lt-minus8-weight", type=float, default=3.5)
|
|
parser.add_argument("--sample-weight-min", type=float, default=1.0)
|
|
parser.add_argument("--sample-weight-max", type=float, default=4.0)
|
|
args = parser.parse_args()
|
|
|
|
tag = normalize_tag(args.tag)
|
|
processed_path = Path(args.processed) if args.processed is not None else processed_path_for_tag(tag)
|
|
# 时间条件模型使用独立目录,避免覆盖固定长度曲线代理模型的 checkpoint。
|
|
if args.output_dir is not None:
|
|
output_dir = Path(args.output_dir)
|
|
elif tag:
|
|
output_dir = Path("models") / f"time_conditioned_surrogate_{tag}"
|
|
else:
|
|
output_dir = Path("models") / "time_conditioned_surrogate"
|
|
|
|
cfg = TimeConditionedTrainConfig(
|
|
processed_path=processed_path,
|
|
output_dir=output_dir,
|
|
runtime=TimeRuntimeConfig(seed=int(args.seed)),
|
|
optim=TimeOptimConfig(
|
|
batch_size=int(args.batch_size),
|
|
epochs=int(args.epochs),
|
|
lr=float(args.lr),
|
|
weight_decay=float(args.weight_decay),
|
|
),
|
|
model=TimeModelConfig(
|
|
hidden_dim=int(args.hidden_dim),
|
|
n_blocks=int(args.n_blocks),
|
|
dropout=float(args.dropout),
|
|
use_schedule=not bool(args.no_schedule),
|
|
),
|
|
loss=TimeLossConfig(
|
|
w_pressure=float(args.w_pressure),
|
|
w_derivative=float(args.w_derivative),
|
|
huber_beta=float(args.huber_beta),
|
|
),
|
|
risk_weight=RiskWeightConfig(
|
|
mode=str(args.sample_weight_mode),
|
|
risk_weight=float(args.risk_weight),
|
|
skin_lt_minus8_weight=float(args.skin_lt_minus8_weight),
|
|
weight_min=float(args.sample_weight_min),
|
|
weight_max=float(args.sample_weight_max),
|
|
),
|
|
)
|
|
train_time_conditioned(cfg)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|