From 3dd717aca8ea68eea74e2eaf93f45247e4840969 Mon Sep 17 00:00:00 2001 From: Connor McLaughlin Date: Sun, 6 Sep 2020 21:07:03 +1000 Subject: [PATCH] Common: Add memory arena and page fault handler classes --- src/common/CMakeLists.txt | 9 ++ src/common/common.vcxproj | 6 +- src/common/common.vcxproj.filters | 6 +- src/common/memory_arena.cpp | 213 ++++++++++++++++++++++++++++++ src/common/memory_arena.h | 58 ++++++++ src/common/page_fault_handler.cpp | 185 ++++++++++++++++++++++++++ src/common/page_fault_handler.h | 18 +++ 7 files changed, 493 insertions(+), 2 deletions(-) create mode 100644 src/common/memory_arena.cpp create mode 100644 src/common/memory_arena.h create mode 100644 src/common/page_fault_handler.cpp create mode 100644 src/common/page_fault_handler.h diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index b7600fd75..f224c63d8 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -56,6 +56,10 @@ add_library(common minizip_helpers.h null_audio_stream.cpp null_audio_stream.h + memory_arena.cpp + memory_arena.h + page_fault_handler.cpp + page_fault_handler.h rectangle.h progress_callback.cpp progress_callback.h @@ -180,3 +184,8 @@ if(APPLE AND NOT BUILD_LIBRETRO_CORE) gl/context_agl.h ) endif() + +if(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") + # We need -lrt for shm_unlink + target_link_libraries(common PRIVATE rt) +endif() diff --git a/src/common/common.vcxproj b/src/common/common.vcxproj index f81abf94c..fd08eb5c1 100644 --- a/src/common/common.vcxproj +++ b/src/common/common.vcxproj @@ -70,6 +70,8 @@ + + @@ -130,6 +132,8 @@ + + @@ -543,4 +547,4 @@ - \ No newline at end of file + diff --git a/src/common/common.vcxproj.filters b/src/common/common.vcxproj.filters index 2903f3d34..62268c8f4 100644 --- a/src/common/common.vcxproj.filters +++ b/src/common/common.vcxproj.filters @@ -103,6 +103,8 @@ + + @@ -198,6 +200,8 @@ + + @@ -213,4 +217,4 @@ {642ff5eb-af39-4aab-a42f-6eb8188a11d7} - \ No newline at end of file + diff --git a/src/common/memory_arena.cpp b/src/common/memory_arena.cpp new file mode 100644 index 000000000..cdb52ff11 --- /dev/null +++ b/src/common/memory_arena.cpp @@ -0,0 +1,213 @@ +#include "memory_arena.h" +#include "common/assert.h" +#include "common/log.h" +#include "common/string_util.h" +Log_SetChannel(Common::MemoryArena); + +#if defined(WIN32) +#include "common/windows_headers.h" +#elif defined(__linux__) || defined(__ANDROID__) +#include +#include +#include +#include +#endif + +namespace Common { + +MemoryArena::MemoryArena() = default; + +MemoryArena::~MemoryArena() +{ +#if defined(WIN32) + if (m_file_handle) + CloseHandle(m_file_handle); +#elif defined(__linux__) + if (m_shmem_fd > 0) + close(m_shmem_fd); +#endif +} + +void* MemoryArena::FindBaseAddressForMapping(size_t size) +{ + void* base_address; +#if defined(WIN32) + base_address = VirtualAlloc(nullptr, size, MEM_RESERVE, PAGE_READWRITE); + if (base_address) + VirtualFree(base_address, 0, MEM_RELEASE); +#elif defined(__linux__) + base_address = mmap(nullptr, size, PROT_NONE, MAP_ANON | MAP_PRIVATE, -1, 0); + if (base_address) + munmap(base_address, size); +#elif defined(__ANDROID__) + base_address = mmap(nullptr, size, PROT_NONE, MAP_ANON | MAP_SHARED, -1, 0); + if (base_address) + munmap(base_address, size); +#else + base_address = nullptr; +#endif + + if (!base_address) + { + Log_ErrorPrintf("Failed to get base address for memory mapping of size %zu", size); + return nullptr; + } + + return base_address; +} + +bool MemoryArena::Create(size_t size, bool writable, bool executable) +{ +#if defined(WIN32) + const std::string file_mapping_name = + StringUtil::StdStringFromFormat("common_memory_arena_%zu_%u", size, GetCurrentProcessId()); + + const DWORD protect = (writable ? (executable ? PAGE_EXECUTE_READWRITE : PAGE_READWRITE) : PAGE_READONLY); + m_file_handle = CreateFileMappingA(INVALID_HANDLE_VALUE, nullptr, protect, Truncate32(size >> 32), Truncate32(size), + file_mapping_name.c_str()); + if (!m_file_handle) + { + Log_ErrorPrintf("CreateFileMapping failed: %u", GetLastError()); + return false; + } + + return true; +#elif defined(__linux__) + const std::string file_mapping_name = + StringUtil::StdStringFromFormat("common_memory_arena_%zu_%u", size, static_cast(getpid())); + m_shmem_fd = shm_open(file_mapping_name.c_str(), O_CREAT | O_EXCL | (writable ? O_RDWR : O_RDONLY), 0600); + if (m_shmem_fd < 0) + { + Log_ErrorPrintf("shm_open failed: %d", errno); + return false; + } + + // we're not going to be opening this mapping in other processes, so remove the file + shm_unlink(file_mapping_name.c_str()); + + // ensure it's the correct size + if (ftruncate64(m_shmem_fd, static_cast(size)) < 0) + { + Log_ErrorPrintf("ftruncate64(%zu) failed: %d", size, errno); + return false; + } + + return true; +#else + return false; +#endif +} + +std::optional MemoryArena::CreateView(size_t offset, size_t size, bool writable, bool executable, + void* fixed_address) +{ + void* base_pointer = CreateViewPtr(offset, size, writable, executable, fixed_address); + if (!base_pointer) + return std::nullopt; + + return View(this, base_pointer, offset, size, writable); +} + +void* MemoryArena::CreateViewPtr(size_t offset, size_t size, bool writable, bool executable, + void* fixed_address /*= nullptr*/) +{ + void* base_pointer; +#if defined(WIN32) + const DWORD desired_access = FILE_MAP_READ | (writable ? FILE_MAP_WRITE : 0) | (executable ? FILE_MAP_EXECUTE : 0); + base_pointer = + MapViewOfFileEx(m_file_handle, desired_access, Truncate32(offset >> 32), Truncate32(offset), size, fixed_address); + if (!base_pointer) + return nullptr; +#elif defined(__linux__) + const int flags = (fixed_address != nullptr) ? (MAP_SHARED | MAP_FIXED) : MAP_SHARED; + const int prot = PROT_READ | (writable ? PROT_WRITE : 0) | (executable ? PROT_EXEC : 0); + base_pointer = mmap64(fixed_address, size, prot, flags, m_shmem_fd, static_cast(offset)); + if (base_pointer == reinterpret_cast(-1)) + return nullptr; +#else + return nullptr; +#endif + + m_num_views.fetch_add(1); + return base_pointer; +} + +bool MemoryArena::FlushViewPtr(void* address, size_t size) +{ +#if defined(WIN32) + return FlushViewOfFile(address, size); +#elif defined(__linux__) + return (msync(address, size, 0) >= 0); +#else + return false; +#endif +} + +bool MemoryArena::ReleaseViewPtr(void* address, size_t size) +{ + bool result; +#if defined(WIN32) + result = static_cast(UnmapViewOfFile(address)); +#elif defined(__linux__) + result = (munmap(address, size) >= 0); +#else + result = false; +#endif + + if (!result) + { + Log_ErrorPrintf("Failed to unmap previously-created view at %p", address); + return false; + } + + const size_t prev_count = m_num_views.fetch_sub(1); + Assert(prev_count > 0); + return true; +} + +bool MemoryArena::SetPageProtection(void* address, size_t length, bool readable, bool writable, bool executable) +{ +#if defined(WIN32) + static constexpr DWORD protection_table[2][2][2] = { + {{PAGE_NOACCESS, PAGE_EXECUTE}, {PAGE_WRITECOPY, PAGE_EXECUTE_WRITECOPY}}, + {{PAGE_READONLY, PAGE_EXECUTE_READ}, {PAGE_READWRITE, PAGE_EXECUTE_READWRITE}}}; + + DWORD old_protect; + return static_cast( + VirtualProtect(address, length, protection_table[readable][writable][executable], &old_protect)); +#elif defined(__linux__) || defined(__ANDROID__) + const int prot = (readable ? PROT_READ : 0) | (writable ? PROT_WRITE : 0) | (executable ? PROT_EXEC : 0); + return (mprotect(address, length, prot) >= 0); +#else + return false; +#endif +} + +MemoryArena::View::View(MemoryArena* parent, void* base_pointer, size_t arena_offset, size_t mapping_size, + bool writable) + : m_parent(parent), m_base_pointer(base_pointer), m_arena_offset(arena_offset), m_mapping_size(mapping_size), + m_writable(writable) +{ +} + +MemoryArena::View::View(View&& view) + : m_parent(view.m_parent), m_base_pointer(view.m_base_pointer), m_arena_offset(view.m_arena_offset), + m_mapping_size(view.m_mapping_size) +{ + view.m_parent = nullptr; + view.m_base_pointer = nullptr; + view.m_arena_offset = 0; + view.m_mapping_size = 0; +} + +MemoryArena::View::~View() +{ + if (m_parent) + { + if (m_writable && !m_parent->FlushViewPtr(m_base_pointer, m_mapping_size)) + Panic("Failed to flush previously-created view"); + if (!m_parent->ReleaseViewPtr(m_base_pointer, m_mapping_size)) + Panic("Failed to unmap previously-created view"); + } +} +} // namespace Common diff --git a/src/common/memory_arena.h b/src/common/memory_arena.h new file mode 100644 index 000000000..8e175bd47 --- /dev/null +++ b/src/common/memory_arena.h @@ -0,0 +1,58 @@ +#pragma once +#include "types.h" +#include +#include + +namespace Common { +class MemoryArena +{ +public: + class View + { + public: + View(MemoryArena* parent, void* base_pointer, size_t arena_offset, size_t mapping_size, bool writable); + View(View&& view); + ~View(); + + void* GetBasePointer() const { return m_base_pointer; } + size_t GetArenaOffset() const { return m_arena_offset; } + size_t GetMappingSize() const { return m_mapping_size; } + bool IsWritable() const { return m_writable; } + + private: + MemoryArena* m_parent; + void* m_base_pointer; + size_t m_arena_offset; + size_t m_mapping_size; + bool m_writable; + }; + + MemoryArena(); + ~MemoryArena(); + + static void* FindBaseAddressForMapping(size_t size); + + bool Create(size_t size, bool writable, bool executable); + + std::optional CreateView(size_t offset, size_t size, bool writable, bool executable, + void* fixed_address = nullptr); + + void* CreateViewPtr(size_t offset, size_t size, bool writable, bool executable, void* fixed_address = nullptr); + bool FlushViewPtr(void* address, size_t size); + bool ReleaseViewPtr(void* address, size_t size); + + static bool SetPageProtection(void* address, size_t length, bool readable, bool writable, bool executable); + +private: +#if defined(WIN32) + void* m_file_handle = nullptr; +#elif defined(__linux__) + int m_shmem_fd = -1; +#endif + + std::atomic_size_t m_num_views{0}; + size_t m_size = 0; + bool m_writable = false; + bool m_executable = false; +}; +} // namespace Common diff --git a/src/common/page_fault_handler.cpp b/src/common/page_fault_handler.cpp new file mode 100644 index 000000000..67d3192b2 --- /dev/null +++ b/src/common/page_fault_handler.cpp @@ -0,0 +1,185 @@ +#include "page_fault_handler.h" +#include "common/log.h" +#include +#include +#include +Log_SetChannel(Common::PageFaultHandler); + +#if defined(WIN32) +#include "common/windows_headers.h" +#elif defined(__linux__) || defined(__ANDROID__) +#include +#include +#include +#define USE_SIGSEGV 1 +#endif + +namespace Common::PageFaultHandler { + +struct RegisteredHandler +{ + void* owner; + Callback callback; +}; +static std::vector m_handlers; +static std::mutex m_handler_lock; +static thread_local bool s_in_handler; + +#if defined(WIN32) +static PVOID s_veh_handle; + +static LONG ExceptionHandler(PEXCEPTION_POINTERS exi) +{ + if (exi->ExceptionRecord->ExceptionCode != EXCEPTION_ACCESS_VIOLATION || s_in_handler) + return EXCEPTION_CONTINUE_SEARCH; + + s_in_handler = true; + + void* const exception_pc = reinterpret_cast(exi->ContextRecord->Rip); + void* const exception_address = reinterpret_cast(exi->ExceptionRecord->ExceptionInformation[1]); + bool const is_write = exi->ExceptionRecord->ExceptionInformation[0] == 1; + + std::lock_guard guard(m_handler_lock); + for (const RegisteredHandler& rh : m_handlers) + { + if (rh.callback(exception_pc, exception_address, is_write) == HandlerResult::ContinueExecution) + { + s_in_handler = false; + return EXCEPTION_CONTINUE_EXECUTION; + } + } + + s_in_handler = false; + return EXCEPTION_CONTINUE_SEARCH; +} + +#elif defined(USE_SIGSEGV) + +static struct sigaction s_old_sigsegv_action; + +static void SIGSEGVHandler(int sig, siginfo_t* info, void* ctx) +{ + if ((info->si_code != SEGV_MAPERR && info->si_code != SEGV_ACCERR) || s_in_handler) + return; + + void* const exception_address = reinterpret_cast(info->si_addr); + +#if defined(__x86_64__) + void* const exception_pc = reinterpret_cast(static_cast(ctx)->uc_mcontext.gregs[REG_RIP]); + const bool is_write = (static_cast(ctx)->uc_mcontext.gregs[REG_ERR] & 2) != 0; +#elif defined(__aarch64__) + void* const exception_pc = reinterpret_cast(static_cast(ctx)->uc_mcontext.pc); + const bool is_write = false; +#else + void* const exception_pc = nullptr; + const bool is_write = false; +#endif + + std::lock_guard guard(m_handler_lock); + for (const RegisteredHandler& rh : m_handlers) + { + if (rh.callback(exception_pc, exception_address, is_write) == HandlerResult::ContinueExecution) + { + s_in_handler = false; + return; + } + } + + // call old signal handler + if (s_old_sigsegv_action.sa_flags & SA_SIGINFO) + s_old_sigsegv_action.sa_sigaction(sig, info, ctx); + else if (s_old_sigsegv_action.sa_handler == SIG_DFL) + signal(sig, SIG_DFL); + else if (s_old_sigsegv_action.sa_handler == SIG_IGN) + return; + else + s_old_sigsegv_action.sa_handler(sig); +} + +#endif + +bool InstallHandler(void* owner, Callback callback) +{ + bool was_empty; + { + std::lock_guard guard(m_handler_lock); + if (std::find_if(m_handlers.begin(), m_handlers.end(), + [owner](const RegisteredHandler& rh) { return rh.owner == owner; }) != m_handlers.end()) + { + return false; + } + + was_empty = m_handlers.empty(); + m_handlers.push_back(RegisteredHandler{owner, std::move(callback)}); + } + + if (was_empty) + { +#if defined(WIN32) + s_veh_handle = AddVectoredExceptionHandler(1, ExceptionHandler); + if (!s_veh_handle) + { + Log_ErrorPrint("Failed to add vectored exception handler"); + return false; + } +#elif defined(USE_SIGSEGV) +#if 0 + // TODO: Is this needed? + stack_t signal_stack = {}; + signal_stack.ss_sp = malloc(SIGSTKSZ); + signal_stack.ss_size = SIGSTKSZ; + if (sigaltstack(&signal_stack, nullptr)) + { + Log_ErrorPrintf("signaltstack() failed: %d", errno); + return false; + } +#endif + + struct sigaction sa = {}; + sa.sa_sigaction = SIGSEGVHandler; + sa.sa_flags = SA_SIGINFO; + sigemptyset(&sa.sa_mask); + if (sigaction(SIGSEGV, &sa, &s_old_sigsegv_action) < 0) + { + Log_ErrorPrintf("sigaction() failed: %d", errno); + return false; + } +#else + return false; +#endif + } + + return true; +} + +bool RemoveHandler(void* owner) +{ + std::lock_guard guard(m_handler_lock); + auto it = std::find_if(m_handlers.begin(), m_handlers.end(), + [owner](const RegisteredHandler& rh) { return rh.owner == owner; }); + if (it == m_handlers.end()) + return false; + + m_handlers.erase(it); + + if (m_handlers.empty()) + { +#if defined(WIN32) + RemoveVectoredExceptionHandler(s_veh_handle); + s_veh_handle = nullptr; +#else + // restore old signal handler + if (sigaction(SIGSEGV, &s_old_sigsegv_action, nullptr) < 0) + { + Log_ErrorPrintf("sigaction() failed: %d", errno); + return false; + } + + s_old_sigsegv_action = {}; +#endif + } + + return true; +} + +} // namespace Common::PageFaultHandler diff --git a/src/common/page_fault_handler.h b/src/common/page_fault_handler.h new file mode 100644 index 000000000..b2c4f9040 --- /dev/null +++ b/src/common/page_fault_handler.h @@ -0,0 +1,18 @@ +#pragma once +#include "types.h" +#include + +namespace Common::PageFaultHandler { +enum class HandlerResult +{ + ContinueExecution, + ExecuteNextHandler, +}; + +using Callback = std::function; +using Handle = void*; + +bool InstallHandler(void* owner, Callback callback); +bool RemoveHandler(void* owner); + +} // namespace Common::PageFaultHandler