Jax-metal - whisper-jax

Testing out https://developer.apple.com/metal/jax/ mainly for trying to run whisper-jax (https://github.com/sanchit-gandhi/whisper-jax/tree/main) on my M2.

The jax-metal plugin seems to install without issues, and the basic test code runs fine. However, the jax-whisper code fails when trying to encode a file with the following error:

error: failed to legalize operation 'mhlo.convolution'

/Users/pere/jax-metal/lib/python3.10/site-packages/whisper_jax/layers.py:1236:0: note: see current operation: %111 = "mhlo.convolution"(%110, <<UNKNOWN SSA VALUE>>) {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<1xi64>, padding = dense<1> : tensor<1x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<1xi64>, window_reversal = dense<false> : tensor<1xi1>, window_strides = dense<1> : tensor<1xi64>} : (tensor<1x3000x80xf32>, tensor<3x80x384xf32>) -> tensor<1x3000x384xf32>

Same problem here on my m2max.

I get this problem when using conv2dtranspose. Seems metal does not support all opperations yet. Did you find a fix?

Also tested this on the new jax-metal 0.0.5. Basically the same error:

XlaRuntimeError: UNKNOWN: <ipython-input-3-1d989b5d528f>:1:0: error: 'mhlo.convolution' op Not supported: ConvolutionOp other than Conv2d.
<ipython-input-3-1d989b5d528f>:1:0: note: see current operation: %153 = "mhlo.convolution"(%152, <<UNKNOWN SSA VALUE>>) {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<1xi64>, padding = dense<1> : tensor<1x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<1xi64>, window_reversal = dense<false> : tensor<1xi1>, window_strides = dense<1> : tensor<1xi64>} : (tensor<1x3000x80xf32>, tensor<3x80x768xf32>) -> tensor<1x3000x768xf32>
<ipython-input-3-1d989b5d528f>:1:0: error: failed to legalize operation 'mhlo.convolution'
<ipython-input-3-1d989b5d528f>:1:0: note: see current operation: %153 = "mhlo.convolution"(%152, <<UNKNOWN SSA VALUE>>) {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<1xi64>, padding = dense<1> : tensor<1x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<1xi64>, window_reversal = dense<false> : tensor<1xi1>, window_strides = dense<1> : tensor<1xi64>} : (tensor<1x3000x80xf32>, tensor<3x80x768xf32>) -> tensor<1x3000x768xf32>
Jax-metal - whisper-jax
 
 
Q