diff --git a/rpcs3/util/atomic.cpp b/rpcs3/util/atomic.cpp index 64d4babf21..3692369898 100644 --- a/rpcs3/util/atomic.cpp +++ b/rpcs3/util/atomic.cpp @@ -104,8 +104,17 @@ namespace // Find lowest clear bit const auto sema = &sema_data[std::countr_one(bits)]; -#if defined(USE_FUTEX) || defined(_WIN32) +#if defined(USE_FUTEX) sema->release(1); +#elif defined(_WIN32) + if (NtWaitForAlertByThreadId) + { + sema->release(GetCurrentThreadId()); + } + else + { + sema->release(1); + } #endif return sema; @@ -278,13 +287,8 @@ static void slot_free(std::uintptr_t iptr, sync_var* loc, u64 lv = 0) } } -void atomic_storage_futex::wait(const void* data, std::size_t size, u64 old_value, u64 timeout, u64 mask) +SAFE_BUFFERS void atomic_storage_futex::wait(const void* data, std::size_t size, u64 old_value, u64 timeout, u64 mask) { - if (!timeout) - { - return; - } - const std::uintptr_t iptr = reinterpret_cast(data); // Allocated slot index @@ -380,6 +384,11 @@ void atomic_storage_futex::wait(const void* data, std::size_t size, u64 old_valu lv = eq_bits + 1; } +#ifdef _WIN32 + // May be used by NtWaitForAlertByThreadId + u32 thread_id[16]{GetCurrentThreadId()}; +#endif + auto sema = slot->sema_alloc(); while (!sema) @@ -396,9 +405,9 @@ void atomic_storage_futex::wait(const void* data, std::size_t size, u64 old_valu } // Can skip unqueue process if true -#ifdef USE_FUTEX - bool fallback = true; -#else +#if defined(USE_FUTEX) + const bool fallback = true; +#elif defined(_WIN32) bool fallback = false; #endif @@ -428,18 +437,46 @@ void atomic_storage_futex::wait(const void* data, std::size_t size, u64 old_valu qw.QuadPart -= 1; } - if (fallback) + if (NtWaitForAlertByThreadId) { - // Restart waiting - verify(HERE), sema->load() == 2; - sema->release(1); - fallback = false; - } + if (fallback) [[unlikely]] + { + // Restart waiting + if (sema->load() == umax) + { + sema->release(thread_id[0]); + } - if (!NtWaitForKeyedEvent(nullptr, sema, false, timeout + 1 ? &qw : nullptr)) + fallback = false; + } + + // Let's assume it can return spuriously + switch (DWORD status = NtWaitForAlertByThreadId(thread_id, timeout + 1 ? &qw : nullptr)) + { + case NTSTATUS_ALERTED: fallback = true; break; + case NTSTATUS_TIMEOUT: break; + default: + { + SetLastError(status); + fmt::raw_verify_error("Unexpected NtWaitForAlertByThreadId result.", nullptr, 0); + } + } + } + else { - // Error code assumed to be timeout - fallback = true; + if (fallback) + { + // Restart waiting + verify(HERE), sema->load() == 2; + sema->release(1); + fallback = false; + } + + if (!NtWaitForKeyedEvent(nullptr, sema, false, timeout + 1 ? &qw : nullptr)) + { + // Error code assumed to be timeout + fallback = true; + } } #endif @@ -455,6 +492,21 @@ void atomic_storage_futex::wait(const void* data, std::size_t size, u64 old_valu #if defined(_WIN32) static LARGE_INTEGER instant{}; + if (NtWaitForAlertByThreadId) + { + if (sema->compare_and_swap_test(thread_id[0], -1)) + { + break; + } + + if (NtWaitForAlertByThreadId(thread_id, &instant) == NTSTATUS_ALERTED) + { + break; + } + + continue; + } + if (sema->compare_and_swap_test(1, 2)) { // Succeeded in self-notifying @@ -469,6 +521,10 @@ void atomic_storage_futex::wait(const void* data, std::size_t size, u64 old_valu #endif } +#ifdef _WIN32 + verify(HERE), thread_id[0] == GetCurrentThreadId(); +#endif + slot->sema_free(sema); slot_free(iptr, &s_hashtable[iptr % s_hashtable_size]); @@ -487,6 +543,22 @@ static inline bool alert_sema(atomic_t* sema) return true; } #elif defined(_WIN32) + if (NtWaitForAlertByThreadId) + { + u32 tid = sema->load(); + + // Check if tid is neither 0 nor -1 + if (tid + 1 > 1 && sema->compare_and_swap_test(tid, -1)) + { + if (NtAlertThreadByThreadId(tid) == NTSTATUS_SUCCESS) + { + return true; + } + } + + return false; + } + if (sema->load() == 1 && sema->compare_and_swap_test(1, 2)) { // Can wait in rare cases, which is its annoying weakness @@ -548,13 +620,13 @@ void atomic_storage_futex::notify_all(const void* data) } #if defined(_WIN32) && !defined(USE_FUTEX) - if (true) + if (!NtAlertThreadByThreadId) { // Make a copy to filter out waiters that fail some checks u64 copy = slot->sema_bits.load(); // Used for making non-blocking syscall - LARGE_INTEGER instant{}; + static LARGE_INTEGER instant{}; for (u64 bits = copy; bits; bits &= bits - 1) {