kernel/address_arbiter: Convert the address arbiter into a class

Places all of the functions for address arbiter operation into a class.
This will be necessary for future deglobalizing efforts related to both
the memory and system itself.
This commit is contained in:
Lioncash 2019-03-05 11:54:06 -05:00
parent cc92c054ec
commit ec6664f6d6
5 changed files with 137 additions and 84 deletions

View File

@ -9,6 +9,7 @@
#include "common/common_types.h" #include "common/common_types.h"
#include "core/core.h" #include "core/core.h"
#include "core/core_cpu.h" #include "core/core_cpu.h"
#include "core/hle/kernel/address_arbiter.h"
#include "core/hle/kernel/errors.h" #include "core/hle/kernel/errors.h"
#include "core/hle/kernel/object.h" #include "core/hle/kernel/object.h"
#include "core/hle/kernel/process.h" #include "core/hle/kernel/process.h"
@ -17,53 +18,10 @@
#include "core/hle/result.h" #include "core/hle/result.h"
#include "core/memory.h" #include "core/memory.h"
namespace Kernel::AddressArbiter { namespace Kernel {
namespace {
// Performs actual address waiting logic.
static ResultCode WaitForAddress(VAddr address, s64 timeout) {
SharedPtr<Thread> current_thread = GetCurrentThread();
current_thread->SetArbiterWaitAddress(address);
current_thread->SetStatus(ThreadStatus::WaitArb);
current_thread->InvalidateWakeupCallback();
current_thread->WakeAfterDelay(timeout);
Core::System::GetInstance().CpuCore(current_thread->GetProcessorID()).PrepareReschedule();
return RESULT_TIMEOUT;
}
// Gets the threads waiting on an address.
static std::vector<SharedPtr<Thread>> GetThreadsWaitingOnAddress(VAddr address) {
const auto RetrieveWaitingThreads = [](std::size_t core_index,
std::vector<SharedPtr<Thread>>& waiting_threads,
VAddr arb_addr) {
const auto& scheduler = Core::System::GetInstance().Scheduler(core_index);
const auto& thread_list = scheduler.GetThreadList();
for (const auto& thread : thread_list) {
if (thread->GetArbiterWaitAddress() == arb_addr)
waiting_threads.push_back(thread);
}
};
// Retrieve all threads that are waiting for this address.
std::vector<SharedPtr<Thread>> threads;
RetrieveWaitingThreads(0, threads, address);
RetrieveWaitingThreads(1, threads, address);
RetrieveWaitingThreads(2, threads, address);
RetrieveWaitingThreads(3, threads, address);
// Sort them by priority, such that the highest priority ones come first.
std::sort(threads.begin(), threads.end(),
[](const SharedPtr<Thread>& lhs, const SharedPtr<Thread>& rhs) {
return lhs->GetPriority() < rhs->GetPriority();
});
return threads;
}
// Wake up num_to_wake (or all) threads in a vector. // Wake up num_to_wake (or all) threads in a vector.
static void WakeThreads(std::vector<SharedPtr<Thread>>& waiting_threads, s32 num_to_wake) { void WakeThreads(std::vector<SharedPtr<Thread>>& waiting_threads, s32 num_to_wake) {
// Only process up to 'target' threads, unless 'target' is <= 0, in which case process // Only process up to 'target' threads, unless 'target' is <= 0, in which case process
// them all. // them all.
std::size_t last = waiting_threads.size(); std::size_t last = waiting_threads.size();
@ -78,17 +36,20 @@ static void WakeThreads(std::vector<SharedPtr<Thread>>& waiting_threads, s32 num
waiting_threads[i]->ResumeFromWait(); waiting_threads[i]->ResumeFromWait();
} }
} }
} // Anonymous namespace
// Signals an address being waited on. AddressArbiter::AddressArbiter() = default;
ResultCode SignalToAddress(VAddr address, s32 num_to_wake) { AddressArbiter::~AddressArbiter() = default;
ResultCode AddressArbiter::SignalToAddress(VAddr address, s32 num_to_wake) {
std::vector<SharedPtr<Thread>> waiting_threads = GetThreadsWaitingOnAddress(address); std::vector<SharedPtr<Thread>> waiting_threads = GetThreadsWaitingOnAddress(address);
WakeThreads(waiting_threads, num_to_wake); WakeThreads(waiting_threads, num_to_wake);
return RESULT_SUCCESS; return RESULT_SUCCESS;
} }
// Signals an address being waited on and increments its value if equal to the value argument. ResultCode AddressArbiter::IncrementAndSignalToAddressIfEqual(VAddr address, s32 value,
ResultCode IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake) { s32 num_to_wake) {
// Ensure that we can write to the address. // Ensure that we can write to the address.
if (!Memory::IsValidVirtualAddress(address)) { if (!Memory::IsValidVirtualAddress(address)) {
return ERR_INVALID_ADDRESS_STATE; return ERR_INVALID_ADDRESS_STATE;
@ -103,9 +64,7 @@ ResultCode IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_
return SignalToAddress(address, num_to_wake); return SignalToAddress(address, num_to_wake);
} }
// Signals an address being waited on and modifies its value based on waiting thread count if equal ResultCode AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value,
// to the value argument.
ResultCode ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value,
s32 num_to_wake) { s32 num_to_wake) {
// Ensure that we can write to the address. // Ensure that we can write to the address.
if (!Memory::IsValidVirtualAddress(address)) { if (!Memory::IsValidVirtualAddress(address)) {
@ -135,8 +94,8 @@ ResultCode ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 valu
return RESULT_SUCCESS; return RESULT_SUCCESS;
} }
// Waits on an address if the value passed is less than the argument value, optionally decrementing. ResultCode AddressArbiter::WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout,
ResultCode WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout, bool should_decrement) { bool should_decrement) {
// Ensure that we can read the address. // Ensure that we can read the address.
if (!Memory::IsValidVirtualAddress(address)) { if (!Memory::IsValidVirtualAddress(address)) {
return ERR_INVALID_ADDRESS_STATE; return ERR_INVALID_ADDRESS_STATE;
@ -158,8 +117,7 @@ ResultCode WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout, bool
return WaitForAddress(address, timeout); return WaitForAddress(address, timeout);
} }
// Waits on an address if the value passed is equal to the argument value. ResultCode AddressArbiter::WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) {
ResultCode WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) {
// Ensure that we can read the address. // Ensure that we can read the address.
if (!Memory::IsValidVirtualAddress(address)) { if (!Memory::IsValidVirtualAddress(address)) {
return ERR_INVALID_ADDRESS_STATE; return ERR_INVALID_ADDRESS_STATE;
@ -175,4 +133,45 @@ ResultCode WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout) {
return WaitForAddress(address, timeout); return WaitForAddress(address, timeout);
} }
} // namespace Kernel::AddressArbiter
ResultCode AddressArbiter::WaitForAddress(VAddr address, s64 timeout) {
SharedPtr<Thread> current_thread = GetCurrentThread();
current_thread->SetArbiterWaitAddress(address);
current_thread->SetStatus(ThreadStatus::WaitArb);
current_thread->InvalidateWakeupCallback();
current_thread->WakeAfterDelay(timeout);
Core::System::GetInstance().CpuCore(current_thread->GetProcessorID()).PrepareReschedule();
return RESULT_TIMEOUT;
}
std::vector<SharedPtr<Thread>> AddressArbiter::GetThreadsWaitingOnAddress(VAddr address) const {
const auto RetrieveWaitingThreads = [](std::size_t core_index,
std::vector<SharedPtr<Thread>>& waiting_threads,
VAddr arb_addr) {
const auto& scheduler = Core::System::GetInstance().Scheduler(core_index);
const auto& thread_list = scheduler.GetThreadList();
for (const auto& thread : thread_list) {
if (thread->GetArbiterWaitAddress() == arb_addr)
waiting_threads.push_back(thread);
}
};
// Retrieve all threads that are waiting for this address.
std::vector<SharedPtr<Thread>> threads;
RetrieveWaitingThreads(0, threads, address);
RetrieveWaitingThreads(1, threads, address);
RetrieveWaitingThreads(2, threads, address);
RetrieveWaitingThreads(3, threads, address);
// Sort them by priority, such that the highest priority ones come first.
std::sort(threads.begin(), threads.end(),
[](const SharedPtr<Thread>& lhs, const SharedPtr<Thread>& rhs) {
return lhs->GetPriority() < rhs->GetPriority();
});
return threads;
}
} // namespace Kernel

