Jax-metal 0.0.3 appears to have fixed the jnp.dot(x,y) bug. It still needs jaxlib0.4.10 and jax0.4.11 to run in IPython. Performance is not good (slightly better than free Google Colab GPU). It was much slower than the CPU for a small model. It became faster than the CPU with a bigger model.