Post

Replies

Boosts

Views

Activity

Jax-Metal - error: failed to legalize operation 'mhlo.cholesky'
After building jaxlib as per the instructions and installing jax-metal, upon testing upon an existing model which works fine using CPU (and GPU on linux), I get the following error. jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky' /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32> A have tried to reproduce this with the following minimal example, but this works fine. from jax import jit import jax.numpy as jnp import jax.random as jnr import jax.scipy as jsp key = jnr.PRNGKey(0) A = jnr.normal(key, (100,100)) def calc_cholesky_decomp(test_matrix): psd_test_matrix = test_matrix @ test_matrix.T col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True) return col_decomp calc_cholesky_decomp(A) jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp) jitted_calc_cholesky_decomp(A) I am unable to attach the full error message has it exceeds all the restricts placed on uploads attached to a post. I am more than happy to try a more complex model if you have any suggestions.
7
3
1.3k
Jun ’23