jax.lax.conv_transpose not correctly implemented

Good evening!

Tried to use Flax nn.ConvTranspose which calls jax.lax.conv_transpose but it looks like it isn't implemented correctly for the METAL backend, works fine on CPU.

File "/Users/cemlyn/Documents/VCLless/mnist_vae/venv/lib/python3.11/site-packages/flax/linen/linear.py", line 768, in __call__
    y = lax.conv_transpose(
        ^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<1x8x8x64xf32>') doesn't match function result type ('tensor<1x14x14x64xf32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%0) : (tensor<1x8x8x64xf32>) -> ()

Versions:

pip list | grep jax
jax                       0.4.11
jax-metal                 0.0.4
jaxlib                    0.4.11

Stride not (1,1) gives me that. I use manual scaling instead.:

def upscale_nearest_neighbor(x, scale_factor=2): # Assuming x has shape (batch, height, width, channels) b, h, w, c = x.shape x = x.reshape(b, h, 1, w, 1, c) x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c))) return x.reshape(b, h * scale_factor, w * scale_factor, c)

def deconv2d(x, w): x_upscaled = upscale_nearest_neighbor(x) return lax.conv_transpose( x_upscaled, w, strides=(1, 1), padding='SAME', dimension_numbers=("NHWC", "HWIO", "NHWC"))

The issue has been fixed in jax-metal 0.0.7.

jax.lax.conv_transpose not correctly implemented
 
 
Q