jax.numpy.dot and jax.numpy.matmul crashes

I am seeing an issue in jax.numpy-dot and jax.numpy.matmul as illustrated by this example of jax.numpy.dot:

import jax.numpy as jnp
import numpy as np

x = np.array(np.random.rand(3, 3))
y = np.array(np.random.rand(3))
z = np.array(np.random.rand(3))

print("X: ", x)
print("Y: ", y)
print("Z: ", z)

print("Numpy 1D*1D: ", np.dot(y, z))
print("Jax Numpy 1D*1D: ", jnp.dot(y, z))
print("Numpy 2D*1D: ", np.dot(x, y))
print("Jax Numpy 2D*1D: ", jnp.dot(x, y))
loc("-":4:5): error: type of return operand 0 ('tensor<*xf32>') doesn't match function result type ('tensor<3xf32>') in function @main
/AppleInternal/Library/BuildRoots/1a7a4148-f669-11ed-9d56-f6357a1003e8/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1950: failed assertion `Error: MLIR pass manager failed'
zsh: abort      python test.py

As can be seen, dot product between two 1D arrays works for both standard Numpy and jax.numpy. However, 2D*1D only works for standard Numpy while jax.numpy throws an error.

I am using: Jax 0.4.11, Jax-metal 0.0.2 and jaxlib 0.4.10.

Has anyone else seen this issue?

jax.numpy.dot and jax.numpy.matmul crashes
 
 
Q