Post

Replies

Boosts

Views

Activity

Reply to jax.lax.conv_transpose not correctly implemented
Stride not (1,1) gives me that. I use manual scaling instead.: def upscale_nearest_neighbor(x, scale_factor=2): # Assuming x has shape (batch, height, width, channels) b, h, w, c = x.shape x = x.reshape(b, h, 1, w, 1, c) x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c))) return x.reshape(b, h * scale_factor, w * scale_factor, c) def deconv2d(x, w): x_upscaled = upscale_nearest_neighbor(x) return lax.conv_transpose( x_upscaled, w, strides=(1, 1), padding='SAME', dimension_numbers=("NHWC", "HWIO", "NHWC"))
Oct ’23