Hi all,
I'm having trouble even getting jax-metal latest version to install on my M1 MacBook Pro. In a clean conda environment, I pip install jax-metal
and get
In [1]: import jax; print(jax.numpy.arange(10))
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
[... skipping hidden 1 frame]
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/xla_bridge.py:977, in _init_backend(platform)
976 logger.debug("Initializing backend '%s'", platform)
--> 977 backend = registration.factory()
978 # TODO(skye): consider raising more descriptive errors directly from backend
979 # factories instead of returning None.
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/xla_bridge.py:666, in register_plugin.<locals>.factory()
665 if not xla_client.pjrt_plugin_initialized(plugin_name):
--> 666 xla_client.initialize_pjrt_plugin(plugin_name)
667 updated_options = {}
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jaxlib/xla_client.py:176, in initialize_pjrt_plugin(plugin_name)
169 """Initializes a PJRT plugin.
170
171 The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or
(...)
174 plugin_name: the name of the PJRT plugin.
175 """
--> 176 _xla.initialize_pjrt_plugin(plugin_name)
XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51).
During handling of the above exception, another exception occurred:
RuntimeError Traceback (most recent call last)
Cell In[1], line 1
----> 1 import jax; print(jax.numpy.arange(10))
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2952, in arange(start, stop, step, dtype)
2950 ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil
2951 start = ceil_(start).astype(int) # type: ignore
-> 2952 return lax.iota(dtype, start)
2953 else:
2954 if step is None and start == 0 and stop is not None:
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/lax/lax.py:1282, in iota(dtype, size)
1277 def iota(dtype: DTypeLike, size: int) -> Array:
1278 """Wraps XLA's `Iota
1279 <https://www.tensorflow.org/xla/operation_semantics#iota>`_
1280 operator.
1281 """
-> 1282 return broadcasted_iota(dtype, (size,), 0)
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/lax/lax.py:1292, in broadcasted_iota(dtype, shape, dimension)
1289 static_shape = [None if isinstance(d, core.Tracer) else d for d in shape]
1290 dimension = core.concrete_or_error(
1291 int, dimension, "dimension argument of lax.broadcasted_iota")
-> 1292 return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape),
1293 dimension=dimension)
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/core.py:387, in Primitive.bind(self, *args, **params)
384 def bind(self, *args, **params):
385 assert (not config.enable_checks.value or
386 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 387 return self.bind_with_trace(find_top_trace(args), args, params)
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/core.py:391, in Primitive.bind_with_trace(self, trace, args, params)
389 def bind_with_trace(self, trace, args, params):
390 with pop_level(trace.level):
--> 391 out = trace.process_primitive(self, map(trace.full_raise, args), params)
392 return map(full_lower, out) if self.multiple_results else full_lower(out)
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/core.py:879, in EvalTrace.process_primitive(self, primitive, tracers, params)
877 return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params)
878 else:
--> 879 return primitive.impl(*tracers, **params)
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/dispatch.py:86, in apply_primitive(prim, *args, **params)
84 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
85 try:
---> 86 outs = fun(*args)
87 finally:
88 lib.jax_jit.swap_thread_local_state_disable_jit(prev)
[... skipping hidden 17 frame]
File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/xla_bridge.py:902, in backends()
900 else:
901 err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)"
--> 902 raise RuntimeError(err_msg)
904 assert _default_backend is not None
905 if not config.jax_platforms.value:
RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)
jax.__version__
is 0.4.27.