Merge pull request #12235 from JosJuice/jitarm64-float-cls

JitArm64: Use LSL+CLS for classifying floats
This commit is contained in:
Mai 2023-11-28 19:20:01 +01:00 committed by GitHub
commit bfc6bca583
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 63 deletions

View File

@ -23,11 +23,15 @@ static constexpr u64 DOUBLE_SIGN = 0x8000000000000000ULL;
static constexpr u64 DOUBLE_EXP = 0x7FF0000000000000ULL;
static constexpr u64 DOUBLE_FRAC = 0x000FFFFFFFFFFFFFULL;
static constexpr u64 DOUBLE_ZERO = 0x0000000000000000ULL;
static constexpr int DOUBLE_EXP_WIDTH = 11;
static constexpr int DOUBLE_FRAC_WIDTH = 52;
static constexpr u32 FLOAT_SIGN = 0x80000000;
static constexpr u32 FLOAT_EXP = 0x7F800000;
static constexpr u32 FLOAT_FRAC = 0x007FFFFF;
static constexpr u32 FLOAT_ZERO = 0x00000000;
static constexpr int FLOAT_EXP_WIDTH = 8;
static constexpr int FLOAT_FRAC_WIDTH = 23;
inline bool IsQNAN(double d)
{

View File

@ -315,25 +315,14 @@ void JitArm64::GenerateFrsqrte()
// inf, even the mantissa matches. But the mantissa does not match for most other inputs, so in
// the normal case we calculate the mantissa using the table-based algorithm from the interpreter.
LSL(ARM64Reg::X2, ARM64Reg::X1, 1);
m_float_emit.FMOV(ARM64Reg::X0, ARM64Reg::D0);
TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP | Common::DOUBLE_FRAC, 64));
FixupBranch zero = B(CCFlags::CC_EQ);
AND(ARM64Reg::X2, ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP, 64));
MOVI2R(ARM64Reg::X3, Common::DOUBLE_EXP);
CMP(ARM64Reg::X2, ARM64Reg::X3);
FixupBranch nan_or_inf = B(CCFlags::CC_EQ);
FixupBranch negative = TBNZ(ARM64Reg::X1, 63);
FixupBranch normal = CBNZ(ARM64Reg::X2);
CLS(ARM64Reg::X3, ARM64Reg::X2);
TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_SIGN, 64));
CCMP(ARM64Reg::X3, Common::DOUBLE_EXP_WIDTH - 1, 0b0010, CCFlags::CC_EQ);
FixupBranch not_positive_normal = B(CCFlags::CC_HS);
// "Normalize" denormal values.
// The simplified calculation used here results in the upper 11 bits being incorrect,
// but that's fine, because the code below never reads those bits.
CLZ(ARM64Reg::X3, ARM64Reg::X1);
LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3);
LSR(ARM64Reg::X1, ARM64Reg::X1, 11);
BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12);
SetJumpTarget(normal);
const u8* positive_normal = GetCodePtr();
UBFX(ARM64Reg::X2, ARM64Reg::X1, 48, 5);
MOVP2R(ARM64Reg::X3, &Common::frsqrte_expected);
ADD(ARM64Reg::X2, ARM64Reg::X3, ARM64Reg::X2, ArithOption(ARM64Reg::X2, ShiftType::LSL, 3));
@ -344,27 +333,41 @@ void JitArm64::GenerateFrsqrte()
ORR(ARM64Reg::X0, ARM64Reg::X0, ARM64Reg::X1, ArithOption(ARM64Reg::X1, ShiftType::LSL, 26));
RET();
SetJumpTarget(zero);
SetJumpTarget(not_positive_normal);
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
FixupBranch not_positive_normal_not_zero = CBNZ(ARM64Reg::X2);
// Zero
FixupBranch skip_set_zx = TBNZ(ARM64Reg::W3, 26);
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_ZX, ARM64Reg::W2);
const u8* store_fpscr = GetCodePtr();
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
SetJumpTarget(skip_set_zx);
const u8* done = GetCodePtr();
RET();
SetJumpTarget(not_positive_normal_not_zero);
FixupBranch nan_or_inf = TBNZ(ARM64Reg::X1, 62);
FixupBranch negative = TBNZ(ARM64Reg::X1, 63);
// "Normalize" denormal values.
// The simplified calculation used here results in the upper 11 bits being incorrect,
// but that's fine, because the code we jump to never reads those bits.
CLZ(ARM64Reg::X3, ARM64Reg::X1);
LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3);
LSR(ARM64Reg::X1, ARM64Reg::X1, 11);
BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12);
B(positive_normal);
SetJumpTarget(nan_or_inf);
MOVI2R(ARM64Reg::X3, Common::BitCast<u64>(-std::numeric_limits<double>::infinity()));
CMP(ARM64Reg::X1, ARM64Reg::X3);
FixupBranch nan_or_positive_inf = B(CCFlags::CC_NEQ);
MOVI2R(ARM64Reg::X2, Common::BitCast<u64>(-std::numeric_limits<double>::infinity()));
CMP(ARM64Reg::X1, ARM64Reg::X2);
B(CCFlags::CC_NEQ, done);
SetJumpTarget(negative);
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
FixupBranch skip_set_vxsqrt = TBNZ(ARM64Reg::W3, 9);
TBNZ(ARM64Reg::W3, 9, done);
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_VXSQRT, ARM64Reg::W2);
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
SetJumpTarget(skip_set_vxsqrt);
SetJumpTarget(nan_or_positive_inf);
RET();
B(store_fpscr);
}
// Input in X0, output in W1, clobbers X0-X3 and flags.
@ -438,59 +441,50 @@ void JitArm64::GenerateFPRF(bool single)
const auto reg_encoder = single ? EncodeRegTo32 : EncodeRegTo64;
const ARM64Reg input_reg = reg_encoder(ARM64Reg::W0);
const ARM64Reg temp_reg = reg_encoder(ARM64Reg::W1);
const ARM64Reg exp_reg = reg_encoder(ARM64Reg::W2);
const ARM64Reg cls_reg = reg_encoder(ARM64Reg::W2);
constexpr ARM64Reg fprf_reg = ARM64Reg::W3;
constexpr ARM64Reg fpscr_reg = ARM64Reg::W4;
const int input_size = single ? 32 : 64;
const u64 input_exp_mask = single ? Common::FLOAT_EXP : Common::DOUBLE_EXP;
const int input_exp_size = single ? Common::FLOAT_EXP_WIDTH : Common::DOUBLE_EXP_WIDTH;
const u64 input_frac_mask = single ? Common::FLOAT_FRAC : Common::DOUBLE_FRAC;
constexpr u32 output_sign_mask = 0xC;
// This code is duplicated for the most common cases for performance.
// For the less common cases, we branch to an existing copy of this code.
auto emit_write_fprf_and_ret = [&] {
BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH);
STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
RET();
};
// First of all, start the load of the old FPSCR value, in case it takes a while
LDR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
// Most branches handle the sign in the same way. Perform that handling before branching
MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN);
MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN);
CMP(input_reg, 0); // Grab sign bit (conveniently the same bit for floats as for integers)
CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT);
AND(exp_reg, input_reg, LogicalImm(input_exp_mask, input_size)); // Grab exponent
FixupBranch zero_or_denormal = CBZ(exp_reg);
// exp != 0
MOVI2R(temp_reg, input_exp_mask);
CMP(exp_reg, temp_reg);
FixupBranch nan_or_inf = B(CCFlags::CC_EQ);
// exp != 0 && exp != EXP_MASK
emit_write_fprf_and_ret();
// exp == 0
SetJumpTarget(zero_or_denormal);
TST(input_reg, LogicalImm(input_frac_mask, input_size));
FixupBranch denormal = B(CCFlags::CC_NEQ);
LSL(cls_reg, input_reg, 1);
FixupBranch not_zero = CBNZ(cls_reg);
// exp == 0 && frac == 0
LSR(ARM64Reg::W1, fprf_reg, 3);
MOVI2R(fprf_reg, Common::PPC_FPCLASS_PZ & ~output_sign_mask);
BFI(fprf_reg, ARM64Reg::W1, 4, 1);
MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PZ);
MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NZ);
CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT);
const u8* write_fprf_and_ret = GetCodePtr();
emit_write_fprf_and_ret();
BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH);
STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
RET();
// exp != 0 || frac != 0
SetJumpTarget(not_zero);
CLS(cls_reg, cls_reg);
// All branches except the zero branch handle the sign in the same way.
// Perform that handling before branching further
MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN);
MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN);
CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT);
CMP(cls_reg, input_exp_size - 1);
B(CCFlags::CC_LO, write_fprf_and_ret); // Branch if input is normal
// exp == EXP_MASK || (exp == 0 && frac != 0)
FixupBranch nan_or_inf = TBNZ(input_reg, input_size - 2);
// exp == 0 && frac != 0
SetJumpTarget(denormal);
ORR(fprf_reg, fprf_reg, LogicalImm(Common::PPC_FPCLASS_PD & ~output_sign_mask, 32));
B(write_fprf_and_ret);

View File

@ -84,6 +84,8 @@ TEST(JitArm64, FPRF)
const u32 expected_double = RunUpdateFPRF(
ppc_state, [&] { ppc_state.UpdateFPRFDouble(Common::BitCast<double>(double_input)); });
const u32 actual_double = RunUpdateFPRF(ppc_state, [&] { test.fprf_double(double_input); });
if (expected_double != actual_double)
fmt::print("{:016x} -> {:08x} == {:08x}\n", double_input, actual_double, expected_double);
EXPECT_EQ(expected_double, actual_double);
const u32 single_input = ConvertToSingle(double_input);
@ -91,6 +93,8 @@ TEST(JitArm64, FPRF)
const u32 expected_single = RunUpdateFPRF(
ppc_state, [&] { ppc_state.UpdateFPRFSingle(Common::BitCast<float>(single_input)); });
const u32 actual_single = RunUpdateFPRF(ppc_state, [&] { test.fprf_single(single_input); });
if (expected_single != actual_single)
fmt::print("{:08x} -> {:08x} == {:08x}\n", single_input, actual_single, expected_single);
EXPECT_EQ(expected_single, actual_single);
}
}