Intel GPU here. I just installed Jax-metal, and it runs fine. However, when I try the following code, it returns the RuntimeError you see below:
jax.device_put(jnp.ones(1), device=jax.devices('gpu')[0])
RuntimeError: Unknown backend: 'gpu' requested, but no platforms that are instances of gpu are present. Platforms present are: interpreter,cpu