Hi, we already have a fix on the way to the sample code which had a minor bug, attaching the diff below for anyone to be unblocked, please apply it with:
git apply fix.diff
diff --git a/MPSGraphClassifier/MNISTClassifierGraph.swift b/MPSGraphClassifier/MNISTClassifierGraph.swift
index 0f93b6d..2d6e6ef 100644
--- a/MPSGraphClassifier/MNISTClassifierGraph.swift
+++ b/MPSGraphClassifier/MNISTClassifierGraph.swift
@@ -209,7 +209,7 @@ class MNISTClassifierGraph: NSObject {
let doubleBufferingSemaphore = DispatchSemaphore(value: 2)
// Encode training batch to command buffer using double buffering
- func encodeTrainingBatch(commandBuffer: MTLCommandBuffer,
+ func encodeTrainingBatch(commandBuffer: MPSCommandBuffer,
sourceTensorData: MPSGraphTensorData,
labelsTensorData: MPSGraphTensorData,
completion: ((Float) -> Void)?) -> MPSGraphTensorData {
@@ -237,7 +237,7 @@ class MNISTClassifierGraph: NSObject {
let feed = [sourcePlaceholderTensor: sourceTensorData,
labelsPlaceholderTensor: labelsTensorData]
- let fetch = graph.encode(to: MPSCommandBuffer(commandBuffer: commandBuffer),
+ let fetch = graph.encode(to: commandBuffer,
feeds: feed,
targetTensors: targetTrainingTensors,
targetOperations: targetTrainingOps,
@@ -247,7 +247,7 @@ class MNISTClassifierGraph: NSObject {
}
// Encode inference batch to command buffer using double buffering
- func encodeInferenceBatch(commandBuffer: MTLCommandBuffer,
+ func encodeInferenceBatch(commandBuffer: MPSCommandBuffer,
sourceTensorData: MPSGraphTensorData,
labelsTensorData: MPSGraphTensorData) -> MPSGraphTensorData {
doubleBufferingSemaphore.wait()
@@ -286,7 +286,7 @@ class MNISTClassifierGraph: NSObject {
self.doubleBufferingSemaphore.signal()
}
- let fetch = graph.encode(to: MPSCommandBuffer(commandBuffer: commandBuffer),
+ let fetch = graph.encode(to: commandBuffer,
feeds: [sourcePlaceholderTensor: sourceTensorData,
labelsPlaceholderTensor: labelsTensorData],
targetTensors: targetInferenceTensors,
diff --git a/MPSGraphClassifier/ViewController.swift b/MPSGraphClassifier/ViewController.swift
index 6efc787..2aafe88 100644
--- a/MPSGraphClassifier/ViewController.swift
+++ b/MPSGraphClassifier/ViewController.swift
@@ -247,7 +247,7 @@ class ViewController: UIViewController, CanvasDelegate {
// Run a training iteration batch
func runTrainingIterationBatch() -> MTLCommandBuffer {
- let commandBuffer = gCommandQueue.makeCommandBuffer()!
+ let commandBuffer = MPSCommandBuffer(commandBuffer: gCommandQueue.makeCommandBuffer()!)
var yLabels: MPSNDArray? = nil
let xInput = dataset.getRandomTrainingBatch(device: gDevice, batchSize: batchSize, labels: &yLabels)
@@ -297,7 +297,7 @@ class ViewController: UIViewController, CanvasDelegate {
// encoding each image
for currImageIdx in stride(from: 0, to: dataset.totalNumberOfTestImages, by: Int(batchSize)) {
- let commandBuffer = gCommandQueue.makeCommandBuffer()!
+ let commandBuffer = MPSCommandBuffer(commandBuffer: gCommandQueue.makeCommandBuffer()!)
xInput = dataset.getTrainingBatchWithDevice(device: gDevice,
batchIndex: Int(currImageIdx) / Int(batchSize),