You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
346 lines
9.6 KiB
Plaintext
346 lines
9.6 KiB
Plaintext
3 weeks ago
|
/*=========================================================================
|
||
|
|
||
|
Program: Visualization Toolkit
|
||
|
Module: vtkThreadedTaskQueue.txx
|
||
|
|
||
|
Copyright (c) Ken Martin, Will Schroeder, Bill Lorensen
|
||
|
All rights reserved.
|
||
|
See Copyright.txt or http://www.kitware.com/Copyright.htm for details.
|
||
|
|
||
|
This software is distributed WITHOUT ANY WARRANTY; without even
|
||
|
the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
|
||
|
PURPOSE. See the above copyright notice for more information.
|
||
|
|
||
|
=========================================================================*/
|
||
|
#include "vtkLogger.h"
|
||
|
|
||
|
#include "vtkMultiThreader.h"
|
||
|
|
||
|
#include <cassert>
|
||
|
#include <queue>
|
||
|
#include <utility>
|
||
|
|
||
|
//=============================================================================
|
||
|
namespace vtkThreadedTaskQueueInternals
|
||
|
{
|
||
|
|
||
|
template <typename R>
|
||
|
class TaskQueue
|
||
|
{
|
||
|
public:
|
||
|
TaskQueue(int buffer_size)
|
||
|
: Done(false)
|
||
|
, BufferSize(buffer_size)
|
||
|
, NextTaskId(0)
|
||
|
{
|
||
|
}
|
||
|
|
||
|
~TaskQueue() = default;
|
||
|
|
||
|
void MarkDone()
|
||
|
{
|
||
|
this->Done = true;
|
||
|
this->TasksCV.notify_all();
|
||
|
}
|
||
|
|
||
|
std::uint64_t GetNextTaskId() const { return this->NextTaskId; }
|
||
|
|
||
|
void Push(std::function<R()>&& task)
|
||
|
{
|
||
|
if (this->Done)
|
||
|
{
|
||
|
return;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
std::lock_guard<std::mutex> lk(this->TasksMutex);
|
||
|
// vtkLogF(INFO, "pushing-task %d", (int)this->NextTaskId);
|
||
|
this->Tasks.push(std::make_pair(this->NextTaskId++, std::move(task)));
|
||
|
while (this->BufferSize > 0 && static_cast<int>(this->Tasks.size()) > this->BufferSize)
|
||
|
{
|
||
|
this->Tasks.pop();
|
||
|
}
|
||
|
}
|
||
|
this->TasksCV.notify_one();
|
||
|
}
|
||
|
|
||
|
bool Pop(std::uint64_t& task_id, std::function<R()>& task)
|
||
|
{
|
||
|
std::unique_lock<std::mutex> lk(this->TasksMutex);
|
||
|
this->TasksCV.wait(lk, [this] { return this->Done || this->Tasks.size() > 0; });
|
||
|
if (this->Tasks.size() > 0)
|
||
|
{
|
||
|
auto task_pair = this->Tasks.front();
|
||
|
// vtkLogF(TRACE, "popping-task %d", (int)task_pair.first);
|
||
|
this->Tasks.pop();
|
||
|
lk.unlock();
|
||
|
|
||
|
task_id = task_pair.first;
|
||
|
task = std::move(task_pair.second);
|
||
|
return true;
|
||
|
}
|
||
|
assert(this->Done);
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
std::atomic_bool Done;
|
||
|
int BufferSize;
|
||
|
std::atomic<std::uint64_t> NextTaskId;
|
||
|
std::queue<std::pair<std::uint64_t, std::function<R()> > > Tasks;
|
||
|
std::mutex TasksMutex;
|
||
|
std::condition_variable TasksCV;
|
||
|
};
|
||
|
|
||
|
//=============================================================================
|
||
|
template <typename R>
|
||
|
class ResultQueue
|
||
|
{
|
||
|
public:
|
||
|
ResultQueue(bool strict_ordering)
|
||
|
: NextResultId(0)
|
||
|
, StrictOrdering(strict_ordering)
|
||
|
{
|
||
|
}
|
||
|
|
||
|
~ResultQueue() = default;
|
||
|
|
||
|
std::uint64_t GetNextResultId() const { return this->NextResultId; }
|
||
|
|
||
|
void Push(std::uint64_t task_id, const R&& result)
|
||
|
{
|
||
|
std::unique_lock<std::mutex> lk(this->ResultsMutex);
|
||
|
// don't save this result if it's obsolete.
|
||
|
if (task_id >= this->NextResultId)
|
||
|
{
|
||
|
this->Results.push(std::make_pair(task_id, std::move(result)));
|
||
|
}
|
||
|
lk.unlock();
|
||
|
this->ResultsCV.notify_one();
|
||
|
}
|
||
|
|
||
|
bool TryPop(R& result)
|
||
|
{
|
||
|
std::unique_lock<std::mutex> lk(this->ResultsMutex);
|
||
|
if (this->Results.empty() ||
|
||
|
(this->StrictOrdering && this->Results.top().first != this->NextResultId))
|
||
|
{
|
||
|
// results are not available or of strict-ordering is requested, the
|
||
|
// result available is not the next one in sequence, hence don't pop
|
||
|
// anything.
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
auto result_pair = this->Results.top();
|
||
|
this->NextResultId = (result_pair.first + 1);
|
||
|
this->Results.pop();
|
||
|
lk.unlock();
|
||
|
|
||
|
result = std::move(result_pair.second);
|
||
|
return true;
|
||
|
}
|
||
|
|
||
|
bool Pop(R& result)
|
||
|
{
|
||
|
std::unique_lock<std::mutex> lk(this->ResultsMutex);
|
||
|
this->ResultsCV.wait(lk, [this] {
|
||
|
return !this->Results.empty() &&
|
||
|
(!this->StrictOrdering || this->Results.top().first == this->NextResultId);
|
||
|
});
|
||
|
lk.unlock();
|
||
|
return this->TryPop(result);
|
||
|
}
|
||
|
|
||
|
private:
|
||
|
template <typename T>
|
||
|
struct Comparator
|
||
|
{
|
||
|
bool operator()(const T& left, const T& right) const { return left.first > right.first; }
|
||
|
};
|
||
|
std::priority_queue<std::pair<std::uint64_t, R>, std::vector<std::pair<std::uint64_t, R> >,
|
||
|
Comparator<std::pair<std::uint64_t, R> > >
|
||
|
Results;
|
||
|
std::mutex ResultsMutex;
|
||
|
std::condition_variable ResultsCV;
|
||
|
std::atomic<std::uint64_t> NextResultId;
|
||
|
bool StrictOrdering;
|
||
|
};
|
||
|
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename R, typename... Args>
|
||
|
vtkThreadedTaskQueue<R, Args...>::vtkThreadedTaskQueue(
|
||
|
std::function<R(Args...)> worker, bool strict_ordering, int buffer_size, int max_concurrent_tasks)
|
||
|
: Worker(worker)
|
||
|
, Tasks(new vtkThreadedTaskQueueInternals::TaskQueue<R>(
|
||
|
std::max(0, strict_ordering ? 0 : buffer_size)))
|
||
|
, Results(new vtkThreadedTaskQueueInternals::ResultQueue<R>(strict_ordering))
|
||
|
, NumberOfThreads(max_concurrent_tasks <= 0 ? vtkMultiThreader::GetGlobalDefaultNumberOfThreads()
|
||
|
: max_concurrent_tasks)
|
||
|
, Threads{ new std::thread[this->NumberOfThreads] }
|
||
|
{
|
||
|
auto f = [this](int thread_id) {
|
||
|
vtkLogger::SetThreadName("ttq::worker" + std::to_string(thread_id));
|
||
|
while (true)
|
||
|
{
|
||
|
std::function<R()> task;
|
||
|
std::uint64_t task_id;
|
||
|
if (this->Tasks->Pop(task_id, task))
|
||
|
{
|
||
|
this->Results->Push(task_id, task());
|
||
|
continue;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
// vtkLogF(INFO, "done");
|
||
|
};
|
||
|
|
||
|
for (int cc = 0; cc < this->NumberOfThreads; ++cc)
|
||
|
{
|
||
|
this->Threads[cc] = std::thread(f, cc);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename R, typename... Args>
|
||
|
vtkThreadedTaskQueue<R, Args...>::~vtkThreadedTaskQueue()
|
||
|
{
|
||
|
this->Tasks->MarkDone();
|
||
|
for (int cc = 0; cc < this->NumberOfThreads; ++cc)
|
||
|
{
|
||
|
this->Threads[cc].join();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename R, typename... Args>
|
||
|
void vtkThreadedTaskQueue<R, Args...>::Push(Args&&... args)
|
||
|
{
|
||
|
this->Tasks->Push(std::bind(this->Worker, args...));
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename R, typename... Args>
|
||
|
bool vtkThreadedTaskQueue<R, Args...>::TryPop(R& result)
|
||
|
{
|
||
|
return this->Results->TryPop(result);
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename R, typename... Args>
|
||
|
bool vtkThreadedTaskQueue<R, Args...>::Pop(R& result)
|
||
|
{
|
||
|
if (this->IsEmpty())
|
||
|
{
|
||
|
return false;
|
||
|
}
|
||
|
|
||
|
return this->Results->Pop(result);
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename R, typename... Args>
|
||
|
bool vtkThreadedTaskQueue<R, Args...>::IsEmpty() const
|
||
|
{
|
||
|
return this->Results->GetNextResultId() == this->Tasks->GetNextTaskId();
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename R, typename... Args>
|
||
|
void vtkThreadedTaskQueue<R, Args...>::Flush()
|
||
|
{
|
||
|
R tmp;
|
||
|
while (!this->IsEmpty())
|
||
|
{
|
||
|
this->Pop(tmp);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//=============================================================================
|
||
|
// ** specialization for `void` returns types.
|
||
|
//=============================================================================
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename... Args>
|
||
|
vtkThreadedTaskQueue<void, Args...>::vtkThreadedTaskQueue(std::function<void(Args...)> worker,
|
||
|
bool strict_ordering, int buffer_size, int max_concurrent_tasks)
|
||
|
: Worker(worker)
|
||
|
, Tasks(new vtkThreadedTaskQueueInternals::TaskQueue<void>(
|
||
|
std::max(0, strict_ordering ? 0 : buffer_size)))
|
||
|
, NextResultId(0)
|
||
|
, NumberOfThreads(max_concurrent_tasks <= 0 ? vtkMultiThreader::GetGlobalDefaultNumberOfThreads()
|
||
|
: max_concurrent_tasks)
|
||
|
, Threads{ new std::thread[this->NumberOfThreads] }
|
||
|
{
|
||
|
auto f = [this](int thread_id) {
|
||
|
vtkLogger::SetThreadName("ttq::worker" + std::to_string(thread_id));
|
||
|
while (true)
|
||
|
{
|
||
|
std::function<void()> task;
|
||
|
std::uint64_t task_id;
|
||
|
if (this->Tasks->Pop(task_id, task))
|
||
|
{
|
||
|
task();
|
||
|
|
||
|
std::unique_lock<std::mutex> lk(this->NextResultIdMutex);
|
||
|
this->NextResultId = std::max(static_cast<std::uint64_t>(this->NextResultId), task_id + 1);
|
||
|
lk.unlock();
|
||
|
this->ResultsCV.notify_all();
|
||
|
continue;
|
||
|
}
|
||
|
else
|
||
|
{
|
||
|
break;
|
||
|
}
|
||
|
}
|
||
|
this->ResultsCV.notify_all();
|
||
|
// vtkLogF(INFO, "done");
|
||
|
};
|
||
|
|
||
|
for (int cc = 0; cc < this->NumberOfThreads; ++cc)
|
||
|
{
|
||
|
this->Threads[cc] = std::thread(f, cc);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename... Args>
|
||
|
vtkThreadedTaskQueue<void, Args...>::~vtkThreadedTaskQueue()
|
||
|
{
|
||
|
this->Tasks->MarkDone();
|
||
|
for (int cc = 0; cc < this->NumberOfThreads; ++cc)
|
||
|
{
|
||
|
this->Threads[cc].join();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename... Args>
|
||
|
void vtkThreadedTaskQueue<void, Args...>::Push(Args&&... args)
|
||
|
{
|
||
|
this->Tasks->Push(std::bind(this->Worker, args...));
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename... Args>
|
||
|
bool vtkThreadedTaskQueue<void, Args...>::IsEmpty() const
|
||
|
{
|
||
|
return this->NextResultId == this->Tasks->GetNextTaskId();
|
||
|
}
|
||
|
|
||
|
//-----------------------------------------------------------------------------
|
||
|
template <typename... Args>
|
||
|
void vtkThreadedTaskQueue<void, Args...>::Flush()
|
||
|
{
|
||
|
if (this->IsEmpty())
|
||
|
{
|
||
|
return;
|
||
|
}
|
||
|
std::unique_lock<std::mutex> lk(this->NextResultIdMutex);
|
||
|
this->ResultsCV.wait(lk, [this] { return this->IsEmpty(); });
|
||
|
}
|