Post

Replies

Boosts

Views

Activity

Question: Will TensorFlow-Metal and JAX-Metal code be open sourced?
Will TensorFlow-Metal and JAX-Metal code be open sourced? Reasons why I ask: If it is open sourced on GitHub or something it might make it easier for people to find issues and create new ones if necessary, also the open source community might be able to help ;) I'd love to learn about how you guys implement some of these operations :P (I know you guys made an Apple tutorial on how to implement TensorFlow custom op for Metal which was fire https://developer.apple.com/documentation/metal/metal_sample_code_library/customizing_a_tensorflow_operation)
0
1
430
Oct ’23
jax.lax.conv_transpose not correctly implemented
Good evening! Tried to use Flax nn.ConvTranspose which calls jax.lax.conv_transpose but it looks like it isn't implemented correctly for the METAL backend, works fine on CPU. File "/Users/cemlyn/Documents/VCLless/mnist_vae/venv/lib/python3.11/site-packages/flax/linen/linear.py", line 768, in __call__ y = lax.conv_transpose( ^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<1x8x8x64xf32>') doesn't match function result type ('tensor<1x14x14x64xf32>') in function @main <unknown>:0: note: see current operation: "func.return"(%0) : (tensor<1x8x8x64xf32>) -> () Versions: pip list | grep jax jax 0.4.11 jax-metal 0.0.4 jaxlib 0.4.11
2
0
659
Oct ’23
JAX Metal error: failed to legalize operation 'mhlo.scatter'
I only get this error when using the JAX Metal device (CPU is fine). It seems to be a problem whenever I want to modify values of an array in-place using at and set. note: see current operation: %2903 = "mhlo.scatter"(%arg3, %2902, %2893) ({ ^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>): "mhlo.return"(%arg5) : (tensor<f32>) -> () }) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<10x100x4xf32>, tensor<1xsi32>, tensor<10x4xf32>) -> tensor<10x100x4xf32> blocks = blocks.at[i].set( ...
6
5
1.5k
Oct ’23