mirror of https://github.com/PCSX2/pcsx2.git
D3D12: Add compute shader support
This commit is contained in:
parent
d132ddefef
commit
5363a90c6b
|
@ -223,6 +223,56 @@ void GraphicsPipelineBuilder::SetDepthStencilFormat(DXGI_FORMAT format)
|
|||
m_desc.DSVFormat = format;
|
||||
}
|
||||
|
||||
|
||||
ComputePipelineBuilder::ComputePipelineBuilder()
|
||||
{
|
||||
Clear();
|
||||
}
|
||||
|
||||
void ComputePipelineBuilder::Clear()
|
||||
{
|
||||
std::memset(&m_desc, 0, sizeof(m_desc));
|
||||
}
|
||||
|
||||
wil::com_ptr_nothrow<ID3D12PipelineState> ComputePipelineBuilder::Create(ID3D12Device* device, bool clear /*= true*/)
|
||||
{
|
||||
wil::com_ptr_nothrow<ID3D12PipelineState> ps;
|
||||
HRESULT hr = device->CreateComputePipelineState(&m_desc, IID_PPV_ARGS(ps.put()));
|
||||
if (FAILED(hr))
|
||||
{
|
||||
Console.Error("CreateComputePipelineState() failed: %08X", hr);
|
||||
return {};
|
||||
}
|
||||
|
||||
if (clear)
|
||||
Clear();
|
||||
|
||||
return ps;
|
||||
}
|
||||
|
||||
wil::com_ptr_nothrow<ID3D12PipelineState> ComputePipelineBuilder::Create(ID3D12Device* device, ShaderCache& cache, bool clear /*= true*/)
|
||||
{
|
||||
wil::com_ptr_nothrow<ID3D12PipelineState> pso = cache.GetPipelineState(device, m_desc);
|
||||
if (!pso)
|
||||
return {};
|
||||
|
||||
if (clear)
|
||||
Clear();
|
||||
|
||||
return pso;
|
||||
}
|
||||
|
||||
void ComputePipelineBuilder::SetRootSignature(ID3D12RootSignature* rs)
|
||||
{
|
||||
m_desc.pRootSignature = rs;
|
||||
}
|
||||
|
||||
void ComputePipelineBuilder::SetShader(const void* data, u32 data_size)
|
||||
{
|
||||
m_desc.CS.pShaderBytecode = data;
|
||||
m_desc.CS.BytecodeLength = data_size;
|
||||
}
|
||||
|
||||
RootSignatureBuilder::RootSignatureBuilder()
|
||||
{
|
||||
Clear();
|
||||
|
|
|
@ -114,8 +114,26 @@ namespace D3D12
|
|||
void SetDepthStencilFormat(DXGI_FORMAT format);
|
||||
|
||||
private:
|
||||
D3D12_GRAPHICS_PIPELINE_STATE_DESC m_desc{};
|
||||
std::array<D3D12_INPUT_ELEMENT_DESC, MAX_VERTEX_ATTRIBUTES> m_input_elements{};
|
||||
D3D12_GRAPHICS_PIPELINE_STATE_DESC m_desc;
|
||||
std::array<D3D12_INPUT_ELEMENT_DESC, MAX_VERTEX_ATTRIBUTES> m_input_elements;
|
||||
};
|
||||
|
||||
class ComputePipelineBuilder
|
||||
{
|
||||
public:
|
||||
ComputePipelineBuilder();
|
||||
~ComputePipelineBuilder() = default;
|
||||
|
||||
void Clear();
|
||||
|
||||
wil::com_ptr_nothrow<ID3D12PipelineState> Create(ID3D12Device* device, bool clear = true);
|
||||
wil::com_ptr_nothrow<ID3D12PipelineState> Create(ID3D12Device* device, ShaderCache& cache, bool clear = true);
|
||||
|
||||
void SetRootSignature(ID3D12RootSignature* rs);
|
||||
|
||||
void SetShader(const void* data, u32 data_size);
|
||||
private:
|
||||
D3D12_COMPUTE_PIPELINE_STATE_DESC m_desc;
|
||||
};
|
||||
|
||||
} // namespace D3D12
|
|
@ -387,6 +387,23 @@ ShaderCache::CacheIndexKey ShaderCache::GetPipelineCacheKey(const D3D12_GRAPHICS
|
|||
return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::GraphicsPipeline};
|
||||
}
|
||||
|
||||
ShaderCache::CacheIndexKey ShaderCache::GetPipelineCacheKey(const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc)
|
||||
{
|
||||
MD5Digest digest;
|
||||
u32 length = sizeof(D3D12_GRAPHICS_PIPELINE_STATE_DESC);
|
||||
|
||||
if (gpdesc.CS.BytecodeLength > 0)
|
||||
{
|
||||
digest.Update(gpdesc.CS.pShaderBytecode, static_cast<u32>(gpdesc.CS.BytecodeLength));
|
||||
length += static_cast<u32>(gpdesc.CS.BytecodeLength);
|
||||
}
|
||||
|
||||
MD5Hash h;
|
||||
digest.Final(h.hash);
|
||||
|
||||
return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::ComputePipeline};
|
||||
}
|
||||
|
||||
ShaderCache::ComPtr<ID3DBlob> ShaderCache::GetShaderBlob(EntryType type, std::string_view shader_code,
|
||||
const D3D_SHADER_MACRO* macros /* = nullptr */, const char* entry_point /* = "main" */)
|
||||
{
|
||||
|
@ -441,6 +458,40 @@ ShaderCache::ComPtr<ID3D12PipelineState> ShaderCache::GetPipelineState(ID3D12Dev
|
|||
return pso;
|
||||
}
|
||||
|
||||
ShaderCache::ComPtr<ID3D12PipelineState> ShaderCache::GetPipelineState(ID3D12Device* device,
|
||||
const D3D12_COMPUTE_PIPELINE_STATE_DESC& desc)
|
||||
{
|
||||
const auto key = GetPipelineCacheKey(desc);
|
||||
|
||||
auto iter = m_pipeline_index.find(key);
|
||||
if (iter == m_pipeline_index.end())
|
||||
return CompileAndAddPipeline(device, key, desc);
|
||||
|
||||
ComPtr<ID3DBlob> blob;
|
||||
HRESULT hr = D3DCreateBlob(iter->second.blob_size, blob.put());
|
||||
if (FAILED(hr) || std::fseek(m_pipeline_blob_file, iter->second.file_offset, SEEK_SET) != 0 ||
|
||||
std::fread(blob->GetBufferPointer(), 1, iter->second.blob_size, m_pipeline_blob_file) != iter->second.blob_size)
|
||||
{
|
||||
Console.Error("Read blob from file failed");
|
||||
return {};
|
||||
}
|
||||
|
||||
D3D12_COMPUTE_PIPELINE_STATE_DESC desc_with_blob(desc);
|
||||
desc_with_blob.CachedPSO.pCachedBlob = blob->GetBufferPointer();
|
||||
desc_with_blob.CachedPSO.CachedBlobSizeInBytes = blob->GetBufferSize();
|
||||
|
||||
ComPtr<ID3D12PipelineState> pso;
|
||||
hr = device->CreateComputePipelineState(&desc_with_blob, IID_PPV_ARGS(pso.put()));
|
||||
if (FAILED(hr))
|
||||
{
|
||||
Console.Warning("Creating cached PSO failed: %08X. Invalidating cache.", hr);
|
||||
InvalidatePipelineCache();
|
||||
pso = CompileAndAddPipeline(device, key, desc);
|
||||
}
|
||||
|
||||
return pso;
|
||||
}
|
||||
|
||||
ShaderCache::ComPtr<ID3DBlob> ShaderCache::CompileAndAddShaderBlob(const CacheIndexKey& key, std::string_view shader_code,
|
||||
const D3D_SHADER_MACRO* macros, const char* entry_point)
|
||||
{
|
||||
|
@ -457,6 +508,9 @@ ShaderCache::ComPtr<ID3DBlob> ShaderCache::CompileAndAddShaderBlob(const CacheIn
|
|||
case EntryType::PixelShader:
|
||||
blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Pixel, m_feature_level, m_debug, shader_code, macros, entry_point);
|
||||
break;
|
||||
case EntryType::ComputeShader:
|
||||
blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Compute, m_feature_level, m_debug, shader_code, macros, entry_point);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
@ -507,15 +561,37 @@ ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& ke
|
|||
return {};
|
||||
}
|
||||
|
||||
AddPipelineToBlob(key, pso.get());
|
||||
return pso;
|
||||
}
|
||||
|
||||
ShaderCache::ComPtr<ID3D12PipelineState>
|
||||
ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
|
||||
const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc)
|
||||
{
|
||||
ComPtr<ID3D12PipelineState> pso;
|
||||
HRESULT hr = device->CreateComputePipelineState(&gpdesc, IID_PPV_ARGS(pso.put()));
|
||||
if (FAILED(hr))
|
||||
{
|
||||
Console.Error("Creating cached compute PSO failed: %08X", hr);
|
||||
return {};
|
||||
}
|
||||
|
||||
AddPipelineToBlob(key, pso.get());
|
||||
return pso;
|
||||
}
|
||||
|
||||
bool ShaderCache::AddPipelineToBlob(const CacheIndexKey& key, ID3D12PipelineState* pso)
|
||||
{
|
||||
if (!m_pipeline_blob_file || std::fseek(m_pipeline_blob_file, 0, SEEK_END) != 0)
|
||||
return pso;
|
||||
return false;
|
||||
|
||||
ComPtr<ID3DBlob> blob;
|
||||
hr = pso->GetCachedBlob(blob.put());
|
||||
HRESULT hr = pso->GetCachedBlob(blob.put());
|
||||
if (FAILED(hr))
|
||||
{
|
||||
Console.Warning("Failed to get cached PSO data: %08X", hr);
|
||||
return pso;
|
||||
return false;
|
||||
}
|
||||
|
||||
CacheIndexData data;
|
||||
|
@ -535,9 +611,9 @@ ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& ke
|
|||
std::fflush(m_pipeline_index_file) != 0)
|
||||
{
|
||||
Console.Error("Failed to write pipeline blob to file");
|
||||
return pso;
|
||||
return false;
|
||||
}
|
||||
|
||||
m_shader_index.emplace(key, data);
|
||||
return pso;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -41,6 +41,7 @@ namespace D3D12
|
|||
PixelShader,
|
||||
ComputeShader,
|
||||
GraphicsPipeline,
|
||||
ComputePipeline,
|
||||
};
|
||||
|
||||
ShaderCache();
|
||||
|
@ -77,6 +78,7 @@ namespace D3D12
|
|||
const D3D_SHADER_MACRO* macros = nullptr, const char* entry_point = "main");
|
||||
|
||||
ComPtr<ID3D12PipelineState> GetPipelineState(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc);
|
||||
ComPtr<ID3D12PipelineState> GetPipelineState(ID3D12Device* device, const D3D12_COMPUTE_PIPELINE_STATE_DESC& desc);
|
||||
|
||||
private:
|
||||
static constexpr u32 FILE_VERSION = 1;
|
||||
|
@ -120,6 +122,7 @@ namespace D3D12
|
|||
static CacheIndexKey GetShaderCacheKey(EntryType type, const std::string_view& shader_code,
|
||||
const D3D_SHADER_MACRO* macros, const char* entry_point);
|
||||
static CacheIndexKey GetPipelineCacheKey(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc);
|
||||
static CacheIndexKey GetPipelineCacheKey(const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc);
|
||||
|
||||
bool CreateNew(const std::string& index_filename, const std::string& blob_filename, std::FILE*& index_file,
|
||||
std::FILE*& blob_file);
|
||||
|
@ -132,6 +135,9 @@ namespace D3D12
|
|||
const D3D_SHADER_MACRO* macros, const char* entry_point);
|
||||
ComPtr<ID3D12PipelineState> CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
|
||||
const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc);
|
||||
ComPtr<ID3D12PipelineState> CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
|
||||
const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc);
|
||||
bool AddPipelineToBlob(const CacheIndexKey& key, ID3D12PipelineState* pso);
|
||||
|
||||
std::string m_base_path;
|
||||
|
||||
|
|
Loading…
Reference in New Issue