Post

Replies

Boosts

Views

Activity

tf.function decorator with tensorflow-metal breaks tf.signal.fft3d()
I consistently receive corrupted results from tf.signal.fft3d() when it is within a function that has a @tf.function decorator. The results are all zero (0.) for entries after a certain x index (see image). Surprisingly, the issue depends on the matrix size. For example, (1023, 1023, 287) works but (1023, 1023, 575) does not. The issue is problematic because it occurs silently and not for all matrix sizes, i.e. can easily slip through tests. The error occurs only when tensorflow-metal is installed. The Tensorflow version is 2.16.1. My hardware is a Macbook Pro M3 Max with 40 GPU cores, 128 GB RAM running MacOS Sonoma version 14.5 (23F79). A Python environment to reproduce the bug can be created as follows: conda create --name tfmetalbug python=3.11.9 conda activate tfmetalbug pip install tensorflow tensorflow-metal conda install matplotlib The following code reproduces the issue: import tensorflow as tf import numpy as np import matplotlib.pyplot as plt # Wrap fft3d with tf.function @tf.function def fft3d_wrapper_function(x): return tf.signal.fft3d(x) # Generate a 3D image img = tf.random.normal(shape=(1023, 1023, 575), stddev=1., dtype=float) # generate random 3d image img = tf.dtypes.cast(img, tf.complex64) # convert to complex values # Compute the 3D FFT img_fft = fft3d_wrapper_function(img) # Visualize the 3D FFT plt.imshow(np.real(img_fft)[:, img_fft.shape[1]//2+10, :], cmap="gray", vmin=-0.001, vmax=0.001) plt.savefig("fft3d_wrapper_function.png") For me, removing the @tf.function decorator has resolved the issue.
0
0
547
Jun ’24