After building jaxlib as per the instructions and installing jax-metal, upon testing upon an existing model which works fine using CPU (and GPU on linux), I get the following error.
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: /Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: error: failed to legalize operation 'mhlo.cholesky'
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: called from
/Users/adam/Developer/Pycharm Projects/gpy_flow_test/sparse_gps.py:66:0: note: see current operation: %406 = "mhlo.cholesky"(%405) {lower = true} : (tensor<50x50xf32>) -> tensor<50x50xf32>
A have tried to reproduce this with the following minimal example, but this works fine.
from jax import jit
import jax.numpy as jnp
import jax.random as jnr
import jax.scipy as jsp
key = jnr.PRNGKey(0)
A = jnr.normal(key, (100,100))
def calc_cholesky_decomp(test_matrix):
psd_test_matrix = test_matrix @ test_matrix.T
col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
return col_decomp
calc_cholesky_decomp(A)
jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
jitted_calc_cholesky_decomp(A)
I am unable to attach the full error message has it exceeds all the restricts placed on uploads attached to a post.
I am more than happy to try a more complex model if you have any suggestions.