diff --git a/src/xenia/base/threading.h b/src/xenia/base/threading.h index 3278f9fcb..62fb5e815 100644 --- a/src/xenia/base/threading.h +++ b/src/xenia/base/threading.h @@ -345,7 +345,7 @@ class Thread : public WaitHandle { // within that thread. static std::unique_ptr Create(CreationParameters params, std::function start_routine); - static std::unique_ptr GetCurrentThread(); + static Thread* GetCurrentThread(); // Ends the calling thread. // No destructors are called, and this function does not return. diff --git a/src/xenia/base/threading_win.cc b/src/xenia/base/threading_win.cc index f6b472cd6..d36ee76e1 100644 --- a/src/xenia/base/threading_win.cc +++ b/src/xenia/base/threading_win.cc @@ -426,10 +426,14 @@ class Win32Thread : public Win32Handle { } }; +thread_local std::unique_ptr current_thread_ = nullptr; + struct ThreadStartData { std::function start_routine; }; DWORD WINAPI ThreadStartRoutine(LPVOID parameter) { + current_thread_ = std::make_unique(::GetCurrentThread()); + auto start_data = reinterpret_cast(parameter); start_data->start_routine(); delete start_data; @@ -449,17 +453,22 @@ std::unique_ptr Thread::Create(CreationParameters params, delete start_data; return nullptr; } - GetThreadId(handle); + return std::make_unique(handle); } -std::unique_ptr Thread::GetCurrentThread() { +Thread* Thread::GetCurrentThread() { + if (current_thread_) { + return current_thread_.get(); + } + HANDLE handle = ::GetCurrentThread(); if (handle == INVALID_HANDLE_VALUE) { return nullptr; } - return std::make_unique(handle); + current_thread_ = std::make_unique(handle); + return current_thread_.get(); } void Thread::Exit(int exit_code) { ExitThread(exit_code); }