Vulkan with GLTF and Procedural geometries, intersection shader not being called

Hi, I’m currently adapting the Vulkan mini path tracer tutorial to work with a sphere. However, I cannot get my custom intersection shader called. I believe I have set the shader modules and stages correctly along with the SBT and ray stride/offsets set correctly. Can I get some guidance on why the shader modules does not seem to be called?

void RestirPass::createPipeline(const vk::DescriptorSetLayout& uniformDescSetLayout, const vk::DescriptorSetLayout& sceneDescSetLayout, const vk::DescriptorSetLayout& lightDescSetLayout, const vk::DescriptorSetLayout& restirDescSetLayout) {

	enum StageIndices
	{
		eRaygen,
		eMiss,
		eMiss2,
		eClosestHit,
		eClosestHit2,
		eIntersection,
		eClosestHitt,
		eClosestHit2t,
		eIntersectiont,
		eShaderGroupCount
	};

	std::vector<std::string> paths = defaultSearchPaths;

	// All stages
	std::array<vk::PipelineShaderStageCreateInfo, eShaderGroupCount> stages{};
	vk::PipelineShaderStageCreateInfo stage ;
	stage.pName = "main";  // All the same entry point
	// Raygen
	stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restir.rgen.spv", true, paths, true));
	stage.setStage(vk::ShaderStageFlagBits::eRaygenKHR);
	stages[eRaygen] = stage;
	// Miss
	stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restir.rmiss.spv", true, paths, true));
	stage.setStage(vk::ShaderStageFlagBits::eMissKHR);
	stages[eMiss] = stage;
	// The second miss shader is invoked when a shadow ray misses the geometry. It simply indicates that no occlusion has been found
	stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restirShadow.rmiss.spv", true, paths, true));
	stage.setStage(vk::ShaderStageFlagBits::eMissKHR);
	stages[eMiss2] = stage;
	// Hit Group - Closest Hit
	stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restir.rchit.spv", true, paths, true));
	stage.setStage(vk::ShaderStageFlagBits::eClosestHitKHR);
	stages[eClosestHit] = stage;
	// Closest hit
	stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/raytrace2.rchit.spv", true, paths, true));
	stage.setStage(vk::ShaderStageFlagBits::eClosestHitKHR);
	stages[eClosestHit2] = stage;
	// Intersection
	stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/raytrace.rint.spv", true, paths, true));
	stage.setStage(vk::ShaderStageFlagBits::eIntersectionKHR);
	stages[eIntersection] = stage;
	// Hit Group - Closest Hit
	stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/restir.rchit.spv", true, paths, true));
	stage.setStage(vk::ShaderStageFlagBits::eClosestHitKHR);
	stages[eClosestHitt] = stage;
	// Closest hit
	stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/raytrace2.rchit.spv", true, paths, true));
	stage.setStage(vk::ShaderStageFlagBits::eClosestHitKHR);
	stages[eClosestHit2t] = stage;
	// Intersection
	stage.module = nvvk::createShaderModule(m_device, nvh::loadFile("src/shaders/raytrace.rint.spv", true, paths, true));
	stage.setStage(vk::ShaderStageFlagBits::eIntersectionKHR);
	stages[eIntersectiont] = stage;

	// Shader groups
	VkRayTracingShaderGroupCreateInfoKHR group{ VK_STRUCTURE_TYPE_RAY_TRACING_SHADER_GROUP_CREATE_INFO_KHR };
	group.anyHitShader = VK_SHADER_UNUSED_KHR;
	group.closestHitShader = VK_SHADER_UNUSED_KHR;
	group.generalShader = VK_SHADER_UNUSED_KHR;
	group.intersectionShader = VK_SHADER_UNUSED_KHR;

	// Raygen group record
	group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
	group.generalShader = eRaygen;
	m_rtShaderGroups.push_back(group);

	// Miss group record 
	group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
	group.generalShader = eMiss;
	m_rtShaderGroups.push_back(group);

	// Miss group record 
	group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_GENERAL_KHR;
	group.generalShader = eMiss2;
	m_rtShaderGroups.push_back(group);

	// Hit group record
	group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
	group.generalShader = VK_SHADER_UNUSED_KHR;
	group.closestHitShader = eClosestHit;
	m_rtShaderGroups.push_back(group);

	// Hit group record
	group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
	group.closestHitShader = eClosestHit2;
	group.intersectionShader = eIntersection;
	m_rtShaderGroups.push_back(group);

	// Hit group record
	group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR;
	group.closestHitShader = eClosestHitt;
	group.intersectionShader = VK_SHADER_UNUSED_KHR;
	m_rtShaderGroups.push_back(group);

	// Hit group record
	group.type = VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR;
	group.closestHitShader = eClosestHit2t;
	group.intersectionShader = eIntersectiont;
	m_rtShaderGroups.push_back(group);


	vk::PipelineLayoutCreateInfo		 pipelineLayoutCreateInfo;
	std::vector<vk::DescriptorSetLayout> rtDescSetLayouts{ 
				uniformDescSetLayout, 
				sceneDescSetLayout, 
				lightDescSetLayout,
				restirDescSetLayout};
	pipelineLayoutCreateInfo.setSetLayouts(rtDescSetLayouts);
	m_pipelineLayout = m_device.createPipelineLayout(pipelineLayoutCreateInfo);


	vk::RayTracingPipelineCreateInfoKHR rayPipelineInfo;
	rayPipelineInfo.setStageCount(static_cast<uint32_t>(stages.size()));  // Stages are shaders
	rayPipelineInfo.setPStages(stages.data());

	rayPipelineInfo.setGroupCount(static_cast<uint32_t>(
		m_rtShaderGroups.size()));  // 1-raygen, n-miss, n-(hit[+anyhit+intersect])
	rayPipelineInfo.setPGroups(m_rtShaderGroups.data());

	rayPipelineInfo.setMaxPipelineRayRecursionDepth(2);  // Ray depth
	rayPipelineInfo.setLayout(m_pipelineLayout);
	m_pipeline = static_cast<const vk::Pipeline&>(
		m_device.createRayTracingPipelineKHR({}, {}, rayPipelineInfo));

	_createShaderBindingTable();
}

