I'm getting what looks like the same error during the backward pass for a transformer, but in this case it's less clear how to work around it:
layers.py:108:15: error: failed to legalize operation 'mhlo.dot_general'
attended = jnp.einsum('bsSh,bShd->bshd', weights, v)
^
layers.py:108:15: note: see current operation: %0 = "mhlo.dot_general"(%arg2, %arg1) {dot_dimension_numbers = #mhlo.dot<lhs_batching_dimensions = [0], rhs_batching_dimensions = [0], lhs_contracting_dimensions = [1, 2], rhs_contracting_dimensions = [1, 3]>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>]} : (tensor<16x256x4x64xf32>, tensor<16x256x256x4xf32>) -> tensor<16x64x256xf32>
edit: in my case the issues seems to be due to broadcasting, if I manually broadcast using jnp.repeat first the issue goes away:
if weights.shape[3] != v.shape[2]:
v = jnp.repeat(v, weights.shape[3] // v.shape[2], axis=2)
Post
Replies
Boosts
Views
Activity
To reproduce, first download the model checkpoint from https://www.kaggle.com/models/google/gemma/flax/2b-it
Clone the repository and install the dependencies:
git clone https://github.com/google-deepmind/gemma.git
cd gemma
python3 -m venv .
./bin/pip install jax-metal absl-py sentencepiece orbax chex flax
Patch it to use float32 params:
sed -i.bu 's/param_state = jax.tree_util.tree_map(jnp.array, params)/param_state = jax.tree_util.tree_map(lambda p: jnp.array(p, jnp.float32), params)/' gemma/params.py
Run sampling and observe the segfault (paths here must reference the checkpoint downloaded in the first step):
PYTHONPATH=$(pwd) ./bin/python3 examples/sampling.py --path_checkpoint ~/models/gemma_2b_it/2b-it --path_tokenizer ~/models/gemma_2b_it/tokenizer.model