Post

Replies

Boosts

Views

Activity

Implementation of some core functions of jax-metal
It appears that some of the jax core functions (in pjit, mlir) are not supported. Is this something to be supported in the future? For example, when I tested a diffrax example, from diffrax import diffeqsolve, ODETerm, Dopri5 import jax.numpy as jnp def f(t, y, args): return -y term = ODETerm(f) solver = Dopri5() y0 = jnp.array([2., 3.]) solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0) It generates an error saying EmitPythonCallback is not supported in metal. File ~/anaconda3/envs/jax-metal-0410/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1787 in emit_python_callback raise ValueError( ValueError: `EmitPythonCallback` not supported on METAL backend. I uderstand that, currently, no M1 or M2 chips have multiple devices or can be arranged like that. Therefore, it may not be necessary to fully implement p*** functions (pmap, pjit, etc). But some powerful libraries use them. So, it would be great if at least some workaround for core functions are implemented. Or is there any easy fix for this?
0
1
883
Jul ’23