Post

Replies

Boosts

Views

Activity

Reply to Jax-metal on M2 Pro does not recognize GPU
Hi, I just ran into this issue on an M1 14" MBP. I got it to install and run correctly. Instructions here. The key is for now it needs jaxlib 0.4.10 but jax 0.4.11. Jax seems to allow using a jaxlib that is one point version less, so this configuration works. The key instructions are below: # obtain JAX source code git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch cd jax # build jaxlib from source, with capability to load plugin python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true # install jaxlib python -m pip install dist/*.whl You also need Bazel 5.1.1 to build jaxlib (it’ll give you instructions if it can’t find it) and Python 3.10 or it won’t install the jaxlib wheel. If you’re using Anaconda you’ll have to create an environment using 3.10 and not any other version. At this point it tells you to install Jax via pip, but don't do that or it will default to 0.4.10 which is the wrong version. Instead, download the zip for the source code for the 0.4.11 release of Jax: https://github.com/google/jax/releases/tag/jax-v0.4.11 # make sure you're in the jax-v0.4.11 folder pip install -e . This should install correctly if it found the correct version of jaxlib it wants and from there you should be able to load Jax and see it is using the GPU by running this command: from jax.lib import xla_bridge print(xla_bridge.get_backend().platform) Good luck.
Jun ’23
Reply to Jax-metal not recognising GPU
Hi, I just ran into this issue on an M1 14" MBP. I got it to install and run correctly. Instructions here. The key is for now it needs jaxlib 0.4.10 but jax 0.4.11. Jax seems to allow using a jaxlib that is one point version less, so this configuration works. The key instructions are below: # obtain JAX source code git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch cd jax # build jaxlib from source, with capability to load plugin python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true # install jaxlib python -m pip install dist/*.whl You also need Bazel 5.1.1 to build jaxlib (it’ll give you instructions if it can’t find it) and Python 3.10 or it won’t install the jaxlib wheel. If you’re using Anaconda you’ll have to create an environment using 3.10 and not any other version. At this point it tells you to install Jax via pip, but don't do that or it will default to 0.4.10 which is the wrong version. Instead, download the zip for the source code for the 0.4.11 release of Jax: https://github.com/google/jax/releases/tag/jax-v0.4.11 # make sure you're in the jax-v0.4.11 folder pip install -e . This should install correctly if it found the correct version of jaxlib it wants and from there you should be able to load Jax and see it is using the GPU by running this command: from jax.lib import xla_bridge print(xla_bridge.get_backend().platform) Good luck.
Jun ’23