Post

Replies

Boosts

Views

Activity

Jax-Metal error: failed to legalize operation of mhlo.fft
Hi, just got an Apple M3 Pro to try it out on some Jax operations. I see the development is actively ongoing so maybe this error can help. This is the environment: Metal device set to: Apple M3 Pro systemMemory: 18.00 GB maxCacheSize: 6.00 GB jax: 0.4.26 jaxlib: 0.4.23 numpy: 1.26.4 python: 3.11.8 | packaged by conda-forge | (main, Feb 16 2024, 20:49:36) [Clang 16.0.6 ] jax.devices (1 total, 1 local): [METAL(id=0)] process_count: 1 platform: uname_result(system='Darwin', node='MKFL96VR9YT', release='23.4.0', version='Darwin Kernel Version 23.4.0: Wed Feb 21 21:44:54 PST 2024; root:xnu-10063.101.15~2/RELEASE_ARM64_T6030', machine='arm64') This is a minimal example which produces an error, I think due to the fft part: from jax import numpy as np array = np.ones((16, 16)) np.fft.fft2(array) This is the full traceback: Traceback (most recent call last): File "/Users/user/Downloads/wow.py", line 5, in <module> np.fft.fft2(array) File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/numpy/fft.py", line 216, in fft2 return _fft_core_2d('fft2', xla_client.FftType.FFT, a, s=s, axes=axes, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/numpy/fft.py", line 210, in _fft_core_2d return _fft_core(func_name, fft_type, a, s, axes, norm) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/numpy/fft.py", line 102, in _fft_core transformed = lax.fft(arr, fft_type, tuple(s)) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback return fun(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 298, in cache_miss outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper( ^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 176, in _python_pjit_helper out_flat = pjit_p.bind(*args_flat, **params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/core.py", line 2788, in bind return self.bind_with_trace(top_trace, args, params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/core.py", line 425, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/core.py", line 913, in process_primitive return primitive.impl(*tracers, **params) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1494, in _pjit_call_impl return xc._xla.pjit(name, f, call_impl_cache_miss, [], [], donated_argnums, # type: ignore ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1471, in call_impl_cache_miss out_flat, compiled = _pjit_call_impl_python( ^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/pjit.py", line 1406, in _pjit_call_impl_python lowering_parameters=mlir.LoweringParameters()).compile() ^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2369, in compile executable = UnloadedMeshExecutable.from_hlo( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2908, in from_hlo xla_executable, compile_options = _cached_compilation( ^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2718, in _cached_compilation xla_executable = compiler.compile_or_get_cached( ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/compiler.py", line 266, in compile_or_get_cached return backend_compile(backend, computation, compile_options, ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/profiler.py", line 335, in wrapper return func(*args, **kwargs) ^^^^^^^^^^^^^^^^^^^^^ File "/opt/anaconda3/envs/jaxmetal/lib/python3.11/site-packages/jax/_src/compiler.py", line 238, in backend_compile return backend.compile(built_c, compile_options=options) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: 'func.func' op One or more function input/output data types are not supported. <unknown>:0: note: see current operation: "func.func"() <{arg_attrs = [{mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}], function_type = (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({ ^bb0(%arg0: tensor<16x16xf32>): %0 = "mhlo.convert"(%arg0) : (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>> %1 = "mhlo.fft"(%0) {fft_length = dense<16> : tensor<2xi64>, fft_type = #mhlo<fft_type FFT>} : (tensor<16x16xcomplex<f32>>) -> tensor<16x16xcomplex<f32>> "func.return"(%1) : (tensor<16x16xcomplex<f32>>) -> () }) : () -> () <unknown>:0: error: failed to legalize operation 'func.func' <unknown>:0: note: see current operation: "func.func"() <{arg_attrs = [{mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}], function_type = (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>>, res_attrs = [{jax.result_info = "", mhlo.layout_mode = "default"}], sym_name = "main", sym_visibility = "public"}> ({ ^bb0(%arg0: tensor<16x16xf32>): %0 = "mhlo.convert"(%arg0) : (tensor<16x16xf32>) -> tensor<16x16xcomplex<f32>> %1 = "mhlo.fft"(%0) {fft_length = dense<16> : tensor<2xi64>, fft_type = #mhlo<fft_type FFT>} : (tensor<16x16xcomplex<f32>>) -> tensor<16x16xcomplex<f32>> "func.return"(%1) : (tensor<16x16xcomplex<f32>>) -> () }) : () -> () I'd be happy running more tests should you need them, I'm new to this, so not sure which just yet. Many thanks!!
0
0
682
Apr ’24