Metal JAX device appears as `cpu`

Hello,

I'm interested in trying the new JAX Metal plug-in and followed the steps in https://developer.apple.com/metal/jax/. Upon installation, I don't see any difference between the backend device detected by JAX and a pure CPU setup:

>>> import jax
>>> jax.devices()
[CpuDevice(id=0)]
>>> jax.devices()[0].platform
'cpu'
>>> jax.devices()[0].device_kind
'cpu'
>>> jax.devices()[0].client.platform
'cpu'
>>> jax.devices()[0].client.runtime_type
'tfrt'

Is this really using a Metal backend? How can I determine for sure?

Thank you!

Answered by pcuenca in 756441022

JAX v0.4.11 is required, so fixed by any of the following suggestions:

(JAX v04.12 doesn't work with Metal. v0.4.13 requires a newer version of jaxlib)

(M1 Max running Sonoma 14.0 and Xcode Version 15.0 beta (15A5160n))

I am facing the same issue.

Same here

I have the same issue. Have already tried installing and re-installing thrice, but the same problem. I am using condo environment but I don't think that is the issue, since all the installations take place smoothly. Wonder when this issue will be resolved.

Accepted Answer

JAX v0.4.11 is required, so fixed by any of the following suggestions:

(JAX v04.12 doesn't work with Metal. v0.4.13 requires a newer version of jaxlib)

Metal JAX device appears as `cpu`
 
 
Q