1、数据集生成路径适配

feature/ModelOpt20260526
1294271022 2 weeks ago
parent 47d00106b1
commit abe4804690

@ -360,20 +360,18 @@ int main(int argc, char** argv)
return runServer(argv[2], argv[3], argv[4]);
}
// 默认单次模式
std::string dataDir = exeDir + "\\..\\..\\nmWTAI-ML\\data\\temp";
std::string datasetPath = dataDir + "\\dataset.bin";
std::string paramsPath = dataDir + "\\params.bin";
std::string resultPath = dataDir + "\\result.bin";
std::string dllPath = exeDir + "\\HX_NWTM.dll";
std::string licPath = exeDir + "\\..\\..\\..\\Bin\\Res\\license\\HXNWTM_license.dat";
if (argc >= 2) datasetPath = argv[1];
if (argc >= 3) paramsPath = argv[2];
if (argc >= 4) resultPath = argv[3];
if (argc >= 5) dllPath = argv[4];
if (argc >= 6) licPath = argv[5];
// Single-run mode requires explicit paths. Do not fall back to exeDir\\HX_NWTM.dll,
// because Runner and HX_NWTM.dll must be built against the same PEBI header version.
if (argc < 6) {
std::cerr << "Usage: runner.exe <dataset.bin> <params.bin> <result.bin> <HX_NWTM.dll> <license.dat>\n";
return 2;
}
std::string datasetPath = argv[1];
std::string paramsPath = argv[2];
std::string resultPath = argv[3];
std::string dllPath = argv[4];
std::string licPath = argv[5];
if (!fileExistsA(datasetPath)) { std::cerr << "ERROR: dataset not found\n"; return 10; }
if (!fileExistsA(paramsPath)) { std::cerr << "ERROR: params not found\n"; return 11; }

@ -12,7 +12,7 @@ paths:
cpp:
training_exe: "../Training/Release/training.exe"
runner_exe: "../Training/Release/runner.exe"
hx_dll: "../Training/Release/HX_NWTM.dll"
hx_dll: "../../3rd/Pebi/V1/bin/HX_NWTM.dll"
license_dat: "../../Bin/Res/license/HXNWTM_license.dat"
dataset_runtime:
@ -26,7 +26,7 @@ streaming_hdf5: # HDF5 样本文件流式写入设置
compression: null # HDF5 压缩方式null 表示不压缩
parallel: # 并行样本生成设置
n_workers: 36 # 并行生成样本的工作进程数量
n_workers: 12 # 并行生成样本的工作进程数量
max_in_flight: 48 # 同时提交但尚未完成的最大任务数
checkpoint_every_n: 5000 # 每生成多少条样本输出一次检查点
checkpoint_every_sec: 180.0 # 每隔多少秒输出一次检查点

@ -7,3 +7,4 @@ numpy==2.4.3
PyYAML==6.0.3
scikit-learn==1.8.0
torch==2.11.0
tqdm==4.67.3

@ -433,7 +433,14 @@ def _worker_simulate_parallel(args):
_RESULT_BIN.unlink()
# 每个 worker 使用独立临时目录,避免并行运行时 params.bin/result.bin 相互覆盖。
cmd = [str(cfg.runner_exe), str(cfg.dataset_bin), str(_PARAMS_BIN), str(_RESULT_BIN)]
cmd = [
str(cfg.runner_exe),
str(cfg.dataset_bin),
str(_PARAMS_BIN),
str(_RESULT_BIN),
str(cfg.hx_dll),
str(cfg.license_dat),
]
result = subprocess.run(
cmd,
cwd=str(cfg.runner_exe.parent),
@ -472,8 +479,8 @@ def _worker_simulate_parallel(args):
sec = int(np.clip(int(sch.sectionIndex), 1, max(len(sch.timeQ), 1)))
enc = encode_schedule_to_timegrid(
cfg,
sectionIndex=sec,
timeQ=sch.timeQ,
section_index=sec,
time_q=sch.timeQ,
q=sch.q,
n_sections=len(sch.timeQ),
)

Loading…
Cancel
Save