diff --git a/common/Vulkan/Builders.cpp b/common/Vulkan/Builders.cpp index ab0ad20d41..ec0026dfa3 100644 --- a/common/Vulkan/Builders.cpp +++ b/common/Vulkan/Builders.cpp @@ -436,6 +436,65 @@ namespace Vulkan m_provoking_vertex.provokingVertexMode = mode; } + ComputePipelineBuilder::ComputePipelineBuilder() { Clear(); } + + void ComputePipelineBuilder::Clear() + { + m_ci = {}; + m_ci.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; + m_si = {}; + m_smap_entries = {}; + m_smap_constants = {}; + } + + VkPipeline ComputePipelineBuilder::Create(VkDevice device, VkPipelineCache pipeline_cache /*= VK_NULL_HANDLE*/, bool clear /*= true*/) + { + VkPipeline pipeline; + VkResult res = vkCreateComputePipelines(device, pipeline_cache, 1, &m_ci, nullptr, &pipeline); + if (res != VK_SUCCESS) + { + LOG_VULKAN_ERROR(res, "vkCreateComputePipelines() failed: "); + return VK_NULL_HANDLE; + } + + if (clear) + Clear(); + + return pipeline; + } + + void ComputePipelineBuilder::SetShader(VkShaderModule module, const char* entry_point) + { + m_ci.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; + m_ci.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; + m_ci.stage.module = module; + m_ci.stage.pName = entry_point; + } + + void ComputePipelineBuilder::SetPipelineLayout(VkPipelineLayout layout) + { + m_ci.layout = layout; + } + + void ComputePipelineBuilder::SetSpecializationBool(u32 index, bool value) + { + const u32 u32_value = static_cast(value); + SetSpecializationValue(index, u32_value); + } + + void ComputePipelineBuilder::SetSpecializationValue(u32 index, u32 value) + { + if (m_si.mapEntryCount == 0) + { + m_si.pMapEntries = m_smap_entries.data(); + m_si.pData = m_smap_constants.data(); + m_ci.stage.pSpecializationInfo = &m_si; + } + + m_smap_entries[m_si.mapEntryCount++] = {index, index * SPECIALIZATION_CONSTANT_SIZE, SPECIALIZATION_CONSTANT_SIZE}; + m_si.dataSize += SPECIALIZATION_CONSTANT_SIZE; + } + SamplerBuilder::SamplerBuilder() { Clear(); } void SamplerBuilder::Clear() diff --git a/common/Vulkan/Builders.h b/common/Vulkan/Builders.h index a9199cede5..86a0f884dc 100644 --- a/common/Vulkan/Builders.h +++ b/common/Vulkan/Builders.h @@ -159,6 +159,37 @@ namespace Vulkan VkPipelineRasterizationProvokingVertexStateCreateInfoEXT m_provoking_vertex; }; + class ComputePipelineBuilder + { + public: + enum : u32 + { + SPECIALIZATION_CONSTANT_SIZE = 4, + MAX_SPECIALIZATION_CONSTANTS = 4, + }; + + ComputePipelineBuilder(); + + void Clear(); + + VkPipeline Create(VkDevice device, VkPipelineCache pipeline_cache = VK_NULL_HANDLE, bool clear = true); + + void SetShader(VkShaderModule module, const char* entry_point); + + void SetPipelineLayout(VkPipelineLayout layout); + + void SetSpecializationBool(u32 index, bool value); + + private: + void SetSpecializationValue(u32 index, u32 value); + + VkComputePipelineCreateInfo m_ci; + + VkSpecializationInfo m_si; + std::array m_smap_entries; + std::array m_smap_constants; + }; + class SamplerBuilder { public: