Also tested this on the new jax-metal 0.0.5. Basically the same error:
XlaRuntimeError: UNKNOWN: <ipython-input-3-1d989b5d528f>:1:0: error: 'mhlo.convolution' op Not supported: ConvolutionOp other than Conv2d.
<ipython-input-3-1d989b5d528f>:1:0: note: see current operation: %153 = "mhlo.convolution"(%152, <<UNKNOWN SSA VALUE>>) {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<1xi64>, padding = dense<1> : tensor<1x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<1xi64>, window_reversal = dense<false> : tensor<1xi1>, window_strides = dense<1> : tensor<1xi64>} : (tensor<1x3000x80xf32>, tensor<3x80x768xf32>) -> tensor<1x3000x768xf32>
<ipython-input-3-1d989b5d528f>:1:0: error: failed to legalize operation 'mhlo.convolution'
<ipython-input-3-1d989b5d528f>:1:0: note: see current operation: %153 = "mhlo.convolution"(%152, <<UNKNOWN SSA VALUE>>) {batch_group_count = 1 : i64, dimension_numbers = #mhlo.conv<[b, 0, f]x[0, i, o]->[b, 0, f]>, feature_group_count = 1 : i64, lhs_dilation = dense<1> : tensor<1xi64>, padding = dense<1> : tensor<1x2xi64>, precision_config = [#mhlo<precision DEFAULT>, #mhlo<precision DEFAULT>], rhs_dilation = dense<1> : tensor<1xi64>, window_reversal = dense<false> : tensor<1xi1>, window_strides = dense<1> : tensor<1xi64>} : (tensor<1x3000x80xf32>, tensor<3x80x768xf32>) -> tensor<1x3000x768xf32>