[Base] Refactor POSIX timers, fix user-after-free

Since timer_delete does not clean up already queued signals, signal info
data needs to be retained after timer deletion and object destruction in
order to circumvent use-after-free bugs.
This commit is contained in:
Joel Linn 2022-03-01 23:41:03 +01:00 committed by Rick Gibbed
parent 257b904a5e
commit b72ab7b4a4
1 changed files with 131 additions and 44 deletions

View File

@ -10,7 +10,7 @@
#include "xenia/base/threading.h" #include "xenia/base/threading.h"
#include "xenia/base/assert.h" #include "xenia/base/assert.h"
#include "xenia/base/logging.h" #include "xenia/base/delay_scheduler.h"
#include "xenia/base/platform.h" #include "xenia/base/platform.h"
#include <pthread.h> #include <pthread.h>
@ -23,7 +23,6 @@
#include <unistd.h> #include <unistd.h>
#include <array> #include <array>
#include <cstddef> #include <cstddef>
#include <cstring>
#include <ctime> #include <ctime>
#include <memory> #include <memory>
@ -79,6 +78,47 @@ void AndroidShutdown() {
} }
#endif #endif
// This is separately allocated for each (`HighResolution`)`Timer` object. It
// will be cleaned up some time (`timers_garbage_collector_delay`) after the
// posix timer was canceled because posix `timer_delete(...)` does not remove
// pending timer signals.
// https://stackoverflow.com/questions/49756114/linux-timer-pending-signal
struct timer_callback_info_t {
std::atomic_bool disarmed;
#if !XE_HAS_SIGEV_THREAD_ID
pthread_t target_thread;
#endif
std::function<void()> callback;
void* userdata;
timer_callback_info_t(std::function<void()> callback)
: disarmed(false),
#if !XE_HAS_SIGEV_THREAD_ID
target_thread(),
#endif
callback(callback),
userdata(nullptr) {
}
};
// GC for timer signal info structs:
constexpr uint_fast8_t timers_garbage_collector_scale_ =
#if XE_HAS_SIGEV_THREAD_ID
1;
#else
2;
#endif
DelayScheduler<timer_callback_info_t> timers_garbage_collector_(
512 * timers_garbage_collector_scale_,
[](timer_callback_info_t* info) {
assert_not_null(info);
delete info;
},
true);
// Delay we have to assume it takes to clear all pending signals (maximum):
constexpr auto timers_garbage_collector_delay =
std::chrono::milliseconds(100 * timers_garbage_collector_scale_);
template <typename _Rep, typename _Period> template <typename _Rep, typename _Period>
inline timespec DurationToTimeSpec( inline timespec DurationToTimeSpec(
std::chrono::duration<_Rep, _Period> duration) { std::chrono::duration<_Rep, _Period> duration) {
@ -195,9 +235,20 @@ bool SetTlsValue(TlsHandle handle, uintptr_t value) {
class PosixHighResolutionTimer : public HighResolutionTimer { class PosixHighResolutionTimer : public HighResolutionTimer {
public: public:
explicit PosixHighResolutionTimer(std::function<void()> callback) explicit PosixHighResolutionTimer(std::function<void()> callback)
: callback_(std::move(callback)), valid_(false) {} : valid_(false) {
callback_info_ = new timer_callback_info_t(std::move(callback));
}
~PosixHighResolutionTimer() override { ~PosixHighResolutionTimer() override {
if (valid_) timer_delete(timer_); if (valid_) {
callback_info_->disarmed = true;
timer_delete(timerid_);
// Deliberately leaks memory when wait queue is full instead of blogs,
// check logs
static_cast<void>(timers_garbage_collector_.TryScheduleAfter(
callback_info_, timers_garbage_collector_delay));
} else {
delete callback_info_;
}
} }
bool Initialize(std::chrono::milliseconds period) { bool Initialize(std::chrono::milliseconds period) {
@ -216,20 +267,23 @@ class PosixHighResolutionTimer : public HighResolutionTimer {
callback_info_->target_thread = pthread_self(); callback_info_->target_thread = pthread_self();
#endif #endif
sev.sigev_signo = GetSystemSignal(SignalType::kHighResolutionTimer); sev.sigev_signo = GetSystemSignal(SignalType::kHighResolutionTimer);
sev.sigev_value.sival_ptr = (void*)&callback_; sev.sigev_value.sival_ptr = callback_info_;
if (timer_create(CLOCK_MONOTONIC, &sev, &timer_) == -1) return false; if (timer_create(CLOCK_MONOTONIC, &sev, &timerid_) == -1) return false;
// Start timer // Start timer
itimerspec its{}; itimerspec its{};
its.it_value = DurationToTimeSpec(period); its.it_value = DurationToTimeSpec(period);
its.it_interval = its.it_value; its.it_interval = its.it_value;
valid_ = timer_settime(timer_, 0, &its, nullptr) != -1; valid_ = timer_settime(timerid_, 0, &its, nullptr) != -1;
if (!valid_) {
timer_delete(timerid_);
}
return valid_; return valid_;
} }
private: private:
std::function<void()> callback_; timer_callback_info_t* callback_info_;
timer_t timer_; timer_t timerid_;
bool valid_; // all values for timer_t are legal so we need this bool valid_; // all values for timer_t are legal so we need this
}; };
@ -441,28 +495,34 @@ template <>
class PosixCondition<Timer> : public PosixConditionBase { class PosixCondition<Timer> : public PosixConditionBase {
public: public:
explicit PosixCondition(bool manual_reset) explicit PosixCondition(bool manual_reset)
: callback_(), : timer_(nullptr),
timer_(nullptr), callback_info_(nullptr),
signal_(false), signal_(false),
manual_reset_(manual_reset) {} manual_reset_(manual_reset) {}
virtual ~PosixCondition() { Cancel(); } virtual ~PosixCondition() { Cancel(); }
bool Signal() override { bool Signal() override {
CompletionRoutine(); std::lock_guard<std::mutex> lock(mutex_);
signal_ = true;
cond_.notify_all();
return true; return true;
} }
// TODO(bwrsandman): due_times of under 1ms deadlock under travis // TODO(bwrsandman): due_times of under 1ms deadlock under travis
// TODO(joellinn): This is likely due to deadlock on mutex_ if Signal() is
// called from signal_handler running in Thread A while thread A was still in
// Set(...) routine inside the lock
bool Set(std::chrono::nanoseconds due_time, std::chrono::milliseconds period, bool Set(std::chrono::nanoseconds due_time, std::chrono::milliseconds period,
std::function<void()> opt_callback = nullptr) { std::function<void()> opt_callback = nullptr) {
std::lock_guard<std::mutex> lock(mutex_); Cancel();
callback_ = std::move(opt_callback); std::lock_guard<std::mutex> lock(mutex_);
callback_info_ = new timer_callback_info_t(std::move(opt_callback));
callback_info_->userdata = this;
signal_ = false; signal_ = false;
// Create timer // Create timer
if (timer_ == nullptr) {
sigevent sev{}; sigevent sev{};
#if XE_HAS_SIGEV_THREAD_ID #if XE_HAS_SIGEV_THREAD_ID
sev.sigev_notify = SIGEV_SIGNAL | SIGEV_THREAD_ID; sev.sigev_notify = SIGEV_SIGNAL | SIGEV_THREAD_ID;
@ -472,8 +532,10 @@ class PosixCondition<Timer> : public PosixConditionBase {
callback_info_->target_thread = pthread_self(); callback_info_->target_thread = pthread_self();
#endif #endif
sev.sigev_signo = GetSystemSignal(SignalType::kTimer); sev.sigev_signo = GetSystemSignal(SignalType::kTimer);
sev.sigev_value.sival_ptr = this; sev.sigev_value.sival_ptr = callback_info_;
if (timer_create(CLOCK_MONOTONIC, &sev, &timer_) == -1) return false; if (timer_create(CLOCK_MONOTONIC, &sev, &timer_) == -1) {
delete callback_info_;
return false;
} }
// Start timer // Start timer
@ -483,26 +545,16 @@ class PosixCondition<Timer> : public PosixConditionBase {
return timer_settime(timer_, 0, &its, nullptr) == 0; return timer_settime(timer_, 0, &its, nullptr) == 0;
} }
void CompletionRoutine() {
// As the callback may reset the timer, store local.
std::function<void()> callback;
{
std::lock_guard<std::mutex> lock(mutex_);
// Store callback
if (callback_) callback = callback_;
signal_ = true;
cond_.notify_all();
}
// Call callback
if (callback) callback();
}
bool Cancel() { bool Cancel() {
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock(mutex_);
bool result = true; bool result = true;
if (timer_) { if (timer_) {
callback_info_->disarmed = true;
result = timer_delete(timer_) == 0; result = timer_delete(timer_) == 0;
timer_ = nullptr; timer_ = nullptr;
static_cast<void>(timers_garbage_collector_.TryScheduleAfter(
callback_info_, timers_garbage_collector_delay));
callback_info_ = nullptr;
} }
return result; return result;
} }
@ -518,8 +570,8 @@ class PosixCondition<Timer> : public PosixConditionBase {
signal_ = false; signal_ = false;
} }
} }
std::function<void()> callback_;
timer_t timer_; timer_t timer_;
timer_callback_info_t* callback_info_;
volatile bool signal_; volatile bool signal_;
const bool manual_reset_; const bool manual_reset_;
}; };
@ -1202,15 +1254,50 @@ static void signal_handler(int signal, siginfo_t* info, void* /*context*/) {
switch (GetSystemSignalType(signal)) { switch (GetSystemSignalType(signal)) {
case SignalType::kHighResolutionTimer: { case SignalType::kHighResolutionTimer: {
assert_not_null(info->si_value.sival_ptr); assert_not_null(info->si_value.sival_ptr);
auto callback = auto timer_info =
*static_cast<std::function<void()>*>(info->si_value.sival_ptr); reinterpret_cast<timer_callback_info_t*>(info->si_value.sival_ptr);
callback(); if (!timer_info->disarmed) {
#if XE_HAS_SIGEV_THREAD_ID
{
#else
if (pthread_self() != timer_info->target_thread) {
sigval info_inner{};
info_inner.sival_ptr = timer_info;
const auto queueres = pthread_sigqueue(
timer_info->target_thread,
GetSystemSignal(SignalType::kHighResolutionTimer), info_inner);
assert_zero(queueres);
} else {
#endif
timer_info->callback();
}
}
} break; } break;
case SignalType::kTimer: { case SignalType::kTimer: {
assert_not_null(info->si_value.sival_ptr); assert_not_null(info->si_value.sival_ptr);
auto pTimer = auto timer_info =
static_cast<PosixCondition<Timer>*>(info->si_value.sival_ptr); reinterpret_cast<timer_callback_info_t*>(info->si_value.sival_ptr);
pTimer->CompletionRoutine(); if (!timer_info->disarmed) {
assert_not_null(timer_info->userdata);
auto timer = static_cast<PosixCondition<Timer>*>(timer_info->userdata);
#if XE_HAS_SIGEV_THREAD_ID
{
#else
if (pthread_self() != timer_info->target_thread) {
sigval info_inner{};
info_inner.sival_ptr = timer_info;
const auto queueres =
pthread_sigqueue(timer_info->target_thread,
GetSystemSignal(SignalType::kTimer), info_inner);
assert_zero(queueres);
} else {
#endif
timer->Signal();
if (timer_info->callback) {
timer_info->callback();
}
}
}
} break; } break;
case SignalType::kThreadSuspend: { case SignalType::kThreadSuspend: {
assert_not_null(current_thread_); assert_not_null(current_thread_);