Implementation of some core functions of jax-metal

It appears that some of the jax core functions (in pjit, mlir) are not supported. Is this something to be supported in the future?

For example, when I tested a diffrax example,

from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnp

def f(t, y, args):
    return -y

term = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)

It generates an error saying EmitPythonCallback is not supported in metal.


  File ~/anaconda3/envs/jax-metal-0410/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1787 in emit_python_callback
    raise ValueError(

ValueError: `EmitPythonCallback` not supported on METAL backend.

I uderstand that, currently, no M1 or M2 chips have multiple devices or can be arranged like that. Therefore, it may not be necessary to fully implement p*** functions (pmap, pjit, etc). But some powerful libraries use them. So, it would be great if at least some workaround for core functions are implemented.

Or is there any easy fix for this?

Implementation of some core functions of jax-metal
 
 
Q