Post

Replies

Boosts

Views

Activity

jax-metal fails to install on M1 clean environment
Hi all, I'm having trouble even getting jax-metal latest version to install on my M1 MacBook Pro. In a clean conda environment, I pip install jax-metal and get In [1]: import jax; print(jax.numpy.arange(10)) Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! --------------------------------------------------------------------------- XlaRuntimeError Traceback (most recent call last) [... skipping hidden 1 frame] File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/xla_bridge.py:977, in _init_backend(platform) 976 logger.debug("Initializing backend '%s'", platform) --> 977 backend = registration.factory() 978 # TODO(skye): consider raising more descriptive errors directly from backend 979 # factories instead of returning None. File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/xla_bridge.py:666, in register_plugin.<locals>.factory() 665 if not xla_client.pjrt_plugin_initialized(plugin_name): --> 666 xla_client.initialize_pjrt_plugin(plugin_name) 667 updated_options = {} File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jaxlib/xla_client.py:176, in initialize_pjrt_plugin(plugin_name) 169 """Initializes a PJRT plugin. 170 171 The plugin needs to be loaded first (through load_pjrt_plugin_dynamically or (...) 174 plugin_name: the name of the PJRT plugin. 175 """ --> 176 _xla.initialize_pjrt_plugin(plugin_name) XlaRuntimeError: INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). During handling of the above exception, another exception occurred: RuntimeError Traceback (most recent call last) Cell In[1], line 1 ----> 1 import jax; print(jax.numpy.arange(10)) File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2952, in arange(start, stop, step, dtype) 2950 ceil_ = ufuncs.ceil if isinstance(start, core.Tracer) else np.ceil 2951 start = ceil_(start).astype(int) # type: ignore -> 2952 return lax.iota(dtype, start) 2953 else: 2954 if step is None and start == 0 and stop is not None: File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/lax/lax.py:1282, in iota(dtype, size) 1277 def iota(dtype: DTypeLike, size: int) -> Array: 1278 """Wraps XLA's `Iota 1279 <https://www.tensorflow.org/xla/operation_semantics#iota>`_ 1280 operator. 1281 """ -> 1282 return broadcasted_iota(dtype, (size,), 0) File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/lax/lax.py:1292, in broadcasted_iota(dtype, shape, dimension) 1289 static_shape = [None if isinstance(d, core.Tracer) else d for d in shape] 1290 dimension = core.concrete_or_error( 1291 int, dimension, "dimension argument of lax.broadcasted_iota") -> 1292 return iota_p.bind(*dynamic_shape, dtype=dtype, shape=tuple(static_shape), 1293 dimension=dimension) File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/core.py:387, in Primitive.bind(self, *args, **params) 384 def bind(self, *args, **params): 385 assert (not config.enable_checks.value or 386 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args --> 387 return self.bind_with_trace(find_top_trace(args), args, params) File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/core.py:391, in Primitive.bind_with_trace(self, trace, args, params) 389 def bind_with_trace(self, trace, args, params): 390 with pop_level(trace.level): --> 391 out = trace.process_primitive(self, map(trace.full_raise, args), params) 392 return map(full_lower, out) if self.multiple_results else full_lower(out) File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/core.py:879, in EvalTrace.process_primitive(self, primitive, tracers, params) 877 return call_impl_with_key_reuse_checks(primitive, primitive.impl, *tracers, **params) 878 else: --> 879 return primitive.impl(*tracers, **params) File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/dispatch.py:86, in apply_primitive(prim, *args, **params) 84 prev = lib.jax_jit.swap_thread_local_state_disable_jit(False) 85 try: ---> 86 outs = fun(*args) 87 finally: 88 lib.jax_jit.swap_thread_local_state_disable_jit(prev) [... skipping hidden 17 frame] File ~/opt/anaconda3/envs/metal/lib/python3.11/site-packages/jax/_src/xla_bridge.py:902, in backends() 900 else: 901 err_msg += " (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.)" --> 902 raise RuntimeError(err_msg) 904 assert _default_backend is not None 905 if not config.jax_platforms.value: RuntimeError: Unable to initialize backend 'METAL': INVALID_ARGUMENT: Mismatched PJRT plugin PJRT API version (0.47) and framework PJRT API version 0.51). (you may need to uninstall the failing plugin package, or set JAX_PLATFORMS=cpu to skip this backend.) jax.__version__ is 0.4.27.
4
0
1.3k
May ’24