Post

Replies

Boosts

Views

Activity

jax-metal segfaults when running Gemma inference
I tried running inference with the 2B model from https://github.com/google-deepmind/gemma on my M2 MacBook Pro, but it segfaults during sampling: https://pastebin.com/KECyz60T Note: out of the box it will try to load bfloat16 weights, which will fail. To avoid this, I patched line 30 in gemma/params.py to explicitly cast to float32: param_state = jax.tree_util.tree_map(lambda p: jnp.array(p, jnp.float32), params)
2
0
833
Mar ’24