In my MBP (M1 Max), jax-metal works in the following setup. (I followed the instructions in https://developer.apple.com/metal/jax/ with modifications shown below in parentheses. I used Anaconda environments.)
jaxlib 0.4.10 ( git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch )
jax 0.4.11 ( pip install jax==0.4.11 )
jax-metal 0.0.2
Some catch:
This does not work well in Jupyter, but it works in the command line. So, I use it in a Spyder editor (with IPython).
Some basic jax.numpy operations may have bugs. Ex: The following code does not work.
import jax.numpy as jnp
from jax import random
key = random.PRNGKey(1)
x = random.normal(key, (200, 100))
y = random.normal(key, (100,))
a = jnp.dot(y, x.T) # OK
print(a.shape)
b = jnp.dot(x, y[:, jnp.newaxis]) # OK, but it becomes 2D
print(b.shape)
c = jnp.dot(x, y) # Error in jax-metal 0.0.2 in GPU. Works fine in CPU (with jaxlib 0.4.12, which does not work in GPU, but there is no point in using it in CPU).