The JAX ml_dtypes module was recently updated to 0.3.0 - as part of this change, the 'float8_e4m3b11' dtype has been deprecated, with newer versions of JAX also reflecting this change. The new ml_dtypes version now seems to be incompatible with JAX v0.4.11.
As jax-metal currently requires JAX v0.4.11, perhaps the dependencies list should be updated to include ml_dtypes==0.2.0 in order to prevent the following import error:
AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11'
Which essentially makes JAX unusable on import (and appears to be fixed by pip install ml_dtypes==0.2.0
)