diff --git a/src/xenia/base/testing/threading_test.cc b/src/xenia/base/testing/threading_test.cc index 2d355da42..ad663812f 100644 --- a/src/xenia/base/testing/threading_test.cc +++ b/src/xenia/base/testing/threading_test.cc @@ -849,6 +849,41 @@ TEST_CASE("Test Suspending Thread", "Thread") { thread->Resume(); result = threading::Wait(thread.get(), false, 50ms); REQUIRE(result == threading::WaitResult::kSuccess); + + // Test recursive suspend + thread = threading::Thread::Create(params, func); + thread->Suspend(); + thread->Suspend(); + result = threading::Wait(thread.get(), false, 50ms); + REQUIRE(result == threading::WaitResult::kTimeout); + thread->Resume(); + result = threading::Wait(thread.get(), false, 50ms); + REQUIRE(result == threading::WaitResult::kTimeout); + thread->Resume(); + result = threading::Wait(thread.get(), false, 50ms); + REQUIRE(result == threading::WaitResult::kSuccess); + + // Test suspend count + uint32_t suspend_count = 0; + thread = threading::Thread::Create(params, func); + thread->Suspend(&suspend_count); + REQUIRE(suspend_count == 0); + thread->Suspend(&suspend_count); + REQUIRE(suspend_count == 1); + thread->Suspend(&suspend_count); + REQUIRE(suspend_count == 2); + thread->Resume(&suspend_count); + REQUIRE(suspend_count == 3); + thread->Resume(&suspend_count); + REQUIRE(suspend_count == 2); + thread->Resume(&suspend_count); + REQUIRE(suspend_count == 1); + thread->Suspend(&suspend_count); + REQUIRE(suspend_count == 0); + thread->Resume(&suspend_count); + REQUIRE(suspend_count == 1); + result = threading::Wait(thread.get(), false, 50ms); + REQUIRE(result == threading::WaitResult::kSuccess); } TEST_CASE("Test Thread QueueUserCallback", "Thread") { diff --git a/src/xenia/base/threading.h b/src/xenia/base/threading.h index 790539141..1e10be22b 100644 --- a/src/xenia/base/threading.h +++ b/src/xenia/base/threading.h @@ -389,7 +389,7 @@ class Thread : public WaitHandle { // Decrements a thread's suspend count. When the suspend count is decremented // to zero, the execution of the thread is resumed. - virtual bool Resume(uint32_t* out_new_suspend_count = nullptr) = 0; + virtual bool Resume(uint32_t* out_previous_suspend_count = nullptr) = 0; // Suspends the specified thread. virtual bool Suspend(uint32_t* out_previous_suspend_count = nullptr) = 0; diff --git a/src/xenia/base/threading_posix.cc b/src/xenia/base/threading_posix.cc index 23653a968..21476b544 100644 --- a/src/xenia/base/threading_posix.cc +++ b/src/xenia/base/threading_posix.cc @@ -473,7 +473,8 @@ class PosixCondition : public PosixConditionBase { : thread_(0), signaled_(false), exit_code_(0), - state_(State::kUninitialized) {} + state_(State::kUninitialized), + suspend_count_(0) {} bool Initialize(Thread::CreationParameters params, ThreadStartData* start_data) { start_data->create_suspended = params.create_suspended; @@ -608,21 +609,33 @@ class PosixCondition : public PosixConditionBase { user_callback_(); } - bool Resume(uint32_t* out_new_suspend_count = nullptr) { - // TODO(bwrsandman): implement suspend_count - assert_null(out_new_suspend_count); + bool Resume(uint32_t* out_previous_suspend_count = nullptr) { + if (out_previous_suspend_count) { + *out_previous_suspend_count = 0; + } WaitStarted(); std::unique_lock lock(state_mutex_); if (state_ != State::kSuspended) return false; - state_ = State::kRunning; + if (out_previous_suspend_count) { + *out_previous_suspend_count = suspend_count_; + } + --suspend_count_; state_signal_.notify_all(); return true; } bool Suspend(uint32_t* out_previous_suspend_count = nullptr) { - // TODO(bwrsandman): implement suspend_count - assert_null(out_previous_suspend_count); + if (out_previous_suspend_count) { + *out_previous_suspend_count = 0; + } WaitStarted(); + { + if (out_previous_suspend_count) { + *out_previous_suspend_count = suspend_count_; + } + state_ = State::kSuspended; + ++suspend_count_; + } int result = pthread_kill(thread_, GetSystemSignal(SignalType::kThreadSuspend)); return result == 0; @@ -656,8 +669,8 @@ class PosixCondition : public PosixConditionBase { /// Set state to suspended and wait until it reset by another thread void WaitSuspended() { std::unique_lock lock(state_mutex_); - state_ = State::kSuspended; - state_signal_.wait(lock, [this] { return state_ != State::kSuspended; }); + state_signal_.wait(lock, [this] { return suspend_count_ == 0; }); + state_ = State::kRunning; } private: @@ -673,6 +686,7 @@ class PosixCondition : public PosixConditionBase { bool signaled_; int exit_code_; volatile State state_; + volatile uint32_t suspend_count_; mutable std::mutex state_mutex_; mutable std::mutex callback_mutex_; mutable std::condition_variable state_signal_; @@ -883,8 +897,8 @@ class PosixThread : public PosixConditionHandle { handle_.QueueUserCallback(std::move(callback)); } - bool Resume(uint32_t* out_new_suspend_count) override { - return handle_.Resume(out_new_suspend_count); + bool Resume(uint32_t* out_previous_suspend_count) override { + return handle_.Resume(out_previous_suspend_count); } bool Suspend(uint32_t* out_previous_suspend_count) override { @@ -923,8 +937,9 @@ void* PosixCondition::ThreadStartRoutine(void* parameter) { if (create_suspended) { std::unique_lock lock(thread->handle_.state_mutex_); + thread->handle_.suspend_count_ = 1; thread->handle_.state_signal_.wait( - lock, [thread] { return thread->handle_.state_ != State::kSuspended; }); + lock, [thread] { return thread->handle_.suspend_count_ == 0; }); } start_routine(); diff --git a/src/xenia/base/threading_win.cc b/src/xenia/base/threading_win.cc index 605c2ccbf..6b4e31a99 100644 --- a/src/xenia/base/threading_win.cc +++ b/src/xenia/base/threading_win.cc @@ -388,16 +388,16 @@ class Win32Thread : public Win32Handle { QueueUserAPC(DispatchApc, handle_, reinterpret_cast(apc_data)); } - bool Resume(uint32_t* out_new_suspend_count = nullptr) override { - if (out_new_suspend_count) { - *out_new_suspend_count = 0; + bool Resume(uint32_t* out_previous_suspend_count = nullptr) override { + if (out_previous_suspend_count) { + *out_previous_suspend_count = 0; } DWORD result = ResumeThread(handle_); if (result == UINT_MAX) { return false; } - if (out_new_suspend_count) { - *out_new_suspend_count = result; + if (out_previous_suspend_count) { + *out_previous_suspend_count = result; } return true; }