/*========================================================================= Program: Visualization Toolkit Module: vtkThreadedTaskQueue.txx Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen All rights reserved. See Copyright.txt or http://www.kitware.com/Copyright.htm for details. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the above copyright notice for more information. =========================================================================*/ #include "vtkLogger.h" #include "vtkMultiThreader.h" #include #include #include //============================================================================= namespace vtkThreadedTaskQueueInternals { template class TaskQueue { public: TaskQueue(int buffer_size) : Done(false) , BufferSize(buffer_size) , NextTaskId(0) { } ~TaskQueue() = default; void MarkDone() { this->Done = true; this->TasksCV.notify_all(); } std::uint64_t GetNextTaskId() const { return this->NextTaskId; } void Push(std::function&& task) { if (this->Done) { return; } else { std::lock_guard lk(this->TasksMutex); // vtkLogF(INFO, "pushing-task %d", (int)this->NextTaskId); this->Tasks.push(std::make_pair(this->NextTaskId++, std::move(task))); while (this->BufferSize > 0 && static_cast(this->Tasks.size()) > this->BufferSize) { this->Tasks.pop(); } } this->TasksCV.notify_one(); } bool Pop(std::uint64_t& task_id, std::function& task) { std::unique_lock lk(this->TasksMutex); this->TasksCV.wait(lk, [this] { return this->Done || this->Tasks.size() > 0; }); if (this->Tasks.size() > 0) { auto task_pair = this->Tasks.front(); // vtkLogF(TRACE, "popping-task %d", (int)task_pair.first); this->Tasks.pop(); lk.unlock(); task_id = task_pair.first; task = std::move(task_pair.second); return true; } assert(this->Done); return false; } private: std::atomic_bool Done; int BufferSize; std::atomic NextTaskId; std::queue > > Tasks; std::mutex TasksMutex; std::condition_variable TasksCV; }; //============================================================================= template class ResultQueue { public: ResultQueue(bool strict_ordering) : NextResultId(0) , StrictOrdering(strict_ordering) { } ~ResultQueue() = default; std::uint64_t GetNextResultId() const { return this->NextResultId; } void Push(std::uint64_t task_id, const R&& result) { std::unique_lock lk(this->ResultsMutex); // don't save this result if it's obsolete. if (task_id >= this->NextResultId) { this->Results.push(std::make_pair(task_id, std::move(result))); } lk.unlock(); this->ResultsCV.notify_one(); } bool TryPop(R& result) { std::unique_lock lk(this->ResultsMutex); if (this->Results.empty() || (this->StrictOrdering && this->Results.top().first != this->NextResultId)) { // results are not available or of strict-ordering is requested, the // result available is not the next one in sequence, hence don't pop // anything. return false; } auto result_pair = this->Results.top(); this->NextResultId = (result_pair.first + 1); this->Results.pop(); lk.unlock(); result = std::move(result_pair.second); return true; } bool Pop(R& result) { std::unique_lock lk(this->ResultsMutex); this->ResultsCV.wait(lk, [this] { return !this->Results.empty() && (!this->StrictOrdering || this->Results.top().first == this->NextResultId); }); lk.unlock(); return this->TryPop(result); } private: template struct Comparator { bool operator()(const T& left, const T& right) const { return left.first > right.first; } }; std::priority_queue, std::vector >, Comparator > > Results; std::mutex ResultsMutex; std::condition_variable ResultsCV; std::atomic NextResultId; bool StrictOrdering; }; } //----------------------------------------------------------------------------- template vtkThreadedTaskQueue::vtkThreadedTaskQueue( std::function worker, bool strict_ordering, int buffer_size, int max_concurrent_tasks) : Worker(worker) , Tasks(new vtkThreadedTaskQueueInternals::TaskQueue( std::max(0, strict_ordering ? 0 : buffer_size))) , Results(new vtkThreadedTaskQueueInternals::ResultQueue(strict_ordering)) , NumberOfThreads(max_concurrent_tasks <= 0 ? vtkMultiThreader::GetGlobalDefaultNumberOfThreads() : max_concurrent_tasks) , Threads{ new std::thread[this->NumberOfThreads] } { auto f = [this](int thread_id) { vtkLogger::SetThreadName("ttq::worker" + std::to_string(thread_id)); while (true) { std::function task; std::uint64_t task_id; if (this->Tasks->Pop(task_id, task)) { this->Results->Push(task_id, task()); continue; } else { break; } } // vtkLogF(INFO, "done"); }; for (int cc = 0; cc < this->NumberOfThreads; ++cc) { this->Threads[cc] = std::thread(f, cc); } } //----------------------------------------------------------------------------- template vtkThreadedTaskQueue::~vtkThreadedTaskQueue() { this->Tasks->MarkDone(); for (int cc = 0; cc < this->NumberOfThreads; ++cc) { this->Threads[cc].join(); } } //----------------------------------------------------------------------------- template void vtkThreadedTaskQueue::Push(Args&&... args) { this->Tasks->Push(std::bind(this->Worker, args...)); } //----------------------------------------------------------------------------- template bool vtkThreadedTaskQueue::TryPop(R& result) { return this->Results->TryPop(result); } //----------------------------------------------------------------------------- template bool vtkThreadedTaskQueue::Pop(R& result) { if (this->IsEmpty()) { return false; } return this->Results->Pop(result); } //----------------------------------------------------------------------------- template bool vtkThreadedTaskQueue::IsEmpty() const { return this->Results->GetNextResultId() == this->Tasks->GetNextTaskId(); } //----------------------------------------------------------------------------- template void vtkThreadedTaskQueue::Flush() { R tmp; while (!this->IsEmpty()) { this->Pop(tmp); } } //============================================================================= // ** specialization for `void` returns types. //============================================================================= //----------------------------------------------------------------------------- template vtkThreadedTaskQueue::vtkThreadedTaskQueue(std::function worker, bool strict_ordering, int buffer_size, int max_concurrent_tasks) : Worker(worker) , Tasks(new vtkThreadedTaskQueueInternals::TaskQueue( std::max(0, strict_ordering ? 0 : buffer_size))) , NextResultId(0) , NumberOfThreads(max_concurrent_tasks <= 0 ? vtkMultiThreader::GetGlobalDefaultNumberOfThreads() : max_concurrent_tasks) , Threads{ new std::thread[this->NumberOfThreads] } { auto f = [this](int thread_id) { vtkLogger::SetThreadName("ttq::worker" + std::to_string(thread_id)); while (true) { std::function task; std::uint64_t task_id; if (this->Tasks->Pop(task_id, task)) { task(); std::unique_lock lk(this->NextResultIdMutex); this->NextResultId = std::max(static_cast(this->NextResultId), task_id + 1); lk.unlock(); this->ResultsCV.notify_all(); continue; } else { break; } } this->ResultsCV.notify_all(); // vtkLogF(INFO, "done"); }; for (int cc = 0; cc < this->NumberOfThreads; ++cc) { this->Threads[cc] = std::thread(f, cc); } } //----------------------------------------------------------------------------- template vtkThreadedTaskQueue::~vtkThreadedTaskQueue() { this->Tasks->MarkDone(); for (int cc = 0; cc < this->NumberOfThreads; ++cc) { this->Threads[cc].join(); } } //----------------------------------------------------------------------------- template void vtkThreadedTaskQueue::Push(Args&&... args) { this->Tasks->Push(std::bind(this->Worker, args...)); } //----------------------------------------------------------------------------- template bool vtkThreadedTaskQueue::IsEmpty() const { return this->NextResultId == this->Tasks->GetNextTaskId(); } //----------------------------------------------------------------------------- template void vtkThreadedTaskQueue::Flush() { if (this->IsEmpty()) { return; } std::unique_lock lk(this->NextResultIdMutex); this->ResultsCV.wait(lk, [this] { return this->IsEmpty(); }); }