Reorganize so that we're calling any memory protectives from the arena

This commit is contained in:
Jared M. White 2024-07-03 06:29:15 -05:00
parent 4b53e9b52b
commit e252c4a729
5 changed files with 25 additions and 18 deletions

View File

@ -113,12 +113,13 @@ public:
void UnmapFromMemoryRegion(void* view, size_t size); void UnmapFromMemoryRegion(void* view, size_t size);
/// ///
/// Write protect a section from the memory region previously mapped by CreateView. /// Virtual protect a section from the memory region previously mapped by CreateView.
/// ///
/// @param data Pointer to data to protect. /// @param data Pointer to data to protect.
/// @param size Size of the protection. /// @param size Size of the protection.
/// @param flag What new permission to protect with.
/// ///
bool WriteProtectMemoryRegion(u8* data, size_t size); bool VirtualProtectMemoryRegion(u8* data, size_t size, u64 flag);
private: private:
#ifdef _WIN32 #ifdef _WIN32

View File

@ -215,11 +215,11 @@ void MemArena::ReleaseMemoryRegion()
} }
} }
bool MemArena::WriteProtectMemoryRegion(u8* data, size_t size) bool MemArena::VirtualProtectMemoryRegion(u8* data, size_t size, u64 flag)
{ {
DWORD lpflOldProtect = 0; DWORD lpflOldProtect = 0;
return static_cast<PVirtualProtect>(m_memory_functions.m_address_VirtualProtect)( return static_cast<PVirtualProtect>(m_memory_functions.m_address_VirtualProtect)(
data, size, PAGE_READONLY, &lpflOldProtect); data, size, flag, &lpflOldProtect);
} }
WindowsMemoryRegion* MemArena::EnsureSplitRegionForMapping(void* start_address, size_t size) WindowsMemoryRegion* MemArena::EnsureSplitRegionForMapping(void* start_address, size_t size)

View File

@ -54,7 +54,12 @@ std::optional<size_t> MemoryManager::GetDirtyPageIndexFromAddress(u64 address)
return (address & ~page_mask) >> 12; return (address & ~page_mask) >> 12;
} }
void MemoryManager::WriteProtectMemory() bool MemoryManager::VirtualProtectMemory(u8* data, size_t size, u64 flag)
{
return m_arena.VirtualProtectMemoryRegion(data, size, flag);
}
void MemoryManager::WriteProtectPhysicalMemoryRegions()
{ {
for (const PhysicalMemoryRegion& region : m_physical_regions) for (const PhysicalMemoryRegion& region : m_physical_regions)
{ {
@ -63,12 +68,11 @@ void MemoryManager::WriteProtectMemory()
size_t page_size = Common::PageSize(); size_t page_size = Common::PageSize();
for (size_t i = 0; i < region.size; i += page_size) for (size_t i = 0; i < region.size; i += page_size)
{ {
bool change_protection = bool change_protection = VirtualProtectMemory((*region.out_pointer) + i, page_size, PAGE_READONLY);
m_arena.WriteProtectMemoryRegion((*region.out_pointer) + i, page_size);
if (!change_protection) if (!change_protection)
{ {
PanicAlertFmt( PanicAlertFmt(
"Memory::Init(): Failed to write protect for this block of memory at 0x{:08X}.", "Memory::WriteProtectPhysicalMemoryRegions(): Failed to write protect for this block of memory at 0x{:08X}.",
reinterpret_cast<u64>(*region.out_pointer)); reinterpret_cast<u64>(*region.out_pointer));
} }
std::optional<size_t> index = std::optional<size_t> index =
@ -105,7 +109,7 @@ void MemoryManager::InitMMIO(bool is_wii)
void MemoryManager::InitDirtyPages() void MemoryManager::InitDirtyPages()
{ {
WriteProtectMemory(); WriteProtectPhysicalMemoryRegions();
} }
void MemoryManager::Init() void MemoryManager::Init()
@ -675,7 +679,7 @@ void MemoryManager::SetPageDirtyBit(uintptr_t address, size_t size, bool dirty)
void MemoryManager::ResetDirtyPages() void MemoryManager::ResetDirtyPages()
{ {
WriteProtectMemory(); WriteProtectPhysicalMemoryRegions();
} }
} // namespace Memory } // namespace Memory

View File

@ -135,6 +135,7 @@ public:
bool IsPageDirty(uintptr_t address); bool IsPageDirty(uintptr_t address);
void SetPageDirtyBit(uintptr_t address, size_t size, bool dirty); void SetPageDirtyBit(uintptr_t address, size_t size, bool dirty);
void ResetDirtyPages(); void ResetDirtyPages();
bool VirtualProtectMemory(u8* data, size_t size, u64 flag);
std::map<u64, u8>& GetDirtyPages() { return m_dirty_pages; } std::map<u64, u8>& GetDirtyPages() { return m_dirty_pages; }
@ -265,7 +266,8 @@ private:
std::map<u64, u8> m_dirty_pages; std::map<u64, u8> m_dirty_pages;
std::optional<size_t> GetDirtyPageIndexFromAddress(u64 address); std::optional<size_t> GetDirtyPageIndexFromAddress(u64 address);
void WriteProtectMemory(); void WriteProtectPhysicalMemoryRegions();
void InitMMIO(bool is_wii); void InitMMIO(bool is_wii);
}; };
} // namespace Memory } // namespace Memory

View File

@ -62,21 +62,21 @@ static LONG NTAPI Handler(PEXCEPTION_POINTERS pPtrs)
// virtual address of the inaccessible data // virtual address of the inaccessible data
uintptr_t fault_address = (uintptr_t)pPtrs->ExceptionRecord->ExceptionInformation[1]; uintptr_t fault_address = (uintptr_t)pPtrs->ExceptionRecord->ExceptionInformation[1];
SContext* ctx = pPtrs->ContextRecord; SContext* ctx = pPtrs->ContextRecord;
if (Core::System::GetInstance().GetJitInterface().HandleFault(fault_address, ctx)) Core::System& system = Core::System::GetInstance();
Memory::MemoryManager& memory = system.GetMemory();
if (system.GetJitInterface().HandleFault(fault_address, ctx))
{ {
return EXCEPTION_CONTINUE_EXECUTION; return EXCEPTION_CONTINUE_EXECUTION;
} }
else if (!Core::System::GetInstance().GetMemory().IsPageDirty(fault_address)) else if (!memory.IsPageDirty(fault_address))
{ {
Core::System::GetInstance().GetMemory().SetPageDirtyBit(fault_address, 1, true); memory.SetPageDirtyBit(fault_address, 1, true);
size_t page_size = Common::PageSize(); size_t page_size = Common::PageSize();
size_t page_mask = page_size - 1; size_t page_mask = page_size - 1;
u64 page_index = fault_address & page_mask; u64 page_index = fault_address & page_mask;
DWORD lpflOldProtect = 0; bool change_protection = memory.VirtualProtectMemory(reinterpret_cast<u8*>(fault_address),
bool change_protection = page_size - page_index, PAGE_READWRITE);
VirtualProtect(reinterpret_cast<u8*>(fault_address), page_size - page_index,
PAGE_READWRITE, &lpflOldProtect);
if (!change_protection) if (!change_protection)
{ {
return EXCEPTION_CONTINUE_SEARCH; return EXCEPTION_CONTINUE_SEARCH;