jax.numpy.insert returning incorrect results wen jitted

Hi,

I just noticed that using the jax.numpy.insert() function returns an incorrect result (zero-padding the array) when compiled with jax.jit. When not jitted, the results are correct

Config:

  • M1 Pro Macbook Pro 2021
  • python 3.12.3 ; jax-metal 0.0.6 ; jax 0.4.26 ; jaxlib 0.4.23

MWE:

import jax
import jax.numpy as jnp

x = jnp.arange(20).reshape(5, 4)
print(f"{x=}\n")

def return_arr_with_ins(arr, ins):
    return jnp.insert(arr, 2, ins, axis=1)

x2 = return_arr_with_ins(x, 99)
print(f"{x2=}\n")

return_arr_with_ins_jit = jax.jit(return_arr_with_ins)
x3 = return_arr_with_ins_jit(x, 99)
print(f"{x3=}\n")

Output: x2 (computed with the non-jitted function) is correct; x3 just has zero-padding instead of a column of 99

x=Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19]], dtype=int32)

x2=Array([[ 0,  1, 99,  2,  3],
       [ 4,  5, 99,  6,  7],
       [ 8,  9, 99, 10, 11],
       [12, 13, 99, 14, 15],
       [16, 17, 99, 18, 19]], dtype=int32)

x3=Array([[ 0,  1,  2,  3,  0],
       [ 4,  5,  6,  7,  0],
       [ 8,  9, 10, 11,  0],
       [12, 13, 14, 15,  0],
       [16, 17, 18, 19,  0]], dtype=int32)

The same code run on a non-metal machine gives the correct results:

x=Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19]], dtype=int32)

x2=Array([[ 0,  1, 99,  2,  3],
       [ 4,  5, 99,  6,  7],
       [ 8,  9, 99, 10, 11],
       [12, 13, 99, 14, 15],
       [16, 17, 99, 18, 19]], dtype=int32)

x3=Array([[ 0,  1, 99,  2,  3],
       [ 4,  5, 99,  6,  7],
       [ 8,  9, 99, 10, 11],
       [12, 13, 99, 14, 15],
       [16, 17, 99, 18, 19]], dtype=int32)

Not sure if this is the correct channel for bug reports, please feel free to let me know if there's a more appropriate place!