How to efficiently perform compute reductions?

So I’ve got 2 ways to do a reduction across a large buffer.

It’s a little complex so I think its best to just give the GLSL (these specific cases are carrying out the sum of the absolute values of an array).

The 1st involves restricting the number of invocations to 1 workgroup to allow synchronisation:

#version 450
#extension GL_KHR_shader_subgroup_arithmetic : enable

layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;

layout(binding = 0) buffer Buffer0 {
    float x[];
};
layout(binding = 1) buffer Output {
    float total; // total sum
};
// Length of `x`
layout(push_constant) uniform PushConsts {
    uint n; // n=len(x)
};

shared float sums[1024]; // gl_WorkGroupSize.x = 1024
shared float sdata[16]; // gl_WorkGroupSize.x / gl_SubgroupSize.x = 1024 / 64 = 16;

// This should only be called with 1 workgroup
// gl_LocalInvocationID.x === gl_GlobalInvocationID.x
void main() {
    uint indx = gl_LocalInvocationID.x;
    float sum = 0;

    // n -> gl_WorkGroupSize.x
    // ---------------------------
    const uint elementsPer = n / gl_WorkGroupSize.x;
    for(uint i=0;i<elementsPer;++i) {
        sum += abs(x[elementsPer * indx + i]);
    }
    barrier();

    // gl_WorkGroupSize.x -> 1
    // ---------------------------
    if (subgroupElect()) sdata[gl_SubgroupID] = subgroupAdd(sum);
    barrier();

    if (gl_SubgroupID == 0){
        sum = gl_SubgroupInvocationID < gl_NumSubgroups ? sdata[gl_SubgroupInvocationID] : 0;
    }
    if (indx == 0) total = subgroupAdd(sum);
}

The 2nd involves sequentially applying the shader (each time reducing the data by the workgroup size):

#version 450
#extension GL_KHR_shader_subgroup_arithmetic : enable

layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in;

layout(binding = 0) buffer Buffer0 {
    float x[];
};
layout(binding = 1) buffer Output {
    float total[]; // total sum
};

shared float sdata[16]; // gl_WorkGroupSize.x / gl_SubgroupSize.x = 1024 / 64 = 16;

void main() {
    uint indx = gl_GlobalInvocationID.x;
    float sum = abs(x[gl_GlobalInvocationID.x]);

    // len(x) -> len(x) / gl_WorkGroupSize.x
    // ---------------------------
    if (subgroupElect()) sdata[gl_SubgroupID] = subgroupAdd(sum);
    barrier();

    if (gl_SubgroupID == 0) {
        sum = gl_SubgroupInvocationID < gl_NumSubgroups ? sdata[gl_SubgroupInvocationID] : 0;
    }
    if (gl_LocalInvocationID.x == 0) total[gl_WorkGroupID.x] = subgroupAdd(sum);
}

I am yet to properly implement the 2nd as I’m struggling with how to efficiently chain shaders as I would need for this approach (but from my admittedly poor knowledge right now I think that it is possible, right?).

So for a bit of analysis, the O’notation of these 2 approaches is:

  1. O(log2(w) + n/w)
  2. O(log2(n))

Where n equals the number of element and w equals the workgroup size.

I’m using the O’notation here to be the maximum number of elements any invocation needs to reduce (in this case sum).

With a large buffer the maximum number of invocations here becomes notable:

  • Given w=1024, n=1024^3
  • O(log2(w) + n) = 10 + 1024^2
  • O(log2(n)) = 30

So this leads me to think the expense of sequentially running shaders may be worth it.

  • Is my O’notation sensible?
  • Are my approaches sensible?
  • May the 2nd be more efficient?

And most importantly, does anyone know of any examples of how one might implement the 2nd?

Been struggling with it for a while, made a template for attempting to learn: JonathanWoollett-Light/vulkan-blas-l1 (github.com) with the goal of running set.comp then plus.comp a set number of multiple times after sequentially.

(As always, if I’m missing something, drop a comment I’ll try to add more info, I would really appreciate any help here; this feels like just this one roadblock I’m stuck behind in learning Vulkan)

Examples diagrams for each method:

  • circles represents currently active invocations
  • squares represent memory and values within
  • numbers on lines are the gl_GlobalInvocationID.x s of the invocations carring out that operation
  • { represents workgroups (red if any invocation within it is active at any point)
  • |𑁋⏐ represents subgroups (red if any invocation within it is active at any point)
  • Dotted line in 2 represents 1 shader finishing and next begining.

1.

2.

(here since I cant include 2 embedded things in 1)