mlx 0.30.7.3 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (590) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +104 -56
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/nn/base.rb +4 -0
  8. data/lib/mlx/nn/layers/linear.rb +2 -3
  9. data/lib/mlx/onnx.rb +250 -0
  10. data/lib/mlx/version.rb +1 -1
  11. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  12. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  13. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  14. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  15. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  16. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  17. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  18. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  19. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  20. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  21. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  22. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  23. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  24. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  25. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  26. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  28. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  29. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  30. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  31. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  32. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  33. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  34. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  35. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  36. data/submodules/mlx-onnx/LICENSE +21 -0
  37. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  38. data/submodules/mlx-onnx/src/api.cpp +81 -0
  39. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  40. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  41. data/submodules/mlx-onnx/src/export.cpp +653 -0
  42. data/submodules/mlx-onnx/src/io.cpp +61 -0
  43. data/submodules/mlx-onnx/src/json.hpp +25 -0
  44. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  45. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  46. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  47. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  48. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  49. metadata +609 -563
  50. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  51. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  52. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  53. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  54. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  55. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  56. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  57. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  58. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  59. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  60. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  61. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  62. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  63. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  64. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  65. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  66. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  67. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  68. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  69. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  70. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  71. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  590. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
@@ -10,6 +10,7 @@
10
10
  #include <sstream>
11
11
 
12
12
  #include "mlx/backend/cuda/cuda.h"
13
+ #include "mlx/backend/metal/metal.h"
13
14
  #include "mlx/fast_primitives.h"
14
15
  #include "mlx/ops.h"
15
16
  #include "mlx/primitives.h"
@@ -2311,6 +2312,40 @@ array argmax(
2311
2312
  return out;
2312
2313
  }
2313
2314
 
2315
+ array hanning(int M, StreamOrDevice s /* = {} */) {
2316
+ if (M < 1) {
2317
+ return array({});
2318
+ }
2319
+ if (M == 1) {
2320
+ return ones({1}, float32, s);
2321
+ }
2322
+
2323
+ auto n = arange(0, M, float32, s);
2324
+ array factor(M_PI / (M - 1), float32);
2325
+ return square(sin(multiply(factor, n, s), s), s);
2326
+ }
2327
+
2328
+ array hamming(int M, StreamOrDevice s /* = {} */) {
2329
+ if (M < 1) {
2330
+ return array({});
2331
+ }
2332
+ if (M == 1) {
2333
+ return ones({1}, float32, s);
2334
+ }
2335
+
2336
+ auto n = arange(0, M, float32, s);
2337
+ float factor_val = (2.0 * M_PI) / (M - 1);
2338
+ auto factor = array(factor_val, float32);
2339
+
2340
+ auto arg = multiply(factor, n, s);
2341
+ auto cos_vals = cos(arg, s);
2342
+
2343
+ auto left_coef = array(0.54f, float32);
2344
+ auto right_coef = array(0.46f, float32);
2345
+
2346
+ return subtract(left_coef, multiply(right_coef, cos_vals, s), s);
2347
+ }
2348
+
2314
2349
  /** Returns a sorted copy of the flattened array. */
2315
2350
  array sort(const array& a, StreamOrDevice s /* = {} */) {
2316
2351
  int size = a.size();
@@ -4209,6 +4244,34 @@ std::pair<Dtype, QuantizationMode> validate_mode_with_type(
4209
4244
  }
4210
4245
  }
4211
4246
 
