Keras with tensorflow-metal freezes during training with image augmentation

I am trying to train an image classification network in Keras with tensorflow-metal.

The training freezes after the first 2-3 epochs if image augmentation layers are used (RandomFlip, RandomContrast, RandomBrightness)

The system appears to use both GPU as well as CPU (as indicated by Activity Monitor). Also, warnings appear both in Jupyter and Terminal (see below).

When the image augmentation layers are removed (i.e. we only rebuild the head and feed images from disk), CPU appears to be idle, no warnings appear, and training completes successfully.

Versions: python 3.8, tensorflow-macos 2.11.0, tensorflow-metal 0.7.1

Sample code:

img_augmentation = Sequential(
    [
        layers.RandomFlip(),
        layers.RandomBrightness(factor=0.2),
        layers.RandomContrast(factor=0.2)
    ],
    name="img_augmentation",
)

inputs = layers.Input(shape=(384, 384, 3))
x = img_augmentation(inputs)

model = tf.keras.applications.EfficientNetV2S(include_top=False, input_tensor=x, weights='imagenet')

model.trainable = False
x = tf.keras.layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
x = tf.keras.layers.BatchNormalization()(x)
top_dropout_rate = 0.2
x = tf.keras.layers.Dropout(top_dropout_rate, name="top_dropout")(x)
outputs = tf.keras.layers.Dense(179, activation="softmax", name="pred")(x)

newModel = Model(inputs=model.input, outputs=outputs, name="EfficientNet_DF20M_species")

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.9, patience=2, verbose=1, min_lr=0.000001)

optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=0.01, momentum=0.9)

newModel.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

history = newModel.fit(x=train_ds, validation_data=val_ds, epochs=30, verbose=2, callbacks=[reduce_lr])

During training with image augmentation, Jupyter prints the following warnings while training the first epoch:

WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting RngReadAndSkip cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting Bitcast cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomUniformFullIntV2 cause there is no registered converter for this op.
WARNING:tensorflow:Using a while_loop for converting StatelessRandomGetKeyCounter cause there is no registered converter for this op.
...

During training with image augmentation, Terminal keeps spamming the following warning:

2023-02-21 23:13:38.958633: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2023-02-21 23:13:38.958920: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2023-02-21 23:13:38.959071: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2023-02-21 23:13:38.959115: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
2023-02-21 23:13:38.959359: I metal_plugin/src/kernels/stateless_random_op.cc:282] Note the GPU implementation does not produce the same series as CPU implementation.
...

Any suggestions?

Replies

Would you be able to find a solution? I am having the same problem.

I've got only a solution for the spam in Terminal (found here https://stackoverflow.com/questions/35911252/disable-tensorflow-debugging-information ):

Those messages are infos sent by Tensorflow; I got rid of them by setting the Tensorflow log level to 1 in my environment variables: TF_CPP_MIN_LOG_LEVEL=1

0 = all messages are logged (default behavior) 1 = INFO messages are not printed 2 = INFO and WARNING messages are not printed 3 = INFO, WARNING, and ERROR messages are not printed

  • I tried this but it didn't solve my problem

Add a Comment

I got the same issue when use tfp.mcmc to sample from some distribution.