D3D12: Add compute shader support

This commit is contained in:
Connor McLaughlin 2022-11-21 18:46:09 +10:00 committed by refractionpcsx2
parent d132ddefef
commit 5363a90c6b
4 changed files with 157 additions and 7 deletions

View File

@ -223,6 +223,56 @@ void GraphicsPipelineBuilder::SetDepthStencilFormat(DXGI_FORMAT format)
m_desc.DSVFormat = format; m_desc.DSVFormat = format;
} }
ComputePipelineBuilder::ComputePipelineBuilder()
{
Clear();
}
void ComputePipelineBuilder::Clear()
{
std::memset(&m_desc, 0, sizeof(m_desc));
}
wil::com_ptr_nothrow<ID3D12PipelineState> ComputePipelineBuilder::Create(ID3D12Device* device, bool clear /*= true*/)
{
wil::com_ptr_nothrow<ID3D12PipelineState> ps;
HRESULT hr = device->CreateComputePipelineState(&m_desc, IID_PPV_ARGS(ps.put()));
if (FAILED(hr))
{
Console.Error("CreateComputePipelineState() failed: %08X", hr);
return {};
}
if (clear)
Clear();
return ps;
}
wil::com_ptr_nothrow<ID3D12PipelineState> ComputePipelineBuilder::Create(ID3D12Device* device, ShaderCache& cache, bool clear /*= true*/)
{
wil::com_ptr_nothrow<ID3D12PipelineState> pso = cache.GetPipelineState(device, m_desc);
if (!pso)
return {};
if (clear)
Clear();
return pso;
}
void ComputePipelineBuilder::SetRootSignature(ID3D12RootSignature* rs)
{
m_desc.pRootSignature = rs;
}
void ComputePipelineBuilder::SetShader(const void* data, u32 data_size)
{
m_desc.CS.pShaderBytecode = data;
m_desc.CS.BytecodeLength = data_size;
}
RootSignatureBuilder::RootSignatureBuilder() RootSignatureBuilder::RootSignatureBuilder()
{ {
Clear(); Clear();

View File

@ -114,8 +114,26 @@ namespace D3D12
void SetDepthStencilFormat(DXGI_FORMAT format); void SetDepthStencilFormat(DXGI_FORMAT format);
private: private:
D3D12_GRAPHICS_PIPELINE_STATE_DESC m_desc{}; D3D12_GRAPHICS_PIPELINE_STATE_DESC m_desc;
std::array<D3D12_INPUT_ELEMENT_DESC, MAX_VERTEX_ATTRIBUTES> m_input_elements{}; std::array<D3D12_INPUT_ELEMENT_DESC, MAX_VERTEX_ATTRIBUTES> m_input_elements;
};
class ComputePipelineBuilder
{
public:
ComputePipelineBuilder();
~ComputePipelineBuilder() = default;
void Clear();
wil::com_ptr_nothrow<ID3D12PipelineState> Create(ID3D12Device* device, bool clear = true);
wil::com_ptr_nothrow<ID3D12PipelineState> Create(ID3D12Device* device, ShaderCache& cache, bool clear = true);
void SetRootSignature(ID3D12RootSignature* rs);
void SetShader(const void* data, u32 data_size);
private:
D3D12_COMPUTE_PIPELINE_STATE_DESC m_desc;
}; };
} // namespace D3D12 } // namespace D3D12

View File

