Same problem with torch on Macbook M2. Using CPU converges well, using colab also converges well, but if I set MPS - fails to converge.
Link to the code: https://colab.research.google.com/drive/1xG_R3RpmTVLCTCTeGTG8yo-e7iwIFYbt?usp=sharing
torch version 2.1.2