The issue has been fixed in jax-metal 0.0.7.
Post
Replies
Boosts
Views
Activity
Can you pls provide the steps and the script to reproduce it?
This is reproducible with jax-metal 0.0.5. The lowering pattern need to be expanded to handle contracting_dimensions size > 1. To workaround, could you make the below changes and give it a try:
a = np.random.rand(11, 12, 13, 11, 12).reshape(132, 13, 11, 12)
b = np.random.rand(11, 12, 13).reshape(132, 13)
#subscripts = 'ijklm,ijk->lmk'
subscripts = 'iklm,ik->lmk'
It generates matching result on my side.
Thx for reporting it. Several bugs of advanced indexing, involving GatherOp and ScatterOp conversion have been fixed at the tip. The example in the post shall be fixed. The fixes will be integrated into next release of jax-metal.