View File

@ -5,11 +5,16 @@
#pragma once #pragma once
#include "common/common_types.h" #include "common/common_types.h"
#include "core/hle/kernel/address_arbiter.h"
union ResultCode; union ResultCode;
namespace Kernel::AddressArbiter { namespace Kernel {
class Thread;
class AddressArbiter {
public:
enum class ArbitrationType { enum class ArbitrationType {
WaitIfLessThan = 0, WaitIfLessThan = 0,
DecrementAndWaitIfLessThan = 1, DecrementAndWaitIfLessThan = 1,
@ -22,11 +27,40 @@ enum class SignalType {
ModifyByWaitingCountAndSignalIfEqual = 2, ModifyByWaitingCountAndSignalIfEqual = 2,
}; };
ResultCode SignalToAddress(VAddr address, s32 num_to_wake); AddressArbiter();
ResultCode IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake); ~AddressArbiter();
ResultCode ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake);
ResultCode WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout, bool should_decrement); AddressArbiter(const AddressArbiter&) = delete;
AddressArbiter& operator=(const AddressArbiter&) = delete;
AddressArbiter(AddressArbiter&&) = default;
AddressArbiter& operator=(AddressArbiter&&) = delete;
/// Signals an address being waited on.
ResultCode SignalToAddress(VAddr address, s32 num_to_wake);
/// Signals an address being waited on and increments its value if equal to the value argument.
ResultCode IncrementAndSignalToAddressIfEqual(VAddr address, s32 value, s32 num_to_wake);
/// Signals an address being waited on and modifies its value based on waiting thread count if
/// equal to the value argument.
ResultCode ModifyByWaitingCountAndSignalToAddressIfEqual(VAddr address, s32 value,
s32 num_to_wake);
/// Waits on an address if the value passed is less than the argument value,
/// optionally decrementing.
ResultCode WaitForAddressIfLessThan(VAddr address, s32 value, s64 timeout,
bool should_decrement);
/// Waits on an address if the value passed is equal to the argument value.
ResultCode WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout); ResultCode WaitForAddressIfEqual(VAddr address, s32 value, s64 timeout);
} // namespace Kernel::AddressArbiter private:
// Waits on the given address with a timeout in nanoseconds
ResultCode WaitForAddress(VAddr address, s64 timeout);
// Gets the threads waiting on an address.
std::vector<SharedPtr<Thread>> GetThreadsWaitingOnAddress(VAddr address) const;
};
} // namespace Kernel

