[SPIR-V] Loops

This commit is contained in:
Triang3l 2020-10-25 20:24:48 +03:00
parent a5410ada01
commit 556c8de2ab
2 changed files with 279 additions and 19 deletions

View File

@ -54,6 +54,7 @@ void SpirvShaderTranslator::StartTranslation() {
type_int_ = builder_->makeIntType(32);
type_int4_ = builder_->makeVectorType(type_int_, 4);
type_uint_ = builder_->makeUintType(32);
type_uint3_ = builder_->makeVectorType(type_uint_, 3);
type_uint4_ = builder_->makeVectorType(type_uint_, 4);
type_float_ = builder_->makeFloatType(32);
type_float2_ = builder_->makeVectorType(type_float_, 2);
@ -61,13 +62,20 @@ void SpirvShaderTranslator::StartTranslation() {
type_float4_ = builder_->makeVectorType(type_float_, 4);
const_int_0_ = builder_->makeIntConstant(0);
const_uint_0_ = builder_->makeUintConstant(0);
id_vector_temp_.clear();
id_vector_temp_.reserve(4);
for (uint32_t i = 0; i < 4; ++i) {
id_vector_temp_.push_back(const_int_0_);
}
const_int4_0_ = builder_->makeCompositeConstant(type_int4_, id_vector_temp_);
const_uint_0_ = builder_->makeUintConstant(0);
id_vector_temp_.clear();
id_vector_temp_.reserve(4);
for (uint32_t i = 0; i < 4; ++i) {
id_vector_temp_.push_back(const_uint_0_);
}
const_uint4_0_ =
builder_->makeCompositeConstant(type_uint4_, id_vector_temp_);
const_float_0_ = builder_->makeFloatConstant(0.0f);
id_vector_temp_.clear();
id_vector_temp_.reserve(4);
@ -128,6 +136,9 @@ void SpirvShaderTranslator::StartTranslation() {
var_main_predicate_ = builder_->createVariable(
spv::NoPrecision, spv::StorageClassFunction, type_bool_,
"xe_var_predicate", builder_->makeBoolConstant(false));
var_main_loop_count_ = builder_->createVariable(
spv::NoPrecision, spv::StorageClassFunction, type_uint4_,
"xe_var_loop_count", const_uint4_0_);
var_main_address_absolute_ = builder_->createVariable(
spv::NoPrecision, spv::StorageClassFunction, type_int_,
"xe_var_address_absolute", const_int_0_);
@ -179,6 +190,7 @@ void SpirvShaderTranslator::StartTranslation() {
builder_->setBuildPoint(main_loop_header_);
spv::Id main_loop_pc_current = 0;
if (has_main_switch) {
// OpPhi must be the first in the block.
id_vector_temp_.clear();
id_vector_temp_.reserve(4);
id_vector_temp_.push_back(const_int_0_);
@ -259,6 +271,7 @@ std::vector<uint8_t> SpirvShaderTranslator::CompleteTranslation() {
function_main_->addBlock(main_loop_continue_);
builder_->setBuildPoint(main_loop_continue_);
if (has_main_switch) {
// OpPhi, if added, must be the first in the block.
// If labels were added, but not jumps (for example, due to the call
// instruction not being implemented as of October 18, 2020), send an
// impossible program counter value (-1) to the OpPhi at the next iteration.
@ -367,6 +380,253 @@ void SpirvShaderTranslator::ProcessExecInstructionEnd(
instr.condition);
}
void SpirvShaderTranslator::ProcessLoopStartInstruction(
const ParsedLoopStartInstruction& instr) {
// loop il<idx>, L<idx> - loop with loop data il<idx>, end @ L<idx>
// Loop control is outside execs - actually close the last exec.
CloseExecConditionals();
EnsureBuildPointAvailable();
id_vector_temp_.clear();
id_vector_temp_.reserve(3);
// Loop constants (member 1).
id_vector_temp_.push_back(builder_->makeIntConstant(1));
// 4-component vector.
id_vector_temp_.push_back(
builder_->makeIntConstant(int(instr.loop_constant_index >> 2)));
// Scalar within the vector.
id_vector_temp_.push_back(
builder_->makeIntConstant(int(instr.loop_constant_index & 3)));
// Count (unsigned) in bits 0:7 of the loop constant (struct member 1),
// initial aL (unsigned) in 8:15.
spv::Id loop_constant =
builder_->createLoad(builder_->createAccessChain(
spv::StorageClassUniform,
uniform_bool_loop_constants_, id_vector_temp_),
spv::NoPrecision);
spv::Id const_int_8 = builder_->makeIntConstant(8);
// Push the count to the loop count stack - move XYZ to YZW and set X to the
// new iteration count (swizzling the way glslang does it for similar GLSL).
spv::Id loop_count_stack_old =
builder_->createLoad(var_main_loop_count_, spv::NoPrecision);
spv::Id loop_count_new =
builder_->createTriOp(spv::OpBitFieldUExtract, type_uint_, loop_constant,
const_int_0_, const_int_8);
id_vector_temp_.clear();
id_vector_temp_.reserve(4);
id_vector_temp_.push_back(loop_count_new);
for (unsigned int i = 0; i < 3; ++i) {
id_vector_temp_.push_back(
builder_->createCompositeExtract(loop_count_stack_old, type_uint_, i));
}
builder_->createStore(
builder_->createCompositeConstruct(type_uint4_, id_vector_temp_),
var_main_loop_count_);
// Push aL - keep the same value as in the previous loop if repeating, or the
// new one otherwise.
spv::Id address_relative_stack_old =
builder_->createLoad(var_main_address_relative_, spv::NoPrecision);
id_vector_temp_.clear();
id_vector_temp_.reserve(4);
if (instr.is_repeat) {
id_vector_temp_.emplace_back();
} else {
id_vector_temp_.push_back(builder_->createUnaryOp(
spv::OpBitcast, type_int_,
builder_->createTriOp(spv::OpBitFieldUExtract, type_uint_,
loop_constant, const_int_8, const_int_8)));
}
for (unsigned int i = 0; i < 3; ++i) {
id_vector_temp_.push_back(builder_->createCompositeExtract(
address_relative_stack_old, type_int_, i));
}
if (instr.is_repeat) {
id_vector_temp_[0] = id_vector_temp_[1];
}
builder_->createStore(
builder_->createCompositeConstruct(type_int4_, id_vector_temp_),
var_main_address_relative_);
// Break (jump to the skip label) if the loop counter is 0 (since the
// condition is checked in the end).
spv::Block& head_block = *builder_->getBuildPoint();
spv::Id loop_count_zero = builder_->createBinOp(
spv::OpIEqual, type_bool_, loop_count_new, const_uint_0_);
spv::Block& skip_block = builder_->makeNewBlock();
spv::Block& body_block = builder_->makeNewBlock();
{
std::unique_ptr<spv::Instruction> selection_merge_op =
std::make_unique<spv::Instruction>(spv::OpSelectionMerge);
selection_merge_op->addIdOperand(body_block.getId());
selection_merge_op->addImmediateOperand(spv::SelectionControlMaskNone);
head_block.addInstruction(std::move(selection_merge_op));
}
{
std::unique_ptr<spv::Instruction> branch_conditional_op =
std::make_unique<spv::Instruction>(spv::OpBranchConditional);
branch_conditional_op->addIdOperand(loop_count_zero);
branch_conditional_op->addIdOperand(skip_block.getId());
branch_conditional_op->addIdOperand(body_block.getId());
// More likely to enter than to skip.
branch_conditional_op->addImmediateOperand(1);
branch_conditional_op->addImmediateOperand(2);
head_block.addInstruction(std::move(branch_conditional_op));
}
skip_block.addPredecessor(&head_block);
body_block.addPredecessor(&head_block);
builder_->setBuildPoint(&skip_block);
main_switch_next_pc_phi_operands_.push_back(
builder_->makeIntConstant(int(instr.loop_skip_address)));
main_switch_next_pc_phi_operands_.push_back(
builder_->getBuildPoint()->getId());
builder_->createBranch(main_loop_continue_);
builder_->setBuildPoint(&body_block);
}
void SpirvShaderTranslator::ProcessLoopEndInstruction(
const ParsedLoopEndInstruction& instr) {
// endloop il<idx>, L<idx> - end loop w/ data il<idx>, head @ L<idx>
// Loop control is outside execs - actually close the last exec.
CloseExecConditionals();
EnsureBuildPointAvailable();
// Subtract 1 from the loop counter (will store later).
spv::Id loop_count_stack_old =
builder_->createLoad(var_main_loop_count_, spv::NoPrecision);
spv::Id loop_count = builder_->createBinOp(
spv::OpISub, type_uint_,
builder_->createCompositeExtract(loop_count_stack_old, type_uint_, 0),
builder_->makeUintConstant(1));
spv::Id address_relative_stack_old =
builder_->createLoad(var_main_address_relative_, spv::NoPrecision);
// Predicated break works like break if (loop_count == 0 || [!]p0).
// Three options, due to logical operations usage (so OpLogicalNot is not
// required):
// - Continue if (loop_count != 0).
// - Continue if (loop_count != 0 && p0), if breaking if !p0.
// - Break if (loop_count == 0 || p0), if breaking if p0.
bool break_is_true = instr.is_predicated_break && instr.predicate_condition;
spv::Id condition =
builder_->createBinOp(break_is_true ? spv::OpIEqual : spv::OpINotEqual,
type_bool_, loop_count, const_uint_0_);
if (instr.is_predicated_break) {
condition = builder_->createBinOp(
instr.predicate_condition ? spv::OpLogicalOr : spv::OpLogicalAnd,
type_bool_, condition,
builder_->createLoad(var_main_predicate_, spv::NoPrecision));
}
spv::Block& body_block = *builder_->getBuildPoint();
spv::Block& continue_block = builder_->makeNewBlock();
spv::Block& break_block = builder_->makeNewBlock();
{
std::unique_ptr<spv::Instruction> selection_merge_op =
std::make_unique<spv::Instruction>(spv::OpSelectionMerge);
selection_merge_op->addIdOperand(break_block.getId());
selection_merge_op->addImmediateOperand(spv::SelectionControlMaskNone);
body_block.addInstruction(std::move(selection_merge_op));
}
{
std::unique_ptr<spv::Instruction> branch_conditional_op =
std::make_unique<spv::Instruction>(spv::OpBranchConditional);
branch_conditional_op->addIdOperand(condition);
// More likely to continue than to break.
if (break_is_true) {
branch_conditional_op->addIdOperand(break_block.getId());
branch_conditional_op->addIdOperand(continue_block.getId());
branch_conditional_op->addImmediateOperand(1);
branch_conditional_op->addImmediateOperand(2);
} else {
branch_conditional_op->addIdOperand(continue_block.getId());
branch_conditional_op->addIdOperand(break_block.getId());
branch_conditional_op->addImmediateOperand(2);
branch_conditional_op->addImmediateOperand(1);
}
body_block.addInstruction(std::move(branch_conditional_op));
}
continue_block.addPredecessor(&body_block);
break_block.addPredecessor(&body_block);
// Continue case.
builder_->setBuildPoint(&continue_block);
// Store the loop count with 1 subtracted.
builder_->createStore(builder_->createCompositeInsert(
loop_count, loop_count_stack_old, type_uint4_, 0),
var_main_loop_count_);
// Extract the value to add to aL (signed, in bits 16:23 of the loop
// constant).
id_vector_temp_.clear();
id_vector_temp_.reserve(3);
// Loop constants (member 1).
id_vector_temp_.push_back(builder_->makeIntConstant(1));
// 4-component vector.
id_vector_temp_.push_back(
builder_->makeIntConstant(int(instr.loop_constant_index >> 2)));
// Scalar within the vector.
id_vector_temp_.push_back(
builder_->makeIntConstant(int(instr.loop_constant_index & 3)));
spv::Id loop_constant =
builder_->createLoad(builder_->createAccessChain(
spv::StorageClassUniform,
uniform_bool_loop_constants_, id_vector_temp_),
spv::NoPrecision);
spv::Id address_relative_old = builder_->createCompositeExtract(
address_relative_stack_old, type_int_, 0);
builder_->createStore(
builder_->createCompositeInsert(
builder_->createBinOp(
spv::OpIAdd, type_int_, address_relative_old,
builder_->createTriOp(
spv::OpBitFieldSExtract, type_int_,
builder_->createUnaryOp(spv::OpBitcast, type_int_,
loop_constant),
builder_->makeIntConstant(16), builder_->makeIntConstant(8))),
address_relative_stack_old, type_int4_, 0),
var_main_address_relative_);
// Jump back to the beginning of the loop body.
main_switch_next_pc_phi_operands_.push_back(
builder_->makeIntConstant(int(instr.loop_body_address)));
main_switch_next_pc_phi_operands_.push_back(
builder_->getBuildPoint()->getId());
builder_->createBranch(main_loop_continue_);
// Break case.
builder_->setBuildPoint(&break_block);
// Pop the current loop off the loop counter and the relative address stacks -
// move YZW to XYZ and set W to 0.
id_vector_temp_.clear();
id_vector_temp_.reserve(4);
for (unsigned int i = 1; i < 4; ++i) {
id_vector_temp_.push_back(
builder_->createCompositeExtract(loop_count_stack_old, type_uint_, i));
}
id_vector_temp_.push_back(const_uint_0_);
builder_->createStore(
builder_->createCompositeConstruct(type_uint4_, id_vector_temp_),
var_main_loop_count_);
id_vector_temp_.clear();
id_vector_temp_.reserve(4);
for (unsigned int i = 1; i < 4; ++i) {
id_vector_temp_.push_back(builder_->createCompositeExtract(
address_relative_stack_old, type_int_, i));
}
id_vector_temp_.push_back(const_int_0_);
builder_->createStore(
builder_->createCompositeConstruct(type_int4_, id_vector_temp_),
var_main_address_relative_);
id_vector_temp_.clear();
id_vector_temp_.reserve(4);
// Now going to fall through to the next control flow instruction.
}
void SpirvShaderTranslator::ProcessJumpInstruction(
const ParsedJumpInstruction& instr) {
// Treat like exec, merge with execs if possible, since it's an if too.
@ -386,7 +646,15 @@ void SpirvShaderTranslator::ProcessJumpInstruction(
// on the control flow level too.
CloseInstructionPredication();
JumpToLabel(instr.target_address);
if (builder_->getBuildPoint()->isTerminated()) {
// Unreachable for some reason.
return;
}
main_switch_next_pc_phi_operands_.push_back(
builder_->makeIntConstant(int(instr.target_address)));
main_switch_next_pc_phi_operands_.push_back(
builder_->getBuildPoint()->getId());
builder_->createBranch(main_loop_continue_);
}
void SpirvShaderTranslator::EnsureBuildPointAvailable() {
@ -520,7 +788,7 @@ void SpirvShaderTranslator::UpdateExecConditionals(
builder_->makeIntConstant(int(bool_constant_index >> 7)));
// 32-bit scalar of a 128-bit vector.
id_vector_temp_.push_back(
builder_->makeIntConstant(int((bool_constant_index >> 5) & 2)));
builder_->makeIntConstant(int((bool_constant_index >> 5) & 3)));
spv::Id bool_constant_scalar =
builder_->createLoad(builder_->createAccessChain(
spv::StorageClassUniform,
@ -589,18 +857,5 @@ void SpirvShaderTranslator::CloseExecConditionals() {
cf_exec_predicate_written_ = false;
}
void SpirvShaderTranslator::JumpToLabel(uint32_t address) {
assert_false(label_addresses().empty());
spv::Block& origin_block = *builder_->getBuildPoint();
if (origin_block.isTerminated()) {
// Unreachable jump for some reason.
return;
}
main_switch_next_pc_phi_operands_.push_back(
builder_->makeIntConstant(int(address)));
main_switch_next_pc_phi_operands_.push_back(origin_block.getId());
builder_->createBranch(main_loop_continue_);
}
} // namespace gpu
} // namespace xe

View File

@ -60,6 +60,10 @@ class SpirvShaderTranslator : public ShaderTranslator {
void ProcessExecInstructionBegin(const ParsedExecInstruction& instr) override;
void ProcessExecInstructionEnd(const ParsedExecInstruction& instr) override;
void ProcessLoopStartInstruction(
const ParsedLoopStartInstruction& instr) override;
void ProcessLoopEndInstruction(
const ParsedLoopEndInstruction& instr) override;
void ProcessJumpInstruction(const ParsedJumpInstruction& instr) override;
private:
@ -99,9 +103,6 @@ class SpirvShaderTranslator : public ShaderTranslator {
// Closes conditionals opened by exec and instructions within them (but not by
// labels) and updates the state accordingly.
void CloseExecConditionals();
// Sets the next iteration's program counter value (adding it to phi operands)
// and closes the current block.
void JumpToLabel(uint32_t address);
bool supports_clip_distance_;
bool supports_cull_distance_;
@ -118,6 +119,7 @@ class SpirvShaderTranslator : public ShaderTranslator {
spv::Id type_int_;
spv::Id type_int4_;
spv::Id type_uint_;
spv::Id type_uint3_;
spv::Id type_uint4_;
spv::Id type_float_;
spv::Id type_float2_;
@ -127,6 +129,7 @@ class SpirvShaderTranslator : public ShaderTranslator {
spv::Id const_int_0_;
spv::Id const_int4_0_;
spv::Id const_uint_0_;
spv::Id const_uint4_0_;
spv::Id const_float_0_;
spv::Id const_float4_0_;
@ -149,6 +152,8 @@ class SpirvShaderTranslator : public ShaderTranslator {
spv::Function* function_main_;
// bool.
spv::Id var_main_predicate_;
// uint4.
spv::Id var_main_loop_count_;
// int4.
spv::Id var_main_address_relative_;
// int.