Good evening!
Tried to use Flax nn.ConvTranspose
which calls jax.lax.conv_transpose
but it looks like it isn't implemented correctly for the METAL backend, works fine on CPU.
File "/Users/cemlyn/Documents/VCLless/mnist_vae/venv/lib/python3.11/site-packages/flax/linen/linear.py", line 768, in __call__
y = lax.conv_transpose(
^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<1x8x8x64xf32>') doesn't match function result type ('tensor<1x14x14x64xf32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%0) : (tensor<1x8x8x64xf32>) -> ()
Versions:
pip list | grep jax
jax 0.4.11
jax-metal 0.0.4
jaxlib 0.4.11