mlx 0.30.7.3 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (590) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +104 -56
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/nn/base.rb +4 -0
  8. data/lib/mlx/nn/layers/linear.rb +2 -3
  9. data/lib/mlx/onnx.rb +250 -0
  10. data/lib/mlx/version.rb +1 -1
  11. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  12. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  13. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  14. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  15. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  16. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  17. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  18. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  19. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  20. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  21. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  22. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  23. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  24. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  25. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  26. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  28. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  29. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  30. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  31. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  32. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  33. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  34. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  35. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  36. data/submodules/mlx-onnx/LICENSE +21 -0
  37. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  38. data/submodules/mlx-onnx/src/api.cpp +81 -0
  39. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  40. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  41. data/submodules/mlx-onnx/src/export.cpp +653 -0
  42. data/submodules/mlx-onnx/src/io.cpp +61 -0
  43. data/submodules/mlx-onnx/src/json.hpp +25 -0
  44. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  45. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  46. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  47. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  48. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  49. metadata +609 -563
  50. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  51. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  52. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  53. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  54. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  55. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  56. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  57. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  58. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  59. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  60. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  61. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  62. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  63. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  64. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  65. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  66. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  67. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  68. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  69. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  70. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  71. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  590. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
@@ -13,39 +13,26 @@ namespace mlx::core {
13
13
 
14
14
  namespace {
15
15
 
16
- // Currently cublas supports only mxfp8 and nvfp4
17
- // quantization modes for block scaled quantization
18
- cudaDataType_t qmode_to_cublas_scale_dtype(std::string mode) {
19
- if (mode == "mxfp8") {
20
- return CUDA_R_8F_UE8M0;
21
- } else if (mode == "nvfp4") {
22
- return CUDA_R_8F_UE4M3;
23
- } else {
24
- throw std::runtime_error(
25
- fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
26
- }
27
- }
28
-
29
- cudaDataType_t qmode_to_cublas_dtype(std::string mode) {
30
- if (mode == "mxfp8") {
31
- return CUDA_R_8F_E4M3;
32
- } else if (mode == "nvfp4") {
33
- return CUDA_R_4F_E2M1;
34
- } else {
35
- throw std::runtime_error(
36
- fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
37
- }
38
- }
16
+ struct QuantModeConfig {
17
+ cudaDataType_t data_type;
18
+ cudaDataType_t scale_dtype;
19
+ cublasLtMatmulMatrixScale_t scale_mode;
20
+ };
39
21
 
40
- cublasLtMatmulMatrixScale_t qmode_to_cublas_scale_mode(std::string mode) {
22
+ QuantModeConfig get_quant_mode_config(const std::string& mode) {
41
23
  if (mode == "mxfp8") {
42
- return CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0;
24
+ return {
25
+ CUDA_R_8F_E4M3,
26
+ CUDA_R_8F_UE8M0,
27
+ CUBLASLT_MATMUL_MATRIX_SCALE_VEC32_UE8M0};
43
28
  } else if (mode == "nvfp4") {
44
- return CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3;
45
- } else {
46
- throw std::runtime_error(
47
- fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
29
+ return {
30
+ CUDA_R_4F_E2M1,
31
+ CUDA_R_8F_UE4M3,
32
+ CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3};
48
33
  }
34
+ throw std::runtime_error(
35
+ fmt::format("Unsupported quantization mode in CublasQQMM: {}.", mode));
49
36
  }
50
37
 
51
38
  } // namespace
@@ -64,21 +51,21 @@ CublasQQMM::CublasQQMM(
64
51
  int64_t a_batch_stride,
65
52
  int64_t b_batch_stride,
66
53
  Dtype out_dtype,
67
- std::string qmode) {
54
+ const std::string& qmode) {
55
+ auto config = get_quant_mode_config(qmode);
56
+
68
57
  // The compute type must be CUBLAS_COMPUTE_32F.
69
58
  // The scale type must be CUDA_R_32F.
70
59
  cudaDataType_t scale_type = CUDA_R_32F;
71
60
  cublasComputeType_t gemm_compute_type = CUBLAS_COMPUTE_32F;
72
61
  cudaDataType_t output_type =
73
62
  cublas_utils::dtype_to_cublas_type(out_dtype, "CublasQQMM");
74
- cudaDataType_t data_type = qmode_to_cublas_dtype(qmode);
75
- quantization_mode_ = std::string(qmode);
76
63
 
77
64
  init_base(
78
65
  device,
79
66
  scale_type,
80
67
  gemm_compute_type,
81
- data_type,
68
+ config.data_type,
82
69
  output_type,
83
70
  a_transposed,
84
71
  a_rows,
@@ -92,8 +79,8 @@ CublasQQMM::CublasQQMM(
92
79
  a_batch_stride,
93
80
  b_batch_stride);
94
81
 
95
- a_scale_mode_ = qmode_to_cublas_scale_mode(qmode);
96
- b_scale_mode_ = qmode_to_cublas_scale_mode(qmode);
82
+ a_scale_mode_ = config.scale_mode;
83
+ b_scale_mode_ = config.scale_mode;
97
84
 
98
85
  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
99
86
  matmul_desc_,
@@ -123,7 +110,7 @@ CublasQQMM::CublasQQMM(
123
110
  int64_t b_batch_stride,
124
111
  int64_t c_batch_stride,
125
112
  Dtype out_dtype,
126
- std::string qmode)
113
+ const std::string& qmode)
127
114
  : CublasQQMM(
128
115
  device,
129
116
  a_transposed,
@@ -158,11 +145,14 @@ void CublasQQMM::run(
158
145
  const array& b,
159
146
  const array& a_scale,
160
147
  const array& b_scale,
161
- float alpha) {
148
+ const array& alpha,
149
+ const array& beta) {
162
150
  encoder.set_input_array(a);
163
151
  encoder.set_input_array(b);
164
152
  encoder.set_input_array(a_scale);
165
153
  encoder.set_input_array(b_scale);
154
+ encoder.set_input_array(alpha);
155
+ encoder.set_input_array(beta);
166
156
  encoder.set_output_array(out);
167
157
 
168
158
  execute(
@@ -173,19 +163,37 @@ void CublasQQMM::run(
173
163
  gpu_ptr<void>(a_scale),
174
164
  gpu_ptr<void>(b_scale),
175
165
  nullptr,
176
- alpha);
166
+ gpu_ptr<void>(alpha),
167
+ gpu_ptr<void>(beta));
177
168
  }
178
169
 
179
- void CublasQQMM::execute(
170
+ void CublasQQMM::run(
171
+ cu::CommandEncoder& encoder,
172
+ array& out,
173
+ const array& a,
174
+ const array& b,
175
+ const array& a_scale,
176
+ const array& b_scale) {
177
+ encoder.set_input_array(a);
178
+ encoder.set_input_array(b);
179
+ encoder.set_input_array(a_scale);
180
+ encoder.set_input_array(b_scale);
181
+ encoder.set_output_array(out);
182
+
183
+ execute(
184
+ encoder,
185
+ gpu_ptr<void>(out),
186
+ gpu_ptr<void>(a),
187
+ gpu_ptr<void>(b),
188
+ gpu_ptr<void>(a_scale),
189
+ gpu_ptr<void>(b_scale),
190
+ nullptr);
191
+ }
192
+
193
+ void CublasQQMM::set_scales_ptrs(
180
194
  cu::CommandEncoder& encoder,
181
- void* out,
182
- const void* a,
183
- const void* b,
184
195
  const void* a_scale,
185
- const void* b_scale,
186
- const void* c,
187
- float alpha /* = 1 */,
188
- float beta /* = 0 */) {
196
+ const void* b_scale) {
189
197
  CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
190
198
  matmul_desc_,
191
199
  CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
@@ -196,6 +204,49 @@ void CublasQQMM::execute(
196
204
  CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
197
205
  &a_scale,
198
206
  sizeof(a_scale)));
207
+ }
208
+
209
+ void CublasQQMM::execute(
210
+ cu::CommandEncoder& encoder,
211
+ void* out,
212
+ const void* a,
213
+ const void* b,
214
+ const void* a_scale,
215
+ const void* b_scale,
216
+ const void* c,
217
+ const void* alpha,
218
+ const void* beta) {
219
+ set_scales_ptrs(encoder, a_scale, b_scale);
220
+ // alpha and beta are both should be device pointers for nvfp4
221
+ // by default cublas uses host pointers
222
+ // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t
223
+ cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
224
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
225
+ matmul_desc_,
226
+ CUBLASLT_MATMUL_DESC_POINTER_MODE,
227
+ &pointer_mode,
228
+ sizeof(pointer_mode)));
229
+ execute_matmul(encoder, out, a, b, c, alpha, beta);
230
+ }
231
+
232
+ void CublasQQMM::execute(
233
+ cu::CommandEncoder& encoder,
234
+ void* out,
235
+ const void* a,
236
+ const void* b,
237
+ const void* a_scale,
238
+ const void* b_scale,
239
+ const void* c,
240
+ const float alpha /* = 1 */,
241
+ const float beta /* = 0 */) {
242
+ set_scales_ptrs(encoder, a_scale, b_scale);
243
+ // alpha and beta are both should be host pointers
244
+ cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
245
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
246
+ matmul_desc_,
247
+ CUBLASLT_MATMUL_DESC_POINTER_MODE,
248
+ &pointer_mode,
249
+ sizeof(pointer_mode)));
199
250
 
200
251
  const void* alpha_ptr = &alpha;
201
252
  const void* beta_ptr = &beta;
@@ -25,7 +25,7 @@ class CublasQQMM : public CublasMatmulBase {
25
25
  int64_t a_batch_stride,
26
26
  int64_t b_batch_stride,
27
27
  Dtype out_dtype,
28
- std::string quantization_mode);
28
+ const std::string& quantization_mode);
29
29
 
30
30
  CublasQQMM(
31
31
  cu::Device& device,
@@ -43,7 +43,7 @@ class CublasQQMM : public CublasMatmulBase {
43
43
  int64_t b_batch_stride,
44
44
  int64_t c_batch_stride,
45
45
  Dtype out_dtype,
46
- std::string quantization_mode);
46
+ const std::string& quantization_mode);
47
47
 
48
48
  void run(
49
49
  cu::CommandEncoder& encoder,
@@ -52,20 +52,33 @@ class CublasQQMM : public CublasMatmulBase {
52
52
  const array& b,
53
53
  const array& a_scale,
54
54
  const array& b_scale,
55
- float alpha = 1.0f);
55
+ const array& alpha,
56
+ const array& beta);
56
57
 
57
- private:
58
- void run_batched(
58
+ void run(
59
59
  cu::CommandEncoder& encoder,
60
60
  array& out,
61
61
  const array& a,
62
62
  const array& b,
63
63
  const array& a_scale,
64
- const array& b_scale,
65
- const Shape& batch_shape,
66
- const Strides& a_batch_strides,
67
- const Strides& b_batch_strides,
68
- float alpha);
64
+ const array& b_scale);
65
+
66
+ private:
67
+ void set_scales_ptrs(
68
+ cu::CommandEncoder& encoder,
69
+ const void* a_scale,
70
+ const void* b_scale);
71
+
72
+ void execute(
73
+ cu::CommandEncoder& encoder,
74
+ void* out,
75
+ const void* a,
76
+ const void* b,
77
+ const void* a_scale,
78
+ const void* b_scale,
79
+ const void* c,
80
+ const void* alpha,
81
+ const void* beta);
69
82
 
70
83
  void execute(
71
84
  cu::CommandEncoder& encoder,
@@ -75,10 +88,9 @@ class CublasQQMM : public CublasMatmulBase {
75
88
  const void* a_scale,
76
89
  const void* b_scale,
77
90
  const void* c,
78
- float alpha = 1,
79
- float beta = 0);
91
+ const float alpha = 1.0f,
92
+ const float beta = 0.0f);
80
93
 
81
- std::string quantization_mode_;
82
94
  cublasLtMatmulMatrixScale_t a_scale_mode_;
83
95
  cublasLtMatmulMatrixScale_t b_scale_mode_;
84
96
  cublasLtMatmulMatrixScale_t c_scale_mode_;
@@ -11,6 +11,11 @@
11
11
 
12
12
  #include <cooperative_groups.h>
13
13
  #include <cooperative_groups/reduce.h>
14
+ #include <cuda_fp4.h>
15
+ #include <cuda_fp8.h>
16
+
17
+ constexpr float F8E4M3_MAX = 448.0f;
18
+ constexpr float F4E2M1_MAX = 6.0f;
14
19
 
15
20
  namespace mlx::core {
16
21
  namespace cu {
@@ -29,7 +34,16 @@ struct Dequantize {
29
34
  namespace cg = cooperative_groups;
30
35
 
31
36
  template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
32
- __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) {
37
+ __global__ void fp_quantize_dequantize(
38
+ T* w,
39
+ T* out,
40
+ size_t size,
41
+ float* global_scale = nullptr) {
42
+ const bool use_global_scale = global_scale != nullptr;
43
+ const float scale_enc =
44
+ use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;
45
+ const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f;
46
+
33
47
  using Tx2 = Vector2_t<T>;
34
48
  using Tx4 = Vector4_t<T>;
35
49
  uint32_t rbits = 0; // reserved bits for future use
@@ -48,26 +62,28 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) {
48
62
  }
49
63
 
50
64
  auto w_tile = load_vector<group_size, T>(w, thread_idx);
51
- float scale = 0.0f;
65
+ float scale_dec_b = 0.0f;
52
66
 
53
67
  Tx2 amax_2x = Tx2{0.0f, 0.0f};
54
68
 
55
69
  #pragma unroll
56
70
  for (int i = 0; i < group_size; i += 2) {
57
71
  auto pair = Tx2{w_tile[i], w_tile[i + 1]};
58
- abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
72
+ absmax_x2<Tx2>(amax_2x, amax_2x, pair);
59
73
  }
60
74
 
61
- scale = static_cast<float>(
75
+ scale_dec_b = static_cast<float>(
62
76
  max(fabsf(static_cast<float>(amax_2x.x)),
63
77
  fabsf(static_cast<float>(amax_2x.y))));
64
78
 
65
- scale /= bits == 4 ? 6.0f : 448.0f;
79
+ scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;
80
+ scale_dec_b *= scale_enc;
66
81
  // Convert to mx scale or nv scale
67
82
  using ScaleType =
68
83
  std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
69
- auto s = ScaleType(scale);
70
- scale = float(s);
84
+ auto s = ScaleType(scale_dec_b);
85
+ float scale_enc_b = scale_enc / float(s);
86
+ float scale_dec = float(s) * inv_scale_enc;
71
87
  AlignedVector<T, group_size> w_hat;
72
88
 
73
89
  #pragma unroll
@@ -76,24 +92,36 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) {
76
92
  float4 dq;
77
93
  if constexpr (bits == 8) {
78
94
  uint32_t quantized_val =
79
- scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
95
+ scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
80
96
  dq = dequant_fp8(quantized_val);
81
97
  } else {
82
98
  uint16_t quantized_val =
83
- scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
99
+ scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
84
100
  dq = dequant_fp4(quantized_val);
85
101
  }
86
- w_hat[i * 4] = static_cast<T>(dq.x * scale);
87
- w_hat[i * 4 + 1] = static_cast<T>(dq.y * scale);
88
- w_hat[i * 4 + 2] = static_cast<T>(dq.z * scale);
89
- w_hat[i * 4 + 3] = static_cast<T>(dq.w * scale);
102
+ w_hat[i * 4] = static_cast<T>(dq.x * scale_dec);
103
+ w_hat[i * 4 + 1] = static_cast<T>(dq.y * scale_dec);
104
+ w_hat[i * 4 + 2] = static_cast<T>(dq.z * scale_dec);
105
+ w_hat[i * 4 + 3] = static_cast<T>(dq.w * scale_dec);
90
106
  }
91
107
  store_vector<group_size>(out, thread_idx, w_hat);
92
108
  }
93
109
 
94
110
  template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
95
- __global__ void
96
- fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) {
111
+ __global__ void fp_quantize_rowwise(
112
+ T* w,
113
+ uint8_t* out,
114
+ uint8_t* scales,
115
+ size_t size,
116
+ float* global_scale = nullptr) {
117
+ // NVFP4 conversion:
118
+ // Global encode scale: (448 × 6) / *global_scale
119
+ // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8
120
+ // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b
121
+ const bool use_global_scale = global_scale != nullptr;
122
+ const float scale_enc =
123
+ use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;
124
+
97
125
  using Tx2 = Vector2_t<T>;
98
126
  using Tx4 = Vector4_t<T>;
99
127
  uint32_t rbits = 0; // reserved bits for future use
@@ -112,27 +140,28 @@ fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) {
112
140
  }
113
141
 
114
142
  auto w_tile = load_vector<group_size, T>(w, thread_idx);
115
- float scale = 0.0f;
143
+ float scale_dec_b = 0.0f;
116
144
 
117
145
  Tx2 amax_2x = Tx2{0.0f, 0.0f};
118
146
 
119
147
  #pragma unroll
120
148
  for (int i = 0; i < group_size; i += 2) {
121
149
  auto pair = Tx2{w_tile[i], w_tile[i + 1]};
122
- abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
150
+ absmax_x2<Tx2>(amax_2x, amax_2x, pair);
123
151
  }
124
152
 
125
- scale = static_cast<float>(
153
+ scale_dec_b = static_cast<float>(
126
154
  max(fabsf(static_cast<float>(amax_2x.x)),
127
155
  fabsf(static_cast<float>(amax_2x.y))));
128
156
 
129
- scale /= bits == 4 ? 6.0f : 448.0f;
157
+ scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;
158
+ scale_dec_b *= scale_enc;
130
159
  // Convert to mx scale or nv scale
131
160
  using ScaleType =
132
161
  std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
133
- auto s = ScaleType(scale);
162
+ auto s = ScaleType(scale_dec_b);
134
163
  uint8_t q_scale = s.__x;
135
- scale = float(s);
164
+ float scale_enc_b = scale_enc / float(s);
136
165
 
137
166
  scales[thread_idx] = q_scale;
138
167
  constexpr int elem_per_byte = bits == 8 ? 1 : 2;
@@ -143,11 +172,11 @@ fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) {
143
172
  Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
144
173
  if constexpr (bits == 8) {
145
174
  uint32_t quantized_val =
146
- scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
175
+ scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
147
176
  *reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
148
177
  } else {
149
178
  uint16_t quantized_val =
150
- scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
179
+ scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
151
180
  *reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
152
181
  }
153
182
  }
@@ -161,11 +190,15 @@ __global__ void fp_quantize_columnwise(
161
190
  uint8_t* scales,
162
191
  size_t size,
163
192
  int M,
164
- int K) {
193
+ int K,
194
+ float* global_scale = nullptr) {
165
195
  // Input: [M, K] with strides [1, M] (M-major)
166
196
  // Quantized output: [M, K/elem_per_byte] row-major (K-major)
167
197
  // Scales: [M, K/group_size] row-major (K-major)
168
198
  // Quantize along K (last dimension, groups of group_size elements)
199
+ const bool use_global_scale = global_scale != nullptr;
200
+ const float scale_enc =
201
+ use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;
169
202
 
170
203
  using Tx2 = Vector2_t<T>;
171
204
  using Tx4 = Vector4_t<T>;
@@ -215,16 +248,18 @@ __global__ void fp_quantize_columnwise(
215
248
  #pragma unroll
216
249
  for (int r = 0; r < group_size; r += 2) {
217
250
  auto pair = Tx2{thread_data[r], thread_data[r + 1]};
218
- abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
251
+ absmax_x2<Tx2>(amax_2x, amax_2x, pair);
219
252
  }
220
- float scale =
253
+ float scale_dec_b =
221
254
  max(fabsf(static_cast<float>(amax_2x.x)),
222
255
  fabsf(static_cast<float>(amax_2x.y)));
223
- scale /= (bits == 4) ? 6.0f : 448.0f;
256
+ scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;
257
+ scale_dec_b *= scale_enc;
258
+ // Convert to mx scale or nv scale
224
259
  using ScaleType =
225
260
  std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
226
- auto s = ScaleType(scale);
227
- scale = float(s);
261
+ auto s = ScaleType(scale_dec_b);
262
+ float scale_enc_b = scale_enc / float(s);
228
263
  scales_smem[tidx][tidy] = s.__x;
229
264
 
230
265
  int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group;
@@ -234,12 +269,12 @@ __global__ void fp_quantize_columnwise(
234
269
  Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&thread_data[j * 4]);
235
270
  if constexpr (bits == 8) {
236
271
  uint32_t quantized_val =
237
- scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
272
+ scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
238
273
  *reinterpret_cast<uint32_t*>(&quantized_smem[shared_idx + j * 4]) =
239
274
  quantized_val;
240
275
  } else {
241
276
  uint16_t quantized_val =
242
- scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
277
+ scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
243
278
  *reinterpret_cast<uint16_t*>(&quantized_smem[shared_idx + j * 2]) =
244
279
  quantized_val;
245
280
  }
@@ -282,8 +317,12 @@ __global__ void fp_quantize_columnwise(
282
317
  }
283
318
 
284
319
  template <typename T, int group_size, int bits, bool use_mx_scale>
285
- __global__ void
286
- fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
320
+ __global__ void fp_dequantize(
321
+ const uint8_t* w,
322
+ const uint8_t* scales,
323
+ T* out,
324
+ size_t size,
325
+ float* global_scale = nullptr) {
287
326
  auto block_size = cg::this_thread_block().dim_threads();
288
327
  auto block_idx = cg::this_thread_block().group_index();
289
328
  auto idx_in_block = cg::this_thread_block().thread_index();
@@ -294,6 +333,10 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
294
333
  auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
295
334
 
296
335
  constexpr int pack_factor = bits == 8 ? 1 : 2;
336
+ const bool use_global_scale = global_scale != nullptr;
337
+ const float inv_scale_enc = use_mx_scale
338
+ ? 1.0f
339
+ : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f);
297
340
  size_t offset = tidx + grid_dim_x * size_t(tidy);
298
341
  size_t oindex = offset * pack_factor;
299
342
 
@@ -304,7 +347,7 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
304
347
  size_t gindex = oindex / group_size;
305
348
  using ScaleType =
306
349
  std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
307
- auto scale = float(((ScaleType*)(scales))[gindex]);
350
+ auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc;
308
351
 
309
352
  out += oindex;
310
353
 
@@ -346,9 +389,13 @@ void fp_quantize_dequantize(
346
389
  array& what,
347
390
  int group_size,
348
391
  int bits,
392
+ const std::optional<array>& global_scale /* = std::nullopt */,
349
393
  cu::CommandEncoder& enc,
350
394
  const Stream& s) {
351
395
  enc.set_input_array(w);
396
+ if (global_scale.has_value()) {
397
+ enc.set_input_array(global_scale.value());
398
+ }
352
399
  enc.set_output_array(what);
353
400
  dispatch_float_types(w.dtype(), "fp_quantize_dequantize", [&](auto type_tag) {
354
401
  using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -370,7 +417,9 @@ void fp_quantize_dequantize(
370
417
  0,
371
418
  gpu_ptr<T>(w),
372
419
  gpu_ptr<T>(what),
373
- w.size());
420
+ w.size(),
421
+ global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
422
+ : nullptr);
374
423
  }
375
424
  });
376
425
  }
@@ -381,9 +430,13 @@ void fp_quantize(
381
430
  array& scales,
382
431
  int group_size,
383
432
  int bits,
433
+ const std::optional<array>& global_scale /* = std::nullopt */,
384
434
  cu::CommandEncoder& enc,
385
435
  const Stream& s) {
386
436
  enc.set_input_array(w);
437
+ if (global_scale.has_value()) {
438
+ enc.set_input_array(global_scale.value());
439
+ }
387
440
  enc.set_output_array(wq);
388
441
  enc.set_output_array(scales);
389
442
  if (w.strides().back() != 1) {
@@ -410,7 +463,9 @@ void fp_quantize(
410
463
  gpu_ptr<uint8_t>(scales),
411
464
  w.size(),
412
465
  M,
413
- K);
466
+ K,
467
+ global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
468
+ : nullptr);
414
469
  } else {
415
470
  throw std::runtime_error(
416
471
  "[Quantize::eval_gpu] Can not quantize input with type float64.");
@@ -438,7 +493,9 @@ void fp_quantize(
438
493
  gpu_ptr<T>(w),
439
494
  gpu_ptr<uint8_t>(wq),
440
495
  gpu_ptr<uint8_t>(scales),
441
- w.size());
496
+ w.size(),
497
+ global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
498
+ : nullptr);
442
499
  } else {
443
500
  throw std::runtime_error(
444
501
  "[Quantize::eval_gpu] Can not quantize input with type float64.");
@@ -453,6 +510,7 @@ void fp_dequantize(
453
510
  array& w,
454
511
  int group_size,
455
512
  int bits,
513
+ const std::optional<array>& global_scale /* = std::nullopt */,
456
514
  cu::CommandEncoder& enc,
457
515
  const Stream& s) {
458
516
  constexpr int uint8_per_uint32 = 4;
@@ -465,6 +523,9 @@ void fp_dequantize(
465
523
 
466
524
  enc.set_input_array(wq);
467
525
  enc.set_input_array(scales);
526
+ if (global_scale.has_value()) {
527
+ enc.set_input_array(global_scale.value());
528
+ }
468
529
  enc.set_output_array(w);
469
530
  dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
470
531
  using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -485,7 +546,9 @@ void fp_dequantize(
485
546
  gpu_ptr<uint8_t>(wq),
486
547
  gpu_ptr<uint8_t>(scales),
487
548
  gpu_ptr<T>(w),
488
- w.size());
549
+ w.size(),
550
+ global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
551
+ : nullptr);
489
552
  } else {
490
553
  throw std::runtime_error(
491
554
  "[Quantize::eval_gpu] Can not dequantize to output with type float64.");
@@ -17,9 +17,8 @@ void qqmm_impl(
17
17
  const array&,
18
18
  const array&,
19
19
  const array&,
20
- Dtype,
21
20
  QuantizationMode,
22
- float) {
21
+ const GemmScalars&) {
23
22
  throw std::runtime_error(
24
23
  "[QQMatmul::eval_gpu] QQMM is only supported with CUDA 12.8 or higher.");
25
24
  }