Post

Replies

Boosts

Views

Activity

Reply to Jax-metal on M2 Pro does not recognize GPU
Thanks! That helped and Jax now recognizes the GPU. Unfortunately, when I tried to run a simple example of the Newton algorithm from https://jax.quantecon.org/newtons_method.html it fails: Traceback (most recent call last): File "<stdin>", line 1, in <module> File "<stdin>", line 11, in newton jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:4:0: error: failed to legalize operation 'mhlo.scatter' <stdin>:4:0: note: called from <stdin>:4:0: note: see current operation: %2177 = "mhlo.scatter"(%2052, %2176, %2167) ({ ^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>): "mhlo.return"(%arg7) : (tensor<f32>) -> () }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<5000x128xf32>, tensor<1xsi32>, tensor<5000xf32>) -> tensor<5000x128xf32> Running the same code on the CPU works fine.
Jun ’23
Reply to Jax-metal on M2 Pro does not recognize GPU
same problem here, but when I changed the instructions to git clone https://github.com/google/jax.git --branch jaxlib-v0.4.11 --single-branch and python -m pip install jax==v0.4.11 it now seems to recognize the GPU: >>> from jax.lib import xla_bridge >>> print(xla_bridge.get_backend().platform) Metal device set to: Apple M2 Max systemMemory: 96.00 GB maxCacheSize: 36.00 GB METAL >>> import jax >>> jax.devices() [MetalDevice(id=0, process_index=0)] >>> jax.devices()[0].platform 'METAL' >>> jax.devices()[0].device_kind 'Metal' >>> jax.devices()[0].client.platform 'METAL' >>> jax.devices()[0].client.runtime_type 'tfrt' But now, x = jnp.ones((10000, 10000)) generates errors: jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: -:0:0: error: bytecode version 5 is newer than the current version 1
Jun ’23