Inconsistent int64 tensor device after abs

Despite with tf.device('/GPU:0'), it is unexpected that the output of int64 GPU tensor is put to CPU on M1 Pro. By contrast, it is consistent when running on Nvidia GPU on Google Colab. As discussed in the PyTorch issue, is it the limitation of MPS int64 or will it be fixed in near future?

import tensorflow as tf

with tf.device('/GPU:0'):
    a = tf.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
    b = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
    c = tf.matmul(a, b)
    cabs = tf.abs(c)
    c64 = tf.cast(c, dtype=tf.int64)
    c32 = tf.cast(c, dtype=tf.int32)
    c64abs = tf.abs(c64)
    c32abs = tf.abs(c32)
    print('c float:', c.device)
    print('c int32:', c32.device)
    print('c int64:', c64.device)
    print('c float abs:', cabs.device)
    print('c int32 abs:', c32abs.device) 
    print('c int64 abs:', c64abs.device) 

Output:

c float: /job:localhost/replica:0/task:0/device:GPU:0
c int32: /job:localhost/replica:0/task:0/device:GPU:0
c int64: /job:localhost/replica:0/task:0/device:GPU:0
c float abs: /job:localhost/replica:0/task:0/device:GPU:0
c int32 abs: /job:localhost/replica:0/task:0/device:GPU:0
c int64 abs: /job:localhost/replica:0/task:0/device:CPU:0
                                                    ^^^^^

Hi @farleylai!

We are currently working on making more data types available on the GPU ops in tensorflow-metal. The issue mentioned in the PyTorch thread was specifically a correctness issue on AMD machines but on Apple GPUs we do not have the same limitations with int64. It's just a matter of enabling the ops and verifying the correctness. I've made a note that the int64 support has been requested to up the priority on these. I'll update here once we have the support out.

Inconsistent int64 tensor device after abs
 
 
Q