Post

Replies

Boosts

Views

Activity

Unsupported type in JAX metal PJRT plugin with rng_bit_generator
Hi all, When executing an HLO program using the JAX metal PJRT plugin, the program fails due to an unsupported data type returned by the rng_bit_generator operation. The generated HLO includes: %output_state, %output = "mhlo.rng_bit_generator"(%1) <{rng_algorithm = #mhlo.rng_algorithm<PHILOX>}> : (tensor<3xi64>) -> (tensor<3xi64>, tensor<3xui32>) The error message indicates that: Metal only supports MPSDataTypeFloat16, MPSDataTypeBFloat16, MPSDataTypeFloat32, MPSDataTypeInt32, and MPSDataTypeInt64. The use of ui32 seems to be incompatible with Metal’s allowed types. I’m trying to understand if the ui32 output is the problem or maybe the use of rng_bit_generator is wrong. Could you clarify if there is a workaround or planned support for ui32 output in this context? Alternatively, guidance on configuring rng_bit_generator for compatibility with Metal’s supported types would be greatly appreciated.
0
0
101
1w