[SPIR-V] Add support for loops
This commit is contained in:
parent
80b0b66e5d
commit
fb89973266
|
@ -102,6 +102,8 @@ void SpirvShaderTranslator::StartTranslation() {
|
|||
aL_ = b.createVariable(spv::StorageClass::StorageClassFunction,
|
||||
vec4_uint_type_, "aL");
|
||||
|
||||
loop_count_ = b.createVariable(spv::StorageClass::StorageClassFunction,
|
||||
vec4_uint_type_, "loop_count");
|
||||
p0_ = b.createVariable(spv::StorageClass::StorageClassFunction, bool_type_,
|
||||
"p0");
|
||||
ps_ = b.createVariable(spv::StorageClass::StorageClassFunction, float_type_,
|
||||
|
@ -689,32 +691,49 @@ void SpirvShaderTranslator::PreProcessControlFlowInstructions(
|
|||
auto& instr = instrs[i];
|
||||
if (instr.opcode() == ucode::ControlFlowOpcode::kCondJmp) {
|
||||
uint32_t address = instr.cond_jmp.address();
|
||||
cf_blocks_[address].labelled = true;
|
||||
|
||||
if (!cf_blocks_[address].labelled) {
|
||||
cf_blocks_[address].labelled = true;
|
||||
operands.push_back(address);
|
||||
operands.push_back(cf_blocks_[address].block->getId());
|
||||
cf_blocks_[address].block->addPredecessor(loop_body_block_);
|
||||
}
|
||||
|
||||
if (!cf_blocks_[i + 1].labelled) {
|
||||
cf_blocks_[i + 1].labelled = true;
|
||||
operands.push_back(uint32_t(i + 1));
|
||||
operands.push_back(cf_blocks_[i + 1].block->getId());
|
||||
cf_blocks_[i + 1].block->addPredecessor(loop_body_block_);
|
||||
}
|
||||
} else if (instr.opcode() == ucode::ControlFlowOpcode::kLoopStart) {
|
||||
uint32_t address = instr.loop_start.address();
|
||||
cf_blocks_[address].labelled = true;
|
||||
|
||||
// Label the loop skip address.
|
||||
if (!cf_blocks_[address].labelled) {
|
||||
cf_blocks_[address].labelled = true;
|
||||
operands.push_back(address);
|
||||
operands.push_back(cf_blocks_[address].block->getId());
|
||||
cf_blocks_[address].block->addPredecessor(loop_body_block_);
|
||||
}
|
||||
|
||||
// Label the body
|
||||
if (!cf_blocks_[i + 1].labelled) {
|
||||
cf_blocks_[i + 1].labelled = true;
|
||||
operands.push_back(uint32_t(i + 1));
|
||||
operands.push_back(cf_blocks_[i + 1].block->getId());
|
||||
cf_blocks_[i + 1].block->addPredecessor(loop_body_block_);
|
||||
}
|
||||
} else if (instr.opcode() == ucode::ControlFlowOpcode::kLoopEnd) {
|
||||
uint32_t address = instr.loop_end.address();
|
||||
cf_blocks_[address].labelled = true;
|
||||
|
||||
if (!cf_blocks_[address].labelled) {
|
||||
cf_blocks_[address].labelled = true;
|
||||
operands.push_back(address);
|
||||
operands.push_back(cf_blocks_[address].block->getId());
|
||||
cf_blocks_[address].block->addPredecessor(loop_body_block_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
b.createSelectionMerge(switch_break_block_, 0);
|
||||
b.createNoResultOp(spv::Op::OpSwitch, operands);
|
||||
|
@ -865,13 +884,52 @@ void SpirvShaderTranslator::ProcessLoopStartInstruction(
|
|||
auto head = cf_blocks_[instr.dword_index].block;
|
||||
b.setBuildPoint(head);
|
||||
|
||||
// TODO: Emit a spv LoopMerge
|
||||
// (need to know the continue target and merge target beforehand though)
|
||||
// loop il<idx>, L<idx> - loop with loop data il<idx>, end @ L<idx>
|
||||
|
||||
EmitUnimplementedTranslationError();
|
||||
std::vector<Id> offsets;
|
||||
offsets.push_back(b.makeUintConstant(1)); // loop_consts
|
||||
offsets.push_back(b.makeUintConstant(instr.loop_constant_index));
|
||||
auto loop_const = b.createAccessChain(spv::StorageClass::StorageClassUniform,
|
||||
consts_, offsets);
|
||||
loop_const = b.createLoad(loop_const);
|
||||
|
||||
assert_true(cf_blocks_.size() > instr.dword_index + 1);
|
||||
b.createBranch(cf_blocks_[instr.dword_index + 1].block);
|
||||
// uint loop_count_value = loop_const & 0xFF;
|
||||
auto loop_count_value = b.createBinOp(spv::Op::OpBitwiseAnd, uint_type_,
|
||||
loop_const, b.makeUintConstant(0xFF));
|
||||
|
||||
// uint loop_aL_value = (loop_const >> 8) & 0xFF;
|
||||
auto loop_aL_value = b.createBinOp(spv::Op::OpShiftRightLogical, uint_type_,
|
||||
loop_const, b.makeUintConstant(8));
|
||||
loop_aL_value = b.createBinOp(spv::Op::OpBitwiseAnd, uint_type_,
|
||||
loop_aL_value, b.makeUintConstant(0xFF));
|
||||
|
||||
// loop_count_ = uvec4(loop_count_value, loop_count_.xyz);
|
||||
auto loop_count = b.createLoad(loop_count_);
|
||||
loop_count =
|
||||
b.createRvalueSwizzle(spv::NoPrecision, vec4_uint_type_, loop_count,
|
||||
std::vector<uint32_t>({0, 0, 1, 2}));
|
||||
loop_count =
|
||||
b.createCompositeInsert(loop_count_value, loop_count, vec4_uint_type_, 0);
|
||||
b.createStore(loop_count, loop_count_);
|
||||
|
||||
// aL = aL.xxyz;
|
||||
auto aL = b.createLoad(aL_);
|
||||
aL = b.createRvalueSwizzle(spv::NoPrecision, vec4_uint_type_, aL,
|
||||
std::vector<uint32_t>({0, 0, 1, 2}));
|
||||
if (!instr.is_repeat) {
|
||||
// aL.x = loop_aL_value;
|
||||
aL = b.createCompositeInsert(loop_aL_value, aL, vec4_uint_type_, 0);
|
||||
}
|
||||
b.createStore(aL, aL_);
|
||||
|
||||
// Short-circuit if loop counter is 0
|
||||
auto cond = b.createBinOp(spv::Op::OpIEqual, bool_type_, loop_count_value,
|
||||
b.makeUintConstant(0));
|
||||
auto next_pc = b.createTriOp(spv::Op::OpSelect, int_type_, cond,
|
||||
b.makeIntConstant(instr.loop_skip_address),
|
||||
b.makeIntConstant(instr.dword_index + 1));
|
||||
b.createStore(next_pc, pc_);
|
||||
b.createBranch(switch_break_block_);
|
||||
}
|
||||
|
||||
void SpirvShaderTranslator::ProcessLoopEndInstruction(
|
||||
|
@ -881,10 +939,83 @@ void SpirvShaderTranslator::ProcessLoopEndInstruction(
|
|||
auto head = cf_blocks_[instr.dword_index].block;
|
||||
b.setBuildPoint(head);
|
||||
|
||||
EmitUnimplementedTranslationError();
|
||||
// endloop il<idx>, L<idx> - end loop w/ data il<idx>, head @ L<idx>
|
||||
auto loop_count = b.createLoad(loop_count_);
|
||||
auto count = b.createCompositeExtract(loop_count, uint_type_, 0);
|
||||
count =
|
||||
b.createBinOp(spv::Op::OpISub, uint_type_, count, b.makeUintConstant(1));
|
||||
loop_count = b.createCompositeInsert(count, loop_count, vec4_uint_type_, 0);
|
||||
b.createStore(loop_count, loop_count_);
|
||||
|
||||
assert_true(cf_blocks_.size() > instr.dword_index + 1);
|
||||
b.createBranch(cf_blocks_[instr.dword_index + 1].block);
|
||||
// if (--loop_count_.x == 0 || [!]p0)
|
||||
auto c1 = b.createBinOp(spv::Op::OpIEqual, bool_type_, count,
|
||||
b.makeUintConstant(0));
|
||||
auto c2 =
|
||||
b.createBinOp(spv::Op::OpLogicalEqual, bool_type_, b.createLoad(p0_),
|
||||
b.makeBoolConstant(instr.predicate_condition));
|
||||
auto cond = b.createBinOp(spv::Op::OpLogicalOr, bool_type_, c1, c2);
|
||||
|
||||
auto loop = &b.makeNewBlock();
|
||||
auto end = &b.makeNewBlock();
|
||||
auto tail = &b.makeNewBlock();
|
||||
b.createSelectionMerge(tail, spv::SelectionControlMaskNone);
|
||||
b.createConditionalBranch(cond, end, loop);
|
||||
|
||||
// ================================================
|
||||
// Loop completed - pop the current loop off the stack and exit
|
||||
b.setBuildPoint(end);
|
||||
loop_count = b.createLoad(loop_count_);
|
||||
auto aL = b.createLoad(aL_);
|
||||
|
||||
// loop_count = loop_count.yzw0
|
||||
loop_count =
|
||||
b.createRvalueSwizzle(spv::NoPrecision, vec4_uint_type_, loop_count,
|
||||
std::vector<uint32_t>({1, 2, 3, 3}));
|
||||
loop_count = b.createCompositeInsert(b.makeUintConstant(0), loop_count,
|
||||
vec4_uint_type_, 3);
|
||||
b.createStore(loop_count, loop_count_);
|
||||
|
||||
// aL = aL.yzw0
|
||||
aL = b.createRvalueSwizzle(spv::NoPrecision, vec4_uint_type_, aL,
|
||||
std::vector<uint32_t>({1, 2, 3, 3}));
|
||||
aL = b.createCompositeInsert(b.makeUintConstant(0), aL, vec4_uint_type_, 3);
|
||||
b.createStore(aL, aL_);
|
||||
|
||||
// Update pc with the next block
|
||||
// pc_ = instr.dword_index + 1
|
||||
b.createStore(b.makeIntConstant(instr.dword_index + 1), pc_);
|
||||
b.createBranch(tail);
|
||||
|
||||
// ================================================
|
||||
// Still looping - increment aL and loop
|
||||
b.setBuildPoint(loop);
|
||||
aL = b.createLoad(aL_);
|
||||
auto aL_x = b.createCompositeExtract(aL, uint_type_, 0);
|
||||
|
||||
std::vector<Id> offsets;
|
||||
offsets.push_back(b.makeUintConstant(1)); // loop_consts
|
||||
offsets.push_back(b.makeUintConstant(instr.loop_constant_index));
|
||||
auto loop_const = b.createAccessChain(spv::StorageClass::StorageClassUniform,
|
||||
consts_, offsets);
|
||||
loop_const = b.createLoad(loop_const);
|
||||
|
||||
// uint loop_aL_value = (loop_const >> 16) & 0xFF;
|
||||
auto loop_aL_value = b.createBinOp(spv::Op::OpShiftRightLogical, uint_type_,
|
||||
loop_const, b.makeUintConstant(16));
|
||||
loop_aL_value = b.createBinOp(spv::Op::OpBitwiseAnd, uint_type_,
|
||||
loop_aL_value, b.makeUintConstant(0xFF));
|
||||
|
||||
aL_x = b.createBinOp(spv::Op::OpIAdd, uint_type_, aL_x, loop_aL_value);
|
||||
aL = b.createCompositeInsert(aL_x, aL, vec4_uint_type_, 0);
|
||||
b.createStore(aL, aL_);
|
||||
|
||||
// pc_ = instr.loop_body_address;
|
||||
b.createStore(b.makeIntConstant(instr.loop_body_address), pc_);
|
||||
b.createBranch(tail);
|
||||
|
||||
// ================================================
|
||||
b.setBuildPoint(tail);
|
||||
b.createBranch(switch_break_block_);
|
||||
}
|
||||
|
||||
void SpirvShaderTranslator::ProcessCallInstruction(
|
||||
|
@ -2557,10 +2688,10 @@ Id SpirvShaderTranslator::LoadFromOperand(const InstructionOperand& op) {
|
|||
b.makeUintConstant(storage_base + op.storage_index));
|
||||
} break;
|
||||
case InstructionStorageAddressingMode::kAddressRelative: {
|
||||
// TODO: Based on loop index
|
||||
// storage_index + aL.x
|
||||
auto idx = b.createCompositeExtract(b.createLoad(aL_), uint_type_, 0);
|
||||
storage_index =
|
||||
b.createBinOp(spv::Op::OpIAdd, uint_type_, b.makeUintConstant(0),
|
||||
b.createBinOp(spv::Op::OpIAdd, uint_type_, idx,
|
||||
b.makeUintConstant(storage_base + op.storage_index));
|
||||
} break;
|
||||
default:
|
||||
|
@ -2716,7 +2847,9 @@ void SpirvShaderTranslator::StoreToResult(Id source_value_id,
|
|||
} break;
|
||||
case InstructionStorageAddressingMode::kAddressRelative: {
|
||||
// storage_index + aL.x
|
||||
// TODO
|
||||
auto idx = b.createCompositeExtract(b.createLoad(aL_), uint_type_, 0);
|
||||
storage_index = b.createBinOp(spv::Op::OpIAdd, uint_type_, idx,
|
||||
b.makeUintConstant(result.storage_index));
|
||||
} break;
|
||||
default:
|
||||
assert_always();
|
||||
|
|
|
@ -145,7 +145,9 @@ class SpirvShaderTranslator : public ShaderTranslator {
|
|||
// Array of AMD registers.
|
||||
// These values are all pointers.
|
||||
spv::Id registers_ptr_ = 0, registers_type_ = 0;
|
||||
spv::Id consts_ = 0, a0_ = 0, aL_ = 0, p0_ = 0;
|
||||
spv::Id consts_ = 0, a0_ = 0, p0_ = 0;
|
||||
spv::Id aL_ = 0; // Loop index stack - .x is active loop
|
||||
spv::Id loop_count_ = 0; // Loop counter stack
|
||||
spv::Id ps_ = 0, pv_ = 0; // IDs of previous results
|
||||
spv::Id pc_ = 0; // Program counter
|
||||
spv::Id pos_ = 0;
|
||||
|
|
Loading…
Reference in New Issue