JitArm64: Use LSL+CLS for classifying floats

This is a little trick I came up with that lets us restructure our float
classification code so we can exit earlier when the float is normal,
which is the case more often than not.

First we shift left by 1 to get rid of the sign bit, and then we count
the number of leading sign bits. If the result is less than 10 (for
doubles) or 7 (for floats), the float is normal. This is because, if the
float isn't normal, the exponent is either all zeroes or all ones.
This commit is contained in:
JosJuice 2023-10-13 19:27:03 +02:00
parent 5d9838548b
commit 255ee3fdce
3 changed files with 56 additions and 56 deletions

View File

@ -23,11 +23,15 @@ static constexpr u64 DOUBLE_SIGN = 0x8000000000000000ULL;
static constexpr u64 DOUBLE_EXP = 0x7FF0000000000000ULL; static constexpr u64 DOUBLE_EXP = 0x7FF0000000000000ULL;
static constexpr u64 DOUBLE_FRAC = 0x000FFFFFFFFFFFFFULL; static constexpr u64 DOUBLE_FRAC = 0x000FFFFFFFFFFFFFULL;
static constexpr u64 DOUBLE_ZERO = 0x0000000000000000ULL; static constexpr u64 DOUBLE_ZERO = 0x0000000000000000ULL;
static constexpr int DOUBLE_EXP_WIDTH = 11;
static constexpr int DOUBLE_FRAC_WIDTH = 52;
static constexpr u32 FLOAT_SIGN = 0x80000000; static constexpr u32 FLOAT_SIGN = 0x80000000;
static constexpr u32 FLOAT_EXP = 0x7F800000; static constexpr u32 FLOAT_EXP = 0x7F800000;
static constexpr u32 FLOAT_FRAC = 0x007FFFFF; static constexpr u32 FLOAT_FRAC = 0x007FFFFF;
static constexpr u32 FLOAT_ZERO = 0x00000000; static constexpr u32 FLOAT_ZERO = 0x00000000;
static constexpr int FLOAT_EXP_WIDTH = 8;
static constexpr int FLOAT_FRAC_WIDTH = 23;
inline bool IsQNAN(double d) inline bool IsQNAN(double d)
{ {

View File

@ -315,25 +315,14 @@ void JitArm64::GenerateFrsqrte()
// inf, even the mantissa matches. But the mantissa does not match for most other inputs, so in // inf, even the mantissa matches. But the mantissa does not match for most other inputs, so in
// the normal case we calculate the mantissa using the table-based algorithm from the interpreter. // the normal case we calculate the mantissa using the table-based algorithm from the interpreter.
LSL(ARM64Reg::X2, ARM64Reg::X1, 1);
m_float_emit.FMOV(ARM64Reg::X0, ARM64Reg::D0); m_float_emit.FMOV(ARM64Reg::X0, ARM64Reg::D0);
TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP | Common::DOUBLE_FRAC, 64)); CLS(ARM64Reg::X3, ARM64Reg::X2);
FixupBranch zero = B(CCFlags::CC_EQ); TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_SIGN, 64));
AND(ARM64Reg::X2, ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP, 64)); CCMP(ARM64Reg::X3, Common::DOUBLE_EXP_WIDTH - 1, 0b0010, CCFlags::CC_EQ);
MOVI2R(ARM64Reg::X3, Common::DOUBLE_EXP); FixupBranch not_positive_normal = B(CCFlags::CC_HS);
CMP(ARM64Reg::X2, ARM64Reg::X3);
FixupBranch nan_or_inf = B(CCFlags::CC_EQ);
FixupBranch negative = TBNZ(ARM64Reg::X1, 63);
FixupBranch normal = CBNZ(ARM64Reg::X2);
// "Normalize" denormal values. const u8* positive_normal = GetCodePtr();
// The simplified calculation used here results in the upper 11 bits being incorrect,
// but that's fine, because the code below never reads those bits.
CLZ(ARM64Reg::X3, ARM64Reg::X1);
LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3);
LSR(ARM64Reg::X1, ARM64Reg::X1, 11);
BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12);
SetJumpTarget(normal);
UBFX(ARM64Reg::X2, ARM64Reg::X1, 48, 5); UBFX(ARM64Reg::X2, ARM64Reg::X1, 48, 5);
MOVP2R(ARM64Reg::X3, &Common::frsqrte_expected); MOVP2R(ARM64Reg::X3, &Common::frsqrte_expected);
ADD(ARM64Reg::X2, ARM64Reg::X3, ARM64Reg::X2, ArithOption(ARM64Reg::X2, ShiftType::LSL, 3)); ADD(ARM64Reg::X2, ARM64Reg::X3, ARM64Reg::X2, ArithOption(ARM64Reg::X2, ShiftType::LSL, 3));
@ -344,27 +333,41 @@ void JitArm64::GenerateFrsqrte()
ORR(ARM64Reg::X0, ARM64Reg::X0, ARM64Reg::X1, ArithOption(ARM64Reg::X1, ShiftType::LSL, 26)); ORR(ARM64Reg::X0, ARM64Reg::X0, ARM64Reg::X1, ArithOption(ARM64Reg::X1, ShiftType::LSL, 26));
RET(); RET();
SetJumpTarget(zero); SetJumpTarget(not_positive_normal);
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr)); LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
FixupBranch not_positive_normal_not_zero = CBNZ(ARM64Reg::X2);
// Zero
FixupBranch skip_set_zx = TBNZ(ARM64Reg::W3, 26); FixupBranch skip_set_zx = TBNZ(ARM64Reg::W3, 26);
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_ZX, ARM64Reg::W2); ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_ZX, ARM64Reg::W2);
const u8* store_fpscr = GetCodePtr();
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr)); STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
SetJumpTarget(skip_set_zx); SetJumpTarget(skip_set_zx);
const u8* done = GetCodePtr();
RET(); RET();
SetJumpTarget(not_positive_normal_not_zero);
FixupBranch nan_or_inf = TBNZ(ARM64Reg::X1, 62);
FixupBranch negative = TBNZ(ARM64Reg::X1, 63);
// "Normalize" denormal values.
// The simplified calculation used here results in the upper 11 bits being incorrect,
// but that's fine, because the code we jump to never reads those bits.
CLZ(ARM64Reg::X3, ARM64Reg::X1);
LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3);
LSR(ARM64Reg::X1, ARM64Reg::X1, 11);
BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12);
B(positive_normal);
SetJumpTarget(nan_or_inf); SetJumpTarget(nan_or_inf);
MOVI2R(ARM64Reg::X3, Common::BitCast<u64>(-std::numeric_limits<double>::infinity())); MOVI2R(ARM64Reg::X2, Common::BitCast<u64>(-std::numeric_limits<double>::infinity()));
CMP(ARM64Reg::X1, ARM64Reg::X3); CMP(ARM64Reg::X1, ARM64Reg::X2);
FixupBranch nan_or_positive_inf = B(CCFlags::CC_NEQ); B(CCFlags::CC_NEQ, done);
SetJumpTarget(negative); SetJumpTarget(negative);
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr)); TBNZ(ARM64Reg::W3, 9, done);
FixupBranch skip_set_vxsqrt = TBNZ(ARM64Reg::W3, 9);
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_VXSQRT, ARM64Reg::W2); ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_VXSQRT, ARM64Reg::W2);
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr)); B(store_fpscr);
SetJumpTarget(skip_set_vxsqrt);
SetJumpTarget(nan_or_positive_inf);
RET();
} }
// Input in X0, output in W1, clobbers X0-X3 and flags. // Input in X0, output in W1, clobbers X0-X3 and flags.
@ -438,25 +441,17 @@ void JitArm64::GenerateFPRF(bool single)
const auto reg_encoder = single ? EncodeRegTo32 : EncodeRegTo64; const auto reg_encoder = single ? EncodeRegTo32 : EncodeRegTo64;
const ARM64Reg input_reg = reg_encoder(ARM64Reg::W0); const ARM64Reg input_reg = reg_encoder(ARM64Reg::W0);
const ARM64Reg temp_reg = reg_encoder(ARM64Reg::W1); const ARM64Reg cls_reg = reg_encoder(ARM64Reg::W1);
const ARM64Reg exp_reg = reg_encoder(ARM64Reg::W2); const ARM64Reg exp_and_frac_reg = reg_encoder(ARM64Reg::W2);
constexpr ARM64Reg fprf_reg = ARM64Reg::W3; constexpr ARM64Reg fprf_reg = ARM64Reg::W3;
constexpr ARM64Reg fpscr_reg = ARM64Reg::W4; constexpr ARM64Reg fpscr_reg = ARM64Reg::W4;
const int input_size = single ? 32 : 64; const int input_size = single ? 32 : 64;
const u64 input_exp_mask = single ? Common::FLOAT_EXP : Common::DOUBLE_EXP; const int input_exp_size = single ? Common::FLOAT_EXP_WIDTH : Common::DOUBLE_EXP_WIDTH;
const u64 input_frac_mask = single ? Common::FLOAT_FRAC : Common::DOUBLE_FRAC; const u64 input_frac_mask = single ? Common::FLOAT_FRAC : Common::DOUBLE_FRAC;
constexpr u32 output_sign_mask = 0xC; constexpr u32 output_sign_mask = 0xC;
// This code is duplicated for the most common cases for performance.
// For the less common cases, we branch to an existing copy of this code.
auto emit_write_fprf_and_ret = [&] {
BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH);
STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
RET();
};
// First of all, start the load of the old FPSCR value, in case it takes a while // First of all, start the load of the old FPSCR value, in case it takes a while
LDR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr)); LDR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
@ -464,33 +459,30 @@ void JitArm64::GenerateFPRF(bool single)
MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN); MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN);
MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN); MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN);
CMP(input_reg, 0); // Grab sign bit (conveniently the same bit for floats as for integers) CMP(input_reg, 0); // Grab sign bit (conveniently the same bit for floats as for integers)
LSL(exp_and_frac_reg, input_reg, 1);
CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT); CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT);
CLS(cls_reg, exp_and_frac_reg);
AND(exp_reg, input_reg, LogicalImm(input_exp_mask, input_size)); // Grab exponent FixupBranch not_zero = CBNZ(exp_and_frac_reg);
FixupBranch zero_or_denormal = CBZ(exp_reg);
// exp != 0
MOVI2R(temp_reg, input_exp_mask);
CMP(exp_reg, temp_reg);
FixupBranch nan_or_inf = B(CCFlags::CC_EQ);
// exp != 0 && exp != EXP_MASK
emit_write_fprf_and_ret();
// exp == 0
SetJumpTarget(zero_or_denormal);
TST(input_reg, LogicalImm(input_frac_mask, input_size));
FixupBranch denormal = B(CCFlags::CC_NEQ);
// exp == 0 && frac == 0 // exp == 0 && frac == 0
LSR(ARM64Reg::W1, fprf_reg, 3); LSR(ARM64Reg::W1, fprf_reg, 3);
MOVI2R(fprf_reg, Common::PPC_FPCLASS_PZ & ~output_sign_mask); MOVI2R(fprf_reg, Common::PPC_FPCLASS_PZ & ~output_sign_mask);
BFI(fprf_reg, ARM64Reg::W1, 4, 1); BFI(fprf_reg, ARM64Reg::W1, 4, 1);
const u8* write_fprf_and_ret = GetCodePtr(); const u8* write_fprf_and_ret = GetCodePtr();
emit_write_fprf_and_ret(); BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH);
STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
RET();
// exp != 0 || frac != 0
SetJumpTarget(not_zero);
CMP(cls_reg, input_exp_size - 1);
B(CCFlags::CC_LO, write_fprf_and_ret); // Branch if input is normal
// exp == EXP_MASK || (exp == 0 && frac != 0)
FixupBranch nan_or_inf = TBNZ(input_reg, input_size - 2);
// exp == 0 && frac != 0 // exp == 0 && frac != 0
SetJumpTarget(denormal);
ORR(fprf_reg, fprf_reg, LogicalImm(Common::PPC_FPCLASS_PD & ~output_sign_mask, 32)); ORR(fprf_reg, fprf_reg, LogicalImm(Common::PPC_FPCLASS_PD & ~output_sign_mask, 32));
B(write_fprf_and_ret); B(write_fprf_and_ret);

