Jit: Check MSR state in BLR optimization

When we execute a JIT block, we have to make sure that both the PC and
the DR/IR bits of MSR are the same as they were when the block was
compiled. When jumping to a block from the dispatcher, this is done in
the way you would expect: By checking the PC and the relevant MSR bits.
However, when returning to a block using the BLR optimization, we only
check the PC. Checking the MSR bits is done by instead resetting the
stack when the MSR changes, making PC checks afterwards fail.

Except... We were only resetting the stack on rfi instructions. There
are actually many more ways for the MSR to change, and we weren't
covering those at all. I looked into resetting the stack on all of them,
but it would be pretty cumbersome both in terms of writing the code and
in terms of how often at runtime we'd have to reset the stack, so I
think the better option would be to check the MSR bits along with the
PC. That's what this commit implements.
This commit is contained in:
JosJuice 2023-08-27 13:48:22 +02:00
parent 7ac0db70c6
commit 8c2c665af3
5 changed files with 69 additions and 27 deletions

View File

@ -482,7 +482,8 @@ void Jit64::FakeBLCall(u32 after)
// We may need to fake the BLR stack on inlined CALL instructions. // We may need to fake the BLR stack on inlined CALL instructions.
// Else we can't return to this location any more. // Else we can't return to this location any more.
MOV(32, R(RSCRATCH2), Imm32(after)); MOV(64, R(RSCRATCH2),
Imm64(u64(m_ppc_state.msr.Hex & JitBaseBlockCache::JIT_CACHE_MSR_MASK) << 32 | after));
PUSH(RSCRATCH2); PUSH(RSCRATCH2);
FixupBranch skip_exit = CALL(); FixupBranch skip_exit = CALL();
POP(RSCRATCH2); POP(RSCRATCH2);
@ -514,7 +515,8 @@ void Jit64::WriteExit(u32 destination, bool bl, u32 after)
if (bl) if (bl)
{ {
MOV(32, R(RSCRATCH2), Imm32(after)); MOV(64, R(RSCRATCH2),
Imm64(u64(m_ppc_state.msr.Hex & JitBaseBlockCache::JIT_CACHE_MSR_MASK) << 32 | after));
PUSH(RSCRATCH2); PUSH(RSCRATCH2);
} }
@ -571,7 +573,8 @@ void Jit64::WriteExitDestInRSCRATCH(bool bl, u32 after)
if (bl) if (bl)
{ {
MOV(32, R(RSCRATCH2), Imm32(after)); MOV(64, R(RSCRATCH2),
Imm64(u64(m_ppc_state.msr.Hex & JitBaseBlockCache::JIT_CACHE_MSR_MASK) << 32 | after));
PUSH(RSCRATCH2); PUSH(RSCRATCH2);
} }
@ -599,6 +602,13 @@ void Jit64::WriteBLRExit()
bool disturbed = Cleanup(); bool disturbed = Cleanup();
if (disturbed) if (disturbed)
MOV(32, R(RSCRATCH), PPCSTATE(pc)); MOV(32, R(RSCRATCH), PPCSTATE(pc));
const u64 msr_bits = m_ppc_state.msr.Hex & JitBaseBlockCache::JIT_CACHE_MSR_MASK;
if (msr_bits != 0)
{
MOV(32, R(RSCRATCH2), Imm32(msr_bits));
SHL(64, R(RSCRATCH2), Imm8(32));
OR(64, R(RSCRATCH), R(RSCRATCH2));
}
MOV(32, R(RSCRATCH2), Imm32(js.downcountAmount)); MOV(32, R(RSCRATCH2), Imm32(js.downcountAmount));
CMP(64, R(RSCRATCH), MDisp(RSP, 8)); CMP(64, R(RSCRATCH), MDisp(RSP, 8));
J_CC(CC_NE, asm_routines.dispatcher_mispredicted_blr); J_CC(CC_NE, asm_routines.dispatcher_mispredicted_blr);

View File

@ -445,10 +445,6 @@ void Jit64::mtmsr(UGeckoInstruction inst)
gpr.Flush(); gpr.Flush();
fpr.Flush(); fpr.Flush();
// Our jit cache also stores some MSR bits, as they have changed, we either
// have to validate them in the BLR/RET check, or just flush the stack here.
asm_routines.ResetStack(*this);
// If some exceptions are pending and EE are now enabled, force checking // If some exceptions are pending and EE are now enabled, force checking
// external exceptions when going out of mtmsr in order to execute delayed // external exceptions when going out of mtmsr in order to execute delayed
// interrupts as soon as possible. // interrupts as soon as possible.

View File

@ -386,12 +386,21 @@ void JitArm64::WriteExit(u32 destination, bool LK, u32 exit_address_after_return
const u8* host_address_after_return; const u8* host_address_after_return;
if (LK) if (LK)
{ {
// Push {ARM_PC; PPC_PC} on the stack // Push {ARM_PC (64-bit); PPC_PC (32-bit); MSR_BITS (32-bit)} on the stack
ARM64Reg reg_to_push = exit_address_after_return_reg; ARM64Reg reg_to_push = ARM64Reg::X1;
const u64 msr_bits = m_ppc_state.msr.Hex & JitBaseBlockCache::JIT_CACHE_MSR_MASK;
if (exit_address_after_return_reg == ARM64Reg::INVALID_REG) if (exit_address_after_return_reg == ARM64Reg::INVALID_REG)
{ {
MOVI2R(ARM64Reg::X1, exit_address_after_return); MOVI2R(ARM64Reg::X1, msr_bits << 32 | exit_address_after_return);
reg_to_push = ARM64Reg::X1; }
else if (msr_bits == 0)
{
reg_to_push = EncodeRegTo64(exit_address_after_return_reg);
}
else
{
ORRI2R(ARM64Reg::X1, EncodeRegTo64(exit_address_after_return_reg), msr_bits << 32,
ARM64Reg::X1);
} }
constexpr s32 adr_offset = JitArm64BlockCache::BLOCK_LINK_SIZE + sizeof(u32) * 2; constexpr s32 adr_offset = JitArm64BlockCache::BLOCK_LINK_SIZE + sizeof(u32) * 2;
host_address_after_return = GetCodePtr() + adr_offset; host_address_after_return = GetCodePtr() + adr_offset;
@ -481,14 +490,22 @@ void JitArm64::WriteExit(Arm64Gen::ARM64Reg dest, bool LK, u32 exit_address_afte
} }
else else
{ {
// Push {ARM_PC, PPC_PC} on the stack // Push {ARM_PC (64-bit); PPC_PC (32-bit); MSR_BITS (32-bit)} on the stack
ARM64Reg reg_to_push = exit_address_after_return_reg; ARM64Reg reg_to_push = ARM64Reg::X1;
const u64 msr_bits = m_ppc_state.msr.Hex & JitBaseBlockCache::JIT_CACHE_MSR_MASK;
if (exit_address_after_return_reg == ARM64Reg::INVALID_REG) if (exit_address_after_return_reg == ARM64Reg::INVALID_REG)
{ {
MOVI2R(ARM64Reg::X1, exit_address_after_return); MOVI2R(ARM64Reg::X1, msr_bits << 32 | exit_address_after_return);
reg_to_push = ARM64Reg::X1; }
else if (msr_bits == 0)
{
reg_to_push = EncodeRegTo64(exit_address_after_return_reg);
}
else
{
ORRI2R(ARM64Reg::X1, EncodeRegTo64(exit_address_after_return_reg), msr_bits << 32,
ARM64Reg::X1);
} }
MOVI2R(ARM64Reg::X1, exit_address_after_return);
constexpr s32 adr_offset = sizeof(u32) * 3; constexpr s32 adr_offset = sizeof(u32) * 3;
const u8* host_address_after_return = GetCodePtr() + adr_offset; const u8* host_address_after_return = GetCodePtr() + adr_offset;
ADR(ARM64Reg::X0, adr_offset); ADR(ARM64Reg::X0, adr_offset);
@ -544,19 +561,33 @@ void JitArm64::FakeLKExit(u32 exit_address_after_return, ARM64Reg exit_address_a
// function has been called! // function has been called!
gpr.Lock(ARM64Reg::W30); gpr.Lock(ARM64Reg::W30);
} }
ARM64Reg after_reg = exit_address_after_return_reg; // Push {ARM_PC (64-bit); PPC_PC (32-bit); MSR_BITS (32-bit)} on the stack
ARM64Reg after_reg = ARM64Reg::INVALID_REG;
ARM64Reg reg_to_push;
const u64 msr_bits = m_ppc_state.msr.Hex & JitBaseBlockCache::JIT_CACHE_MSR_MASK;
if (exit_address_after_return_reg == ARM64Reg::INVALID_REG) if (exit_address_after_return_reg == ARM64Reg::INVALID_REG)
{ {
after_reg = gpr.GetReg(); after_reg = gpr.GetReg();
MOVI2R(after_reg, exit_address_after_return); reg_to_push = EncodeRegTo64(after_reg);
MOVI2R(reg_to_push, msr_bits << 32 | exit_address_after_return);
}
else if (msr_bits == 0)
{
reg_to_push = EncodeRegTo64(exit_address_after_return_reg);
}
else
{
after_reg = gpr.GetReg();
reg_to_push = EncodeRegTo64(after_reg);
ORRI2R(reg_to_push, EncodeRegTo64(exit_address_after_return_reg), msr_bits << 32, reg_to_push);
} }
ARM64Reg code_reg = gpr.GetReg(); ARM64Reg code_reg = gpr.GetReg();
constexpr s32 adr_offset = sizeof(u32) * 3; constexpr s32 adr_offset = sizeof(u32) * 3;
const u8* host_address_after_return = GetCodePtr() + adr_offset; const u8* host_address_after_return = GetCodePtr() + adr_offset;
ADR(EncodeRegTo64(code_reg), adr_offset); ADR(EncodeRegTo64(code_reg), adr_offset);
STP(IndexType::Pre, EncodeRegTo64(code_reg), EncodeRegTo64(after_reg), ARM64Reg::SP, -16); STP(IndexType::Pre, EncodeRegTo64(code_reg), reg_to_push, ARM64Reg::SP, -16);
gpr.Unlock(code_reg); gpr.Unlock(code_reg);
if (after_reg != exit_address_after_return_reg) if (after_reg != ARM64Reg::INVALID_REG)
gpr.Unlock(after_reg); gpr.Unlock(after_reg);
FixupBranch skip_exit = BL(); FixupBranch skip_exit = BL();
@ -612,9 +643,18 @@ void JitArm64::WriteBLRExit(Arm64Gen::ARM64Reg dest)
Cleanup(); Cleanup();
EndTimeProfile(js.curBlock); EndTimeProfile(js.curBlock);
// Check if {ARM_PC, PPC_PC} matches the current state. // Check if {PPC_PC, MSR_BITS} matches the current state, then RET to ARM_PC.
LDP(IndexType::Post, ARM64Reg::X2, ARM64Reg::X1, ARM64Reg::SP, 16); LDP(IndexType::Post, ARM64Reg::X2, ARM64Reg::X1, ARM64Reg::SP, 16);
CMP(ARM64Reg::W1, DISPATCHER_PC); const u64 msr_bits = m_ppc_state.msr.Hex & JitBaseBlockCache::JIT_CACHE_MSR_MASK;
if (msr_bits == 0)
{
CMP(ARM64Reg::X1, EncodeRegTo64(DISPATCHER_PC));
}
else
{
ORRI2R(ARM64Reg::X0, EncodeRegTo64(DISPATCHER_PC), msr_bits << 32, ARM64Reg::X0);
CMP(ARM64Reg::X1, ARM64Reg::X0);
}
FixupBranch no_match = B(CC_NEQ); FixupBranch no_match = B(CC_NEQ);
DoDownCount(); // overwrites X0 + X1 DoDownCount(); // overwrites X0 + X1

View File

@ -99,10 +99,6 @@ void JitArm64::mtmsr(UGeckoInstruction inst)
gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG); gpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG); fpr.Flush(FlushMode::All, ARM64Reg::INVALID_REG);
// Our jit cache also stores some MSR bits, as they have changed, we either
// have to validate them in the BLR/RET check, or just flush the stack here.
ResetStack();
WriteExceptionExit(js.compilerPC + 4, true); WriteExceptionExit(js.compilerPC + 4, true);
} }

View File

@ -50,7 +50,7 @@ void JitArm64::GenerateAsm()
STR(IndexType::Unsigned, ARM64Reg::X0, PPC_REG, PPCSTATE_OFF(stored_stack_pointer)); STR(IndexType::Unsigned, ARM64Reg::X0, PPC_REG, PPCSTATE_OFF(stored_stack_pointer));
// Push {nullptr; -1} as invalid destination on the stack. // Push {nullptr; -1} as invalid destination on the stack.
MOVI2R(ARM64Reg::X0, 0xFFFFFFFF); MOVI2R(ARM64Reg::X0, 0xFFFF'FFFF'FFFF'FFFF);
STP(IndexType::Pre, ARM64Reg::ZR, ARM64Reg::X0, ARM64Reg::SP, -16); STP(IndexType::Pre, ARM64Reg::ZR, ARM64Reg::X0, ARM64Reg::SP, -16);
// The PC will be loaded into DISPATCHER_PC after the call to CoreTiming::Advance(). // The PC will be loaded into DISPATCHER_PC after the call to CoreTiming::Advance().