From 255ee3fdce7bcdd230357d2bc8f3da143e510494 Mon Sep 17 00:00:00 2001 From: JosJuice Date: Fri, 13 Oct 2023 19:27:03 +0200 Subject: [PATCH 1/2] JitArm64: Use LSL+CLS for classifying floats This is a little trick I came up with that lets us restructure our float classification code so we can exit earlier when the float is normal, which is the case more often than not. First we shift left by 1 to get rid of the sign bit, and then we count the number of leading sign bits. If the result is less than 10 (for doubles) or 7 (for floats), the float is normal. This is because, if the float isn't normal, the exponent is either all zeroes or all ones. --- Source/Core/Common/FloatUtils.h | 4 + Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp | 104 ++++++++---------- .../UnitTests/Core/PowerPC/JitArm64/FPRF.cpp | 4 + 3 files changed, 56 insertions(+), 56 deletions(-) diff --git a/Source/Core/Common/FloatUtils.h b/Source/Core/Common/FloatUtils.h index b1c16d172c..82942c4b19 100644 --- a/Source/Core/Common/FloatUtils.h +++ b/Source/Core/Common/FloatUtils.h @@ -23,11 +23,15 @@ static constexpr u64 DOUBLE_SIGN = 0x8000000000000000ULL; static constexpr u64 DOUBLE_EXP = 0x7FF0000000000000ULL; static constexpr u64 DOUBLE_FRAC = 0x000FFFFFFFFFFFFFULL; static constexpr u64 DOUBLE_ZERO = 0x0000000000000000ULL; +static constexpr int DOUBLE_EXP_WIDTH = 11; +static constexpr int DOUBLE_FRAC_WIDTH = 52; static constexpr u32 FLOAT_SIGN = 0x80000000; static constexpr u32 FLOAT_EXP = 0x7F800000; static constexpr u32 FLOAT_FRAC = 0x007FFFFF; static constexpr u32 FLOAT_ZERO = 0x00000000; +static constexpr int FLOAT_EXP_WIDTH = 8; +static constexpr int FLOAT_FRAC_WIDTH = 23; inline bool IsQNAN(double d) { diff --git a/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp b/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp index 5c5033d1e3..f72ef3bb1d 100644 --- a/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp +++ b/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp @@ -315,25 +315,14 @@ void JitArm64::GenerateFrsqrte() // inf, even the mantissa matches. But the mantissa does not match for most other inputs, so in // the normal case we calculate the mantissa using the table-based algorithm from the interpreter. + LSL(ARM64Reg::X2, ARM64Reg::X1, 1); m_float_emit.FMOV(ARM64Reg::X0, ARM64Reg::D0); - TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP | Common::DOUBLE_FRAC, 64)); - FixupBranch zero = B(CCFlags::CC_EQ); - AND(ARM64Reg::X2, ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP, 64)); - MOVI2R(ARM64Reg::X3, Common::DOUBLE_EXP); - CMP(ARM64Reg::X2, ARM64Reg::X3); - FixupBranch nan_or_inf = B(CCFlags::CC_EQ); - FixupBranch negative = TBNZ(ARM64Reg::X1, 63); - FixupBranch normal = CBNZ(ARM64Reg::X2); + CLS(ARM64Reg::X3, ARM64Reg::X2); + TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_SIGN, 64)); + CCMP(ARM64Reg::X3, Common::DOUBLE_EXP_WIDTH - 1, 0b0010, CCFlags::CC_EQ); + FixupBranch not_positive_normal = B(CCFlags::CC_HS); - // "Normalize" denormal values. - // The simplified calculation used here results in the upper 11 bits being incorrect, - // but that's fine, because the code below never reads those bits. - CLZ(ARM64Reg::X3, ARM64Reg::X1); - LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3); - LSR(ARM64Reg::X1, ARM64Reg::X1, 11); - BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12); - - SetJumpTarget(normal); + const u8* positive_normal = GetCodePtr(); UBFX(ARM64Reg::X2, ARM64Reg::X1, 48, 5); MOVP2R(ARM64Reg::X3, &Common::frsqrte_expected); ADD(ARM64Reg::X2, ARM64Reg::X3, ARM64Reg::X2, ArithOption(ARM64Reg::X2, ShiftType::LSL, 3)); @@ -344,27 +333,41 @@ void JitArm64::GenerateFrsqrte() ORR(ARM64Reg::X0, ARM64Reg::X0, ARM64Reg::X1, ArithOption(ARM64Reg::X1, ShiftType::LSL, 26)); RET(); - SetJumpTarget(zero); + SetJumpTarget(not_positive_normal); LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr)); + FixupBranch not_positive_normal_not_zero = CBNZ(ARM64Reg::X2); + + // Zero FixupBranch skip_set_zx = TBNZ(ARM64Reg::W3, 26); ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_ZX, ARM64Reg::W2); + const u8* store_fpscr = GetCodePtr(); STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr)); SetJumpTarget(skip_set_zx); + const u8* done = GetCodePtr(); RET(); + SetJumpTarget(not_positive_normal_not_zero); + FixupBranch nan_or_inf = TBNZ(ARM64Reg::X1, 62); + FixupBranch negative = TBNZ(ARM64Reg::X1, 63); + + // "Normalize" denormal values. + // The simplified calculation used here results in the upper 11 bits being incorrect, + // but that's fine, because the code we jump to never reads those bits. + CLZ(ARM64Reg::X3, ARM64Reg::X1); + LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3); + LSR(ARM64Reg::X1, ARM64Reg::X1, 11); + BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12); + B(positive_normal); + SetJumpTarget(nan_or_inf); - MOVI2R(ARM64Reg::X3, Common::BitCast(-std::numeric_limits::infinity())); - CMP(ARM64Reg::X1, ARM64Reg::X3); - FixupBranch nan_or_positive_inf = B(CCFlags::CC_NEQ); + MOVI2R(ARM64Reg::X2, Common::BitCast(-std::numeric_limits::infinity())); + CMP(ARM64Reg::X1, ARM64Reg::X2); + B(CCFlags::CC_NEQ, done); SetJumpTarget(negative); - LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr)); - FixupBranch skip_set_vxsqrt = TBNZ(ARM64Reg::W3, 9); + TBNZ(ARM64Reg::W3, 9, done); ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_VXSQRT, ARM64Reg::W2); - STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr)); - SetJumpTarget(skip_set_vxsqrt); - SetJumpTarget(nan_or_positive_inf); - RET(); + B(store_fpscr); } // Input in X0, output in W1, clobbers X0-X3 and flags. @@ -438,25 +441,17 @@ void JitArm64::GenerateFPRF(bool single) const auto reg_encoder = single ? EncodeRegTo32 : EncodeRegTo64; const ARM64Reg input_reg = reg_encoder(ARM64Reg::W0); - const ARM64Reg temp_reg = reg_encoder(ARM64Reg::W1); - const ARM64Reg exp_reg = reg_encoder(ARM64Reg::W2); + const ARM64Reg cls_reg = reg_encoder(ARM64Reg::W1); + const ARM64Reg exp_and_frac_reg = reg_encoder(ARM64Reg::W2); constexpr ARM64Reg fprf_reg = ARM64Reg::W3; constexpr ARM64Reg fpscr_reg = ARM64Reg::W4; const int input_size = single ? 32 : 64; - const u64 input_exp_mask = single ? Common::FLOAT_EXP : Common::DOUBLE_EXP; + const int input_exp_size = single ? Common::FLOAT_EXP_WIDTH : Common::DOUBLE_EXP_WIDTH; const u64 input_frac_mask = single ? Common::FLOAT_FRAC : Common::DOUBLE_FRAC; constexpr u32 output_sign_mask = 0xC; - // This code is duplicated for the most common cases for performance. - // For the less common cases, we branch to an existing copy of this code. - auto emit_write_fprf_and_ret = [&] { - BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH); - STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr)); - RET(); - }; - // First of all, start the load of the old FPSCR value, in case it takes a while LDR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr)); @@ -464,33 +459,30 @@ void JitArm64::GenerateFPRF(bool single) MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN); MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN); CMP(input_reg, 0); // Grab sign bit (conveniently the same bit for floats as for integers) + LSL(exp_and_frac_reg, input_reg, 1); CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT); - - AND(exp_reg, input_reg, LogicalImm(input_exp_mask, input_size)); // Grab exponent - FixupBranch zero_or_denormal = CBZ(exp_reg); - - // exp != 0 - MOVI2R(temp_reg, input_exp_mask); - CMP(exp_reg, temp_reg); - FixupBranch nan_or_inf = B(CCFlags::CC_EQ); - - // exp != 0 && exp != EXP_MASK - emit_write_fprf_and_ret(); - - // exp == 0 - SetJumpTarget(zero_or_denormal); - TST(input_reg, LogicalImm(input_frac_mask, input_size)); - FixupBranch denormal = B(CCFlags::CC_NEQ); + CLS(cls_reg, exp_and_frac_reg); + FixupBranch not_zero = CBNZ(exp_and_frac_reg); // exp == 0 && frac == 0 LSR(ARM64Reg::W1, fprf_reg, 3); MOVI2R(fprf_reg, Common::PPC_FPCLASS_PZ & ~output_sign_mask); BFI(fprf_reg, ARM64Reg::W1, 4, 1); + const u8* write_fprf_and_ret = GetCodePtr(); - emit_write_fprf_and_ret(); + BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH); + STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr)); + RET(); + + // exp != 0 || frac != 0 + SetJumpTarget(not_zero); + CMP(cls_reg, input_exp_size - 1); + B(CCFlags::CC_LO, write_fprf_and_ret); // Branch if input is normal + + // exp == EXP_MASK || (exp == 0 && frac != 0) + FixupBranch nan_or_inf = TBNZ(input_reg, input_size - 2); // exp == 0 && frac != 0 - SetJumpTarget(denormal); ORR(fprf_reg, fprf_reg, LogicalImm(Common::PPC_FPCLASS_PD & ~output_sign_mask, 32)); B(write_fprf_and_ret); diff --git a/Source/UnitTests/Core/PowerPC/JitArm64/FPRF.cpp b/Source/UnitTests/Core/PowerPC/JitArm64/FPRF.cpp index 0b7e250ae6..cb19980949 100644 --- a/Source/UnitTests/Core/PowerPC/JitArm64/FPRF.cpp +++ b/Source/UnitTests/Core/PowerPC/JitArm64/FPRF.cpp @@ -84,6 +84,8 @@ TEST(JitArm64, FPRF) const u32 expected_double = RunUpdateFPRF( ppc_state, [&] { ppc_state.UpdateFPRFDouble(Common::BitCast(double_input)); }); const u32 actual_double = RunUpdateFPRF(ppc_state, [&] { test.fprf_double(double_input); }); + if (expected_double != actual_double) + fmt::print("{:016x} -> {:08x} == {:08x}\n", double_input, actual_double, expected_double); EXPECT_EQ(expected_double, actual_double); const u32 single_input = ConvertToSingle(double_input); @@ -91,6 +93,8 @@ TEST(JitArm64, FPRF) const u32 expected_single = RunUpdateFPRF( ppc_state, [&] { ppc_state.UpdateFPRFSingle(Common::BitCast(single_input)); }); const u32 actual_single = RunUpdateFPRF(ppc_state, [&] { test.fprf_single(single_input); }); + if (expected_single != actual_single) + fmt::print("{:08x} -> {:08x} == {:08x}\n", single_input, actual_single, expected_single); EXPECT_EQ(expected_single, actual_single); } } From d5ec5c005a028a80f78eb1726739a5d4b6c6e9c0 Mon Sep 17 00:00:00 2001 From: JosJuice Date: Fri, 13 Oct 2023 20:17:33 +0200 Subject: [PATCH 2/2] JitArm64: Some more FPRF optimization By using MOVI2R+MOVI2R+CSEL in the zero case instead of doing bitwise operations on the output of the other MOVI2R+MOVI2R+CSEL, we avoid using BFI, an instruction that takes two cycles on most CPUs. The instruction count is the same and the pipelining should be at least equally good. --- Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp | 26 +++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp b/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp index f72ef3bb1d..a10f4ae5d0 100644 --- a/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp +++ b/Source/Core/Core/PowerPC/JitArm64/JitAsm.cpp @@ -441,8 +441,7 @@ void JitArm64::GenerateFPRF(bool single) const auto reg_encoder = single ? EncodeRegTo32 : EncodeRegTo64; const ARM64Reg input_reg = reg_encoder(ARM64Reg::W0); - const ARM64Reg cls_reg = reg_encoder(ARM64Reg::W1); - const ARM64Reg exp_and_frac_reg = reg_encoder(ARM64Reg::W2); + const ARM64Reg cls_reg = reg_encoder(ARM64Reg::W2); constexpr ARM64Reg fprf_reg = ARM64Reg::W3; constexpr ARM64Reg fpscr_reg = ARM64Reg::W4; @@ -455,19 +454,14 @@ void JitArm64::GenerateFPRF(bool single) // First of all, start the load of the old FPSCR value, in case it takes a while LDR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr)); - // Most branches handle the sign in the same way. Perform that handling before branching - MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN); - MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN); CMP(input_reg, 0); // Grab sign bit (conveniently the same bit for floats as for integers) - LSL(exp_and_frac_reg, input_reg, 1); - CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT); - CLS(cls_reg, exp_and_frac_reg); - FixupBranch not_zero = CBNZ(exp_and_frac_reg); + LSL(cls_reg, input_reg, 1); + FixupBranch not_zero = CBNZ(cls_reg); // exp == 0 && frac == 0 - LSR(ARM64Reg::W1, fprf_reg, 3); - MOVI2R(fprf_reg, Common::PPC_FPCLASS_PZ & ~output_sign_mask); - BFI(fprf_reg, ARM64Reg::W1, 4, 1); + MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PZ); + MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NZ); + CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT); const u8* write_fprf_and_ret = GetCodePtr(); BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH); @@ -476,6 +470,14 @@ void JitArm64::GenerateFPRF(bool single) // exp != 0 || frac != 0 SetJumpTarget(not_zero); + CLS(cls_reg, cls_reg); + + // All branches except the zero branch handle the sign in the same way. + // Perform that handling before branching further + MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN); + MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN); + CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT); + CMP(cls_reg, input_exp_size - 1); B(CCFlags::CC_LO, write_fprf_and_ret); // Branch if input is normal