[SPIR-V] Add support for loops

This commit is contained in:
DrChat 2017-12-22 22:23:28 -06:00
parent 80b0b66e5d
commit fb89973266
2 changed files with 163 additions and 28 deletions

View File

@ -102,6 +102,8 @@ void SpirvShaderTranslator::StartTranslation() {
aL_ = b.createVariable(spv::StorageClass::StorageClassFunction,
vec4_uint_type_, "aL");
loop_count_ = b.createVariable(spv::StorageClass::StorageClassFunction,
vec4_uint_type_, "loop_count");
p0_ = b.createVariable(spv::StorageClass::StorageClassFunction, bool_type_,
"p0");
ps_ = b.createVariable(spv::StorageClass::StorageClassFunction, float_type_,
@ -689,30 +691,47 @@ void SpirvShaderTranslator::PreProcessControlFlowInstructions(
auto& instr = instrs[i];
if (instr.opcode() == ucode::ControlFlowOpcode::kCondJmp) {
uint32_t address = instr.cond_jmp.address();
cf_blocks_[address].labelled = true;
operands.push_back(address);
operands.push_back(cf_blocks_[address].block->getId());
cf_blocks_[address].block->addPredecessor(loop_body_block_);
if (!cf_blocks_[address].labelled) {
cf_blocks_[address].labelled = true;
operands.push_back(address);
operands.push_back(cf_blocks_[address].block->getId());
cf_blocks_[address].block->addPredecessor(loop_body_block_);
}
cf_blocks_[i + 1].labelled = true;
operands.push_back(uint32_t(i + 1));
operands.push_back(cf_blocks_[i + 1].block->getId());
cf_blocks_[i + 1].block->addPredecessor(loop_body_block_);
if (!cf_blocks_[i + 1].labelled) {
cf_blocks_[i + 1].labelled = true;
operands.push_back(uint32_t(i + 1));
operands.push_back(cf_blocks_[i + 1].block->getId());
cf_blocks_[i + 1].block->addPredecessor(loop_body_block_);
}
} else if (instr.opcode() == ucode::ControlFlowOpcode::kLoopStart) {
uint32_t address = instr.loop_start.address();
cf_blocks_[address].labelled = true;
operands.push_back(address);
operands.push_back(cf_blocks_[address].block->getId());
cf_blocks_[address].block->addPredecessor(loop_body_block_);
// Label the loop skip address.
if (!cf_blocks_[address].labelled) {
cf_blocks_[address].labelled = true;
operands.push_back(address);
operands.push_back(cf_blocks_[address].block->getId());
cf_blocks_[address].block->addPredecessor(loop_body_block_);
}
// Label the body
if (!cf_blocks_[i + 1].labelled) {
cf_blocks_[i + 1].labelled = true;
operands.push_back(uint32_t(i + 1));
operands.push_back(cf_blocks_[i + 1].block->getId());
cf_blocks_[i + 1].block->addPredecessor(loop_body_block_);
}
} else if (instr.opcode() == ucode::ControlFlowOpcode::kLoopEnd) {
uint32_t address = instr.loop_end.address();
cf_blocks_[address].labelled = true;
operands.push_back(address);
operands.push_back(cf_blocks_[address].block->getId());
cf_blocks_[address].block->addPredecessor(loop_body_block_);
if (!cf_blocks_[address].labelled) {
cf_blocks_[address].labelled = true;
operands.push_back(address);
operands.push_back(cf_blocks_[address].block->getId());
cf_blocks_[address].block->addPredecessor(loop_body_block_);
}
}
}
@ -865,13 +884,52 @@ void SpirvShaderTranslator::ProcessLoopStartInstruction(
auto head = cf_blocks_[instr.dword_index].block;
b.setBuildPoint(head);
// TODO: Emit a spv LoopMerge
// (need to know the continue target and merge target beforehand though)
// loop il<idx>, L<idx> - loop with loop data il<idx>, end @ L<idx>
EmitUnimplementedTranslationError();
std::vector<Id> offsets;
offsets.push_back(b.makeUintConstant(1)); // loop_consts
offsets.push_back(b.makeUintConstant(instr.loop_constant_index));
auto loop_const = b.createAccessChain(spv::StorageClass::StorageClassUniform,
consts_, offsets);
loop_const = b.createLoad(loop_const);
assert_true(cf_blocks_.size() > instr.dword_index + 1);
b.createBranch(cf_blocks_[instr.dword_index + 1].block);
// uint loop_count_value = loop_const & 0xFF;
auto loop_count_value = b.createBinOp(spv::Op::OpBitwiseAnd, uint_type_,
loop_const, b.makeUintConstant(0xFF));
// uint loop_aL_value = (loop_const >> 8) & 0xFF;
auto loop_aL_value = b.createBinOp(spv::Op::OpShiftRightLogical, uint_type_,
loop_const, b.makeUintConstant(8));
loop_aL_value = b.createBinOp(spv::Op::OpBitwiseAnd, uint_type_,
loop_aL_value, b.makeUintConstant(0xFF));
// loop_count_ = uvec4(loop_count_value, loop_count_.xyz);
auto loop_count = b.createLoad(loop_count_);
loop_count =
b.createRvalueSwizzle(spv::NoPrecision, vec4_uint_type_, loop_count,
std::vector<uint32_t>({0, 0, 1, 2}));
loop_count =
b.createCompositeInsert(loop_count_value, loop_count, vec4_uint_type_, 0);
b.createStore(loop_count, loop_count_);
// aL = aL.xxyz;
auto aL = b.createLoad(aL_);
aL = b.createRvalueSwizzle(spv::NoPrecision, vec4_uint_type_, aL,
std::vector<uint32_t>({0, 0, 1, 2}));
if (!instr.is_repeat) {
// aL.x = loop_aL_value;
aL = b.createCompositeInsert(loop_aL_value, aL, vec4_uint_type_, 0);
}
b.createStore(aL, aL_);
// Short-circuit if loop counter is 0
auto cond = b.createBinOp(spv::Op::OpIEqual, bool_type_, loop_count_value,
b.makeUintConstant(0));
auto next_pc = b.createTriOp(spv::Op::OpSelect, int_type_, cond,
b.makeIntConstant(instr.loop_skip_address),
b.makeIntConstant(instr.dword_index + 1));
b.createStore(next_pc, pc_);
b.createBranch(switch_break_block_);
}
void SpirvShaderTranslator::ProcessLoopEndInstruction(
@ -881,10 +939,83 @@ void SpirvShaderTranslator::ProcessLoopEndInstruction(
auto head = cf_blocks_[instr.dword_index].block;
b.setBuildPoint(head);
EmitUnimplementedTranslationError();
// endloop il<idx>, L<idx> - end loop w/ data il<idx>, head @ L<idx>
auto loop_count = b.createLoad(loop_count_);
auto count = b.createCompositeExtract(loop_count, uint_type_, 0);
count =
b.createBinOp(spv::Op::OpISub, uint_type_, count, b.makeUintConstant(1));
loop_count = b.createCompositeInsert(count, loop_count, vec4_uint_type_, 0);
b.createStore(loop_count, loop_count_);
assert_true(cf_blocks_.size() > instr.dword_index + 1);
b.createBranch(cf_blocks_[instr.dword_index + 1].block);
// if (--loop_count_.x == 0 || [!]p0)
auto c1 = b.createBinOp(spv::Op::OpIEqual, bool_type_, count,
b.makeUintConstant(0));
auto c2 =
b.createBinOp(spv::Op::OpLogicalEqual, bool_type_, b.createLoad(p0_),
b.makeBoolConstant(instr.predicate_condition));
auto cond = b.createBinOp(spv::Op::OpLogicalOr, bool_type_, c1, c2);
auto loop = &b.makeNewBlock();
auto end = &b.makeNewBlock();
auto tail = &b.makeNewBlock();
b.createSelectionMerge(tail, spv::SelectionControlMaskNone);
b.createConditionalBranch(cond, end, loop);
// ================================================
// Loop completed - pop the current loop off the stack and exit
b.setBuildPoint(end);
loop_count = b.createLoad(loop_count_);
auto aL = b.createLoad(aL_);
// loop_count = loop_count.yzw0
loop_count =
b.createRvalueSwizzle(spv::NoPrecision, vec4_uint_type_, loop_count,
std::vector<uint32_t>({1, 2, 3, 3}));
loop_count = b.createCompositeInsert(b.makeUintConstant(0), loop_count,
vec4_uint_type_, 3);
b.createStore(loop_count, loop_count_);
// aL = aL.yzw0
aL = b.createRvalueSwizzle(spv::NoPrecision, vec4_uint_type_, aL,
std::vector<uint32_t>({1, 2, 3, 3}));
aL = b.createCompositeInsert(b.makeUintConstant(0), aL, vec4_uint_type_, 3);
b.createStore(aL, aL_);
// Update pc with the next block
// pc_ = instr.dword_index + 1
b.createStore(b.makeIntConstant(instr.dword_index + 1), pc_);
b.createBranch(tail);
// ================================================
// Still looping - increment aL and loop
b.setBuildPoint(loop);
aL = b.createLoad(aL_);
auto aL_x = b.createCompositeExtract(aL, uint_type_, 0);
std::vector<Id> offsets;
offsets.push_back(b.makeUintConstant(1)); // loop_consts
offsets.push_back(b.makeUintConstant(instr.loop_constant_index));
auto loop_const = b.createAccessChain(spv::StorageClass::StorageClassUniform,
consts_, offsets);
loop_const = b.createLoad(loop_const);
// uint loop_aL_value = (loop_const >> 16) & 0xFF;
auto loop_aL_value = b.createBinOp(spv::Op::OpShiftRightLogical, uint_type_,
loop_const, b.makeUintConstant(16));
loop_aL_value = b.createBinOp(spv::Op::OpBitwiseAnd, uint_type_,
loop_aL_value, b.makeUintConstant(0xFF));
aL_x = b.createBinOp(spv::Op::OpIAdd, uint_type_, aL_x, loop_aL_value);
aL = b.createCompositeInsert(aL_x, aL, vec4_uint_type_, 0);
b.createStore(aL, aL_);
// pc_ = instr.loop_body_address;
b.createStore(b.makeIntConstant(instr.loop_body_address), pc_);
b.createBranch(tail);
// ================================================
b.setBuildPoint(tail);
b.createBranch(switch_break_block_);
}
void SpirvShaderTranslator::ProcessCallInstruction(
@ -2557,10 +2688,10 @@ Id SpirvShaderTranslator::LoadFromOperand(const InstructionOperand& op) {
b.makeUintConstant(storage_base + op.storage_index));
} break;
case InstructionStorageAddressingMode::kAddressRelative: {
// TODO: Based on loop index
// storage_index + aL.x
auto idx = b.createCompositeExtract(b.createLoad(aL_), uint_type_, 0);
storage_index =
b.createBinOp(spv::Op::OpIAdd, uint_type_, b.makeUintConstant(0),
b.createBinOp(spv::Op::OpIAdd, uint_type_, idx,
b.makeUintConstant(storage_base + op.storage_index));
} break;
default:
@ -2716,7 +2847,9 @@ void SpirvShaderTranslator::StoreToResult(Id source_value_id,
} break;
case InstructionStorageAddressingMode::kAddressRelative: {
// storage_index + aL.x
// TODO
auto idx = b.createCompositeExtract(b.createLoad(aL_), uint_type_, 0);
storage_index = b.createBinOp(spv::Op::OpIAdd, uint_type_, idx,
b.makeUintConstant(result.storage_index));
} break;
default:
assert_always();

View File

@ -145,7 +145,9 @@ class SpirvShaderTranslator : public ShaderTranslator {
// Array of AMD registers.
// These values are all pointers.
spv::Id registers_ptr_ = 0, registers_type_ = 0;
spv::Id consts_ = 0, a0_ = 0, aL_ = 0, p0_ = 0;
spv::Id consts_ = 0, a0_ = 0, p0_ = 0;
spv::Id aL_ = 0; // Loop index stack - .x is active loop
spv::Id loop_count_ = 0; // Loop counter stack
spv::Id ps_ = 0, pv_ = 0; // IDs of previous results
spv::Id pc_ = 0; // Program counter
spv::Id pos_ = 0;