Hi All - I am stuck on a kernel that works on 2014 MacBook, 2014 mini, 2012 iMac, but doesn’t work on DTK or M1 Mac mini or air.
Its a simple chained prefix sum ("StreamScan" for anyone that cares - 10.1145/2442516.2442539). So specifically, the first thread group does a prefix sum on its items, then updates a global array with its inclusive sum and a flag to report that it is finished and that the inclusive sum is available. Subsequent thread groups wait on the flag and when they see that the prior thread group is done, update their global inclusive sum and their status, etc.
Nvidia showed that a similar but slightly more complex strategy (called decoupled loopback) was faster than multi level scans that’s why I was trying it out.
#include <metal_stdlib>
using namespace metal;
kernel void
ChainedPrefixExclusiveSum(device uint* output_data,
volatile device uchar* status_flags,
volatile device uint* inclusive_sums,
const device uint* input_data,
constant uint& n,
uint global_id [[ thread_position_in_grid ]],
uint group_id [[threadgroup_position_in_grid]],
ushort local_id [[thread_position_in_threadgroup]],
ushort local_size [[ threads_per_threadgroup]],
ushort simd_size [[threads_per_simdgroup]],
ushort simd_lane_id [[thread_index_in_simdgroup]],
ushort simd_group_id [[simdgroup_index_in_threadgroup]])
{
// we are going to load from global into registers
uint value = (global_id < n) ? input_data[global_id] : 0;
// simple but slow - scan by warp and save inclusive sum into shared mem
uint scan = simd_prefix_exclusive_sum(value);
threadgroup uint local_partial_sums[32];
if (simd_lane_id == simd_size - 1)
local_partial_sums[simd_group_id] = scan + value;
threadgroup_barrier(mem_flags::mem_threadgroup);
// scan the partial sums in shared mem and save local inclusive sum
threadgroup uint prefix = 0;
if (simd_group_id == 0){
uint aggregate = local_partial_sums[simd_size - 1];
local_partial_sums[simd_lane_id] = simd_prefix_exclusive_sum(local_partial_sums[simd_lane_id]);
aggregate += local_partial_sums[simd_size - 1];
if (simd_lane_id == 0){ //==================================================//
// THE WHILE LOOP HERE IS THE PROBLEM:
if (group_id != 0) while(status_flags[group_id - 1] == 0){}
//==================================================//
prefix = (group_id == 0) ? 0 : inclusive_sums[group_id - 1];
inclusive_sums[group_id] = aggregate + prefix;
status_flags[group_id] = 'X';
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
// add per warp sum and prefix sum to thread scan value and store
output_data[global_id] = scan + local_partial_sums[simd_group_id] + prefix;
}
I get an IOAF code 5 on the while loop on the DTK and M1. The debugger really isn't helping me much, basically its an infinite wait because the device buffer, specifically inclusive_sums[group_id - 1], never gets a value.
Any ideas? Could this be something with tile memory and deferred writes and if so is there a way to always force the write to device via a buffer descriptor somewhere?
I did file a report (FB8967586) but upon further reflection was thinking that maybe this was actually TBDR working as expected, and maybe I am just not finding some API I need to use with a compute kernel. And so I figured I would ask here!
Thanks in advance.