Batch normalization - update means and variances

Is this the right method (see below) to update means and variances in the callback updateMeanAndVarianceWithCommandBuffer:batchNormalizationState:?

Code Block
-(MPSCNNNormalizationMeanAndVarianceState*) updateMeanAndVarianceWithCommandBuffer:(id<MTLCommandBuffer>)commandBuffer batchNormalizationState:(MPSCNNBatchNormalizationState*)batchNormalizationState
{
  MPSVector* determinedMeans = [[MPSVector alloc] initWithBuffer:[batchNormalizationState mean] descriptor:[MPSVectorDescriptor vectorDescriptorWithLength:[self featureChannels] dataType:[self dataType]]];
  MPSVector* determinedVariances = [[MPSVector alloc] initWithBuffer:[batchNormalizationState variance] descriptor:[MPSVectorDescriptor vectorDescriptorWithLength:[self featureChannels] dataType:[self dataType]]];
[[self meansOptimizer] encodeToCommandBuffer:commandBuffer inputGradientVector:determinedMeans
inputValuesVector:[self meansVector] inputMomentumVector:nil
resultValuesVector:[self meansVector]];
[[self variancesOptimizer] encodeToCommandBuffer:commandBuffer
inputGradientVector:determinedVariances
inputValuesVector:[self variancesVector]
inputMomentumVector:nil
resultValuesVector:[self variancesVector]];
[batchNormalizationState setReadCount:[batchNormalizationState readCount]-1];
  return [self meanAndVarianceState];
}

The means and variances optimisers are initialised like:
Code Block
_meansOptimizer = [[MPSNNOptimizerStochasticGradientDescent alloc] initWithDevice:_device
momentumScale:0.0
useNestrovMomentum:NO
optimizerDescriptor:[MPSNNOptimizerDescriptor optimizerDescriptorWithLearningRate:-0.1                                                                                                gradientRescale:1.0f                                                                                             regularizationType:MPSNNRegularizationTypeL2                                                                                      regularizationScale:-1.0f]];
 _variancesOptimizer = [[MPSNNOptimizerStochasticGradientDescent alloc] initWithDevice:_device                                        momentumScale:0.0
useNestrovMomentum:NO
optimizerDescriptor:[MPSNNOptimizerDescriptor optimizerDescriptorWithLearningRate:-0.1
gradientRescale:1.0f
regularizationType:MPSNNRegularizationTypeL2
regularizationScale:-1.0f]];


By using this method as in GitHub the callback does not crash anymore but I am not sure if this is correct. Especially because the read count has to be manually decremented, is this OK?

PS: [self meansVector] and [self variancesVector] return MPSVector objects.
PPS: [self dataType] returns MPSDataTypeFloat32.
Batch normalization - update means and variances
 
 
Q