I could fix the issue on my MBP M3Pro / 14" on Sonoma 14.4.1 using the following minimal versions according to https://pypi.org/project/jax-metal/
pip install jax-metal
pip install jax==0.4.26 jaxlib==0.4.26
I additionally ran into a deprecated types error for ml_dtypes version > 0.2.0 thus I did
3. pip install ml_dtypes==0.2.0.
Now, as expected, I get
import jax.numpy as jnp; jnp.arange(5)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1715346567.395054 2183826 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Pro
systemMemory: 18.00 GB
maxCacheSize: 6.00 GB
I0000 00:00:1715346567.407453 2183826 service.cc:145] XLA service 0x600001124c00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1715346567.407588 2183826 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1715346567.409986 2183826 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1715346567.409992 2183826 mps_client.cc:384] XLA backend will use up to 12883132416 bytes on device 0 for SimpleAllocator.
Array([0, 1, 2, 3, 4], dtype=int32)
Post
Replies
Boosts
Views
Activity
i face the same issue on m3p/mbp14".