Commit 13f30eb0 authored by Xavier Thompson's avatar Xavier Thompson

Refactor and improve scheduler.hpp

- Centralize the logic in scheduler.hpp
- Remove worker.hpp and introduce pool.hpp
- Allow a different number of pools than there are threads
parent ebf375d2
#ifndef TYPON_CORE_POOL_HPP_INCLUDED
#define TYPON_CORE_POOL_HPP_INCLUDED
#include <mutex>
#include <vector>
#include <typon/core/stack.hpp>
namespace typon
{
struct Pool
{
using Vector = std::vector<Stack *>;
using Size = Vector::size_type;
std::mutex _mutex;
Vector _pool;
auto lock_guard() noexcept
{
return std::lock_guard(_mutex);
}
auto adopt_lock_guard() noexcept
{
return std::lock_guard(_mutex, std::adopt_lock);
}
bool try_lock() noexcept
{
return _mutex.try_lock();
}
void add(Stack * stack) noexcept
{
_pool.push_back(stack);
}
auto size() noexcept
{
return _pool.size();
}
Stack * get(Size index) noexcept
{
return _pool[index];
}
void remove(Size index) noexcept
{
_pool[index] = _pool.back();
_pool.pop_back();
}
~Pool()
{
for (auto & stack : _pool)
{
delete stack;
}
}
};
}
#endif // TYPON_CORE_POOL_HPP_INCLUDED
...@@ -2,10 +2,12 @@ ...@@ -2,10 +2,12 @@
#define TYPON_CORE_SCHEDULER_HPP_INCLUDED #define TYPON_CORE_SCHEDULER_HPP_INCLUDED
#include <atomic> #include <atomic>
#include <bit>
#include <coroutine> #include <coroutine>
#include <cstdint> #include <cstdint>
#include <exception> #include <exception>
#include <thread> #include <thread>
#include <utility>
#include <vector> #include <vector>
#include <typon/fundamental/event_count.hpp> #include <typon/fundamental/event_count.hpp>
...@@ -13,8 +15,8 @@ ...@@ -13,8 +15,8 @@
#include <typon/fundamental/random.hpp> #include <typon/fundamental/random.hpp>
#include <typon/core/continuation.hpp> #include <typon/core/continuation.hpp>
#include <typon/core/pool.hpp>
#include <typon/core/stack.hpp> #include <typon/core/stack.hpp>
#include <typon/core/worker.hpp>
namespace typon namespace typon
...@@ -24,8 +26,17 @@ namespace typon ...@@ -24,8 +26,17 @@ namespace typon
{ {
using uint = unsigned int; using uint = unsigned int;
using u64 = std::uint_fast64_t; using u64 = std::uint_fast64_t;
using Work = Worker::Work;
using garbage_collector = fdt::lock_free::garbage_collector; std::atomic_bool _done {false};
const uint _concurrency;
const uint _mask;
std::atomic<uint> _thieves = 0;
std::atomic<u64> _potential = 0;
fdt::lock_free::event_count<> _notifyer;
std::vector<Pool> _pool;
std::vector<Stack *> _stack;
std::vector<std::thread> _thread;
fdt::lock_free::garbage_collector _gc;
static inline thread_local uint thread_id; static inline thread_local uint thread_id;
...@@ -35,123 +46,90 @@ namespace typon ...@@ -35,123 +46,90 @@ namespace typon
return scheduler; return scheduler;
} }
static void schedule(std::coroutine_handle<> task) noexcept static void schedule(std::coroutine_handle<> coroutine) noexcept
{ {
uint id = fdt::random::random() % get()._concurrency; Pool & pool = get().random();
get()._worker[id].add(new Stack(task)); Stack * stack = new Stack(Stack::READY);
{
auto lock_guard = pool.lock_guard();
pool.add(stack);
}
stack->_coroutine = coroutine;
get()._potential.fetch_add(1); get()._potential.fetch_add(1);
get()._notifyer.notify_one(); get()._notifyer.notify_one();
} }
static void push(Continuation task) noexcept static void push(Continuation task) noexcept
{ {
get()._worker[thread_id].push(task); get()._stack[thread_id]->push(task);
} }
static bool pop() noexcept static bool pop() noexcept
{ {
return get()._worker[thread_id].pop(); bool result = get()._stack[thread_id]->pop();
if (auto garbage = get()._stack[thread_id]->reclaim())
{
get()._gc.retire(garbage);
}
return result;
} }
static auto suspend(std::coroutine_handle<> coroutine) noexcept static auto suspend(std::coroutine_handle<> coroutine) noexcept
{ {
Worker & worker = get()._worker[thread_id]; Stack * stack = std::exchange(get()._stack[thread_id], nullptr);
auto stack = worker.suspend(coroutine); stack->_coroutine = coroutine;
uint id = fdt::random::random() % get()._concurrency; stack->_state.store(Stack::WAITING);
get()._worker[id].add(stack);
return stack; return stack;
} }
static void enable(Stack * stack) noexcept static void enable(Stack * stack) noexcept
{ {
auto state = stack->_state.exchange(Stack::Resumable); auto state = stack->_state.exchange(Stack::READY);
if (state == Stack::Empty) if (state == Stack::EMPTY)
{ {
uint id = fdt::random::random() % get()._concurrency; Pool & pool = get().random();
get()._worker[id].add(stack); {
get()._potential.fetch_add(1); auto lock_guard = pool.lock_guard();
pool.add(stack);
} }
get()._potential.fetch_add(1);
get()._notifyer.notify_one(); get()._notifyer.notify_one();
} }
}
std::atomic<uint> _thieves = 0;
std::atomic<u64> _potential = 0;
std::vector<Worker> _worker;
std::vector<std::thread> _thread;
std::atomic_bool _done {false};
fdt::lock_free::event_count<> _notifyer;
const uint _concurrency;
garbage_collector _gc;
Scheduler(uint concurrency) noexcept Scheduler(uint concurrency) noexcept
: _worker(concurrency) : _concurrency(concurrency)
, _concurrency(concurrency) , _mask((1 << std::bit_width(concurrency)) - 1)
, _pool(this->_mask + 1)
, _stack(concurrency, nullptr)
, _gc(concurrency) , _gc(concurrency)
{ {
for (uint i = 0; i < concurrency; i++)
{
_worker[i]._gc = &(_gc);
_worker[i]._potential = &(_potential);
}
thread_id = concurrency;
for (uint id = 0; id < concurrency; id++) for (uint id = 0; id < concurrency; id++)
{ {
_thread.emplace_back([this, id]() { _thread.emplace_back([this, id]() {
thread_id = id; thread_id = id;
Work work {};
for(;;) for(;;)
{ {
if (!wait_for_work(work)) std::coroutine_handle<> coroutine;
if (!wait_for_work(coroutine))
{ {
break; break;
} }
exploit_work(work); coroutine.resume();
} }
}); });
} }
} }
~Scheduler() noexcept bool wait_for_work(std::coroutine_handle<> & coroutine) noexcept
{
_done.store(true);
_notifyer.notify_all();
for (auto & t : _thread)
{
t.join();
}
}
void exploit_work(Work & work) noexcept
{
_worker[thread_id].resume(work);
}
void explore_work(Work & work) noexcept
{
auto epoch = _gc.epoch(thread_id);
for (uint i = 0; i < _concurrency * 2 + 1; i++)
{
uint id = fdt::random::random() % _concurrency;
work = _worker[id].steal();
if (work)
{
break;
}
}
}
bool wait_for_work(Work & work) noexcept
{ {
work = {};
for(;;) for(;;)
{ {
_thieves.fetch_add(1); _thieves.fetch_add(1);
explore_work(work); find_work(coroutine);
if (work) if (coroutine)
{ {
prepare_stack();
if (_thieves.fetch_sub(1) == 1) if (_thieves.fetch_sub(1) == 1)
{ {
_notifyer.notify_one(); _notifyer.notify_one();
...@@ -177,6 +155,107 @@ namespace typon ...@@ -177,6 +155,107 @@ namespace typon
_notifyer.wait(key); _notifyer.wait(key);
} }
} }
void find_work(std::coroutine_handle<> & coroutine) noexcept
{
auto epoch = _gc.epoch(thread_id);
for (uint i = 0; i < _concurrency * 2 + 1; i++)
{
Pool & pool = random();
if (!pool.try_lock())
{
continue;
}
auto lock_guard = pool.adopt_lock_guard();
auto size = pool.size();
if (!size)
{
continue;
}
auto index = size > 1 ? fdt::random::random() % size : 0;
auto stack = pool.get(index);
auto state = stack->_state.load();
if (state == Stack::ACTIVE)
{
if (auto task = stack->steal())
{
task.thefts()++;
coroutine = task;
}
return;
}
if (state == Stack::WAITING)
{
if (auto task = stack->pop_top())
{
task.thefts()++;
coroutine = task;
return;
}
if (stack->_state.compare_exchange_strong(state, Stack::EMPTY))
{
_potential.fetch_sub(1);
pool.remove(index);
return;
}
adopt_stack(stack);
coroutine = stack->_coroutine;
return;
}
if (state == Stack::READY)
{
adopt_stack(stack);
coroutine = stack->_coroutine;
return;
}
if (state == Stack::DONE)
{
pool.remove(index);
delete stack;
return;
}
}
}
void adopt_stack(Stack * stack) noexcept
{
stack->_state.store(Stack::ACTIVE);
if (auto old = std::exchange(_stack[thread_id], stack))
{
_potential.fetch_sub(1);
old->_state.store(Stack::DONE);
}
}
void prepare_stack() noexcept
{
if (!_stack[thread_id])
{
auto stack = _stack[thread_id] = new Stack(Stack::ACTIVE);
Pool & pool = random();
{
auto lock_guard = pool.lock_guard();
pool.add(stack);
}
_potential.fetch_add(1);
_notifyer.notify_one();
}
}
Pool & random() noexcept
{
return _pool[fdt::random::random() & _mask];
}
~Scheduler() noexcept
{
_done.store(true);
_notifyer.notify_all();
for (auto & t : _thread)
{
t.join();
}
}
}; };
} }
......
...@@ -20,19 +20,17 @@ namespace typon ...@@ -20,19 +20,17 @@ namespace typon
using enum std::memory_order; using enum std::memory_order;
enum State : unsigned char { Empty, Suspended, Resumable }; enum State : unsigned char { ACTIVE, WAITING, EMPTY, READY, DONE };
std::atomic<u64> _top {1}; std::atomic<u64> _top {1};
std::atomic<u64> _bottom {1}; std::atomic<u64> _bottom {1};
std::atomic<ring_buffer *> _buffer { new ring_buffer(3) }; std::atomic<ring_buffer *> _buffer;
std::coroutine_handle<> _coroutine;
std::atomic<State> _state; std::atomic<State> _state;
std::coroutine_handle<> _coroutine;
Stack() noexcept {} Stack(State state) noexcept
: _buffer(new ring_buffer(3))
Stack(std::coroutine_handle<> coroutine) noexcept , _state(state)
: _coroutine(coroutine)
, _state(Resumable)
{} {}
~Stack() ~Stack()
...@@ -98,24 +96,17 @@ namespace typon ...@@ -98,24 +96,17 @@ namespace typon
u64 top = _top.load(relaxed); u64 top = _top.load(relaxed);
u64 bottom = _bottom.load(relaxed); u64 bottom = _bottom.load(relaxed);
auto buffer = _buffer.load(relaxed); auto buffer = _buffer.load(relaxed);
Continuation x { nullptr };
if (top < bottom) if (top < bottom)
{ {
Continuation x = buffer->get(top); x = buffer->get(top);
_top.store(top + 1, relaxed); _top.store(top + 1, relaxed);
return x;
} }
return { nullptr }; if (auto garbage = reclaim())
}
void suspend(std::coroutine_handle<> coroutine) noexcept
{ {
_state.store(Suspended); delete garbage;
_coroutine = coroutine;
} }
return x;
void resume() noexcept
{
_coroutine.resume();
} }
ring_buffer * reclaim() noexcept ring_buffer * reclaim() noexcept
......
#ifndef TYPON_CORE_WORKER_HPP_INCLUDED
#define TYPON_CORE_WORKER_HPP_INCLUDED
#include <atomic>
#include <coroutine>
#include <mutex>
#include <type_traits>
#include <utility>
#include <variant>
#include <vector>
#include <typon/fundamental/garbage_collector.hpp>
#include <typon/fundamental/random.hpp>
#include <typon/core/continuation.hpp>
namespace typon
{
struct Worker
{
struct Work
{
static_assert(std::is_trivially_destructible_v<Continuation>);
enum State : char { Empty, Resumable, Stolen };
State _state;
union
{
Stack * _stack;
Continuation _task;
};
Work() noexcept : _state(Empty) {}
Work(Stack * stack) noexcept : _state(Resumable), _stack(stack) {}
Work(Continuation task) noexcept : _state(Stolen), _task(task) {}
operator bool() noexcept
{
return _state != Empty;
}
};
std::mutex _mutex;
std::atomic<Stack *> _stack {nullptr};
std::vector<Stack *> _pool;
std::atomic_uint_fast64_t * _potential;
fdt::lock_free::garbage_collector * _gc;
~Worker()
{
for (auto & stack : _pool)
{
delete stack;
}
if (auto stack = _stack.load())
{
delete stack;
}
}
void add(Stack * stack) noexcept
{
std::lock_guard lock(_mutex);
_pool.push_back(stack);
}
bool try_add(Stack * stack) noexcept
{
if (!_mutex.try_lock())
{
return false;
}
std::lock_guard lock(_mutex, std::adopt_lock);
_pool.push_back(stack);
return true;
}
auto suspend(std::coroutine_handle<> coroutine) noexcept
{
auto stack = _stack.load();
_stack.store(nullptr);
stack->suspend(coroutine);
return stack;
}
void resume(Work & work) noexcept
{
if (work._state == Work::Resumable)
{
auto stack = _stack.load();
_stack.store(work._stack);
if (stack)
{
_gc->retire(stack);
}
_potential->fetch_add(1);
work._stack->resume();
}
else
{
if (!_stack.load())
{
_stack.store(new Stack());
}
_potential->fetch_add(1);
work._task.resume();
}
if (_stack.load())
{
_potential->fetch_sub(1);
}
}
void push(Continuation task) noexcept
{
_stack.load()->push(task);
}
bool pop() noexcept
{
Stack * stack = _stack.load();
bool result = stack->pop();
if (auto garbage = stack->reclaim())
{
_gc->retire(garbage);
}
return result;
}
Work steal() noexcept
{
if (!_mutex.try_lock())
{
return {};
}
std::lock_guard lock(_mutex, std::adopt_lock);
auto stack = _stack.load();
auto total = _pool.size() + bool(stack);
if (total == 0)
{
return {};
}
auto index = fdt::random::random64() % total;
if (index == _pool.size())
{
if (auto task = stack->steal())
{
task.thefts()++;
return task;
}
return {};
}
stack = _pool[index];
if (stack->_state.load() == Stack::Resumable)
{
if (index < _pool.size() - 1)
{
_pool[index] = _pool.back();
}
_pool.pop_back();
return stack;
}
auto task = stack->pop_top();
if (auto garbage = stack->reclaim())
{
delete garbage;
}
if (task)
{
task.thefts()++;
return task;
}
if (index < _pool.size() - 1)
{
_pool[index] = _pool.back();
}
_pool.pop_back();
Stack::State expected = Stack::Suspended;
if (!stack->_state.compare_exchange_strong(expected, Stack::Empty))
{
return stack;
}
_potential->fetch_sub(1);
return {};
}
};
}
#endif // TYPON_CORE_WORKER_HPP_INCLUDED
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment