Will TensorFlow-Metal and JAX-Metal code be open sourced?
Reasons why I ask:
If it is open sourced on GitHub or something it might make it easier for people to find issues and create new ones if necessary, also the open source community might be able to help ;)
I'd love to learn about how you guys implement some of these operations :P (I know you guys made an Apple tutorial on how to implement TensorFlow custom op for Metal which was fire https://developer.apple.com/documentation/metal/metal_sample_code_library/customizing_a_tensorflow_operation)
Post
Replies
Boosts
Views
Activity
Good evening!
Tried to use Flax nn.ConvTranspose which calls jax.lax.conv_transpose but it looks like it isn't implemented correctly for the METAL backend, works fine on CPU.
File "/Users/cemlyn/Documents/VCLless/mnist_vae/venv/lib/python3.11/site-packages/flax/linen/linear.py", line 768, in __call__
y = lax.conv_transpose(
^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<1x8x8x64xf32>') doesn't match function result type ('tensor<1x14x14x64xf32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%0) : (tensor<1x8x8x64xf32>) -> ()
Versions:
pip list | grep jax
jax 0.4.11
jax-metal 0.0.4
jaxlib 0.4.11
I only get this error when using the JAX Metal device (CPU is fine). It seems to be a problem whenever I want to modify values of an array in-place using at and set.
note: see current operation:
%2903 = "mhlo.scatter"(%arg3, %2902, %2893) ({
^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>):
"mhlo.return"(%arg5) : (tensor<f32>) -> ()
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<10x100x4xf32>, tensor<1xsi32>, tensor<10x4xf32>) -> tensor<10x100x4xf32>
blocks = blocks.at[i].set(
...