LLVM DSL: expression matching (preview 2)

Implement more instructions.
This commit is contained in:
Nekotekina 2019-04-25 03:33:18 +03:00
parent aca61fdcf9
commit 2ade3c594c
2 changed files with 348 additions and 40 deletions

View File

@ -584,6 +584,34 @@ struct llvm_sum
llvm_match_tuple<A1, A2, A3> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
llvm::Value* v3 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == llvm::Instruction::Add)
{
v3 = i->getOperand(1);
if (auto r3 = a3.match(v3); v3)
{
i = llvm::dyn_cast<llvm::BinaryOperator>(i->getOperand(0));
if (i && i->getOpcode() == llvm::Instruction::Add)
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2, r3);
}
}
}
}
}
value = nullptr;
return {};
}
@ -781,6 +809,21 @@ struct llvm_neg
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::BinaryOperator>(value); i && i->getOpcode() == opc)
{
v1 = i->getOperand(1);
if (i->getOperand(0) == llvm::ConstantFP::getZeroValueForNegation(v1->getType()))
{
if (auto r1 = a1.match(v1); v1)
{
return r1;
}
}
}
value = nullptr;
return {};
}
@ -945,6 +988,28 @@ struct llvm_fshl
llvm_match_tuple<A1, A2, A3> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
llvm::Value* v3 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CallInst>(value); i && i->getIntrinsicID() == llvm::Intrinsic::fshl)
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
v3 = i->getOperand(2);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
if (auto r3 = a3.match(v3); v3)
{
return std::tuple_cat(r1, r2, r3);
}
}
}
}
value = nullptr;
return {};
}
@ -995,6 +1060,28 @@ struct llvm_fshr
llvm_match_tuple<A1, A2, A3> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
llvm::Value* v3 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CallInst>(value); i && i->getIntrinsicID() == llvm::Intrinsic::fshr)
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
v3 = i->getOperand(2);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
if (auto r3 = a3.match(v3); v3)
{
return std::tuple_cat(r1, r2, r3);
}
}
}
}
value = nullptr;
return {};
}
@ -1027,6 +1114,26 @@ struct llvm_rol
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CallInst>(value); i && i->getIntrinsicID() == llvm::Intrinsic::fshl)
{
v1 = i->getOperand(0);
v2 = i->getOperand(2);
if (i->getOperand(1) == v1)
{
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
}
value = nullptr;
return {};
}
@ -1229,6 +1336,23 @@ struct llvm_cmp
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::ICmpInst>(value); i && i->getOpcode() == pred)
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
return {};
}
@ -1272,6 +1396,23 @@ struct llvm_ord
llvm_match_tuple<Cmp> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::FCmpInst>(value); i && i->getOpcode() == pred)
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = cmp.a1.match(v1); v1)
{
if (auto r2 = cmp.a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
return {};
}
@ -1308,6 +1449,23 @@ struct llvm_uno
llvm_match_tuple<Cmp> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::FCmpInst>(value); i && i->getOpcode() == pred)
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = cmp.a1.match(v1); v1)
{
if (auto r2 = cmp.a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
return {};
}
@ -1414,6 +1572,14 @@ struct llvm_noncast
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
if (value)
{
if (auto r1 = a1.match(value); value)
{
return r1;
}
}
value = nullptr;
return {};
}
@ -1432,6 +1598,8 @@ struct llvm_bitcast
static_assert(bitsize0 == bitsize1, "llvm_bitcast<>: invalid type (size mismatch)");
static_assert(llvm_value_t<T>::is_int || llvm_value_t<T>::is_float, "llvm_bitcast<>: invalid type");
static_assert(llvm_value_t<U>::is_int || llvm_value_t<U>::is_float, "llvm_bitcast<>: invalid result type");
static_assert(llvm_value_t<T>::is_int != llvm_value_t<U>::is_int || llvm_value_t<T>::is_vector != llvm_value_t<U>::is_vector,
"llvm_bitcast<>: no-op cast (use noncast)");
static constexpr bool is_ok =
bitsize0 && bitsize0 == bitsize1 &&
@ -1443,12 +1611,6 @@ struct llvm_bitcast
const auto v1 = a1.eval(ir);
const auto rt = llvm_value_t<U>::get_type(ir->getContext());
if constexpr (llvm_value_t<T>::is_int == llvm_value_t<U>::is_int && llvm_value_t<T>::is_vector == llvm_value_t<U>::is_vector)
{
// No-op case
return v1;
}
if (const auto c1 = llvm::dyn_cast<llvm::Constant>(v1))
{
const auto module = ir->GetInsertBlock()->getParent()->getParent();
@ -1465,6 +1627,36 @@ struct llvm_bitcast
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CastInst>(value); i && i->getOpcode() == llvm::Instruction::BitCast)
{
v1 = i->getOperand(0);
if (llvm_value_t<U>::get_type(v1->getContext()) == i->getDestTy())
{
if (auto r1 = a1.match(v1); v1)
{
return r1;
}
}
}
if (auto c = llvm::dyn_cast_or_null<llvm::Constant>(value))
{
// TODO
// const auto target = llvm_value_t<T>::get_type(c->getContext());
// // Reverse bitcast on a constant
// if (llvm::Value* cv = llvm::ConstantFoldCastOperand(llvm::Instruction::BitCast, c, target, module->getDataLayout()))
// {
// if (auto r1 = a1.match(cv); cv)
// {
// return r1;
// }
// }
}
value = nullptr;
return {};
}
@ -1495,6 +1687,21 @@ struct llvm_trunc
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CastInst>(value); i && i->getOpcode() == llvm::Instruction::Trunc)
{
v1 = i->getOperand(0);
if (llvm_value_t<U>::get_type(v1->getContext()) == i->getDestTy())
{
if (auto r1 = a1.match(v1); v1)
{
return r1;
}
}
}
value = nullptr;
return {};
}
@ -1525,6 +1732,21 @@ struct llvm_sext
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CastInst>(value); i && i->getOpcode() == llvm::Instruction::SExt)
{
v1 = i->getOperand(0);
if (llvm_value_t<U>::get_type(v1->getContext()) == i->getDestTy())
{
if (auto r1 = a1.match(v1); v1)
{
return r1;
}
}
}
value = nullptr;
return {};
}
@ -1555,6 +1777,21 @@ struct llvm_zext
llvm_match_tuple<A1> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CastInst>(value); i && i->getOpcode() == llvm::Instruction::ZExt)
{
v1 = i->getOperand(0);
if (llvm_value_t<U>::get_type(v1->getContext()) == i->getDestTy())
{
if (auto r1 = a1.match(v1); v1)
{
return r1;
}
}
}
value = nullptr;
return {};
}
@ -1583,6 +1820,28 @@ struct llvm_select
llvm_match_tuple<A1, A2, A3> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
llvm::Value* v3 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::SelectInst>(value))
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
v3 = i->getOperand(2);
if (auto r1 = cond.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
if (auto r3 = a3.match(v3); v3)
{
return std::tuple_cat(r1, r2, r3);
}
}
}
}
value = nullptr;
return {};
}
@ -1668,16 +1927,12 @@ struct llvm_add_sat
static constexpr bool is_ok = llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint;
static llvm::Function* get_sadd_sat(llvm::IRBuilder<>* ir)
{
const auto module = ir->GetInsertBlock()->getParent()->getParent();
return llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::sadd_sat, {llvm_value_t<T>::get_type(ir->getContext())});
}
static constexpr auto intr = llvm_value_t<T>::is_sint ? llvm::Intrinsic::sadd_sat : llvm::Intrinsic::uadd_sat;
static llvm::Function* get_uadd_sat(llvm::IRBuilder<>* ir)
static llvm::Function* get_add_sat(llvm::IRBuilder<>* ir)
{
const auto module = ir->GetInsertBlock()->getParent()->getParent();
return llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::uadd_sat, {llvm_value_t<T>::get_type(ir->getContext())});
return llvm::Intrinsic::getDeclaration(module, intr, {llvm_value_t<T>::get_type(ir->getContext())});
}
llvm::Value* eval(llvm::IRBuilder<>* ir) const
@ -1704,19 +1959,28 @@ struct llvm_add_sat
}
}
if constexpr (llvm_value_t<T>::is_sint)
{
return ir->CreateCall(get_sadd_sat(ir), {v1, v2});
}
if constexpr (llvm_value_t<T>::is_uint)
{
return ir->CreateCall(get_uadd_sat(ir), {v1, v2});
}
return ir->CreateCall(get_add_sat(ir), {v1, v2});
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CallInst>(value); i && i->getIntrinsicID() == intr)
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
return {};
}
@ -1734,16 +1998,12 @@ struct llvm_sub_sat
static constexpr bool is_ok = llvm_value_t<T>::is_sint || llvm_value_t<T>::is_uint;
static llvm::Function* get_ssub_sat(llvm::IRBuilder<>* ir)
{
const auto module = ir->GetInsertBlock()->getParent()->getParent();
return llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::ssub_sat, {llvm_value_t<T>::get_type(ir->getContext())});
}
static constexpr auto intr = llvm_value_t<T>::is_sint ? llvm::Intrinsic::ssub_sat : llvm::Intrinsic::usub_sat;
static llvm::Function* get_usub_sat(llvm::IRBuilder<>* ir)
static llvm::Function* get_sub_sat(llvm::IRBuilder<>* ir)
{
const auto module = ir->GetInsertBlock()->getParent()->getParent();
return llvm::Intrinsic::getDeclaration(module, llvm::Intrinsic::usub_sat, {llvm_value_t<T>::get_type(ir->getContext())});
return llvm::Intrinsic::getDeclaration(module, intr, {llvm_value_t<T>::get_type(ir->getContext())});
}
llvm::Value* eval(llvm::IRBuilder<>* ir) const
@ -1769,19 +2029,28 @@ struct llvm_sub_sat
}
}
if constexpr (llvm_value_t<T>::is_sint)
{
return ir->CreateCall(get_ssub_sat(ir), {v1, v2});
}
if constexpr (llvm_value_t<T>::is_uint)
{
return ir->CreateCall(get_usub_sat(ir), {v1, v2});
}
return ir->CreateCall(get_sub_sat(ir), {v1, v2});
}
llvm_match_tuple<A1, A2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::CallInst>(value); i && i->getIntrinsicID() == intr)
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = a2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
return {};
}
@ -1811,6 +2080,23 @@ struct llvm_extract
llvm_match_tuple<A1, I2> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::ExtractElementInst>(value))
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = i2.match(v2); v2)
{
return std::tuple_cat(r1, r2);
}
}
}
value = nullptr;
return {};
}
@ -1843,6 +2129,28 @@ struct llvm_insert
llvm_match_tuple<A1, I2, A3> match(llvm::Value*& value) const
{
llvm::Value* v1 = {};
llvm::Value* v2 = {};
llvm::Value* v3 = {};
if (auto i = llvm::dyn_cast_or_null<llvm::InsertElementInst>(value))
{
v1 = i->getOperand(0);
v2 = i->getOperand(1);
v3 = i->getOperand(2);
if (auto r1 = a1.match(v1); v1)
{
if (auto r2 = i2.match(v2); v2)
{
if (auto r3 = a3.match(v3); v3)
{
return std::tuple_cat(r1, r2, r3);
}
}
}
}
value = nullptr;
return {};
}

View File

@ -1109,8 +1109,8 @@ void PPUTranslator::VMSUMUHS(ppu_opcode_t op)
const auto a = get_vr<u32[4]>(op.va);
const auto b = get_vr<u32[4]>(op.vb);
const auto c = get_vr<u32[4]>(op.vc);
const auto ml = bitcast<u32[4]>((a << 16 >> 16) * (b << 16 >> 16));
const auto mh = bitcast<u32[4]>((a >> 16) * (b >> 16));
const auto ml = noncast<u32[4]>((a << 16 >> 16) * (b << 16 >> 16));
const auto mh = noncast<u32[4]>((a >> 16) * (b >> 16));
const auto s = eval(ml + mh);
const auto s2 = eval(s + c);
const auto x = eval(s < ml | s2 < s);