It appears that some of the jax core functions (in pjit, mlir) are not supported. Is this something to be supported in the future?
For example, when I tested a diffrax example,
from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp
def f(t, y, args):
return -y
term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
It generates an error saying EmitPythonCallback is not supported in metal.
File ~/anaconda3/envs/jax-metal-0410/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1787 in emit_python_callback
raise ValueError(
ValueError: `EmitPythonCallback` not supported on METAL backend.
I uderstand that, currently, no M1 or M2 chips have multiple devices or can be arranged like that. Therefore, it may not be necessary to fully implement p*** functions (pmap, pjit, etc). But some powerful libraries use them. So, it would be great if at least some workaround for core functions are implemented.
Or is there any easy fix for this?