Thanks! That helped and Jax now recognizes the GPU.
Unfortunately, when I tried to run a simple example of the Newton algorithm from https://jax.quantecon.org/newtons_method.html it fails:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "<stdin>", line 11, in newton
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:4:0: error: failed to legalize operation 'mhlo.scatter'
<stdin>:4:0: note: called from
<stdin>:4:0: note: see current operation:
%2177 = "mhlo.scatter"(%2052, %2176, %2167) ({
^bb0(%arg6: tensor<f32>, %arg7: tensor<f32>):
"mhlo.return"(%arg7) : (tensor<f32>) -> ()
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<5000x128xf32>, tensor<1xsi32>, tensor<5000xf32>) -> tensor<5000x128xf32>
Running the same code on the CPU works fine.
Post
Replies
Boosts
Views
Activity
same problem here, but when I changed the instructions to
git clone https://github.com/google/jax.git --branch jaxlib-v0.4.11 --single-branch
and
python -m pip install jax==v0.4.11
it now seems to recognize the GPU:
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
Metal device set to: Apple M2 Max
systemMemory: 96.00 GB
maxCacheSize: 36.00 GB
METAL
>>> import jax
>>> jax.devices()
[MetalDevice(id=0, process_index=0)]
>>> jax.devices()[0].platform
'METAL'
>>> jax.devices()[0].device_kind
'Metal'
>>> jax.devices()[0].client.platform
'METAL'
>>> jax.devices()[0].client.runtime_type
'tfrt'
But now, x = jnp.ones((10000, 10000)) generates errors:
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: -:0:0: error: bytecode version 5 is newer than the current version 1
Same here