Hi,
I just noticed that using the jax.numpy.insert() function returns an incorrect result (zero-padding the array) when compiled with jax.jit. When not jitted, the results are correct
Config:
M1 Pro Macbook Pro 2021
python 3.12.3 ; jax-metal 0.0.6 ; jax 0.4.26 ; jaxlib 0.4.23
MWE:
import jax
import jax.numpy as jnp
x = jnp.arange(20).reshape(5, 4)
print(f"{x=}\n")
def return_arr_with_ins(arr, ins):
return jnp.insert(arr, 2, ins, axis=1)
x2 = return_arr_with_ins(x, 99)
print(f"{x2=}\n")
return_arr_with_ins_jit = jax.jit(return_arr_with_ins)
x3 = return_arr_with_ins_jit(x, 99)
print(f"{x3=}\n")
Output: x2 (computed with the non-jitted function) is correct; x3 just has zero-padding instead of a column of 99
x=Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]], dtype=int32)
x2=Array([[ 0, 1, 99, 2, 3],
[ 4, 5, 99, 6, 7],
[ 8, 9, 99, 10, 11],
[12, 13, 99, 14, 15],
[16, 17, 99, 18, 19]], dtype=int32)
x3=Array([[ 0, 1, 2, 3, 0],
[ 4, 5, 6, 7, 0],
[ 8, 9, 10, 11, 0],
[12, 13, 14, 15, 0],
[16, 17, 18, 19, 0]], dtype=int32)
The same code run on a non-metal machine gives the correct results:
x=Array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19]], dtype=int32)
x2=Array([[ 0, 1, 99, 2, 3],
[ 4, 5, 99, 6, 7],
[ 8, 9, 99, 10, 11],
[12, 13, 99, 14, 15],
[16, 17, 99, 18, 19]], dtype=int32)
x3=Array([[ 0, 1, 99, 2, 3],
[ 4, 5, 99, 6, 7],
[ 8, 9, 99, 10, 11],
[12, 13, 99, 14, 15],
[16, 17, 99, 18, 19]], dtype=int32)
Not sure if this is the correct channel for bug reports, please feel free to let me know if there's a more appropriate place!