View File

@ -84,6 +84,8 @@ TEST(JitArm64, FPRF)
const u32 expected_double = RunUpdateFPRF( const u32 expected_double = RunUpdateFPRF(
ppc_state, [&] { ppc_state.UpdateFPRFDouble(Common::BitCast<double>(double_input)); }); ppc_state, [&] { ppc_state.UpdateFPRFDouble(Common::BitCast<double>(double_input)); });
const u32 actual_double = RunUpdateFPRF(ppc_state, [&] { test.fprf_double(double_input); }); const u32 actual_double = RunUpdateFPRF(ppc_state, [&] { test.fprf_double(double_input); });
if (expected_double != actual_double)
fmt::print("{:016x} -> {:08x} == {:08x}\n", double_input, actual_double, expected_double);
EXPECT_EQ(expected_double, actual_double); EXPECT_EQ(expected_double, actual_double);
const u32 single_input = ConvertToSingle(double_input); const u32 single_input = ConvertToSingle(double_input);
@ -91,6 +93,8 @@ TEST(JitArm64, FPRF)
const u32 expected_single = RunUpdateFPRF( const u32 expected_single = RunUpdateFPRF(
ppc_state, [&] { ppc_state.UpdateFPRFSingle(Common::BitCast<float>(single_input)); }); ppc_state, [&] { ppc_state.UpdateFPRFSingle(Common::BitCast<float>(single_input)); });
const u32 actual_single = RunUpdateFPRF(ppc_state, [&] { test.fprf_single(single_input); }); const u32 actual_single = RunUpdateFPRF(ppc_state, [&] { test.fprf_single(single_input); });
if (expected_single != actual_single)
fmt::print("{:08x} -> {:08x} == {:08x}\n", single_input, actual_single, expected_single);
EXPECT_EQ(expected_single, actual_single); EXPECT_EQ(expected_single, actual_single);
} }
} }