diff --git a/third_party/glslang-spirv/SpvBuilder.cpp b/third_party/glslang-spirv/SpvBuilder.cpp index 0a2fa2139..13a6c946a 100644 --- a/third_party/glslang-spirv/SpvBuilder.cpp +++ b/third_party/glslang-spirv/SpvBuilder.cpp @@ -1166,6 +1166,7 @@ void Builder::createMemoryBarrier(unsigned executionScope, unsigned memorySemant // An opcode that has one operands, a result id, and a type Id Builder::createUnaryOp(Op opCode, Id typeId, Id operand) { + assert(operand != 0); Instruction* op = new Instruction(getUniqueId(), typeId, opCode); op->addIdOperand(operand); buildPoint->addInstruction(std::unique_ptr(op)); @@ -1175,6 +1176,8 @@ Id Builder::createUnaryOp(Op opCode, Id typeId, Id operand) Id Builder::createBinOp(Op opCode, Id typeId, Id left, Id right) { + assert(left != 0); + assert(right != 0); Instruction* op = new Instruction(getUniqueId(), typeId, opCode); op->addIdOperand(left); op->addIdOperand(right); @@ -1185,6 +1188,9 @@ Id Builder::createBinOp(Op opCode, Id typeId, Id left, Id right) Id Builder::createTriOp(Op opCode, Id typeId, Id op1, Id op2, Id op3) { + assert(op1 != 0); + assert(op2 != 0); + assert(op3 != 0); Instruction* op = new Instruction(getUniqueId(), typeId, opCode); op->addIdOperand(op1); op->addIdOperand(op2); diff --git a/third_party/glslang-spirv/SpvBuilder.h b/third_party/glslang-spirv/SpvBuilder.h index d6dc61218..7eae4fe91 100644 --- a/third_party/glslang-spirv/SpvBuilder.h +++ b/third_party/glslang-spirv/SpvBuilder.h @@ -93,6 +93,8 @@ public: return id; } + Module* getModule() { return &module; } + // For creating new types (will return old type if the requested one was already made). Id makeVoidType(); Id makeBoolType(); @@ -517,6 +519,7 @@ public: void createBranch(Block* block); void createConditionalBranch(Id condition, Block* thenBlock, Block* elseBlock); void createLoopMerge(Block* mergeBlock, Block* continueBlock, unsigned int control); + void createSelectionMerge(Block* mergeBlock, unsigned int control); protected: Id makeIntConstant(Id typeId, unsigned value, bool specConstant); @@ -527,7 +530,6 @@ public: void transferAccessChainSwizzle(bool dynamic); void simplifyAccessChainSwizzle(); void createAndSetNoPredecessorBlock(const char*); - void createSelectionMerge(Block* mergeBlock, unsigned int control); void dumpInstructions(std::vector&, const std::vector >&) const; SourceLanguage source; diff --git a/third_party/glslang-spirv/spvIR.h b/third_party/glslang-spirv/spvIR.h index 98f4971b4..63e460ebb 100644 --- a/third_party/glslang-spirv/spvIR.h +++ b/third_party/glslang-spirv/spvIR.h @@ -180,6 +180,11 @@ public: void addInstruction(std::unique_ptr inst); void addPredecessor(Block* pred) { predecessors.push_back(pred); pred->successors.push_back(this);} void addLocalVariable(std::unique_ptr inst) { localVariables.push_back(std::move(inst)); } + void insertInstruction(size_t pos, std::unique_ptr inst); + + size_t getInstructionCount() { return instructions.size(); } + Instruction* getInstruction(size_t i) { return instructions[i].get(); } + void removeInstruction(size_t i) { instructions.erase(instructions.begin() + i); } const std::vector& getPredecessors() const { return predecessors; } const std::vector& getSuccessors() const { return successors; } void setUnreachable() { unreachable = true; } @@ -200,6 +205,10 @@ public: bool isTerminated() const { + if (instructions.size() == 0) { + return false; + } + switch (instructions.back()->getOpCode()) { case OpBranch: case OpBranchConditional: @@ -215,6 +224,7 @@ public: void dump(std::vector& out) const { + // OpLabel instructions[0]->dump(out); for (int i = 0; i < (int)localVariables.size(); ++i) localVariables[i]->dump(out); @@ -222,7 +232,51 @@ public: instructions[i]->dump(out); } -protected: + // Moves all instructions from a target block into this block, and removes + // the target block from our list of successors. + // This function assumes this block unconditionally branches to the target + // block directly. + void merge(Block* target_block) { + if (isTerminated()) { + instructions.erase(instructions.end() - 1); + } + + // Find the target block in our successors first. + for (auto it = successors.begin(); it != successors.end(); ++it) { + if (*it == target_block) { + it = successors.erase(it); + break; + } + } + + // Add target block's successors to our successors. + successors.insert(successors.end(), target_block->successors.begin(), + target_block->successors.end()); + + // For each successor, replace the target block in their predecessors with + // us. + for (auto block : successors) { + std::replace(block->predecessors.begin(), block->predecessors.end(), + target_block, this); + } + + // Move instructions from target block into this block. + for (auto it = target_block->instructions.begin(); + it != target_block->instructions.end();) { + if ((*it)->getOpCode() == spv::Op::OpLabel) { + ++it; + continue; + } + + instructions.push_back(std::move(*it)); + it = target_block->instructions.erase(it); + } + + target_block->predecessors.clear(); + target_block->successors.clear(); + } + + protected: Block(const Block&); Block& operator=(Block&); @@ -275,6 +329,17 @@ public: Module& getParent() const { return parent; } Block* getEntryBlock() const { return blocks.front(); } Block* getLastBlock() const { return blocks.back(); } + Block* findBlockById(Id id) + { + for (auto block : blocks) { + if (block->getId() == id) { + return block; + } + } + + return nullptr; + } + std::vector& getBlocks() { return blocks; } void addLocalVariable(std::unique_ptr inst); Id getReturnType() const { return functionInstruction.getTypeId(); } void dump(std::vector& out) const @@ -315,6 +380,8 @@ public: } void addFunction(Function *fun) { functions.push_back(fun); } + const std::vector& getFunctions() const { return functions; } + std::vector& getFunctions() { return functions; } void mapInstruction(Instruction *instruction) { @@ -398,6 +465,14 @@ __inline void Block::addInstruction(std::unique_ptr inst) parent.getParent().mapInstruction(raw_instruction); } +__inline void Block::insertInstruction(size_t pos, std::unique_ptr inst) { + Instruction* raw_instruction = inst.get(); + instructions.insert(instructions.begin() + pos, std::move(inst)); + raw_instruction->setBlock(this); + if (raw_instruction->getResultId()) + parent.getParent().mapInstruction(raw_instruction); +} + }; // end spv namespace #endif // spvIR_H