Hi, I just ran into this issue on an M1 14" MBP. I got it to install and run correctly. Instructions here.
The key is for now it needs jaxlib 0.4.10 but jax 0.4.11. Jax seems to allow using a jaxlib that is one point version less, so this configuration works.
The key instructions are below:
# obtain JAX source code
git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch
cd jax
# build jaxlib from source, with capability to load plugin
python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true
# install jaxlib
python -m pip install dist/*.whl
You also need Bazel 5.1.1 to build jaxlib (it’ll give you instructions if it can’t find it) and Python 3.10 or it won’t install the jaxlib wheel. If you’re using Anaconda you’ll have to create an environment using 3.10 and not any other version.
At this point it tells you to install Jax via pip, but don't do that or it will default to 0.4.10 which is the wrong version. Instead, download the zip for the source code for the 0.4.11 release of Jax: https://github.com/google/jax/releases/tag/jax-v0.4.11
# make sure you're in the jax-v0.4.11 folder
pip install -e .
This should install correctly if it found the correct version of jaxlib it wants and from there you should be able to load Jax and see it is using the GPU by running this command:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)
Good luck.