I was facing the same issue but I appear to have solved it (although my device is an M1 Max). Jaxlib 0.4.10, Jax 0.4.11, and jax-metal 0.0.2.
Install Jaxlib 0.4.10:
# obtain JAX source code
git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch
cd jax
# build jaxlib from source, with capability to load plugin
python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true
# install jaxlib
python -m pip install dist/*.whl
Now install Jax 0.4.11:
Download the source code zip of release 0.4.11 of Jax: https://github.com/google/jax/releases/tag/jax-v0.4.11
Install Jax: python -m pip install -e .
Finally, install jax-metal 0.0.2: python -m pip install jax-metal
Python 3.10.9 (main, Mar 1 2023, 12:20:14) [Clang 14.0.6 ] on darwin
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.devices()
Metal device set to: Apple M1 Max
systemMemory: 32.00 GB
maxCacheSize: 10.67 GB
[MetalDevice(id=0, process_index=0)]
Hope this helps!