TensorFlow-Metal breaks TF Percentile Function

Hi all, I was experimenting with the tf-metal (v0.4) framework and noticed some odd interactions with the tensorflow_probability package:

import tensorflow as tf
import numpy as np
from tensorflow_probability.python.stats import percentile

if __name__ == "__main__":
    data = np.array([0.12941672, 0.22039098, 0.33956015, 0.3787993, 0.5329178, 0.62175393, 0.5906472, 0.97234255, 0.7709932, 0.76639813, 1.0468946, 1.1515584, 1.0470238, 1.1140094, 1.2083299, 1.051311, 1.0782655, 1.0192754, 0.8690998, 0.9439713, 0.6992503, 0.7017522, 0.6524739, 0.536425, 0.47863948, 0.46657538, 0.45757294, 0.2988146, 0.19273241, 0.1494804, 0., 0.], dtype=np.float64)
    data16 = tf.convert_to_tensor(data, dtype=tf.float16)
    data32 = tf.convert_to_tensor(data, dtype=tf.float32)
    data64 = tf.convert_to_tensor(data, dtype=tf.float64)

    p = percentile(data, 99, keepdims=True, interpolation="lower")
    print(f"Percentile based on Numpy array (float64): {p}")

    p = percentile(data16, 99, keepdims=True, interpolation="lower")
    print(f"Percentile based on TF (float16): {p}")

    p = percentile(data32, 99, keepdims=True, interpolation="lower")
    print(f"Percentile based on TF (float32): {p}")

    p = percentile(data64, 99, keepdims=True, interpolation="lower")
    print(f"Percentile based on TF (float64): {p}")

This results in:

Percentile based on Numpy array (float64): [1.1515584]

Percentile based on TF (float16): [1.151]

Percentile based on TF (float32): [-0.]

Percentile based on TF (float64): [1.1515584]

The float32 value here is obviously corrupted, whereas the others are fine (presumably because only float32 is sent to the gpu?). When I uninstall tf-metal the float32 values are computed correctly. Any thoughts on when a fix might be available? Also, is there any timeline for supporting float16 on gpu?

Answered by Frameworks Engineer in 724410022

TopK op bug with K>16 is fixed in tensorflow-metal==0.5.1 which also addresses the bug seen here.

Underneath TF Percentile uses TopK op to perform the computations. This seems to stem from a problem on the TopK op we can reproduce locally. We will update here once the fix is in.

Accepted Answer

TopK op bug with K>16 is fixed in tensorflow-metal==0.5.1 which also addresses the bug seen here.

TensorFlow-Metal breaks TF Percentile Function
 
 
Q