Thx for reporting it. Several bugs of advanced indexing, involving GatherOp and ScatterOp conversion have been fixed at the tip. The example in the post shall be fixed. The fixes will be integrated into next release of jax-metal.
Post
Replies
Boosts
Views
Activity
This is reproducible with jax-metal 0.0.5. The lowering pattern need to be expanded to handle contracting_dimensions size > 1. To workaround, could you make the below changes and give it a try:
a = np.random.rand(11, 12, 13, 11, 12).reshape(132, 13, 11, 12)
b = np.random.rand(11, 12, 13).reshape(132, 13)
#subscripts = 'ijklm,ijk->lmk'
subscripts = 'iklm,ik->lmk'
It generates matching result on my side.
Can you pls provide the steps and the script to reproduce it?
The issue has been fixed in jax-metal 0.0.7.