diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/CodeGenContext.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/CodeGenContext.cs index 79c13964c..0ae6313eb 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/CodeGenContext.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/CodeGenContext.cs @@ -9,7 +9,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public const string Tab = " "; // The number of additional arguments that every function (except for the main one) must have (for instance support_buffer) - public const int AdditionalArgCount = 2; + public const int AdditionalArgCount = 1; public StructuredFunction CurrentFunction { get; set; } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 8d4a9c877..60729ac60 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -64,9 +64,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl return ioDefinition.StorageKind == storageKind && ioDefinition.IoVariable == IoVariable.UserDefined; } - public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage) + public static void DeclareLocals(CodeGenContext context, StructuredFunction function, ShaderStage stage, bool isMainFunc = false) { - DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false); + if (isMainFunc) + { + DeclareMemories(context, context.Properties.LocalMemories.Values, isShared: false); + DeclareMemories(context, context.Properties.SharedMemories.Values, isShared: true); + } + switch (stage) { case ShaderStage.Vertex: @@ -112,6 +117,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl private static void DeclareMemories(CodeGenContext context, IEnumerable memories, bool isShared) { + string prefix = isShared ? "threadgroup " : string.Empty; + foreach (var memory in memories) { string arraySize = ""; @@ -120,7 +127,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl arraySize = $"[{memory.ArrayLength}]"; } var typeName = GetVarTypeName(context, memory.Type & ~AggregateType.Array); - context.AppendLine($"{typeName} {memory.Name}{arraySize};"); + context.AppendLine($"{prefix}{typeName} {memory.Name}{arraySize};"); } } @@ -128,23 +135,28 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl { foreach (BufferDefinition buffer in buffers) { - context.AppendLine($"struct Struct_{buffer.Name}"); + context.AppendLine($"struct {DefaultNames.StructPrefix}_{buffer.Name}"); context.EnterScope(); foreach (StructureField field in buffer.Type.Fields) { - if (field.Type.HasFlag(AggregateType.Array) && field.ArrayLength > 0) - { - string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array); + string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array); + string arraySuffix = ""; - context.AppendLine($"{typeName} {field.Name}[{field.ArrayLength}];"); - } - else + if (field.Type.HasFlag(AggregateType.Array)) { - string typeName = GetVarTypeName(context, field.Type & ~AggregateType.Array); - - context.AppendLine($"{typeName} {field.Name};"); + if (field.ArrayLength > 0) + { + arraySuffix = $"[{field.ArrayLength}]"; + } + else + { + // Probably UB, but this is the approach that MVK takes + arraySuffix = "[1]"; + } } + + context.AppendLine($"{typeName} {field.Name}{arraySuffix};"); } context.LeaveScope(";"); @@ -191,6 +203,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable.GlobalId => "uint3", IoVariable.VertexId => "uint", IoVariable.VertexIndex => "uint", + IoVariable.PointCoord => "float2", _ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: false)) }; string name = ioDefinition.IoVariable switch @@ -199,6 +212,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable.GlobalId => "global_id", IoVariable.VertexId => "vertex_id", IoVariable.VertexIndex => "vertex_index", + IoVariable.PointCoord => "point_coord", _ => $"{DefaultNames.IAttributePrefix}{ioDefinition.Location}" }; string suffix = ioDefinition.IoVariable switch @@ -208,6 +222,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable.VertexId => "[[vertex_id]]", // TODO: Avoid potential redeclaration IoVariable.VertexIndex => "[[vertex_id]]", + IoVariable.PointCoord => "[[point_coord]]", IoVariable.UserDefined => context.Definitions.Stage == ShaderStage.Fragment ? $"[[user(loc{ioDefinition.Location})]]" : $"[[attribute({ioDefinition.Location})]]", _ => "" }; diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/DefaultNames.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/DefaultNames.cs index 8a468395e..0b946c3aa 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/DefaultNames.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/DefaultNames.cs @@ -8,6 +8,8 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public const string IAttributePrefix = "inAttr"; public const string OAttributePrefix = "outAttr"; + public const string StructPrefix = "struct"; + public const string ArgumentNamePrefix = "a"; public const string UndefinedName = "0"; diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs index 8c101ad75..696564992 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGen.cs @@ -2,7 +2,7 @@ using Ryujinx.Graphics.Shader.IntermediateRepresentation; using Ryujinx.Graphics.Shader.StructuredIr; using Ryujinx.Graphics.Shader.Translation; using System; - +using System.Text; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenCall; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenHelper; using static Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions.InstGenMemory; @@ -39,11 +39,20 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions int arity = (int)(info.Type & InstType.ArityMask); - string args = string.Empty; + StringBuilder builder = new(); - if (atomic) + if (atomic && (operation.StorageKind == StorageKind.StorageBuffer || operation.StorageKind == StorageKind.SharedMemory)) { - // Hell + builder.Append(GenerateLoadOrStore(context, operation, isStore: false)); + + AggregateType dstType = operation.Inst == Instruction.AtomicMaxS32 || operation.Inst == Instruction.AtomicMinS32 + ? AggregateType.S32 + : AggregateType.U32; + + for (int argIndex = operation.SourcesCount - arity + 2; argIndex < operation.SourcesCount; argIndex++) + { + builder.Append($", {GetSourceExpr(context, operation.GetSource(argIndex), dstType)}"); + } } else { @@ -51,16 +60,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions { if (argIndex != 0) { - args += ", "; + builder.Append(", "); } AggregateType dstType = GetSrcVarType(inst, argIndex); - args += GetSourceExpr(context, operation.GetSource(argIndex), dstType); + builder.Append(GetSourceExpr(context, operation.GetSource(argIndex), dstType)); } } - return info.OpName + '(' + args + ')'; + return $"{info.OpName}({builder})"; } else if ((info.Type & InstType.Op) != 0) { @@ -110,7 +119,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions switch (inst & Instruction.Mask) { case Instruction.Barrier: - return "|| BARRIER ||"; + return "threadgroup_barrier(mem_flags::mem_threadgroup)"; case Instruction.Call: return Call(context, operation); case Instruction.FSIBegin: diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs index 5df3aa282..c063ff458 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs @@ -13,13 +13,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions var functon = context.GetFunction(funcId.Value); int argCount = operation.SourcesCount - 1; - string[] args = new string[argCount + CodeGenContext.AdditionalArgCount]; + int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0); + + string[] args = new string[argCount + additionalArgCount]; // Additional arguments - args[0] = "in"; - args[1] = "support_buffer"; + if (context.Definitions.Stage != ShaderStage.Compute) + { + args[0] = "in"; + args[1] = "support_buffer"; + } + else + { + args[0] = "support_buffer"; + } - int argIndex = CodeGenContext.AdditionalArgCount; + int argIndex = additionalArgCount; for (int i = 0; i < argCount; i++) { args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i)); diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs index 406fda11a..014d070ef 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenHelper.cs @@ -109,10 +109,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions Add(Instruction.ShiftLeft, InstType.OpBinary, "<<", 3); Add(Instruction.ShiftRightS32, InstType.OpBinary, ">>", 3); Add(Instruction.ShiftRightU32, InstType.OpBinary, ">>", 3); - Add(Instruction.Shuffle, InstType.CallQuaternary, "simd_shuffle"); - Add(Instruction.ShuffleDown, InstType.CallQuaternary, "simd_shuffle_down"); - Add(Instruction.ShuffleUp, InstType.CallQuaternary, "simd_shuffle_up"); - Add(Instruction.ShuffleXor, InstType.CallQuaternary, "simd_shuffle_xor"); + Add(Instruction.Shuffle, InstType.CallBinary, "simd_shuffle"); + Add(Instruction.ShuffleDown, InstType.CallBinary, "simd_shuffle_down"); + Add(Instruction.ShuffleUp, InstType.CallBinary, "simd_shuffle_up"); + Add(Instruction.ShuffleXor, InstType.CallBinary, "simd_shuffle_xor"); Add(Instruction.Sine, InstType.CallUnary, "sin"); Add(Instruction.SquareRoot, InstType.CallUnary, "sqrt"); Add(Instruction.Store, InstType.Special); diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs index 135cd80e0..bb1a69939 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenMemory.cs @@ -47,15 +47,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions StructureField field = buffer.Type.Fields[fieldIndex.Value]; varName = buffer.Name; - if ((field.Type & AggregateType.Array) != 0 && field.ArrayLength == 0) - { - // Unsized array, the buffer is indexed instead of the field - fieldName = "." + field.Name; - } - else - { - varName += "->" + field.Name; - } + varName += "->" + field.Name; varType = field.Type; break; diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs index 2e93310aa..1561271d0 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs @@ -27,13 +27,19 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions IoVariable.GlobalId => ("thread_position_in_grid", AggregateType.Vector3 | AggregateType.U32), IoVariable.InstanceId => ("instance_id", AggregateType.S32), IoVariable.InvocationId => ("INVOCATION_ID", AggregateType.S32), - IoVariable.PointCoord => ("point_coord", AggregateType.Vector2), + IoVariable.PointCoord => ("point_coord", AggregateType.Vector2 | AggregateType.FP32), IoVariable.PointSize => ("out.point_size", AggregateType.FP32), IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32), IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32), + IoVariable.SubgroupEqMask => ("thread_index_in_simdgroup >= 32 ? uint4(0, (1 << (thread_index_in_simdgroup - 32)), uint2(0)) : uint4(1 << thread_index_in_simdgroup, uint3(0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.SubgroupGeMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup, 32 - thread_index_in_simdgroup), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.SubgroupGtMask => ("uint4(insert_bits(0u, 0xFFFFFFFF, thread_index_in_simdgroup + 1, 32 - thread_index_in_simdgroup - 1), uint3(0)) & (uint4((uint)((simd_vote::vote_t)simd_ballot(true) & 0xFFFFFFFF), (uint)(((simd_vote::vote_t)simd_ballot(true) >> 32) & 0xFFFFFFFF), 0, 0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32), + IoVariable.SubgroupLeMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup + 1 - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.SubgroupLtMask => ("uint4(extract_bits(0xFFFFFFFF, 0, min(thread_index_in_simdgroup, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)thread_index_in_simdgroup - 32, 0)), uint2(0))", AggregateType.Vector4 | AggregateType.U32), + IoVariable.ThreadKill => ("simd_is_helper_thread()", AggregateType.Bool), IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch), IoVariable.ThreadId => ("thread_position_in_threadgroup", AggregateType.Vector3 | AggregateType.U32), - IoVariable.SubgroupLaneId => ("thread_index_in_simdgroup", AggregateType.U32), IoVariable.VertexId => ("vertex_id", AggregateType.S32), // gl_VertexIndex does not have a direct equivalent in MSL IoVariable.VertexIndex => ("vertex_id", AggregateType.U32), diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs index 87512a961..bb5ea5f6f 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs @@ -44,7 +44,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl context.AppendLine(GetFunctionSignature(context, function, stage, isMainFunc)); context.EnterScope(); - Declarations.DeclareLocals(context, function, stage); + Declarations.DeclareLocals(context, function, stage, isMainFunc); PrintBlock(context, function.MainBlock, isMainFunc); @@ -63,15 +63,22 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl ShaderStage stage, bool isMainFunc = false) { - int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount; + int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0); string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length]; // All non-main functions need to be able to access the support_buffer as well if (!isMainFunc) { - args[0] = "FragmentIn in"; - args[1] = "constant Struct_support_buffer* support_buffer"; + if (stage != ShaderStage.Compute) + { + args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in"; + args[1] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer"; + } + else + { + args[0] = $"constant {DefaultNames.StructPrefix}_support_buffer* support_buffer"; + } } int argIndex = additionalArgCount; @@ -141,13 +148,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl foreach (var constantBuffer in context.Properties.ConstantBuffers.Values) { - args = args.Append($"constant Struct_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray(); + args = args.Append($"constant {DefaultNames.StructPrefix}_{constantBuffer.Name}* {constantBuffer.Name} [[buffer({constantBuffer.Binding})]]").ToArray(); } foreach (var storageBuffers in context.Properties.StorageBuffers.Values) { // Offset the binding by 15 to avoid clashing with the constant buffers - args = args.Append($"device Struct_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray(); + args = args.Append($"device {DefaultNames.StructPrefix}_{storageBuffers.Name}* {storageBuffers.Name} [[buffer({storageBuffers.Binding + 15})]]").ToArray(); } foreach (var texture in context.Properties.Textures.Values) @@ -162,7 +169,10 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl } } - return $"{funcKeyword} {returnType} {funcName ?? function.Name}({string.Join(", ", args)})"; + var funcPrefix = $"{funcKeyword} {returnType} {funcName ?? function.Name}("; + var indent = new string(' ', funcPrefix.Length); + + return $"{funcPrefix}{string.Join($", \n{indent}", args)})"; } private static void PrintBlock(CodeGenContext context, AstBlock block, bool isMainFunction) diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs index f086e7436..63ecbc0aa 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/NumberFormatter.cs @@ -10,25 +10,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public static bool TryFormat(int value, AggregateType dstType, out string formatted) { - if (dstType == AggregateType.FP32) + switch (dstType) { - return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted); - } - else if (dstType == AggregateType.S32) - { - formatted = FormatInt(value); - } - else if (dstType == AggregateType.U32) - { - formatted = FormatUint((uint)value); - } - else if (dstType == AggregateType.Bool) - { - formatted = value != 0 ? "true" : "false"; - } - else - { - throw new ArgumentException($"Invalid variable type \"{dstType}\"."); + case AggregateType.FP32: + return TryFormatFloat(BitConverter.Int32BitsToSingle(value), out formatted); + case AggregateType.S32: + formatted = FormatInt(value); + break; + case AggregateType.U32: + formatted = FormatUint((uint)value); + break; + case AggregateType.Bool: + formatted = value != 0 ? "true" : "false"; + break; + default: + throw new ArgumentException($"Invalid variable type \"{dstType}\"."); } return true; @@ -65,18 +61,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public static string FormatInt(int value, AggregateType dstType) { - if (dstType == AggregateType.S32) + return dstType switch { - return FormatInt(value); - } - else if (dstType == AggregateType.U32) - { - return FormatUint((uint)value); - } - else - { - throw new ArgumentException($"Invalid variable type \"{dstType}\"."); - } + AggregateType.S32 => FormatInt(value), + AggregateType.U32 => FormatUint((uint)value), + _ => throw new ArgumentException($"Invalid variable type \"{dstType}\".") + }; } public static string FormatInt(int value)