diff --git a/ARMeilleure/Signal/NativeSignalHandler.cs b/ARMeilleure/Signal/NativeSignalHandler.cs index cad0d4202..0257f4403 100644 --- a/ARMeilleure/Signal/NativeSignalHandler.cs +++ b/ARMeilleure/Signal/NativeSignalHandler.cs @@ -197,12 +197,29 @@ namespace ARMeilleure.Signal // Only call tracking if in range. context.BranchIfFalse(nextLabel, inRange, BasicBlockFrequency.Cold); - context.Copy(inRegionLocal, Const(1)); Operand offset = context.BitwiseAnd(context.Subtract(faultAddress, rangeAddress), Const(~PageMask)); // Call the tracking action, with the pointer's relative offset to the base address. Operand trackingActionPtr = context.Load(OperandType.I64, Const((ulong)signalStructPtr + rangeBaseOffset + 20)); - context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0)); + + context.Copy(inRegionLocal, Const(0)); + + Operand skipActionLabel = Label(); + + // Tracking action should be non-null to call it, otherwise assume false return. + context.BranchIfFalse(skipActionLabel, trackingActionPtr); + Operand result = context.Call(trackingActionPtr, OperandType.I32, offset, Const(PageSize), isWrite, Const(0)); + context.Copy(inRegionLocal, result); + + context.MarkLabel(skipActionLabel); + + // If the tracking action returns false or does not exist, it might be an invalid access due to a partial overlap on Windows. + if (OperatingSystem.IsWindows()) + { + context.BranchIfTrue(endLabel, inRegionLocal); + + context.Copy(inRegionLocal, WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context)); + } context.Branch(endLabel); diff --git a/ARMeilleure/Signal/TestMethods.cs b/ARMeilleure/Signal/TestMethods.cs new file mode 100644 index 000000000..2d7cef166 --- /dev/null +++ b/ARMeilleure/Signal/TestMethods.cs @@ -0,0 +1,84 @@ +using ARMeilleure.IntermediateRepresentation; +using ARMeilleure.Translation; +using System; + +using static ARMeilleure.IntermediateRepresentation.Operand.Factory; + +namespace ARMeilleure.Signal +{ + public struct NativeWriteLoopState + { + public int Running; + public int Error; + } + + public static class TestMethods + { + public delegate bool DebugPartialUnmap(); + public delegate int DebugThreadLocalMapGetOrReserve(int threadId, int initialState); + public delegate void DebugNativeWriteLoop(IntPtr nativeWriteLoopPtr, IntPtr writePtr); + + public static DebugPartialUnmap GenerateDebugPartialUnmap() + { + EmitterContext context = new EmitterContext(); + + var result = WindowsPartialUnmapHandler.EmitRetryFromAccessViolation(context); + + context.Return(result); + + // Compile and return the function. + + ControlFlowGraph cfg = context.GetControlFlowGraph(); + + OperandType[] argTypes = new OperandType[] { OperandType.I64 }; + + return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map(); + } + + public static DebugThreadLocalMapGetOrReserve GenerateDebugThreadLocalMapGetOrReserve(IntPtr structPtr) + { + EmitterContext context = new EmitterContext(); + + var result = WindowsPartialUnmapHandler.EmitThreadLocalMapIntGetOrReserve(context, structPtr, context.LoadArgument(OperandType.I32, 0), context.LoadArgument(OperandType.I32, 1)); + + context.Return(result); + + // Compile and return the function. + + ControlFlowGraph cfg = context.GetControlFlowGraph(); + + OperandType[] argTypes = new OperandType[] { OperandType.I64 }; + + return Compiler.Compile(cfg, argTypes, OperandType.I32, CompilerOptions.HighCq).Map(); + } + + public static DebugNativeWriteLoop GenerateDebugNativeWriteLoop() + { + EmitterContext context = new EmitterContext(); + + // Loop a write to the target address until "running" is false. + + Operand structPtr = context.Copy(context.LoadArgument(OperandType.I64, 0)); + Operand writePtr = context.Copy(context.LoadArgument(OperandType.I64, 1)); + + Operand loopLabel = Label(); + context.MarkLabel(loopLabel); + + context.Store(writePtr, Const(12345)); + + Operand running = context.Load(OperandType.I32, structPtr); + + context.BranchIfTrue(loopLabel, running); + + context.Return(); + + // Compile and return the function. + + ControlFlowGraph cfg = context.GetControlFlowGraph(); + + OperandType[] argTypes = new OperandType[] { OperandType.I64 }; + + return Compiler.Compile(cfg, argTypes, OperandType.None, CompilerOptions.HighCq).Map(); + } + } +} diff --git a/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs b/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs new file mode 100644 index 000000000..941e36e58 --- /dev/null +++ b/ARMeilleure/Signal/WindowsPartialUnmapHandler.cs @@ -0,0 +1,186 @@ +using ARMeilleure.IntermediateRepresentation; +using ARMeilleure.Translation; +using Ryujinx.Common.Memory.PartialUnmaps; +using System; + +using static ARMeilleure.IntermediateRepresentation.Operand.Factory; + +namespace ARMeilleure.Signal +{ + /// + /// Methods to handle signals caused by partial unmaps. See the structs for C# implementations of the methods. + /// + internal static class WindowsPartialUnmapHandler + { + public static Operand EmitRetryFromAccessViolation(EmitterContext context) + { + IntPtr partialRemapStatePtr = PartialUnmapState.GlobalState; + IntPtr localCountsPtr = IntPtr.Add(partialRemapStatePtr, PartialUnmapState.LocalCountsOffset); + + // Get the lock first. + EmitNativeReaderLockAcquire(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset)); + + IntPtr getCurrentThreadId = WindowsSignalHandlerRegistration.GetCurrentThreadIdFunc(); + Operand threadId = context.Call(Const((ulong)getCurrentThreadId), OperandType.I32); + Operand threadIndex = EmitThreadLocalMapIntGetOrReserve(context, localCountsPtr, threadId, Const(0)); + + Operand endLabel = Label(); + Operand retry = context.AllocateLocal(OperandType.I32); + Operand threadIndexValidLabel = Label(); + + context.BranchIfFalse(threadIndexValidLabel, context.ICompareEqual(threadIndex, Const(-1))); + + context.Copy(retry, Const(1)); // Always retry when thread local cannot be allocated. + + context.Branch(endLabel); + + context.MarkLabel(threadIndexValidLabel); + + Operand threadLocalPartialUnmapsPtr = EmitThreadLocalMapIntGetValuePtr(context, localCountsPtr, threadIndex); + Operand threadLocalPartialUnmaps = context.Load(OperandType.I32, threadLocalPartialUnmapsPtr); + Operand partialUnmapsCount = context.Load(OperandType.I32, Const((ulong)IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapsCountOffset))); + + context.Copy(retry, context.ICompareNotEqual(threadLocalPartialUnmaps, partialUnmapsCount)); + + Operand noRetryLabel = Label(); + + context.BranchIfFalse(noRetryLabel, retry); + + // if (retry) { + + context.Store(threadLocalPartialUnmapsPtr, partialUnmapsCount); + + context.Branch(endLabel); + + context.MarkLabel(noRetryLabel); + + // } + + context.MarkLabel(endLabel); + + // Finally, release the lock and return the retry value. + EmitNativeReaderLockRelease(context, IntPtr.Add(partialRemapStatePtr, PartialUnmapState.PartialUnmapLockOffset)); + + return retry; + } + + public static Operand EmitThreadLocalMapIntGetOrReserve(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand initialState) + { + Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap.ThreadIdsOffset)); + + Operand i = context.AllocateLocal(OperandType.I32); + + context.Copy(i, Const(0)); + + // (Loop 1) Check all slots for a matching Thread ID (while also trying to allocate) + + Operand endLabel = Label(); + + Operand loopLabel = Label(); + context.MarkLabel(loopLabel); + + Operand offset = context.Multiply(i, Const(sizeof(int))); + Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset)); + + // Check that this slot has the thread ID. + Operand existingId = context.CompareAndSwap(idPtr, threadId, threadId); + + // If it was already the thread ID, then we just need to return i. + context.BranchIfTrue(endLabel, context.ICompareEqual(existingId, threadId)); + + context.Copy(i, context.Add(i, Const(1))); + + context.BranchIfTrue(loopLabel, context.ICompareLess(i, Const(ThreadLocalMap.MapSize))); + + // (Loop 2) Try take a slot that is 0 with our Thread ID. + + context.Copy(i, Const(0)); // Reset i. + + Operand loop2Label = Label(); + context.MarkLabel(loop2Label); + + Operand offset2 = context.Multiply(i, Const(sizeof(int))); + Operand idPtr2 = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset2)); + + // Try and swap in the thread id on top of 0. + Operand existingId2 = context.CompareAndSwap(idPtr2, Const(0), threadId); + + Operand idNot0Label = Label(); + + // If it was 0, then we need to initialize the struct entry and return i. + context.BranchIfFalse(idNot0Label, context.ICompareEqual(existingId2, Const(0))); + + Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap.StructsOffset)); + Operand structPtr = context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset2)); + context.Store(structPtr, initialState); + + context.Branch(endLabel); + + context.MarkLabel(idNot0Label); + + context.Copy(i, context.Add(i, Const(1))); + + context.BranchIfTrue(loop2Label, context.ICompareLess(i, Const(ThreadLocalMap.MapSize))); + + context.Copy(i, Const(-1)); // Could not place the thread in the list. + + context.MarkLabel(endLabel); + + return context.Copy(i); + } + + private static Operand EmitThreadLocalMapIntGetValuePtr(EmitterContext context, IntPtr threadLocalMapPtr, Operand index) + { + Operand offset = context.Multiply(index, Const(sizeof(int))); + Operand structsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap.StructsOffset)); + + return context.Add(structsPtr, context.SignExtend32(OperandType.I64, offset)); + } + + private static void EmitThreadLocalMapIntRelease(EmitterContext context, IntPtr threadLocalMapPtr, Operand threadId, Operand index) + { + Operand offset = context.Multiply(index, Const(sizeof(int))); + Operand idsPtr = Const((ulong)IntPtr.Add(threadLocalMapPtr, ThreadLocalMap.ThreadIdsOffset)); + Operand idPtr = context.Add(idsPtr, context.SignExtend32(OperandType.I64, offset)); + + context.CompareAndSwap(idPtr, threadId, Const(0)); + } + + private static void EmitAtomicAddI32(EmitterContext context, Operand ptr, Operand additive) + { + Operand loop = Label(); + context.MarkLabel(loop); + + Operand initial = context.Load(OperandType.I32, ptr); + Operand newValue = context.Add(initial, additive); + + Operand replaced = context.CompareAndSwap(ptr, initial, newValue); + + context.BranchIfFalse(loop, context.ICompareEqual(initial, replaced)); + } + + private static void EmitNativeReaderLockAcquire(EmitterContext context, IntPtr nativeReaderLockPtr) + { + Operand writeLockPtr = Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.WriteLockOffset)); + + // Spin until we can acquire the write lock. + Operand spinLabel = Label(); + context.MarkLabel(spinLabel); + + // Old value must be 0 to continue (we gained the write lock) + context.BranchIfTrue(spinLabel, context.CompareAndSwap(writeLockPtr, Const(0), Const(1))); + + // Increment reader count. + EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(1)); + + // Release write lock. + context.CompareAndSwap(writeLockPtr, Const(1), Const(0)); + } + + private static void EmitNativeReaderLockRelease(EmitterContext context, IntPtr nativeReaderLockPtr) + { + // Decrement reader count. + EmitAtomicAddI32(context, Const((ulong)IntPtr.Add(nativeReaderLockPtr, NativeReaderWriterLock.ReaderCountOffset)), Const(-1)); + } + } +} diff --git a/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs b/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs index 959d1c477..513829a6e 100644 --- a/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs +++ b/ARMeilleure/Signal/WindowsSignalHandlerRegistration.cs @@ -1,9 +1,10 @@ using System; +using System.Runtime.CompilerServices; using System.Runtime.InteropServices; namespace ARMeilleure.Signal { - class WindowsSignalHandlerRegistration + unsafe class WindowsSignalHandlerRegistration { [DllImport("kernel32.dll")] private static extern IntPtr AddVectoredExceptionHandler(uint first, IntPtr handler); @@ -11,6 +12,14 @@ namespace ARMeilleure.Signal [DllImport("kernel32.dll")] private static extern ulong RemoveVectoredExceptionHandler(IntPtr handle); + [DllImport("kernel32.dll", SetLastError = true, CharSet = CharSet.Ansi)] + static extern IntPtr LoadLibrary([MarshalAs(UnmanagedType.LPStr)] string lpFileName); + + [DllImport("kernel32.dll", CharSet = CharSet.Ansi, ExactSpelling = true, SetLastError = true)] + private static extern IntPtr GetProcAddress(IntPtr hModule, string procName); + + private static IntPtr _getCurrentThreadIdPtr; + public static IntPtr RegisterExceptionHandler(IntPtr action) { return AddVectoredExceptionHandler(1, action); @@ -20,5 +29,17 @@ namespace ARMeilleure.Signal { return RemoveVectoredExceptionHandler(handle) != 0; } + + public static IntPtr GetCurrentThreadIdFunc() + { + if (_getCurrentThreadIdPtr == IntPtr.Zero) + { + IntPtr handle = LoadLibrary("kernel32.dll"); + + _getCurrentThreadIdPtr = GetProcAddress(handle, "GetCurrentThreadId"); + } + + return _getCurrentThreadIdPtr; + } } } diff --git a/Ryujinx.Common/Memory/PartialUnmaps/NativeReaderWriterLock.cs b/Ryujinx.Common/Memory/PartialUnmaps/NativeReaderWriterLock.cs new file mode 100644 index 000000000..5419b3405 --- /dev/null +++ b/Ryujinx.Common/Memory/PartialUnmaps/NativeReaderWriterLock.cs @@ -0,0 +1,80 @@ +using System.Runtime.InteropServices; +using System.Threading; + +using static Ryujinx.Common.Memory.PartialUnmaps.PartialUnmapHelpers; + +namespace Ryujinx.Common.Memory.PartialUnmaps +{ + /// + /// A simple implementation of a ReaderWriterLock which can be used from native code. + /// + [StructLayout(LayoutKind.Sequential, Pack = 1)] + public struct NativeReaderWriterLock + { + public int WriteLock; + public int ReaderCount; + + public static int WriteLockOffset; + public static int ReaderCountOffset; + + /// + /// Populates the field offsets for use when emitting native code. + /// + static NativeReaderWriterLock() + { + NativeReaderWriterLock instance = new NativeReaderWriterLock(); + + WriteLockOffset = OffsetOf(ref instance, ref instance.WriteLock); + ReaderCountOffset = OffsetOf(ref instance, ref instance.ReaderCount); + } + + /// + /// Acquires the reader lock. + /// + public void AcquireReaderLock() + { + // Must take write lock for a very short time to become a reader. + + while (Interlocked.CompareExchange(ref WriteLock, 1, 0) != 0) { } + + Interlocked.Increment(ref ReaderCount); + + Interlocked.Exchange(ref WriteLock, 0); + } + + /// + /// Releases the reader lock. + /// + public void ReleaseReaderLock() + { + Interlocked.Decrement(ref ReaderCount); + } + + /// + /// Upgrades to a writer lock. The reader lock is temporarily released while obtaining the writer lock. + /// + public void UpgradeToWriterLock() + { + // Prevent any more threads from entering reader. + // If the write lock is already taken, wait for it to not be taken. + + Interlocked.Decrement(ref ReaderCount); + + while (Interlocked.CompareExchange(ref WriteLock, 1, 0) != 0) { } + + // Wait for reader count to drop to 0, then take the lock again as the only reader. + + while (Interlocked.CompareExchange(ref ReaderCount, 1, 0) != 0) { } + } + + /// + /// Downgrades from a writer lock, back to a reader one. + /// + public void DowngradeFromWriterLock() + { + // Release the WriteLock. + + Interlocked.Exchange(ref WriteLock, 0); + } + } +} diff --git a/Ryujinx.Common/Memory/PartialUnmaps/PartialUnmapHelpers.cs b/Ryujinx.Common/Memory/PartialUnmaps/PartialUnmapHelpers.cs new file mode 100644 index 000000000..e650de068 --- /dev/null +++ b/Ryujinx.Common/Memory/PartialUnmaps/PartialUnmapHelpers.cs @@ -0,0 +1,20 @@ +using System.Runtime.CompilerServices; + +namespace Ryujinx.Common.Memory.PartialUnmaps +{ + static class PartialUnmapHelpers + { + /// + /// Calculates a byte offset of a given field within a struct. + /// + /// Struct type + /// Field type + /// Parent struct + /// Field + /// The byte offset of the given field in the given struct + public static int OffsetOf(ref T2 storage, ref T target) + { + return (int)Unsafe.ByteOffset(ref Unsafe.As(ref storage), ref target); + } + } +} diff --git a/Ryujinx.Common/Memory/PartialUnmaps/PartialUnmapState.cs b/Ryujinx.Common/Memory/PartialUnmaps/PartialUnmapState.cs new file mode 100644 index 000000000..3b42e140b --- /dev/null +++ b/Ryujinx.Common/Memory/PartialUnmaps/PartialUnmapState.cs @@ -0,0 +1,160 @@ +using System; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Runtime.Versioning; +using System.Threading; + +using static Ryujinx.Common.Memory.PartialUnmaps.PartialUnmapHelpers; + +namespace Ryujinx.Common.Memory.PartialUnmaps +{ + /// + /// State for partial unmaps. Intended to be used on Windows. + /// + [StructLayout(LayoutKind.Sequential, Pack = 1)] + public struct PartialUnmapState + { + public NativeReaderWriterLock PartialUnmapLock; + public int PartialUnmapsCount; + public ThreadLocalMap LocalCounts; + + public readonly static int PartialUnmapLockOffset; + public readonly static int PartialUnmapsCountOffset; + public readonly static int LocalCountsOffset; + + public readonly static IntPtr GlobalState; + + [SupportedOSPlatform("windows")] + [DllImport("kernel32.dll")] + public static extern int GetCurrentThreadId(); + + [SupportedOSPlatform("windows")] + [DllImport("kernel32.dll", SetLastError = true)] + static extern IntPtr OpenThread(int dwDesiredAccess, bool bInheritHandle, uint dwThreadId); + + [SupportedOSPlatform("windows")] + [DllImport("kernel32.dll", SetLastError = true)] + public static extern bool CloseHandle(IntPtr hObject); + + [SupportedOSPlatform("windows")] + [DllImport("kernel32.dll", SetLastError = true)] + static extern bool GetExitCodeThread(IntPtr hThread, out uint lpExitCode); + + /// + /// Creates a global static PartialUnmapState and populates the field offsets. + /// + static unsafe PartialUnmapState() + { + PartialUnmapState instance = new PartialUnmapState(); + + PartialUnmapLockOffset = OffsetOf(ref instance, ref instance.PartialUnmapLock); + PartialUnmapsCountOffset = OffsetOf(ref instance, ref instance.PartialUnmapsCount); + LocalCountsOffset = OffsetOf(ref instance, ref instance.LocalCounts); + + int size = Unsafe.SizeOf(); + GlobalState = Marshal.AllocHGlobal(size); + Unsafe.InitBlockUnaligned((void*)GlobalState, 0, (uint)size); + } + + /// + /// Resets the global state. + /// + public static unsafe void Reset() + { + int size = Unsafe.SizeOf(); + Unsafe.InitBlockUnaligned((void*)GlobalState, 0, (uint)size); + } + + /// + /// Gets a reference to the global state. + /// + /// A reference to the global state + public static unsafe ref PartialUnmapState GetRef() + { + return ref Unsafe.AsRef((void*)GlobalState); + } + + /// + /// Checks if an access violation handler should retry execution due to a fault caused by partial unmap. + /// + /// + /// Due to Windows limitations, might need to unmap more memory than requested. + /// The additional memory that was unmapped is later remapped, however this leaves a time gap where the + /// memory might be accessed but is unmapped. Users of the API must compensate for that by catching the + /// access violation and retrying if it happened between the unmap and remap operation. + /// This method can be used to decide if retrying in such cases is necessary or not. + /// + /// This version of the function is not used, but serves as a reference for the native + /// implementation in ARMeilleure. + /// + /// True if execution should be retried, false otherwise + [SupportedOSPlatform("windows")] + public bool RetryFromAccessViolation() + { + PartialUnmapLock.AcquireReaderLock(); + + int threadID = GetCurrentThreadId(); + int threadIndex = LocalCounts.GetOrReserve(threadID, 0); + + if (threadIndex == -1) + { + // Out of thread local space... try again later. + + PartialUnmapLock.ReleaseReaderLock(); + + return true; + } + + ref int threadLocalPartialUnmapsCount = ref LocalCounts.GetValue(threadIndex); + + bool retry = threadLocalPartialUnmapsCount != PartialUnmapsCount; + if (retry) + { + threadLocalPartialUnmapsCount = PartialUnmapsCount; + } + + PartialUnmapLock.ReleaseReaderLock(); + + return retry; + } + + /// + /// Iterates and trims threads in the thread -> count map that + /// are no longer active. + /// + [SupportedOSPlatform("windows")] + public void TrimThreads() + { + const uint ExitCodeStillActive = 259; + const int ThreadQueryInformation = 0x40; + + Span ids = LocalCounts.ThreadIds.ToSpan(); + + for (int i = 0; i < ids.Length; i++) + { + int id = ids[i]; + + if (id != 0) + { + IntPtr handle = OpenThread(ThreadQueryInformation, false, (uint)id); + + if (handle == IntPtr.Zero) + { + Interlocked.CompareExchange(ref ids[i], 0, id); + } + else + { + GetExitCodeThread(handle, out uint exitCode); + + if (exitCode != ExitCodeStillActive) + { + Interlocked.CompareExchange(ref ids[i], 0, id); + } + + CloseHandle(handle); + } + } + } + } + } +} diff --git a/Ryujinx.Common/Memory/PartialUnmaps/ThreadLocalMap.cs b/Ryujinx.Common/Memory/PartialUnmaps/ThreadLocalMap.cs new file mode 100644 index 000000000..a3bd5be85 --- /dev/null +++ b/Ryujinx.Common/Memory/PartialUnmaps/ThreadLocalMap.cs @@ -0,0 +1,92 @@ +using System.Runtime.InteropServices; +using System.Threading; + +using static Ryujinx.Common.Memory.PartialUnmaps.PartialUnmapHelpers; + +namespace Ryujinx.Common.Memory.PartialUnmaps +{ + /// + /// A simple fixed size thread safe map that can be used from native code. + /// Integer thread IDs map to corresponding structs. + /// + /// The value type for the map + [StructLayout(LayoutKind.Sequential, Pack = 1)] + public struct ThreadLocalMap where T : unmanaged + { + public const int MapSize = 20; + + public Array20 ThreadIds; + public Array20 Structs; + + public static int ThreadIdsOffset; + public static int StructsOffset; + + /// + /// Populates the field offsets for use when emitting native code. + /// + static ThreadLocalMap() + { + ThreadLocalMap instance = new ThreadLocalMap(); + + ThreadIdsOffset = OffsetOf(ref instance, ref instance.ThreadIds); + StructsOffset = OffsetOf(ref instance, ref instance.Structs); + } + + /// + /// Gets the index of a given thread ID in the map, or reserves one. + /// When reserving a struct, its value is set to the given initial value. + /// Returns -1 when there is no space to reserve a new entry. + /// + /// Thread ID to use as a key + /// Initial value of the associated struct. + /// The index of the entry, or -1 if none + public int GetOrReserve(int threadId, T initial) + { + // Try get a match first. + + for (int i = 0; i < MapSize; i++) + { + int compare = Interlocked.CompareExchange(ref ThreadIds[i], threadId, threadId); + + if (compare == threadId) + { + return i; + } + } + + // Try get a free entry. Since the id is assumed to be unique to this thread, we know it doesn't exist yet. + + for (int i = 0; i < MapSize; i++) + { + int compare = Interlocked.CompareExchange(ref ThreadIds[i], threadId, 0); + + if (compare == 0) + { + Structs[i] = initial; + return i; + } + } + + return -1; + } + + /// + /// Gets the struct value for a given map entry. + /// + /// Index of the entry + /// A reference to the struct value + public ref T GetValue(int index) + { + return ref Structs[index]; + } + + /// + /// Releases an entry from the map. + /// + /// Index of the entry to release + public void Release(int index) + { + Interlocked.Exchange(ref ThreadIds[index], 0); + } + } +} diff --git a/Ryujinx.Cpu/Jit/MemoryManagerHostMapped.cs b/Ryujinx.Cpu/Jit/MemoryManagerHostMapped.cs index 5961e3773..4df29699d 100644 --- a/Ryujinx.Cpu/Jit/MemoryManagerHostMapped.cs +++ b/Ryujinx.Cpu/Jit/MemoryManagerHostMapped.cs @@ -89,10 +89,10 @@ namespace Ryujinx.Cpu.Jit MemoryAllocationFlags asFlags = MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible; _addressSpace = new MemoryBlock(asSize, asFlags); - _addressSpaceMirror = new MemoryBlock(asSize, asFlags | MemoryAllocationFlags.ForceWindows4KBViewMapping); + _addressSpaceMirror = new MemoryBlock(asSize, asFlags); Tracking = new MemoryTracking(this, PageSize, invalidAccessHandler); - _memoryEh = new MemoryEhMeilleure(_addressSpace, Tracking); + _memoryEh = new MemoryEhMeilleure(_addressSpace, _addressSpaceMirror, Tracking); } /// diff --git a/Ryujinx.Cpu/MemoryEhMeilleure.cs b/Ryujinx.Cpu/MemoryEhMeilleure.cs index a82295819..806ef8113 100644 --- a/Ryujinx.Cpu/MemoryEhMeilleure.cs +++ b/Ryujinx.Cpu/MemoryEhMeilleure.cs @@ -6,36 +6,57 @@ using System.Runtime.InteropServices; namespace Ryujinx.Cpu { - class MemoryEhMeilleure : IDisposable + public class MemoryEhMeilleure : IDisposable { private delegate bool TrackingEventDelegate(ulong address, ulong size, bool write, bool precise = false); - private readonly MemoryBlock _addressSpace; private readonly MemoryTracking _tracking; private readonly TrackingEventDelegate _trackingEvent; private readonly ulong _baseAddress; + private readonly ulong _mirrorAddress; - public MemoryEhMeilleure(MemoryBlock addressSpace, MemoryTracking tracking) + public MemoryEhMeilleure(MemoryBlock addressSpace, MemoryBlock addressSpaceMirror, MemoryTracking tracking) { - _addressSpace = addressSpace; _tracking = tracking; - _baseAddress = (ulong)_addressSpace.Pointer; + _baseAddress = (ulong)addressSpace.Pointer; ulong endAddress = _baseAddress + addressSpace.Size; - _trackingEvent = new TrackingEventDelegate(tracking.VirtualMemoryEventEh); + _trackingEvent = new TrackingEventDelegate(tracking.VirtualMemoryEvent); bool added = NativeSignalHandler.AddTrackedRegion((nuint)_baseAddress, (nuint)endAddress, Marshal.GetFunctionPointerForDelegate(_trackingEvent)); if (!added) { throw new InvalidOperationException("Number of allowed tracked regions exceeded."); } + + if (OperatingSystem.IsWindows()) + { + // Add a tracking event with no signal handler for the mirror on Windows. + // The native handler has its own code to check for the partial overlap race when regions are protected by accident, + // and when there is no signal handler present. + + _mirrorAddress = (ulong)addressSpaceMirror.Pointer; + ulong endAddressMirror = _mirrorAddress + addressSpace.Size; + + bool addedMirror = NativeSignalHandler.AddTrackedRegion((nuint)_mirrorAddress, (nuint)endAddressMirror, IntPtr.Zero); + + if (!addedMirror) + { + throw new InvalidOperationException("Number of allowed tracked regions exceeded."); + } + } } public void Dispose() { NativeSignalHandler.RemoveTrackedRegion((nuint)_baseAddress); + + if (_mirrorAddress != 0) + { + NativeSignalHandler.RemoveTrackedRegion((nuint)_mirrorAddress); + } } } } diff --git a/Ryujinx.Memory.Tests/MockVirtualMemoryManager.cs b/Ryujinx.Memory.Tests/MockVirtualMemoryManager.cs index cad0c2b53..29922f898 100644 --- a/Ryujinx.Memory.Tests/MockVirtualMemoryManager.cs +++ b/Ryujinx.Memory.Tests/MockVirtualMemoryManager.cs @@ -4,7 +4,7 @@ using System.Collections.Generic; namespace Ryujinx.Memory.Tests { - class MockVirtualMemoryManager : IVirtualMemoryManager + public class MockVirtualMemoryManager : IVirtualMemoryManager { public bool NoMappings = false; diff --git a/Ryujinx.Memory.Tests/Tests.cs b/Ryujinx.Memory.Tests/Tests.cs index aa20c38a8..45d00e51d 100644 --- a/Ryujinx.Memory.Tests/Tests.cs +++ b/Ryujinx.Memory.Tests/Tests.cs @@ -38,9 +38,15 @@ namespace Ryujinx.Memory.Tests Assert.AreEqual(Marshal.ReadInt32(_memoryBlock.Pointer, 0x2040), 0xbadc0de); } - [Test, Explicit] + [Test] public void Test_Alias() { + if (OperatingSystem.IsMacOS()) + { + // Memory aliasing tests fail on CI at the moment. + return; + } + using MemoryBlock backing = new MemoryBlock(0x10000, MemoryAllocationFlags.Mirrorable); using MemoryBlock toAlias = new MemoryBlock(0x10000, MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible); @@ -51,9 +57,15 @@ namespace Ryujinx.Memory.Tests Assert.AreEqual(Marshal.ReadInt32(backing.Pointer, 0x1000), 0xbadc0de); } - [Test, Explicit] + [Test] public void Test_AliasRandom() { + if (OperatingSystem.IsMacOS()) + { + // Memory aliasing tests fail on CI at the moment. + return; + } + using MemoryBlock backing = new MemoryBlock(0x80000, MemoryAllocationFlags.Mirrorable); using MemoryBlock toAlias = new MemoryBlock(0x80000, MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible); diff --git a/Ryujinx.Memory/MemoryAllocationFlags.cs b/Ryujinx.Memory/MemoryAllocationFlags.cs index 8706a25b8..313f33e5f 100644 --- a/Ryujinx.Memory/MemoryAllocationFlags.cs +++ b/Ryujinx.Memory/MemoryAllocationFlags.cs @@ -35,12 +35,6 @@ namespace Ryujinx.Memory /// Indicates that the memory block should support mapping views of a mirrorable memory block. /// The block that is to have their views mapped should be created with the flag. /// - ViewCompatible = 1 << 3, - - /// - /// Forces views to be mapped page by page on Windows. When partial unmaps are done, this avoids the need - /// to unmap the full range and remap sub-ranges, which creates a time window with incorrectly unmapped memory. - /// - ForceWindows4KBViewMapping = 1 << 4 + ViewCompatible = 1 << 3 } } diff --git a/Ryujinx.Memory/MemoryBlock.cs b/Ryujinx.Memory/MemoryBlock.cs index b68a10005..79a5cfe7c 100644 --- a/Ryujinx.Memory/MemoryBlock.cs +++ b/Ryujinx.Memory/MemoryBlock.cs @@ -13,14 +13,11 @@ namespace Ryujinx.Memory private readonly bool _usesSharedMemory; private readonly bool _isMirror; private readonly bool _viewCompatible; - private readonly bool _forceWindows4KBView; private IntPtr _sharedMemory; private IntPtr _pointer; private ConcurrentDictionary _viewStorages; private int _viewCount; - internal bool ForceWindows4KBView => _forceWindows4KBView; - /// /// Pointer to the memory block data. /// @@ -49,8 +46,7 @@ namespace Ryujinx.Memory else if (flags.HasFlag(MemoryAllocationFlags.Reserve)) { _viewCompatible = flags.HasFlag(MemoryAllocationFlags.ViewCompatible); - _forceWindows4KBView = flags.HasFlag(MemoryAllocationFlags.ForceWindows4KBViewMapping); - _pointer = MemoryManagement.Reserve(size, _viewCompatible, _forceWindows4KBView); + _pointer = MemoryManagement.Reserve(size, _viewCompatible); } else { @@ -173,7 +169,7 @@ namespace Ryujinx.Memory /// Throw when is invalid public void Reprotect(ulong offset, ulong size, MemoryPermission permission, bool throwOnFail = true) { - MemoryManagement.Reprotect(GetPointerInternal(offset, size), size, permission, _viewCompatible, _forceWindows4KBView, throwOnFail); + MemoryManagement.Reprotect(GetPointerInternal(offset, size), size, permission, _viewCompatible, throwOnFail); } /// @@ -406,7 +402,7 @@ namespace Ryujinx.Memory } else { - MemoryManagement.Free(ptr, Size, _forceWindows4KBView); + MemoryManagement.Free(ptr, Size); } foreach (MemoryBlock viewStorage in _viewStorages.Keys) diff --git a/Ryujinx.Memory/MemoryManagement.cs b/Ryujinx.Memory/MemoryManagement.cs index 77a8a1efb..7c042eba3 100644 --- a/Ryujinx.Memory/MemoryManagement.cs +++ b/Ryujinx.Memory/MemoryManagement.cs @@ -20,11 +20,11 @@ namespace Ryujinx.Memory } } - public static IntPtr Reserve(ulong size, bool viewCompatible, bool force4KBMap) + public static IntPtr Reserve(ulong size, bool viewCompatible) { if (OperatingSystem.IsWindows()) { - return MemoryManagementWindows.Reserve((IntPtr)size, viewCompatible, force4KBMap); + return MemoryManagementWindows.Reserve((IntPtr)size, viewCompatible); } else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) { @@ -72,14 +72,7 @@ namespace Ryujinx.Memory { if (OperatingSystem.IsWindows()) { - if (owner.ForceWindows4KBView) - { - MemoryManagementWindows.MapView4KB(sharedMemory, srcOffset, address, (IntPtr)size); - } - else - { - MemoryManagementWindows.MapView(sharedMemory, srcOffset, address, (IntPtr)size, owner); - } + MemoryManagementWindows.MapView(sharedMemory, srcOffset, address, (IntPtr)size, owner); } else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) { @@ -95,14 +88,7 @@ namespace Ryujinx.Memory { if (OperatingSystem.IsWindows()) { - if (owner.ForceWindows4KBView) - { - MemoryManagementWindows.UnmapView4KB(address, (IntPtr)size); - } - else - { - MemoryManagementWindows.UnmapView(sharedMemory, address, (IntPtr)size, owner); - } + MemoryManagementWindows.UnmapView(sharedMemory, address, (IntPtr)size, owner); } else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) { @@ -114,20 +100,13 @@ namespace Ryujinx.Memory } } - public static void Reprotect(IntPtr address, ulong size, MemoryPermission permission, bool forView, bool force4KBMap, bool throwOnFail) + public static void Reprotect(IntPtr address, ulong size, MemoryPermission permission, bool forView, bool throwOnFail) { bool result; if (OperatingSystem.IsWindows()) { - if (forView && force4KBMap) - { - result = MemoryManagementWindows.Reprotect4KB(address, (IntPtr)size, permission, forView); - } - else - { - result = MemoryManagementWindows.Reprotect(address, (IntPtr)size, permission, forView); - } + result = MemoryManagementWindows.Reprotect(address, (IntPtr)size, permission, forView); } else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) { @@ -144,11 +123,11 @@ namespace Ryujinx.Memory } } - public static bool Free(IntPtr address, ulong size, bool force4KBMap) + public static bool Free(IntPtr address, ulong size) { if (OperatingSystem.IsWindows()) { - return MemoryManagementWindows.Free(address, (IntPtr)size, force4KBMap); + return MemoryManagementWindows.Free(address, (IntPtr)size); } else if (OperatingSystem.IsLinux() || OperatingSystem.IsMacOS()) { diff --git a/Ryujinx.Memory/MemoryManagementWindows.cs b/Ryujinx.Memory/MemoryManagementWindows.cs index 6b4e7fbee..2f89a921c 100644 --- a/Ryujinx.Memory/MemoryManagementWindows.cs +++ b/Ryujinx.Memory/MemoryManagementWindows.cs @@ -10,23 +10,19 @@ namespace Ryujinx.Memory public const int PageSize = 0x1000; private static readonly PlaceholderManager _placeholders = new PlaceholderManager(); - private static readonly PlaceholderManager4KB _placeholders4KB = new PlaceholderManager4KB(); public static IntPtr Allocate(IntPtr size) { return AllocateInternal(size, AllocationType.Reserve | AllocationType.Commit); } - public static IntPtr Reserve(IntPtr size, bool viewCompatible, bool force4KBMap) + public static IntPtr Reserve(IntPtr size, bool viewCompatible) { if (viewCompatible) { IntPtr baseAddress = AllocateInternal2(size, AllocationType.Reserve | AllocationType.ReservePlaceholder); - if (!force4KBMap) - { - _placeholders.ReserveRange((ulong)baseAddress, (ulong)size); - } + _placeholders.ReserveRange((ulong)baseAddress, (ulong)size); return baseAddress; } @@ -73,49 +69,11 @@ namespace Ryujinx.Memory _placeholders.MapView(sharedMemory, srcOffset, location, size, owner); } - public static void MapView4KB(IntPtr sharedMemory, ulong srcOffset, IntPtr location, IntPtr size) - { - _placeholders4KB.UnmapAndMarkRangeAsMapped(location, size); - - ulong uaddress = (ulong)location; - ulong usize = (ulong)size; - IntPtr endLocation = (IntPtr)(uaddress + usize); - - while (location != endLocation) - { - WindowsApi.VirtualFree(location, (IntPtr)PageSize, AllocationType.Release | AllocationType.PreservePlaceholder); - - var ptr = WindowsApi.MapViewOfFile3( - sharedMemory, - WindowsApi.CurrentProcessHandle, - location, - srcOffset, - (IntPtr)PageSize, - 0x4000, - MemoryProtection.ReadWrite, - IntPtr.Zero, - 0); - - if (ptr == IntPtr.Zero) - { - throw new WindowsApiException("MapViewOfFile3"); - } - - location += PageSize; - srcOffset += PageSize; - } - } - public static void UnmapView(IntPtr sharedMemory, IntPtr location, IntPtr size, MemoryBlock owner) { _placeholders.UnmapView(sharedMemory, location, size, owner); } - public static void UnmapView4KB(IntPtr location, IntPtr size) - { - _placeholders4KB.UnmapView(location, size); - } - public static bool Reprotect(IntPtr address, IntPtr size, MemoryPermission permission, bool forView) { if (forView) @@ -128,34 +86,9 @@ namespace Ryujinx.Memory } } - public static bool Reprotect4KB(IntPtr address, IntPtr size, MemoryPermission permission, bool forView) + public static bool Free(IntPtr address, IntPtr size) { - ulong uaddress = (ulong)address; - ulong usize = (ulong)size; - while (usize > 0) - { - if (!WindowsApi.VirtualProtect((IntPtr)uaddress, (IntPtr)PageSize, WindowsApi.GetProtection(permission), out _)) - { - return false; - } - - uaddress += PageSize; - usize -= PageSize; - } - - return true; - } - - public static bool Free(IntPtr address, IntPtr size, bool force4KBMap) - { - if (force4KBMap) - { - _placeholders4KB.UnmapRange(address, size); - } - else - { - _placeholders.UnreserveRange((ulong)address, (ulong)size); - } + _placeholders.UnreserveRange((ulong)address, (ulong)size); return WindowsApi.VirtualFree(address, IntPtr.Zero, AllocationType.Release); } @@ -207,10 +140,5 @@ namespace Ryujinx.Memory throw new ArgumentException("Invalid address.", nameof(address)); } } - - public static bool RetryFromAccessViolation() - { - return _placeholders.RetryFromAccessViolation(); - } } } \ No newline at end of file diff --git a/Ryujinx.Memory/Tracking/MemoryTracking.cs b/Ryujinx.Memory/Tracking/MemoryTracking.cs index c5abb5765..ec75e3d0e 100644 --- a/Ryujinx.Memory/Tracking/MemoryTracking.cs +++ b/Ryujinx.Memory/Tracking/MemoryTracking.cs @@ -188,30 +188,6 @@ namespace Ryujinx.Memory.Tracking return VirtualMemoryEvent(address, 1, write); } - /// - /// Signal that a virtual memory event happened at the given location. - /// This is similar VirtualMemoryEvent, but on Windows, it might also return true after a partial unmap. - /// This should only be called from the exception handler. - /// - /// Virtual address accessed - /// Size of the region affected in bytes - /// Whether the region was written to or read - /// True if the access is precise, false otherwise - /// True if the event triggered any tracking regions, false otherwise - public bool VirtualMemoryEventEh(ulong address, ulong size, bool write, bool precise = false) - { - // Windows has a limitation, it can't do partial unmaps. - // For this reason, we need to unmap the whole range and then remap the sub-ranges. - // When this happens, we might have caused a undesirable access violation from the time that the range was unmapped. - // In this case, try again as the memory might be mapped now. - if (OperatingSystem.IsWindows() && MemoryManagementWindows.RetryFromAccessViolation()) - { - return true; - } - - return VirtualMemoryEvent(address, size, write, precise); - } - /// /// Signal that a virtual memory event happened at the given location. /// This can be flagged as a precise event, which will avoid reprotection and call special handlers if possible. @@ -237,10 +213,12 @@ namespace Ryujinx.Memory.Tracking if (count == 0 && !precise) { - if (_memoryManager.IsMapped(address)) + if (_memoryManager.IsRangeMapped(address, size)) { + // TODO: There is currently the possibility that a page can be protected after its virtual region is removed. + // This code handles that case when it happens, but it would be better to find out how this happens. _memoryManager.TrackingReprotect(address & ~(ulong)(_pageSize - 1), (ulong)_pageSize, MemoryPermission.ReadAndWrite); - return false; // We can't handle this - it's probably a real invalid access. + return true; // This memory _should_ be mapped, so we need to try again. } else { diff --git a/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs b/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs index 1b425d66f..0937d4623 100644 --- a/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs +++ b/Ryujinx.Memory/WindowsShared/PlaceholderManager.cs @@ -1,5 +1,7 @@ +using Ryujinx.Common.Memory.PartialUnmaps; using System; using System.Diagnostics; +using System.Runtime.CompilerServices; using System.Runtime.Versioning; using System.Threading; @@ -13,13 +15,10 @@ namespace Ryujinx.Memory.WindowsShared { private const ulong MinimumPageSize = 0x1000; - [ThreadStatic] - private static int _threadLocalPartialUnmapsCount; - private readonly IntervalTree _mappings; private readonly IntervalTree _protections; - private readonly ReaderWriterLock _partialUnmapLock; - private int _partialUnmapsCount; + private readonly IntPtr _partialUnmapStatePtr; + private readonly Thread _partialUnmapTrimThread; /// /// Creates a new instance of the Windows memory placeholder manager. @@ -28,7 +27,35 @@ namespace Ryujinx.Memory.WindowsShared { _mappings = new IntervalTree(); _protections = new IntervalTree(); - _partialUnmapLock = new ReaderWriterLock(); + + _partialUnmapStatePtr = PartialUnmapState.GlobalState; + + _partialUnmapTrimThread = new Thread(TrimThreadLocalMapLoop); + _partialUnmapTrimThread.Name = "CPU.PartialUnmapTrimThread"; + _partialUnmapTrimThread.IsBackground = true; + _partialUnmapTrimThread.Start(); + } + + /// + /// Gets a reference to the partial unmap state struct. + /// + /// A reference to the partial unmap state struct + private unsafe ref PartialUnmapState GetPartialUnmapState() + { + return ref Unsafe.AsRef((void*)_partialUnmapStatePtr); + } + + /// + /// Trims inactive threads from the partial unmap state's thread mapping every few seconds. + /// Should be run in a Background thread so that it doesn't stop the program from closing. + /// + private void TrimThreadLocalMapLoop() + { + while (true) + { + Thread.Sleep(2000); + GetPartialUnmapState().TrimThreads(); + } } /// @@ -98,7 +125,8 @@ namespace Ryujinx.Memory.WindowsShared /// Memory block that owns the mapping public void MapView(IntPtr sharedMemory, ulong srcOffset, IntPtr location, IntPtr size, MemoryBlock owner) { - _partialUnmapLock.AcquireReaderLock(Timeout.Infinite); + ref var partialUnmapLock = ref GetPartialUnmapState().PartialUnmapLock; + partialUnmapLock.AcquireReaderLock(); try { @@ -107,7 +135,7 @@ namespace Ryujinx.Memory.WindowsShared } finally { - _partialUnmapLock.ReleaseReaderLock(); + partialUnmapLock.ReleaseReaderLock(); } } @@ -221,7 +249,8 @@ namespace Ryujinx.Memory.WindowsShared /// Memory block that owns the mapping public void UnmapView(IntPtr sharedMemory, IntPtr location, IntPtr size, MemoryBlock owner) { - _partialUnmapLock.AcquireReaderLock(Timeout.Infinite); + ref var partialUnmapLock = ref GetPartialUnmapState().PartialUnmapLock; + partialUnmapLock.AcquireReaderLock(); try { @@ -229,7 +258,7 @@ namespace Ryujinx.Memory.WindowsShared } finally { - _partialUnmapLock.ReleaseReaderLock(); + partialUnmapLock.ReleaseReaderLock(); } } @@ -265,11 +294,6 @@ namespace Ryujinx.Memory.WindowsShared if (IsMapped(overlap.Value)) { - if (!WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)overlap.Start, 2)) - { - throw new WindowsApiException("UnmapViewOfFile2"); - } - // Tree operations might modify the node start/end values, so save a copy before we modify the tree. ulong overlapStart = overlap.Start; ulong overlapEnd = overlap.End; @@ -291,30 +315,46 @@ namespace Ryujinx.Memory.WindowsShared // This is necessary because Windows does not support partial view unmaps. // That is, you can only fully unmap a view that was previously mapped, you can't just unmap a chunck of it. - LockCookie lockCookie = _partialUnmapLock.UpgradeToWriterLock(Timeout.Infinite); + ref var partialUnmapState = ref GetPartialUnmapState(); + ref var partialUnmapLock = ref partialUnmapState.PartialUnmapLock; + partialUnmapLock.UpgradeToWriterLock(); - _partialUnmapsCount++; - - if (overlapStartsBefore) + try { - ulong remapSize = startAddress - overlapStart; + partialUnmapState.PartialUnmapsCount++; - MapViewInternal(sharedMemory, overlapValue, (IntPtr)overlapStart, (IntPtr)remapSize); - RestoreRangeProtection(overlapStart, remapSize); + if (!WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)overlapStart, 2)) + { + throw new WindowsApiException("UnmapViewOfFile2"); + } + + if (overlapStartsBefore) + { + ulong remapSize = startAddress - overlapStart; + + MapViewInternal(sharedMemory, overlapValue, (IntPtr)overlapStart, (IntPtr)remapSize); + RestoreRangeProtection(overlapStart, remapSize); + } + + if (overlapEndsAfter) + { + ulong overlappedSize = endAddress - overlapStart; + ulong remapBackingOffset = overlapValue + overlappedSize; + ulong remapAddress = overlapStart + overlappedSize; + ulong remapSize = overlapEnd - endAddress; + + MapViewInternal(sharedMemory, remapBackingOffset, (IntPtr)remapAddress, (IntPtr)remapSize); + RestoreRangeProtection(remapAddress, remapSize); + } } - - if (overlapEndsAfter) + finally { - ulong overlappedSize = endAddress - overlapStart; - ulong remapBackingOffset = overlapValue + overlappedSize; - ulong remapAddress = overlapStart + overlappedSize; - ulong remapSize = overlapEnd - endAddress; - - MapViewInternal(sharedMemory, remapBackingOffset, (IntPtr)remapAddress, (IntPtr)remapSize); - RestoreRangeProtection(remapAddress, remapSize); + partialUnmapLock.DowngradeFromWriterLock(); } - - _partialUnmapLock.DowngradeFromWriterLock(ref lockCookie); + } + else if (!WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)overlapStart, 2)) + { + throw new WindowsApiException("UnmapViewOfFile2"); } } } @@ -394,7 +434,8 @@ namespace Ryujinx.Memory.WindowsShared /// True if the reprotection was successful, false otherwise public bool ReprotectView(IntPtr address, IntPtr size, MemoryPermission permission) { - _partialUnmapLock.AcquireReaderLock(Timeout.Infinite); + ref var partialUnmapLock = ref GetPartialUnmapState().PartialUnmapLock; + partialUnmapLock.AcquireReaderLock(); try { @@ -402,7 +443,7 @@ namespace Ryujinx.Memory.WindowsShared } finally { - _partialUnmapLock.ReleaseReaderLock(); + partialUnmapLock.ReleaseReaderLock(); } } @@ -659,31 +700,5 @@ namespace Ryujinx.Memory.WindowsShared ReprotectViewInternal((IntPtr)protAddress, (IntPtr)(protEndAddress - protAddress), protection.Value, true); } } - - /// - /// Checks if an access violation handler should retry execution due to a fault caused by partial unmap. - /// - /// - /// Due to Windows limitations, might need to unmap more memory than requested. - /// The additional memory that was unmapped is later remapped, however this leaves a time gap where the - /// memory might be accessed but is unmapped. Users of the API must compensate for that by catching the - /// access violation and retrying if it happened between the unmap and remap operation. - /// This method can be used to decide if retrying in such cases is necessary or not. - /// - /// True if execution should be retried, false otherwise - public bool RetryFromAccessViolation() - { - _partialUnmapLock.AcquireReaderLock(Timeout.Infinite); - - bool retry = _threadLocalPartialUnmapsCount != _partialUnmapsCount; - if (retry) - { - _threadLocalPartialUnmapsCount = _partialUnmapsCount; - } - - _partialUnmapLock.ReleaseReaderLock(); - - return retry; - } } } \ No newline at end of file diff --git a/Ryujinx.Memory/WindowsShared/PlaceholderManager4KB.cs b/Ryujinx.Memory/WindowsShared/PlaceholderManager4KB.cs deleted file mode 100644 index fc056a2f7..000000000 --- a/Ryujinx.Memory/WindowsShared/PlaceholderManager4KB.cs +++ /dev/null @@ -1,170 +0,0 @@ -using System; -using System.Runtime.Versioning; - -namespace Ryujinx.Memory.WindowsShared -{ - /// - /// Windows 4KB memory placeholder manager. - /// - [SupportedOSPlatform("windows")] - class PlaceholderManager4KB - { - private const int PageSize = MemoryManagementWindows.PageSize; - - private readonly IntervalTree _mappings; - - /// - /// Creates a new instance of the Windows 4KB memory placeholder manager. - /// - public PlaceholderManager4KB() - { - _mappings = new IntervalTree(); - } - - /// - /// Unmaps the specified range of memory and marks it as mapped internally. - /// - /// - /// Since this marks the range as mapped, the expectation is that the range will be mapped after calling this method. - /// - /// Memory address to unmap and mark as mapped - /// Size of the range in bytes - public void UnmapAndMarkRangeAsMapped(IntPtr location, IntPtr size) - { - ulong startAddress = (ulong)location; - ulong unmapSize = (ulong)size; - ulong endAddress = startAddress + unmapSize; - - var overlaps = Array.Empty>(); - int count = 0; - - lock (_mappings) - { - count = _mappings.Get(startAddress, endAddress, ref overlaps); - } - - for (int index = 0; index < count; index++) - { - var overlap = overlaps[index]; - - // Tree operations might modify the node start/end values, so save a copy before we modify the tree. - ulong overlapStart = overlap.Start; - ulong overlapEnd = overlap.End; - ulong overlapValue = overlap.Value; - - _mappings.Remove(overlap); - - ulong unmapStart = Math.Max(overlapStart, startAddress); - ulong unmapEnd = Math.Min(overlapEnd, endAddress); - - if (overlapStart < startAddress) - { - startAddress = overlapStart; - } - - if (overlapEnd > endAddress) - { - endAddress = overlapEnd; - } - - ulong currentAddress = unmapStart; - while (currentAddress < unmapEnd) - { - WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2); - currentAddress += PageSize; - } - } - - _mappings.Add(startAddress, endAddress, 0); - } - - /// - /// Unmaps views at the specified memory range. - /// - /// Address of the range - /// Size of the range in bytes - public void UnmapView(IntPtr location, IntPtr size) - { - ulong startAddress = (ulong)location; - ulong unmapSize = (ulong)size; - ulong endAddress = startAddress + unmapSize; - - var overlaps = Array.Empty>(); - int count = 0; - - lock (_mappings) - { - count = _mappings.Get(startAddress, endAddress, ref overlaps); - } - - for (int index = 0; index < count; index++) - { - var overlap = overlaps[index]; - - // Tree operations might modify the node start/end values, so save a copy before we modify the tree. - ulong overlapStart = overlap.Start; - ulong overlapEnd = overlap.End; - - _mappings.Remove(overlap); - - if (overlapStart < startAddress) - { - _mappings.Add(overlapStart, startAddress, 0); - } - - if (overlapEnd > endAddress) - { - _mappings.Add(endAddress, overlapEnd, 0); - } - - ulong unmapStart = Math.Max(overlapStart, startAddress); - ulong unmapEnd = Math.Min(overlapEnd, endAddress); - - ulong currentAddress = unmapStart; - while (currentAddress < unmapEnd) - { - WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2); - currentAddress += PageSize; - } - } - } - - /// - /// Unmaps mapped memory at a given range. - /// - /// Address of the range - /// Size of the range in bytes - public void UnmapRange(IntPtr location, IntPtr size) - { - ulong startAddress = (ulong)location; - ulong unmapSize = (ulong)size; - ulong endAddress = startAddress + unmapSize; - - var overlaps = Array.Empty>(); - int count = 0; - - lock (_mappings) - { - count = _mappings.Get(startAddress, endAddress, ref overlaps); - } - - for (int index = 0; index < count; index++) - { - var overlap = overlaps[index]; - - // Tree operations might modify the node start/end values, so save a copy before we modify the tree. - ulong unmapStart = Math.Max(overlap.Start, startAddress); - ulong unmapEnd = Math.Min(overlap.End, endAddress); - - _mappings.Remove(overlap); - - ulong currentAddress = unmapStart; - while (currentAddress < unmapEnd) - { - WindowsApi.UnmapViewOfFile2(WindowsApi.CurrentProcessHandle, (IntPtr)currentAddress, 2); - currentAddress += PageSize; - } - } - } - } -} \ No newline at end of file diff --git a/Ryujinx.Memory/WindowsShared/WindowsApi.cs b/Ryujinx.Memory/WindowsShared/WindowsApi.cs index 297bd1eee..cbb7d99e6 100644 --- a/Ryujinx.Memory/WindowsShared/WindowsApi.cs +++ b/Ryujinx.Memory/WindowsShared/WindowsApi.cs @@ -76,6 +76,9 @@ namespace Ryujinx.Memory.WindowsShared [DllImport("kernel32.dll")] public static extern uint GetLastError(); + [DllImport("kernel32.dll")] + public static extern int GetCurrentThreadId(); + public static MemoryProtection GetProtection(MemoryPermission permission) { return permission switch diff --git a/Ryujinx.Tests/.runsettings b/Ryujinx.Tests/.runsettings new file mode 100644 index 000000000..ca70d359e --- /dev/null +++ b/Ryujinx.Tests/.runsettings @@ -0,0 +1,8 @@ + + + + + 1 + + + diff --git a/Ryujinx.Tests/Memory/MockMemoryManager.cs b/Ryujinx.Tests/Memory/MockMemoryManager.cs new file mode 100644 index 000000000..3f7692636 --- /dev/null +++ b/Ryujinx.Tests/Memory/MockMemoryManager.cs @@ -0,0 +1,53 @@ +using ARMeilleure.Memory; +using System; + +namespace Ryujinx.Tests.Memory +{ + internal class MockMemoryManager : IMemoryManager + { + public int AddressSpaceBits => throw new NotImplementedException(); + + public IntPtr PageTablePointer => throw new NotImplementedException(); + + public MemoryManagerType Type => MemoryManagerType.HostMappedUnsafe; + +#pragma warning disable CS0067 + public event Action UnmapEvent; +#pragma warning restore CS0067 + + public ref T GetRef(ulong va) where T : unmanaged + { + throw new NotImplementedException(); + } + + public ReadOnlySpan GetSpan(ulong va, int size, bool tracked = false) + { + throw new NotImplementedException(); + } + + public bool IsMapped(ulong va) + { + throw new NotImplementedException(); + } + + public T Read(ulong va) where T : unmanaged + { + throw new NotImplementedException(); + } + + public T ReadTracked(ulong va) where T : unmanaged + { + throw new NotImplementedException(); + } + + public void SignalMemoryTracking(ulong va, ulong size, bool write, bool precise = false) + { + throw new NotImplementedException(); + } + + public void Write(ulong va, T value) where T : unmanaged + { + throw new NotImplementedException(); + } + } +} diff --git a/Ryujinx.Tests/Memory/PartialUnmaps.cs b/Ryujinx.Tests/Memory/PartialUnmaps.cs new file mode 100644 index 000000000..1088b52c4 --- /dev/null +++ b/Ryujinx.Tests/Memory/PartialUnmaps.cs @@ -0,0 +1,484 @@ +using ARMeilleure.Signal; +using ARMeilleure.Translation; +using NUnit.Framework; +using Ryujinx.Common.Memory.PartialUnmaps; +using Ryujinx.Cpu; +using Ryujinx.Cpu.Jit; +using Ryujinx.Memory; +using Ryujinx.Memory.Tests; +using Ryujinx.Memory.Tracking; +using System; +using System.Collections.Generic; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; +using System.Threading; + +namespace Ryujinx.Tests.Memory +{ + [TestFixture] + internal class PartialUnmaps + { + private static Translator _translator; + + private (MemoryBlock virt, MemoryBlock mirror, MemoryEhMeilleure exceptionHandler) GetVirtual(ulong asSize) + { + MemoryAllocationFlags asFlags = MemoryAllocationFlags.Reserve | MemoryAllocationFlags.ViewCompatible; + + var addressSpace = new MemoryBlock(asSize, asFlags); + var addressSpaceMirror = new MemoryBlock(asSize, asFlags); + + var tracking = new MemoryTracking(new MockVirtualMemoryManager(asSize, 0x1000), 0x1000); + var exceptionHandler = new MemoryEhMeilleure(addressSpace, addressSpaceMirror, tracking); + + return (addressSpace, addressSpaceMirror, exceptionHandler); + } + + private int CountThreads(ref PartialUnmapState state) + { + int count = 0; + + ref var ids = ref state.LocalCounts.ThreadIds; + + for (int i = 0; i < ids.Length; i++) + { + if (ids[i] != 0) + { + count++; + } + } + + return count; + } + + private void EnsureTranslator() + { + // Create a translator, as one is needed to register the signal handler or emit methods. + _translator ??= new Translator(new JitMemoryAllocator(), new MockMemoryManager(), true); + } + + [Test] + public void PartialUnmap([Values] bool readOnly) + { + if (OperatingSystem.IsMacOS()) + { + // Memory aliasing tests fail on CI at the moment. + return; + } + + // Set up an address space to test partial unmapping. + // Should register the signal handler to deal with this on Windows. + ulong vaSize = 0x100000; + + // The first 0x100000 is mapped to start. It is replaced from the center with the 0x200000 mapping. + var backing = new MemoryBlock(vaSize * 2, MemoryAllocationFlags.Mirrorable); + + (MemoryBlock unusedMainMemory, MemoryBlock memory, MemoryEhMeilleure exceptionHandler) = GetVirtual(vaSize * 2); + + EnsureTranslator(); + + ref var state = ref PartialUnmapState.GetRef(); + + try + { + // Globally reset the struct for handling partial unmap races. + PartialUnmapState.Reset(); + bool shouldAccess = true; + bool error = false; + + // Create a large mapping. + memory.MapView(backing, 0, 0, vaSize); + + if (readOnly) + { + memory.Reprotect(0, vaSize, MemoryPermission.Read); + } + + Thread testThread; + + if (readOnly) + { + // Write a value to the physical memory, then try to read it repeately from virtual. + // It should not change. + testThread = new Thread(() => + { + int i = 12345; + backing.Write(vaSize - 0x1000, i); + + while (shouldAccess) + { + if (memory.Read(vaSize - 0x1000) != i) + { + error = true; + shouldAccess = false; + } + } + }); + } + else + { + // Repeatedly write and check the value on the last page of the mapping on another thread. + testThread = new Thread(() => + { + int i = 0; + while (shouldAccess) + { + memory.Write(vaSize - 0x1000, i); + if (memory.Read(vaSize - 0x1000) != i) + { + error = true; + shouldAccess = false; + } + + i++; + } + }); + } + + testThread.Start(); + + // Create a smaller mapping, covering the larger mapping. + // Immediately try to write to the part of the larger mapping that did not change. + // Do this a lot, with the smaller mapping gradually increasing in size. Should not crash, data should not be lost. + + ulong pageSize = 0x1000; + int mappingExpandCount = (int)(vaSize / (pageSize * 2)) - 1; + ulong vaCenter = vaSize / 2; + + for (int i = 1; i <= mappingExpandCount; i++) + { + ulong start = vaCenter - (pageSize * (ulong)i); + ulong size = pageSize * (ulong)i * 2; + + ulong startPa = start + vaSize; + + memory.MapView(backing, startPa, start, size); + } + + // On Windows, this should put unmap counts on the thread local map. + if (OperatingSystem.IsWindows()) + { + // One thread should be present on the thread local map. Trimming should remove it. + Assert.AreEqual(1, CountThreads(ref state)); + } + + shouldAccess = false; + testThread.Join(); + + Assert.False(error); + + string test = null; + + try + { + test.IndexOf('1'); + } + catch (NullReferenceException) + { + // This shouldn't freeze. + } + + if (OperatingSystem.IsWindows()) + { + state.TrimThreads(); + + Assert.AreEqual(0, CountThreads(ref state)); + } + + /* + * Use this to test invalid access. Can't put this in the test suite unfortunately as invalid access crashes the test process. + * memory.Reprotect(vaSize - 0x1000, 0x1000, MemoryPermission.None); + * //memory.UnmapView(backing, vaSize - 0x1000, 0x1000); + * memory.Read(vaSize - 0x1000); + */ + } + finally + { + exceptionHandler.Dispose(); + unusedMainMemory.Dispose(); + memory.Dispose(); + backing.Dispose(); + } + } + + [Test] + public unsafe void PartialUnmapNative() + { + if (OperatingSystem.IsMacOS()) + { + // Memory aliasing tests fail on CI at the moment. + return; + } + + // Set up an address space to test partial unmapping. + // Should register the signal handler to deal with this on Windows. + ulong vaSize = 0x100000; + + // The first 0x100000 is mapped to start. It is replaced from the center with the 0x200000 mapping. + var backing = new MemoryBlock(vaSize * 2, MemoryAllocationFlags.Mirrorable); + + (MemoryBlock mainMemory, MemoryBlock unusedMirror, MemoryEhMeilleure exceptionHandler) = GetVirtual(vaSize * 2); + + EnsureTranslator(); + + ref var state = ref PartialUnmapState.GetRef(); + + // Create some state to be used for managing the native writing loop. + int stateSize = Unsafe.SizeOf(); + var statePtr = Marshal.AllocHGlobal(stateSize); + Unsafe.InitBlockUnaligned((void*)statePtr, 0, (uint)stateSize); + + ref NativeWriteLoopState writeLoopState = ref Unsafe.AsRef((void*)statePtr); + writeLoopState.Running = 1; + writeLoopState.Error = 0; + + try + { + // Globally reset the struct for handling partial unmap races. + PartialUnmapState.Reset(); + + // Create a large mapping. + mainMemory.MapView(backing, 0, 0, vaSize); + + var writeFunc = TestMethods.GenerateDebugNativeWriteLoop(); + IntPtr writePtr = mainMemory.GetPointer(vaSize - 0x1000, 4); + + Thread testThread = new Thread(() => + { + writeFunc(statePtr, writePtr); + }); + + testThread.Start(); + + // Create a smaller mapping, covering the larger mapping. + // Immediately try to write to the part of the larger mapping that did not change. + // Do this a lot, with the smaller mapping gradually increasing in size. Should not crash, data should not be lost. + + ulong pageSize = 0x1000; + int mappingExpandCount = (int)(vaSize / (pageSize * 2)) - 1; + ulong vaCenter = vaSize / 2; + + for (int i = 1; i <= mappingExpandCount; i++) + { + ulong start = vaCenter - (pageSize * (ulong)i); + ulong size = pageSize * (ulong)i * 2; + + ulong startPa = start + vaSize; + + mainMemory.MapView(backing, startPa, start, size); + } + + writeLoopState.Running = 0; + testThread.Join(); + + Assert.False(writeLoopState.Error != 0); + } + finally + { + Marshal.FreeHGlobal(statePtr); + + exceptionHandler.Dispose(); + mainMemory.Dispose(); + unusedMirror.Dispose(); + backing.Dispose(); + } + } + + [Test] + public void ThreadLocalMap() + { + if (!OperatingSystem.IsWindows()) + { + // Only test in Windows, as this is only used on Windows and uses Windows APIs for trimming. + return; + } + + PartialUnmapState.Reset(); + ref var state = ref PartialUnmapState.GetRef(); + + bool running = true; + var testThread = new Thread(() => + { + if (!OperatingSystem.IsWindows()) + { + // Need this here to avoid a warning. + return; + } + + PartialUnmapState.GetRef().RetryFromAccessViolation(); + while (running) + { + Thread.Sleep(1); + } + }); + + testThread.Start(); + Thread.Sleep(200); + + Assert.AreEqual(1, CountThreads(ref state)); + + // Trimming should not remove the thread as it's still active. + state.TrimThreads(); + Assert.AreEqual(1, CountThreads(ref state)); + + running = false; + + testThread.Join(); + + // Should trim now that it's inactive. + state.TrimThreads(); + Assert.AreEqual(0, CountThreads(ref state)); + } + + [Test] + public unsafe void ThreadLocalMapNative() + { + if (!OperatingSystem.IsWindows()) + { + // Only test in Windows, as this is only used on Windows and uses Windows APIs for trimming. + return; + } + + EnsureTranslator(); + + PartialUnmapState.Reset(); + + ref var state = ref PartialUnmapState.GetRef(); + + fixed (void* localMap = &state.LocalCounts) + { + var getOrReserve = TestMethods.GenerateDebugThreadLocalMapGetOrReserve((IntPtr)localMap); + + for (int i = 0; i < ThreadLocalMap.MapSize; i++) + { + // Should obtain the index matching the call #. + Assert.AreEqual(i, getOrReserve(i + 1, i)); + + // Check that this and all previously reserved thread IDs and struct contents are intact. + for (int j = 0; j <= i; j++) + { + Assert.AreEqual(j + 1, state.LocalCounts.ThreadIds[j]); + Assert.AreEqual(j, state.LocalCounts.Structs[j]); + } + } + + // Trying to reserve again when the map is full should return -1. + Assert.AreEqual(-1, getOrReserve(200, 0)); + + for (int i = 0; i < ThreadLocalMap.MapSize; i++) + { + // Should obtain the index matching the call #, as it already exists. + Assert.AreEqual(i, getOrReserve(i + 1, -1)); + + // The struct should not be reset to -1. + Assert.AreEqual(i, state.LocalCounts.Structs[i]); + } + + // Clear one of the ids as if it were freed. + state.LocalCounts.ThreadIds[13] = 0; + + // GetOrReserve should now obtain and return 13. + Assert.AreEqual(13, getOrReserve(300, 301)); + Assert.AreEqual(300, state.LocalCounts.ThreadIds[13]); + Assert.AreEqual(301, state.LocalCounts.Structs[13]); + } + } + + [Test] + public void NativeReaderWriterLock() + { + var rwLock = new NativeReaderWriterLock(); + var threads = new List(); + + int value = 0; + + bool running = true; + bool error = false; + int readersAllowed = 1; + + for (int i = 0; i < 5; i++) + { + var readThread = new Thread(() => + { + int count = 0; + while (running) + { + rwLock.AcquireReaderLock(); + + int originalValue = Thread.VolatileRead(ref value); + + count++; + + // Spin a bit. + for (int i = 0; i < 100; i++) + { + if (Thread.VolatileRead(ref readersAllowed) == 0) + { + error = true; + running = false; + } + } + + // Should not change while the lock is held. + if (Thread.VolatileRead(ref value) != originalValue) + { + error = true; + running = false; + } + + rwLock.ReleaseReaderLock(); + } + }); + + threads.Add(readThread); + } + + for (int i = 0; i < 2; i++) + { + var writeThread = new Thread(() => + { + int count = 0; + while (running) + { + rwLock.AcquireReaderLock(); + rwLock.UpgradeToWriterLock(); + + Thread.Sleep(2); + count++; + + Interlocked.Exchange(ref readersAllowed, 0); + + for (int i = 0; i < 10; i++) + { + Interlocked.Increment(ref value); + } + + Interlocked.Exchange(ref readersAllowed, 1); + + rwLock.DowngradeFromWriterLock(); + rwLock.ReleaseReaderLock(); + + Thread.Sleep(1); + } + }); + + threads.Add(writeThread); + } + + foreach (var thread in threads) + { + thread.Start(); + } + + Thread.Sleep(1000); + + running = false; + + foreach (var thread in threads) + { + thread.Join(); + } + + Assert.False(error); + } + } +} diff --git a/Ryujinx.Tests/Ryujinx.Tests.csproj b/Ryujinx.Tests/Ryujinx.Tests.csproj index ec191a8ee..42a35e9ea 100644 --- a/Ryujinx.Tests/Ryujinx.Tests.csproj +++ b/Ryujinx.Tests/Ryujinx.Tests.csproj @@ -1,4 +1,4 @@ - + net6.0 @@ -9,10 +9,12 @@ osx linux Debug;Release + $(MSBuildProjectDirectory)\.runsettings false + True @@ -25,6 +27,7 @@ +