diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs index 98e681f3a..4580df203 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Declarations.cs @@ -65,11 +65,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl { if (stage == ShaderStage.Vertex) { - context.AppendLine("VertexOutput out;"); + context.AppendLine("VertexOut out;"); } else if (stage == ShaderStage.Fragment) { - context.AppendLine("FragmentOutput out;"); + context.AppendLine("FragmentOut out;"); } foreach (AstOperand decl in function.Locals) @@ -120,17 +120,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl switch (context.Definitions.Stage) { case ShaderStage.Vertex: - prefix = "Vertex"; + context.AppendLine($"struct VertexIn"); break; case ShaderStage.Fragment: - prefix = "Fragment"; + context.AppendLine($"struct VertexOut"); break; case ShaderStage.Compute: - prefix = "Compute"; + context.AppendLine($"struct ComputeIn"); break; } - context.AppendLine($"struct {prefix}In"); context.EnterScope(); foreach (var ioDefinition in inputs.OrderBy(x => x.Location)) @@ -162,31 +161,38 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl switch (context.Definitions.Stage) { case ShaderStage.Vertex: - prefix = "Vertex"; + context.AppendLine($"struct VertexOut"); break; case ShaderStage.Fragment: - prefix = "Fragment"; + context.AppendLine($"struct FragmentOut"); break; case ShaderStage.Compute: - prefix = "Compute"; + context.AppendLine($"struct ComputeOut"); break; } - context.AppendLine($"struct {prefix}Output"); context.EnterScope(); foreach (var ioDefinition in inputs.OrderBy(x => x.Location)) { - string type = GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: true)); + string type = ioDefinition.IoVariable switch + { + IoVariable.Position => "float4", + IoVariable.PointSize => "float", + _ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: true)) + }; string name = ioDefinition.IoVariable switch { IoVariable.Position => "position", + IoVariable.PointSize => "point_size", IoVariable.FragmentOutputColor => "color", _ => $"{DefaultNames.OAttributePrefix}{ioDefinition.Location}" }; string suffix = ioDefinition.IoVariable switch { IoVariable.Position => " [[position]]", + IoVariable.PointSize => " [[point_size]]", + IoVariable.FragmentOutputColor => $" [[color({ioDefinition.Location})]]", _ => "" }; diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs index 5c7800de8..2ec7a1779 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/Instructions/IoMap.cs @@ -24,7 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions IoVariable.FrontFacing => ("front_facing", AggregateType.Bool), IoVariable.InstanceId => ("instance_id", AggregateType.S32), IoVariable.PointCoord => ("point_coord", AggregateType.Vector2), - IoVariable.PointSize => ("point_size", AggregateType.FP32), + IoVariable.PointSize => ("out.point_size", AggregateType.FP32), IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32), IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32), IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch), diff --git a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs index 70f4d5080..0e56629fe 100644 --- a/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs +++ b/src/Ryujinx.Graphics.Shader/CodeGen/Msl/MslGenerator.cs @@ -85,13 +85,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl { funcKeyword = "vertex"; funcName = "vertexMain"; - returnType = "VertexOutput"; + returnType = "VertexOut"; } else if (stage == ShaderStage.Fragment) { funcKeyword = "fragment"; funcName = "fragmentMain"; - returnType = "FragmentOutput"; + returnType = "FragmentOut"; } else if (stage == ShaderStage.Compute) { @@ -106,7 +106,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl } else if (stage == ShaderStage.Fragment) { - args = args.Prepend("FragmentIn in").ToArray(); + args = args.Prepend("VertexOut in").ToArray(); } else if (stage == ShaderStage.Compute) {