Hello,
I'm interested in trying the new JAX Metal plug-in and followed the steps in https://developer.apple.com/metal/jax/. Upon installation, I don't see any difference between the backend device detected by JAX and a pure CPU setup:
>>> import jax
>>> jax.devices()
[CpuDevice(id=0)]
>>> jax.devices()[0].platform
'cpu'
>>> jax.devices()[0].device_kind
'cpu'
>>> jax.devices()[0].client.platform
'cpu'
>>> jax.devices()[0].client.runtime_type
'tfrt'
Is this really using a Metal backend? How can I determine for sure?
Thank you!
JAX v0.4.11 is required, so fixed by any of the following suggestions:
- https://developer.apple.com/forums/thread/731465?answerId=756330022#756330022
- https://developer.apple.com/forums/thread/731465?answerId=756319022#756319022
(JAX v04.12 doesn't work with Metal. v0.4.13 requires a newer version of jaxlib
)