Build and installed Jax and Jax-metal following instructions on a M2Pro Mac-mini from here - https://developer.apple.com/metal/jax/
However, the following check seems to suggest XLA using CPU and not GPU.
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
cpu
Has anyone got it working to dump GPU?
Thanks in advance!