Post

Replies

Boosts

Views

Activity

Reply to jax-metal error on jax.numpy.einsum
I'm getting what looks like the same error during the backward pass for a transformer, but in this case it's less clear how to work around it: layers.py:108:15: error: failed to legalize operation 'mhlo.dot_general' attended = jnp.einsum('bsSh,bShd->bshd', weights, v) ^ layers.py:108:15: note: see current operation: %0 = "mhlo.dot_general"(%arg2, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1, 2], rhs_contracting_dimensions = [1, 3]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<16x256x4x64xf32>, tensor<16x256x256x4xf32>) -> tensor<16x64x256xf32> edit: in my case the issues seems to be due to broadcasting, if I manually broadcast using jnp.repeat first the issue goes away: if weights.shape[3] != v.shape[2]: v = jnp.repeat(v, weights.shape[3] // v.shape[2], axis=2)
Feb ’24
Reply to jax-metal segfaults when running Gemma inference
To reproduce, first download the model checkpoint from https://www.kaggle.com/models/google/gemma/flax/2b-it Clone the repository and install the dependencies: git clone https://github.com/google-deepmind/gemma.git cd gemma python3 -m venv . ./bin/pip install jax-metal absl-py sentencepiece orbax chex flax Patch it to use float32 params: sed -i.bu 's/param_state = jax.tree_util.tree_map(jnp.array, params)/param_state = jax.tree_util.tree_map(lambda p: jnp.array(p, jnp.float32), params)/' gemma/params.py Run sampling and observe the segfault (paths here must reference the checkpoint downloaded in the first step): PYTHONPATH=$(pwd) ./bin/python3 examples/sampling.py --path_checkpoint ~/models/gemma_2b_it/2b-it --path_tokenizer ~/models/gemma_2b_it/tokenizer.model
Apr ’24