diff --git a/common/D3D12/Builders.cpp b/common/D3D12/Builders.cpp index b48c2d6677..7e0431f5f4 100644 --- a/common/D3D12/Builders.cpp +++ b/common/D3D12/Builders.cpp @@ -223,6 +223,56 @@ void GraphicsPipelineBuilder::SetDepthStencilFormat(DXGI_FORMAT format) m_desc.DSVFormat = format; } + +ComputePipelineBuilder::ComputePipelineBuilder() +{ + Clear(); +} + +void ComputePipelineBuilder::Clear() +{ + std::memset(&m_desc, 0, sizeof(m_desc)); +} + +wil::com_ptr_nothrow ComputePipelineBuilder::Create(ID3D12Device* device, bool clear /*= true*/) +{ + wil::com_ptr_nothrow ps; + HRESULT hr = device->CreateComputePipelineState(&m_desc, IID_PPV_ARGS(ps.put())); + if (FAILED(hr)) + { + Console.Error("CreateComputePipelineState() failed: %08X", hr); + return {}; + } + + if (clear) + Clear(); + + return ps; +} + +wil::com_ptr_nothrow ComputePipelineBuilder::Create(ID3D12Device* device, ShaderCache& cache, bool clear /*= true*/) +{ + wil::com_ptr_nothrow pso = cache.GetPipelineState(device, m_desc); + if (!pso) + return {}; + + if (clear) + Clear(); + + return pso; +} + +void ComputePipelineBuilder::SetRootSignature(ID3D12RootSignature* rs) +{ + m_desc.pRootSignature = rs; +} + +void ComputePipelineBuilder::SetShader(const void* data, u32 data_size) +{ + m_desc.CS.pShaderBytecode = data; + m_desc.CS.BytecodeLength = data_size; +} + RootSignatureBuilder::RootSignatureBuilder() { Clear(); diff --git a/common/D3D12/Builders.h b/common/D3D12/Builders.h index fe111a7358..625b1c0706 100644 --- a/common/D3D12/Builders.h +++ b/common/D3D12/Builders.h @@ -114,8 +114,26 @@ namespace D3D12 void SetDepthStencilFormat(DXGI_FORMAT format); private: - D3D12_GRAPHICS_PIPELINE_STATE_DESC m_desc{}; - std::array m_input_elements{}; + D3D12_GRAPHICS_PIPELINE_STATE_DESC m_desc; + std::array m_input_elements; + }; + + class ComputePipelineBuilder + { + public: + ComputePipelineBuilder(); + ~ComputePipelineBuilder() = default; + + void Clear(); + + wil::com_ptr_nothrow Create(ID3D12Device* device, bool clear = true); + wil::com_ptr_nothrow Create(ID3D12Device* device, ShaderCache& cache, bool clear = true); + + void SetRootSignature(ID3D12RootSignature* rs); + + void SetShader(const void* data, u32 data_size); + private: + D3D12_COMPUTE_PIPELINE_STATE_DESC m_desc; }; } // namespace D3D12 \ No newline at end of file diff --git a/common/D3D12/ShaderCache.cpp b/common/D3D12/ShaderCache.cpp index 72f694d016..f145ab6e08 100644 --- a/common/D3D12/ShaderCache.cpp +++ b/common/D3D12/ShaderCache.cpp @@ -387,6 +387,23 @@ ShaderCache::CacheIndexKey ShaderCache::GetPipelineCacheKey(const D3D12_GRAPHICS return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::GraphicsPipeline}; } +ShaderCache::CacheIndexKey ShaderCache::GetPipelineCacheKey(const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc) +{ + MD5Digest digest; + u32 length = sizeof(D3D12_GRAPHICS_PIPELINE_STATE_DESC); + + if (gpdesc.CS.BytecodeLength > 0) + { + digest.Update(gpdesc.CS.pShaderBytecode, static_cast(gpdesc.CS.BytecodeLength)); + length += static_cast(gpdesc.CS.BytecodeLength); + } + + MD5Hash h; + digest.Final(h.hash); + + return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::ComputePipeline}; +} + ShaderCache::ComPtr ShaderCache::GetShaderBlob(EntryType type, std::string_view shader_code, const D3D_SHADER_MACRO* macros /* = nullptr */, const char* entry_point /* = "main" */) { @@ -441,6 +458,40 @@ ShaderCache::ComPtr ShaderCache::GetPipelineState(ID3D12Dev return pso; } +ShaderCache::ComPtr ShaderCache::GetPipelineState(ID3D12Device* device, + const D3D12_COMPUTE_PIPELINE_STATE_DESC& desc) +{ + const auto key = GetPipelineCacheKey(desc); + + auto iter = m_pipeline_index.find(key); + if (iter == m_pipeline_index.end()) + return CompileAndAddPipeline(device, key, desc); + + ComPtr blob; + HRESULT hr = D3DCreateBlob(iter->second.blob_size, blob.put()); + if (FAILED(hr) || std::fseek(m_pipeline_blob_file, iter->second.file_offset, SEEK_SET) != 0 || + std::fread(blob->GetBufferPointer(), 1, iter->second.blob_size, m_pipeline_blob_file) != iter->second.blob_size) + { + Console.Error("Read blob from file failed"); + return {}; + } + + D3D12_COMPUTE_PIPELINE_STATE_DESC desc_with_blob(desc); + desc_with_blob.CachedPSO.pCachedBlob = blob->GetBufferPointer(); + desc_with_blob.CachedPSO.CachedBlobSizeInBytes = blob->GetBufferSize(); + + ComPtr pso; + hr = device->CreateComputePipelineState(&desc_with_blob, IID_PPV_ARGS(pso.put())); + if (FAILED(hr)) + { + Console.Warning("Creating cached PSO failed: %08X. Invalidating cache.", hr); + InvalidatePipelineCache(); + pso = CompileAndAddPipeline(device, key, desc); + } + + return pso; +} + ShaderCache::ComPtr ShaderCache::CompileAndAddShaderBlob(const CacheIndexKey& key, std::string_view shader_code, const D3D_SHADER_MACRO* macros, const char* entry_point) { @@ -457,6 +508,9 @@ ShaderCache::ComPtr ShaderCache::CompileAndAddShaderBlob(const CacheIn case EntryType::PixelShader: blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Pixel, m_feature_level, m_debug, shader_code, macros, entry_point); break; + case EntryType::ComputeShader: + blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Compute, m_feature_level, m_debug, shader_code, macros, entry_point); + break; default: break; } @@ -507,15 +561,37 @@ ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& ke return {}; } + AddPipelineToBlob(key, pso.get()); + return pso; +} + +ShaderCache::ComPtr +ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key, + const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc) +{ + ComPtr pso; + HRESULT hr = device->CreateComputePipelineState(&gpdesc, IID_PPV_ARGS(pso.put())); + if (FAILED(hr)) + { + Console.Error("Creating cached compute PSO failed: %08X", hr); + return {}; + } + + AddPipelineToBlob(key, pso.get()); + return pso; +} + +bool ShaderCache::AddPipelineToBlob(const CacheIndexKey& key, ID3D12PipelineState* pso) +{ if (!m_pipeline_blob_file || std::fseek(m_pipeline_blob_file, 0, SEEK_END) != 0) - return pso; + return false; ComPtr blob; - hr = pso->GetCachedBlob(blob.put()); + HRESULT hr = pso->GetCachedBlob(blob.put()); if (FAILED(hr)) { Console.Warning("Failed to get cached PSO data: %08X", hr); - return pso; + return false; } CacheIndexData data; @@ -535,9 +611,9 @@ ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& ke std::fflush(m_pipeline_index_file) != 0) { Console.Error("Failed to write pipeline blob to file"); - return pso; + return false; } m_shader_index.emplace(key, data); - return pso; + return true; } diff --git a/common/D3D12/ShaderCache.h b/common/D3D12/ShaderCache.h index e033d9524a..48f5b9afa9 100644 --- a/common/D3D12/ShaderCache.h +++ b/common/D3D12/ShaderCache.h @@ -41,6 +41,7 @@ namespace D3D12 PixelShader, ComputeShader, GraphicsPipeline, + ComputePipeline, }; ShaderCache(); @@ -77,6 +78,7 @@ namespace D3D12 const D3D_SHADER_MACRO* macros = nullptr, const char* entry_point = "main"); ComPtr GetPipelineState(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc); + ComPtr GetPipelineState(ID3D12Device* device, const D3D12_COMPUTE_PIPELINE_STATE_DESC& desc); private: static constexpr u32 FILE_VERSION = 1; @@ -120,6 +122,7 @@ namespace D3D12 static CacheIndexKey GetShaderCacheKey(EntryType type, const std::string_view& shader_code, const D3D_SHADER_MACRO* macros, const char* entry_point); static CacheIndexKey GetPipelineCacheKey(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc); + static CacheIndexKey GetPipelineCacheKey(const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc); bool CreateNew(const std::string& index_filename, const std::string& blob_filename, std::FILE*& index_file, std::FILE*& blob_file); @@ -132,6 +135,9 @@ namespace D3D12 const D3D_SHADER_MACRO* macros, const char* entry_point); ComPtr CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc); + ComPtr CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key, + const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc); + bool AddPipelineToBlob(const CacheIndexKey& key, ID3D12PipelineState* pso); std::string m_base_path;