Posts

Post not yet marked as solved
0 Replies
437 Views
Many basic functions, such as lgamma, hyperbolic trig functions, and inverse trig functions appear to be missing/broken. After following the instructions at https://developer.apple.com/metal/jax/ (and verifying the installation), the following example throws an error: import jax print(jax.devices()) jax.scipy.stats.t.logpdf(x=1.5, df=3.0, loc=1.0, scale=2.0) # All of the following calls fail in similar ways: # jax.lax.lgamma(2.5) # jax.lax.acosh(0.5) # jax.lax.atanh(0.5) # jax.lax.asin(0.5) The above results in the following error: XlaRuntimeError: UNKNOWN: /var/folders/0h/1lcyrkv11ynfjq5w7bt0sl_00000gn/T/ipykernel_4332/2323994674.py:3:0: error: failed to legalize operation 'chlo.lgamma' /var/folders/0h/1lcyrkv11ynfjq5w7bt0sl_00000gn/T/ipykernel_4332/2323994674.py:3:0: note: see current operation: %0 = "chlo.lgamma"(%arg0) : (tensor<f32>) -> tensor<f32> This appears to be due to the t-distribution pdf requiring the lgamma function, which itself is broken. The lax functions commented out above fail in similar ways (though this is not a complete list). Is there a way to work around these missing functions in the meantime?
Posted
by bsidhom.
Last updated
.