Post

Replies

Boosts

Views

Activity

Jax-metal dependencies issue - now requires ml_dtypes==0.2.0
The JAX ml_dtypes module was recently updated to 0.3.0 - as part of this change, the 'float8_e4m3b11' dtype has been deprecated, with newer versions of JAX also reflecting this change. The new ml_dtypes version now seems to be incompatible with JAX v0.4.11. As jax-metal currently requires JAX v0.4.11, perhaps the dependencies list should be updated to include ml_dtypes==0.2.0 in order to prevent the following import error: AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11' Which essentially makes JAX unusable on import (and appears to be fixed by pip install ml_dtypes==0.2.0)
1
1
1.4k
Sep ’23