I'm trying to implement a pytorch custom layer [grid_sampler] (https://pytorch.org/docs/1.9.1/generated/torch.nn.functional.grid_sample.html) on GPU. Both of its inputs, input
and grid
can be 5-D. My implementation of encodeToCommandBuffer
, which is MLCustomLayer protocol
's function, is shown below. According to my current attempts, both value of id<MTLTexture> input
and id<MTLTexture> grid
don't meet expectations. So i wonder can MTLTexture
be used to store 5-D input tensor as inputs of encodeToCommandBuffer
? Or can anybody help to show me how to use MTLTexture
correctly here? Thanks a lot!
- (BOOL)encodeToCommandBuffer:(id<MTLCommandBuffer>)commandBuffer
inputs:(NSArray<id<MTLTexture>> *)inputs
outputs:(NSArray<id<MTLTexture>> *)outputs
error:(NSError * _Nullable *)error {
NSLog(@"Dispatching to GPU");
NSLog(@"inputs count %lu", (unsigned long)inputs.count);
NSLog(@"outputs count %lu", (unsigned long)outputs.count);
id<MTLComputeCommandEncoder> encoder = [commandBuffer
computeCommandEncoderWithDispatchType:MTLDispatchTypeSerial];
assert(encoder != nil);
id<MTLTexture> input = inputs[0];
id<MTLTexture> grid = inputs[1];
id<MTLTexture> output = outputs[0];
NSLog(@"inputs shape %lu, %lu, %lu, %lu", (unsigned long)input.width, (unsigned long)input.height, (unsigned long)input.depth, (unsigned long)input.arrayLength);
NSLog(@"grid shape %lu, %lu, %lu, %lu", (unsigned long)grid.width, (unsigned long)grid.height, (unsigned long)grid.depth, (unsigned long)grid.arrayLength);
if (encoder)
{
[encoder setTexture:input atIndex:0];
[encoder setTexture:grid atIndex:1];
[encoder setTexture:output atIndex:2];
NSUInteger wd = grid_sample_Pipeline.threadExecutionWidth;
NSUInteger ht = grid_sample_Pipeline.maxTotalThreadsPerThreadgroup / wd;
MTLSize threadsPerThreadgroup = MTLSizeMake(wd, ht, 1);
MTLSize threadgroupsPerGrid = MTLSizeMake((input.width + wd - 1) / wd, (input.height + ht - 1) / ht, input.arrayLength);
[encoder setComputePipelineState:grid_sample_Pipeline];
[encoder dispatchThreadgroups:threadgroupsPerGrid threadsPerThreadgroup:threadsPerThreadgroup];
[encoder endEncoding];
}
else
return NO;
*error = nil;
return YES;
}