Add basic constant propagation simplification

This commit is contained in:
joseph 2016-12-18 16:51:05 -06:00 committed by Anthony Pesch
parent 45d5a9b5d8
commit 4ba9885aab
5 changed files with 86 additions and 282 deletions

View File

@ -194,7 +194,7 @@ set(REDREAM_SOURCES
src/jit/ir/ir.c
src/jit/ir/ir_read.c
src/jit/ir/ir_write.c
#src/jit/passes/constant_propagation_pass.c
src/jit/passes/constant_propagation_pass.c
src/jit/passes/conversion_elimination_pass.c
src/jit/passes/dead_code_elimination_pass.c
src/jit/passes/expression_simplification_pass.c

View File

@ -6,6 +6,7 @@
#include "jit/backend/jit_backend.h"
#include "jit/frontend/jit_frontend.h"
#include "jit/ir/ir.h"
#include "jit/passes/constant_propagation_pass.h"
#include "jit/passes/dead_code_elimination_pass.h"
#include "jit/passes/expression_simplification_pass.h"
#include "jit/passes/load_store_elimination_pass.h"
@ -321,6 +322,7 @@ void jit_compile_block(struct jit *jit, uint32_t guest_addr) {
/* run optimization passes */
lse_run(&ir);
cpro_run(&ir);
esimp_run(&ir);
dce_run(&ir);
ra_run(&ir, jit->backend->registers, jit->backend->num_registers);

View File

@ -1,286 +1,90 @@
#include <type_traits>
#include <unordered_map>
#include "jit/passes/constant_propagation_pass.h"
#include "jit/ir/ir.h"
#include "jit/pass_stats.h"
typedef void (*FoldFn)(IRBuilder &, Instr *i);
// specify which arguments must be constant in order for fold operation to run
enum {
ARG0_CNST = 0x1,
ARG1_CNST = 0x2,
ARG2_CNST = 0x4,
};
DEFINE_STAT(constant_propagations_removed, "constant propagations removed");
DEFINE_STAT(could_optimize_binary_op, "constant binary operations possible");
DEFINE_STAT(could_optimize_unary_op, "constant unary operations possible");
// fold callbacks for each operaton
std::unordered_map<int, FoldFn> fold_cbs;
int fold_masks[NUM_OPS];
void cpro_run(struct ir *ir) {
list_for_each_entry(instr, &ir->instrs, struct ir_instr, it) {
// OP_SELECT and OP_BRANCH_COND are the only instructions using arg2, and
// arg2's type always matches arg1's. because of this, arg2 isn't considered
// when generating the lookup table
#define CALLBACK_IDX(op, r, a0, a1) \
((op)*VALUE_NUM * VALUE_NUM * VALUE_NUM) + ((r)*VALUE_NUM * VALUE_NUM) + \
((a0)*VALUE_NUM) + (a1)
// declare a templated callback for an IR operation. note, declaring a
// callback does not actually register it. callbacks must be registered
// for a particular signature with REGISTER_FOLD
#define FOLD(op, mask) \
static struct _##op##_init { \
_##op##_init() { \
fold_masks[OP_##op] = mask; \
} \
} op##_init; \
template <typename R = struct ir_valueInfo<VALUE_V>, \
typename A0 = struct ir_valueInfo<VALUE_V>, \
typename A1 = struct ir_valueInfo<VALUE_V>> \
void Handle##op(struct ir *ir, Instr *instr)
// registers a fold callback for the specified signature
#define REGISTER_FOLD(op, r, a0, a1) \
static struct _cpp_##op##_##r##_##a0##_##a1##_init { \
_cpp_##op##_##r##_##a0##_##a1##_init() { \
fold_cbs[CALLBACK_IDX(OP_##op, VALUE_##r, VALUE_##a0, VALUE_##a1)] = \
&Handle##op<struct ir_valueInfo<VALUE_##r>, \
struct ir_valueInfo<VALUE_##a0>, \
struct ir_valueInfo<VALUE_##a1>>; \
} \
} cpp_##op##_##r##_##a0##_##a1##_init
// common helpers for fold functions
#define ARG0() (instr->arg[0]->*A0::fn)()
#define ARG1() (instr->arg[1]->*A1::fn)()
#define ARG2() (instr->arg[2]->*A1::fn)()
#define ARG0_UNSIGNED() static_cast<typename A0::unsigned_type>(ARG0())
#define ARG1_UNSIGNED() static_cast<typename A1::unsigned_type>(ARG1())
#define ARG2_UNSIGNED() static_cast<typename A1::unsigned_type>(ARG2())
#define RESULT(expr) \
ir_replace_uses(instr, ir_alloc_constant( \
ir, static_cast<typename R::signed_type>(expr))); \
ir_remove_instr(instr)
static FoldFn GetFoldFn(Instr *instr) {
auto it = fold_cbs.find(
CALLBACK_IDX(instr->op, instr->type,
instr->arg[0] ? (int)instr->arg[0]->type : VALUE_V,
instr->arg[1] ? (int)instr->arg[1]->type : VALUE_V));
if (it == fold_cbs.end()) {
return nullptr;
}
return it->second;
}
static int GetFoldMask(Instr *instr) {
return fold_masks[instr->op];
}
static int GetConstantSig(Instr *instr) {
int cnst_sig = 0;
if (instr->arg[0] && ir_is_constant(instr->arg[0])) {
cnst_sig |= ARG0_CNST;
}
if (instr->arg[1] && ir_is_constant(instr->arg[1])) {
cnst_sig |= ARG1_CNST;
}
if (instr->arg[2] && ir_is_constant(instr->arg[2])) {
cnst_sig |= ARG2_CNST;
}
return cnst_sig;
}
void ConstantPropagationPass::Run(struct ir *ir) {
list_for_each_entry_safe(instr, &ir->instrs, struct ir_instr, it) {
int fold_mask = GetFoldMask(instr);
int cnst_sig = GetConstantSig(instr);
if (!fold_mask || (cnst_sig & fold_mask) != fold_mask) {
/* Skip instructions which do not perform any operations */
if(instr->op == OP_DEBUG_INFO || instr->op == OP_LABEL)
continue;
/* Profile the number of possible constant propagation optimizations */
if (instr->arg[0] && instr->arg[1] && ir_is_constant(instr->arg[0]) &&
ir_is_constant(instr->arg[1])) {
STAT_could_optimize_binary_op++;
}
else if (instr->arg[0] && !instr->arg[1] && ir_is_constant(instr->arg[0])){
STAT_could_optimize_unary_op++;
}
FoldFn fold = GetFoldFn(instr);
if (!fold) {
continue;
/* Simplify binary ops with constant arguments */
if(instr->arg[0] && ir_is_constant(instr->arg[0]) &&
instr->arg[1] && ir_is_constant(instr->arg[1]))
{
uint64_t lhs = ir_zext_constant(instr->arg[0]);
uint64_t rhs = ir_zext_constant(instr->arg[1]);
struct ir_value *result;
switch(instr->op)
{
case OP_ADD:
result = ir_alloc_int(ir, lhs + rhs, instr->result->type);
break;
case OP_AND:
result = ir_alloc_int(ir, lhs & rhs, instr->result->type);
break;
case OP_DIV:
result = ir_alloc_int(ir, lhs / rhs, instr->result->type);
break;
case OP_LSHR:
result = ir_alloc_int(ir, lhs >> rhs, instr->result->type);
break;
case OP_OR:
result = ir_alloc_int(ir, lhs | rhs, instr->result->type);
break;
case OP_SHL:
result = ir_alloc_int(ir, lhs << rhs, instr->result->type);
break;
case OP_SUB:
result = ir_alloc_int(ir, lhs - rhs, instr->result->type);
break;
case OP_UMUL:
result = ir_alloc_int(ir, lhs * rhs, instr->result->type);
break;
case OP_XOR:
result = ir_alloc_int(ir, lhs ^ rhs, instr->result->type);
break;
default:
continue;
}
ir_replace_uses(instr->result, result);
STAT_constant_propagations_removed++;
}
/* Simplify constant unary ops */
else if(instr->arg[0] && !instr->arg[1] && ir_is_constant(instr->arg[0])) {
uint64_t arg = ir_zext_constant(instr->arg[0]);
struct ir_value *result;
switch(instr->op)
{
case OP_NEG:
result = ir_alloc_int(ir, 0 - arg, instr->result->type);
break;
case OP_NOT:
result = ir_alloc_int(ir, ~arg, instr->result->type);
break;
default:
continue;
}
ir_replace_uses(instr->result, result);
STAT_constant_propagations_removed++;
}
fold(builder, instr);
}
}
FOLD(SELECT, ARG0_CNST) {
ir_replace_uses(instr, ARG0() ? instr->arg[1] : instr->arg[2]);
ir_remove_instr(ir, instr);
}
REGISTER_FOLD(SELECT, I8, I8, I8);
REGISTER_FOLD(SELECT, I16, I16, I16);
REGISTER_FOLD(SELECT, I32, I32, I32);
REGISTER_FOLD(SELECT, I64, I64, I64);
FOLD(EQ, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() == ARG1());
}
REGISTER_FOLD(EQ, I8, I8, I8);
REGISTER_FOLD(EQ, I8, I16, I16);
REGISTER_FOLD(EQ, I8, I32, I32);
REGISTER_FOLD(EQ, I8, I64, I64);
REGISTER_FOLD(EQ, I8, F32, F32);
REGISTER_FOLD(EQ, I8, F64, F64);
FOLD(NE, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() != ARG1());
}
REGISTER_FOLD(NE, I8, I8, I8);
REGISTER_FOLD(NE, I8, I16, I16);
REGISTER_FOLD(NE, I8, I32, I32);
REGISTER_FOLD(NE, I8, I64, I64);
REGISTER_FOLD(NE, I8, F32, F32);
REGISTER_FOLD(NE, I8, F64, F64);
FOLD(SGE, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() >= ARG1());
}
REGISTER_FOLD(SGE, I8, I8, I8);
REGISTER_FOLD(SGE, I8, I16, I16);
REGISTER_FOLD(SGE, I8, I32, I32);
REGISTER_FOLD(SGE, I8, I64, I64);
REGISTER_FOLD(SGE, I8, F32, F32);
REGISTER_FOLD(SGE, I8, F64, F64);
FOLD(SGT, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() > ARG1());
}
REGISTER_FOLD(SGT, I8, I8, I8);
REGISTER_FOLD(SGT, I8, I16, I16);
REGISTER_FOLD(SGT, I8, I32, I32);
REGISTER_FOLD(SGT, I8, I64, I64);
REGISTER_FOLD(SGT, I8, F32, F32);
REGISTER_FOLD(SGT, I8, F64, F64);
// IR_OP(UGE)
// IR_OP(UGT)
FOLD(SLE, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() <= ARG1());
}
REGISTER_FOLD(SLE, I8, I8, I8);
REGISTER_FOLD(SLE, I8, I16, I16);
REGISTER_FOLD(SLE, I8, I32, I32);
REGISTER_FOLD(SLE, I8, I64, I64);
REGISTER_FOLD(SLE, I8, F32, F32);
REGISTER_FOLD(SLE, I8, F64, F64);
FOLD(SLT, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() < ARG1());
}
REGISTER_FOLD(SLT, I8, I8, I8);
REGISTER_FOLD(SLT, I8, I16, I16);
REGISTER_FOLD(SLT, I8, I32, I32);
REGISTER_FOLD(SLT, I8, I64, I64);
REGISTER_FOLD(SLT, I8, F32, F32);
REGISTER_FOLD(SLT, I8, F64, F64);
// IR_OP(ULE)
// IR_OP(ULT)
FOLD(ADD, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() + ARG1());
}
REGISTER_FOLD(ADD, I8, I8, I8);
REGISTER_FOLD(ADD, I16, I16, I16);
REGISTER_FOLD(ADD, I32, I32, I32);
REGISTER_FOLD(ADD, I64, I64, I64);
REGISTER_FOLD(ADD, F32, F32, F32);
REGISTER_FOLD(ADD, F64, F64, F64);
FOLD(SUB, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() - ARG1());
}
REGISTER_FOLD(SUB, I8, I8, I8);
REGISTER_FOLD(SUB, I16, I16, I16);
REGISTER_FOLD(SUB, I32, I32, I32);
REGISTER_FOLD(SUB, I64, I64, I64);
REGISTER_FOLD(SUB, F32, F32, F32);
REGISTER_FOLD(SUB, F64, F64, F64);
FOLD(SMUL, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() * ARG1());
}
REGISTER_FOLD(SMUL, I8, I8, I8);
REGISTER_FOLD(SMUL, I16, I16, I16);
REGISTER_FOLD(SMUL, I32, I32, I32);
REGISTER_FOLD(SMUL, I64, I64, I64);
REGISTER_FOLD(SMUL, F32, F32, F32);
REGISTER_FOLD(SMUL, F64, F64, F64);
FOLD(UMUL, ARG0_CNST | ARG1_CNST) {
auto lhs = ARG0_UNSIGNED();
auto rhs = ARG1_UNSIGNED();
RESULT(lhs * rhs);
}
REGISTER_FOLD(UMUL, I8, I8, I8);
REGISTER_FOLD(UMUL, I16, I16, I16);
REGISTER_FOLD(UMUL, I32, I32, I32);
REGISTER_FOLD(UMUL, I64, I64, I64);
// IR_OP(DIV)
// IR_OP(NEG)
// IR_OP(SQRT)
// IR_OP(ABS)
// IR_OP(SIN)
// IR_OP(COS)
FOLD(AND, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() & ARG1());
}
REGISTER_FOLD(AND, I8, I8, I8);
REGISTER_FOLD(AND, I16, I16, I16);
REGISTER_FOLD(AND, I32, I32, I32);
REGISTER_FOLD(AND, I64, I64, I64);
FOLD(OR, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() | ARG1());
}
REGISTER_FOLD(OR, I8, I8, I8);
REGISTER_FOLD(OR, I16, I16, I16);
REGISTER_FOLD(OR, I32, I32, I32);
REGISTER_FOLD(OR, I64, I64, I64);
FOLD(XOR, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() ^ ARG1());
}
REGISTER_FOLD(XOR, I8, I8, I8);
REGISTER_FOLD(XOR, I16, I16, I16);
REGISTER_FOLD(XOR, I32, I32, I32);
REGISTER_FOLD(XOR, I64, I64, I64);
FOLD(NOT, ARG0_CNST) {
RESULT(~ARG0());
}
REGISTER_FOLD(NOT, I8, I8, V);
REGISTER_FOLD(NOT, I16, I16, V);
REGISTER_FOLD(NOT, I32, I32, V);
REGISTER_FOLD(NOT, I64, I64, V);
FOLD(SHL, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0() << ARG1());
}
REGISTER_FOLD(SHL, I8, I8, I32);
REGISTER_FOLD(SHL, I16, I16, I32);
REGISTER_FOLD(SHL, I32, I32, I32);
REGISTER_FOLD(SHL, I64, I64, I32);
// IR_OP(ASHR)
FOLD(LSHR, ARG0_CNST | ARG1_CNST) {
RESULT(ARG0_UNSIGNED() >> ARG1());
}
REGISTER_FOLD(LSHR, I8, I8, I32);
REGISTER_FOLD(LSHR, I16, I16, I32);
REGISTER_FOLD(LSHR, I32, I32, I32);
REGISTER_FOLD(LSHR, I64, I64, I32);
// IR_OP(BRANCH)
// IR_OP(BRANCH_COND)
// IR_OP(CALL_EXTERNAL)

