Implemented thread_ctrl::interrupt

This commit is contained in:
Nekotekina 2016-07-16 20:58:42 +03:00
parent 96728a83f6
commit 59433bfcd5
2 changed files with 219 additions and 21 deletions

View File

@ -1609,6 +1609,8 @@ static void prepare_throw_access_violation(x64_context* context, const char* cau
RIP(context) = (u64)std::addressof(throw_access_violation); RIP(context) = (u64)std::addressof(throw_access_violation);
} }
static void _handle_interrupt(x64_context* ctx);
#ifdef _WIN32 #ifdef _WIN32
static LONG exception_handler(PEXCEPTION_POINTERS pExp) static LONG exception_handler(PEXCEPTION_POINTERS pExp)
@ -1698,6 +1700,11 @@ static void signal_handler(int sig, siginfo_t* info, void* uct)
{ {
x64_context* context = (ucontext_t*)uct; x64_context* context = (ucontext_t*)uct;
if (sig == SIGUSR1)
{
return _handle_interrupt(context);
}
#ifdef __APPLE__ #ifdef __APPLE__
const bool is_writing = context->uc_mcontext->__es.__err & 0x2; const bool is_writing = context->uc_mcontext->__es.__err & 0x2;
#else #else
@ -1735,7 +1742,15 @@ const bool s_exception_handler_set = []() -> bool
if (::sigaction(SIGSEGV, &sa, NULL) == -1) if (::sigaction(SIGSEGV, &sa, NULL) == -1)
{ {
std::printf("sigaction() failed (0x%x).", errno); std::printf("sigaction(SIGSEGV) failed (0x%x).", errno);
std::abort();
}
sa.sa_sigaction = signal_handler;
if (::sigaction(SIGUSR1, &sa, NULL) == -1)
{
std::printf("sigaction(SIGUSR1) failed (0x%x).", errno);
std::abort(); std::abort();
} }
@ -1767,13 +1782,23 @@ struct thread_ctrl::internal
{ {
std::mutex mutex; std::mutex mutex;
std::condition_variable cond; std::condition_variable cond;
std::condition_variable join; // Allows simultaneous joining std::condition_variable jcv; // Allows simultaneous joining
std::condition_variable icv;
task_stack atexit; task_stack atexit;
std::exception_ptr exception; // Caught exception std::exception_ptr exception; // Stored exception
std::chrono::high_resolution_clock::time_point time_limit; std::chrono::high_resolution_clock::time_point time_limit;
#ifdef _WIN32
DWORD thread_id = 0;
x64_context _context{};
#endif
x64_context* thread_ctx{};
atomic_t<void(*)()> interrupt{}; // Interrupt function
}; };
thread_local thread_ctrl::internal* g_tls_internal = nullptr; thread_local thread_ctrl::internal* g_tls_internal = nullptr;
@ -1804,7 +1829,6 @@ void thread_ctrl::start(const std::shared_ptr<thread_ctrl>& ctrl, task_stack tas
} }
catch (...) catch (...)
{ {
ctrl->initialize_once();
ctrl->m_data->exception = std::current_exception(); ctrl->m_data->exception = std::current_exception();
} }
@ -1814,15 +1838,11 @@ void thread_ctrl::start(const std::shared_ptr<thread_ctrl>& ctrl, task_stack tas
void thread_ctrl::wait_start(u64 timeout) void thread_ctrl::wait_start(u64 timeout)
{ {
initialize_once();
m_data->time_limit = std::chrono::high_resolution_clock::now() + std::chrono::microseconds(timeout); m_data->time_limit = std::chrono::high_resolution_clock::now() + std::chrono::microseconds(timeout);
} }
bool thread_ctrl::wait_wait(u64 timeout) bool thread_ctrl::wait_wait(u64 timeout)
{ {
initialize_once();
std::unique_lock<std::mutex> lock(m_data->mutex, std::adopt_lock); std::unique_lock<std::mutex> lock(m_data->mutex, std::adopt_lock);
if (timeout && m_data->cond.wait_until(lock, m_data->time_limit) == std::cv_status::timeout) if (timeout && m_data->cond.wait_until(lock, m_data->time_limit) == std::cv_status::timeout)
@ -1846,11 +1866,12 @@ void thread_ctrl::test()
void thread_ctrl::initialize() void thread_ctrl::initialize()
{ {
initialize_once(); // TODO (temporarily)
// Initialize TLS variable // Initialize TLS variable
g_tls_this_thread = this; g_tls_this_thread = this;
g_tls_internal = this->m_data; g_tls_internal = this->m_data;
#ifdef _WIN32
m_data->thread_id = GetCurrentThreadId();
#endif
g_tls_log_prefix = [] g_tls_log_prefix = []
{ {
@ -1892,6 +1913,10 @@ void thread_ctrl::initialize()
void thread_ctrl::finalize() noexcept void thread_ctrl::finalize() noexcept
{ {
// Disable and discard possible interrupts
interrupt_disable();
test_interrupt();
// TODO // TODO
vm::reservation_free(); vm::reservation_free();
@ -1909,7 +1934,6 @@ void thread_ctrl::finalize() noexcept
void thread_ctrl::push_atexit(task_stack task) void thread_ctrl::push_atexit(task_stack task)
{ {
initialize_once();
m_data->atexit.push(std::move(task)); m_data->atexit.push(std::move(task));
} }
@ -1922,6 +1946,8 @@ thread_ctrl::thread_ctrl(std::string&& name)
#undef new #undef new
new (&m_thread) std::thread; new (&m_thread) std::thread;
#pragma pop_macro("new") #pragma pop_macro("new")
initialize_once();
} }
thread_ctrl::~thread_ctrl() thread_ctrl::~thread_ctrl()
@ -1967,24 +1993,20 @@ void thread_ctrl::join()
// Notify others if necessary // Notify others if necessary
if (UNLIKELY(m_joining.exchange(0x80000000) != 1)) if (UNLIKELY(m_joining.exchange(0x80000000) != 1))
{ {
initialize_once();
// Serialize for reliable notification // Serialize for reliable notification
m_data->mutex.lock(); m_data->mutex.lock();
m_data->mutex.unlock(); m_data->mutex.unlock();
m_data->join.notify_all(); m_data->jcv.notify_all();
} }
} }
else else
{ {
// Hard way // Hard way
initialize_once();
std::unique_lock<std::mutex> lock(m_data->mutex); std::unique_lock<std::mutex> lock(m_data->mutex);
m_data->join.wait(lock, WRAP_EXPR(m_joining >= 0x80000000)); m_data->jcv.wait(lock, WRAP_EXPR(m_joining >= 0x80000000));
} }
if (UNLIKELY(m_data && m_data->exception)) if (UNLIKELY(m_data && m_data->exception && !std::uncaught_exception()))
{ {
std::rethrow_exception(m_data->exception); std::rethrow_exception(m_data->exception);
} }
@ -1992,7 +2014,6 @@ void thread_ctrl::join()
void thread_ctrl::lock() void thread_ctrl::lock()
{ {
initialize_once();
m_data->mutex.lock(); m_data->mutex.lock();
} }
@ -2008,8 +2029,6 @@ void thread_ctrl::lock_notify()
return; return;
} }
initialize_once();
// Serialize for reliable notification, condition is assumed to be changed externally // Serialize for reliable notification, condition is assumed to be changed externally
m_data->mutex.lock(); m_data->mutex.lock();
m_data->mutex.unlock(); m_data->mutex.unlock();
@ -2026,6 +2045,116 @@ void thread_ctrl::set_exception(std::exception_ptr e)
m_data->exception = e; m_data->exception = e;
} }
static void _handle_interrupt(x64_context* ctx)
{
g_tls_internal->thread_ctx = ctx;
thread_ctrl::handle_interrupt();
}
void thread_ctrl::handle_interrupt()
{
const auto _this = g_tls_this_thread;
const auto ctx = g_tls_internal->thread_ctx;
if (_this->m_guard & 0x80000000)
{
// Discard interrupt if interrupts are disabled
if (g_tls_internal->interrupt.exchange(nullptr))
{
_this->lock();
_this->unlock();
g_tls_internal->icv.notify_one();
}
}
else if (_this->m_guard == 0)
{
// Set interrupt immediately if no guard set
if (const auto handler = g_tls_internal->interrupt.exchange(nullptr))
{
_this->lock();
_this->unlock();
g_tls_internal->icv.notify_one();
// Install function call
*--(u64*&)(RSP(ctx)) = RIP(ctx);
RIP(ctx) = (u64)handler;
}
}
else
{
// Set delayed interrupt otherwise
_this->m_guard |= 0x40000000;
}
#ifdef _WIN32
RtlRestoreContext(ctx, nullptr);
#endif
}
void thread_ctrl::interrupt(void(*handler)())
{
VERIFY(this != g_tls_this_thread); // TODO: self-interrupt
VERIFY(m_data->interrupt.compare_and_swap_test(nullptr, handler)); // TODO: multiple interrupts
#ifdef _WIN32
const auto ctx = &m_data->_context;
m_data->thread_ctx = ctx;
const HANDLE nt = OpenThread(THREAD_ALL_ACCESS, FALSE, m_data->thread_id);
VERIFY(nt);
VERIFY(SuspendThread(nt) != -1);
ctx->ContextFlags = CONTEXT_FULL;
VERIFY(GetThreadContext(nt, ctx));
ctx->ContextFlags = CONTEXT_FULL;
const u64 _rip = RIP(ctx);
RIP(ctx) = (u64)std::addressof(thread_ctrl::handle_interrupt);
VERIFY(SetThreadContext(nt, ctx));
RIP(ctx) = _rip;
VERIFY(ResumeThread(nt) != -1);
CloseHandle(nt);
#else
pthread_kill(reinterpret_cast<std::thread&>(m_thread).native_handle(), SIGUSR1);
#endif
std::unique_lock<std::mutex> lock(m_data->mutex, std::adopt_lock);
while (m_data->interrupt)
{
m_data->icv.wait(lock);
}
lock.release();
}
void thread_ctrl::test_interrupt()
{
if (m_guard & 0x80000000)
{
if (m_data->interrupt.exchange(nullptr))
{
lock(), unlock(), m_data->icv.notify_one();
}
return;
}
if (m_guard == 0x40000000 && !std::uncaught_exception())
{
m_guard = 0;
// Execute delayed interrupt handler
if (const auto handler = m_data->interrupt.exchange(nullptr))
{
lock(), unlock(), m_data->icv.notify_one();
return handler();
}
}
}
void thread_ctrl::sleep(u64 useconds) void thread_ctrl::sleep(u64 useconds)
{ {
std::this_thread::sleep_for(std::chrono::microseconds(useconds)); std::this_thread::sleep_for(std::chrono::microseconds(useconds));

View File

@ -94,6 +94,9 @@ private:
// Thread join contention counter // Thread join contention counter
atomic_t<u32> m_joining{}; atomic_t<u32> m_joining{};
// Thread interrupt guard counter
u32 m_guard = 0x80000000;
// Thread internals // Thread internals
atomic_t<internal*> m_data{}; atomic_t<internal*> m_data{};
@ -187,6 +190,42 @@ public:
// Set exception (internal data must be initialized, thread mutex must be locked) // Set exception (internal data must be initialized, thread mutex must be locked)
void set_exception(std::exception_ptr); void set_exception(std::exception_ptr);
// Internal
static void handle_interrupt();
// Interrupt thread with specified handler call (thread mutex must be locked)
void interrupt(void(*handler)());
// Interrupt guard recursive enter
void guard_enter()
{
m_guard++;
}
// Interrupt guard recursive leave
void guard_leave()
{
if (UNLIKELY(--m_guard & 0x40000000))
{
test_interrupt();
}
}
// Allow interrupts
void interrupt_enable()
{
m_guard &= ~0x80000000;
}
// Disable and discard any interrupt
void interrupt_disable()
{
m_guard |= 0x80000000;
}
// Check interrupt if delayed by guard scope
void test_interrupt();
// Current thread sleeps for specified amount of microseconds. // Current thread sleeps for specified amount of microseconds.
// Wrapper for std::this_thread::sleep, doesn't require valid thread_ctrl. // Wrapper for std::this_thread::sleep, doesn't require valid thread_ctrl.
[[deprecated]] static void sleep(u64 useconds); [[deprecated]] static void sleep(u64 useconds);
@ -352,6 +391,36 @@ public:
} }
}; };
// Interrupt guard scope
class thread_guard final
{
thread_ctrl* m_thread;
public:
thread_guard(const thread_guard&) = delete;
thread_guard(thread_ctrl* thread)
: m_thread(thread)
{
m_thread->guard_enter();
}
thread_guard(named_thread& thread)
: thread_guard(thread.operator->())
{
}
thread_guard()
: thread_guard(thread_ctrl::get_current())
{
}
~thread_guard() noexcept(false)
{
m_thread->guard_leave();
}
};
// Wrapper for named thread, joins automatically in the destructor, can only be used in function scope // Wrapper for named thread, joins automatically in the destructor, can only be used in function scope
class scope_thread final class scope_thread final
{ {