View File

@ -12,6 +12,7 @@
#include "core/core.h" #include "core/core.h"
#include "core/core_timing.h" #include "core/core_timing.h"
#include "core/hle/kernel/address_arbiter.h"
#include "core/hle/kernel/client_port.h" #include "core/hle/kernel/client_port.h"
#include "core/hle/kernel/handle_table.h" #include "core/hle/kernel/handle_table.h"
#include "core/hle/kernel/kernel.h" #include "core/hle/kernel/kernel.h"
@ -135,6 +136,8 @@ struct KernelCore::Impl {
std::vector<SharedPtr<Process>> process_list; std::vector<SharedPtr<Process>> process_list;
Process* current_process = nullptr; Process* current_process = nullptr;
Kernel::AddressArbiter address_arbiter;
SharedPtr<ResourceLimit> system_resource_limit; SharedPtr<ResourceLimit> system_resource_limit;
Core::Timing::EventType* thread_wakeup_event_type = nullptr; Core::Timing::EventType* thread_wakeup_event_type = nullptr;
@ -184,6 +187,14 @@ const Process* KernelCore::CurrentProcess() const {
return impl->current_process; return impl->current_process;
} }
AddressArbiter& KernelCore::AddressArbiter() {
return impl->address_arbiter;
}
const AddressArbiter& KernelCore::AddressArbiter() const {
return impl->address_arbiter;
}
void KernelCore::AddNamedPort(std::string name, SharedPtr<ClientPort> port) { void KernelCore::AddNamedPort(std::string name, SharedPtr<ClientPort> port) {
impl->named_ports.emplace(std::move(name), std::move(port)); impl->named_ports.emplace(std::move(name), std::move(port));
} }

