Posts

Post not yet marked as solved
2 Replies
526 Views
Problem I am trying to use the jax.numpy.einsum function (https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.einsum.html). However, for some subscripts, this seems to fail. Hardware Apple M1 Max, 32GB RAM Steps to Reproduce follow installation steps from https://developer.apple.com/metal/jax/ conda create -n 'jax_metal_demo' python=3.11 conda activate jax_metal_demo python -m pip install numpy wheel ml-dtypes==0.2.0 python -m pip install jax-metal Save the following code in a file called minimal_example.py import numpy as np from jax import device_put import jax.numpy as jnp np.random.seed(0) a = np.random.rand(11, 12, 13, 11, 12) b = np.random.rand(11, 12, 13) subscripts = 'ijklm,ijk->lmk' # intended result print(np.einsum(subscripts, a, b)) # will cause crash a, b = device_put(a), device_put(b) print(jnp.einsum(subscripts, a, b)) run the code python minimal_example.py Output I waas expecting Platform 'METAL' is experimental and not all JAX functionality may be correctly supported! 2024-02-12 16:45:34.684973: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported! Metal device set to: Apple M1 Max systemMemory: 32.00 GB maxCacheSize: 10.67 GB Traceback (most recent call last): File "/Users/linus/workspace/minimal_example.py", line 15, in <module> print(jnp.einsum(subscripts, a, b)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 3369, in einsum return _einsum_computation(operands, contractions, precision, # type: ignore[operator] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/contextlib.py", line 81, in inner return func(*args, **kwds) ^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/linus/workspace/minimal_example.py:15:6: error: failed to legalize operation 'mhlo.dot_general' print(jnp.einsum(subscripts, a, b)) ^ /Users/linus/workspace/minimal_example.py:15:6: note: see current operation: %0 = "mhlo.dot_general"(%arg1, %arg0) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [2], rhs_batching_dimensions = [2], lhs_contracting_dimensions = [0, 1], rhs_contracting_dimensions = [0, 1]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<11x12x13xf32>, tensor<11x12x13x11x12xf32>) -> tensor<13x11x12xf32> -------------------- For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. Conclusion I would greatly appreciate any ideas for workarounds.
Posted Last updated
.