diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 7428c3a31..3c9859a2f 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -112,12 +112,12 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl switch (stage) { case ShaderStage.Vertex: - context.AppendLine("VertexOut out;"); + context.AppendLine("VertexOut out = {};"); // TODO: Only add if necessary context.AppendLine("uint instance_index = instance_id + base_instance;"); break; case ShaderStage.Fragment: - context.AppendLine("FragmentOut out;"); + context.AppendLine("FragmentOut out = {};"); break; } @@ -420,6 +420,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable.PointSize => "float", IoVariable.FragmentOutputColor => GetVarTypeName(context.Definitions.GetFragmentOutputColorType(ioDefinition.Location)), IoVariable.FragmentOutputDepth => "float", + IoVariable.ClipDistance => "float", _ => GetVarTypeName(context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: true)) }; string name = ioDefinition.IoVariable switch @@ -428,6 +429,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable.PointSize => "point_size", IoVariable.FragmentOutputColor => $"color{ioDefinition.Location}", IoVariable.FragmentOutputDepth => "depth", + IoVariable.ClipDistance => "clip_distance", _ => $"{Defaults.OAttributePrefix}{ioDefinition.Location}" }; string suffix = ioDefinition.IoVariable switch @@ -437,6 +439,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl IoVariable.UserDefined => $"[[user(loc{ioDefinition.Location})]]", IoVariable.FragmentOutputColor => $"[[color({ioDefinition.Location})]]", IoVariable.FragmentOutputDepth => "[[depth(any)]]", + IoVariable.ClipDistance => $"[[clip_distance]][{Defaults.TotalClipDistances}]", _ => "" }; diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Defaults.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Defaults.cs index f43f5f255..a78de36ce 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Defaults.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Defaults.cs @@ -22,5 +22,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl public const uint StorageBuffersIndex = 21; public const uint TexturesIndex = 22; public const uint ImagesIndex = 23; + + public const int TotalClipDistances = 8; } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs index 0bad36f73..44881deee 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/InstGenCall.cs @@ -10,10 +10,18 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions { AstOperand funcId = (AstOperand)operation.GetSource(0); - var functon = context.GetFunction(funcId.Value); + var function = context.GetFunction(funcId.Value); int argCount = operation.SourcesCount - 1; int additionalArgCount = CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0); + bool needsThreadIndex = false; + + // TODO: Replace this with a proper flag + if (function.Name.Contains("Shuffle")) + { + needsThreadIndex = true; + additionalArgCount++; + } string[] args = new string[argCount + additionalArgCount]; @@ -23,20 +31,30 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions args[0] = "in"; args[1] = "constant_buffers"; args[2] = "storage_buffers"; + + if (needsThreadIndex) + { + args[3] = "thread_index_in_simdgroup"; + } } else { args[0] = "constant_buffers"; args[1] = "storage_buffers"; + + if (needsThreadIndex) + { + args[2] = "thread_index_in_simdgroup"; + } } int argIndex = additionalArgCount; for (int i = 0; i < argCount; i++) { - args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), functon.GetArgumentType(i)); + args[argIndex++] = GetSourceExpr(context, operation.GetSource(i + 1), function.GetArgumentType(i)); } - return $"{functon.Name}({string.Join(", ", args)})"; + return $"{function.Name}({string.Join(", ", args)})"; } } } diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs index bd7a15e2f..57c180fb4 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs @@ -20,7 +20,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions IoVariable.BaseInstance => ("base_instance", AggregateType.U32), IoVariable.BaseVertex => ("base_vertex", AggregateType.U32), IoVariable.CtaId => ("threadgroup_position_in_grid", AggregateType.Vector3 | AggregateType.U32), - IoVariable.ClipDistance => ("clip_distance", AggregateType.Array | AggregateType.FP32), + IoVariable.ClipDistance => ("out.clip_distance", AggregateType.Array | AggregateType.FP32), IoVariable.FragmentOutputColor => ($"out.color{location}", definitions.GetFragmentOutputColorType(location)), IoVariable.FragmentOutputDepth => ("out.depth", AggregateType.FP32), IoVariable.FrontFacing => ("in.front_facing", AggregateType.Bool), diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs index 757abffdc..28a69c508 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs @@ -64,6 +64,14 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl bool isMainFunc = false) { int additionalArgCount = isMainFunc ? 0 : CodeGenContext.AdditionalArgCount + (context.Definitions.Stage != ShaderStage.Compute ? 1 : 0); + bool needsThreadIndex = false; + + // TODO: Replace this with a proper flag + if (function.Name.Contains("Shuffle")) + { + needsThreadIndex = true; + additionalArgCount++; + } string[] args = new string[additionalArgCount + function.InArguments.Length + function.OutArguments.Length]; @@ -75,11 +83,21 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl args[0] = stage == ShaderStage.Vertex ? "VertexIn in" : "FragmentIn in"; args[1] = "constant ConstantBuffers &constant_buffers"; args[2] = "device StorageBuffers &storage_buffers"; + + if (needsThreadIndex) + { + args[3] = "uint thread_index_in_simdgroup"; + } } else { args[0] = "constant ConstantBuffers &constant_buffers"; args[1] = "device StorageBuffers &storage_buffers"; + + if (needsThreadIndex) + { + args[2] = "uint thread_index_in_simdgroup"; + } } } @@ -93,8 +111,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl { int j = i + function.InArguments.Length; - // Likely need to be made into pointers - args[argIndex++] = $"out {Declarations.GetVarTypeName(function.OutArguments[i])} {OperandManager.GetArgumentName(j)}"; + args[argIndex++] = $"thread {Declarations.GetVarTypeName(function.OutArguments[i])} &{OperandManager.GetArgumentName(j)}"; } string funcKeyword = "inline";