@ -387,6 +387,23 @@ ShaderCache::CacheIndexKey ShaderCache::GetPipelineCacheKey(const D3D12_GRAPHICS
return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::GraphicsPipeline}; return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::GraphicsPipeline};
} }
ShaderCache::CacheIndexKey ShaderCache::GetPipelineCacheKey(const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc)
{
MD5Digest digest;
u32 length = sizeof(D3D12_GRAPHICS_PIPELINE_STATE_DESC);
if (gpdesc.CS.BytecodeLength > 0)
{
digest.Update(gpdesc.CS.pShaderBytecode, static_cast<u32>(gpdesc.CS.BytecodeLength));
length += static_cast<u32>(gpdesc.CS.BytecodeLength);
}
MD5Hash h;
digest.Final(h.hash);
return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::ComputePipeline};
}
ShaderCache::ComPtr<ID3DBlob> ShaderCache::GetShaderBlob(EntryType type, std::string_view shader_code, ShaderCache::ComPtr<ID3DBlob> ShaderCache::GetShaderBlob(EntryType type, std::string_view shader_code,
const D3D_SHADER_MACRO* macros /* = nullptr */, const char* entry_point /* = "main" */) const D3D_SHADER_MACRO* macros /* = nullptr */, const char* entry_point /* = "main" */)
{ {
@ -441,6 +458,40 @@ ShaderCache::ComPtr<ID3D12PipelineState> ShaderCache::GetPipelineState(ID3D12Dev
return pso; return pso;
} }
ShaderCache::ComPtr<ID3D12PipelineState> ShaderCache::GetPipelineState(ID3D12Device* device,
const D3D12_COMPUTE_PIPELINE_STATE_DESC& desc)
{
const auto key = GetPipelineCacheKey(desc);
auto iter = m_pipeline_index.find(key);
if (iter == m_pipeline_index.end())
return CompileAndAddPipeline(device, key, desc);
ComPtr<ID3DBlob> blob;
HRESULT hr = D3DCreateBlob(iter->second.blob_size, blob.put());
if (FAILED(hr) || std::fseek(m_pipeline_blob_file, iter->second.file_offset, SEEK_SET) != 0 ||
std::fread(blob->GetBufferPointer(), 1, iter->second.blob_size, m_pipeline_blob_file) != iter->second.blob_size)
{
Console.Error("Read blob from file failed");
return {};
}
D3D12_COMPUTE_PIPELINE_STATE_DESC desc_with_blob(desc);
desc_with_blob.CachedPSO.pCachedBlob = blob->GetBufferPointer();
desc_with_blob.CachedPSO.CachedBlobSizeInBytes = blob->GetBufferSize();
ComPtr<ID3D12PipelineState> pso;
hr = device->CreateComputePipelineState(&desc_with_blob, IID_PPV_ARGS(pso.put()));
if (FAILED(hr))
{
Console.Warning("Creating cached PSO failed: %08X. Invalidating cache.", hr);
InvalidatePipelineCache();
pso = CompileAndAddPipeline(device, key, desc);
}
return pso;
}
ShaderCache::ComPtr<ID3DBlob> ShaderCache::CompileAndAddShaderBlob(const CacheIndexKey& key, std::string_view shader_code, ShaderCache::ComPtr<ID3DBlob> ShaderCache::CompileAndAddShaderBlob(const CacheIndexKey& key, std::string_view shader_code,
const D3D_SHADER_MACRO* macros, const char* entry_point) const D3D_SHADER_MACRO* macros, const char* entry_point)
{ {
@ -457,6 +508,9 @@ ShaderCache::ComPtr<ID3DBlob> ShaderCache::CompileAndAddShaderBlob(const CacheIn
case EntryType::PixelShader: case EntryType::PixelShader:
blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Pixel, m_feature_level, m_debug, shader_code, macros, entry_point); blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Pixel, m_feature_level, m_debug, shader_code, macros, entry_point);
break; break;
case EntryType::ComputeShader:
blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Compute, m_feature_level, m_debug, shader_code, macros, entry_point);
break;
default: default:
break; break;
} }
@ -507,15 +561,37 @@ ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& ke
return {}; return {};
} }
AddPipelineToBlob(key, pso.get());
return pso;
}
ShaderCache::ComPtr<ID3D12PipelineState>
ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc)
{
ComPtr<ID3D12PipelineState> pso;
HRESULT hr = device->CreateComputePipelineState(&gpdesc, IID_PPV_ARGS(pso.put()));
if (FAILED(hr))
{
Console.Error("Creating cached compute PSO failed: %08X", hr);
return {};
}
AddPipelineToBlob(key, pso.get());
return pso;
}
bool ShaderCache::AddPipelineToBlob(const CacheIndexKey& key, ID3D12PipelineState* pso)
{
if (!m_pipeline_blob_file || std::fseek(m_pipeline_blob_file, 0, SEEK_END) != 0) if (!m_pipeline_blob_file || std::fseek(m_pipeline_blob_file, 0, SEEK_END) != 0)
return pso; return false;
ComPtr<ID3DBlob> blob; ComPtr<ID3DBlob> blob;
hr = pso->GetCachedBlob(blob.put()); HRESULT hr = pso->GetCachedBlob(blob.put());
if (FAILED(hr)) if (FAILED(hr))
{ {
Console.Warning("Failed to get cached PSO data: %08X", hr); Console.Warning("Failed to get cached PSO data: %08X", hr);
return pso; return false;
} }
CacheIndexData data; CacheIndexData data;
@ -535,9 +611,9 @@ ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& ke
std::fflush(m_pipeline_index_file) != 0) std::fflush(m_pipeline_index_file) != 0)
{ {
Console.Error("Failed to write pipeline blob to file"); Console.Error("Failed to write pipeline blob to file");
return pso; return false;
} }
m_shader_index.emplace(key, data); m_shader_index.emplace(key, data);
return pso; return true;
} }

View File

@ -41,6 +41,7 @@ namespace D3D12
PixelShader, PixelShader,
ComputeShader, ComputeShader,
GraphicsPipeline, GraphicsPipeline,
ComputePipeline,
}; };
ShaderCache(); ShaderCache();
@ -77,6 +78,7 @@ namespace D3D12
const D3D_SHADER_MACRO* macros = nullptr, const char* entry_point = "main"); const D3D_SHADER_MACRO* macros = nullptr, const char* entry_point = "main");
ComPtr<ID3D12PipelineState> GetPipelineState(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc); ComPtr<ID3D12PipelineState> GetPipelineState(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc);
ComPtr<ID3D12PipelineState> GetPipelineState(ID3D12Device* device, const D3D12_COMPUTE_PIPELINE_STATE_DESC& desc);
private: private:
static constexpr u32 FILE_VERSION = 1; static constexpr u32 FILE_VERSION = 1;
@ -120,6 +122,7 @@ namespace D3D12
static CacheIndexKey GetShaderCacheKey(EntryType type, const std::string_view& shader_code, static CacheIndexKey GetShaderCacheKey(EntryType type, const std::string_view& shader_code,
const D3D_SHADER_MACRO* macros, const char* entry_point); const D3D_SHADER_MACRO* macros, const char* entry_point);
static CacheIndexKey GetPipelineCacheKey(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc); static CacheIndexKey GetPipelineCacheKey(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc);
static CacheIndexKey GetPipelineCacheKey(const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc);
bool CreateNew(const std::string& index_filename, const std::string& blob_filename, std::FILE*& index_file, bool CreateNew(const std::string& index_filename, const std::string& blob_filename, std::FILE*& index_file,
std::FILE*& blob_file); std::FILE*& blob_file);
@ -132,6 +135,9 @@ namespace D3D12
const D3D_SHADER_MACRO* macros, const char* entry_point); const D3D_SHADER_MACRO* macros, const char* entry_point);
ComPtr<ID3D12PipelineState> CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key, ComPtr<ID3D12PipelineState> CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc); const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc);
ComPtr<ID3D12PipelineState> CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc);
bool AddPipelineToBlob(const CacheIndexKey& key, ID3D12PipelineState* pso);
std::string m_base_path; std::string m_base_path;