Post

Replies

Boosts

Views

Activity

Reply to jax-metal error on jax.numpy.einsum
This is reproducible with jax-metal 0.0.5. The lowering pattern need to be expanded to handle contracting_dimensions size > 1. To workaround, could you make the below changes and give it a try: a = np.random.rand(11, 12, 13, 11, 12).reshape(132, 13, 11, 12) b = np.random.rand(11, 12, 13).reshape(132, 13) #subscripts = 'ijklm,ijk->lmk' subscripts = 'iklm,ik->lmk' It generates matching result on my side.
Feb ’24