mirror of https://github.com/PCSX2/pcsx2.git
D3D12: Add compute shader support
This commit is contained in:
parent
d132ddefef
commit
5363a90c6b
|
@ -223,6 +223,56 @@ void GraphicsPipelineBuilder::SetDepthStencilFormat(DXGI_FORMAT format)
|
||||||
m_desc.DSVFormat = format;
|
m_desc.DSVFormat = format;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
ComputePipelineBuilder::ComputePipelineBuilder()
|
||||||
|
{
|
||||||
|
Clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ComputePipelineBuilder::Clear()
|
||||||
|
{
|
||||||
|
std::memset(&m_desc, 0, sizeof(m_desc));
|
||||||
|
}
|
||||||
|
|
||||||
|
wil::com_ptr_nothrow<ID3D12PipelineState> ComputePipelineBuilder::Create(ID3D12Device* device, bool clear /*= true*/)
|
||||||
|
{
|
||||||
|
wil::com_ptr_nothrow<ID3D12PipelineState> ps;
|
||||||
|
HRESULT hr = device->CreateComputePipelineState(&m_desc, IID_PPV_ARGS(ps.put()));
|
||||||
|
if (FAILED(hr))
|
||||||
|
{
|
||||||
|
Console.Error("CreateComputePipelineState() failed: %08X", hr);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (clear)
|
||||||
|
Clear();
|
||||||
|
|
||||||
|
return ps;
|
||||||
|
}
|
||||||
|
|
||||||
|
wil::com_ptr_nothrow<ID3D12PipelineState> ComputePipelineBuilder::Create(ID3D12Device* device, ShaderCache& cache, bool clear /*= true*/)
|
||||||
|
{
|
||||||
|
wil::com_ptr_nothrow<ID3D12PipelineState> pso = cache.GetPipelineState(device, m_desc);
|
||||||
|
if (!pso)
|
||||||
|
return {};
|
||||||
|
|
||||||
|
if (clear)
|
||||||
|
Clear();
|
||||||
|
|
||||||
|
return pso;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ComputePipelineBuilder::SetRootSignature(ID3D12RootSignature* rs)
|
||||||
|
{
|
||||||
|
m_desc.pRootSignature = rs;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ComputePipelineBuilder::SetShader(const void* data, u32 data_size)
|
||||||
|
{
|
||||||
|
m_desc.CS.pShaderBytecode = data;
|
||||||
|
m_desc.CS.BytecodeLength = data_size;
|
||||||
|
}
|
||||||
|
|
||||||
RootSignatureBuilder::RootSignatureBuilder()
|
RootSignatureBuilder::RootSignatureBuilder()
|
||||||
{
|
{
|
||||||
Clear();
|
Clear();
|
||||||
|
|
|
@ -114,8 +114,26 @@ namespace D3D12
|
||||||
void SetDepthStencilFormat(DXGI_FORMAT format);
|
void SetDepthStencilFormat(DXGI_FORMAT format);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
D3D12_GRAPHICS_PIPELINE_STATE_DESC m_desc{};
|
D3D12_GRAPHICS_PIPELINE_STATE_DESC m_desc;
|
||||||
std::array<D3D12_INPUT_ELEMENT_DESC, MAX_VERTEX_ATTRIBUTES> m_input_elements{};
|
std::array<D3D12_INPUT_ELEMENT_DESC, MAX_VERTEX_ATTRIBUTES> m_input_elements;
|
||||||
|
};
|
||||||
|
|
||||||
|
class ComputePipelineBuilder
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
ComputePipelineBuilder();
|
||||||
|
~ComputePipelineBuilder() = default;
|
||||||
|
|
||||||
|
void Clear();
|
||||||
|
|
||||||
|
wil::com_ptr_nothrow<ID3D12PipelineState> Create(ID3D12Device* device, bool clear = true);
|
||||||
|
wil::com_ptr_nothrow<ID3D12PipelineState> Create(ID3D12Device* device, ShaderCache& cache, bool clear = true);
|
||||||
|
|
||||||
|
void SetRootSignature(ID3D12RootSignature* rs);
|
||||||
|
|
||||||
|
void SetShader(const void* data, u32 data_size);
|
||||||
|
private:
|
||||||
|
D3D12_COMPUTE_PIPELINE_STATE_DESC m_desc;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace D3D12
|
} // namespace D3D12
|
|
@ -387,6 +387,23 @@ ShaderCache::CacheIndexKey ShaderCache::GetPipelineCacheKey(const D3D12_GRAPHICS
|
||||||
return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::GraphicsPipeline};
|
return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::GraphicsPipeline};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShaderCache::CacheIndexKey ShaderCache::GetPipelineCacheKey(const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc)
|
||||||
|
{
|
||||||
|
MD5Digest digest;
|
||||||
|
u32 length = sizeof(D3D12_GRAPHICS_PIPELINE_STATE_DESC);
|
||||||
|
|
||||||
|
if (gpdesc.CS.BytecodeLength > 0)
|
||||||
|
{
|
||||||
|
digest.Update(gpdesc.CS.pShaderBytecode, static_cast<u32>(gpdesc.CS.BytecodeLength));
|
||||||
|
length += static_cast<u32>(gpdesc.CS.BytecodeLength);
|
||||||
|
}
|
||||||
|
|
||||||
|
MD5Hash h;
|
||||||
|
digest.Final(h.hash);
|
||||||
|
|
||||||
|
return CacheIndexKey{h.low, h.high, 0, 0, 0, 0, length, EntryType::ComputePipeline};
|
||||||
|
}
|
||||||
|
|
||||||
ShaderCache::ComPtr<ID3DBlob> ShaderCache::GetShaderBlob(EntryType type, std::string_view shader_code,
|
ShaderCache::ComPtr<ID3DBlob> ShaderCache::GetShaderBlob(EntryType type, std::string_view shader_code,
|
||||||
const D3D_SHADER_MACRO* macros /* = nullptr */, const char* entry_point /* = "main" */)
|
const D3D_SHADER_MACRO* macros /* = nullptr */, const char* entry_point /* = "main" */)
|
||||||
{
|
{
|
||||||
|
@ -441,6 +458,40 @@ ShaderCache::ComPtr<ID3D12PipelineState> ShaderCache::GetPipelineState(ID3D12Dev
|
||||||
return pso;
|
return pso;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShaderCache::ComPtr<ID3D12PipelineState> ShaderCache::GetPipelineState(ID3D12Device* device,
|
||||||
|
const D3D12_COMPUTE_PIPELINE_STATE_DESC& desc)
|
||||||
|
{
|
||||||
|
const auto key = GetPipelineCacheKey(desc);
|
||||||
|
|
||||||
|
auto iter = m_pipeline_index.find(key);
|
||||||
|
if (iter == m_pipeline_index.end())
|
||||||
|
return CompileAndAddPipeline(device, key, desc);
|
||||||
|
|
||||||
|
ComPtr<ID3DBlob> blob;
|
||||||
|
HRESULT hr = D3DCreateBlob(iter->second.blob_size, blob.put());
|
||||||
|
if (FAILED(hr) || std::fseek(m_pipeline_blob_file, iter->second.file_offset, SEEK_SET) != 0 ||
|
||||||
|
std::fread(blob->GetBufferPointer(), 1, iter->second.blob_size, m_pipeline_blob_file) != iter->second.blob_size)
|
||||||
|
{
|
||||||
|
Console.Error("Read blob from file failed");
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
D3D12_COMPUTE_PIPELINE_STATE_DESC desc_with_blob(desc);
|
||||||
|
desc_with_blob.CachedPSO.pCachedBlob = blob->GetBufferPointer();
|
||||||
|
desc_with_blob.CachedPSO.CachedBlobSizeInBytes = blob->GetBufferSize();
|
||||||
|
|
||||||
|
ComPtr<ID3D12PipelineState> pso;
|
||||||
|
hr = device->CreateComputePipelineState(&desc_with_blob, IID_PPV_ARGS(pso.put()));
|
||||||
|
if (FAILED(hr))
|
||||||
|
{
|
||||||
|
Console.Warning("Creating cached PSO failed: %08X. Invalidating cache.", hr);
|
||||||
|
InvalidatePipelineCache();
|
||||||
|
pso = CompileAndAddPipeline(device, key, desc);
|
||||||
|
}
|
||||||
|
|
||||||
|
return pso;
|
||||||
|
}
|
||||||
|
|
||||||
ShaderCache::ComPtr<ID3DBlob> ShaderCache::CompileAndAddShaderBlob(const CacheIndexKey& key, std::string_view shader_code,
|
ShaderCache::ComPtr<ID3DBlob> ShaderCache::CompileAndAddShaderBlob(const CacheIndexKey& key, std::string_view shader_code,
|
||||||
const D3D_SHADER_MACRO* macros, const char* entry_point)
|
const D3D_SHADER_MACRO* macros, const char* entry_point)
|
||||||
{
|
{
|
||||||
|
@ -457,6 +508,9 @@ ShaderCache::ComPtr<ID3DBlob> ShaderCache::CompileAndAddShaderBlob(const CacheIn
|
||||||
case EntryType::PixelShader:
|
case EntryType::PixelShader:
|
||||||
blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Pixel, m_feature_level, m_debug, shader_code, macros, entry_point);
|
blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Pixel, m_feature_level, m_debug, shader_code, macros, entry_point);
|
||||||
break;
|
break;
|
||||||
|
case EntryType::ComputeShader:
|
||||||
|
blob = D3D11::ShaderCompiler::CompileShader(D3D11::ShaderCompiler::Type::Compute, m_feature_level, m_debug, shader_code, macros, entry_point);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -507,15 +561,37 @@ ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& ke
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AddPipelineToBlob(key, pso.get());
|
||||||
|
return pso;
|
||||||
|
}
|
||||||
|
|
||||||
|
ShaderCache::ComPtr<ID3D12PipelineState>
|
||||||
|
ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
|
||||||
|
const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc)
|
||||||
|
{
|
||||||
|
ComPtr<ID3D12PipelineState> pso;
|
||||||
|
HRESULT hr = device->CreateComputePipelineState(&gpdesc, IID_PPV_ARGS(pso.put()));
|
||||||
|
if (FAILED(hr))
|
||||||
|
{
|
||||||
|
Console.Error("Creating cached compute PSO failed: %08X", hr);
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
AddPipelineToBlob(key, pso.get());
|
||||||
|
return pso;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ShaderCache::AddPipelineToBlob(const CacheIndexKey& key, ID3D12PipelineState* pso)
|
||||||
|
{
|
||||||
if (!m_pipeline_blob_file || std::fseek(m_pipeline_blob_file, 0, SEEK_END) != 0)
|
if (!m_pipeline_blob_file || std::fseek(m_pipeline_blob_file, 0, SEEK_END) != 0)
|
||||||
return pso;
|
return false;
|
||||||
|
|
||||||
ComPtr<ID3DBlob> blob;
|
ComPtr<ID3DBlob> blob;
|
||||||
hr = pso->GetCachedBlob(blob.put());
|
HRESULT hr = pso->GetCachedBlob(blob.put());
|
||||||
if (FAILED(hr))
|
if (FAILED(hr))
|
||||||
{
|
{
|
||||||
Console.Warning("Failed to get cached PSO data: %08X", hr);
|
Console.Warning("Failed to get cached PSO data: %08X", hr);
|
||||||
return pso;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
CacheIndexData data;
|
CacheIndexData data;
|
||||||
|
@ -535,9 +611,9 @@ ShaderCache::CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& ke
|
||||||
std::fflush(m_pipeline_index_file) != 0)
|
std::fflush(m_pipeline_index_file) != 0)
|
||||||
{
|
{
|
||||||
Console.Error("Failed to write pipeline blob to file");
|
Console.Error("Failed to write pipeline blob to file");
|
||||||
return pso;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
m_shader_index.emplace(key, data);
|
m_shader_index.emplace(key, data);
|
||||||
return pso;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,6 +41,7 @@ namespace D3D12
|
||||||
PixelShader,
|
PixelShader,
|
||||||
ComputeShader,
|
ComputeShader,
|
||||||
GraphicsPipeline,
|
GraphicsPipeline,
|
||||||
|
ComputePipeline,
|
||||||
};
|
};
|
||||||
|
|
||||||
ShaderCache();
|
ShaderCache();
|
||||||
|
@ -77,6 +78,7 @@ namespace D3D12
|
||||||
const D3D_SHADER_MACRO* macros = nullptr, const char* entry_point = "main");
|
const D3D_SHADER_MACRO* macros = nullptr, const char* entry_point = "main");
|
||||||
|
|
||||||
ComPtr<ID3D12PipelineState> GetPipelineState(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc);
|
ComPtr<ID3D12PipelineState> GetPipelineState(ID3D12Device* device, const D3D12_GRAPHICS_PIPELINE_STATE_DESC& desc);
|
||||||
|
ComPtr<ID3D12PipelineState> GetPipelineState(ID3D12Device* device, const D3D12_COMPUTE_PIPELINE_STATE_DESC& desc);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
static constexpr u32 FILE_VERSION = 1;
|
static constexpr u32 FILE_VERSION = 1;
|
||||||
|
@ -120,6 +122,7 @@ namespace D3D12
|
||||||
static CacheIndexKey GetShaderCacheKey(EntryType type, const std::string_view& shader_code,
|
static CacheIndexKey GetShaderCacheKey(EntryType type, const std::string_view& shader_code,
|
||||||
const D3D_SHADER_MACRO* macros, const char* entry_point);
|
const D3D_SHADER_MACRO* macros, const char* entry_point);
|
||||||
static CacheIndexKey GetPipelineCacheKey(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc);
|
static CacheIndexKey GetPipelineCacheKey(const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc);
|
||||||
|
static CacheIndexKey GetPipelineCacheKey(const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc);
|
||||||
|
|
||||||
bool CreateNew(const std::string& index_filename, const std::string& blob_filename, std::FILE*& index_file,
|
bool CreateNew(const std::string& index_filename, const std::string& blob_filename, std::FILE*& index_file,
|
||||||
std::FILE*& blob_file);
|
std::FILE*& blob_file);
|
||||||
|
@ -132,6 +135,9 @@ namespace D3D12
|
||||||
const D3D_SHADER_MACRO* macros, const char* entry_point);
|
const D3D_SHADER_MACRO* macros, const char* entry_point);
|
||||||
ComPtr<ID3D12PipelineState> CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
|
ComPtr<ID3D12PipelineState> CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
|
||||||
const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc);
|
const D3D12_GRAPHICS_PIPELINE_STATE_DESC& gpdesc);
|
||||||
|
ComPtr<ID3D12PipelineState> CompileAndAddPipeline(ID3D12Device* device, const CacheIndexKey& key,
|
||||||
|
const D3D12_COMPUTE_PIPELINE_STATE_DESC& gpdesc);
|
||||||
|
bool AddPipelineToBlob(const CacheIndexKey& key, ID3D12PipelineState* pso);
|
||||||
|
|
||||||
std::string m_base_path;
|
std::string m_base_path;
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue