forked from ShuriZma/suyu
shader/other: Implement thread comparisons (NV_shader_thread_group)
Hardware S2R special registers match gl_Thread*MaskNV. We can trivially implement these using Nvidia's extension on OpenGL or naively stubbing them with the ARB instructions to match. This might cause issues if the host device warp size doesn't match Nvidia's. That said, this is unlikely on proper shaders. Refer to the attached url for more documentation about these flags. https://www.khronos.org/registry/OpenGL/extensions/NV/NV_shader_thread_group.txt
This commit is contained in:
parent
cf4ee279c6
commit
e2b67a868b
|
@ -2309,6 +2309,18 @@ private:
|
||||||
return {"gl_SubGroupInvocationARB", Type::Uint};
|
return {"gl_SubGroupInvocationARB", Type::Uint};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <const std::string_view& comparison>
|
||||||
|
Expression ThreadMask(Operation) {
|
||||||
|
if (device.HasWarpIntrinsics()) {
|
||||||
|
return {fmt::format("gl_Thread{}MaskNV", comparison), Type::Uint};
|
||||||
|
}
|
||||||
|
if (device.HasShaderBallot()) {
|
||||||
|
return {fmt::format("uint(gl_SubGroup{}MaskARB)", comparison), Type::Uint};
|
||||||
|
}
|
||||||
|
LOG_ERROR(Render_OpenGL, "Thread mask intrinsics are required by the shader");
|
||||||
|
return {"0U", Type::Uint};
|
||||||
|
}
|
||||||
|
|
||||||
Expression ShuffleIndexed(Operation operation) {
|
Expression ShuffleIndexed(Operation operation) {
|
||||||
std::string value = VisitOperand(operation, 0).AsFloat();
|
std::string value = VisitOperand(operation, 0).AsFloat();
|
||||||
|
|
||||||
|
@ -2337,6 +2349,12 @@ private:
|
||||||
static constexpr std::string_view NotEqual = "!=";
|
static constexpr std::string_view NotEqual = "!=";
|
||||||
static constexpr std::string_view GreaterEqual = ">=";
|
static constexpr std::string_view GreaterEqual = ">=";
|
||||||
|
|
||||||
|
static constexpr std::string_view Eq = "Eq";
|
||||||
|
static constexpr std::string_view Ge = "Ge";
|
||||||
|
static constexpr std::string_view Gt = "Gt";
|
||||||
|
static constexpr std::string_view Le = "Le";
|
||||||
|
static constexpr std::string_view Lt = "Lt";
|
||||||
|
|
||||||
static constexpr std::string_view Add = "Add";
|
static constexpr std::string_view Add = "Add";
|
||||||
static constexpr std::string_view Min = "Min";
|
static constexpr std::string_view Min = "Min";
|
||||||
static constexpr std::string_view Max = "Max";
|
static constexpr std::string_view Max = "Max";
|
||||||
|
@ -2554,6 +2572,11 @@ private:
|
||||||
&GLSLDecompiler::VoteEqual,
|
&GLSLDecompiler::VoteEqual,
|
||||||
|
|
||||||
&GLSLDecompiler::ThreadId,
|
&GLSLDecompiler::ThreadId,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Eq>,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Ge>,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Gt>,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Le>,
|
||||||
|
&GLSLDecompiler::ThreadMask<Func::Lt>,
|
||||||
&GLSLDecompiler::ShuffleIndexed,
|
&GLSLDecompiler::ShuffleIndexed,
|
||||||
|
|
||||||
&GLSLDecompiler::MemoryBarrierGL,
|
&GLSLDecompiler::MemoryBarrierGL,
|
||||||
|
|
|
@ -515,6 +515,16 @@ private:
|
||||||
void DeclareCommon() {
|
void DeclareCommon() {
|
||||||
thread_id =
|
thread_id =
|
||||||
DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupLocalInvocationId, t_in_uint, "thread_id");
|
||||||
|
thread_masks[0] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupEqMask, t_in_uint4, "thread_eq_mask");
|
||||||
|
thread_masks[1] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupGeMask, t_in_uint4, "thread_ge_mask");
|
||||||
|
thread_masks[2] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupGtMask, t_in_uint4, "thread_gt_mask");
|
||||||
|
thread_masks[3] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupLeMask, t_in_uint4, "thread_le_mask");
|
||||||
|
thread_masks[4] =
|
||||||
|
DeclareInputBuiltIn(spv::BuiltIn::SubgroupLtMask, t_in_uint4, "thread_lt_mask");
|
||||||
}
|
}
|
||||||
|
|
||||||
void DeclareVertex() {
|
void DeclareVertex() {
|
||||||
|
@ -2175,6 +2185,13 @@ private:
|
||||||
return {OpLoad(t_uint, thread_id), Type::Uint};
|
return {OpLoad(t_uint, thread_id), Type::Uint};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <std::size_t index>
|
||||||
|
Expression ThreadMask(Operation) {
|
||||||
|
// TODO(Rodrigo): Handle devices with different warp sizes
|
||||||
|
const Id mask = thread_masks[index];
|
||||||
|
return {OpLoad(t_uint, AccessElement(t_in_uint, mask, 0)), Type::Uint};
|
||||||
|
}
|
||||||
|
|
||||||
Expression ShuffleIndexed(Operation operation) {
|
Expression ShuffleIndexed(Operation operation) {
|
||||||
const Id value = AsFloat(Visit(operation[0]));
|
const Id value = AsFloat(Visit(operation[0]));
|
||||||
const Id index = AsUint(Visit(operation[1]));
|
const Id index = AsUint(Visit(operation[1]));
|
||||||
|
@ -2639,6 +2656,11 @@ private:
|
||||||
&SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,
|
&SPIRVDecompiler::Vote<&Module::OpSubgroupAllEqualKHR>,
|
||||||
|
|
||||||
&SPIRVDecompiler::ThreadId,
|
&SPIRVDecompiler::ThreadId,
|
||||||
|
&SPIRVDecompiler::ThreadMask<0>, // Eq
|
||||||
|
&SPIRVDecompiler::ThreadMask<1>, // Ge
|
||||||
|
&SPIRVDecompiler::ThreadMask<2>, // Gt
|
||||||
|
&SPIRVDecompiler::ThreadMask<3>, // Le
|
||||||
|
&SPIRVDecompiler::ThreadMask<4>, // Lt
|
||||||
&SPIRVDecompiler::ShuffleIndexed,
|
&SPIRVDecompiler::ShuffleIndexed,
|
||||||
|
|
||||||
&SPIRVDecompiler::MemoryBarrierGL,
|
&SPIRVDecompiler::MemoryBarrierGL,
|
||||||
|
@ -2763,6 +2785,7 @@ private:
|
||||||
Id workgroup_id{};
|
Id workgroup_id{};
|
||||||
Id local_invocation_id{};
|
Id local_invocation_id{};
|
||||||
Id thread_id{};
|
Id thread_id{};
|
||||||
|
std::array<Id, 5> thread_masks{}; // eq, ge, gt, le, lt
|
||||||
|
|
||||||
VertexIndices in_indices;
|
VertexIndices in_indices;
|
||||||
VertexIndices out_indices;
|
VertexIndices out_indices;
|
||||||
|
|
|
@ -109,6 +109,27 @@ u32 ShaderIR::DecodeOther(NodeBlock& bb, u32 pc) {
|
||||||
return Operation(OperationCode::WorkGroupIdY);
|
return Operation(OperationCode::WorkGroupIdY);
|
||||||
case SystemVariable::CtaIdZ:
|
case SystemVariable::CtaIdZ:
|
||||||
return Operation(OperationCode::WorkGroupIdZ);
|
return Operation(OperationCode::WorkGroupIdZ);
|
||||||
|
case SystemVariable::EqMask:
|
||||||
|
case SystemVariable::LtMask:
|
||||||
|
case SystemVariable::LeMask:
|
||||||
|
case SystemVariable::GtMask:
|
||||||
|
case SystemVariable::GeMask:
|
||||||
|
uses_warps = true;
|
||||||
|
switch (instr.sys20) {
|
||||||
|
case SystemVariable::EqMask:
|
||||||
|
return Operation(OperationCode::ThreadEqMask);
|
||||||
|
case SystemVariable::LtMask:
|
||||||
|
return Operation(OperationCode::ThreadLtMask);
|
||||||
|
case SystemVariable::LeMask:
|
||||||
|
return Operation(OperationCode::ThreadLeMask);
|
||||||
|
case SystemVariable::GtMask:
|
||||||
|
return Operation(OperationCode::ThreadGtMask);
|
||||||
|
case SystemVariable::GeMask:
|
||||||
|
return Operation(OperationCode::ThreadGeMask);
|
||||||
|
default:
|
||||||
|
UNREACHABLE();
|
||||||
|
return Immediate(0u);
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
UNIMPLEMENTED_MSG("Unhandled system move: {}",
|
UNIMPLEMENTED_MSG("Unhandled system move: {}",
|
||||||
static_cast<u32>(instr.sys20.Value()));
|
static_cast<u32>(instr.sys20.Value()));
|
||||||
|
|
|
@ -226,6 +226,11 @@ enum class OperationCode {
|
||||||
VoteEqual, /// (bool) -> bool
|
VoteEqual, /// (bool) -> bool
|
||||||
|
|
||||||
ThreadId, /// () -> uint
|
ThreadId, /// () -> uint
|
||||||
|
ThreadEqMask, /// () -> uint
|
||||||
|
ThreadGeMask, /// () -> uint
|
||||||
|
ThreadGtMask, /// () -> uint
|
||||||
|
ThreadLeMask, /// () -> uint
|
||||||
|
ThreadLtMask, /// () -> uint
|
||||||
ShuffleIndexed, /// (uint value, uint index) -> uint
|
ShuffleIndexed, /// (uint value, uint index) -> uint
|
||||||
|
|
||||||
MemoryBarrierGL, /// () -> void
|
MemoryBarrierGL, /// () -> void
|
||||||
|
|
Loading…
Reference in New Issue