You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nmWTAI-Platform/ML/Training/Runner/RunnerIO.cpp

196 lines
4.9 KiB
C++

#include "RunnerIO.h"
#include <fstream>
#include <iostream>
#include <stdint.h>
#include <algorithm>
static uint32_t makeMagic(const char a, const char b, const char c, const char d)
{
return (uint32_t)(uint8_t)a
| ((uint32_t)(uint8_t)b << 8)
| ((uint32_t)(uint8_t)c << 16)
| ((uint32_t)(uint8_t)d << 24);
}
static bool readU32(std::ifstream& fs, uint32_t& v)
{
fs.read((char*)&v, sizeof(v));
return fs.good();
}
static bool writeU32(std::ofstream& fs, uint32_t v)
{
fs.write((const char*)&v, sizeof(v));
return fs.good();
}
static bool readDouble(std::ifstream& fs, double& v)
{
fs.read((char*)&v, sizeof(v));
return fs.good();
}
static inline bool writeDoubles(std::ofstream& fs, const double* p, size_t n)
{
if (n == 0) return true;
fs.write((const char*)p, (std::streamsize)(n * sizeof(double)));
return fs.good();
}
static bool writeDouble(std::ofstream& fs, double v)
{
fs.write((const char*)&v, sizeof(v));
return fs.good();
}
namespace RunnerIO
{
bool fileExists(const std::string& path)
{
std::ifstream fs(path.c_str(), std::ios::binary);
return fs.good();
}
void printParams(const RunnerParams& p)
{
std::cout << "Params:"
<< " k=" << p.k
<< ", skin=" << p.skin
<< ", C=" << p.wellboreC
<< ", phi=" << p.phi
<< ", h=" << p.h
<< ", Cf=" << p.Cf;
if (!p.timeQ.empty() && !p.q.empty() && p.timeQ.size() == p.q.size()) {
std::cout << ", sectionIndex=" << p.sectionIndex
<< ", nQ=" << p.timeQ.size();
} else {
std::cout << ", schedule=DEFAULT(from scene)";
}
std::cout << std::endl;
}
bool readParamsBin(const std::string& path, RunnerParams& out)
{
std::ifstream fs(path.c_str(), std::ios::binary);
if (!fs)
{
std::cerr << "RunnerIO::readParamsBin: cannot open " << path << std::endl;
return false;
}
const uint32_t MAGIC = makeMagic('P', 'R', 'M', '1');
uint32_t magic = 0, ver = 0;
if (!readU32(fs, magic) || !readU32(fs, ver))
{
std::cerr << "RunnerIO::readParamsBin: header read failed\n";
return false;
}
if (magic != MAGIC || ver != 1)
{
std::cerr << "RunnerIO::readParamsBin: magic/version mismatch\n";
return false;
}
if (!readDouble(fs, out.k)) return false;
if (!readDouble(fs, out.skin)) return false;
if (!readDouble(fs, out.wellboreC)) return false;
if (!readDouble(fs, out.phi)) return false;
if (!readDouble(fs, out.h)) return false;
if (!readDouble(fs, out.Cf)) return false;
out.sectionIndex = 0;
out.timeQ.clear();
out.q.clear();
int nextByte = fs.peek();
if (nextByte == EOF) {
return true;
}
uint32_t sec = 0, nQ = 0;
if (!readU32(fs, sec)) {
std::cerr << "RunnerIO::readParamsBin: schedule extension present but missing sectionIndex\n";
return false;
}
if (!readU32(fs, nQ)) {
std::cerr << "RunnerIO::readParamsBin: schedule extension present but missing nQ\n";
return false;
}
if (nQ < 2 || nQ > 100000) {
std::cerr << "RunnerIO::readParamsBin: invalid nQ=" << nQ << "\n";
return false;
}
out.sectionIndex = sec;
out.timeQ.resize(nQ);
out.q.resize(nQ);
for (uint32_t i = 0; i < nQ; ++i) {
if (!readDouble(fs, out.timeQ[i])) return false;
}
for (uint32_t i = 0; i < nQ; ++i) {
if (!readDouble(fs, out.q[i])) return false;
}
return fs.good();
}
bool writeResultBin(const std::string& path, const RunnerResult& r)
{
std::ofstream fs(path.c_str(), std::ios::binary);
if (!fs)
{
std::cerr << "RunnerIO::writeResultBin: cannot open " << path << std::endl;
return false;
}
//使用普通静态变量
static std::vector<char> buf;
buf.resize(4 * 1024 * 1024);
fs.rdbuf()->pubsetbuf(&buf[0], (std::streamsize)buf.size());
const uint32_t MAGIC = makeMagic('R', 'S', 'B', '1');
if (!writeU32(fs, MAGIC)) return false;
if (!writeU32(fs, 1)) return false;
if (!writeU32(fs, (uint32_t)r.nWells)) return false;
if (!writeU32(fs, (uint32_t)r.nSteps)) return false;
// t
if (r.nSteps > 0) {
if (r.t.size() < r.nSteps) return false;
if (!writeDoubles(fs, &r.t[0], r.nSteps)) return false; // 使用 &r.t[0] 替代 .data()
}
// pw
for (unsigned int w = 0; w < r.nWells; ++w)
{
if (r.nSteps > 0) {
if (w >= r.pw.size()) return false;
if (r.pw[w].size() < r.nSteps) return false;
if (!writeDoubles(fs, &r.pw[w][0], r.nSteps)) return false; // 使用 &r.pw[w][0]
}
}
// loglog
for (unsigned int w = 0; w < r.nWells; ++w)
{
uint32_t nLogLog = 0;
if (w < r.loglog_t.size()) nLogLog = (uint32_t)r.loglog_t[w].size();
if (!writeU32(fs, nLogLog)) return false;
if (nLogLog == 0) continue;
if (w >= r.loglog_p.size() || w >= r.loglog_deriv.size()) return false;
if (r.loglog_p[w].size() != nLogLog) return false;
if (r.loglog_deriv[w].size() != nLogLog) return false;
if (!writeDoubles(fs, &r.loglog_t[w][0], nLogLog)) return false;
if (!writeDoubles(fs, &r.loglog_p[w][0], nLogLog)) return false;
if (!writeDoubles(fs, &r.loglog_deriv[w][0], nLogLog)) return false;
}
return fs.good();
}
}