jax-metal fails to install on M1 clean environment

Hi all,

I'm having trouble even getting jax-metal latest version to install on my M1 MacBook Pro. In a clean conda environment, I pip install jax-metal and get

In [1]: import jax; print(jax.numpy.arange(10))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
    [... skipping hidden 1 frame]

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/xla_bridge.py:977, in _init_backend(platform)
    976 logger.debug("Initializing backend '%s'", platform)
--> 977 backend = registration.factory()
    978 # TODO(skye): consider raising more descriptive errors directly from backend
    979 # factories instead of returning None.

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/xla_bridge.py:666, in register_plugin.<locals>.factory()
    665 if not xla_client.pjrt_plugin_initialized(plugin_name):
--> 666   xla_client.initialize_pjrt_plugin(plugin_name)
    667 updated_options = {}

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jaxlib/xla_client.py:176, in initialize_pjrt_plugin(plugin_name)
    169 """Initializes a PJRT plugin.
    170 
    171 The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or
   (...)
    174   plugin_name: the name of the PJRT plugin.
    175 """
--> 176 _xla.initialize_pjrt_plugin(plugin_name)

XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[1], line 1
----> 1 import jax; print(jax.numpy.arange(10))

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2952, in arange(start, stop, step, dtype)
   2950     ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil
   2951     start = ceil_(start).astype(int)  # type: ignore
-> 2952   return lax.iota(dtype, start)
   2953 else:
   2954   if step is None and start == 0 and stop is not None:

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/lax/lax.py:1282, in iota(dtype, size)
   1277 def iota(dtype: DTypeLike, size: int) -> Array:
   1278   """Wraps XLA's `Iota
   1279   <https://www.tensorflow.org/xla/operation_semantics#iota>`_
   1280   operator.
   1281   """
-> 1282   return broadcasted_iota(dtype, (size,), 0)

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/lax/lax.py:1292, in broadcasted_iota(dtype, shape, dimension)
   1289 static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
   1290 dimension = core.concrete_or_error(
   1291     int, dimension, "dimension argument of lax.broadcasted_iota")
-> 1292 return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
   1293                    dimension=dimension)

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/core.py:387, in Primitive.bind(self, *args, **params)
    384 def bind(self, *args, **params):
    385   assert (not config.enable_checks.value or
    386           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 387   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/core.py:391, in Primitive.bind_with_trace(self, trace, args, params)
    389 def bind_with_trace(self, trace, args, params):
    390   with pop_level(trace.level):
--> 391     out = trace.process_primitive(self, map(trace.full_raise, args), params)
    392   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/core.py:879, in EvalTrace.process_primitive(self, primitive, tracers, params)
    877   return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
    878 else:
--> 879   return primitive.impl(*tracers, **params)

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/dispatch.py:86, in apply_primitive(prim, *args, **params)
     84 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
     85 try:
---> 86   outs = fun(*args)
     87 finally:
     88   lib.jax_jit.swap_thread_local_state_disable_jit(prev)

    [... skipping hidden 17 frame]

File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/xla_bridge.py:902, in backends()
    900       else:
    901         err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
--> 902       raise RuntimeError(err_msg)
    904 assert _default_backend is not None
    905 if not config.jax_platforms.value:

RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)

jax.__version__ is 0.4.27.

i face the same issue on m3p/mbp14".

I could fix the issue on my MBP M3Pro / 14" on Sonoma 14.4.1 using the following minimal versions according to https://pypi.org/project/jax-metal/

  1. pip install jax-metal
  2. pip install jax==0.4.26 jaxlib==0.4.26

I additionally ran into a deprecated types error for ml_dtypes version > 0.2.0 thus I did 3. pip install ml_dtypes==0.2.0.

Now, as expected, I get

import jax.numpy as jnp; jnp.arange(5)
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1715346567.395054 2183826 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M3 Pro

systemMemory: 18.00 GB
maxCacheSize: 6.00 GB

I0000 00:00:1715346567.407453 2183826 service.cc:145] XLA service 0x600001124c00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1715346567.407588 2183826 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1715346567.409986 2183826 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1715346567.409992 2183826 mps_client.cc:384] XLA backend will use up to 12883132416 bytes on device 0 for SimpleAllocator.
Array([0, 1, 2, 3, 4], dtype=int32)

Thank you for this awesome advice!

The jax-metal pypi package should set its dependencies to jax==0.4.26 jaxlib==0.4.26 rather than using >=. The devs can choose to patch the pypi release with a post release version.

jax-metal fails to install on M1 clean environment
 
 
Q