Hi, I am getting the same error. I am running google's neural tangent tutorial (https://colab.research.google.com/github/google/neural-tangents/blob/master/notebooks/neural_tangents_cookbook.ipynb#scrollTo=zDIYbtgA_atG) by downloading it on to my iMac with M1 processor. Even if I try to run the minimal example given above I face a problem i.e.
XlaRuntimeError Traceback (most recent call last)
Cell In[24], line 14
11 col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
12 return col_decomp
---> 14 calc_cholesky_decomp(A)
16 jitted_calc_cholesky_decomp = jit(calc_cholesky_decomp)
17 jitted_calc_cholesky_decomp(A)
Cell In[24], line 11, in calc_cholesky_decomp(test_matrix)
9 def calc_cholesky_decomp(test_matrix):
10 psd_test_matrix = test_matrix @ test_matrix.T
---> 11 col_decomp = jsp.linalg.cholesky(psd_test_matrix, lower=True)
12 return col_decomp
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/linalg.py:54, in cholesky(***failed resolving arguments***)
49 @_wraps(scipy.linalg.cholesky,
50 lax_description=_no_overwrite_and_chkfinite_doc, skip_params=('overwrite_a', 'check_finite'))
51 def cholesky(a: ArrayLike, lower: bool = False, overwrite_a: bool = False,
52 check_finite: bool = True) -> Array:
53 del overwrite_a, check_finite # Unused
---> 54 return _cholesky(a, lower)
[... skipping hidden 14 frame]
File ~/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py:465, in backend_compile(backend, module, options, host_callbacks)
460 return backend.compile(built_c, compile_options=options,
461 host_callbacks=host_callbacks)
462 # Some backends don't have `host_callbacks` option yet
463 # TODO(sharadmv): remove this fallback when all backends allow `compile`
464 # to take in `host_callbacks`
--> 465 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: UNKNOWN: /var/folders/jh/ms0xbzxj5cq7vswmxdsdvh2w0000gn/T/ipykernel_77014/3178916729.py:11:0: error: failed to legalize operation 'mhlo.cholesky'
/var/folders/jh/ms0xbzxj5cq7vswmxdsdvh2w0000gn/T/ipykernel_77014/3178916729.py:11:0: note: see current operation: %2 = "mhlo.cholesky"(%arg0) {lower = true} : (tensor<100x100xf32>) -> tensor<100x100xf32>
It would be nice if someone could fix this soon