View File

@ -1,15 +1,10 @@
#ifndef CONSTANT_PROPAGATION_PASS_H
#define CONSTANT_PROPAGATION_PASS_H
class ConstantPropagationPass : public Pass {
public:
static const char *NAME = "constprop";
struct ir;
const char *name() {
return NAME;
}
void Run(struct ir *ir);
};
void cpro_run(struct ir *ir);
#endif

View File

@ -6,6 +6,7 @@
#include "jit/ir/ir.h"
#include "jit/jit.h"
#include "jit/pass_stats.h"
#include "jit/passes/constant_propagation_pass.h"
#include "jit/passes/conversion_elimination_pass.h"
#include "jit/passes/dead_code_elimination_pass.h"
#include "jit/passes/expression_simplification_pass.h"
@ -14,7 +15,7 @@
#include "sys/filesystem.h"
DEFINE_OPTION_INT(help, 0, "Show help");
DEFINE_OPTION_STRING(pass, "lse,cve,esimp,dce,ra",
DEFINE_OPTION_STRING(pass, "lse,cpro,cve,esimp,dce,ra",
"Comma-separated list of passes to run");
DEFINE_STAT(ir_instrs_total, "total ir instructions");
@ -77,6 +78,8 @@ static void process_file(struct jit *jit, const char *filename,
while (name) {
if (!strcmp(name, "lse")) {
lse_run(&ir);
} else if (!strcmp(name, "cpro")){
cpro_run(&ir);
} else if (!strcmp(name, "cve")) {
cve_run(&ir);
} else if (!strcmp(name, "dce")) {