Stride not (1,1) gives me that. I use manual scaling instead.:
def upscale_nearest_neighbor(x, scale_factor=2):
# Assuming x has shape (batch, height, width, channels)
b, h, w, c = x.shape
x = x.reshape(b, h, 1, w, 1, c)
x = lax.tie_in(x, jnp.broadcast_to(x, (b, h, scale_factor, w, scale_factor, c)))
return x.reshape(b, h * scale_factor, w * scale_factor, c)
def deconv2d(x, w):
x_upscaled = upscale_nearest_neighbor(x)
return lax.conv_transpose(
x_upscaled, w,
strides=(1, 1),
padding='SAME',
dimension_numbers=("NHWC", "HWIO", "NHWC"))
Post
Replies
Boosts
Views
Activity
I get this problem when using conv2dtranspose. Seems metal does not support all opperations yet. Did you find a fix?