Problem
I am trying to use the jax.numpy.einsum
function (https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.einsum.html). However, for some subscripts, this seems to fail.
Hardware
Apple M1 Max, 32GB RAM
Steps to Reproduce
- follow installation steps from https://developer.apple.com/metal/jax/
conda create -n 'jax_metal_demo' python=3.11
conda activate jax_metal_demo
python -m pip install numpy wheel ml-dtypes==0.2.0
python -m pip install jax-metal
- Save the following code in a file called
minimal_example.py
import numpy as np
from jax import device_put
import jax.numpy as jnp
np.random.seed(0)
a = np.random.rand(11, 12, 13, 11, 12)
b = np.random.rand(11, 12, 13)
subscripts = 'ijklm,ijk->lmk'
# intended result
print(np.einsum(subscripts, a, b))
# will cause crash
a, b = device_put(a), device_put(b)
print(jnp.einsum(subscripts, a, b))
- run the code
python minimal_example.py
Output
I waas expecting
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-02-12 16:45:34.684973: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max
systemMemory: 32.00 GB
maxCacheSize: 10.67 GB
Traceback (most recent call last):
File "/Users/linus/workspace/minimal_example.py", line 15, in <module>
print(jnp.einsum(subscripts, a, b))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 3369, in einsum
return _einsum_computation(operands, contractions, precision, # type: ignore[operator]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/contextlib.py", line 81, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/linus/workspace/minimal_example.py:15:6: error: failed to legalize operation 'mhlo.dot_general'
print(jnp.einsum(subscripts, a, b))
^
/Users/linus/workspace/minimal_example.py:15:6: note: see current operation: %0 = "mhlo.dot_general"(%arg1, %arg0) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [2], rhs_batching_dimensions = [2], lhs_contracting_dimensions = [0, 1], rhs_contracting_dimensions = [0, 1]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<11x12x13xf32>, tensor<11x12x13x11x12xf32>) -> tensor<13x11x12xf32>
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
Conclusion
I would greatly appreciate any ideas for workarounds.