XlaRuntimeError Traceback (most recent call last) Cell In[49], line 4 1 arr = jnp.array( [7, 8, 9]) 3 # Find indices where the condition is True ----> 4 indices = jnp.where(arr > 1) 6 print(indices)
XlaRuntimeError: UNKNOWN
XlaRuntimeError Traceback (most recent call last) Cell In[49], line 4 1 arr = jnp.array( [7, 8, 9]) 3 # Find indices where the condition is True ----> 4 indices = jnp.where(arr > 1) 6 print(indices)
XlaRuntimeError: UNKNOWN
jnp.round doesn't work either...
Same here. More specifically the error is:
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <ipython-input-5-64a76e03061b>:1:0: error: failed to legalize operation 'mhlo.pad'
and for jnp.round the error is:
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <ipython-input-9-ea6c0ef3275e>:1:0: error: failed to legalize operation 'mhlo.round_nearest_even'
N.B. for the case of jnp.where(), specifying x and y args other than 'None' resolves the issue
Same issue. +1
Same issue with many JAX-based packages! +1
Relates to: GitHub issue Apple Silicon: error: failed to legalize operation 'mhlo.pad' #16366