mlx 0.30.7.3 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (590) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +104 -56
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/nn/base.rb +4 -0
  8. data/lib/mlx/nn/layers/linear.rb +2 -3
  9. data/lib/mlx/onnx.rb +250 -0
  10. data/lib/mlx/version.rb +1 -1
  11. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  12. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  13. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  14. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  15. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  16. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  17. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  18. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  19. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  20. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  21. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  22. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  23. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  24. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  25. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  26. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  28. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  29. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  30. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  31. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  32. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  33. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  34. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  35. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  36. data/submodules/mlx-onnx/LICENSE +21 -0
  37. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  38. data/submodules/mlx-onnx/src/api.cpp +81 -0
  39. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  40. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  41. data/submodules/mlx-onnx/src/export.cpp +653 -0
  42. data/submodules/mlx-onnx/src/io.cpp +61 -0
  43. data/submodules/mlx-onnx/src/json.hpp +25 -0
  44. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  45. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  46. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  47. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  48. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  49. metadata +609 -563
  50. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  51. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  52. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  53. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  54. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  55. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  56. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  57. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  58. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  59. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  60. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  61. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  62. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  63. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  64. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  65. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  66. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  67. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  68. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  69. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  70. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  71. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  590. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
@@ -0,0 +1,193 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/device.h"
4
+ #include "mlx/backend/cuda/quantized/qmv.h"
5
+ #include "mlx/backend/cuda/quantized/qqmm_impl.h"
6
+ #include "mlx/backend/cuda/quantized/qqmm_utils.h"
7
+ #include "mlx/backend/cuda/quantized/quantized.h"
8
+ #include "mlx/backend/cuda/quantized/quantized_utils.h"
9
+ #include "mlx/primitives.h"
10
+
11
+ #include <nvtx3/nvtx3.hpp>
12
+
13
+ namespace mlx::core {
14
+
15
+ namespace {
16
+
17
+ std::tuple<array, array> quantize_input(
18
+ const array& input,
19
+ cu::CommandEncoder& encoder,
20
+ const Stream& s,
21
+ QuantizationMode mode,
22
+ int bits,
23
+ int group_size,
24
+ std::optional<array> global_scale = std::nullopt) {
25
+ const array x = ensure_contiguous(input, encoder, s);
26
+
27
+ // Compute output shapes
28
+ auto xq_shape = x.shape();
29
+ xq_shape.back() = x.shape(-1) * bits / 32;
30
+
31
+ const int64_t scales_inner = x.shape(-1) / group_size;
32
+ auto [pad_outer, pad_inner] =
33
+ get_padded_scale_dims(x.shape(-2), scales_inner);
34
+
35
+ auto sshape = x.shape();
36
+ sshape[x.ndim() - 2] = pad_outer;
37
+ sshape[x.ndim() - 1] = pad_inner;
38
+ sshape.back() = scales_inner;
39
+
40
+ // Allocate outputs
41
+ const int64_t xq_bytes = x.size() * bits / 8;
42
+ const int64_t batch = x.size() / (x.shape(-2) * x.shape(-1));
43
+ const int64_t scales_bytes = batch * (pad_outer * pad_inner);
44
+
45
+ array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32);
46
+ array scales_x(
47
+ cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8);
48
+ encoder.add_temporary(x_q);
49
+ encoder.add_temporary(scales_x);
50
+ // global_scale is not nullopt only for NVFP4
51
+ fp_quantize(x, x_q, scales_x, group_size, bits, global_scale, encoder, s);
52
+ return {std::move(x_q), std::move(scales_x)};
53
+ }
54
+
55
+ GemmScalars create_nvfp4_scalars(
56
+ const array& global_scale_x,
57
+ const array& global_scale_w,
58
+ cu::CommandEncoder& encoder) {
59
+ // NVFP4 requires alpha/beta as device pointers
60
+ // alpha = amax_x * amax_w / (448 * 6)^2
61
+ // beta = 0
62
+ array alpha(cu::malloc_async(sizeof(float), encoder), {}, float32);
63
+ array beta(cu::malloc_async(sizeof(float), encoder), {}, float32);
64
+ compute_qqmm_pointers(alpha, beta, global_scale_x, global_scale_w, encoder);
65
+ encoder.add_temporary(alpha);
66
+ encoder.add_temporary(beta);
67
+ return {alpha, beta};
68
+ }
69
+
70
+ } // namespace
71
+
72
+ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
73
+ nvtx3::scoped_range r("QQMatmul::eval_gpu");
74
+
75
+ auto& s = stream();
76
+ auto& encoder = cu::get_command_encoder(s);
77
+ auto& device = encoder.device();
78
+ bool w_quantized = (inputs[1].dtype() == uint32);
79
+ int base_size = w_quantized ? 3 : 2;
80
+
81
+ assert(
82
+ inputs.size() == base_size ||
83
+ (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2));
84
+
85
+ if (w_quantized && inputs[0].shape(-2) == 1) {
86
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
87
+
88
+ // For nvfp4, get global scale for x from inputs if present
89
+ bool has_global_scale =
90
+ mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;
91
+ std::optional<array> global_scale = std::nullopt;
92
+ if (has_global_scale) {
93
+ global_scale = inputs[inputs.size() - 2];
94
+ }
95
+
96
+ bool donate_x = inputs[0].is_donatable();
97
+ array x = ensure_row_contiguous(inputs[0], encoder, s);
98
+ // If x is a copy it should be donatable
99
+ donate_x |= x.is_donatable();
100
+ auto xhat = donate_x
101
+ ? x
102
+ : array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype());
103
+ if (!donate_x) {
104
+ encoder.add_temporary(xhat);
105
+ }
106
+ fp_quantize_dequantize(
107
+ x, xhat, group_size_, bits_, global_scale, encoder, s);
108
+
109
+ // Make sure the last two dims of w and s are contiguous
110
+ array w = ensure_row_contiguous_matrix(inputs[1], encoder, s);
111
+ array scales = ensure_row_contiguous_matrix(inputs[2], encoder, s);
112
+
113
+ bool non_batched = w.ndim() == 2;
114
+ int K = x.shape(-1);
115
+ int M = non_batched ? x.size() / K : x.shape(-2);
116
+ int N = out.shape(-1);
117
+
118
+ fp_qmv(w, scales, xhat, out, bits_, group_size_, M, N, K, encoder);
119
+ return;
120
+ }
121
+
122
+ auto cc = device.compute_capability_major() * 100 +
123
+ device.compute_capability_minor() * 10;
124
+ if (cc < 1000) {
125
+ throw std::runtime_error(
126
+ "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher.");
127
+ }
128
+
129
+ // - 2 inputs: x, w (non-quantized w)
130
+ // - 3 inputs: x, w, scales_w (quantized w)
131
+
132
+ // For nvfp4, global scales are optional but must be both present or both
133
+ // absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
134
+ bool has_global_scales =
135
+ mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;
136
+
137
+ // For nvfp4, get global scales from inputs if present
138
+ std::optional<array> global_scale_x = std::nullopt;
139
+ std::optional<array> global_scale_w = std::nullopt;
140
+ if (has_global_scales) {
141
+ global_scale_x = inputs[inputs.size() - 2];
142
+ global_scale_w = inputs[inputs.size() - 1];
143
+ }
144
+
145
+ // Quantize inputs (or use pre-quantized)
146
+ auto [x_q, scale_x_pre] = quantize_input(
147
+ inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x);
148
+ auto [w_q, scale_w_pre] = !w_quantized
149
+ ? quantize_input(
150
+ inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w)
151
+ : std::make_tuple(
152
+ ensure_contiguous(inputs[1], encoder, s),
153
+ ensure_contiguous(inputs[2], encoder, s));
154
+
155
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
156
+
157
+ int M = x_q.shape(-2);
158
+ int N = w_q.shape(-2); // transposed
159
+ int K = x_q.shape(-1) * (32 / bits_);
160
+
161
+ bool x_transposed = false;
162
+ bool w_transposed = true; // always transposed
163
+ int64_t lda = K;
164
+ int64_t ldb = K;
165
+
166
+ // Repack scales to tiled layout for tensor cores
167
+ array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s);
168
+ array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s);
169
+
170
+ GemmScalars scalars;
171
+ if (has_global_scales) {
172
+ scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder);
173
+ }
174
+
175
+ qqmm_impl(
176
+ encoder,
177
+ M,
178
+ N,
179
+ K,
180
+ x_transposed,
181
+ lda,
182
+ w_transposed,
183
+ ldb,
184
+ out,
185
+ x_q,
186
+ w_q,
187
+ scale_x,
188
+ scale_w,
189
+ mode_,
190
+ scalars);
191
+ }
192
+
193
+ } // namespace mlx::core
@@ -19,15 +19,10 @@ void qqmm_impl(
19
19
  const array& b,
20
20
  const array& a_scale,
21
21
  const array& b_scale,
22
- Dtype out_dtype,
23
22
  QuantizationMode mode,
24
- float alpha) {
25
- // Invoke CublasQQMM
23
+ const GemmScalars& scalars) {
26
24
  std::string qmode = quantization_mode_to_string(mode);
27
25
 
28
- // Currently only supports non-batched QQMM operations
29
- // that covers all use cases for training, we will just collapse (batch,
30
- // seq_len) into (tokens)
31
26
  CublasQQMM qqmm(
32
27
  encoder.device(),
33
28
  a_transposed,
@@ -41,10 +36,22 @@ void qqmm_impl(
41
36
  1, // batch_count
42
37
  0, // a_batch_stride
43
38
  0, // b_batch_stride
44
- out_dtype,
39
+ out.dtype(),
45
40
  qmode);
46
41
 
47
- qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha);
42
+ if (scalars.has_values()) {
43
+ qqmm.run(
44
+ encoder,
45
+ out,
46
+ a,
47
+ b,
48
+ a_scale,
49
+ b_scale,
50
+ *scalars.alpha_device,
51
+ *scalars.beta_device);
52
+ } else {
53
+ qqmm.run(encoder, out, a, b, a_scale, b_scale);
54
+ }
48
55
  }
