Testing out https://developer.apple.com/metal/jax/ mainly for trying to run whisper-jax (https://github.com/sanchit-gandhi/whisper-jax/tree/main) on my M2.
The jax-metal plugin seems to install without issues, and the basic test code runs fine. However, the jax-whisper code fails when trying to encode a file with the following error:
error: failed to legalize operation 'mhlo.convolution'
/Users/pere/jax-metal/lib/python3.10/site-packages/whisper_jax/layers.py:1236:0: note: see current operation: %111 = "mhlo.convolution"(%110, <<UNKNOWN SSA VALUE>>) {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<1xi64>, padding = dense<1> : tensor<1x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<1xi64>, window_reversal = dense<false> : tensor<1xi1>, window_strides = dense<1> : tensor<1xi64>} : (tensor<1x3000x80xf32>, tensor<3x80x384xf32>) -> tensor<1x3000x384xf32>