[CPU] Handle constant multiply in fmadd/fmsub in constant propagation pass
This commit is contained in:
parent
d31db60a75
commit
6fd75cea91
|
@ -4614,16 +4614,6 @@ EMITTER_OPCODE_TABLE(OPCODE_DIV, DIV_I8, DIV_I16, DIV_I32, DIV_I64, DIV_F32,
|
|||
struct MUL_ADD_F32
|
||||
: Sequence<MUL_ADD_F32, I<OPCODE_MUL_ADD, F32Op, F32Op, F32Op, F32Op>> {
|
||||
static void Emit(X64Emitter& e, const EmitArgType& i) {
|
||||
// Calculate the multiply part if it's constant.
|
||||
// TODO: Do this in the constant propagation pass.
|
||||
if (i.src1.is_constant && i.src2.is_constant) {
|
||||
float mul = i.src1.constant() * i.src2.constant();
|
||||
|
||||
e.LoadConstantXmm(e.xmm0, mul);
|
||||
e.vaddss(i.dest, e.xmm0, i.src3);
|
||||
return;
|
||||
}
|
||||
|
||||
// FMA extension
|
||||
if (e.IsFeatureEnabled(kX64EmitFMA)) {
|
||||
EmitCommutativeBinaryXmmOp(e, i,
|
||||
|
@ -4673,16 +4663,6 @@ struct MUL_ADD_F32
|
|||
struct MUL_ADD_F64
|
||||
: Sequence<MUL_ADD_F64, I<OPCODE_MUL_ADD, F64Op, F64Op, F64Op, F64Op>> {
|
||||
static void Emit(X64Emitter& e, const EmitArgType& i) {
|
||||
// Calculate the multiply part if it's constant.
|
||||
// TODO: Do this in the constant propagation pass.
|
||||
if (i.src1.is_constant && i.src2.is_constant) {
|
||||
double mul = i.src1.constant() * i.src2.constant();
|
||||
|
||||
e.LoadConstantXmm(e.xmm0, mul);
|
||||
e.vaddsd(i.dest, e.xmm0, i.src3);
|
||||
return;
|
||||
}
|
||||
|
||||
// FMA extension
|
||||
if (e.IsFeatureEnabled(kX64EmitFMA)) {
|
||||
EmitCommutativeBinaryXmmOp(e, i,
|
||||
|
@ -4733,19 +4713,6 @@ struct MUL_ADD_V128
|
|||
: Sequence<MUL_ADD_V128,
|
||||
I<OPCODE_MUL_ADD, V128Op, V128Op, V128Op, V128Op>> {
|
||||
static void Emit(X64Emitter& e, const EmitArgType& i) {
|
||||
// Calculate the multiply part if it's constant.
|
||||
// TODO: Do this in the constant propagation pass.
|
||||
if (i.src1.is_constant && i.src2.is_constant) {
|
||||
vec128_t mul;
|
||||
for (int n = 0; n < 4; n++) {
|
||||
mul.f32[n] = i.src1.constant().f32[n] * i.src2.constant().f32[n];
|
||||
}
|
||||
|
||||
e.LoadConstantXmm(e.xmm0, mul);
|
||||
e.vaddps(i.dest, e.xmm0, i.src3);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO(benvanik): the vfmadd sequence produces slightly different results
|
||||
// than vmul+vadd and it'd be nice to know why. Until we know, it's
|
||||
// disabled so tests pass.
|
||||
|
@ -4811,16 +4778,6 @@ EMITTER_OPCODE_TABLE(OPCODE_MUL_ADD, MUL_ADD_F32, MUL_ADD_F64, MUL_ADD_V128);
|
|||
struct MUL_SUB_F32
|
||||
: Sequence<MUL_SUB_F32, I<OPCODE_MUL_SUB, F32Op, F32Op, F32Op, F32Op>> {
|
||||
static void Emit(X64Emitter& e, const EmitArgType& i) {
|
||||
// Calculate the multiply part if it's constant.
|
||||
// TODO: Do this in the constant propagation pass.
|
||||
if (i.src1.is_constant && i.src2.is_constant) {
|
||||
float mul = i.src1.constant() * i.src2.constant();
|
||||
|
||||
e.LoadConstantXmm(e.xmm0, mul);
|
||||
e.vsubss(i.dest, e.xmm0, i.src3);
|
||||
return;
|
||||
}
|
||||
|
||||
// FMA extension
|
||||
if (e.IsFeatureEnabled(kX64EmitFMA)) {
|
||||
EmitCommutativeBinaryXmmOp(e, i,
|
||||
|
@ -4870,16 +4827,6 @@ struct MUL_SUB_F32
|
|||
struct MUL_SUB_F64
|
||||
: Sequence<MUL_SUB_F64, I<OPCODE_MUL_SUB, F64Op, F64Op, F64Op, F64Op>> {
|
||||
static void Emit(X64Emitter& e, const EmitArgType& i) {
|
||||
// Calculate the multiply part if it's constant.
|
||||
// TODO: Do this in the constant propagation pass.
|
||||
if (i.src1.is_constant && i.src2.is_constant) {
|
||||
double mul = i.src1.constant() * i.src2.constant();
|
||||
|
||||
e.LoadConstantXmm(e.xmm0, mul);
|
||||
e.vsubsd(i.dest, e.xmm0, i.src3);
|
||||
return;
|
||||
}
|
||||
|
||||
// FMA extension
|
||||
if (e.IsFeatureEnabled(kX64EmitFMA)) {
|
||||
EmitCommutativeBinaryXmmOp(e, i,
|
||||
|
@ -4930,19 +4877,6 @@ struct MUL_SUB_V128
|
|||
: Sequence<MUL_SUB_V128,
|
||||
I<OPCODE_MUL_SUB, V128Op, V128Op, V128Op, V128Op>> {
|
||||
static void Emit(X64Emitter& e, const EmitArgType& i) {
|
||||
// Calculate the multiply part if it's constant.
|
||||
// TODO: Do this in the constant propagation pass.
|
||||
if (i.src1.is_constant && i.src2.is_constant) {
|
||||
vec128_t mul;
|
||||
for (int n = 0; n < 4; n++) {
|
||||
mul.f32[n] = i.src1.constant().f32[n] * i.src2.constant().f32[n];
|
||||
}
|
||||
|
||||
e.LoadConstantXmm(e.xmm0, mul);
|
||||
e.vsubps(i.dest, e.xmm0, i.src3);
|
||||
return;
|
||||
}
|
||||
|
||||
// FMA extension
|
||||
if (e.IsFeatureEnabled(kX64EmitFMA)) {
|
||||
EmitCommutativeBinaryXmmOp(e, i,
|
||||
|
|
|
@ -499,11 +499,20 @@ bool ConstantPropagationPass::Run(HIRBuilder* builder) {
|
|||
break;
|
||||
case OPCODE_MUL_ADD:
|
||||
if (i->src1.value->IsConstant() && i->src2.value->IsConstant()) {
|
||||
// Multiply part is constant.
|
||||
if (i->src3.value->IsConstant()) {
|
||||
v->set_from(i->src1.value);
|
||||
Value::MulAdd(v, i->src1.value, i->src2.value, i->src3.value);
|
||||
i->Remove();
|
||||
} else {
|
||||
// Multiply part is constant.
|
||||
Value* mul = builder->AllocValue();
|
||||
mul->set_from(i->src1.value);
|
||||
mul->Mul(i->src2.value);
|
||||
|
||||
Value* add = i->src3.value;
|
||||
i->Replace(&OPCODE_ADD_info, 0);
|
||||
i->set_src1(mul);
|
||||
i->set_src2(add);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
@ -514,6 +523,16 @@ bool ConstantPropagationPass::Run(HIRBuilder* builder) {
|
|||
v->set_from(i->src1.value);
|
||||
Value::MulSub(v, i->src1.value, i->src2.value, i->src3.value);
|
||||
i->Remove();
|
||||
} else {
|
||||
// Multiply part is constant.
|
||||
Value* mul = builder->AllocValue();
|
||||
mul->set_from(i->src1.value);
|
||||
mul->Mul(i->src2.value);
|
||||
|
||||
Value* add = i->src3.value;
|
||||
i->Replace(&OPCODE_SUB_info, 0);
|
||||
i->set_src1(mul);
|
||||
i->set_src2(add);
|
||||
}
|
||||
}
|
||||
break;
|
||||
|
|
|
@ -97,8 +97,10 @@ class HIRBuilder {
|
|||
void BranchTrue(Value* cond, Label* label, uint16_t branch_flags = 0);
|
||||
void BranchFalse(Value* cond, Label* label, uint16_t branch_flags = 0);
|
||||
|
||||
// phi type_name, Block* b1, Value* v1, Block* b2, Value* v2, etc
|
||||
Value* AllocValue(TypeName type = INT64_TYPE);
|
||||
Value* CloneValue(Value* source);
|
||||
|
||||
// phi type_name, Block* b1, Value* v1, Block* b2, Value* v2, etc
|
||||
Value* Assign(Value* value);
|
||||
Value* Cast(Value* value, TypeName target_type);
|
||||
Value* ZeroExtend(Value* value, TypeName target_type);
|
||||
|
@ -253,9 +255,6 @@ class HIRBuilder {
|
|||
void DumpValue(StringBuffer* str, Value* value);
|
||||
void DumpOp(StringBuffer* str, OpcodeSignatureType sig_type, Instr::Op* op);
|
||||
|
||||
Value* AllocValue(TypeName type = INT64_TYPE);
|
||||
Value* CloneValue(Value* source);
|
||||
|
||||
private:
|
||||
Block* AppendBlock();
|
||||
void EndBlock();
|
||||
|
|
Loading…
Reference in New Issue