How to optimize ReductionSum with Metal?

I would like to write a ReductionSum Metal Shader like this:

https://github.com/alibaba/MNN/blob/master/source/backend/metal/MetalReduction.metal#L32

Sometimes the reduced dimension is large while the other dimensions is small, which cause few threads can be launched and inefficient.

Is there any way to optimize it?

I suggest you to read the comprehensive article "Optimizing Parallel Reduction in Metal for Apple M1" from Matthew Kieber-Emmons : https://kieber-emmons.medium.com/optimizing-parallel-reduction-in-metal-for-apple-m1-8e8677b49b01

How to optimize ReductionSum with Metal?
 
 
Q