Hi, I’m currently adapting the Vulkan mini path tracer tutorial to work with a sphere. However, I cannot get my custom intersection shader called. I believe I have set the shader modules and stages correctly along with the SBT and ray stride/offsets set correctly. Can I get some guidance on why the shader modules does not seem to be called?
void RestirPass::createPipeline(const vk::DescriptorSetLayout& uniformDescSetLayout, const vk::DescriptorSetLayout& sceneDescSetLayout, const vk::DescriptorSetLayout& lightDescSetLayout, const vk::DescriptorSetLayout& restirDescSetLayout) {
enum StageIndices
{
eRaygen,
eMiss,
eMiss2,
eClosestHit,
eClosestHit2,
eIntersection,
eClosestHitt,
eClosestHit2t,
eIntersectiont,
eShaderGroupCount
};
std::vector<std::string> paths = defaultSearchPaths;
// All stages
std::array<vk::PipelineShaderStageCreateInfo, eShaderGroupCount> stages{};
vk::PipelineShaderStageCreateInfo stage ;
stage.pName = "main"; // All the same entry point
// Raygen
stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restir.rgen.spv", true, paths, true));
stage.setStage(vk::ShaderStageFlagBits::eRaygenKHR);
stages[eRaygen] = stage;
// Miss
stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restir.rmiss.spv", true, paths, true));
stage.setStage(vk::ShaderStageFlagBits::eMissKHR);
stages[eMiss] = stage;
// The second miss shader is invoked when a shadow ray misses the geometry. It simply indicates that no occlusion has been found
stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restirShadow.rmiss.spv", true, paths, true));
stage.setStage(vk::ShaderStageFlagBits::eMissKHR);
stages[eMiss2] = stage;
// Hit Group - Closest Hit
stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restir.rchit.spv", true, paths, true));
stage.setStage(vk::ShaderStageFlagBits::eClosestHitKHR);
stages[eClosestHit] = stage;
// Closest hit
stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/raytrace2.rchit.spv", true, paths, true));
stage.setStage(vk::ShaderStageFlagBits::eClosestHitKHR);
stages[eClosestHit2] = stage;
// Intersection
stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/raytrace.rint.spv", true, paths, true));
stage.setStage(vk::ShaderStageFlagBits::eIntersectionKHR);
stages[eIntersection] = stage;
// Hit Group - Closest Hit
stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restir.rchit.spv", true, paths, true));
stage.setStage(vk::ShaderStageFlagBits::eClosestHitKHR);
stages[eClosestHitt] = stage;
// Closest hit
stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/raytrace2.rchit.spv", true, paths, true));
stage.setStage(vk::ShaderStageFlagBits::eClosestHitKHR);
stages[eClosestHit2t] = stage;
// Intersection
stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/raytrace.rint.spv", true, paths, true));
stage.setStage(vk::ShaderStageFlagBits::eIntersectionKHR);
stages[eIntersectiont] = stage;
// Shader groups
VkRayTracingShaderGroupCreateInfoKHR group{ VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR };
group.anyHitShader = VK_SHADER_UNUSED_KHR;
group.closestHitShader = VK_SHADER_UNUSED_KHR;
group.generalShader = VK_SHADER_UNUSED_KHR;
group.intersectionShader = VK_SHADER_UNUSED_KHR;
// Raygen group record
group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
group.generalShader = eRaygen;
m_rtShaderGroups.push_back(group);
// Miss group record
group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
group.generalShader = eMiss;
m_rtShaderGroups.push_back(group);
// Miss group record
group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
group.generalShader = eMiss2;
m_rtShaderGroups.push_back(group);
// Hit group record
group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
group.generalShader = VK_SHADER_UNUSED_KHR;
group.closestHitShader = eClosestHit;
m_rtShaderGroups.push_back(group);
// Hit group record
group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
group.closestHitShader = eClosestHit2;
group.intersectionShader = eIntersection;
m_rtShaderGroups.push_back(group);
// Hit group record
group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
group.closestHitShader = eClosestHitt;
group.intersectionShader = VK_SHADER_UNUSED_KHR;
m_rtShaderGroups.push_back(group);
// Hit group record
group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
group.closestHitShader = eClosestHit2t;
group.intersectionShader = eIntersectiont;
m_rtShaderGroups.push_back(group);
vk::PipelineLayoutCreateInfo pipelineLayoutCreateInfo;
std::vector<vk::DescriptorSetLayout> rtDescSetLayouts{
uniformDescSetLayout,
sceneDescSetLayout,
lightDescSetLayout,
restirDescSetLayout};
pipelineLayoutCreateInfo.setSetLayouts(rtDescSetLayouts);
m_pipelineLayout = m_device.createPipelineLayout(pipelineLayoutCreateInfo);
vk::RayTracingPipelineCreateInfoKHR rayPipelineInfo;
rayPipelineInfo.setStageCount(static_cast<uint32_t>(stages.size())); // Stages are shaders
rayPipelineInfo.setPStages(stages.data());
rayPipelineInfo.setGroupCount(static_cast<uint32_t>(
m_rtShaderGroups.size())); // 1-raygen, n-miss, n-(hit[+anyhit+intersect])
rayPipelineInfo.setPGroups(m_rtShaderGroups.data());
rayPipelineInfo.setMaxPipelineRayRecursionDepth(2); // Ray depth
rayPipelineInfo.setLayout(m_pipelineLayout);
m_pipeline = static_cast<const vk::Pipeline&>(
m_device.createRayTracingPipelineKHR({}, {}, rayPipelineInfo));
_createShaderBindingTable();
}
void RestirPass::_createShaderBindingTable()
{
uint32_t missCount{ 2 };
uint32_t hitCount{ 4 };
auto handleCount = 1 + missCount + hitCount;
uint32_t handleSize = m_rtProperties.shaderGroupHandleSize;
// The SBT (buffer) need to have starting groups to be aligned and handles in the group to be aligned.
uint32_t handleSizeAligned = nvh::align_up(handleSize, m_rtProperties.shaderGroupHandleAlignment);
m_rgenRegion.stride = nvh::align_up(handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
m_rgenRegion.size = m_rgenRegion.stride; // The size member of pRayGenShaderBindingTable must be equal to its stride member
m_missRegion.stride = handleSizeAligned;
m_missRegion.size = nvh::align_up(missCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
m_hitRegion.stride = handleSizeAligned;
m_hitRegion.size = nvh::align_up(hitCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
// Get the shader group handles
uint32_t dataSize = handleCount * handleSize;
std::vector<uint8_t> handles(dataSize);
auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_pipeline, 0, handleCount, dataSize, handles.data());
assert(result == VK_SUCCESS);
// Allocate a buffer for storing the SBT.
VkDeviceSize sbtSize = m_rgenRegion.size + m_missRegion.size + m_hitRegion.size + m_callRegion.size;
m_SBTBuffer = m_alloc->createBuffer(sbtSize,
VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT
| VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR,
VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
//m_debug.setObjectName(m_rtSBTBuffer.buffer, std::string("SBT")); // Give it a debug name for NSight.
// Find the SBT addresses of each group
VkBufferDeviceAddressInfo info{ VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, nullptr, m_SBTBuffer.buffer };
VkDeviceAddress sbtAddress = vkGetBufferDeviceAddress(m_device, &info);
m_rgenRegion.deviceAddress = sbtAddress;
m_missRegion.deviceAddress = sbtAddress + m_rgenRegion.size;
m_hitRegion.deviceAddress = sbtAddress + m_rgenRegion.size + m_missRegion.size;
// Helper to retrieve the handle data
auto getHandle = [&](int i) { return handles.data() + i * handleSize; };
// Map the SBT buffer and write in the handles.
auto* pSBTBuffer = reinterpret_cast<uint8_t*>(m_alloc->map(m_SBTBuffer));
uint8_t* pData{ nullptr };
uint32_t handleIdx{ 0 };
// Raygen
pData = pSBTBuffer;
memcpy(pData, getHandle(handleIdx++), handleSize);
// Miss
pData = pSBTBuffer + m_rgenRegion.size;
for (uint32_t c = 0; c < missCount; c++)
{
memcpy(pData, getHandle(handleIdx++), handleSize);
pData += m_missRegion.stride;
}
// Hit
pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size;
for (uint32_t c = 0; c < hitCount; c++)
{
memcpy(pData, getHandle(handleIdx++), handleSize);
pData += m_hitRegion.stride;
}
m_alloc->unmap(m_SBTBuffer);
m_alloc->finalizeAndReleaseStaging();
}
void RestirPass::run(const vk::CommandBuffer& cmdBuf, const vk::DescriptorSet& uniformDescSet, const vk::DescriptorSet& sceneDescSet, const vk::DescriptorSet& lightDescSet, const vk::DescriptorSet& restirDescSet) {
cmdBuf.pipelineBarrier(
vk::PipelineStageFlagBits::eAllCommands,
vk::PipelineStageFlagBits::eAllCommands,
{}, {}, {}, {}
);
cmdBuf.bindPipeline(vk::PipelineBindPoint::eRayTracingKHR, m_pipeline);
cmdBuf.bindDescriptorSets(vk::PipelineBindPoint::eRayTracingKHR, m_pipelineLayout, 0,
{ uniformDescSet, sceneDescSet, lightDescSet,restirDescSet }, {});
// Size of a program identifier
uint32_t groupSize =
nvh::align_up(m_rtProperties.shaderGroupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
uint32_t groupStride = groupSize;
vk::DeviceAddress sbtAddress = m_device.getBufferAddress({ m_SBTBuffer.buffer });
cmdBuf.traceRaysKHR(
&m_rgenRegion,
&m_missRegion,
&m_hitRegion,
&m_callRegion,
m_size.width,
m_size.height,
1);
}
[[nodiscard]] void _createRtBuffer(const nvh::GltfScene& gltfScene) {
nvvk::CommandPool cmdBufGet(m_device, m_graphicsQueueIndex);
vk::CommandBuffer cmdBuf = cmdBufGet.createCommandBuffer();
auto properties = m_physicalDevice.getProperties2<vk::PhysicalDeviceProperties2,
vk::PhysicalDeviceRayTracingPipelinePropertiesKHR>();
m_rtProperties = properties.get<vk::PhysicalDeviceRayTracingPipelinePropertiesKHR>();
m_rtBuilder.setup(m_device, m_alloc, m_graphicsQueueIndex);
// BLAS - Storing each primitive in a geometry
std::vector<nvvk::RaytracingBuilderKHR::BlasInput> allBlas;
#ifdef VOLUME_RESTIR_USE_GLTF
allBlas.reserve(gltfScene.m_primMeshes.size() + 1);
for (auto& primMesh : gltfScene.m_primMeshes)
{
auto geo = _primitiveToGeometry(m_device, primMesh);
allBlas.push_back({ geo });
}
#else
allBlas.reserve(1);
#endif
// FIX it from here to render sphere
{
nvvk::RaytracingBuilderKHR::BlasInput blas = sphereAabbToVkGeometryKHR();
allBlas.push_back(blas);
}
m_rtBuilder.buildBlas(allBlas, vk::BuildAccelerationStructureFlagBitsKHR::ePreferFastTrace);
std::vector<nvvk::RaytracingBuilderKHR::Instance> tlas;
#ifdef VOLUME_RESTIR_USE_GLTF
tlas.reserve(gltfScene.m_nodes.size());
#else
tlas.reserve(1);
#endif
#ifdef VOLUME_RESTIR_USE_GLTF
for (auto& node : gltfScene.m_nodes)
{
nvvk::RaytracingBuilderKHR::Instance rayInst;
rayInst.transform = node.worldMatrix;
rayInst.instanceCustomId = node.primMesh; // gl_InstanceCustomIndexEXT: to find which primitive
rayInst.blasId = node.primMesh;
rayInst.flags = VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR;
rayInst.hitGroupId = 0; // We will use the same hit group for all objects
tlas.emplace_back(rayInst);
}
#endif
// only for the sphere
nvvk::RaytracingBuilderKHR::Instance rayInst;
#ifdef VOLUME_RESTIR_USE_GLTF
rayInst.instanceCustomId = gltfScene.m_nodes.size();
#else
rayInst.instanceCustomId = 1;
#endif
rayInst.transform = nvmath::mat4f(1);
rayInst.blasId = allBlas.size() - 1;
rayInst.flags = VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR;
rayInst.mask = 0xFF;
rayInst.hitGroupId = 1; // We will use the same hit group for all objects
tlas.emplace_back(rayInst);
m_rtBuilder.buildTlas(tlas, vk::BuildAccelerationStructureFlagBitsKHR::ePreferFastTrace);
// TODO:: Dont know what this is fpr
std::vector<shader::RtPrimitiveLookup> primLookup;
#ifdef VOLUME_RESTIR_USE_GLTF
for (auto& primMesh : gltfScene.m_primMeshes)
primLookup.push_back({ primMesh.firstIndex, primMesh.vertexOffset, primMesh.materialIndex });
#endif
//Spehere Data
shader::RtPrimitiveLookup sphereData;
sphereData.indexOffset = 0;
sphereData.vertexOffset = 0;
#ifdef VOLUME_RESTIR_USE_GLTF
sphereData.materialIndex = gltfScene.m_materials.size();
#else
sphereData.materialIndex = 0;
#endif
primLookup.push_back(sphereData); //TODO Push Custom Stuff for sphere here
m_primlooks = m_alloc->createBuffer(cmdBuf, primLookup, vk::BufferUsageFlagBits::eStorageBuffer);
cmdBufGet.submitAndWait(cmdBuf);
m_alloc->finalizeAndReleaseStaging();
}