[Memory] Add HostToGuestVirtual and use it in a couple of places

This commit is contained in:
Triang3l 2019-08-13 23:49:49 +03:00
parent f1b9e1afce
commit 741b5ae2ec
7 changed files with 80 additions and 40 deletions

View File

@ -25,6 +25,8 @@ MMIOHandler* MMIOHandler::global_handler_ = nullptr;
std::unique_ptr<MMIOHandler> MMIOHandler::Install(
uint8_t* virtual_membase, uint8_t* physical_membase, uint8_t* membase_end,
HostToGuestVirtual host_to_guest_virtual,
const void* host_to_guest_virtual_context,
AccessViolationCallback access_violation_callback,
void* access_violation_callback_context) {
// There can be only one handler at a time.
@ -34,7 +36,8 @@ std::unique_ptr<MMIOHandler> MMIOHandler::Install(
}
auto handler = std::unique_ptr<MMIOHandler>(new MMIOHandler(
virtual_membase, physical_membase, membase_end, access_violation_callback,
virtual_membase, physical_membase, membase_end, host_to_guest_virtual,
host_to_guest_virtual_context, access_violation_callback,
access_violation_callback_context));
// Install the exception handler directed at the MMIOHandler.
@ -46,11 +49,15 @@ std::unique_ptr<MMIOHandler> MMIOHandler::Install(
MMIOHandler::MMIOHandler(uint8_t* virtual_membase, uint8_t* physical_membase,
uint8_t* membase_end,
HostToGuestVirtual host_to_guest_virtual,
const void* host_to_guest_virtual_context,
AccessViolationCallback access_violation_callback,
void* access_violation_callback_context)
: virtual_membase_(virtual_membase),
physical_membase_(physical_membase),
memory_end_(membase_end),
host_to_guest_virtual_(host_to_guest_virtual),
host_to_guest_virtual_context_(host_to_guest_virtual_context),
access_violation_callback_(access_violation_callback),
access_violation_callback_context_(access_violation_callback_context) {}
@ -279,14 +286,16 @@ bool MMIOHandler::ExceptionCallback(Exception* ex) {
// Quick kill anything outside our mapping.
return false;
}
void* fault_host_address = reinterpret_cast<void*>(ex->fault_address());
// Access violations are pretty rare, so we can do a linear search here.
// Only check if in the virtual range, as we only support virtual ranges.
const MMIORange* range = nullptr;
if (ex->fault_address() < uint64_t(physical_membase_)) {
uint32_t fault_virtual_address = host_to_guest_virtual_(
host_to_guest_virtual_context_, fault_host_address);
for (const auto& test_range : mapped_ranges_) {
if ((static_cast<uint32_t>(ex->fault_address()) & test_range.mask) ==
test_range.address) {
if ((fault_virtual_address & test_range.mask) == test_range.address) {
// Address is within the range of this mapping.
range = &test_range;
break;
@ -300,8 +309,7 @@ bool MMIOHandler::ExceptionCallback(Exception* ex) {
auto lock = global_critical_region_.Acquire();
memory::PageAccess cur_access;
size_t page_length = memory::page_size();
memory::QueryProtect(reinterpret_cast<void*>(ex->fault_address()),
page_length, cur_access);
memory::QueryProtect(fault_host_address, page_length, cur_access);
if (cur_access != memory::PageAccess::kReadOnly &&
cur_access != memory::PageAccess::kNoAccess) {
// Another thread has cleared this write watch. Abort.
@ -314,10 +322,10 @@ bool MMIOHandler::ExceptionCallback(Exception* ex) {
switch (ex->access_violation_operation()) {
case Exception::AccessViolationOperation::kRead:
return access_violation_callback_(access_violation_callback_context_,
size_t(ex->fault_address()), false);
fault_host_address, false);
case Exception::AccessViolationOperation::kWrite:
return access_violation_callback_(access_violation_callback_context_,
size_t(ex->fault_address()), true);
fault_host_address, true);
default:
// Data Execution Prevention or something else uninteresting.
break;

View File

@ -27,8 +27,6 @@ typedef uint32_t (*MMIOReadCallback)(void* ppc_context, void* callback_context,
uint32_t addr);
typedef void (*MMIOWriteCallback)(void* ppc_context, void* callback_context,
uint32_t addr, uint32_t value);
typedef void (*AccessWatchCallback)(void* context_ptr, void* data_ptr,
uint32_t address);
struct MMIORange {
uint32_t address;
@ -44,7 +42,9 @@ class MMIOHandler {
public:
virtual ~MMIOHandler();
typedef bool (*AccessViolationCallback)(void* context, size_t host_address,
typedef uint32_t (*HostToGuestVirtual)(const void* context,
const void* host_address);
typedef bool (*AccessViolationCallback)(void* context, void* host_address,
bool is_write);
// access_violation_callback is called in global_critical_region, so if
@ -52,6 +52,8 @@ class MMIOHandler {
// will be called only once.
static std::unique_ptr<MMIOHandler> Install(
uint8_t* virtual_membase, uint8_t* physical_membase, uint8_t* membase_end,
HostToGuestVirtual host_to_guest_virtual,
const void* host_to_guest_virtual_context,
AccessViolationCallback access_violation_callback,
void* access_violation_callback_context);
static MMIOHandler* global_handler() { return global_handler_; }
@ -66,7 +68,8 @@ class MMIOHandler {
protected:
MMIOHandler(uint8_t* virtual_membase, uint8_t* physical_membase,
uint8_t* membase_end,
uint8_t* membase_end, HostToGuestVirtual host_to_guest_virtual,
const void* host_to_guest_virtual_context,
AccessViolationCallback access_violation_callback,
void* access_violation_callback_context);
@ -79,6 +82,9 @@ class MMIOHandler {
std::vector<MMIORange> mapped_ranges_;
HostToGuestVirtual host_to_guest_virtual_;
const void* host_to_guest_virtual_context_;
AccessViolationCallback access_violation_callback_;
void* access_violation_callback_context_;

View File

@ -314,11 +314,11 @@ X_STATUS UserModule::GetOptHeader(xex2_header_keys key,
if (!header) {
return X_STATUS_UNSUCCESSFUL;
}
return GetOptHeader(memory()->virtual_membase(), header, key,
out_header_guest_ptr);
return GetOptHeader(memory(), header, key, out_header_guest_ptr);
}
X_STATUS UserModule::GetOptHeader(uint8_t* membase, const xex2_header* header,
X_STATUS UserModule::GetOptHeader(const Memory* memory,
const xex2_header* header,
xex2_header_keys key,
uint32_t* out_header_guest_ptr) {
assert_not_null(out_header_guest_ptr);
@ -337,14 +337,11 @@ X_STATUS UserModule::GetOptHeader(uint8_t* membase, const xex2_header* header,
break;
case 0x01:
// Return pointer to data stored in header value.
field_value = static_cast<uint32_t>(
reinterpret_cast<const uint8_t*>(&opt_header.value) - membase);
field_value = memory->HostToGuestVirtual(&opt_header.value);
break;
default:
// Data stored at offset to header.
field_value = static_cast<uint32_t>(
reinterpret_cast<const uint8_t*>(header) - membase) +
opt_header.offset;
field_value = memory->HostToGuestVirtual(header) + opt_header.offset;
break;
}
break;

View File

@ -81,7 +81,7 @@ class UserModule : public XModule {
// Get optional header that can safely be returned to guest code.
X_STATUS GetOptHeader(xex2_header_keys key, uint32_t* out_header_guest_ptr);
static X_STATUS GetOptHeader(uint8_t* membase, const xex2_header* header,
static X_STATUS GetOptHeader(const Memory* memory, const xex2_header* header,
xex2_header_keys key,
uint32_t* out_header_guest_ptr);

View File

@ -341,8 +341,8 @@ pointer_result_t RtlImageXexHeaderField(pointer_t<xex2_header> xex_header,
uint32_t field_value = 0;
uint32_t field = field_dword; // VS acts weird going from dword_t -> enum
UserModule::GetOptHeader(kernel_memory()->virtual_membase(), xex_header,
xex2_header_keys(field), &field_value);
UserModule::GetOptHeader(kernel_memory(), xex_header, xex2_header_keys(field),
&field_value);
return field_value;
}

View File

@ -195,9 +195,9 @@ bool Memory::Initialize() {
kMemoryProtectRead | kMemoryProtectWrite);
// Add handlers for MMIO.
mmio_handler_ = cpu::MMIOHandler::Install(virtual_membase_, physical_membase_,
physical_membase_ + 0x1FFFFFFF,
AccessViolationCallbackThunk, this);
mmio_handler_ = cpu::MMIOHandler::Install(
virtual_membase_, physical_membase_, physical_membase_ + 0x1FFFFFFF,
HostToGuestVirtualThunk, this, AccessViolationCallbackThunk, this);
if (!mmio_handler_) {
XELOGE("Unable to install MMIO handlers");
assert_always();
@ -349,6 +349,26 @@ BaseHeap* Memory::LookupHeapByType(bool physical, uint32_t page_size) {
VirtualHeap* Memory::GetPhysicalHeap() { return &heaps_.physical; }
uint32_t Memory::HostToGuestVirtual(const void* host_address) const {
size_t virtual_address = reinterpret_cast<size_t>(host_address) -
reinterpret_cast<size_t>(virtual_membase_);
uint32_t vE0000000_host_offset = heaps_.vE0000000.host_address_offset();
size_t vE0000000_host_base =
size_t(heaps_.vE0000000.heap_base()) + vE0000000_host_offset;
if (virtual_address >= vE0000000_host_base &&
virtual_address <=
(vE0000000_host_base + heaps_.vE0000000.heap_size() - 1)) {
virtual_address -= vE0000000_host_offset;
}
return uint32_t(virtual_address);
}
uint32_t Memory::HostToGuestVirtualThunk(const void* context,
const void* host_address) {
return reinterpret_cast<const Memory*>(context)->HostToGuestVirtual(
host_address);
}
void Memory::Zero(uint32_t address, uint32_t size) {
std::memset(TranslateVirtual(address), 0, size);
}
@ -405,7 +425,7 @@ cpu::MMIORange* Memory::LookupVirtualMappedRange(uint32_t virtual_address) {
return mmio_handler_->LookupRange(virtual_address);
}
bool Memory::AccessViolationCallback(size_t host_address, bool is_write) {
bool Memory::AccessViolationCallback(void* host_address, bool is_write) {
if (!is_write) {
// TODO(Triang3l): Handle GPU readback.
return false;
@ -413,8 +433,10 @@ bool Memory::AccessViolationCallback(size_t host_address, bool is_write) {
// Access via physical_membase_ is special, when need to bypass everything,
// so only watching virtual memory regions.
if (host_address < reinterpret_cast<size_t>(virtual_membase_) ||
host_address >= reinterpret_cast<size_t>(physical_membase_)) {
if (reinterpret_cast<size_t>(host_address) <
reinterpret_cast<size_t>(virtual_membase_) ||
reinterpret_cast<size_t>(host_address) >=
reinterpret_cast<size_t>(physical_membase_)) {
return false;
}
@ -445,7 +467,7 @@ bool Memory::AccessViolationCallback(size_t host_address, bool is_write) {
return false;
}
bool Memory::AccessViolationCallbackThunk(void* context, size_t host_address,
bool Memory::AccessViolationCallbackThunk(void* context, void* host_address,
bool is_write) {
return reinterpret_cast<Memory*>(context)->AccessViolationCallback(
host_address, is_write);

View File

@ -95,6 +95,13 @@ class BaseHeap {
public:
virtual ~BaseHeap();
// Offset of the heap in relative to membase, without host_address_offset
// adjustment.
uint32_t heap_base() const { return heap_base_; }
// Length of the heap range.
uint32_t heap_size() const { return heap_size_; }
// Size of each page within the heap range in bytes.
uint32_t page_size() const { return page_size_; }
@ -277,10 +284,7 @@ class Memory {
// Translates a guest virtual address to a host address that can be accessed
// as a normal pointer.
// Note that the contents at the specified host address are big-endian.
inline uint8_t* TranslateVirtual(uint32_t guest_address) const {
return virtual_membase_ + guest_address;
}
template <typename T>
template <typename T = uint8_t*>
inline T TranslateVirtual(uint32_t guest_address) const {
return reinterpret_cast<T>(virtual_membase_ + guest_address);
}
@ -292,15 +296,16 @@ class Memory {
// Translates a guest physical address to a host address that can be accessed
// as a normal pointer.
// Note that the contents at the specified host address are big-endian.
inline uint8_t* TranslatePhysical(uint32_t guest_address) const {
return physical_membase_ + (guest_address & 0x1FFFFFFF);
}
template <typename T>
template <typename T = uint8_t*>
inline T TranslatePhysical(uint32_t guest_address) const {
return reinterpret_cast<T>(physical_membase_ +
(guest_address & 0x1FFFFFFF));
}
// Translates a host address to a guest virtual address.
// Note that the contents at the returned host address are big-endian.
uint32_t HostToGuestVirtual(const void* host_address) const;
// Zeros out a range of memory at the given guest address.
void Zero(uint32_t address, uint32_t size);
@ -412,11 +417,13 @@ class Memory {
int MapViews(uint8_t* mapping_base);
void UnmapViews();
bool AccessViolationCallback(size_t host_address, bool is_write);
static bool AccessViolationCallbackThunk(void* context, size_t host_address,
static uint32_t HostToGuestVirtualThunk(const void* context,
const void* host_address);
bool AccessViolationCallback(void* host_address, bool is_write);
static bool AccessViolationCallbackThunk(void* context, void* host_address,
bool is_write);
private:
std::wstring file_name_;
uint32_t system_page_size_ = 0;
uint32_t system_allocation_granularity_ = 0;