void RestirPass::_createShaderBindingTable()
{
	uint32_t missCount{ 2 };
	uint32_t hitCount{ 4 };
	auto     handleCount = 1 + missCount + hitCount;
	uint32_t handleSize = m_rtProperties.shaderGroupHandleSize;

	// The SBT (buffer) need to have starting groups to be aligned and handles in the group to be aligned.
	uint32_t handleSizeAligned = nvh::align_up(handleSize, m_rtProperties.shaderGroupHandleAlignment);

	m_rgenRegion.stride = nvh::align_up(handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
	m_rgenRegion.size	= m_rgenRegion.stride;  // The size member of pRayGenShaderBindingTable must be equal to its stride member
	m_missRegion.stride = handleSizeAligned;
	m_missRegion.size	= nvh::align_up(missCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);
	m_hitRegion.stride	= handleSizeAligned;
	m_hitRegion.size	= nvh::align_up(hitCount * handleSizeAligned, m_rtProperties.shaderGroupBaseAlignment);

	// Get the shader group handles
	uint32_t             dataSize = handleCount * handleSize;
	std::vector<uint8_t> handles(dataSize);
	auto result = vkGetRayTracingShaderGroupHandlesKHR(m_device, m_pipeline, 0, handleCount, dataSize, handles.data());
	assert(result == VK_SUCCESS);

	// Allocate a buffer for storing the SBT.
	VkDeviceSize sbtSize = m_rgenRegion.size + m_missRegion.size + m_hitRegion.size + m_callRegion.size; 
	m_SBTBuffer = m_alloc->createBuffer(sbtSize,
		VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT
		| VK_BUFFER_USAGE_SHADER_BINDING_TABLE_BIT_KHR,
		VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
	//m_debug.setObjectName(m_rtSBTBuffer.buffer, std::string("SBT"));  // Give it a debug name for NSight.

	// Find the SBT addresses of each group
	VkBufferDeviceAddressInfo info{ VK_STRUCTURE_TYPE_BUFFER_DEVICE_ADDRESS_INFO, nullptr, m_SBTBuffer.buffer };
	VkDeviceAddress           sbtAddress = vkGetBufferDeviceAddress(m_device, &info);
	m_rgenRegion.deviceAddress = sbtAddress;
	m_missRegion.deviceAddress = sbtAddress + m_rgenRegion.size;
	m_hitRegion.deviceAddress = sbtAddress + m_rgenRegion.size + m_missRegion.size;

	// Helper to retrieve the handle data
	auto getHandle = [&](int i) { return handles.data() + i * handleSize; };

	// Map the SBT buffer and write in the handles.
	auto* pSBTBuffer = reinterpret_cast<uint8_t*>(m_alloc->map(m_SBTBuffer));
	uint8_t* pData{ nullptr };
	uint32_t handleIdx{ 0 };
	// Raygen
	pData = pSBTBuffer;
	memcpy(pData, getHandle(handleIdx++), handleSize);
	// Miss
	pData = pSBTBuffer + m_rgenRegion.size;
	for (uint32_t c = 0; c < missCount; c++)
	{
		memcpy(pData, getHandle(handleIdx++), handleSize);
		pData += m_missRegion.stride;
	}
	// Hit
	pData = pSBTBuffer + m_rgenRegion.size + m_missRegion.size;
	for (uint32_t c = 0; c < hitCount; c++)
	{
		memcpy(pData, getHandle(handleIdx++), handleSize);
		pData += m_hitRegion.stride;
	}

	m_alloc->unmap(m_SBTBuffer);
	m_alloc->finalizeAndReleaseStaging();
}

void RestirPass::run(const vk::CommandBuffer& cmdBuf, const vk::DescriptorSet& uniformDescSet, const vk::DescriptorSet& sceneDescSet, const vk::DescriptorSet& lightDescSet, const vk::DescriptorSet& restirDescSet) {
	cmdBuf.pipelineBarrier(
		vk::PipelineStageFlagBits::eAllCommands,
		vk::PipelineStageFlagBits::eAllCommands,
		{}, {}, {}, {}
	);
	cmdBuf.bindPipeline(vk::PipelineBindPoint::eRayTracingKHR, m_pipeline);
	cmdBuf.bindDescriptorSets(vk::PipelineBindPoint::eRayTracingKHR, m_pipelineLayout, 0,
		{ uniformDescSet, sceneDescSet, lightDescSet,restirDescSet }, {});
	// Size of a program identifier
	uint32_t groupSize =
		nvh::align_up(m_rtProperties.shaderGroupHandleSize, m_rtProperties.shaderGroupBaseAlignment);
	uint32_t          groupStride = groupSize;
	vk::DeviceAddress sbtAddress = m_device.getBufferAddress({ m_SBTBuffer.buffer });

	cmdBuf.traceRaysKHR(
		&m_rgenRegion,
		&m_missRegion,
		&m_hitRegion,
		&m_callRegion,  
		m_size.width, 
		m_size.height,
		1); 

}

[[nodiscard]] void _createRtBuffer(const nvh::GltfScene& gltfScene) {
		nvvk::CommandPool cmdBufGet(m_device, m_graphicsQueueIndex);
		vk::CommandBuffer cmdBuf = cmdBufGet.createCommandBuffer();
		auto properties = m_physicalDevice.getProperties2<vk::PhysicalDeviceProperties2,
			vk::PhysicalDeviceRayTracingPipelinePropertiesKHR>();
		m_rtProperties = properties.get<vk::PhysicalDeviceRayTracingPipelinePropertiesKHR>();
		m_rtBuilder.setup(m_device, m_alloc, m_graphicsQueueIndex);

		// BLAS - Storing each primitive in a geometry
		std::vector<nvvk::RaytracingBuilderKHR::BlasInput> allBlas;

#ifdef VOLUME_RESTIR_USE_GLTF
		allBlas.reserve(gltfScene.m_primMeshes.size() + 1);
		for (auto& primMesh : gltfScene.m_primMeshes)
		{
			auto geo = _primitiveToGeometry(m_device, primMesh);
			allBlas.push_back({ geo });
		}
#else
    allBlas.reserve(1);
#endif
		// FIX it from here to render sphere
		{
			nvvk::RaytracingBuilderKHR::BlasInput blas = sphereAabbToVkGeometryKHR();
			allBlas.push_back(blas);
		}

		m_rtBuilder.buildBlas(allBlas, vk::BuildAccelerationStructureFlagBitsKHR::ePreferFastTrace);

		std::vector<nvvk::RaytracingBuilderKHR::Instance> tlas;
#ifdef VOLUME_RESTIR_USE_GLTF
		tlas.reserve(gltfScene.m_nodes.size());
#else
    tlas.reserve(1);
#endif

#ifdef VOLUME_RESTIR_USE_GLTF
		for (auto& node : gltfScene.m_nodes)
		{
			nvvk::RaytracingBuilderKHR::Instance rayInst;
			rayInst.transform = node.worldMatrix;
			rayInst.instanceCustomId = node.primMesh;  // gl_InstanceCustomIndexEXT: to find which primitive
			rayInst.blasId = node.primMesh;
			rayInst.flags = VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR;
			rayInst.hitGroupId = 0;  // We will use the same hit group for all objects
			tlas.emplace_back(rayInst);
		}
#endif
		// only for the sphere
		nvvk::RaytracingBuilderKHR::Instance rayInst;
#ifdef VOLUME_RESTIR_USE_GLTF
		rayInst.instanceCustomId = gltfScene.m_nodes.size();
#else
		rayInst.instanceCustomId = 1;
#endif

		rayInst.transform = nvmath::mat4f(1);
		rayInst.blasId = allBlas.size() - 1;
		rayInst.flags = VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT_KHR;
		rayInst.mask = 0xFF;
		rayInst.hitGroupId = 1;  // We will use the same hit group for all objects
		tlas.emplace_back(rayInst);

		m_rtBuilder.buildTlas(tlas, vk::BuildAccelerationStructureFlagBitsKHR::ePreferFastTrace);

		// TODO:: Dont know what this is fpr
		std::vector<shader::RtPrimitiveLookup> primLookup;

#ifdef VOLUME_RESTIR_USE_GLTF
		for (auto& primMesh : gltfScene.m_primMeshes)
			primLookup.push_back({ primMesh.firstIndex, primMesh.vertexOffset, primMesh.materialIndex });
#endif

		//Spehere Data
		shader::RtPrimitiveLookup sphereData;
		sphereData.indexOffset = 0;
		sphereData.vertexOffset = 0;
#ifdef VOLUME_RESTIR_USE_GLTF
		sphereData.materialIndex = gltfScene.m_materials.size();
#else
    sphereData.materialIndex = 0;
#endif
		primLookup.push_back(sphereData); //TODO Push Custom Stuff for sphere here
		m_primlooks = m_alloc->createBuffer(cmdBuf, primLookup, vk::BufferUsageFlagBits::eStorageBuffer);

		cmdBufGet.submitAndWait(cmdBuf);
		m_alloc->finalizeAndReleaseStaging();
	}