tensorflow-metal 0.3.0 error when training model

I installed TensorFlow (version 2.7.0)as per the instructions here. And I ran an example Keras model (https://keras.io/examples/vision/3D_image_classification). Here is the full Colab code (https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/3D_image_classification.ipynb).

The error occurs when I attempt model training:

# Compile model.
initial_learning_rate = 0.0001
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=100000, decay_rate=0.96, staircase=True
)
model.compile(
    loss="binary_crossentropy",
    optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
    metrics=["acc"],
)

# Define callbacks.
checkpoint_cb = keras.callbacks.ModelCheckpoint(
    "3d_image_classification.h5", save_best_only=True
)
early_stopping_cb = keras.callbacks.EarlyStopping(monitor="val_acc", patience=15)

# Train the model, doing validation at the end of each epoch
epochs = 100


model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=epochs,
    shuffle=True,
    verbose=2,
    callbacks=[checkpoint_cb, early_stopping_cb],
)

The error log is quite long, so I attached it as a txt file. The error occurs when I attempt to train the model using the GPU. The error does not happen when I specify with tf.device('/cpu:0'):, so CPU training appears to be fine, and returns similar results when cross-checked with the results obtained on Google Colab. This appears to be an issue specifically due to tensorflow-metal. My tensorflow-metal version is 0.3.0.

This is very useful Thx!

This problem remains in TensorFlow 2.10 and tensorflow-metal 0.6.0. I'm trying to train a 3D U-Net. The input to the batch norm layer is 5-dimensional, and I get the same error. The same model trains just fine on Windows.

Hello, I seem to be dealing with the same issue on tensorflow-macos 2.9.0 and tensorflow-metal 0.5.0. I am working on a time sensitive project so I was wondering if there are any known fixes, thanks.

tensorflow-metal 0.3.0 error when training model
 
 
Q