From 22310dfc3e895ce271748e566405d8d394f28e4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?L=C3=A9o=20Lam?= Date: Fri, 3 Nov 2017 12:23:57 +0100 Subject: [PATCH] USBv5: Read transfer parameters from the correct vector This is why static analysis is essential. --- Source/Core/Core/IOS/Device.cpp | 9 +++++++++ Source/Core/Core/IOS/Device.h | 1 + Source/Core/Core/IOS/USB/USBV5.cpp | 19 +++++++++---------- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/Source/Core/Core/IOS/Device.cpp b/Source/Core/Core/IOS/Device.cpp index c5da4c662c..467ce5cfd1 100644 --- a/Source/Core/Core/IOS/Device.cpp +++ b/Source/Core/Core/IOS/Device.cpp @@ -7,6 +7,7 @@ #include #include +#include "Common/Assert.h" #include "Common/Logging/Log.h" #include "Common/StringUtil.h" #include "Core/HW/Memmap.h" @@ -77,6 +78,14 @@ IOCtlVRequest::IOCtlVRequest(const u32 address_) : Request(address_) } } +const IOCtlVRequest::IOVector* IOCtlVRequest::GetVector(size_t index) const +{ + _assert_(index < (in_vectors.size() + io_vectors.size())); + if (index < in_vectors.size()) + return &in_vectors[index]; + return &io_vectors[index - in_vectors.size()]; +} + bool IOCtlVRequest::HasNumberOfValidVectors(const size_t in_count, const size_t io_count) const { if (in_vectors.size() != in_count || io_vectors.size() != io_count) diff --git a/Source/Core/Core/IOS/Device.h b/Source/Core/Core/IOS/Device.h index e9191f163c..0723972be4 100644 --- a/Source/Core/Core/IOS/Device.h +++ b/Source/Core/Core/IOS/Device.h @@ -156,6 +156,7 @@ struct IOCtlVRequest final : Request // merging them into a single std::vector would make using the first out vector more complicated. std::vector in_vectors; std::vector io_vectors; + const IOVector* GetVector(size_t index) const; explicit IOCtlVRequest(u32 address); bool HasNumberOfValidVectors(size_t in_count, size_t io_count) const; void Dump(const std::string& description, LogTypes::LOG_TYPE type = LogTypes::IOS, diff --git a/Source/Core/Core/IOS/USB/USBV5.cpp b/Source/Core/Core/IOS/USB/USBV5.cpp index 804989b75a..424a3dc98f 100644 --- a/Source/Core/Core/IOS/USB/USBV5.cpp +++ b/Source/Core/Core/IOS/USB/USBV5.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include "Common/ChunkFile.h" #include "Common/Logging/Log.h" @@ -22,38 +21,38 @@ namespace HLE namespace USB { V5CtrlMessage::V5CtrlMessage(Kernel& ios, const IOCtlVRequest& ioctlv) - : CtrlMessage(ios, ioctlv, Memory::Read_U32(ioctlv.in_vectors[0].address + 16)) + : CtrlMessage(ios, ioctlv, ioctlv.GetVector(1)->address) { request_type = Memory::Read_U8(ioctlv.in_vectors[0].address + 8); request = Memory::Read_U8(ioctlv.in_vectors[0].address + 9); value = Memory::Read_U16(ioctlv.in_vectors[0].address + 10); index = Memory::Read_U16(ioctlv.in_vectors[0].address + 12); - length = Memory::Read_U16(ioctlv.in_vectors[0].address + 14); + length = static_cast(ioctlv.GetVector(1)->size); } V5BulkMessage::V5BulkMessage(Kernel& ios, const IOCtlVRequest& ioctlv) - : BulkMessage(ios, ioctlv, Memory::Read_U32(ioctlv.in_vectors[0].address + 8)) + : BulkMessage(ios, ioctlv, ioctlv.GetVector(1)->address) { - length = Memory::Read_U16(ioctlv.in_vectors[0].address + 12); + length = static_cast(ioctlv.GetVector(1)->size); endpoint = Memory::Read_U8(ioctlv.in_vectors[0].address + 18); } V5IntrMessage::V5IntrMessage(Kernel& ios, const IOCtlVRequest& ioctlv) - : IntrMessage(ios, ioctlv, Memory::Read_U32(ioctlv.in_vectors[0].address + 8)) + : IntrMessage(ios, ioctlv, ioctlv.GetVector(1)->address) { - length = Memory::Read_U16(ioctlv.in_vectors[0].address + 12); + length = static_cast(ioctlv.GetVector(1)->size); endpoint = Memory::Read_U8(ioctlv.in_vectors[0].address + 14); } V5IsoMessage::V5IsoMessage(Kernel& ios, const IOCtlVRequest& ioctlv) - : IsoMessage(ios, ioctlv, Memory::Read_U32(ioctlv.in_vectors[0].address + 8)) + : IsoMessage(ios, ioctlv, ioctlv.GetVector(2)->address) { num_packets = Memory::Read_U8(ioctlv.in_vectors[0].address + 16); endpoint = Memory::Read_U8(ioctlv.in_vectors[0].address + 17); - packet_sizes_addr = Memory::Read_U32(ioctlv.in_vectors[0].address + 12); + packet_sizes_addr = ioctlv.GetVector(1)->address; for (size_t i = 0; i < num_packets; ++i) packet_sizes.push_back(Memory::Read_U16(static_cast(packet_sizes_addr + i * sizeof(u16)))); - length = std::accumulate(packet_sizes.begin(), packet_sizes.end(), 0); + length = static_cast(ioctlv.GetVector(2)->size); } } // namespace USB