Post

Replies

Boosts

Views

Activity

Reply to Jax-metal on M2 Pro does not recognize GPU
In my MBP (M1 Max), jax-metal works in the following setup. (I followed the instructions in https://developer.apple.com/metal/jax/ with modifications shown below in parentheses. I used Anaconda environments.) jaxlib 0.4.10 ( git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch ) jax 0.4.11 ( pip install jax==0.4.11 ) jax-metal 0.0.2 Some catch: This does not work well in Jupyter, but it works in the command line. So, I use it in a Spyder editor (with IPython). Some basic jax.numpy operations may have bugs. Ex: The following code does not work. import jax.numpy as jnp from jax import random key = random.PRNGKey(1) x = random.normal(key, (200, 100)) y = random.normal(key, (100,)) a = jnp.dot(y, x.T) # OK print(a.shape) b = jnp.dot(x, y[:, jnp.newaxis]) # OK, but it becomes 2D print(b.shape) c = jnp.dot(x, y) # Error in jax-metal 0.0.2 in GPU. Works fine in CPU (with jaxlib 0.4.12, which does not work in GPU, but there is no point in using it in CPU).
Jul ’23