49
56
 
50
57
  } // namespace mlx::core
@@ -1,10 +1,22 @@
1
- // Copyright © 2026 Apple Inc.
1
+ // Copyright © 2025 Apple Inc.
2
2
  #pragma once
3
3
 
4
4
  #include "mlx/backend/cuda/device.h"
5
5
  #include "mlx/primitives.h"
6
6
 
7
+ #include <optional>
8
+
7
9
  namespace mlx::core {
10
+
11
+ struct GemmScalars {
12
+ std::optional<array> alpha_device;
13
+ std::optional<array> beta_device;
14
+
15
+ bool has_values() const {
16
+ return alpha_device.has_value();
17
+ }
18
+ };
19
+
8
20
  void qqmm_impl(
9
21
  cu::CommandEncoder& encoder,
10
22
  int M,
@@ -19,8 +31,7 @@ void qqmm_impl(
19
31
  const array& b,
20
32
  const array& a_scale,
21
33
  const array& b_scale,
22
- Dtype out_dtype,
23
34
  QuantizationMode mode,
24
- float alpha = 1.0f);
35
+ const GemmScalars& scalars = {});
25
36
 
26
37
  } // namespace mlx::core
@@ -70,6 +70,21 @@ inline std::tuple<dim3, dim3> get_swizzle_launch_args(
70
70
 
71
71
  namespace cu {
72
72
 
73
+ constexpr float F8E4M3_MAX = 448.0f;
74
+ constexpr float F4E2M1_MAX = 6.0f;
75
+
76
+ __global__ void compute_qqmm_pointers(
77
+ float* alpha_out,
78
+ float* beta_out,
79
+ const float* tensor_amax_x,
80
+ const float* tensor_amax_w) {
81
+ // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2
82
+ constexpr float inv_scale_sq =
83
+ 1.0f / (F8E4M3_MAX * F4E2M1_MAX * F8E4M3_MAX * F4E2M1_MAX);
84
+ *alpha_out = (*tensor_amax_x) * (*tensor_amax_w) * inv_scale_sq;
85
+ *beta_out = 0.0f;
86
+ }
87
+
73
88
  __global__ void swizzle_scales(
74
89
  const uint8_t* scales_linear,
75
90
  uint8_t* scales_swizzled,
@@ -224,4 +239,25 @@ void swizzle_scales(
224
239
  output_cols);
225
240
  }
226
241
 
242
+ void compute_qqmm_pointers(
243
+ array& alpha_out,
244
+ array& beta_out,
245
+ const array& tensor_amax_x,
246
+ const array& tensor_amax_w,
247
+ cu::CommandEncoder& enc) {
248
+ enc.set_input_array(tensor_amax_x);
249
+ enc.set_input_array(tensor_amax_w);
250
+ enc.set_output_array(alpha_out);
251
+ enc.set_output_array(beta_out);
252
+ enc.add_kernel_node(
253
+ cu::compute_qqmm_pointers,
254
+ dim3(1),
255
+ dim3(1),
256
+ 0,
257
+ gpu_ptr<void>(alpha_out),
258
+ gpu_ptr<void>(beta_out),
259
+ gpu_ptr<void>(tensor_amax_x),
260
+ gpu_ptr<void>(tensor_amax_w));
261
+ }
262
+
227
263
  } // namespace mlx::core
@@ -0,0 +1,62 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/cuda/device.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ // Compute padded dimensions for tiled layout
11
+ // Tiles are 128 rows × 4 columns, must allocate full tiles
12
+ inline std::pair<int, int> get_padded_scale_dims(int num_rows, int num_cols) {
13
+ constexpr int rows_per_tile = 128;
14
+ constexpr int cols_per_tile = 4;
15
+
16
+ int padded_rows =
17
+ ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile;
18
+ int padded_cols =
19
+ ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile;
20
+
21
+ return {padded_rows, padded_cols};
22
+ }
23
+
24
+ void swizzle_scales(
25
+ const array& scales,
26
+ array& scales_tiled,
27
+ cu::CommandEncoder& enc,
28
+ const Stream& s);
29
+
30
+ inline array pad_and_swizzle_scales(
31
+ const array& scale,
32
+ cu::CommandEncoder& encoder,
33
+ const Stream& s) {
34
+ // Compute padded dimensions for full tiles (128 rows × 4 cols)
35
+ auto [pad_outer, pad_inner] =
36
+ get_padded_scale_dims(scale.shape(-2), scale.shape(-1));
37
+ // cuBLAS requirements for scale factor layout:
38
+ // 1. Dimensions must be padded to full tiles (128 rows × 4 cols)
39
+ // 2. Out-of-bounds values must be filled with zeros
40
+ // 3. Starting addresses must be 16-byte aligned
41
+ // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
42
+ // Note: cu::malloc_async already provides 256-byte alignment
43
+ array scale_tiled(
44
+ cu::malloc_async(pad_outer * pad_inner, encoder),
45
+ Shape{pad_outer, pad_inner},
46
+ scale.dtype());
47
+ swizzle_scales(scale, scale_tiled, encoder, s);
48
+
49
+ encoder.add_temporary(scale_tiled);
50
+ return scale_tiled;
51
+ }
52
+
53
+ // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2
54
+ // Allocate beta zero on device as well
55
+ void compute_qqmm_pointers(
56
+ array& alpha_out,
57
+ array& beta_out,
58
+ const array& tensor_amax_x,
59
+ const array& tensor_amax_w,
60
+ cu::CommandEncoder& enc);
61
+
62
+ } // namespace mlx::core
@@ -51,7 +51,6 @@ void fast::Quantize::eval_gpu(
51
51
  auto& s = stream();
52
52
  auto& d = cu::device(s.device);
53
53
  auto& enc = d.get_command_encoder(s);
54
-
55
54
  if (dequantize_) {
56
55
  auto wq = ensure_row_contiguous(inputs[0], enc, s);
57
56
  auto scales = ensure_row_contiguous(inputs[1], enc, s);
@@ -63,7 +62,12 @@ void fast::Quantize::eval_gpu(
63
62
  auto biases = ensure_row_contiguous(inputs[2], enc, s);
64
63
  affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
65
64
  } else {
66
- fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
65
+ // 0 -- xq, 1 -- scales, 2 -- could be global scale for nvfp4
66
+ bool use_global_scale =
67
+ mode_ == QuantizationMode::Nvfp4 && inputs.size() > 2;
68
+ std::optional<array> global_scale =
69
+ use_global_scale ? std::make_optional(inputs[2]) : std::nullopt;
70
+ fp_dequantize(wq, scales, w, group_size_, bits_, global_scale, enc, s);
67
71
  }
68
72
  } else {
69
73
  auto w = ensure_contiguous(inputs[0], enc, s);
@@ -72,12 +76,17 @@ void fast::Quantize::eval_gpu(
72
76
 
73
77
  wq.set_data(cu::malloc_async(wq.nbytes(), enc));
74
78
  scales.set_data(cu::malloc_async(scales.nbytes(), enc));
79
+
75
80
  if (mode_ == QuantizationMode::Affine) {
76
81
  auto& biases = outputs[2];
77
82
  biases.set_data(cu::malloc_async(biases.nbytes(), enc));
78
83
  affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
79
84
  } else {
80
- fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
85
+ bool use_global_scale =
86
+ mode_ == QuantizationMode::Nvfp4 && inputs.size() > 1;
87
+ std::optional<array> global_scale =
88
+ use_global_scale ? std::make_optional(inputs[1]) : std::nullopt;
89
+ fp_quantize(w, wq, scales, group_size_, bits_, global_scale, enc, s);
81
90
  }
82
91
  }
83
92
  }
@@ -1,5 +1,6 @@
1
1
  // Copyright © 2025 Apple Inc.
2
2
 
3
+ #include <optional>
3
4
  #include "mlx/backend/cuda/device.h"
4
5
 
5
6
  namespace mlx::core {
@@ -30,6 +31,7 @@ void fp_quantize(
30
31
  array& scales,
31
32
  int group_size,
32
33
  int bits,
34
+ const std::optional<array>& global_scale,
33
35
  cu::CommandEncoder& enc,
34
36
  const Stream& s);
35
37
 
@@ -39,6 +41,7 @@ void fp_dequantize(
39
41
  array& w,
40
42
  int group_size,
41
43
  int bits,
44
+ const std::optional<array>& global_scale,
42
45
  cu::CommandEncoder& enc,
43
46
  const Stream& s);
44
47
 
@@ -47,6 +50,7 @@ void fp_quantize_dequantize(
47
50
  array& what,
48
51
  int group_size,
49
52
  int bits,
53
+ const std::optional<array>& global_scale,
50
54
  cu::CommandEncoder& enc,
51
55
  const Stream& s);
52
56
 
@@ -29,7 +29,7 @@ inline constexpr __device__ short get_bytes_per_pack() {
29
29
  }
30
30
 
31
31
  template <typename T>
32
- __device__ __forceinline__ void abs_max_x2(T& out, const T& x1, const T& x2) {
32
+ __device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) {
33
33
  if constexpr (
34
34
  (std::is_same<T, __nv_bfloat162>::value) ||
35
35
  (std::is_same<T, __half2>::value)) {
@@ -247,6 +247,10 @@ void CommandEncoder::set_buffer(
247
247
  const MTL::Buffer* buf,
248
248
  int idx,
249
249
  int64_t offset /* = 0 */) {
250
+ // Record as both input and output to ensure synchronization between command
251
+ // buffers
252
+ all_inputs_.insert((void*)buf);
253
+ all_outputs_.insert((void*)buf);
250
254
  enc_->setBuffer(buf, offset, idx);
251
255
  }
252
256
 
@@ -30,7 +30,7 @@ template <typename T, int N>
30
30
  out_pixels *= params->oS[i];
31
31
 
32
32
  // Set out
33
- out += gid.z * filter_size + gid.y * (params->C);
33
+ out += (size_t)gid.z * filter_size + (size_t)gid.y * (params->C);
34
34
 
35
35
  // Coordinates in input
36
36
  int is[N] = {0};
@@ -93,7 +93,8 @@ template <typename T, int N>
93
93
  out_pixels *= params->oS[i];
94
94
 
95
95
  // Set out
96
- out += gid.z * filter_size + gid.x * (filter_size / params->C);
96
+ out +=
97
+ (size_t)gid.z * filter_size + (size_t)gid.x * (filter_size / params->C);
97
98
 
98
99
  // Coordinates in input
99
100
  int is[N] = {0};
@@ -279,6 +279,8 @@ void extract_state(const T state, std::vector<StateT>& unpacked_state) {
279
279
  unpacked_state.push_back(state);
280
280
  } else if constexpr (std::is_enum_v<T>) {
281
281
  unpacked_state.push_back(static_cast<int>(state));
282
+ } else if constexpr (std::is_same_v<T, Dtype>) {
283
+ unpacked_state.push_back(state);
282
284
  } else if constexpr (is_iterable<T>) {
283
285
  unpacked_state.push_back(state);
284
286
  } else if constexpr (is_pair<T> || is_tuple<T>) {
@@ -446,6 +448,7 @@ struct PrimitiveFactory {
446
448
  SERIALIZE_PRIMITIVE(ScaledDotProductAttention),
447
449
  SERIALIZE_PRIMITIVE(CustomKernel)};
448
450
  std::unordered_map<std::string, std::string> name_remap;
451
+ std::unordered_map<int, Stream> stream_map;
449
452
 
450
453
  PrimitiveFactory() {
451
454
  for (auto& [n, f] : factory) {
@@ -471,13 +474,25 @@ struct PrimitiveFactory {
471
474
  }
472
475
  };
473
476
 
474
- std::shared_ptr<Primitive> load(Reader& is) {
475
- auto stream = deserialize<Stream>(is);
476
- if (get_stream(stream.index) != stream) {
477
- std::ostringstream msg;
478
- msg << "[import_function] Invalid stream encountered " << stream << ".";
479
- throw std::invalid_argument(msg.str());
477
+ Stream resolve_stream(const Stream& stream) {
478
+ if (auto it = stream_map.find(stream.index); it != stream_map.end()) {
479
+ return it->second;
480
+ }
481
+ // Try to find an existing stream on the same device
482
+ for (auto& s : get_streams()) {
483
+ if (s.device == stream.device) {
484
+ stream_map.emplace(stream.index, s);
485
+ return s;
486
+ }
480
487
  }
488
+ // No stream on that device, make a new one
489
+ Stream s = new_stream(stream.device);
490
+ stream_map.emplace(stream.index, s);
491
+ return s;
492
+ }
493
+
494
+ std::shared_ptr<Primitive> load(Reader& is) {
495
+ auto stream = resolve_stream(deserialize<Stream>(is));
481
496
  auto name = deserialize<std::string>(is);
482
497
  if (auto it = factory.find(name); it != factory.end()) {
483
498
  return it->second.deserialize(is, stream);