#pragma once #include "catch.hpp" #include #include #include #if WINAPI_FAMILY_PARTITION(WINAPI_PARTITION_DESKTOP | WINAPI_PARTITION_SYSTEM) // IMallocSpy requires you to implement all methods, but we often only want one or two... struct MallocSpy : Microsoft::WRL::RuntimeClass, IMallocSpy> { wistd::function PreAllocCallback; virtual SIZE_T STDMETHODCALLTYPE PreAlloc(SIZE_T requestSize) override { if (PreAllocCallback) { return PreAllocCallback(requestSize); } return requestSize; } wistd::function PostAllocCallback; virtual void* STDMETHODCALLTYPE PostAlloc(void* ptr) override { if (PostAllocCallback) { return PostAllocCallback(ptr); } return ptr; } wistd::function PreFreeCallback; virtual void* STDMETHODCALLTYPE PreFree(void* ptr, BOOL wasSpyed) override { if (wasSpyed && PreFreeCallback) { return PreFreeCallback(ptr); } return ptr; } virtual void STDMETHODCALLTYPE PostFree(BOOL /*wasSpyed*/) override { } wistd::function PreReallocCallback; virtual SIZE_T STDMETHODCALLTYPE PreRealloc(void* ptr, SIZE_T requestSize, void** newPtr, BOOL wasSpyed) override { *newPtr = ptr; if (wasSpyed && PreReallocCallback) { return PreReallocCallback(ptr, requestSize, newPtr); } return requestSize; } wistd::function PostReallocCallback; virtual void* STDMETHODCALLTYPE PostRealloc(void* ptr, BOOL wasSpyed) override { if (wasSpyed && PostReallocCallback) { return PostReallocCallback(ptr); } return ptr; } wistd::function PreGetSizeCallback; virtual void* STDMETHODCALLTYPE PreGetSize(void* ptr, BOOL wasSpyed) override { if (wasSpyed && PreGetSizeCallback) { return PreGetSizeCallback(ptr); } return ptr; } wistd::function PostGetSizeCallback; virtual SIZE_T STDMETHODCALLTYPE PostGetSize(SIZE_T size, BOOL wasSpyed) override { if (wasSpyed && PostGetSizeCallback) { return PostGetSizeCallback(size); } return size; } wistd::function PreDidAllocCallback; virtual void* STDMETHODCALLTYPE PreDidAlloc(void* ptr, BOOL wasSpyed) override { if (wasSpyed && PreDidAllocCallback) { return PreDidAllocCallback(ptr); } return ptr; } virtual int STDMETHODCALLTYPE PostDidAlloc(void* /*ptr*/, BOOL /*wasSpyed*/, int result) override { return result; } virtual void STDMETHODCALLTYPE PreHeapMinimize() override { } virtual void STDMETHODCALLTYPE PostHeapMinimize() override { } }; Microsoft::WRL::ComPtr MakeSecureDeleterMallocSpy() { using namespace Microsoft::WRL; auto result = Make(); REQUIRE(result); result->PreFreeCallback = [](void* ptr) { ComPtr malloc; if (SUCCEEDED(::CoGetMalloc(1, &malloc))) { auto size = malloc->GetSize(ptr); auto buffer = static_cast(ptr); for (size_t i = 0; i < size; ++i) { REQUIRE(buffer[i] == 0); } } return ptr; }; return result; } #endif