4247
+ void validate_global_scale(
4248
+ std::string_view tag,
4249
+ QuantizationMode qmode,
4250
+ const std::optional<array>& global_scale) {
4251
+ if (global_scale.has_value()) {
4252
+ if (qmode != QuantizationMode::Nvfp4) {
4253
+ std::ostringstream msg;
4254
+ msg << "[" << tag << "] Global scale is only supported for 'nvfp4' "
4255
+ << "quantization mode.";
4256
+ throw std::invalid_argument(msg.str());
4257
+ } else {
4258
+ if (global_scale->size() != 1) {
4259
+ std::ostringstream msg;
4260
+ msg << "[" << tag << "] Global scale must be a scalar but got shape "
4261
+ << global_scale->shape() << ".";
4262
+ throw std::invalid_argument(msg.str());
4263
+ }
4264
+ // TODO: not sure if type should be restricted to float32
4265
+ if (global_scale->dtype() != float32) {
4266
+ std::ostringstream msg;
4267
+ msg << "[" << tag << "] Global scale must have dtype float32 but got "
4268
+ << global_scale->dtype() << ".";
4269
+ throw std::invalid_argument(msg.str());
4270
+ }
4271
+ }
4272
+ }
4273
+ }
4274
+
4212
4275
  array quantized_matmul(
4213
4276
  array x,
4214
4277
  array w,
@@ -4251,7 +4314,6 @@ array quantized_matmul(
4251
4314
  if (x.ndim() > 2 && w.ndim() > 2) {
4252
4315
  inputs = broadcast_arrays(inputs, {-2, -1}, s);
4253
4316
  }
4254
-
4255
4317
  auto out_shape = inputs[0].shape();
4256
4318
  out_shape.back() = w_outer_dims;
4257
4319
  return array(
@@ -4267,7 +4329,10 @@ void validate_qqmm_inputs(
4267
4329
  array w,
4268
4330
  std::optional<array> scales_w,
4269
4331
  int group_size,
4270
- int bits) {
4332
+ int bits,
4333
+ std::optional<array> global_scale_x,
4334
+ std::optional<array> global_scale_w,
4335
+ QuantizationMode qmode) {
4271
4336
  // check 2D (for now)
4272
4337
  if (x.ndim() > 2 || w.ndim() > 2) {
4273
4338
  std::ostringstream msg;
@@ -4304,6 +4369,19 @@ void validate_qqmm_inputs(
4304
4369
  << "first argument dtype == " << x.dtype() << ".";
4305
4370
  throw std::invalid_argument(msg.str());
4306
4371
  }
4372
+ // validate global scales
4373
+ validate_global_scale("qqmm", qmode, global_scale_x);
4374
+ validate_global_scale("qqmm", qmode, global_scale_w);
4375
+ // For nvfp4 mode, both global scales must be provided together or neither
4376
+ if (qmode == QuantizationMode::Nvfp4) {
4377
+ bool has_x = global_scale_x.has_value();
4378
+ bool has_w = global_scale_w.has_value();
4379
+ if (has_x != has_w) {
4380
+ throw std::invalid_argument(
4381
+ "[qqmm] For nvfp4 mode, either both global_scale_x and "
4382
+ "global_scale_w must be provided, or neither.");
4383
+ }
4384
+ }
4307
4385
  }
4308
4386
 
4309
4387
  std::pair<int, int> extract_qqmm_dims(
@@ -4343,6 +4421,8 @@ array qqmm(
4343
4421
  std::optional<int> group_size_ /* = std::nullopt */,
4344
4422
  std::optional<int> bits_ /* = std::nullopt */,
4345
4423
  const std::string& mode /* = "nvfp4" */,
4424
+ const std::optional<array> global_scale_x /* = std::nullopt */,
4425
+ const std::optional<array> global_scale_w /* = std::nullopt */,
4346
4426
  StreamOrDevice s /* = {} */) {
4347
4427
  auto stream = to_stream(s);
4348
4428
  auto qmode = string_to_quantization_mode(mode, "qqmm");
@@ -4369,7 +4449,8 @@ array qqmm(
4369
4449
  }
4370
4450
 
4371
4451
  // validate inputs
4372
- validate_qqmm_inputs(x, w, scales_w, group_size, bits);
4452
+ validate_qqmm_inputs(
4453
+ x, w, scales_w, group_size, bits, global_scale_x, global_scale_w, qmode);
4373
4454
  // validate and extract shapes
4374
4455
  auto [w_inner_dims, w_outer_dims] =
4375
4456
  extract_qqmm_dims(x, w, scales_w, group_size, bits);
@@ -4380,6 +4461,11 @@ array qqmm(
4380
4461
  if (scales_w.has_value()) {
4381
4462
  inputs.push_back(*scales_w);
4382
4463
  }
4464
+ if (global_scale_x.has_value() && global_scale_w.has_value()) {
4465
+ inputs.push_back(*global_scale_x);
4466
+ inputs.push_back(*global_scale_w);
4467
+ }
4468
+
4383
4469
  auto out_shape = inputs[0].shape();
4384
4470
  out_shape.back() = w_outer_dims;
4385
4471
  auto out = array(
@@ -4515,6 +4601,7 @@ std::vector<array> fp_quantize(
4515
4601
  int group_size,
4516
4602
  int bits,
4517
4603
  QuantizationMode mode,
4604
+ const std::optional<array>& global_scale /* = std::nullopt */,
4518
4605
  Stream s) {
4519
4606
  int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
4520
4607
  int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
@@ -4532,6 +4619,12 @@ std::vector<array> fp_quantize(
4532
4619
  << bits << ".";
4533
4620
  throw std::invalid_argument(msg.str());
4534
4621
  }
4622
+
4623
+ auto inputs = std::vector<array>{w};
4624
+ if (global_scale.has_value()) {
4625
+ inputs.push_back(global_scale.value());
4626
+ }
4627
+
4535
4628
  auto fallback = [bits = bits, group_size = group_size, s](
4536
4629
  const std::vector<array>& inputs) -> std::vector<array> {
4537
4630
  auto& w = inputs[0];
@@ -4543,8 +4636,13 @@ std::vector<array> fp_quantize(
4543
4636
  divide(max(abs(wq, s), -1, true, s), array(maxval, w.dtype()), s);
4544
4637
  if (group_size == 16) {
4545
4638
  // convert to e4m3
4639
+ auto scale_encode = inputs.size() > 1
4640
+ ? divide(array(448.0f * 6.0f, float32), inputs[1], s)
4641
+ : array(1.0f, float32);
4642
+ scales = multiply(scales, scale_encode, s);
4546
4643
  scales = to_fp8(scales, s);
4547
- wq = divide(wq, from_fp8(scales, w.dtype(), s), s);
4644
+ wq = multiply(
4645
+ divide(wq, from_fp8(scales, w.dtype(), s), s), scale_encode, s);
4548
4646
  } else {
4549
4647
  // convert to e8m0
4550
4648
  auto z = array(0, scales.dtype());
@@ -4600,9 +4698,9 @@ std::vector<array> fp_quantize(
4600
4698
  {uint32, uint8},
4601
4699
  std::make_shared<fast::Quantize>(
4602
4700
  s, fallback, group_size, bits, mode, false),
4603
- {w});
4701
+ inputs);
4604
4702
  }
4605
- return fallback({w});
4703
+ return fallback(inputs);
4606
4704
  }
4607
4705
 
4608
4706
  std::vector<array> quantize(
@@ -4610,6 +4708,7 @@ std::vector<array> quantize(
4610
4708
  std::optional<int> group_size_ /* = std::nullopt */,
4611
4709
  std::optional<int> bits_ /* = std::nullopt */,
4612
4710
  const std::string& mode /* = "affine" */,
4711
+ const std::optional<array>& global_scale /* = std::nullopt */,
4613
4712
  StreamOrDevice s /* = {} */) {
4614
4713
  auto qmode = string_to_quantization_mode(mode, "quantize");
4615
4714
  auto [group_size, bits] =
@@ -4636,11 +4735,17 @@ std::vector<array> quantize(
4636
4735
  << " matrix has shape " << w.shape();
4637
4736
  throw std::invalid_argument(msg.str());
4638
4737
  }
4639
-
4738
+ if (to_stream(s).device == Device::gpu && metal::is_available() &&
4739
+ global_scale.has_value()) {
4740
+ std::ostringstream msg;
4741
+ msg << "[quantize] Global scale is not supported on the Metal backend.";
4742
+ throw std::invalid_argument(msg.str());
4743
+ }
4744
+ validate_global_scale("quantize", qmode, global_scale);
4640
4745
  if (qmode == QuantizationMode::Affine) {
4641
4746
  return affine_quantize(w, group_size, bits, s);
4642
4747
  } else {
4643
- return fp_quantize(w, group_size, bits, qmode, to_stream(s));
4748
+ return fp_quantize(w, group_size, bits, qmode, global_scale, to_stream(s));
4644
4749
  }
4645
4750
  }
4646
4751
 
@@ -4745,6 +4850,7 @@ array fp_dequantize(
4745
4850
  int bits,
4746
4851
  Dtype out_type,
4747
4852
  QuantizationMode mode,
4853
+ const std::optional<array>& global_scale /* = std::nullopt */,
4748
4854
  Stream s) {
4749
4855
  int expected_gs = mode == QuantizationMode::Nvfp4 ? 16 : 32;
4750
4856
  int expected_bits = mode == QuantizationMode::Mxfp8 ? 8 : 4;
@@ -4789,6 +4895,11 @@ array fp_dequantize(
4789
4895
  throw std::invalid_argument(msg.str());
4790
4896
  }
4791
4897
 
4898
+ auto inputs = std::vector<array>{w, scales};
4899
+ if (global_scale.has_value()) {
4900
+ inputs.push_back(global_scale.value());
4901
+ }
4902
+
4792
4903
  auto fallback =
4793
4904
  [wshape = std::move(wshape),
4794
4905
  sshape = std::move(sshape),
@@ -4831,13 +4942,17 @@ array fp_dequantize(
4831
4942
  out = reshape(out, {-1, group_size}, s);
4832
4943
  scales = reshape(scales, {-1, 1}, s);
4833
4944
  if (group_size == 16) {
4834
- scales = from_fp8(scales, out_type, s);
4945
+ array inv_scale_enc = inputs.size() > 2
4946
+ ? divide(inputs[2], array(448.0f * 6.0f, out_type), s)
4947
+ : array(1.0f, out_type);
4948
+ scales = multiply(from_fp8(scales, out_type, s), inv_scale_enc, s);
4835
4949
  } else {
4836
4950
  scales = subtract(astype(scales, out_type, s), array(127, out_type), s);
4837
4951
  scales = power(array(2.0f, out_type), scales, s);
4838
4952
  }
4839
4953
  return {reshape(multiply(out, scales, s), wshape, s)};
4840
4954
  };
4955
+
4841
4956
  if (s.device == Device::gpu) {
4842
4957
  auto out_shape = w.shape();
4843
4958
  out_shape.back() = out_size;
@@ -4846,9 +4961,9 @@ array fp_dequantize(
4846
4961
  out_type,
4847
4962
  std::make_shared<fast::Quantize>(
4848
4963
  s, fallback, group_size, bits, mode, true),
4849
- {w, scales});
4964
+ inputs);
4850
4965
  }
4851
- return fallback({w, scales})[0];
4966
+ return fallback(inputs)[0];
4852
4967
  }
4853
4968
 
4854
4969
  array dequantize(
@@ -4858,6 +4973,7 @@ array dequantize(
4858
4973
  std::optional<int> group_size_ /* = std::nullopt */,
4859
4974
  std::optional<int> bits_ /* = std::nullopt */,
4860
4975
  const std::string& mode /* = "affine" */,
4976
+ const std::optional<array>& global_scale /* = std::nullopt */,
4861
4977
  std::optional<Dtype> dtype /* = std::nullopt */,
4862
4978
  StreamOrDevice s /* = {} */) {
4863
4979
  auto [out_type, qmode] =
@@ -4884,6 +5000,14 @@ array dequantize(
4884
5000
  << "but it has only " << w.ndim() << ".";
4885
5001
  throw std::invalid_argument(msg.str());
4886
5002
  }
5003
+ if (global_scale.has_value()) {
5004
+ if (to_stream(s).device == Device::gpu && metal::is_available()) {
5005
+ std::ostringstream msg;
5006
+ msg << "[dequantize] Global scale is not supported on the Metal backend.";
5007
+ throw std::invalid_argument(msg.str());
5008
+ }
5009
+ }
5010
+ validate_global_scale("dequantize", qmode, global_scale);
4887
5011
 
4888
5012
  if (qmode == QuantizationMode::Affine) {
4889
5013
  return astype(
@@ -4892,7 +5016,14 @@ array dequantize(
4892
5016
  s);
4893
5017
  } else {
4894
5018
  return fp_dequantize(
4895
- w, scales, group_size, bits, out_type, qmode, to_stream(s));
5019
+ w,
5020
+ scales,
5021
+ group_size,
5022
+ bits,
5023
+ out_type,
5024
+ qmode,
5025
+ global_scale,
5026
+ to_stream(s));
4896
5027
  }
4897
5028
  }
4898
5029
 
@@ -6091,4 +6222,4 @@ array contiguous(
6091
6222
  {a});
6092
6223
  }
6093
6224
 
6094
- } // namespace mlx::core
6225
+ } // namespace mlx::core
@@ -666,6 +666,12 @@ min(const array& a,
666
666
  MLX_API array
667
667
  min(const array& a, int axis, bool keepdims = false, StreamOrDevice s = {});
668
668
 
669
+ /** Returns the Hanning window of size M. */
670
+ MLX_API array hanning(int M, StreamOrDevice s = {});
671
+
672
+ /** Returns the Hamming window of size M. */
673
+ MLX_API array hamming(int M, StreamOrDevice s = {});
674
+
669
675
  /** Returns the index of the minimum value in the array. */
670
676
  MLX_API array argmin(const array& a, bool keepdims, StreamOrDevice s = {});
671
677
  inline array argmin(const array& a, StreamOrDevice s = {}) {
@@ -1391,6 +1397,7 @@ MLX_API std::vector<array> quantize(
1391
1397
  std::optional<int> group_size = std::nullopt,
1392
1398
  std::optional<int> bits = std::nullopt,
1393
1399
  const std::string& mode = "affine",
1400
+ const std::optional<array>& global_scale = std::nullopt,
1394
1401
  StreamOrDevice s = {});
1395
1402
 
1396
1403
  /** Dequantize a matrix produced by quantize() */
@@ -1401,17 +1408,20 @@ MLX_API array dequantize(
1401
1408
  std::optional<int> group_size = std::nullopt,
1402
1409
  std::optional<int> bits = std::nullopt,
1403
1410
  const std::string& mode = "affine",
1411
+ const std::optional<array>& global_scale = std::nullopt,
1404
1412
  std::optional<Dtype> dtype = std::nullopt,
1405
1413
  StreamOrDevice s = {});
1406
1414
 
1407
1415
  MLX_API array qqmm(
1408
1416
  array x, // input activations
1409
1417
  array w, // maybe quantized weights
1410
- std::optional<array> w_scales = std::nullopt, // optional scales if w is
1411
- // quantized
1418
+ const std::optional<array> w_scales = std::nullopt, // optional scales if w
1419
+ // is quantized
1412
1420
  std::optional<int> group_size = std::nullopt,
1413
1421
  std::optional<int> bits = std::nullopt,
1414
1422
  const std::string& mode = "nvfp4",
1423
+ const std::optional<array> global_scale_x = std::nullopt,
1424
+ const std::optional<array> global_scale_w = std::nullopt,
1415
1425
  StreamOrDevice s = {});
1416
1426
 
1417
1427
  /** Convert an E4M3 float8 to the given floating point dtype. */
@@ -3424,6 +3424,7 @@ std::vector<array> QuantizedMatmul::vjp(
3424
3424
  group_size_,
3425
3425
  bits_,
3426
3426
  quantization_mode_to_string(mode_),
3427
+ {}, // placeholder for amax
3427
3428
  std::nullopt,
3428
3429
  stream());
3429
3430
  wq = unflatten(wq, -1, {-1, group_size_}, stream());
@@ -3484,14 +3485,14 @@ std::vector<Shape> QQMatmul::output_shapes(const std::vector<array>& inputs) {
3484
3485
  }
3485
3486
 
3486
3487
  std::vector<array> QQMatmul::vjp(
3487
- const std::vector<array>& primals, // non quantized x, non quantized w
3488
+ const std::vector<array>& primals, // non quantized x, non quantized w, if
3489
+ // nvfp4 global_scale_x, global_scale_w
3488
3490
  const std::vector<array>& cotangents, // non quantized upstream grads
3489
3491
  const std::vector<int>& argnums,
3490
3492
  const std::vector<array>&) {
3491
- if (primals.size() != 2) {
3492
- throw std::runtime_error(
3493
- "[QQMatmul::vjp] Expected exactly 2 non-quantized primal inputs (x, w).");
3494
- }
3493
+ bool is_nvfp4 = mode_ == QuantizationMode::Nvfp4;
3494
+ assert(primals.size() == 2 || (is_nvfp4 && primals.size() == 4));
3495
+
3495
3496
  std::vector<array> vjps;
3496
3497
  auto& cotan = cotangents[0];
3497
3498
  auto& s = stream();
@@ -3499,6 +3500,15 @@ std::vector<array> QQMatmul::vjp(
3499
3500
  // primal[0] -- non quantized activations (M, K)
3500
3501
  // cotan -- non quantized grads (M, N)
3501
3502
  auto qmode = quantization_mode_to_string(mode_);
3503
+ std::optional<array> cotan_amax = (primals.size() == 4)
3504
+ ? std::make_optional(astype(max(abs(cotan, s), s), float32, s))
3505
+ : std::nullopt;
3506
+
3507
+ auto get_primal_scale = [&](int idx) {
3508
+ return (primals.size() == 4) ? std::make_optional(primals[idx])
3509
+ : std::nullopt;
3510
+ };
3511
+
3502
3512
  for (auto arg : argnums) {
3503
3513
  if (arg == 0) { // gradient wrt to x
3504
3514
  // We transpose weights -> quantize along N
@@ -3509,6 +3519,8 @@ std::vector<array> QQMatmul::vjp(
3509
3519
  group_size_,
3510
3520
  bits_,
3511
3521
  qmode,
3522
+ cotan_amax,
3523
+ get_primal_scale(3), // global_scale_w (for w.T)
3512
3524
  s));
3513
3525
  } else if (arg == 1) { // gradient wrt to weights
3514
3526
  vjps.push_back(qqmm(
@@ -3518,7 +3530,11 @@ std::vector<array> QQMatmul::vjp(
3518
3530
  group_size_,
3519
3531
  bits_,
3520
3532
  qmode,
3533
+ cotan_amax,
3534
+ get_primal_scale(2), // global_scale_x (for x.T)
3521
3535
  s));
3536
+ } else {
3537
+ vjps.push_back(zeros_like(primals[arg], s));
3522
3538
  }
3523
3539
  }
3524
3540
  return vjps;
@@ -3643,6 +3659,7 @@ std::vector<array> GatherQMM::vjp(
3643
3659
  bits_,
3644
3660
  quantization_mode_to_string(mode_),
3645
3661
  std::nullopt,
3662
+ std::nullopt, // amax placeholder
3646
3663
  stream()),
3647
3664
  -1,
3648
3665
  {-1, group_size_},
@@ -26,6 +26,10 @@ Stream get_stream(int index) {
26
26
  return scheduler::scheduler().get_stream(index);
27
27
  }
28
28
 
29
+ std::vector<Stream> get_streams() {
30
+ return scheduler::scheduler().get_streams();
31
+ }
32
+
29
33
  Stream new_stream(Device d) {
30
34
  if (!gpu::is_available() && d == Device::gpu) {
31
35
  throw std::invalid_argument(
@@ -99,6 +99,9 @@ class Scheduler {
99
99
  Stream get_stream(int index) const {
100
100
  return streams_.at(index);
101
101
  }
102
+ std::vector<Stream> get_streams() const {
103
+ return streams_;
104
+ }
102
105
 
103
106
  void set_default_stream(const Stream& s) {
104
107
  default_streams_.at(s.device.type) = s;
@@ -2,6 +2,8 @@
2
2
 
3
3
  #pragma once
4
4
 
5
+ #include <vector>
6
+
5
7
  #include "mlx/api.h"
6
8
  #include "mlx/device.h"
7
9
 
@@ -25,6 +27,9 @@ MLX_API Stream new_stream(Device d);
25
27
  /** Get the stream with the given index. */
26
28
  MLX_API Stream get_stream(int index);
27
29
 
30
+ /** Get all available streams. */
31
+ MLX_API std::vector<Stream> get_streams();
32
+
28
33
  inline bool operator==(const Stream& lhs, const Stream& rhs) {
29
34
  return lhs.index == rhs.index;
30
35
  }
@@ -0,0 +1,159 @@
1
+ cmake_minimum_required(VERSION 3.25)
2
+
3
+ project(mlx_onnx VERSION 0.30.7.1 LANGUAGES C CXX)
4
+
5
+ set(CMAKE_CXX_STANDARD 20)
6
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
7
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
8
+
9
+ option(MLX_ONNX_USE_EXTERNAL_MLX "Build against an externally provided MLX install" OFF)
10
+ option(MLX_ONNX_BUILD_PYTHON_BINDINGS "Build Python IR bindings" OFF)
11
+ option(MLX_ONNX_INSTALL_CPP_ARTIFACTS "Install C++ library and headers" ON)
12
+
13
+ include(FetchContent)
14
+
15
+ if(MLX_ONNX_USE_EXTERNAL_MLX AND MLX_ONNX_BUILD_PYTHON_BINDINGS)
16
+ message(
17
+ FATAL_ERROR
18
+ "MLX_ONNX_BUILD_PYTHON_BINDINGS requires bundled mlx sources; set MLX_ONNX_USE_EXTERNAL_MLX=OFF")
19
+ endif()
20
+
21
+ if(MLX_ONNX_USE_EXTERNAL_MLX)
22
+ set(MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR "" CACHE PATH "Path to MLX include root")
23
+ set(MLX_ONNX_EXTERNAL_MLX_LIB_DIR "" CACHE PATH "Path to MLX library directory")
24
+
25
+ if(MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR STREQUAL "")
26
+ message(FATAL_ERROR "MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR must be set when MLX_ONNX_USE_EXTERNAL_MLX=ON")
27
+ endif()
28
+ if(MLX_ONNX_EXTERNAL_MLX_LIB_DIR STREQUAL "")
29
+ message(FATAL_ERROR "MLX_ONNX_EXTERNAL_MLX_LIB_DIR must be set when MLX_ONNX_USE_EXTERNAL_MLX=ON")
30
+ endif()
31
+
32
+ find_library(
33
+ MLX_EXTERNAL_LIBRARY
34
+ NAMES mlx
35
+ PATHS ${MLX_ONNX_EXTERNAL_MLX_LIB_DIR}
36
+ NO_DEFAULT_PATH)
37
+
38
+ if(NOT MLX_EXTERNAL_LIBRARY)
39
+ message(FATAL_ERROR "Could not find libmlx in ${MLX_ONNX_EXTERNAL_MLX_LIB_DIR}")
40
+ endif()
41
+
42
+ add_library(mlx SHARED IMPORTED GLOBAL)
43
+ set_target_properties(
44
+ mlx
45
+ PROPERTIES
46
+ IMPORTED_LOCATION ${MLX_EXTERNAL_LIBRARY}
47
+ INTERFACE_INCLUDE_DIRECTORIES ${MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR})
48
+ else()
49
+ set(MLX_BUILD_TESTS OFF CACHE BOOL "" FORCE)
50
+ set(MLX_BUILD_EXAMPLES OFF CACHE BOOL "" FORCE)
51
+ set(MLX_BUILD_BENCHMARKS OFF CACHE BOOL "" FORCE)
52
+ if(MLX_ONNX_BUILD_PYTHON_BINDINGS)
53
+ set(MLX_BUILD_PYTHON_BINDINGS ON CACHE BOOL "" FORCE)
54
+ else()
55
+ set(MLX_BUILD_PYTHON_BINDINGS OFF CACHE BOOL "" FORCE)
56
+ endif()
57
+ set(MLX_BUILD_PYTHON_STUBS OFF CACHE BOOL "" FORCE)
58
+ set(MLX_BUILD_GGUF OFF CACHE BOOL "" FORCE)
59
+ set(MLX_BUILD_SAFETENSORS OFF CACHE BOOL "" FORCE)
60
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mlx)
61
+ endif()
62
+
63
+ if(NOT TARGET nlohmann_json::nlohmann_json)
64
+ FetchContent_Declare(
65
+ nlohmann_json
66
+ GIT_REPOSITORY https://github.com/nlohmann/json.git
67
+ GIT_TAG v3.11.3
68
+ EXCLUDE_FROM_ALL)
69
+ FetchContent_MakeAvailable(nlohmann_json)
70
+ endif()
71
+
72
+ add_library(
73
+ mlx_onnx
74
+ src/export.cpp
75
+ src/api.cpp
76
+ src/compat.cpp
77
+ src/io.cpp
78
+ src/lowering.cpp
79
+ src/mappings.cpp
80
+ src/onnx.cpp
81
+ src/shared.cpp)
82
+
83
+ set_target_properties(mlx_onnx PROPERTIES OUTPUT_NAME mlx_onnx)
84
+
85
+ target_include_directories(
86
+ mlx_onnx
87
+ PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
88
+ $<INSTALL_INTERFACE:include>
89
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
90
+
91
+ if(MLX_ONNX_USE_EXTERNAL_MLX)
92
+ target_include_directories(mlx_onnx PRIVATE ${MLX_ONNX_EXTERNAL_MLX_INCLUDE_DIR})
93
+ endif()
94
+
95
+ target_link_libraries(mlx_onnx PUBLIC mlx nlohmann_json::nlohmann_json)
96
+
97
+ if(MLX_ONNX_BUILD_PYTHON_BINDINGS)
98
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/python/src)
99
+ set(MLX_ONNX_PY_INIT_FILE ${CMAKE_CURRENT_SOURCE_DIR}/python/mlx_onnx/__init__.py)
100
+ if(NOT EXISTS ${MLX_ONNX_PY_INIT_FILE})
101
+ set(MLX_ONNX_PY_INIT_FILE ${CMAKE_CURRENT_BINARY_DIR}/mlx_onnx___init__.py)
102
+ file(WRITE ${MLX_ONNX_PY_INIT_FILE} "from ._core import * # noqa: F401,F403\n")
103
+ endif()
104
+ if(NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/mlx/CMakeLists.txt)
105
+ message(FATAL_ERROR "Bundled mlx sources are missing at ${CMAKE_CURRENT_SOURCE_DIR}/mlx")
106
+ endif()
107
+ if(NOT TARGET core)
108
+ message(FATAL_ERROR "Bundled mlx Python extension target `core` was not built")
109
+ endif()
110
+ install(TARGETS core LIBRARY DESTINATION mlx COMPONENT python)
111
+ if(APPLE AND MLX_BUILD_METAL)
112
+ # MLX looks for mlx.metallib next to the extension module using MLX runtime.
113
+ install(
114
+ FILES ${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib
115
+ DESTINATION mlx
116
+ COMPONENT python)
117
+ # mlx_onnx._core also links MLX and resolves the same metallib at runtime.
118
+ install(
119
+ FILES ${CMAKE_CURRENT_BINARY_DIR}/mlx/mlx/backend/metal/kernels/mlx.metallib
120
+ DESTINATION mlx_onnx
121
+ COMPONENT python)
122
+ endif()
123
+ install(
124
+ DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlx/python/mlx/
125
+ DESTINATION mlx
126
+ COMPONENT python
127
+ PATTERN "__pycache__" EXCLUDE)
128
+ install(
129
+ FILES ${MLX_ONNX_PY_INIT_FILE}
130
+ DESTINATION mlx_onnx
131
+ RENAME __init__.py
132
+ COMPONENT python)
133
+ set(MLX_ONNX_VENDOR_MLX_ROOT mlx_onnx/_vendor/mlx)
134
+ install(
135
+ FILES ${CMAKE_CURRENT_SOURCE_DIR}/mlx/CMakeLists.txt
136
+ ${CMAKE_CURRENT_SOURCE_DIR}/mlx/mlx.pc.in
137
+ ${CMAKE_CURRENT_SOURCE_DIR}/mlx/LICENSE
138
+ ${CMAKE_CURRENT_SOURCE_DIR}/mlx/ACKNOWLEDGMENTS.md
139
+ DESTINATION ${MLX_ONNX_VENDOR_MLX_ROOT}
140
+ COMPONENT python)
141
+ install(
142
+ DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/mlx/cmake
143
+ ${CMAKE_CURRENT_SOURCE_DIR}/mlx/mlx
144
+ DESTINATION ${MLX_ONNX_VENDOR_MLX_ROOT}
145
+ COMPONENT python)
146
+ endif()
147
+
148
+ if(MLX_ONNX_INSTALL_CPP_ARTIFACTS)
149
+ include(GNUInstallDirs)
150
+ install(
151
+ TARGETS mlx_onnx
152
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
153
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
154
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
155
+ INCLUDES DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
156
+ COMPONENT cpp)
157
+
158
+ install(DIRECTORY include/ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR} COMPONENT cpp)
159
+ endif()
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 MLX Contributors
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.