Jax-metal not recognising GPU

Intel GPU here. I just installed Jax-metal, and it runs fine. However, when I try the following code, it returns the RuntimeError you see below:

jax.device_put(jnp.ones(1), device=jax.devices('gpu')[0])

RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: interpreter,cpu

Unfortunately, the requirements clearly state:

Mac computers with Apple silicon or AMD GPUs

https://developer.apple.com/metal/jax/

I'm having the same issue, using M1 Pro on 16" MBP.

Hi, I just ran into this issue on an M1 14" MBP. I got it to install and run correctly. Instructions here.

The key is for now it needs jaxlib 0.4.10 but jax 0.4.11. Jax seems to allow using a jaxlib that is one point version less, so this configuration works.

The key instructions are below:

# obtain JAX source code
git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch
cd jax
# build jaxlib from source, with capability to load plugin
python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true
# install jaxlib
python -m pip install dist/*.whl

You also need Bazel 5.1.1 to build jaxlib (it’ll give you instructions if it can’t find it) and Python 3.10 or it won’t install the jaxlib wheel. If you’re using Anaconda you’ll have to create an environment using 3.10 and not any other version.

At this point it tells you to install Jax via pip, but don't do that or it will default to 0.4.10 which is the wrong version. Instead, download the zip for the source code for the 0.4.11 release of Jax: https://github.com/google/jax/releases/tag/jax-v0.4.11

# make sure you're in the jax-v0.4.11 folder
pip install -e .

This should install correctly if it found the correct version of jaxlib it wants and from there you should be able to load Jax and see it is using the GPU by running this command:

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

Good luck.

Jax-metal not recognising GPU
 
 
Q