You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/Training/Runner/runner_main.cpp

442 lines
13 KiB
C++

#include "singlePhaseSolver.h"
#include <Windows.h>
#include <iostream>
#include <string>
#include <vector>
#include <fstream>
#include <sstream>
#include <algorithm>
#include "pch.h"
#include "DatasetIO.h"
#include "RunnerIO.h"
// HX_NWTM
typedef void (*HX_NWTM_MODEL_Func)(HX_NWTM_MODEL_OUTPUT&, const HX_NWTM_MODEL_INPUT&, std::string);
// log-log pre
typedef bool (*PreLogFun)(const std::vector<Point>&, const int&, double*, double*, int, std::vector<Point>&);
static bool fileExistsA(const std::string& p)
{
DWORD attr = GetFileAttributesA(p.c_str());
return (attr != INVALID_FILE_ATTRIBUTES) && !(attr & FILE_ATTRIBUTE_DIRECTORY);
}
static std::string getExeDir()
{
char buf[MAX_PATH] = { 0 };
DWORD len = GetModuleFileNameA(NULL, buf, MAX_PATH);
if (len == 0) return ".";
std::string path(buf, len);
size_t pos = path.find_last_of("\\/");
if (pos != std::string::npos) path.resize(pos);
return path;
}
// =============== 组装模型输入 ===============
static HX_NWTM_MODEL_INPUT buildModelInputFromDataset(const PebiScene& scene, const HX_NWTM_GRID_OUTPUT2& gridOutput2)
{
HX_NWTM_MODEL_INPUT input(gridOutput2);
input.T = scene.solverType;
input.Rate.t = scene.Rate.t;
input.Rate.qo = scene.Rate.qo;
input.Rate.qg = scene.Rate.qg;
input.Rate.qw = scene.Rate.qw;
input.CS.C = scene.CS.C;
input.CS.S = scene.CS.S;
input.PVT.p = scene.PVT.p;
input.PVT.pb = scene.PVT.pb;
input.PVT.Rso = scene.PVT.Rso;
input.PVT.Bo = scene.PVT.Bo;
input.PVT.Co = scene.PVT.Co;
input.PVT.miuo = scene.PVT.miuo;
input.PVT.rouo = scene.PVT.rouo;
input.PVT.Rv = scene.PVT.Rv;
input.PVT.Bg = scene.PVT.Bg;
input.PVT.Cg = scene.PVT.Cg;
input.PVT.miug = scene.PVT.miug;
input.PVT.roug = scene.PVT.roug;
input.PVT.Z = scene.PVT.Z;
input.PVT.Rsw = scene.PVT.Rsw;
input.PVT.Bw = scene.PVT.Bw;
input.PVT.Cw = scene.PVT.Cw;
input.PVT.miuw = scene.PVT.miuw;
input.PVT.rouw = scene.PVT.rouw;
input.PVT.V = scene.PVT.V;
input.PVT.k_kinitial = scene.PVT.k_kinitial;
input.PVT.Cf_Cfinitial = scene.PVT.Cf_Cfinitial;
input.PVT.So = scene.PVT.So;
input.PVT.Kro = scene.PVT.Kro;
input.PVT.Sg = scene.PVT.Sg;
input.PVT.Krg = scene.PVT.Krg;
input.PVT.Sw = scene.PVT.Sw;
input.PVT.Krw = scene.PVT.Krw;
input.Base.Pi = scene.Base.Pi;
input.Base.Cti = scene.Base.Cti;
input.Base.Cf = scene.Base.Cf;
input.Base.Soi = scene.Base.Soi;
input.Base.Sgi = scene.Base.Sgi;
input.Base.Swi = scene.Base.Swi;
input.Base.d = scene.Base.d;
input.Base.dt_Min = scene.Base.dt_Min;
input.Base.dt_Max = scene.Base.dt_Max;
size_t nCells = gridOutput2.Trinodexy.size();
input.Base.k = dVec1(nCells, scene.Base.k_ref);
input.Base.phi = dVec1(nCells, scene.Base.phi_ref);
input.Base.h = dVec1(nCells, scene.Base.h_ref);
return input;
}
static inline size_t inferWellCount(const HX_NWTM_MODEL_INPUT& in)
{
size_t n = 0;
n = (std::max)(n, in.Rate.t.size());
n = (std::max)(n, in.CS.C.size());
n = (std::max)(n, in.CS.S.size());
if (n == 0) n = 1;
return n;
}
// =============== 应用采样参数 ===============
static void applySampledParamsAndMaybeOverrideRate(HX_NWTM_MODEL_INPUT& in, const RunnerParams& p)
{
const size_t nWells = inferWellCount(in);
// Ensure well vector sizes (不会清空 capacity)
if (in.Rate.t.size() < nWells) in.Rate.t.resize(nWells);
if (in.Rate.qo.size() < nWells) in.Rate.qo.resize(nWells);
if (in.Rate.qg.size() < nWells) in.Rate.qg.resize(nWells);
if (in.Rate.qw.size() < nWells) in.Rate.qw.resize(nWells);
if (in.CS.S.size() < nWells) in.CS.S.resize(nWells, 0.0);
if (in.CS.C.size() < nWells) in.CS.C.resize(nWells, 0.0);
// 覆盖 skin / wellboreCfill 更快)
std::fill(in.CS.S.begin(), in.CS.S.end(), p.skin);
std::fill(in.CS.C.begin(), in.CS.C.end(), p.wellboreC);
// 覆盖 k/phi/h/Cfresize + fill复用 capacity
const size_t nCells = in.GRID.Trinodexy.size();
in.Base.k.resize(nCells);
in.Base.phi.resize(nCells);
in.Base.h.resize(nCells);
std::fill(in.Base.k.begin(), in.Base.k.end(), p.k);
std::fill(in.Base.phi.begin(), in.Base.phi.end(), p.phi);
std::fill(in.Base.h.begin(), in.Base.h.end(), p.h);
in.Base.Cf = p.Cf;
// 制度覆盖params.bin 带 schedule 才覆盖
if (!p.timeQ.empty() && !p.q.empty() && p.timeQ.size() == p.q.size())
{
const size_t nQ = p.timeQ.size();
for (size_t w = 0; w < nWells; ++w)
{
dVec1& tt = in.Rate.t[w];
dVec1& qo = in.Rate.qo[w];
dVec1& qg = in.Rate.qg[w];
dVec1& qw = in.Rate.qw[w];
tt.resize(nQ);
qo.resize(nQ);
qg.resize(nQ);
qw.resize(nQ);
std::copy(p.timeQ.begin(), p.timeQ.end(), tt.begin());
std::copy(p.q.begin(), p.q.end(), qo.begin());
std::fill(qg.begin(), qg.end(), 0.0);
std::fill(qw.begin(), qw.end(), 0.0);
}
}
}
// =============== 转换求解结果===============
static void toRunnerResult(const HX_NWTM_MODEL_OUTPUT& out, RunnerResult& rr)
{
rr.nSteps = (unsigned int)out.t.size();
rr.t = out.t;
rr.nWells = (unsigned int)out.pw.size();
rr.pw.resize(rr.nWells);
for (unsigned int w = 0; w < rr.nWells; ++w) {
rr.pw[w] = out.pw[w];
if (rr.pw[w].size() < rr.nSteps) rr.pw[w].resize(rr.nSteps, 0.0);
}
}
// =============== 计算双对数曲线(传入已加载的 preLogFun ===============
static bool computeLogLogCurves(RunnerResult& rr, const PebiScene& scene, const RunnerParams& params, PreLogFun preLogFun)
{
// 无论有没有 DLL都先保证结构完整避免 writeResultBin 越界/串数据
rr.loglog_t.assign(rr.nWells, std::vector<double>());
rr.loglog_p.assign(rr.nWells, std::vector<double>());
rr.loglog_deriv.assign(rr.nWells, std::vector<double>());
if (!preLogFun) {
return true; // 没 dllloglog 为空,但结构正确
}
for (unsigned int w = 0; w < rr.nWells; ++w)
{
std::vector<Point> wellPressureData;
wellPressureData.resize(rr.nSteps);
for (unsigned int i = 0; i < rr.nSteps; ++i) {
wellPressureData[i].x = rr.t[i];
wellPressureData[i].y = rr.pw[w][i];
wellPressureData[i].z = 0.0;
}
int sectionIndex = 0;
std::vector<double> timeQ;
std::vector<double> q;
// 1) params schedule 优先
if (!params.timeQ.empty() && !params.q.empty() && params.timeQ.size() == params.q.size())
{
sectionIndex = (int)params.sectionIndex;
timeQ = params.timeQ;
q = params.q;
}
else
{
// 2) scene 默认
if (!scene.wellFlowSectionIndex.empty()) {
size_t idx = (w < (unsigned int)scene.wellFlowSectionIndex.size()) ? w : 0;
sectionIndex = scene.wellFlowSectionIndex[idx];
}
if (w < scene.Rate.t.size()) timeQ = scene.Rate.t[w];
if (w < scene.Rate.qo.size()) q = scene.Rate.qo[w];
}
if (timeQ.size() < 2 || q.size() != timeQ.size()) {
return false;
}
std::vector<Point> logPreResult;
bool success = preLogFun(
wellPressureData,
sectionIndex,
timeQ.empty() ? NULL : &timeQ[0], // 安全处理空vector
q.empty() ? NULL : &q[0],
(int)timeQ.size(),
logPreResult
);
if (!success || logPreResult.size() <= 2) continue;
std::vector<double>& lt = rr.loglog_t[w];
std::vector<double>& lp = rr.loglog_p[w];
std::vector<double>& ld = rr.loglog_deriv[w];
lt.reserve(logPreResult.size() - 2);
lp.reserve(logPreResult.size() - 2);
ld.reserve(logPreResult.size() - 2);
for (size_t i = 1; i + 1 < logPreResult.size(); ++i) {
const Point& pt = logPreResult[i];
lt.push_back(pt.x);
lp.push_back(pt.y);
ld.push_back(pt.z);
}
}
return true;
}
// =============== server 模式:循环处理 paramsPath/resultPath ===============
static int runServer(const std::string& datasetPath,
const std::string& dllPath,
const std::string& licPath)
{
if (!fileExistsA(datasetPath)) { std::cerr << "ERROR: dataset not found\n"; return 10; }
if (!fileExistsA(dllPath)) { std::cerr << "ERROR: dll not found\n"; return 12; }
if (!fileExistsA(licPath)) { std::cerr << "ERROR: license not found\n"; return 13; }
// 1) load dataset once
PebiScene scene;
HX_NWTM_GRID_OUTPUT2 gridOutput2;
if (!DatasetIO::loadDataset(datasetPath, scene, gridOutput2, true)) {
std::cerr << "ERROR: loadDataset failed\n";
return 30;
}
// 2) prepare base input once
HX_NWTM_MODEL_INPUT modelInput = buildModelInputFromDataset(scene, gridOutput2);
// 3) load HX solver once
HMODULE hx = LoadLibraryA(dllPath.c_str());
if (!hx) { std::cerr << "ERROR: LoadLibrary(HX) failed\n"; return 40; }
HX_NWTM_MODEL_Func HX_NWTM_MODEL = (HX_NWTM_MODEL_Func)GetProcAddress(hx, "HX_NWTM_MODEL");
if (!HX_NWTM_MODEL) { std::cerr << "ERROR: GetProcAddress(HX_NWTM_MODEL) failed\n"; FreeLibrary(hx); return 41; }
// 4) load loglog dll once
PreLogFun preLogFun = NULL;
HMODULE hMod_solver = LoadLibrary(L"singlePhaseSolverDll.dll");
if (hMod_solver) {
preLogFun = (PreLogFun)GetProcAddress(hMod_solver, "logLogPre");
if (!preLogFun) {
FreeLibrary(hMod_solver);
hMod_solver = NULL;
}
}
std::string line;
RunnerResult rr;
// 输入格式:每行 "paramsPath resultPath"
while (std::getline(std::cin, line))
{
if (line.empty()) continue;
std::istringstream iss(line);
std::string paramsPath, resultPath;
if (!(iss >> paramsPath >> resultPath)) {
std::cerr << "ERROR: invalid input line, expected: paramsPath resultPath\n";
continue;
}
RunnerParams params;
if (!RunnerIO::readParamsBin(paramsPath, params)) {
std::cerr << "ERROR: read params failed: " << paramsPath << "\n";
continue;
}
// 覆盖参数(复用内存)
applySampledParamsAndMaybeOverrideRate(modelInput, params);
HX_NWTM_MODEL_OUTPUT modelOutput;
try {
HX_NWTM_MODEL(modelOutput, modelInput, licPath);
}
catch (...) {
std::cerr << "ERROR: solver exception\n";
continue;
}
toRunnerResult(modelOutput, rr);
if (!computeLogLogCurves(rr, scene, params, preLogFun)) {
std::cerr << "ERROR: computeLogLogCurves failed\n";
continue;
}
if (!RunnerIO::writeResultBin(resultPath, rr)) {
std::cerr << "ERROR: write result failed: " << resultPath << "\n";
continue;
}
// 方便 Python 端读取:成功就输出一行 OK
std::cout << "OK " << resultPath << "\n";
std::cout.flush();
}
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 0;
}
// =============== 单次模式 ===============
int main(int argc, char** argv)
{
std::string exeDir = getExeDir();
// server: runner.exe --server dataset.bin HX_NWTM.dll lic.dat
if (argc >= 2 && std::string(argv[1]) == "--server")
{
if (argc < 5) {
std::cerr << "Usage: runner.exe --server <dataset.bin> <HX_NWTM.dll> <license.dat>\n";
return 2;
}
return runServer(argv[2], argv[3], argv[4]);
}
// Single-run mode requires explicit paths. Do not fall back to exeDir\\HX_NWTM.dll,
// because Runner and HX_NWTM.dll must be built against the same PEBI header version.
if (argc < 6) {
std::cerr << "Usage: runner.exe <dataset.bin> <params.bin> <result.bin> <HX_NWTM.dll> <license.dat>\n";
return 2;
}
std::string datasetPath = argv[1];
std::string paramsPath = argv[2];
std::string resultPath = argv[3];
std::string dllPath = argv[4];
std::string licPath = argv[5];
if (!fileExistsA(datasetPath)) { std::cerr << "ERROR: dataset not found\n"; return 10; }
if (!fileExistsA(paramsPath)) { std::cerr << "ERROR: params not found\n"; return 11; }
if (!fileExistsA(dllPath)) { std::cerr << "ERROR: dll not found\n"; return 12; }
if (!fileExistsA(licPath)) { std::cerr << "ERROR: license not found\n"; return 13; }
RunnerParams params;
if (!RunnerIO::readParamsBin(paramsPath, params)) {
std::cerr << "ERROR: read params failed\n";
return 20;
}
PebiScene scene;
HX_NWTM_GRID_OUTPUT2 gridOutput2;
if (!DatasetIO::loadDataset(datasetPath, scene, gridOutput2, true)) {
std::cerr << "ERROR: loadDataset failed\n";
return 30;
}
HX_NWTM_MODEL_INPUT modelInput = buildModelInputFromDataset(scene, gridOutput2);
applySampledParamsAndMaybeOverrideRate(modelInput, params);
HMODULE hx = LoadLibraryA(dllPath.c_str());
if (!hx) { std::cerr << "ERROR: LoadLibrary failed\n"; return 40; }
HX_NWTM_MODEL_Func HX_NWTM_MODEL = (HX_NWTM_MODEL_Func)GetProcAddress(hx, "HX_NWTM_MODEL");
if (!HX_NWTM_MODEL) { std::cerr << "ERROR: GetProcAddress failed\n"; FreeLibrary(hx); return 41; }
// loglog dll
PreLogFun preLogFun = NULL;
HMODULE hMod_solver = LoadLibrary(L"singlePhaseSolverDll.dll");
if (hMod_solver) {
preLogFun = (PreLogFun)GetProcAddress(hMod_solver, "logLogPre");
if (!preLogFun) { FreeLibrary(hMod_solver); hMod_solver = NULL; }
}
HX_NWTM_MODEL_OUTPUT modelOutput;
try {
HX_NWTM_MODEL(modelOutput, modelInput, licPath);
}
catch (...) {
std::cerr << "ERROR: solver exception\n";
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 50;
}
RunnerResult rr;
toRunnerResult(modelOutput, rr);
if (!computeLogLogCurves(rr, scene, params, preLogFun)) {
std::cerr << "ERROR: computeLogLogCurves failed\n";
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 55;
}
if (!RunnerIO::writeResultBin(resultPath, rr)) {
std::cerr << "ERROR: write result failed\n";
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 60;
}
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 0;
}