I am seeing an issue in jax.numpy-dot and jax.numpy.matmul as illustrated by this example of jax.numpy.dot:
import jax.numpy as jnp
import numpy as np
x = np.array(np.random.rand(3, 3))
y = np.array(np.random.rand(3))
z = np.array(np.random.rand(3))
print("X: ", x)
print("Y: ", y)
print("Z: ", z)
print("Numpy 1D*1D: ", np.dot(y, z))
print("Jax Numpy 1D*1D: ", jnp.dot(y, z))
print("Numpy 2D*1D: ", np.dot(x, y))
print("Jax Numpy 2D*1D: ", jnp.dot(x, y))
loc("-":4:5): error: type of return operand 0 ('tensor<*xf32>') doesn't match function result type ('tensor<3xf32>') in function @main
/AppleInternal/Library/BuildRoots/1a7a4148-f669-11ed-9d56-f6357a1003e8/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:1950: failed assertion `Error: MLIR pass manager failed'
zsh: abort python test.py
As can be seen, dot product between two 1D arrays works for both standard Numpy and jax.numpy. However, 2D*1D only works for standard Numpy while jax.numpy throws an error.
I am using: Jax 0.4.11, Jax-metal 0.0.2 and jaxlib 0.4.10.
Has anyone else seen this issue?