JitArm64: Use LSL+CLS for classifying floats
This is a little trick I came up with that lets us restructure our float classification code so we can exit earlier when the float is normal, which is the case more often than not. First we shift left by 1 to get rid of the sign bit, and then we count the number of leading sign bits. If the result is less than 10 (for doubles) or 7 (for floats), the float is normal. This is because, if the float isn't normal, the exponent is either all zeroes or all ones.
This commit is contained in:
parent
5d9838548b
commit
255ee3fdce
|
@ -23,11 +23,15 @@ static constexpr u64 DOUBLE_SIGN = 0x8000000000000000ULL;
|
||||||
static constexpr u64 DOUBLE_EXP = 0x7FF0000000000000ULL;
|
static constexpr u64 DOUBLE_EXP = 0x7FF0000000000000ULL;
|
||||||
static constexpr u64 DOUBLE_FRAC = 0x000FFFFFFFFFFFFFULL;
|
static constexpr u64 DOUBLE_FRAC = 0x000FFFFFFFFFFFFFULL;
|
||||||
static constexpr u64 DOUBLE_ZERO = 0x0000000000000000ULL;
|
static constexpr u64 DOUBLE_ZERO = 0x0000000000000000ULL;
|
||||||
|
static constexpr int DOUBLE_EXP_WIDTH = 11;
|
||||||
|
static constexpr int DOUBLE_FRAC_WIDTH = 52;
|
||||||
|
|
||||||
static constexpr u32 FLOAT_SIGN = 0x80000000;
|
static constexpr u32 FLOAT_SIGN = 0x80000000;
|
||||||
static constexpr u32 FLOAT_EXP = 0x7F800000;
|
static constexpr u32 FLOAT_EXP = 0x7F800000;
|
||||||
static constexpr u32 FLOAT_FRAC = 0x007FFFFF;
|
static constexpr u32 FLOAT_FRAC = 0x007FFFFF;
|
||||||
static constexpr u32 FLOAT_ZERO = 0x00000000;
|
static constexpr u32 FLOAT_ZERO = 0x00000000;
|
||||||
|
static constexpr int FLOAT_EXP_WIDTH = 8;
|
||||||
|
static constexpr int FLOAT_FRAC_WIDTH = 23;
|
||||||
|
|
||||||
inline bool IsQNAN(double d)
|
inline bool IsQNAN(double d)
|
||||||
{
|
{
|
||||||
|
|
|
@ -315,25 +315,14 @@ void JitArm64::GenerateFrsqrte()
|
||||||
// inf, even the mantissa matches. But the mantissa does not match for most other inputs, so in
|
// inf, even the mantissa matches. But the mantissa does not match for most other inputs, so in
|
||||||
// the normal case we calculate the mantissa using the table-based algorithm from the interpreter.
|
// the normal case we calculate the mantissa using the table-based algorithm from the interpreter.
|
||||||
|
|
||||||
|
LSL(ARM64Reg::X2, ARM64Reg::X1, 1);
|
||||||
m_float_emit.FMOV(ARM64Reg::X0, ARM64Reg::D0);
|
m_float_emit.FMOV(ARM64Reg::X0, ARM64Reg::D0);
|
||||||
TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP | Common::DOUBLE_FRAC, 64));
|
CLS(ARM64Reg::X3, ARM64Reg::X2);
|
||||||
FixupBranch zero = B(CCFlags::CC_EQ);
|
TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_SIGN, 64));
|
||||||
AND(ARM64Reg::X2, ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP, 64));
|
CCMP(ARM64Reg::X3, Common::DOUBLE_EXP_WIDTH - 1, 0b0010, CCFlags::CC_EQ);
|
||||||
MOVI2R(ARM64Reg::X3, Common::DOUBLE_EXP);
|
FixupBranch not_positive_normal = B(CCFlags::CC_HS);
|
||||||
CMP(ARM64Reg::X2, ARM64Reg::X3);
|
|
||||||
FixupBranch nan_or_inf = B(CCFlags::CC_EQ);
|
|
||||||
FixupBranch negative = TBNZ(ARM64Reg::X1, 63);
|
|
||||||
FixupBranch normal = CBNZ(ARM64Reg::X2);
|
|
||||||
|
|
||||||
// "Normalize" denormal values.
|
const u8* positive_normal = GetCodePtr();
|
||||||
// The simplified calculation used here results in the upper 11 bits being incorrect,
|
|
||||||
// but that's fine, because the code below never reads those bits.
|
|
||||||
CLZ(ARM64Reg::X3, ARM64Reg::X1);
|
|
||||||
LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3);
|
|
||||||
LSR(ARM64Reg::X1, ARM64Reg::X1, 11);
|
|
||||||
BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12);
|
|
||||||
|
|
||||||
SetJumpTarget(normal);
|
|
||||||
UBFX(ARM64Reg::X2, ARM64Reg::X1, 48, 5);
|
UBFX(ARM64Reg::X2, ARM64Reg::X1, 48, 5);
|
||||||
MOVP2R(ARM64Reg::X3, &Common::frsqrte_expected);
|
MOVP2R(ARM64Reg::X3, &Common::frsqrte_expected);
|
||||||
ADD(ARM64Reg::X2, ARM64Reg::X3, ARM64Reg::X2, ArithOption(ARM64Reg::X2, ShiftType::LSL, 3));
|
ADD(ARM64Reg::X2, ARM64Reg::X3, ARM64Reg::X2, ArithOption(ARM64Reg::X2, ShiftType::LSL, 3));
|
||||||
|
@ -344,27 +333,41 @@ void JitArm64::GenerateFrsqrte()
|
||||||
ORR(ARM64Reg::X0, ARM64Reg::X0, ARM64Reg::X1, ArithOption(ARM64Reg::X1, ShiftType::LSL, 26));
|
ORR(ARM64Reg::X0, ARM64Reg::X0, ARM64Reg::X1, ArithOption(ARM64Reg::X1, ShiftType::LSL, 26));
|
||||||
RET();
|
RET();
|
||||||
|
|
||||||
SetJumpTarget(zero);
|
SetJumpTarget(not_positive_normal);
|
||||||
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
||||||
|
FixupBranch not_positive_normal_not_zero = CBNZ(ARM64Reg::X2);
|
||||||
|
|
||||||
|
// Zero
|
||||||
FixupBranch skip_set_zx = TBNZ(ARM64Reg::W3, 26);
|
FixupBranch skip_set_zx = TBNZ(ARM64Reg::W3, 26);
|
||||||
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_ZX, ARM64Reg::W2);
|
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_ZX, ARM64Reg::W2);
|
||||||
|
const u8* store_fpscr = GetCodePtr();
|
||||||
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
||||||
SetJumpTarget(skip_set_zx);
|
SetJumpTarget(skip_set_zx);
|
||||||
|
const u8* done = GetCodePtr();
|
||||||
RET();
|
RET();
|
||||||
|
|
||||||
|
SetJumpTarget(not_positive_normal_not_zero);
|
||||||
|
FixupBranch nan_or_inf = TBNZ(ARM64Reg::X1, 62);
|
||||||
|
FixupBranch negative = TBNZ(ARM64Reg::X1, 63);
|
||||||
|
|
||||||
|
// "Normalize" denormal values.
|
||||||
|
// The simplified calculation used here results in the upper 11 bits being incorrect,
|
||||||
|
// but that's fine, because the code we jump to never reads those bits.
|
||||||
|
CLZ(ARM64Reg::X3, ARM64Reg::X1);
|
||||||
|
LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3);
|
||||||
|
LSR(ARM64Reg::X1, ARM64Reg::X1, 11);
|
||||||
|
BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12);
|
||||||
|
B(positive_normal);
|
||||||
|
|
||||||
SetJumpTarget(nan_or_inf);
|
SetJumpTarget(nan_or_inf);
|
||||||
MOVI2R(ARM64Reg::X3, Common::BitCast<u64>(-std::numeric_limits<double>::infinity()));
|
MOVI2R(ARM64Reg::X2, Common::BitCast<u64>(-std::numeric_limits<double>::infinity()));
|
||||||
CMP(ARM64Reg::X1, ARM64Reg::X3);
|
CMP(ARM64Reg::X1, ARM64Reg::X2);
|
||||||
FixupBranch nan_or_positive_inf = B(CCFlags::CC_NEQ);
|
B(CCFlags::CC_NEQ, done);
|
||||||
|
|
||||||
SetJumpTarget(negative);
|
SetJumpTarget(negative);
|
||||||
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
TBNZ(ARM64Reg::W3, 9, done);
|
||||||
FixupBranch skip_set_vxsqrt = TBNZ(ARM64Reg::W3, 9);
|
|
||||||
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_VXSQRT, ARM64Reg::W2);
|
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_VXSQRT, ARM64Reg::W2);
|
||||||
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
B(store_fpscr);
|
||||||
SetJumpTarget(skip_set_vxsqrt);
|
|
||||||
SetJumpTarget(nan_or_positive_inf);
|
|
||||||
RET();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Input in X0, output in W1, clobbers X0-X3 and flags.
|
// Input in X0, output in W1, clobbers X0-X3 and flags.
|
||||||
|
@ -438,25 +441,17 @@ void JitArm64::GenerateFPRF(bool single)
|
||||||
const auto reg_encoder = single ? EncodeRegTo32 : EncodeRegTo64;
|
const auto reg_encoder = single ? EncodeRegTo32 : EncodeRegTo64;
|
||||||
|
|
||||||
const ARM64Reg input_reg = reg_encoder(ARM64Reg::W0);
|
const ARM64Reg input_reg = reg_encoder(ARM64Reg::W0);
|
||||||
const ARM64Reg temp_reg = reg_encoder(ARM64Reg::W1);
|
const ARM64Reg cls_reg = reg_encoder(ARM64Reg::W1);
|
||||||
const ARM64Reg exp_reg = reg_encoder(ARM64Reg::W2);
|
const ARM64Reg exp_and_frac_reg = reg_encoder(ARM64Reg::W2);
|
||||||
|
|
||||||
constexpr ARM64Reg fprf_reg = ARM64Reg::W3;
|
constexpr ARM64Reg fprf_reg = ARM64Reg::W3;
|
||||||
constexpr ARM64Reg fpscr_reg = ARM64Reg::W4;
|
constexpr ARM64Reg fpscr_reg = ARM64Reg::W4;
|
||||||
|
|
||||||
const int input_size = single ? 32 : 64;
|
const int input_size = single ? 32 : 64;
|
||||||
const u64 input_exp_mask = single ? Common::FLOAT_EXP : Common::DOUBLE_EXP;
|
const int input_exp_size = single ? Common::FLOAT_EXP_WIDTH : Common::DOUBLE_EXP_WIDTH;
|
||||||
const u64 input_frac_mask = single ? Common::FLOAT_FRAC : Common::DOUBLE_FRAC;
|
const u64 input_frac_mask = single ? Common::FLOAT_FRAC : Common::DOUBLE_FRAC;
|
||||||
constexpr u32 output_sign_mask = 0xC;
|
constexpr u32 output_sign_mask = 0xC;
|
||||||
|
|
||||||
// This code is duplicated for the most common cases for performance.
|
|
||||||
// For the less common cases, we branch to an existing copy of this code.
|
|
||||||
auto emit_write_fprf_and_ret = [&] {
|
|
||||||
BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH);
|
|
||||||
STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
|
|
||||||
RET();
|
|
||||||
};
|
|
||||||
|
|
||||||
// First of all, start the load of the old FPSCR value, in case it takes a while
|
// First of all, start the load of the old FPSCR value, in case it takes a while
|
||||||
LDR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
|
LDR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
|
||||||
|
|
||||||
|
@ -464,33 +459,30 @@ void JitArm64::GenerateFPRF(bool single)
|
||||||
MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN);
|
MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN);
|
||||||
MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN);
|
MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN);
|
||||||
CMP(input_reg, 0); // Grab sign bit (conveniently the same bit for floats as for integers)
|
CMP(input_reg, 0); // Grab sign bit (conveniently the same bit for floats as for integers)
|
||||||
|
LSL(exp_and_frac_reg, input_reg, 1);
|
||||||
CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT);
|
CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT);
|
||||||
|
CLS(cls_reg, exp_and_frac_reg);
|
||||||
AND(exp_reg, input_reg, LogicalImm(input_exp_mask, input_size)); // Grab exponent
|
FixupBranch not_zero = CBNZ(exp_and_frac_reg);
|
||||||
FixupBranch zero_or_denormal = CBZ(exp_reg);
|
|
||||||
|
|
||||||
// exp != 0
|
|
||||||
MOVI2R(temp_reg, input_exp_mask);
|
|
||||||
CMP(exp_reg, temp_reg);
|
|
||||||
FixupBranch nan_or_inf = B(CCFlags::CC_EQ);
|
|
||||||
|
|
||||||
// exp != 0 && exp != EXP_MASK
|
|
||||||
emit_write_fprf_and_ret();
|
|
||||||
|
|
||||||
// exp == 0
|
|
||||||
SetJumpTarget(zero_or_denormal);
|
|
||||||
TST(input_reg, LogicalImm(input_frac_mask, input_size));
|
|
||||||
FixupBranch denormal = B(CCFlags::CC_NEQ);
|
|
||||||
|
|
||||||
// exp == 0 && frac == 0
|
// exp == 0 && frac == 0
|
||||||
LSR(ARM64Reg::W1, fprf_reg, 3);
|
LSR(ARM64Reg::W1, fprf_reg, 3);
|
||||||
MOVI2R(fprf_reg, Common::PPC_FPCLASS_PZ & ~output_sign_mask);
|
MOVI2R(fprf_reg, Common::PPC_FPCLASS_PZ & ~output_sign_mask);
|
||||||
BFI(fprf_reg, ARM64Reg::W1, 4, 1);
|
BFI(fprf_reg, ARM64Reg::W1, 4, 1);
|
||||||
|
|
||||||
const u8* write_fprf_and_ret = GetCodePtr();
|
const u8* write_fprf_and_ret = GetCodePtr();
|
||||||
emit_write_fprf_and_ret();
|
BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH);
|
||||||
|
STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
|
||||||
|
RET();
|
||||||
|
|
||||||
|
// exp != 0 || frac != 0
|
||||||
|
SetJumpTarget(not_zero);
|
||||||
|
CMP(cls_reg, input_exp_size - 1);
|
||||||
|
B(CCFlags::CC_LO, write_fprf_and_ret); // Branch if input is normal
|
||||||
|
|
||||||
|
// exp == EXP_MASK || (exp == 0 && frac != 0)
|
||||||
|
FixupBranch nan_or_inf = TBNZ(input_reg, input_size - 2);
|
||||||
|
|
||||||
// exp == 0 && frac != 0
|
// exp == 0 && frac != 0
|
||||||
SetJumpTarget(denormal);
|
|
||||||
ORR(fprf_reg, fprf_reg, LogicalImm(Common::PPC_FPCLASS_PD & ~output_sign_mask, 32));
|
ORR(fprf_reg, fprf_reg, LogicalImm(Common::PPC_FPCLASS_PD & ~output_sign_mask, 32));
|
||||||
B(write_fprf_and_ret);
|
B(write_fprf_and_ret);
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,8 @@ TEST(JitArm64, FPRF)
|
||||||
const u32 expected_double = RunUpdateFPRF(
|
const u32 expected_double = RunUpdateFPRF(
|
||||||
ppc_state, [&] { ppc_state.UpdateFPRFDouble(Common::BitCast<double>(double_input)); });
|
ppc_state, [&] { ppc_state.UpdateFPRFDouble(Common::BitCast<double>(double_input)); });
|
||||||
const u32 actual_double = RunUpdateFPRF(ppc_state, [&] { test.fprf_double(double_input); });
|
const u32 actual_double = RunUpdateFPRF(ppc_state, [&] { test.fprf_double(double_input); });
|
||||||
|
if (expected_double != actual_double)
|
||||||
|
fmt::print("{:016x} -> {:08x} == {:08x}\n", double_input, actual_double, expected_double);
|
||||||
EXPECT_EQ(expected_double, actual_double);
|
EXPECT_EQ(expected_double, actual_double);
|
||||||
|
|
||||||
const u32 single_input = ConvertToSingle(double_input);
|
const u32 single_input = ConvertToSingle(double_input);
|
||||||
|
@ -91,6 +93,8 @@ TEST(JitArm64, FPRF)
|
||||||
const u32 expected_single = RunUpdateFPRF(
|
const u32 expected_single = RunUpdateFPRF(
|
||||||
ppc_state, [&] { ppc_state.UpdateFPRFSingle(Common::BitCast<float>(single_input)); });
|
ppc_state, [&] { ppc_state.UpdateFPRFSingle(Common::BitCast<float>(single_input)); });
|
||||||
const u32 actual_single = RunUpdateFPRF(ppc_state, [&] { test.fprf_single(single_input); });
|
const u32 actual_single = RunUpdateFPRF(ppc_state, [&] { test.fprf_single(single_input); });
|
||||||
|
if (expected_single != actual_single)
|
||||||
|
fmt::print("{:08x} -> {:08x} == {:08x}\n", single_input, actual_single, expected_single);
|
||||||
EXPECT_EQ(expected_single, actual_single);
|
EXPECT_EQ(expected_single, actual_single);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue