mlx 0.30.7.2 → 0.30.7.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (605) hide show
  1. checksums.yaml +4 -4
  2. data/ext/mlx/extconf.rb +267 -8
  3. data/ext/mlx/native.cpp +112 -58
  4. data/ext/mlx-onnx/native.cpp +1402 -0
  5. data/ext/mlx-onnx/native.hpp +19 -0
  6. data/lib/mlx/core.rb +342 -117
  7. data/lib/mlx/distributed_utils/common.rb +1 -1
  8. data/lib/mlx/distributed_utils/config.rb +7 -4
  9. data/lib/mlx/distributed_utils/launch.rb +2 -0
  10. data/lib/mlx/dsl/attention.rb +132 -0
  11. data/lib/mlx/dsl/builder.rb +8 -0
  12. data/lib/mlx/dsl/config_schema.rb +133 -0
  13. data/lib/mlx/dsl/generate.rb +193 -0
  14. data/lib/mlx/dsl/kv_cache.rb +96 -0
  15. data/lib/mlx/dsl/masks.rb +32 -0
  16. data/lib/mlx/dsl/positions.rb +35 -0
  17. data/lib/mlx/dsl/run_stack.rb +68 -0
  18. data/lib/mlx/dsl/tensor.rb +126 -0
  19. data/lib/mlx/dsl/transformer_block.rb +113 -0
  20. data/lib/mlx/dsl/weight_map.rb +140 -0
  21. data/lib/mlx/dsl.rb +10 -0
  22. data/lib/mlx/nn/base.rb +4 -0
  23. data/lib/mlx/nn/layers/linear.rb +2 -3
  24. data/lib/mlx/onnx.rb +250 -0
  25. data/lib/mlx/version.rb +1 -1
  26. data/lib/mlx-onnx/webgpu_harness.rb +289 -0
  27. data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.cpp +0 -7
  28. data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.cpp +10 -2
  29. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.cpp +97 -46
  30. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cublas_qqmm.h +25 -13
  31. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/fp_quantize.cu +101 -38
  32. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +1 -2
  33. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm.cpp +193 -0
  34. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.cpp +15 -8
  35. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_impl.h +14 -3
  36. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qqmm_utils.cu +36 -0
  37. data/submodules/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +62 -0
  38. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.cpp +12 -3
  39. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized.h +4 -0
  40. data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.cuh +1 -1
  41. data/{mlx → submodules/mlx}/mlx/backend/metal/device.cpp +4 -0
  42. data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/conv.metal +3 -2
  43. data/{mlx → submodules/mlx}/mlx/export.cpp +21 -6
  44. data/{mlx → submodules/mlx}/mlx/ops.cpp +144 -13
  45. data/{mlx → submodules/mlx}/mlx/ops.h +12 -2
  46. data/{mlx → submodules/mlx}/mlx/primitives.cpp +22 -5
  47. data/{mlx → submodules/mlx}/mlx/scheduler.cpp +4 -0
  48. data/{mlx → submodules/mlx}/mlx/scheduler.h +3 -0
  49. data/{mlx → submodules/mlx}/mlx/stream.h +5 -0
  50. data/submodules/mlx-onnx/CMakeLists.txt +159 -0
  51. data/submodules/mlx-onnx/LICENSE +21 -0
  52. data/submodules/mlx-onnx/include/mlx/ir.hpp +88 -0
  53. data/submodules/mlx-onnx/src/api.cpp +81 -0
  54. data/submodules/mlx-onnx/src/compat.cpp +111 -0
  55. data/submodules/mlx-onnx/src/detail.hpp +69 -0
  56. data/submodules/mlx-onnx/src/export.cpp +653 -0
  57. data/submodules/mlx-onnx/src/io.cpp +61 -0
  58. data/submodules/mlx-onnx/src/json.hpp +25 -0
  59. data/submodules/mlx-onnx/src/lowering.cpp +6346 -0
  60. data/submodules/mlx-onnx/src/mappings.cpp +201 -0
  61. data/submodules/mlx-onnx/src/mappings.hpp +16 -0
  62. data/submodules/mlx-onnx/src/onnx.cpp +1029 -0
  63. data/submodules/mlx-onnx/src/shared.cpp +206 -0
  64. metadata +665 -567
  65. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +0 -158
  66. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +0 -30
  67. /data/{mlx → submodules/mlx}/CMakeLists.txt +0 -0
  68. /data/{mlx → submodules/mlx}/cmake/FindCUDNN.cmake +0 -0
  69. /data/{mlx → submodules/mlx}/cmake/FindNCCL.cmake +0 -0
  70. /data/{mlx → submodules/mlx}/cmake/Findnvpl.cmake +0 -0
  71. /data/{mlx → submodules/mlx}/cmake/extension.cmake +0 -0
  72. /data/{mlx → submodules/mlx}/mlx/3rdparty/.clang-format +0 -0
  73. /data/{mlx → submodules/mlx}/mlx/3rdparty/pocketfft.h +0 -0
  74. /data/{mlx → submodules/mlx}/mlx/CMakeLists.txt +0 -0
  75. /data/{mlx → submodules/mlx}/mlx/allocator.h +0 -0
  76. /data/{mlx → submodules/mlx}/mlx/api.h +0 -0
  77. /data/{mlx → submodules/mlx}/mlx/array.cpp +0 -0
  78. /data/{mlx → submodules/mlx}/mlx/array.h +0 -0
  79. /data/{mlx → submodules/mlx}/mlx/backend/common/CMakeLists.txt +0 -0
  80. /data/{mlx → submodules/mlx}/mlx/backend/common/binary.h +0 -0
  81. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.cpp +0 -0
  82. /data/{mlx → submodules/mlx}/mlx/backend/common/broadcasting.h +0 -0
  83. /data/{mlx → submodules/mlx}/mlx/backend/common/buffer_cache.h +0 -0
  84. /data/{mlx → submodules/mlx}/mlx/backend/common/common.cpp +0 -0
  85. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.cpp +0 -0
  86. /data/{mlx → submodules/mlx}/mlx/backend/common/compiled.h +0 -0
  87. /data/{mlx → submodules/mlx}/mlx/backend/common/copy.h +0 -0
  88. /data/{mlx → submodules/mlx}/mlx/backend/common/hadamard.h +0 -0
  89. /data/{mlx → submodules/mlx}/mlx/backend/common/load.cpp +0 -0
  90. /data/{mlx → submodules/mlx}/mlx/backend/common/matmul.h +0 -0
  91. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.cpp +0 -0
  92. /data/{mlx → submodules/mlx}/mlx/backend/common/reduce.h +0 -0
  93. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.cpp +0 -0
  94. /data/{mlx → submodules/mlx}/mlx/backend/common/slicing.h +0 -0
  95. /data/{mlx → submodules/mlx}/mlx/backend/common/ternary.h +0 -0
  96. /data/{mlx → submodules/mlx}/mlx/backend/common/unary.h +0 -0
  97. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.cpp +0 -0
  98. /data/{mlx → submodules/mlx}/mlx/backend/common/utils.h +0 -0
  99. /data/{mlx → submodules/mlx}/mlx/backend/cpu/CMakeLists.txt +0 -0
  100. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arange.h +0 -0
  101. /data/{mlx → submodules/mlx}/mlx/backend/cpu/arg_reduce.cpp +0 -0
  102. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.cpp +0 -0
  103. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary.h +0 -0
  104. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_ops.h +0 -0
  105. /data/{mlx → submodules/mlx}/mlx/backend/cpu/binary_two.h +0 -0
  106. /data/{mlx → submodules/mlx}/mlx/backend/cpu/cholesky.cpp +0 -0
  107. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled.cpp +0 -0
  108. /data/{mlx → submodules/mlx}/mlx/backend/cpu/compiled_preamble.h +0 -0
  109. /data/{mlx → submodules/mlx}/mlx/backend/cpu/conv.cpp +0 -0
  110. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.cpp +0 -0
  111. /data/{mlx → submodules/mlx}/mlx/backend/cpu/copy.h +0 -0
  112. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.cpp +0 -0
  113. /data/{mlx → submodules/mlx}/mlx/backend/cpu/device_info.h +0 -0
  114. /data/{mlx → submodules/mlx}/mlx/backend/cpu/distributed.cpp +0 -0
  115. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eig.cpp +0 -0
  116. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eigh.cpp +0 -0
  117. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.cpp +0 -0
  118. /data/{mlx → submodules/mlx}/mlx/backend/cpu/encoder.h +0 -0
  119. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.cpp +0 -0
  120. /data/{mlx → submodules/mlx}/mlx/backend/cpu/eval.h +0 -0
  121. /data/{mlx → submodules/mlx}/mlx/backend/cpu/fft.cpp +0 -0
  122. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemm.h +0 -0
  123. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/bnns.cpp +0 -0
  124. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/cblas.cpp +0 -0
  125. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_bf16.cpp +0 -0
  126. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_fp16.cpp +0 -0
  127. /data/{mlx → submodules/mlx}/mlx/backend/cpu/gemms/simd_gemm.h +0 -0
  128. /data/{mlx → submodules/mlx}/mlx/backend/cpu/hadamard.cpp +0 -0
  129. /data/{mlx → submodules/mlx}/mlx/backend/cpu/indexing.cpp +0 -0
  130. /data/{mlx → submodules/mlx}/mlx/backend/cpu/inverse.cpp +0 -0
  131. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.cpp +0 -0
  132. /data/{mlx → submodules/mlx}/mlx/backend/cpu/jit_compiler.h +0 -0
  133. /data/{mlx → submodules/mlx}/mlx/backend/cpu/lapack.h +0 -0
  134. /data/{mlx → submodules/mlx}/mlx/backend/cpu/logsumexp.cpp +0 -0
  135. /data/{mlx → submodules/mlx}/mlx/backend/cpu/luf.cpp +0 -0
  136. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.ps1 +0 -0
  137. /data/{mlx → submodules/mlx}/mlx/backend/cpu/make_compiled_preamble.sh +0 -0
  138. /data/{mlx → submodules/mlx}/mlx/backend/cpu/masked_mm.cpp +0 -0
  139. /data/{mlx → submodules/mlx}/mlx/backend/cpu/matmul.cpp +0 -0
  140. /data/{mlx → submodules/mlx}/mlx/backend/cpu/primitives.cpp +0 -0
  141. /data/{mlx → submodules/mlx}/mlx/backend/cpu/qrf.cpp +0 -0
  142. /data/{mlx → submodules/mlx}/mlx/backend/cpu/quantized.cpp +0 -0
  143. /data/{mlx → submodules/mlx}/mlx/backend/cpu/reduce.cpp +0 -0
  144. /data/{mlx → submodules/mlx}/mlx/backend/cpu/scan.cpp +0 -0
  145. /data/{mlx → submodules/mlx}/mlx/backend/cpu/select.cpp +0 -0
  146. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_fp16_simd.h +0 -0
  147. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/accelerate_simd.h +0 -0
  148. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/base_simd.h +0 -0
  149. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/math.h +0 -0
  150. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/neon_fp16_simd.h +0 -0
  151. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/simd.h +0 -0
  152. /data/{mlx → submodules/mlx}/mlx/backend/cpu/simd/type.h +0 -0
  153. /data/{mlx → submodules/mlx}/mlx/backend/cpu/slicing.h +0 -0
  154. /data/{mlx → submodules/mlx}/mlx/backend/cpu/softmax.cpp +0 -0
  155. /data/{mlx → submodules/mlx}/mlx/backend/cpu/sort.cpp +0 -0
  156. /data/{mlx → submodules/mlx}/mlx/backend/cpu/svd.cpp +0 -0
  157. /data/{mlx → submodules/mlx}/mlx/backend/cpu/ternary.h +0 -0
  158. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.cpp +0 -0
  159. /data/{mlx → submodules/mlx}/mlx/backend/cpu/threefry.h +0 -0
  160. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.cpp +0 -0
  161. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary.h +0 -0
  162. /data/{mlx → submodules/mlx}/mlx/backend/cpu/unary_ops.h +0 -0
  163. /data/{mlx → submodules/mlx}/mlx/backend/cuda/CMakeLists.txt +0 -0
  164. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.cpp +0 -0
  165. /data/{mlx → submodules/mlx}/mlx/backend/cuda/allocator.h +0 -0
  166. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arange.cu +0 -0
  167. /data/{mlx → submodules/mlx}/mlx/backend/cuda/arg_reduce.cu +0 -0
  168. /data/{mlx → submodules/mlx}/mlx/backend/cuda/bin2h.cmake +0 -0
  169. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/CMakeLists.txt +0 -0
  170. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/add.cu +0 -0
  171. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/arctan2.cu +0 -0
  172. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/binary.cuh +0 -0
  173. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/bitwise_binary.cu +0 -0
  174. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/divide.cu +0 -0
  175. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/equal.cu +0 -0
  176. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater.cu +0 -0
  177. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/greater_equal.cu +0 -0
  178. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less.cu +0 -0
  179. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/less_equal.cu +0 -0
  180. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/log_add_exp.cu +0 -0
  181. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_and.cu +0 -0
  182. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/logical_or.cu +0 -0
  183. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/maximum.cu +0 -0
  184. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/minimum.cu +0 -0
  185. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/multiply.cu +0 -0
  186. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/not_equal.cu +0 -0
  187. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/power.cu +0 -0
  188. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/remainder.cu +0 -0
  189. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary/subtract.cu +0 -0
  190. /data/{mlx → submodules/mlx}/mlx/backend/cuda/binary_two.cu +0 -0
  191. /data/{mlx → submodules/mlx}/mlx/backend/cuda/compiled.cpp +0 -0
  192. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/conv.h +0 -0
  193. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_conv.cu +0 -0
  194. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv/gemm_grouped_conv.cu +0 -0
  195. /data/{mlx → submodules/mlx}/mlx/backend/cuda/conv.cpp +0 -0
  196. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy.cuh +0 -0
  197. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_contiguous.cu +0 -0
  198. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general.cu +0 -0
  199. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_dynamic.cu +0 -0
  200. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy/copy_general_input.cu +0 -0
  201. /data/{mlx → submodules/mlx}/mlx/backend/cuda/copy.cu +0 -0
  202. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cublas_utils.h +0 -0
  203. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda.h +0 -0
  204. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cuda_utils.h +0 -0
  205. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.cpp +0 -0
  206. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cudnn_utils.h +0 -0
  207. /data/{mlx → submodules/mlx}/mlx/backend/cuda/custom_kernel.cpp +0 -0
  208. /data/{mlx → submodules/mlx}/mlx/backend/cuda/cutlass_utils.cuh +0 -0
  209. /data/{mlx → submodules/mlx}/mlx/backend/cuda/delayload.cpp +0 -0
  210. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/atomic_ops.cuh +0 -0
  211. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/binary_ops.cuh +0 -0
  212. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/cast_op.cuh +0 -0
  213. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/complex.cuh +0 -0
  214. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/config.h +0 -0
  215. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/fp16_math.cuh +0 -0
  216. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather.cuh +0 -0
  217. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/gather_axis.cuh +0 -0
  218. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/indexing.cuh +0 -0
  219. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter.cuh +0 -0
  220. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_axis.cuh +0 -0
  221. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/scatter_ops.cuh +0 -0
  222. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/ternary_ops.cuh +0 -0
  223. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/unary_ops.cuh +0 -0
  224. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device/utils.cuh +0 -0
  225. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.cpp +0 -0
  226. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device.h +0 -0
  227. /data/{mlx → submodules/mlx}/mlx/backend/cuda/device_info.cpp +0 -0
  228. /data/{mlx → submodules/mlx}/mlx/backend/cuda/distributed.cu +0 -0
  229. /data/{mlx → submodules/mlx}/mlx/backend/cuda/eval.cpp +0 -0
  230. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.cu +0 -0
  231. /data/{mlx → submodules/mlx}/mlx/backend/cuda/event.h +0 -0
  232. /data/{mlx → submodules/mlx}/mlx/backend/cuda/fence.cpp +0 -0
  233. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm.h +0 -0
  234. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +0 -0
  235. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +0 -0
  236. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.cu +0 -0
  237. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/gemv.h +0 -0
  238. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm.h +0 -0
  239. /data/{mlx → submodules/mlx}/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +0 -0
  240. /data/{mlx → submodules/mlx}/mlx/backend/cuda/indexing.cpp +0 -0
  241. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.cpp +0 -0
  242. /data/{mlx → submodules/mlx}/mlx/backend/cuda/jit_module.h +0 -0
  243. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cu +0 -0
  244. /data/{mlx → submodules/mlx}/mlx/backend/cuda/kernel_utils.cuh +0 -0
  245. /data/{mlx → submodules/mlx}/mlx/backend/cuda/layer_norm.cu +0 -0
  246. /data/{mlx → submodules/mlx}/mlx/backend/cuda/load.cpp +0 -0
  247. /data/{mlx → submodules/mlx}/mlx/backend/cuda/logsumexp.cu +0 -0
  248. /data/{mlx → submodules/mlx}/mlx/backend/cuda/lru_cache.h +0 -0
  249. /data/{mlx → submodules/mlx}/mlx/backend/cuda/matmul.cpp +0 -0
  250. /data/{mlx → submodules/mlx}/mlx/backend/cuda/no_cuda.cpp +0 -0
  251. /data/{mlx → submodules/mlx}/mlx/backend/cuda/primitives.cpp +0 -0
  252. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/affine_quantize.cu +0 -0
  253. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/convert_fp8.cu +0 -0
  254. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/cuda_fp4.h +0 -0
  255. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +0 -0
  256. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +0 -0
  257. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.cu +0 -0
  258. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/qmv.h +0 -0
  259. /data/{mlx → submodules/mlx}/mlx/backend/cuda/quantized/quantized_utils.h +0 -0
  260. /data/{mlx → submodules/mlx}/mlx/backend/cuda/random.cu +0 -0
  261. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/all_reduce.cu +0 -0
  262. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/col_reduce.cu +0 -0
  263. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/init_reduce.cu +0 -0
  264. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce.cuh +0 -0
  265. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_ops.cuh +0 -0
  266. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/reduce_utils.cuh +0 -0
  267. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce/row_reduce.cu +0 -0
  268. /data/{mlx → submodules/mlx}/mlx/backend/cuda/reduce.cu +0 -0
  269. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rms_norm.cu +0 -0
  270. /data/{mlx → submodules/mlx}/mlx/backend/cuda/rope.cu +0 -0
  271. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cpp +0 -0
  272. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scaled_dot_product_attention.cu +0 -0
  273. /data/{mlx → submodules/mlx}/mlx/backend/cuda/scan.cu +0 -0
  274. /data/{mlx → submodules/mlx}/mlx/backend/cuda/slicing.cpp +0 -0
  275. /data/{mlx → submodules/mlx}/mlx/backend/cuda/softmax.cu +0 -0
  276. /data/{mlx → submodules/mlx}/mlx/backend/cuda/sort.cu +0 -0
  277. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/defines.cuh +0 -0
  278. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/gemm.cuh +0 -0
  279. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/mma.cuh +0 -0
  280. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/tiles.cuh +0 -0
  281. /data/{mlx → submodules/mlx}/mlx/backend/cuda/steel/utils.cuh +0 -0
  282. /data/{mlx → submodules/mlx}/mlx/backend/cuda/ternary.cu +0 -0
  283. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/CMakeLists.txt +0 -0
  284. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/abs.cu +0 -0
  285. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccos.cu +0 -0
  286. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arccosh.cu +0 -0
  287. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsin.cu +0 -0
  288. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arcsinh.cu +0 -0
  289. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctan.cu +0 -0
  290. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/arctanh.cu +0 -0
  291. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/bitwise_invert.cu +0 -0
  292. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/ceil.cu +0 -0
  293. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/conjugate.cu +0 -0
  294. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cos.cu +0 -0
  295. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/cosh.cu +0 -0
  296. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf.cu +0 -0
  297. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/erf_inv.cu +0 -0
  298. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/exp.cu +0 -0
  299. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/expm1.cu +0 -0
  300. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/floor.cu +0 -0
  301. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/imag.cu +0 -0
  302. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log.cu +0 -0
  303. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/log1p.cu +0 -0
  304. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/logical_not.cu +0 -0
  305. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/negative.cu +0 -0
  306. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/real.cu +0 -0
  307. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/round.cu +0 -0
  308. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sigmoid.cu +0 -0
  309. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sign.cu +0 -0
  310. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sin.cu +0 -0
  311. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sinh.cu +0 -0
  312. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/sqrt.cu +0 -0
  313. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/square.cu +0 -0
  314. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tan.cu +0 -0
  315. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/tanh.cu +0 -0
  316. /data/{mlx → submodules/mlx}/mlx/backend/cuda/unary/unary.cuh +0 -0
  317. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.cpp +0 -0
  318. /data/{mlx → submodules/mlx}/mlx/backend/cuda/utils.h +0 -0
  319. /data/{mlx → submodules/mlx}/mlx/backend/cuda/vector_types.cuh +0 -0
  320. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.cpp +0 -0
  321. /data/{mlx → submodules/mlx}/mlx/backend/cuda/worker.h +0 -0
  322. /data/{mlx → submodules/mlx}/mlx/backend/gpu/CMakeLists.txt +0 -0
  323. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.cpp +0 -0
  324. /data/{mlx → submodules/mlx}/mlx/backend/gpu/copy.h +0 -0
  325. /data/{mlx → submodules/mlx}/mlx/backend/gpu/device_info.h +0 -0
  326. /data/{mlx → submodules/mlx}/mlx/backend/gpu/eval.h +0 -0
  327. /data/{mlx → submodules/mlx}/mlx/backend/gpu/primitives.cpp +0 -0
  328. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.cpp +0 -0
  329. /data/{mlx → submodules/mlx}/mlx/backend/gpu/slicing.h +0 -0
  330. /data/{mlx → submodules/mlx}/mlx/backend/metal/CMakeLists.txt +0 -0
  331. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.cpp +0 -0
  332. /data/{mlx → submodules/mlx}/mlx/backend/metal/allocator.h +0 -0
  333. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.cpp +0 -0
  334. /data/{mlx → submodules/mlx}/mlx/backend/metal/binary.h +0 -0
  335. /data/{mlx → submodules/mlx}/mlx/backend/metal/compiled.cpp +0 -0
  336. /data/{mlx → submodules/mlx}/mlx/backend/metal/conv.cpp +0 -0
  337. /data/{mlx → submodules/mlx}/mlx/backend/metal/copy.cpp +0 -0
  338. /data/{mlx → submodules/mlx}/mlx/backend/metal/custom_kernel.cpp +0 -0
  339. /data/{mlx → submodules/mlx}/mlx/backend/metal/device.h +0 -0
  340. /data/{mlx → submodules/mlx}/mlx/backend/metal/device_info.cpp +0 -0
  341. /data/{mlx → submodules/mlx}/mlx/backend/metal/distributed.cpp +0 -0
  342. /data/{mlx → submodules/mlx}/mlx/backend/metal/eval.cpp +0 -0
  343. /data/{mlx → submodules/mlx}/mlx/backend/metal/event.cpp +0 -0
  344. /data/{mlx → submodules/mlx}/mlx/backend/metal/fence.cpp +0 -0
  345. /data/{mlx → submodules/mlx}/mlx/backend/metal/fft.cpp +0 -0
  346. /data/{mlx → submodules/mlx}/mlx/backend/metal/hadamard.cpp +0 -0
  347. /data/{mlx → submodules/mlx}/mlx/backend/metal/indexing.cpp +0 -0
  348. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/includes.h +0 -0
  349. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit/indexing.h +0 -0
  350. /data/{mlx → submodules/mlx}/mlx/backend/metal/jit_kernels.cpp +0 -0
  351. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/CMakeLists.txt +0 -0
  352. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.h +0 -0
  353. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arange.metal +0 -0
  354. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/arg_reduce.metal +0 -0
  355. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/atomic.h +0 -0
  356. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16.h +0 -0
  357. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/bf16_math.h +0 -0
  358. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.h +0 -0
  359. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary.metal +0 -0
  360. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_ops.h +0 -0
  361. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.h +0 -0
  362. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/binary_two.metal +0 -0
  363. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/cexpf.h +0 -0
  364. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/complex.h +0 -0
  365. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.h +0 -0
  366. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/copy.metal +0 -0
  367. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/defines.h +0 -0
  368. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/erf.h +0 -0
  369. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/expm1f.h +0 -0
  370. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fence.metal +0 -0
  371. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/radix.h +0 -0
  372. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft/readwrite.h +0 -0
  373. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.h +0 -0
  374. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fft.metal +0 -0
  375. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp4.h +0 -0
  376. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp8.h +0 -0
  377. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.h +0 -0
  378. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized.metal +0 -0
  379. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.h +0 -0
  380. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/fp_quantized_nax.metal +0 -0
  381. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv.metal +0 -0
  382. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.h +0 -0
  383. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/gemv_masked.metal +0 -0
  384. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/hadamard.h +0 -0
  385. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather.h +0 -0
  386. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_axis.h +0 -0
  387. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/gather_front.h +0 -0
  388. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/indexing.h +0 -0
  389. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/masked_scatter.h +0 -0
  390. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter.h +0 -0
  391. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/indexing/scatter_axis.h +0 -0
  392. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/layer_norm.metal +0 -0
  393. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logging.h +0 -0
  394. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.h +0 -0
  395. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/logsumexp.metal +0 -0
  396. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.h +0 -0
  397. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized.metal +0 -0
  398. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.h +0 -0
  399. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_nax.metal +0 -0
  400. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/quantized_utils.h +0 -0
  401. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/random.metal +0 -0
  402. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.h +0 -0
  403. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce.metal +0 -0
  404. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduce_utils.h +0 -0
  405. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/ops.h +0 -0
  406. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_all.h +0 -0
  407. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_col.h +0 -0
  408. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_init.h +0 -0
  409. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/reduction/reduce_row.h +0 -0
  410. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rms_norm.metal +0 -0
  411. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/rope.metal +0 -0
  412. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +0 -0
  413. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.h +0 -0
  414. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/scan.metal +0 -0
  415. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sdpa_vector.h +0 -0
  416. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.h +0 -0
  417. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/softmax.metal +0 -0
  418. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.h +0 -0
  419. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/sort.metal +0 -0
  420. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/attn.h +0 -0
  421. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +0 -0
  422. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +0 -0
  423. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +0 -0
  424. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +0 -0
  425. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/loader.h +0 -0
  426. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/mma.h +0 -0
  427. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/nax.h +0 -0
  428. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/params.h +0 -0
  429. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/attn/transforms.h +0 -0
  430. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/conv.h +0 -0
  431. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +0 -0
  432. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +0 -0
  433. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +0 -0
  434. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +0 -0
  435. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loader.h +0 -0
  436. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +0 -0
  437. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +0 -0
  438. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +0 -0
  439. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/conv/params.h +0 -0
  440. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/defines.h +0 -0
  441. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm.h +0 -0
  442. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +0 -0
  443. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +0 -0
  444. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +0 -0
  445. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +0 -0
  446. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +0 -0
  447. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +0 -0
  448. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +0 -0
  449. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +0 -0
  450. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +0 -0
  451. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +0 -0
  452. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +0 -0
  453. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +0 -0
  454. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +0 -0
  455. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +0 -0
  456. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +0 -0
  457. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +0 -0
  458. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +0 -0
  459. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/loader.h +0 -0
  460. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/mma.h +0 -0
  461. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/nax.h +0 -0
  462. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/params.h +0 -0
  463. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/gemm/transforms.h +0 -0
  464. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/integral_constant.h +0 -0
  465. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils/type_traits.h +0 -0
  466. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/steel/utils.h +0 -0
  467. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.h +0 -0
  468. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary.metal +0 -0
  469. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/ternary_ops.h +0 -0
  470. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.h +0 -0
  471. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary.metal +0 -0
  472. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/unary_ops.h +0 -0
  473. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels/utils.h +0 -0
  474. /data/{mlx → submodules/mlx}/mlx/backend/metal/kernels.h +0 -0
  475. /data/{mlx → submodules/mlx}/mlx/backend/metal/logsumexp.cpp +0 -0
  476. /data/{mlx → submodules/mlx}/mlx/backend/metal/make_compiled_preamble.sh +0 -0
  477. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.cpp +0 -0
  478. /data/{mlx → submodules/mlx}/mlx/backend/metal/matmul.h +0 -0
  479. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.cpp +0 -0
  480. /data/{mlx → submodules/mlx}/mlx/backend/metal/metal.h +0 -0
  481. /data/{mlx → submodules/mlx}/mlx/backend/metal/no_metal.cpp +0 -0
  482. /data/{mlx → submodules/mlx}/mlx/backend/metal/nojit_kernels.cpp +0 -0
  483. /data/{mlx → submodules/mlx}/mlx/backend/metal/normalization.cpp +0 -0
  484. /data/{mlx → submodules/mlx}/mlx/backend/metal/primitives.cpp +0 -0
  485. /data/{mlx → submodules/mlx}/mlx/backend/metal/quantized.cpp +0 -0
  486. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.cpp +0 -0
  487. /data/{mlx → submodules/mlx}/mlx/backend/metal/reduce.h +0 -0
  488. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.cpp +0 -0
  489. /data/{mlx → submodules/mlx}/mlx/backend/metal/resident.h +0 -0
  490. /data/{mlx → submodules/mlx}/mlx/backend/metal/rope.cpp +0 -0
  491. /data/{mlx → submodules/mlx}/mlx/backend/metal/scaled_dot_product_attention.cpp +0 -0
  492. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.cpp +0 -0
  493. /data/{mlx → submodules/mlx}/mlx/backend/metal/scan.h +0 -0
  494. /data/{mlx → submodules/mlx}/mlx/backend/metal/slicing.cpp +0 -0
  495. /data/{mlx → submodules/mlx}/mlx/backend/metal/softmax.cpp +0 -0
  496. /data/{mlx → submodules/mlx}/mlx/backend/metal/sort.cpp +0 -0
  497. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.cpp +0 -0
  498. /data/{mlx → submodules/mlx}/mlx/backend/metal/ternary.h +0 -0
  499. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.cpp +0 -0
  500. /data/{mlx → submodules/mlx}/mlx/backend/metal/unary.h +0 -0
  501. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.cpp +0 -0
  502. /data/{mlx → submodules/mlx}/mlx/backend/metal/utils.h +0 -0
  503. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/CMakeLists.txt +0 -0
  504. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/compiled.cpp +0 -0
  505. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/device_info.cpp +0 -0
  506. /data/{mlx → submodules/mlx}/mlx/backend/no_cpu/primitives.cpp +0 -0
  507. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/CMakeLists.txt +0 -0
  508. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/allocator.cpp +0 -0
  509. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/apple_memory.h +0 -0
  510. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/device_info.cpp +0 -0
  511. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/eval.cpp +0 -0
  512. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/event.cpp +0 -0
  513. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/fence.cpp +0 -0
  514. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/linux_memory.h +0 -0
  515. /data/{mlx → submodules/mlx}/mlx/backend/no_gpu/primitives.cpp +0 -0
  516. /data/{mlx → submodules/mlx}/mlx/compile.cpp +0 -0
  517. /data/{mlx → submodules/mlx}/mlx/compile.h +0 -0
  518. /data/{mlx → submodules/mlx}/mlx/compile_impl.h +0 -0
  519. /data/{mlx → submodules/mlx}/mlx/device.cpp +0 -0
  520. /data/{mlx → submodules/mlx}/mlx/device.h +0 -0
  521. /data/{mlx → submodules/mlx}/mlx/distributed/CMakeLists.txt +0 -0
  522. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.cpp +0 -0
  523. /data/{mlx → submodules/mlx}/mlx/distributed/distributed.h +0 -0
  524. /data/{mlx → submodules/mlx}/mlx/distributed/distributed_impl.h +0 -0
  525. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/CMakeLists.txt +0 -0
  526. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.cpp +0 -0
  527. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/jaccl.h +0 -0
  528. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.cpp +0 -0
  529. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/mesh.h +0 -0
  530. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/no_jaccl.cpp +0 -0
  531. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.cpp +0 -0
  532. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/ring.h +0 -0
  533. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.cpp +0 -0
  534. /data/{mlx → submodules/mlx}/mlx/distributed/jaccl/utils.h +0 -0
  535. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/CMakeLists.txt +0 -0
  536. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.cpp +0 -0
  537. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi.h +0 -0
  538. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/mpi_declarations.h +0 -0
  539. /data/{mlx → submodules/mlx}/mlx/distributed/mpi/no_mpi.cpp +0 -0
  540. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/CMakeLists.txt +0 -0
  541. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.cpp +0 -0
  542. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl.h +0 -0
  543. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +0 -0
  544. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +0 -0
  545. /data/{mlx → submodules/mlx}/mlx/distributed/nccl/no_nccl.cpp +0 -0
  546. /data/{mlx → submodules/mlx}/mlx/distributed/ops.cpp +0 -0
  547. /data/{mlx → submodules/mlx}/mlx/distributed/ops.h +0 -0
  548. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.cpp +0 -0
  549. /data/{mlx → submodules/mlx}/mlx/distributed/primitives.h +0 -0
  550. /data/{mlx → submodules/mlx}/mlx/distributed/reduction_ops.h +0 -0
  551. /data/{mlx → submodules/mlx}/mlx/distributed/ring/CMakeLists.txt +0 -0
  552. /data/{mlx → submodules/mlx}/mlx/distributed/ring/no_ring.cpp +0 -0
  553. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.cpp +0 -0
  554. /data/{mlx → submodules/mlx}/mlx/distributed/ring/ring.h +0 -0
  555. /data/{mlx → submodules/mlx}/mlx/distributed/utils.cpp +0 -0
  556. /data/{mlx → submodules/mlx}/mlx/distributed/utils.h +0 -0
  557. /data/{mlx → submodules/mlx}/mlx/dtype.cpp +0 -0
  558. /data/{mlx → submodules/mlx}/mlx/dtype.h +0 -0
  559. /data/{mlx → submodules/mlx}/mlx/dtype_utils.cpp +0 -0
  560. /data/{mlx → submodules/mlx}/mlx/dtype_utils.h +0 -0
  561. /data/{mlx → submodules/mlx}/mlx/einsum.cpp +0 -0
  562. /data/{mlx → submodules/mlx}/mlx/einsum.h +0 -0
  563. /data/{mlx → submodules/mlx}/mlx/event.h +0 -0
  564. /data/{mlx → submodules/mlx}/mlx/export.h +0 -0
  565. /data/{mlx → submodules/mlx}/mlx/export_impl.h +0 -0
  566. /data/{mlx → submodules/mlx}/mlx/fast.cpp +0 -0
  567. /data/{mlx → submodules/mlx}/mlx/fast.h +0 -0
  568. /data/{mlx → submodules/mlx}/mlx/fast_primitives.h +0 -0
  569. /data/{mlx → submodules/mlx}/mlx/fence.h +0 -0
  570. /data/{mlx → submodules/mlx}/mlx/fft.cpp +0 -0
  571. /data/{mlx → submodules/mlx}/mlx/fft.h +0 -0
  572. /data/{mlx → submodules/mlx}/mlx/graph_utils.cpp +0 -0
  573. /data/{mlx → submodules/mlx}/mlx/graph_utils.h +0 -0
  574. /data/{mlx → submodules/mlx}/mlx/io/CMakeLists.txt +0 -0
  575. /data/{mlx → submodules/mlx}/mlx/io/gguf.cpp +0 -0
  576. /data/{mlx → submodules/mlx}/mlx/io/gguf.h +0 -0
  577. /data/{mlx → submodules/mlx}/mlx/io/gguf_quants.cpp +0 -0
  578. /data/{mlx → submodules/mlx}/mlx/io/load.cpp +0 -0
  579. /data/{mlx → submodules/mlx}/mlx/io/load.h +0 -0
  580. /data/{mlx → submodules/mlx}/mlx/io/no_gguf.cpp +0 -0
  581. /data/{mlx → submodules/mlx}/mlx/io/no_safetensors.cpp +0 -0
  582. /data/{mlx → submodules/mlx}/mlx/io/safetensors.cpp +0 -0
  583. /data/{mlx → submodules/mlx}/mlx/io.h +0 -0
  584. /data/{mlx → submodules/mlx}/mlx/linalg.cpp +0 -0
  585. /data/{mlx → submodules/mlx}/mlx/linalg.h +0 -0
  586. /data/{mlx → submodules/mlx}/mlx/memory.h +0 -0
  587. /data/{mlx → submodules/mlx}/mlx/mlx.h +0 -0
  588. /data/{mlx → submodules/mlx}/mlx/primitives.h +0 -0
  589. /data/{mlx → submodules/mlx}/mlx/random.cpp +0 -0
  590. /data/{mlx → submodules/mlx}/mlx/random.h +0 -0
  591. /data/{mlx → submodules/mlx}/mlx/small_vector.h +0 -0
  592. /data/{mlx → submodules/mlx}/mlx/threadpool.h +0 -0
  593. /data/{mlx → submodules/mlx}/mlx/transforms.cpp +0 -0
  594. /data/{mlx → submodules/mlx}/mlx/transforms.h +0 -0
  595. /data/{mlx → submodules/mlx}/mlx/transforms_impl.h +0 -0
  596. /data/{mlx → submodules/mlx}/mlx/types/bf16.h +0 -0
  597. /data/{mlx → submodules/mlx}/mlx/types/complex.h +0 -0
  598. /data/{mlx → submodules/mlx}/mlx/types/fp16.h +0 -0
  599. /data/{mlx → submodules/mlx}/mlx/types/half_types.h +0 -0
  600. /data/{mlx → submodules/mlx}/mlx/types/limits.h +0 -0
  601. /data/{mlx → submodules/mlx}/mlx/utils.cpp +0 -0
  602. /data/{mlx → submodules/mlx}/mlx/utils.h +0 -0
  603. /data/{mlx → submodules/mlx}/mlx/version.cpp +0 -0
  604. /data/{mlx → submodules/mlx}/mlx/version.h +0 -0
  605. /data/{mlx → submodules/mlx}/mlx.pc.in +0 -0
@@ -11,6 +11,11 @@
11
11
 
12
12
  #include <cooperative_groups.h>
13
13
  #include <cooperative_groups/reduce.h>
14
+ #include <cuda_fp4.h>
15
+ #include <cuda_fp8.h>
16
+
17
+ constexpr float F8E4M3_MAX = 448.0f;
18
+ constexpr float F4E2M1_MAX = 6.0f;
14
19
 
15
20
  namespace mlx::core {
16
21
  namespace cu {
@@ -29,7 +34,16 @@ struct Dequantize {
29
34
  namespace cg = cooperative_groups;
30
35
 
31
36
  template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
32
- __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) {
37
+ __global__ void fp_quantize_dequantize(
38
+ T* w,
39
+ T* out,
40
+ size_t size,
41
+ float* global_scale = nullptr) {
42
+ const bool use_global_scale = global_scale != nullptr;
43
+ const float scale_enc =
44
+ use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;
45
+ const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f;
46
+
33
47
  using Tx2 = Vector2_t<T>;
34
48
  using Tx4 = Vector4_t<T>;
35
49
  uint32_t rbits = 0; // reserved bits for future use
@@ -48,26 +62,28 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) {
48
62
  }
49
63
 
50
64
  auto w_tile = load_vector<group_size, T>(w, thread_idx);
51
- float scale = 0.0f;
65
+ float scale_dec_b = 0.0f;
52
66
 
53
67
  Tx2 amax_2x = Tx2{0.0f, 0.0f};
54
68
 
55
69
  #pragma unroll
56
70
  for (int i = 0; i < group_size; i += 2) {
57
71
  auto pair = Tx2{w_tile[i], w_tile[i + 1]};
58
- abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
72
+ absmax_x2<Tx2>(amax_2x, amax_2x, pair);
59
73
  }
60
74
 
61
- scale = static_cast<float>(
75
+ scale_dec_b = static_cast<float>(
62
76
  max(fabsf(static_cast<float>(amax_2x.x)),
63
77
  fabsf(static_cast<float>(amax_2x.y))));
64
78
 
65
- scale /= bits == 4 ? 6.0f : 448.0f;
79
+ scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;
80
+ scale_dec_b *= scale_enc;
66
81
  // Convert to mx scale or nv scale
67
82
  using ScaleType =
68
83
  std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
69
- auto s = ScaleType(scale);
70
- scale = float(s);
84
+ auto s = ScaleType(scale_dec_b);
85
+ float scale_enc_b = scale_enc / float(s);
86
+ float scale_dec = float(s) * inv_scale_enc;
71
87
  AlignedVector<T, group_size> w_hat;
72
88
 
73
89
  #pragma unroll
@@ -76,24 +92,36 @@ __global__ void fp_quantize_dequantize(T* w, T* out, size_t size) {
76
92
  float4 dq;
77
93
  if constexpr (bits == 8) {
78
94
  uint32_t quantized_val =
79
- scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
95
+ scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
80
96
  dq = dequant_fp8(quantized_val);
81
97
  } else {
82
98
  uint16_t quantized_val =
83
- scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
99
+ scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
84
100
  dq = dequant_fp4(quantized_val);
85
101
  }
86
- w_hat[i * 4] = static_cast<T>(dq.x * scale);
87
- w_hat[i * 4 + 1] = static_cast<T>(dq.y * scale);
88
- w_hat[i * 4 + 2] = static_cast<T>(dq.z * scale);
89
- w_hat[i * 4 + 3] = static_cast<T>(dq.w * scale);
102
+ w_hat[i * 4] = static_cast<T>(dq.x * scale_dec);
103
+ w_hat[i * 4 + 1] = static_cast<T>(dq.y * scale_dec);
104
+ w_hat[i * 4 + 2] = static_cast<T>(dq.z * scale_dec);
105
+ w_hat[i * 4 + 3] = static_cast<T>(dq.w * scale_dec);
90
106
  }
91
107
  store_vector<group_size>(out, thread_idx, w_hat);
92
108
  }
93
109
 
94
110
  template <typename T, int group_size, int bits, bool use_mx_scale, bool USE_SR>
95
- __global__ void
96
- fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) {
111
+ __global__ void fp_quantize_rowwise(
112
+ T* w,
113
+ uint8_t* out,
114
+ uint8_t* scales,
115
+ size_t size,
116
+ float* global_scale = nullptr) {
117
+ // NVFP4 conversion:
118
+ // Global encode scale: (448 × 6) / *global_scale
119
+ // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8
120
+ // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b
121
+ const bool use_global_scale = global_scale != nullptr;
122
+ const float scale_enc =
123
+ use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;
124
+
97
125
  using Tx2 = Vector2_t<T>;
98
126
  using Tx4 = Vector4_t<T>;
99
127
  uint32_t rbits = 0; // reserved bits for future use
@@ -112,27 +140,28 @@ fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) {
112
140
  }
113
141
 
114
142
  auto w_tile = load_vector<group_size, T>(w, thread_idx);
115
- float scale = 0.0f;
143
+ float scale_dec_b = 0.0f;
116
144
 
117
145
  Tx2 amax_2x = Tx2{0.0f, 0.0f};
118
146
 
119
147
  #pragma unroll
120
148
  for (int i = 0; i < group_size; i += 2) {
121
149
  auto pair = Tx2{w_tile[i], w_tile[i + 1]};
122
- abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
150
+ absmax_x2<Tx2>(amax_2x, amax_2x, pair);
123
151
  }
124
152
 
125
- scale = static_cast<float>(
153
+ scale_dec_b = static_cast<float>(
126
154
  max(fabsf(static_cast<float>(amax_2x.x)),
127
155
  fabsf(static_cast<float>(amax_2x.y))));
128
156
 
129
- scale /= bits == 4 ? 6.0f : 448.0f;
157
+ scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;
158
+ scale_dec_b *= scale_enc;
130
159
  // Convert to mx scale or nv scale
131
160
  using ScaleType =
132
161
  std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
133
- auto s = ScaleType(scale);
162
+ auto s = ScaleType(scale_dec_b);
134
163
  uint8_t q_scale = s.__x;
135
- scale = float(s);
164
+ float scale_enc_b = scale_enc / float(s);
136
165
 
137
166
  scales[thread_idx] = q_scale;
138
167
  constexpr int elem_per_byte = bits == 8 ? 1 : 2;
@@ -143,11 +172,11 @@ fp_quantize_rowwise(T* w, uint8_t* out, uint8_t* scales, size_t size) {
143
172
  Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&w_tile[i * 4]);
144
173
  if constexpr (bits == 8) {
145
174
  uint32_t quantized_val =
146
- scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
175
+ scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
147
176
  *reinterpret_cast<uint32_t*>(&quantized[i * 4]) = quantized_val;
148
177
  } else {
149
178
  uint16_t quantized_val =
150
- scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
179
+ scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
151
180
  *reinterpret_cast<uint16_t*>(&quantized[i * 2]) = quantized_val;
152
181
  }
153
182
  }
@@ -161,11 +190,15 @@ __global__ void fp_quantize_columnwise(
161
190
  uint8_t* scales,
162
191
  size_t size,
163
192
  int M,
164
- int K) {
193
+ int K,
194
+ float* global_scale = nullptr) {
165
195
  // Input: [M, K] with strides [1, M] (M-major)
166
196
  // Quantized output: [M, K/elem_per_byte] row-major (K-major)
167
197
  // Scales: [M, K/group_size] row-major (K-major)
168
198
  // Quantize along K (last dimension, groups of group_size elements)
199
+ const bool use_global_scale = global_scale != nullptr;
200
+ const float scale_enc =
201
+ use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f;
169
202
 
170
203
  using Tx2 = Vector2_t<T>;
171
204
  using Tx4 = Vector4_t<T>;
@@ -215,16 +248,18 @@ __global__ void fp_quantize_columnwise(
215
248
  #pragma unroll
216
249
  for (int r = 0; r < group_size; r += 2) {
217
250
  auto pair = Tx2{thread_data[r], thread_data[r + 1]};
218
- abs_max_x2<Tx2>(amax_2x, amax_2x, pair);
251
+ absmax_x2<Tx2>(amax_2x, amax_2x, pair);
219
252
  }
220
- float scale =
253
+ float scale_dec_b =
221
254
  max(fabsf(static_cast<float>(amax_2x.x)),
222
255
  fabsf(static_cast<float>(amax_2x.y)));
223
- scale /= (bits == 4) ? 6.0f : 448.0f;
256
+ scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX;
257
+ scale_dec_b *= scale_enc;
258
+ // Convert to mx scale or nv scale
224
259
  using ScaleType =
225
260
  std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
226
- auto s = ScaleType(scale);
227
- scale = float(s);
261
+ auto s = ScaleType(scale_dec_b);
262
+ float scale_enc_b = scale_enc / float(s);
228
263
  scales_smem[tidx][tidy] = s.__x;
229
264
 
230
265
  int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group;
@@ -234,12 +269,12 @@ __global__ void fp_quantize_columnwise(
234
269
  Tx4 w_Tx4 = *reinterpret_cast<Tx4*>(&thread_data[j * 4]);
235
270
  if constexpr (bits == 8) {
236
271
  uint32_t quantized_val =
237
- scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
272
+ scale_cvt_Tx4_to_fp8x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
238
273
  *reinterpret_cast<uint32_t*>(&quantized_smem[shared_idx + j * 4]) =
239
274
  quantized_val;
240
275
  } else {
241
276
  uint16_t quantized_val =
242
- scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, 1.0f / scale, rbits);
277
+ scale_cvt_Tx4_to_fp4x4<T, USE_SR>(w_Tx4, scale_enc_b, rbits);
243
278
  *reinterpret_cast<uint16_t*>(&quantized_smem[shared_idx + j * 2]) =
244
279
  quantized_val;
245
280
  }
@@ -282,8 +317,12 @@ __global__ void fp_quantize_columnwise(
282
317
  }
283
318
 
284
319
  template <typename T, int group_size, int bits, bool use_mx_scale>
285
- __global__ void
286
- fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
320
+ __global__ void fp_dequantize(
321
+ const uint8_t* w,
322
+ const uint8_t* scales,
323
+ T* out,
324
+ size_t size,
325
+ float* global_scale = nullptr) {
287
326
  auto block_size = cg::this_thread_block().dim_threads();
288
327
  auto block_idx = cg::this_thread_block().group_index();
289
328
  auto idx_in_block = cg::this_thread_block().thread_index();
@@ -294,6 +333,10 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
294
333
  auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;
295
334
 
296
335
  constexpr int pack_factor = bits == 8 ? 1 : 2;
336
+ const bool use_global_scale = global_scale != nullptr;
337
+ const float inv_scale_enc = use_mx_scale
338
+ ? 1.0f
339
+ : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f);
297
340
  size_t offset = tidx + grid_dim_x * size_t(tidy);
298
341
  size_t oindex = offset * pack_factor;
299
342
 
@@ -304,7 +347,7 @@ fp_dequantize(const uint8_t* w, const uint8_t* scales, T* out, size_t size) {
304
347
  size_t gindex = oindex / group_size;
305
348
  using ScaleType =
306
349
  std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
307
- auto scale = float(((ScaleType*)(scales))[gindex]);
350
+ auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc;
308
351
 
309
352
  out += oindex;
310
353
 
@@ -346,9 +389,13 @@ void fp_quantize_dequantize(
346
389
  array& what,
347
390
  int group_size,
348
391
  int bits,
392
+ const std::optional<array>& global_scale /* = std::nullopt */,
349
393
  cu::CommandEncoder& enc,
350
394
  const Stream& s) {
351
395
  enc.set_input_array(w);
396
+ if (global_scale.has_value()) {
397
+ enc.set_input_array(global_scale.value());
398
+ }
352
399
  enc.set_output_array(what);
353
400
  dispatch_float_types(w.dtype(), "fp_quantize_dequantize", [&](auto type_tag) {
354
401
  using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -370,7 +417,9 @@ void fp_quantize_dequantize(
370
417
  0,
371
418
  gpu_ptr<T>(w),
372
419
  gpu_ptr<T>(what),
373
- w.size());
420
+ w.size(),
421
+ global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
422
+ : nullptr);
374
423
  }
375
424
  });
376
425
  }
@@ -381,9 +430,13 @@ void fp_quantize(
381
430
  array& scales,
382
431
  int group_size,
383
432
  int bits,
433
+ const std::optional<array>& global_scale /* = std::nullopt */,
384
434
  cu::CommandEncoder& enc,
385
435
  const Stream& s) {
386
436
  enc.set_input_array(w);
437
+ if (global_scale.has_value()) {
438
+ enc.set_input_array(global_scale.value());
439
+ }
387
440
  enc.set_output_array(wq);
388
441
  enc.set_output_array(scales);
389
442
  if (w.strides().back() != 1) {
@@ -410,7 +463,9 @@ void fp_quantize(
410
463
  gpu_ptr<uint8_t>(scales),
411
464
  w.size(),
412
465
  M,
413
- K);
466
+ K,
467
+ global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
468
+ : nullptr);
414
469
  } else {
415
470
  throw std::runtime_error(
416
471
  "[Quantize::eval_gpu] Can not quantize input with type float64.");
@@ -438,7 +493,9 @@ void fp_quantize(
438
493
  gpu_ptr<T>(w),
439
494
  gpu_ptr<uint8_t>(wq),
440
495
  gpu_ptr<uint8_t>(scales),
441
- w.size());
496
+ w.size(),
497
+ global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
498
+ : nullptr);
442
499
  } else {
443
500
  throw std::runtime_error(
444
501
  "[Quantize::eval_gpu] Can not quantize input with type float64.");
@@ -453,6 +510,7 @@ void fp_dequantize(
453
510
  array& w,
454
511
  int group_size,
455
512
  int bits,
513
+ const std::optional<array>& global_scale /* = std::nullopt */,
456
514
  cu::CommandEncoder& enc,
457
515
  const Stream& s) {
458
516
  constexpr int uint8_per_uint32 = 4;
@@ -465,6 +523,9 @@ void fp_dequantize(
465
523
 
466
524
  enc.set_input_array(wq);
467
525
  enc.set_input_array(scales);
526
+ if (global_scale.has_value()) {
527
+ enc.set_input_array(global_scale.value());
528
+ }
468
529
  enc.set_output_array(w);
469
530
  dispatch_float_types(w.dtype(), "fp_dequantize", [&](auto type_tag) {
470
531
  using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
@@ -485,7 +546,9 @@ void fp_dequantize(
485
546
  gpu_ptr<uint8_t>(wq),
486
547
  gpu_ptr<uint8_t>(scales),
487
548
  gpu_ptr<T>(w),
488
- w.size());
549
+ w.size(),
550
+ global_scale.has_value() ? gpu_ptr<float>(global_scale.value())
551
+ : nullptr);
489
552
  } else {
490
553
  throw std::runtime_error(
491
554
  "[Quantize::eval_gpu] Can not dequantize to output with type float64.");
@@ -17,9 +17,8 @@ void qqmm_impl(
17
17
  const array&,
18
18
  const array&,
19
19
  const array&,
20
- Dtype,
21
20
  QuantizationMode,
22
- float) {
21
+ const GemmScalars&) {
23
22
  throw std::runtime_error(
24
23
  "[QQMatmul::eval_gpu] QQMM is only supported with CUDA 12.8 or higher.");
25
24
  }
@@ -0,0 +1,193 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/device.h"
4
+ #include "mlx/backend/cuda/quantized/qmv.h"
5
+ #include "mlx/backend/cuda/quantized/qqmm_impl.h"
6
+ #include "mlx/backend/cuda/quantized/qqmm_utils.h"
7
+ #include "mlx/backend/cuda/quantized/quantized.h"
8
+ #include "mlx/backend/cuda/quantized/quantized_utils.h"
9
+ #include "mlx/primitives.h"
10
+
11
+ #include <nvtx3/nvtx3.hpp>
12
+
13
+ namespace mlx::core {
14
+
15
+ namespace {
16
+
17
+ std::tuple<array, array> quantize_input(
18
+ const array& input,
19
+ cu::CommandEncoder& encoder,
20
+ const Stream& s,
21
+ QuantizationMode mode,
22
+ int bits,
23
+ int group_size,
24
+ std::optional<array> global_scale = std::nullopt) {
25
+ const array x = ensure_contiguous(input, encoder, s);
26
+
27
+ // Compute output shapes
28
+ auto xq_shape = x.shape();
29
+ xq_shape.back() = x.shape(-1) * bits / 32;
30
+
31
+ const int64_t scales_inner = x.shape(-1) / group_size;
32
+ auto [pad_outer, pad_inner] =
33
+ get_padded_scale_dims(x.shape(-2), scales_inner);
34
+
35
+ auto sshape = x.shape();
36
+ sshape[x.ndim() - 2] = pad_outer;
37
+ sshape[x.ndim() - 1] = pad_inner;
38
+ sshape.back() = scales_inner;
39
+
40
+ // Allocate outputs
41
+ const int64_t xq_bytes = x.size() * bits / 8;
42
+ const int64_t batch = x.size() / (x.shape(-2) * x.shape(-1));
43
+ const int64_t scales_bytes = batch * (pad_outer * pad_inner);
44
+
45
+ array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32);
46
+ array scales_x(
47
+ cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8);
48
+ encoder.add_temporary(x_q);
49
+ encoder.add_temporary(scales_x);
50
+ // global_scale is not nullopt only for NVFP4
51
+ fp_quantize(x, x_q, scales_x, group_size, bits, global_scale, encoder, s);
52
+ return {std::move(x_q), std::move(scales_x)};
53
+ }
54
+
55
+ GemmScalars create_nvfp4_scalars(
56
+ const array& global_scale_x,
57
+ const array& global_scale_w,
58
+ cu::CommandEncoder& encoder) {
59
+ // NVFP4 requires alpha/beta as device pointers
60
+ // alpha = amax_x * amax_w / (448 * 6)^2
61
+ // beta = 0
62
+ array alpha(cu::malloc_async(sizeof(float), encoder), {}, float32);
63
+ array beta(cu::malloc_async(sizeof(float), encoder), {}, float32);
64
+ compute_qqmm_pointers(alpha, beta, global_scale_x, global_scale_w, encoder);
65
+ encoder.add_temporary(alpha);
66
+ encoder.add_temporary(beta);
67
+ return {alpha, beta};
68
+ }
69
+
70
+ } // namespace
71
+
72
+ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
73
+ nvtx3::scoped_range r("QQMatmul::eval_gpu");
74
+
75
+ auto& s = stream();
76
+ auto& encoder = cu::get_command_encoder(s);
77
+ auto& device = encoder.device();
78
+ bool w_quantized = (inputs[1].dtype() == uint32);
79
+ int base_size = w_quantized ? 3 : 2;
80
+
81
+ assert(
82
+ inputs.size() == base_size ||
83
+ (mode_ == QuantizationMode::Nvfp4 && inputs.size() == base_size + 2));
84
+
85
+ if (w_quantized && inputs[0].shape(-2) == 1) {
86
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
87
+
88
+ // For nvfp4, get global scale for x from inputs if present
89
+ bool has_global_scale =
90
+ mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;
91
+ std::optional<array> global_scale = std::nullopt;
92
+ if (has_global_scale) {
93
+ global_scale = inputs[inputs.size() - 2];
94
+ }
95
+
96
+ bool donate_x = inputs[0].is_donatable();
97
+ array x = ensure_row_contiguous(inputs[0], encoder, s);
98
+ // If x is a copy it should be donatable
99
+ donate_x |= x.is_donatable();
100
+ auto xhat = donate_x
101
+ ? x
102
+ : array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype());
103
+ if (!donate_x) {
104
+ encoder.add_temporary(xhat);
105
+ }
106
+ fp_quantize_dequantize(
107
+ x, xhat, group_size_, bits_, global_scale, encoder, s);
108
+
109
+ // Make sure the last two dims of w and s are contiguous
110
+ array w = ensure_row_contiguous_matrix(inputs[1], encoder, s);
111
+ array scales = ensure_row_contiguous_matrix(inputs[2], encoder, s);
112
+
113
+ bool non_batched = w.ndim() == 2;
114
+ int K = x.shape(-1);
115
+ int M = non_batched ? x.size() / K : x.shape(-2);
116
+ int N = out.shape(-1);
117
+
118
+ fp_qmv(w, scales, xhat, out, bits_, group_size_, M, N, K, encoder);
119
+ return;
120
+ }
121
+
122
+ auto cc = device.compute_capability_major() * 100 +
123
+ device.compute_capability_minor() * 10;
124
+ if (cc < 1000) {
125
+ throw std::runtime_error(
126
+ "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher.");
127
+ }
128
+
129
+ // - 2 inputs: x, w (non-quantized w)
130
+ // - 3 inputs: x, w, scales_w (quantized w)
131
+
132
+ // For nvfp4, global scales are optional but must be both present or both
133
+ // absent If present, they add 2 more inputs (global_scale_x, global_scale_w)
134
+ bool has_global_scales =
135
+ mode_ == QuantizationMode::Nvfp4 && inputs.size() > base_size;
136
+
137
+ // For nvfp4, get global scales from inputs if present
138
+ std::optional<array> global_scale_x = std::nullopt;
139
+ std::optional<array> global_scale_w = std::nullopt;
140
+ if (has_global_scales) {
141
+ global_scale_x = inputs[inputs.size() - 2];
142
+ global_scale_w = inputs[inputs.size() - 1];
143
+ }
144
+
145
+ // Quantize inputs (or use pre-quantized)
146
+ auto [x_q, scale_x_pre] = quantize_input(
147
+ inputs[0], encoder, s, mode_, bits_, group_size_, global_scale_x);
148
+ auto [w_q, scale_w_pre] = !w_quantized
149
+ ? quantize_input(
150
+ inputs[1], encoder, s, mode_, bits_, group_size_, global_scale_w)
151
+ : std::make_tuple(
152
+ ensure_contiguous(inputs[1], encoder, s),
153
+ ensure_contiguous(inputs[2], encoder, s));
154
+
155
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
156
+
157
+ int M = x_q.shape(-2);
158
+ int N = w_q.shape(-2); // transposed
159
+ int K = x_q.shape(-1) * (32 / bits_);
160
+
161
+ bool x_transposed = false;
162
+ bool w_transposed = true; // always transposed
163
+ int64_t lda = K;
164
+ int64_t ldb = K;
165
+
166
+ // Repack scales to tiled layout for tensor cores
167
+ array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s);
168
+ array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s);
169
+
170
+ GemmScalars scalars;
171
+ if (has_global_scales) {
172
+ scalars = create_nvfp4_scalars(*global_scale_x, *global_scale_w, encoder);
173
+ }
174
+
175
+ qqmm_impl(
176
+ encoder,
177
+ M,
178
+ N,
179
+ K,
180
+ x_transposed,
181
+ lda,
182
+ w_transposed,
183
+ ldb,
184
+ out,
185
+ x_q,
186
+ w_q,
187
+ scale_x,
188
+ scale_w,
189
+ mode_,
190
+ scalars);
191
+ }
192
+
193
+ } // namespace mlx::core
@@ -19,15 +19,10 @@ void qqmm_impl(
19
19
  const array& b,
20
20
  const array& a_scale,
21
21
  const array& b_scale,
22
- Dtype out_dtype,
23
22
  QuantizationMode mode,
24
- float alpha) {
25
- // Invoke CublasQQMM
23
+ const GemmScalars& scalars) {
26
24
  std::string qmode = quantization_mode_to_string(mode);
27
25
 
28
- // Currently only supports non-batched QQMM operations
29
- // that covers all use cases for training, we will just collapse (batch,
30
- // seq_len) into (tokens)
31
26
  CublasQQMM qqmm(
32
27
  encoder.device(),
33
28
  a_transposed,
@@ -41,10 +36,22 @@ void qqmm_impl(
41
36
  1, // batch_count
42
37
  0, // a_batch_stride
43
38
  0, // b_batch_stride
44
- out_dtype,
39
+ out.dtype(),
45
40
  qmode);
46
41
 
47
- qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha);
42
+ if (scalars.has_values()) {
43
+ qqmm.run(
44
+ encoder,
45
+ out,
46
+ a,
47
+ b,
48
+ a_scale,
49
+ b_scale,
50
+ *scalars.alpha_device,
51
+ *scalars.beta_device);
52
+ } else {
53
+ qqmm.run(encoder, out, a, b, a_scale, b_scale);
54
+ }
48
55
  }
49
56
 
50
57
  } // namespace mlx::core
@@ -1,10 +1,22 @@
1
- // Copyright © 2026 Apple Inc.
1
+ // Copyright © 2025 Apple Inc.
2
2
  #pragma once
3
3
 
4
4
  #include "mlx/backend/cuda/device.h"
5
5
  #include "mlx/primitives.h"
6
6
 
7
+ #include <optional>
8
+
7
9
  namespace mlx::core {
10
+
11
+ struct GemmScalars {
12
+ std::optional<array> alpha_device;
13
+ std::optional<array> beta_device;
14
+
15
+ bool has_values() const {
16
+ return alpha_device.has_value();
17
+ }
18
+ };
19
+
8
20
  void qqmm_impl(
9
21
  cu::CommandEncoder& encoder,
10
22
  int M,
@@ -19,8 +31,7 @@ void qqmm_impl(
19
31
  const array& b,
20
32
  const array& a_scale,
21
33
  const array& b_scale,
22
- Dtype out_dtype,
23
34
  QuantizationMode mode,
24
- float alpha = 1.0f);
35
+ const GemmScalars& scalars = {});
25
36
 
26
37
  } // namespace mlx::core
@@ -70,6 +70,21 @@ inline std::tuple<dim3, dim3> get_swizzle_launch_args(
70
70
 
71
71
  namespace cu {
72
72
 
73
+ constexpr float F8E4M3_MAX = 448.0f;
74
+ constexpr float F4E2M1_MAX = 6.0f;
75
+
76
+ __global__ void compute_qqmm_pointers(
77
+ float* alpha_out,
78
+ float* beta_out,
79
+ const float* tensor_amax_x,
80
+ const float* tensor_amax_w) {
81
+ // Compute alpha = tensor_amax_x * tensor_amax_w / (448 * 6)^2
82
+ constexpr float inv_scale_sq =
83
+ 1.0f / (F8E4M3_MAX * F4E2M1_MAX * F8E4M3_MAX * F4E2M1_MAX);
84
+ *alpha_out = (*tensor_amax_x) * (*tensor_amax_w) * inv_scale_sq;
85
+ *beta_out = 0.0f;
86
+ }
87
+
73
88
  __global__ void swizzle_scales(
74
89
  const uint8_t* scales_linear,
75
90
  uint8_t* scales_swizzled,
@@ -224,4 +239,25 @@ void swizzle_scales(
224
239
  output_cols);
225
240
  }
226
241
 
242
+ void compute_qqmm_pointers(
243
+ array& alpha_out,
244
+ array& beta_out,
245
+ const array& tensor_amax_x,
246
+ const array& tensor_amax_w,
247
+ cu::CommandEncoder& enc) {
248
+ enc.set_input_array(tensor_amax_x);
249
+ enc.set_input_array(tensor_amax_w);
250
+ enc.set_output_array(alpha_out);
251
+ enc.set_output_array(beta_out);
252
+ enc.add_kernel_node(
253
+ cu::compute_qqmm_pointers,
254
+ dim3(1),
255
+ dim3(1),
256
+ 0,
257
+ gpu_ptr<void>(alpha_out),
258
+ gpu_ptr<void>(beta_out),
259
+ gpu_ptr<void>(tensor_amax_x),
260
+ gpu_ptr<void>(tensor_amax_w));
261
+ }
262
+
227
263
  } // namespace mlx::core