Merge pull request #12235 from JosJuice/jitarm64-float-cls
JitArm64: Use LSL+CLS for classifying floats
This commit is contained in:
commit
bfc6bca583
|
@ -23,11 +23,15 @@ static constexpr u64 DOUBLE_SIGN = 0x8000000000000000ULL;
|
||||||
static constexpr u64 DOUBLE_EXP = 0x7FF0000000000000ULL;
|
static constexpr u64 DOUBLE_EXP = 0x7FF0000000000000ULL;
|
||||||
static constexpr u64 DOUBLE_FRAC = 0x000FFFFFFFFFFFFFULL;
|
static constexpr u64 DOUBLE_FRAC = 0x000FFFFFFFFFFFFFULL;
|
||||||
static constexpr u64 DOUBLE_ZERO = 0x0000000000000000ULL;
|
static constexpr u64 DOUBLE_ZERO = 0x0000000000000000ULL;
|
||||||
|
static constexpr int DOUBLE_EXP_WIDTH = 11;
|
||||||
|
static constexpr int DOUBLE_FRAC_WIDTH = 52;
|
||||||
|
|
||||||
static constexpr u32 FLOAT_SIGN = 0x80000000;
|
static constexpr u32 FLOAT_SIGN = 0x80000000;
|
||||||
static constexpr u32 FLOAT_EXP = 0x7F800000;
|
static constexpr u32 FLOAT_EXP = 0x7F800000;
|
||||||
static constexpr u32 FLOAT_FRAC = 0x007FFFFF;
|
static constexpr u32 FLOAT_FRAC = 0x007FFFFF;
|
||||||
static constexpr u32 FLOAT_ZERO = 0x00000000;
|
static constexpr u32 FLOAT_ZERO = 0x00000000;
|
||||||
|
static constexpr int FLOAT_EXP_WIDTH = 8;
|
||||||
|
static constexpr int FLOAT_FRAC_WIDTH = 23;
|
||||||
|
|
||||||
inline bool IsQNAN(double d)
|
inline bool IsQNAN(double d)
|
||||||
{
|
{
|
||||||
|
|
|
@ -315,25 +315,14 @@ void JitArm64::GenerateFrsqrte()
|
||||||
// inf, even the mantissa matches. But the mantissa does not match for most other inputs, so in
|
// inf, even the mantissa matches. But the mantissa does not match for most other inputs, so in
|
||||||
// the normal case we calculate the mantissa using the table-based algorithm from the interpreter.
|
// the normal case we calculate the mantissa using the table-based algorithm from the interpreter.
|
||||||
|
|
||||||
|
LSL(ARM64Reg::X2, ARM64Reg::X1, 1);
|
||||||
m_float_emit.FMOV(ARM64Reg::X0, ARM64Reg::D0);
|
m_float_emit.FMOV(ARM64Reg::X0, ARM64Reg::D0);
|
||||||
TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP | Common::DOUBLE_FRAC, 64));
|
CLS(ARM64Reg::X3, ARM64Reg::X2);
|
||||||
FixupBranch zero = B(CCFlags::CC_EQ);
|
TST(ARM64Reg::X1, LogicalImm(Common::DOUBLE_SIGN, 64));
|
||||||
AND(ARM64Reg::X2, ARM64Reg::X1, LogicalImm(Common::DOUBLE_EXP, 64));
|
CCMP(ARM64Reg::X3, Common::DOUBLE_EXP_WIDTH - 1, 0b0010, CCFlags::CC_EQ);
|
||||||
MOVI2R(ARM64Reg::X3, Common::DOUBLE_EXP);
|
FixupBranch not_positive_normal = B(CCFlags::CC_HS);
|
||||||
CMP(ARM64Reg::X2, ARM64Reg::X3);
|
|
||||||
FixupBranch nan_or_inf = B(CCFlags::CC_EQ);
|
|
||||||
FixupBranch negative = TBNZ(ARM64Reg::X1, 63);
|
|
||||||
FixupBranch normal = CBNZ(ARM64Reg::X2);
|
|
||||||
|
|
||||||
// "Normalize" denormal values.
|
const u8* positive_normal = GetCodePtr();
|
||||||
// The simplified calculation used here results in the upper 11 bits being incorrect,
|
|
||||||
// but that's fine, because the code below never reads those bits.
|
|
||||||
CLZ(ARM64Reg::X3, ARM64Reg::X1);
|
|
||||||
LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3);
|
|
||||||
LSR(ARM64Reg::X1, ARM64Reg::X1, 11);
|
|
||||||
BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12);
|
|
||||||
|
|
||||||
SetJumpTarget(normal);
|
|
||||||
UBFX(ARM64Reg::X2, ARM64Reg::X1, 48, 5);
|
UBFX(ARM64Reg::X2, ARM64Reg::X1, 48, 5);
|
||||||
MOVP2R(ARM64Reg::X3, &Common::frsqrte_expected);
|
MOVP2R(ARM64Reg::X3, &Common::frsqrte_expected);
|
||||||
ADD(ARM64Reg::X2, ARM64Reg::X3, ARM64Reg::X2, ArithOption(ARM64Reg::X2, ShiftType::LSL, 3));
|
ADD(ARM64Reg::X2, ARM64Reg::X3, ARM64Reg::X2, ArithOption(ARM64Reg::X2, ShiftType::LSL, 3));
|
||||||
|
@ -344,27 +333,41 @@ void JitArm64::GenerateFrsqrte()
|
||||||
ORR(ARM64Reg::X0, ARM64Reg::X0, ARM64Reg::X1, ArithOption(ARM64Reg::X1, ShiftType::LSL, 26));
|
ORR(ARM64Reg::X0, ARM64Reg::X0, ARM64Reg::X1, ArithOption(ARM64Reg::X1, ShiftType::LSL, 26));
|
||||||
RET();
|
RET();
|
||||||
|
|
||||||
SetJumpTarget(zero);
|
SetJumpTarget(not_positive_normal);
|
||||||
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
||||||
|
FixupBranch not_positive_normal_not_zero = CBNZ(ARM64Reg::X2);
|
||||||
|
|
||||||
|
// Zero
|
||||||
FixupBranch skip_set_zx = TBNZ(ARM64Reg::W3, 26);
|
FixupBranch skip_set_zx = TBNZ(ARM64Reg::W3, 26);
|
||||||
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_ZX, ARM64Reg::W2);
|
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_ZX, ARM64Reg::W2);
|
||||||
|
const u8* store_fpscr = GetCodePtr();
|
||||||
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
||||||
SetJumpTarget(skip_set_zx);
|
SetJumpTarget(skip_set_zx);
|
||||||
|
const u8* done = GetCodePtr();
|
||||||
RET();
|
RET();
|
||||||
|
|
||||||
|
SetJumpTarget(not_positive_normal_not_zero);
|
||||||
|
FixupBranch nan_or_inf = TBNZ(ARM64Reg::X1, 62);
|
||||||
|
FixupBranch negative = TBNZ(ARM64Reg::X1, 63);
|
||||||
|
|
||||||
|
// "Normalize" denormal values.
|
||||||
|
// The simplified calculation used here results in the upper 11 bits being incorrect,
|
||||||
|
// but that's fine, because the code we jump to never reads those bits.
|
||||||
|
CLZ(ARM64Reg::X3, ARM64Reg::X1);
|
||||||
|
LSLV(ARM64Reg::X1, ARM64Reg::X1, ARM64Reg::X3);
|
||||||
|
LSR(ARM64Reg::X1, ARM64Reg::X1, 11);
|
||||||
|
BFI(ARM64Reg::X1, ARM64Reg::X3, 52, 12);
|
||||||
|
B(positive_normal);
|
||||||
|
|
||||||
SetJumpTarget(nan_or_inf);
|
SetJumpTarget(nan_or_inf);
|
||||||
MOVI2R(ARM64Reg::X3, Common::BitCast<u64>(-std::numeric_limits<double>::infinity()));
|
MOVI2R(ARM64Reg::X2, Common::BitCast<u64>(-std::numeric_limits<double>::infinity()));
|
||||||
CMP(ARM64Reg::X1, ARM64Reg::X3);
|
CMP(ARM64Reg::X1, ARM64Reg::X2);
|
||||||
FixupBranch nan_or_positive_inf = B(CCFlags::CC_NEQ);
|
B(CCFlags::CC_NEQ, done);
|
||||||
|
|
||||||
SetJumpTarget(negative);
|
SetJumpTarget(negative);
|
||||||
LDR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
TBNZ(ARM64Reg::W3, 9, done);
|
||||||
FixupBranch skip_set_vxsqrt = TBNZ(ARM64Reg::W3, 9);
|
|
||||||
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_VXSQRT, ARM64Reg::W2);
|
ORRI2R(ARM64Reg::W3, ARM64Reg::W3, FPSCR_FX | FPSCR_VXSQRT, ARM64Reg::W2);
|
||||||
STR(IndexType::Unsigned, ARM64Reg::W3, PPC_REG, PPCSTATE_OFF(fpscr));
|
B(store_fpscr);
|
||||||
SetJumpTarget(skip_set_vxsqrt);
|
|
||||||
SetJumpTarget(nan_or_positive_inf);
|
|
||||||
RET();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Input in X0, output in W1, clobbers X0-X3 and flags.
|
// Input in X0, output in W1, clobbers X0-X3 and flags.
|
||||||
|
@ -438,59 +441,50 @@ void JitArm64::GenerateFPRF(bool single)
|
||||||
const auto reg_encoder = single ? EncodeRegTo32 : EncodeRegTo64;
|
const auto reg_encoder = single ? EncodeRegTo32 : EncodeRegTo64;
|
||||||
|
|
||||||
const ARM64Reg input_reg = reg_encoder(ARM64Reg::W0);
|
const ARM64Reg input_reg = reg_encoder(ARM64Reg::W0);
|
||||||
const ARM64Reg temp_reg = reg_encoder(ARM64Reg::W1);
|
const ARM64Reg cls_reg = reg_encoder(ARM64Reg::W2);
|
||||||
const ARM64Reg exp_reg = reg_encoder(ARM64Reg::W2);
|
|
||||||
|
|
||||||
constexpr ARM64Reg fprf_reg = ARM64Reg::W3;
|
constexpr ARM64Reg fprf_reg = ARM64Reg::W3;
|
||||||
constexpr ARM64Reg fpscr_reg = ARM64Reg::W4;
|
constexpr ARM64Reg fpscr_reg = ARM64Reg::W4;
|
||||||
|
|
||||||
const int input_size = single ? 32 : 64;
|
const int input_size = single ? 32 : 64;
|
||||||
const u64 input_exp_mask = single ? Common::FLOAT_EXP : Common::DOUBLE_EXP;
|
const int input_exp_size = single ? Common::FLOAT_EXP_WIDTH : Common::DOUBLE_EXP_WIDTH;
|
||||||
const u64 input_frac_mask = single ? Common::FLOAT_FRAC : Common::DOUBLE_FRAC;
|
const u64 input_frac_mask = single ? Common::FLOAT_FRAC : Common::DOUBLE_FRAC;
|
||||||
constexpr u32 output_sign_mask = 0xC;
|
constexpr u32 output_sign_mask = 0xC;
|
||||||
|
|
||||||
// This code is duplicated for the most common cases for performance.
|
|
||||||
// For the less common cases, we branch to an existing copy of this code.
|
|
||||||
auto emit_write_fprf_and_ret = [&] {
|
|
||||||
BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH);
|
|
||||||
STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
|
|
||||||
RET();
|
|
||||||
};
|
|
||||||
|
|
||||||
// First of all, start the load of the old FPSCR value, in case it takes a while
|
// First of all, start the load of the old FPSCR value, in case it takes a while
|
||||||
LDR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
|
LDR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
|
||||||
|
|
||||||
// Most branches handle the sign in the same way. Perform that handling before branching
|
|
||||||
MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN);
|
|
||||||
MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN);
|
|
||||||
CMP(input_reg, 0); // Grab sign bit (conveniently the same bit for floats as for integers)
|
CMP(input_reg, 0); // Grab sign bit (conveniently the same bit for floats as for integers)
|
||||||
CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT);
|
LSL(cls_reg, input_reg, 1);
|
||||||
|
FixupBranch not_zero = CBNZ(cls_reg);
|
||||||
AND(exp_reg, input_reg, LogicalImm(input_exp_mask, input_size)); // Grab exponent
|
|
||||||
FixupBranch zero_or_denormal = CBZ(exp_reg);
|
|
||||||
|
|
||||||
// exp != 0
|
|
||||||
MOVI2R(temp_reg, input_exp_mask);
|
|
||||||
CMP(exp_reg, temp_reg);
|
|
||||||
FixupBranch nan_or_inf = B(CCFlags::CC_EQ);
|
|
||||||
|
|
||||||
// exp != 0 && exp != EXP_MASK
|
|
||||||
emit_write_fprf_and_ret();
|
|
||||||
|
|
||||||
// exp == 0
|
|
||||||
SetJumpTarget(zero_or_denormal);
|
|
||||||
TST(input_reg, LogicalImm(input_frac_mask, input_size));
|
|
||||||
FixupBranch denormal = B(CCFlags::CC_NEQ);
|
|
||||||
|
|
||||||
// exp == 0 && frac == 0
|
// exp == 0 && frac == 0
|
||||||
LSR(ARM64Reg::W1, fprf_reg, 3);
|
MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PZ);
|
||||||
MOVI2R(fprf_reg, Common::PPC_FPCLASS_PZ & ~output_sign_mask);
|
MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NZ);
|
||||||
BFI(fprf_reg, ARM64Reg::W1, 4, 1);
|
CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT);
|
||||||
|
|
||||||
const u8* write_fprf_and_ret = GetCodePtr();
|
const u8* write_fprf_and_ret = GetCodePtr();
|
||||||
emit_write_fprf_and_ret();
|
BFI(fpscr_reg, fprf_reg, FPRF_SHIFT, FPRF_WIDTH);
|
||||||
|
STR(IndexType::Unsigned, fpscr_reg, PPC_REG, PPCSTATE_OFF(fpscr));
|
||||||
|
RET();
|
||||||
|
|
||||||
|
// exp != 0 || frac != 0
|
||||||
|
SetJumpTarget(not_zero);
|
||||||
|
CLS(cls_reg, cls_reg);
|
||||||
|
|
||||||
|
// All branches except the zero branch handle the sign in the same way.
|
||||||
|
// Perform that handling before branching further
|
||||||
|
MOVI2R(ARM64Reg::W3, Common::PPC_FPCLASS_PN);
|
||||||
|
MOVI2R(ARM64Reg::W1, Common::PPC_FPCLASS_NN);
|
||||||
|
CSEL(fprf_reg, ARM64Reg::W1, ARM64Reg::W3, CCFlags::CC_LT);
|
||||||
|
|
||||||
|
CMP(cls_reg, input_exp_size - 1);
|
||||||
|
B(CCFlags::CC_LO, write_fprf_and_ret); // Branch if input is normal
|
||||||
|
|
||||||
|
// exp == EXP_MASK || (exp == 0 && frac != 0)
|
||||||
|
FixupBranch nan_or_inf = TBNZ(input_reg, input_size - 2);
|
||||||
|
|
||||||
// exp == 0 && frac != 0
|
// exp == 0 && frac != 0
|
||||||
SetJumpTarget(denormal);
|
|
||||||
ORR(fprf_reg, fprf_reg, LogicalImm(Common::PPC_FPCLASS_PD & ~output_sign_mask, 32));
|
ORR(fprf_reg, fprf_reg, LogicalImm(Common::PPC_FPCLASS_PD & ~output_sign_mask, 32));
|
||||||
B(write_fprf_and_ret);
|
B(write_fprf_and_ret);
|
||||||
|
|
||||||
|
|
|
@ -84,6 +84,8 @@ TEST(JitArm64, FPRF)
|
||||||
const u32 expected_double = RunUpdateFPRF(
|
const u32 expected_double = RunUpdateFPRF(
|
||||||
ppc_state, [&] { ppc_state.UpdateFPRFDouble(Common::BitCast<double>(double_input)); });
|
ppc_state, [&] { ppc_state.UpdateFPRFDouble(Common::BitCast<double>(double_input)); });
|
||||||
const u32 actual_double = RunUpdateFPRF(ppc_state, [&] { test.fprf_double(double_input); });
|
const u32 actual_double = RunUpdateFPRF(ppc_state, [&] { test.fprf_double(double_input); });
|
||||||
|
if (expected_double != actual_double)
|
||||||
|
fmt::print("{:016x} -> {:08x} == {:08x}\n", double_input, actual_double, expected_double);
|
||||||
EXPECT_EQ(expected_double, actual_double);
|
EXPECT_EQ(expected_double, actual_double);
|
||||||
|
|
||||||
const u32 single_input = ConvertToSingle(double_input);
|
const u32 single_input = ConvertToSingle(double_input);
|
||||||
|
@ -91,6 +93,8 @@ TEST(JitArm64, FPRF)
|
||||||
const u32 expected_single = RunUpdateFPRF(
|
const u32 expected_single = RunUpdateFPRF(
|
||||||
ppc_state, [&] { ppc_state.UpdateFPRFSingle(Common::BitCast<float>(single_input)); });
|
ppc_state, [&] { ppc_state.UpdateFPRFSingle(Common::BitCast<float>(single_input)); });
|
||||||
const u32 actual_single = RunUpdateFPRF(ppc_state, [&] { test.fprf_single(single_input); });
|
const u32 actual_single = RunUpdateFPRF(ppc_state, [&] { test.fprf_single(single_input); });
|
||||||
|
if (expected_single != actual_single)
|
||||||
|
fmt::print("{:08x} -> {:08x} == {:08x}\n", single_input, actual_single, expected_single);
|
||||||
EXPECT_EQ(expected_single, actual_single);
|
EXPECT_EQ(expected_single, actual_single);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue