Post not yet marked as solved
Post marked as unsolved with 0 replies, 437 views
Many basic functions, such as lgamma, hyperbolic trig functions, and inverse trig functions appear to be missing/broken.
After following the instructions at https://developer.apple.com/metal/jax/ (and verifying the installation), the following example throws an error:
import jax
print(jax.devices())
jax.scipy.stats.t.logpdf(x=1.5, df=3.0, loc=1.0, scale=2.0)
# All of the following calls fail in similar ways:
# jax.lax.lgamma(2.5)
# jax.lax.acosh(0.5)
# jax.lax.atanh(0.5)
# jax.lax.asin(0.5)
The above results in the following error:
XlaRuntimeError: UNKNOWN: /var/folders/0h/1lcyrkv11ynfjq5w7bt0sl_00000gn/T/ipykernel_4332/2323994674.py:3:0: error: failed to legalize operation 'chlo.lgamma'
/var/folders/0h/1lcyrkv11ynfjq5w7bt0sl_00000gn/T/ipykernel_4332/2323994674.py:3:0: note: see current operation: %0 = "chlo.lgamma"(%arg0) : (tensor<f32>) -> tensor<f32>
This appears to be due to the t-distribution pdf requiring the lgamma function, which itself is broken. The lax functions commented out above fail in similar ways (though this is not a complete list).
Is there a way to work around these missing functions in the meantime?