[CPU] Handle constant multiply in fmadd/fmsub in constant propagation pass

This commit is contained in:
DrChat 2018-03-01 20:04:56 -06:00
parent d31db60a75
commit 6fd75cea91
3 changed files with 23 additions and 71 deletions

View File

@ -4614,16 +4614,6 @@ EMITTER_OPCODE_TABLE(OPCODE_DIV, DIV_I8, DIV_I16, DIV_I32, DIV_I64, DIV_F32,
struct MUL_ADD_F32
: Sequence<MUL_ADD_F32, I<OPCODE_MUL_ADD, F32Op, F32Op, F32Op, F32Op>> {
static void Emit(X64Emitter& e, const EmitArgType& i) {
// Calculate the multiply part if it's constant.
// TODO: Do this in the constant propagation pass.
if (i.src1.is_constant && i.src2.is_constant) {
float mul = i.src1.constant() * i.src2.constant();
e.LoadConstantXmm(e.xmm0, mul);
e.vaddss(i.dest, e.xmm0, i.src3);
return;
}
// FMA extension
if (e.IsFeatureEnabled(kX64EmitFMA)) {
EmitCommutativeBinaryXmmOp(e, i,
@ -4673,16 +4663,6 @@ struct MUL_ADD_F32
struct MUL_ADD_F64
: Sequence<MUL_ADD_F64, I<OPCODE_MUL_ADD, F64Op, F64Op, F64Op, F64Op>> {
static void Emit(X64Emitter& e, const EmitArgType& i) {
// Calculate the multiply part if it's constant.
// TODO: Do this in the constant propagation pass.
if (i.src1.is_constant && i.src2.is_constant) {
double mul = i.src1.constant() * i.src2.constant();
e.LoadConstantXmm(e.xmm0, mul);
e.vaddsd(i.dest, e.xmm0, i.src3);
return;
}
// FMA extension
if (e.IsFeatureEnabled(kX64EmitFMA)) {
EmitCommutativeBinaryXmmOp(e, i,
@ -4733,19 +4713,6 @@ struct MUL_ADD_V128
: Sequence<MUL_ADD_V128,
I<OPCODE_MUL_ADD, V128Op, V128Op, V128Op, V128Op>> {
static void Emit(X64Emitter& e, const EmitArgType& i) {
// Calculate the multiply part if it's constant.
// TODO: Do this in the constant propagation pass.
if (i.src1.is_constant && i.src2.is_constant) {
vec128_t mul;
for (int n = 0; n < 4; n++) {
mul.f32[n] = i.src1.constant().f32[n] * i.src2.constant().f32[n];
}
e.LoadConstantXmm(e.xmm0, mul);
e.vaddps(i.dest, e.xmm0, i.src3);
return;
}
// TODO(benvanik): the vfmadd sequence produces slightly different results
// than vmul+vadd and it'd be nice to know why. Until we know, it's
// disabled so tests pass.
@ -4811,16 +4778,6 @@ EMITTER_OPCODE_TABLE(OPCODE_MUL_ADD, MUL_ADD_F32, MUL_ADD_F64, MUL_ADD_V128);
struct MUL_SUB_F32
: Sequence<MUL_SUB_F32, I<OPCODE_MUL_SUB, F32Op, F32Op, F32Op, F32Op>> {
static void Emit(X64Emitter& e, const EmitArgType& i) {
// Calculate the multiply part if it's constant.
// TODO: Do this in the constant propagation pass.
if (i.src1.is_constant && i.src2.is_constant) {
float mul = i.src1.constant() * i.src2.constant();
e.LoadConstantXmm(e.xmm0, mul);
e.vsubss(i.dest, e.xmm0, i.src3);
return;
}
// FMA extension
if (e.IsFeatureEnabled(kX64EmitFMA)) {
EmitCommutativeBinaryXmmOp(e, i,
@ -4870,16 +4827,6 @@ struct MUL_SUB_F32
struct MUL_SUB_F64
: Sequence<MUL_SUB_F64, I<OPCODE_MUL_SUB, F64Op, F64Op, F64Op, F64Op>> {
static void Emit(X64Emitter& e, const EmitArgType& i) {
// Calculate the multiply part if it's constant.
// TODO: Do this in the constant propagation pass.
if (i.src1.is_constant && i.src2.is_constant) {
double mul = i.src1.constant() * i.src2.constant();
e.LoadConstantXmm(e.xmm0, mul);
e.vsubsd(i.dest, e.xmm0, i.src3);
return;
}
// FMA extension
if (e.IsFeatureEnabled(kX64EmitFMA)) {
EmitCommutativeBinaryXmmOp(e, i,
@ -4930,19 +4877,6 @@ struct MUL_SUB_V128
: Sequence<MUL_SUB_V128,
I<OPCODE_MUL_SUB, V128Op, V128Op, V128Op, V128Op>> {
static void Emit(X64Emitter& e, const EmitArgType& i) {
// Calculate the multiply part if it's constant.
// TODO: Do this in the constant propagation pass.
if (i.src1.is_constant && i.src2.is_constant) {
vec128_t mul;
for (int n = 0; n < 4; n++) {
mul.f32[n] = i.src1.constant().f32[n] * i.src2.constant().f32[n];
}
e.LoadConstantXmm(e.xmm0, mul);
e.vsubps(i.dest, e.xmm0, i.src3);
return;
}
// FMA extension
if (e.IsFeatureEnabled(kX64EmitFMA)) {
EmitCommutativeBinaryXmmOp(e, i,

View File

@ -499,11 +499,20 @@ bool ConstantPropagationPass::Run(HIRBuilder* builder) {
break;
case OPCODE_MUL_ADD:
if (i->src1.value->IsConstant() && i->src2.value->IsConstant()) {
// Multiply part is constant.
if (i->src3.value->IsConstant()) {
v->set_from(i->src1.value);
Value::MulAdd(v, i->src1.value, i->src2.value, i->src3.value);
i->Remove();
} else {
// Multiply part is constant.
Value* mul = builder->AllocValue();
mul->set_from(i->src1.value);
mul->Mul(i->src2.value);
Value* add = i->src3.value;
i->Replace(&OPCODE_ADD_info, 0);
i->set_src1(mul);
i->set_src2(add);
}
}
break;
@ -514,6 +523,16 @@ bool ConstantPropagationPass::Run(HIRBuilder* builder) {
v->set_from(i->src1.value);
Value::MulSub(v, i->src1.value, i->src2.value, i->src3.value);
i->Remove();
} else {
// Multiply part is constant.
Value* mul = builder->AllocValue();
mul->set_from(i->src1.value);
mul->Mul(i->src2.value);
Value* add = i->src3.value;
i->Replace(&OPCODE_SUB_info, 0);
i->set_src1(mul);
i->set_src2(add);
}
}
break;

View File

@ -97,8 +97,10 @@ class HIRBuilder {
void BranchTrue(Value* cond, Label* label, uint16_t branch_flags = 0);
void BranchFalse(Value* cond, Label* label, uint16_t branch_flags = 0);
// phi type_name, Block* b1, Value* v1, Block* b2, Value* v2, etc
Value* AllocValue(TypeName type = INT64_TYPE);
Value* CloneValue(Value* source);
// phi type_name, Block* b1, Value* v1, Block* b2, Value* v2, etc
Value* Assign(Value* value);
Value* Cast(Value* value, TypeName target_type);
Value* ZeroExtend(Value* value, TypeName target_type);
@ -253,9 +255,6 @@ class HIRBuilder {
void DumpValue(StringBuffer* str, Value* value);
void DumpOp(StringBuffer* str, OpcodeSignatureType sig_type, Instr::Op* op);
Value* AllocValue(TypeName type = INT64_TYPE);
Value* CloneValue(Value* source);
private:
Block* AppendBlock();
void EndBlock();