You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/Training/Runner/runner_main.cpp

442 lines
13 KiB
C++

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "singlePhaseSolver.h"
#include <Windows.h>
#include <iostream>
#include <string>
#include <vector>
#include <fstream>
#include <sstream>
#include <algorithm>
#include "pch.h"
#include "DatasetIO.h"
#include "RunnerIO.h"
// HX_NWTM
typedef void (*HX_NWTM_MODEL_Func)(HX_NWTM_MODEL_OUTPUT&, const HX_NWTM_MODEL_INPUT&, std::string);
// log-log pre
typedef bool (*PreLogFun)(const std::vector<Point>&, const int&, double*, double*, int, std::vector<Point>&);
static bool fileExistsA(const std::string& p)
{
DWORD attr = GetFileAttributesA(p.c_str());
return (attr != INVALID_FILE_ATTRIBUTES) && !(attr & FILE_ATTRIBUTE_DIRECTORY);
}
static std::string getExeDir()
{
char buf[MAX_PATH] = { 0 };
DWORD len = GetModuleFileNameA(NULL, buf, MAX_PATH);
if (len == 0) return ".";
std::string path(buf, len);
size_t pos = path.find_last_of("\\/");
if (pos != std::string::npos) path.resize(pos);
return path;
}
// =============== 组装模型输入 ===============
static HX_NWTM_MODEL_INPUT buildModelInputFromDataset(const PebiScene& scene, const HX_NWTM_GRID_OUTPUT2& gridOutput2)
{
HX_NWTM_MODEL_INPUT input(gridOutput2);
input.T = scene.solverType;
input.Rate.t = scene.Rate.t;
input.Rate.qo = scene.Rate.qo;
input.Rate.qg = scene.Rate.qg;
input.Rate.qw = scene.Rate.qw;
input.CS.C = scene.CS.C;
input.CS.S = scene.CS.S;
input.PVT.p = scene.PVT.p;
input.PVT.pb = scene.PVT.pb;
input.PVT.Rso = scene.PVT.Rso;
input.PVT.Bo = scene.PVT.Bo;
input.PVT.Co = scene.PVT.Co;
input.PVT.miuo = scene.PVT.miuo;
input.PVT.rouo = scene.PVT.rouo;
input.PVT.Rv = scene.PVT.Rv;
input.PVT.Bg = scene.PVT.Bg;
input.PVT.Cg = scene.PVT.Cg;
input.PVT.miug = scene.PVT.miug;
input.PVT.roug = scene.PVT.roug;
input.PVT.Z = scene.PVT.Z;
input.PVT.Rsw = scene.PVT.Rsw;
input.PVT.Bw = scene.PVT.Bw;
input.PVT.Cw = scene.PVT.Cw;
input.PVT.miuw = scene.PVT.miuw;
input.PVT.rouw = scene.PVT.rouw;
input.PVT.V = scene.PVT.V;
input.PVT.k_kinitial = scene.PVT.k_kinitial;
input.PVT.Cf_Cfinitial = scene.PVT.Cf_Cfinitial;
input.PVT.So = scene.PVT.So;
input.PVT.Kro = scene.PVT.Kro;
input.PVT.Sg = scene.PVT.Sg;
input.PVT.Krg = scene.PVT.Krg;
input.PVT.Sw = scene.PVT.Sw;
input.PVT.Krw = scene.PVT.Krw;
input.Base.Pi = scene.Base.Pi;
input.Base.Cti = scene.Base.Cti;
input.Base.Cf = scene.Base.Cf;
input.Base.Soi = scene.Base.Soi;
input.Base.Sgi = scene.Base.Sgi;
input.Base.Swi = scene.Base.Swi;
input.Base.d = scene.Base.d;
input.Base.dt_Min = scene.Base.dt_Min;
input.Base.dt_Max = scene.Base.dt_Max;
size_t nCells = gridOutput2.Trinodexy.size();
input.Base.k = dVec1(nCells, scene.Base.k_ref);
input.Base.phi = dVec1(nCells, scene.Base.phi_ref);
input.Base.h = dVec1(nCells, scene.Base.h_ref);
return input;
}
static inline size_t inferWellCount(const HX_NWTM_MODEL_INPUT& in)
{
size_t n = 0;
n = (std::max)(n, in.Rate.t.size());
n = (std::max)(n, in.CS.C.size());
n = (std::max)(n, in.CS.S.size());
if (n == 0) n = 1;
return n;
}
// =============== 应用采样参数 ===============
static void applySampledParamsAndMaybeOverrideRate(HX_NWTM_MODEL_INPUT& in, const RunnerParams& p)
{
const size_t nWells = inferWellCount(in);
// Ensure well vector sizes (不会清空 capacity)
if (in.Rate.t.size() < nWells) in.Rate.t.resize(nWells);
if (in.Rate.qo.size() < nWells) in.Rate.qo.resize(nWells);
if (in.Rate.qg.size() < nWells) in.Rate.qg.resize(nWells);
if (in.Rate.qw.size() < nWells) in.Rate.qw.resize(nWells);
if (in.CS.S.size() < nWells) in.CS.S.resize(nWells, 0.0);
if (in.CS.C.size() < nWells) in.CS.C.resize(nWells, 0.0);
// 覆盖 skin / wellboreCfill 更快)
std::fill(in.CS.S.begin(), in.CS.S.end(), p.skin);
std::fill(in.CS.C.begin(), in.CS.C.end(), p.wellboreC);
// 覆盖 k/phi/h/Cfresize + fill复用 capacity
const size_t nCells = in.GRID.Trinodexy.size();
in.Base.k.resize(nCells);
in.Base.phi.resize(nCells);
in.Base.h.resize(nCells);
std::fill(in.Base.k.begin(), in.Base.k.end(), p.k);
std::fill(in.Base.phi.begin(), in.Base.phi.end(), p.phi);
std::fill(in.Base.h.begin(), in.Base.h.end(), p.h);
in.Base.Cf = p.Cf;
// 制度覆盖params.bin 带 schedule 才覆盖
if (!p.timeQ.empty() && !p.q.empty() && p.timeQ.size() == p.q.size())
{
const size_t nQ = p.timeQ.size();
for (size_t w = 0; w < nWells; ++w)
{
dVec1& tt = in.Rate.t[w];
dVec1& qo = in.Rate.qo[w];
dVec1& qg = in.Rate.qg[w];
dVec1& qw = in.Rate.qw[w];
tt.resize(nQ);
qo.resize(nQ);
qg.resize(nQ);
qw.resize(nQ);
std::copy(p.timeQ.begin(), p.timeQ.end(), tt.begin());
std::copy(p.q.begin(), p.q.end(), qo.begin());
std::fill(qg.begin(), qg.end(), 0.0);
std::fill(qw.begin(), qw.end(), 0.0);
}
}
}
// =============== 转换求解结果===============
static void toRunnerResult(const HX_NWTM_MODEL_OUTPUT& out, RunnerResult& rr)
{
rr.nSteps = (unsigned int)out.t.size();
rr.t = out.t;
rr.nWells = (unsigned int)out.pw.size();
rr.pw.resize(rr.nWells);
for (unsigned int w = 0; w < rr.nWells; ++w) {
rr.pw[w] = out.pw[w];
if (rr.pw[w].size() < rr.nSteps) rr.pw[w].resize(rr.nSteps, 0.0);
}
}
// =============== 计算双对数曲线(传入已加载的 preLogFun ===============
static bool computeLogLogCurves(RunnerResult& rr, const PebiScene& scene, const RunnerParams& params, PreLogFun preLogFun)
{
// 无论有没有 DLL都先保证结构完整避免 writeResultBin 越界/串数据
rr.loglog_t.assign(rr.nWells, std::vector<double>());
rr.loglog_p.assign(rr.nWells, std::vector<double>());
rr.loglog_deriv.assign(rr.nWells, std::vector<double>());
if (!preLogFun) {
return true; // 没 dllloglog 为空,但结构正确
}
for (unsigned int w = 0; w < rr.nWells; ++w)
{
std::vector<Point> wellPressureData;
wellPressureData.resize(rr.nSteps);
for (unsigned int i = 0; i < rr.nSteps; ++i) {
wellPressureData[i].x = rr.t[i];
wellPressureData[i].y = rr.pw[w][i];
wellPressureData[i].z = 0.0;
}
int sectionIndex = 0;
std::vector<double> timeQ;
std::vector<double> q;
// 1) params schedule 优先
if (!params.timeQ.empty() && !params.q.empty() && params.timeQ.size() == params.q.size())
{
sectionIndex = (int)params.sectionIndex;
timeQ = params.timeQ;
q = params.q;
}
else
{
// 2) scene 默认
if (!scene.wellFlowSectionIndex.empty()) {
size_t idx = (w < (unsigned int)scene.wellFlowSectionIndex.size()) ? w : 0;
sectionIndex = scene.wellFlowSectionIndex[idx];
}
if (w < scene.Rate.t.size()) timeQ = scene.Rate.t[w];
if (w < scene.Rate.qo.size()) q = scene.Rate.qo[w];
}
if (timeQ.size() < 2 || q.size() != timeQ.size()) {
return false;
}
std::vector<Point> logPreResult;
bool success = preLogFun(
wellPressureData,
sectionIndex,
timeQ.empty() ? NULL : &timeQ[0], // 安全处理空vector
q.empty() ? NULL : &q[0],
(int)timeQ.size(),
logPreResult
);
if (!success || logPreResult.size() <= 2) continue;
std::vector<double>& lt = rr.loglog_t[w];
std::vector<double>& lp = rr.loglog_p[w];
std::vector<double>& ld = rr.loglog_deriv[w];
lt.reserve(logPreResult.size() - 2);
lp.reserve(logPreResult.size() - 2);
ld.reserve(logPreResult.size() - 2);
for (size_t i = 1; i + 1 < logPreResult.size(); ++i) {
const Point& pt = logPreResult[i];
lt.push_back(pt.x);
lp.push_back(pt.y);
ld.push_back(pt.z);
}
}
return true;
}
// =============== server 模式:循环处理 paramsPath/resultPath ===============
static int runServer(const std::string& datasetPath,
const std::string& dllPath,
const std::string& licPath)
{
if (!fileExistsA(datasetPath)) { std::cerr << "ERROR: dataset not found\n"; return 10; }
if (!fileExistsA(dllPath)) { std::cerr << "ERROR: dll not found\n"; return 12; }
if (!fileExistsA(licPath)) { std::cerr << "ERROR: license not found\n"; return 13; }
// 1) load dataset once
PebiScene scene;
HX_NWTM_GRID_OUTPUT2 gridOutput2;
if (!DatasetIO::loadDataset(datasetPath, scene, gridOutput2, true)) {
std::cerr << "ERROR: loadDataset failed\n";
return 30;
}
// 2) prepare base input once
HX_NWTM_MODEL_INPUT modelInput = buildModelInputFromDataset(scene, gridOutput2);
// 3) load HX solver once
HMODULE hx = LoadLibraryA(dllPath.c_str());
if (!hx) { std::cerr << "ERROR: LoadLibrary(HX) failed\n"; return 40; }
HX_NWTM_MODEL_Func HX_NWTM_MODEL = (HX_NWTM_MODEL_Func)GetProcAddress(hx, "HX_NWTM_MODEL");
if (!HX_NWTM_MODEL) { std::cerr << "ERROR: GetProcAddress(HX_NWTM_MODEL) failed\n"; FreeLibrary(hx); return 41; }
// 4) load loglog dll once
PreLogFun preLogFun = NULL;
HMODULE hMod_solver = LoadLibrary(L"singlePhaseSolverDll.dll");
if (hMod_solver) {
preLogFun = (PreLogFun)GetProcAddress(hMod_solver, "logLogPre");
if (!preLogFun) {
FreeLibrary(hMod_solver);
hMod_solver = NULL;
}
}
std::string line;
RunnerResult rr;
// 输入格式:每行 "paramsPath resultPath"
while (std::getline(std::cin, line))
{
if (line.empty()) continue;
std::istringstream iss(line);
std::string paramsPath, resultPath;
if (!(iss >> paramsPath >> resultPath)) {
std::cerr << "ERROR: invalid input line, expected: paramsPath resultPath\n";
continue;
}
RunnerParams params;
if (!RunnerIO::readParamsBin(paramsPath, params)) {
std::cerr << "ERROR: read params failed: " << paramsPath << "\n";
continue;
}
// 覆盖参数(复用内存)
applySampledParamsAndMaybeOverrideRate(modelInput, params);
HX_NWTM_MODEL_OUTPUT modelOutput;
try {
HX_NWTM_MODEL(modelOutput, modelInput, licPath);
}
catch (...) {
std::cerr << "ERROR: solver exception\n";
continue;
}
toRunnerResult(modelOutput, rr);
if (!computeLogLogCurves(rr, scene, params, preLogFun)) {
std::cerr << "ERROR: computeLogLogCurves failed\n";
continue;
}
if (!RunnerIO::writeResultBin(resultPath, rr)) {
std::cerr << "ERROR: write result failed: " << resultPath << "\n";
continue;
}
// 方便 Python 端读取:成功就输出一行 OK
std::cout << "OK " << resultPath << "\n";
std::cout.flush();
}
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 0;
}
// =============== 单次模式 ===============
int main(int argc, char** argv)
{
std::string exeDir = getExeDir();
// server: runner.exe --server dataset.bin HX_NWTM.dll lic.dat
if (argc >= 2 && std::string(argv[1]) == "--server")
{
if (argc < 5) {
std::cerr << "Usage: runner.exe --server <dataset.bin> <HX_NWTM.dll> <license.dat>\n";
return 2;
}
return runServer(argv[2], argv[3], argv[4]);
}
// Single-run mode requires explicit paths. Do not fall back to exeDir\\HX_NWTM.dll,
// because Runner and HX_NWTM.dll must be built against the same PEBI header version.
if (argc < 6) {
std::cerr << "Usage: runner.exe <dataset.bin> <params.bin> <result.bin> <HX_NWTM.dll> <license.dat>\n";
return 2;
}
std::string datasetPath = argv[1];
std::string paramsPath = argv[2];
std::string resultPath = argv[3];
std::string dllPath = argv[4];
std::string licPath = argv[5];
if (!fileExistsA(datasetPath)) { std::cerr << "ERROR: dataset not found\n"; return 10; }
if (!fileExistsA(paramsPath)) { std::cerr << "ERROR: params not found\n"; return 11; }
if (!fileExistsA(dllPath)) { std::cerr << "ERROR: dll not found\n"; return 12; }
if (!fileExistsA(licPath)) { std::cerr << "ERROR: license not found\n"; return 13; }
RunnerParams params;
if (!RunnerIO::readParamsBin(paramsPath, params)) {
std::cerr << "ERROR: read params failed\n";
return 20;
}
PebiScene scene;
HX_NWTM_GRID_OUTPUT2 gridOutput2;
if (!DatasetIO::loadDataset(datasetPath, scene, gridOutput2, true)) {
std::cerr << "ERROR: loadDataset failed\n";
return 30;
}
HX_NWTM_MODEL_INPUT modelInput = buildModelInputFromDataset(scene, gridOutput2);
applySampledParamsAndMaybeOverrideRate(modelInput, params);
HMODULE hx = LoadLibraryA(dllPath.c_str());
if (!hx) { std::cerr << "ERROR: LoadLibrary failed\n"; return 40; }
HX_NWTM_MODEL_Func HX_NWTM_MODEL = (HX_NWTM_MODEL_Func)GetProcAddress(hx, "HX_NWTM_MODEL");
if (!HX_NWTM_MODEL) { std::cerr << "ERROR: GetProcAddress failed\n"; FreeLibrary(hx); return 41; }
// loglog dll
PreLogFun preLogFun = NULL;
HMODULE hMod_solver = LoadLibrary(L"singlePhaseSolverDll.dll");
if (hMod_solver) {
preLogFun = (PreLogFun)GetProcAddress(hMod_solver, "logLogPre");
if (!preLogFun) { FreeLibrary(hMod_solver); hMod_solver = NULL; }
}
HX_NWTM_MODEL_OUTPUT modelOutput;
try {
HX_NWTM_MODEL(modelOutput, modelInput, licPath);
}
catch (...) {
std::cerr << "ERROR: solver exception\n";
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 50;
}
RunnerResult rr;
toRunnerResult(modelOutput, rr);
if (!computeLogLogCurves(rr, scene, params, preLogFun)) {
std::cerr << "ERROR: computeLogLogCurves failed\n";
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 55;
}
if (!RunnerIO::writeResultBin(resultPath, rr)) {
std::cerr << "ERROR: write result failed\n";
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 60;
}
if (hMod_solver) FreeLibrary(hMod_solver);
FreeLibrary(hx);
return 0;
}