Post

Replies

Boosts

Views

Activity

Probability density/mass functions from jax.scipy.stats not supported on Metal
(Copied from https://github.com/google/jax/issues/20835) I am attempting to use JAX on Metal (on a M1 Pro chip) to model discrete (count) data. I've installed the latest version jax-metal 0.0.6 using pip. The installation seems to have worked overall as I can perform basic Jax array operations on GPU. However, when I try to compute the (log-)PMFs/PDFs of random variables which are defined in terms of the (log-)Gamma function I get errors like the one below which seems to indicate that the lax.lgamma function is not supported under the hood on M1 metal. This is essential functionality for a wide class of probabilistic machine learning models. Note that following functions (among others) are broken as a result: jax.scipy.stats.binom.logpmf jax.scipy.stats.nbinom.logpmf jax.scipy.stats.poisson.logpmf jax.scipy.stats.dirichlet.logpdf jax.scipy.stats.beta.logpdf jax.scipy.stats.gamma.logpdf ... >>> jax.scipy.stats.binom.logpmf(1, n=2, p=0.5) jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. The above exception was the direct cause of the following exception: Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/stats/binom.py", line 31, in logpmf gammaln(n + 1), File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/special.py", line 44, in gammaln return lax.lgamma(x) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/lax/special.py", line 46, in lgamma return lgamma_p.bind(x) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive return primitive.impl(*tracers, **params) File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive outs = fun(*args) jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'chlo.lgamma' <stdin>:1:0: note: see current operation: %0 = "chlo.lgamma"(%arg0) : (tensor<f32>) -> tensor<f32> System info (python version, jaxlib version, accelerator, etc.) jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.26.4 python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='PHS027794', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')
0
0
601
Apr ’24