Posts

Post not yet marked as solved
0 Replies
111 Views
Hi, I just noticed that using the jax.numpy.insert() function returns an incorrect result (zero-padding the array) when compiled with jax.jit. When not jitted, the results are correct Config: M1 Pro Macbook Pro 2021 python 3.12.3 ; jax-metal 0.0.6 ; jax 0.4.26 ; jaxlib 0.4.23 MWE: import jax import jax.numpy as jnp x = jnp.arange(20).reshape(5, 4) print(f"{x=}\n") def return_arr_with_ins(arr, ins): return jnp.insert(arr, 2, ins, axis=1) x2 = return_arr_with_ins(x, 99) print(f"{x2=}\n") return_arr_with_ins_jit = jax.jit(return_arr_with_ins) x3 = return_arr_with_ins_jit(x, 99) print(f"{x3=}\n") Output: x2 (computed with the non-jitted function) is correct; x3 just has zero-padding instead of a column of 99 x=Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19]], dtype=int32) x2=Array([[ 0, 1, 99, 2, 3], [ 4, 5, 99, 6, 7], [ 8, 9, 99, 10, 11], [12, 13, 99, 14, 15], [16, 17, 99, 18, 19]], dtype=int32) x3=Array([[ 0, 1, 2, 3, 0], [ 4, 5, 6, 7, 0], [ 8, 9, 10, 11, 0], [12, 13, 14, 15, 0], [16, 17, 18, 19, 0]], dtype=int32) The same code run on a non-metal machine gives the correct results: x=Array([[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11], [12, 13, 14, 15], [16, 17, 18, 19]], dtype=int32) x2=Array([[ 0, 1, 99, 2, 3], [ 4, 5, 99, 6, 7], [ 8, 9, 99, 10, 11], [12, 13, 99, 14, 15], [16, 17, 99, 18, 19]], dtype=int32) x3=Array([[ 0, 1, 99, 2, 3], [ 4, 5, 99, 6, 7], [ 8, 9, 99, 10, 11], [12, 13, 99, 14, 15], [16, 17, 99, 18, 19]], dtype=int32) Not sure if this is the correct channel for bug reports, please feel free to let me know if there's a more appropriate place!
Posted
by fkeruzore.
Last updated
.