Post

Replies

Boosts

Views

Activity

jax-metal error jax.numpy.linalg.inv
Hi, I have a an issue with jax.numpy.linalg.inv(a). import jax.numpy.linalg as jnpl B = jnp.identity(2) jnpl.inv(B) Throws the following error: XlaRuntimeError: UNKNOWN: /var/folders/pw/wk5rfkjj6qggqp8r8zb2bw8w0000gn/T/ipykernel_34334/2572982404.py:9:0: error: failed to legalize operation 'mhlo.triangular_solve' /var/folders/pw/wk5rfkjj6qggqp8r8zb2bw8w0000gn/T/ipykernel_34334/2572982404.py:9:0: note: called from /var/folders/pw/wk5rfkjj6qggqp8r8zb2bw8w0000gn/T/ipykernel_34334/2572982404.py:9:0: note: see current operation: %120 = \"mhlo.triangular_solve\"(%42#4, %119) {left_side = true, lower = true, transpose_a = #mhlo<transpose NO_TRANSPOSE>, unit_diagonal = true} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> Any ideas what could be the issue or how to solve it?
2
0
884
Feb ’24