View File

@ -18,6 +18,7 @@ struct EventType;
namespace Kernel { namespace Kernel {
class AddressArbiter;
class ClientPort; class ClientPort;
class HandleTable; class HandleTable;
class Process; class Process;
@ -67,6 +68,12 @@ public:
/// Retrieves a const pointer to the current process. /// Retrieves a const pointer to the current process.
const Process* CurrentProcess() const; const Process* CurrentProcess() const;
/// Provides a reference to the kernel's address arbiter.
Kernel::AddressArbiter& AddressArbiter();
/// Provides a const reference to the kernel's address arbiter.
const Kernel::AddressArbiter& AddressArbiter() const;
/// Adds a port to the named port table /// Adds a port to the named port table
void AddNamedPort(std::string name, SharedPtr<ClientPort> port); void AddNamedPort(std::string name, SharedPtr<ClientPort> port);

View File

@ -1495,13 +1495,14 @@ static ResultCode WaitForAddress(VAddr address, u32 type, s32 value, s64 timeout
return ERR_INVALID_ADDRESS; return ERR_INVALID_ADDRESS;
} }
auto& address_arbiter = Core::System::GetInstance().Kernel().AddressArbiter();
switch (static_cast<AddressArbiter::ArbitrationType>(type)) { switch (static_cast<AddressArbiter::ArbitrationType>(type)) {
case AddressArbiter::ArbitrationType::WaitIfLessThan: case AddressArbiter::ArbitrationType::WaitIfLessThan:
return AddressArbiter::WaitForAddressIfLessThan(address, value, timeout, false); return address_arbiter.WaitForAddressIfLessThan(address, value, timeout, false);
case AddressArbiter::ArbitrationType::DecrementAndWaitIfLessThan: case AddressArbiter::ArbitrationType::DecrementAndWaitIfLessThan:
return AddressArbiter::WaitForAddressIfLessThan(address, value, timeout, true); return address_arbiter.WaitForAddressIfLessThan(address, value, timeout, true);
case AddressArbiter::ArbitrationType::WaitIfEqual: case AddressArbiter::ArbitrationType::WaitIfEqual:
return AddressArbiter::WaitForAddressIfEqual(address, value, timeout); return address_arbiter.WaitForAddressIfEqual(address, value, timeout);
default: default:
LOG_ERROR(Kernel_SVC, LOG_ERROR(Kernel_SVC,
"Invalid arbitration type, expected WaitIfLessThan, DecrementAndWaitIfLessThan " "Invalid arbitration type, expected WaitIfLessThan, DecrementAndWaitIfLessThan "
@ -1526,13 +1527,14 @@ static ResultCode SignalToAddress(VAddr address, u32 type, s32 value, s32 num_to
return ERR_INVALID_ADDRESS; return ERR_INVALID_ADDRESS;
} }
auto& address_arbiter = Core::System::GetInstance().Kernel().AddressArbiter();
switch (static_cast<AddressArbiter::SignalType>(type)) { switch (static_cast<AddressArbiter::SignalType>(type)) {
case AddressArbiter::SignalType::Signal: case AddressArbiter::SignalType::Signal:
return AddressArbiter::SignalToAddress(address, num_to_wake); return address_arbiter.SignalToAddress(address, num_to_wake);
case AddressArbiter::SignalType::IncrementAndSignalIfEqual: case AddressArbiter::SignalType::IncrementAndSignalIfEqual:
return AddressArbiter::IncrementAndSignalToAddressIfEqual(address, value, num_to_wake); return address_arbiter.IncrementAndSignalToAddressIfEqual(address, value, num_to_wake);
case AddressArbiter::SignalType::ModifyByWaitingCountAndSignalIfEqual: case AddressArbiter::SignalType::ModifyByWaitingCountAndSignalIfEqual:
return AddressArbiter::ModifyByWaitingCountAndSignalToAddressIfEqual(address, value, return address_arbiter.ModifyByWaitingCountAndSignalToAddressIfEqual(address, value,
num_to_wake); num_to_wake);
default: default:
LOG_ERROR(Kernel_SVC, LOG_ERROR(Kernel_SVC,