jax-metal error on jax.numpy.einsum

Problem

I am trying to use the jax.numpy.einsum function (https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.einsum.html). However, for some subscripts, this seems to fail.

Hardware

Apple M1 Max, 32GB RAM

Steps to Reproduce

  1. follow installation steps from https://developer.apple.com/metal/jax/
conda create -n 'jax_metal_demo' python=3.11
conda activate jax_metal_demo
python -m pip install numpy wheel ml-dtypes==0.2.0
python -m pip install jax-metal
  1. Save the following code in a file called minimal_example.py
import numpy as np
from jax import device_put
import jax.numpy as jnp

np.random.seed(0)
a = np.random.rand(11, 12, 13, 11, 12)
b = np.random.rand(11, 12, 13)

subscripts = 'ijklm,ijk->lmk'
# intended result
print(np.einsum(subscripts, a, b))

# will cause crash
a, b = device_put(a), device_put(b)
print(jnp.einsum(subscripts, a, b))
  1. run the code
python minimal_example.py

Output

I waas expecting

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
2024-02-12 16:45:34.684973: W pjrt_plugin/src/mps_client.cc:563] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Max

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB

Traceback (most recent call last):
  File "/Users/linus/workspace/minimal_example.py", line 15, in <module>
    print(jnp.einsum(subscripts, a, b))
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py", line 3369, in einsum
    return _einsum_computation(operands, contractions, precision,  # type: ignore[operator]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/linus/miniforge3/envs/jax_metal_demo/lib/python3.11/contextlib.py", line 81, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/linus/workspace/minimal_example.py:15:6: error: failed to legalize operation 'mhlo.dot_general'
print(jnp.einsum(subscripts, a, b))
     ^
/Users/linus/workspace/minimal_example.py:15:6: note: see current operation: %0 = "mhlo.dot_general"(%arg1, %arg0) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [2], rhs_batching_dimensions = [2], lhs_contracting_dimensions = [0, 1], rhs_contracting_dimensions = [0, 1]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<11x12x13xf32>, tensor<11x12x13x11x12xf32>) -> tensor<13x11x12xf32>

--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

Conclusion

I would greatly appreciate any ideas for workarounds.

This is reproducible with jax-metal 0.0.5. The lowering pattern need to be expanded to handle contracting_dimensions size > 1. To workaround, could you make the below changes and give it a try:

a = np.random.rand(11, 12, 13, 11, 12).reshape(132, 13, 11, 12)
b = np.random.rand(11, 12, 13).reshape(132, 13)

#subscripts = 'ijklm,ijk->lmk'
subscripts = 'iklm,ik->lmk'

It generates matching result on my side.

I'm getting what looks like the same error during the backward pass for a transformer, but in this case it's less clear how to work around it:

layers.py:108:15: error: failed to legalize operation 'mhlo.dot_general'
    attended = jnp.einsum('bsSh,bShd->bshd', weights, v)
              ^
layers.py:108:15: note: see current operation: %0 = "mhlo.dot_general"(%arg2, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1, 2], rhs_contracting_dimensions = [1, 3]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<16x256x4x64xf32>, tensor<16x256x256x4xf32>) -> tensor<16x64x256xf32>

edit: in my case the issues seems to be due to broadcasting, if I manually broadcast using jnp.repeat first the issue goes away:

    if weights.shape[3] != v.shape[2]:
      v = jnp.repeat(v, weights.shape[3] // v.shape[2], axis=2)
jax-metal error on jax.numpy.einsum
 
 
Q