mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,158 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/device.h"
4
+ #include "mlx/backend/cuda/quantized/qmv.h"
5
+ #include "mlx/backend/cuda/quantized/qqmm_impl.h"
6
+ #include "mlx/backend/cuda/quantized/qqmm_utils.h"
7
+ #include "mlx/backend/cuda/quantized/quantized.h"
8
+ #include "mlx/backend/cuda/quantized/quantized_utils.h"
9
+ #include "mlx/primitives.h"
10
+
11
+ #include <nvtx3/nvtx3.hpp>
12
+
13
+ namespace mlx::core {
14
+
15
+ namespace {
16
+
17
+ array pad_and_swizzle_scales(
18
+ const array& scale,
19
+ cu::CommandEncoder& encoder,
20
+ const Stream& s) {
21
+ // Compute padded dimensions for full tiles (128 rows × 4 cols)
22
+ auto [pad_outer, pad_inner] =
23
+ get_padded_scale_dims(scale.shape(-2), scale.shape(-1));
24
+ // cuBLAS requirements for scale factor layout:
25
+ // 1. Dimensions must be padded to full tiles (128 rows × 4 cols)
26
+ // 2. Out-of-bounds values must be filled with zeros
27
+ // 3. Starting addresses must be 16-byte aligned
28
+ //
29
+ // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
30
+ // Note: cu::malloc_async already provides 256-byte alignment
31
+ array scale_tiled(
32
+ cu::malloc_async(pad_outer * pad_inner, encoder),
33
+ Shape{pad_outer, pad_inner},
34
+ scale.dtype());
35
+ swizzle_scales(scale, scale_tiled, encoder, s);
36
+
37
+ encoder.add_temporary(scale_tiled);
38
+ return scale_tiled;
39
+ }
40
+
41
+ } // namespace
42
+
43
+ void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
44
+ assert(
45
+ (inputs.size() == 3 && inputs[1].dtype() == uint32) ||
46
+ (inputs.size() == 2));
47
+ nvtx3::scoped_range r("QQMatmul::eval_gpu");
48
+
49
+ auto& s = stream();
50
+ auto& encoder = cu::get_command_encoder(s);
51
+ auto& device = encoder.device();
52
+
53
+ bool w_quantized = (inputs[1].dtype() == uint32);
54
+ if (w_quantized && inputs[0].shape(-2) == 1) {
55
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
56
+
57
+ bool donate_x = inputs[0].is_donatable();
58
+ array x = ensure_row_contiguous(inputs[0], encoder, s);
59
+ // If x is a copy it should be donatable
60
+ donate_x |= x.is_donatable();
61
+ auto xhat = donate_x
62
+ ? x
63
+ : array(cu::malloc_async(x.nbytes(), encoder), x.shape(), x.dtype());
64
+ if (!donate_x) {
65
+ encoder.add_temporary(xhat);
66
+ }
67
+ fp_quantize_dequantize(x, xhat, group_size_, bits_, encoder, s);
68
+
69
+ // Make sure the last two dims of w and s are contiguous
70
+ array w = ensure_row_contiguous_matrix(inputs[1], encoder, s);
71
+ array scales = ensure_row_contiguous_matrix(inputs[2], encoder, s);
72
+
73
+ bool non_batched = w.ndim() == 2;
74
+ int K = x.shape(-1);
75
+ int M = non_batched ? x.size() / K : x.shape(-2);
76
+ int N = out.shape(-1);
77
+
78
+ fp_qmv(w, scales, xhat, out, bits_, group_size_, M, N, K, encoder);
79
+ return;
80
+ }
81
+
82
+ auto cc = device.compute_capability_major() * 100 +
83
+ device.compute_capability_minor() * 10;
84
+ if (cc < 1000) {
85
+ throw std::runtime_error(
86
+ "[QQMatmul::eval_gpu] QQMM is only supported on GPUs with compute capability 10.0 or higher.");
87
+ }
88
+ auto quantize = [&](const array& input,
89
+ cu::CommandEncoder& encoder,
90
+ const Stream& s) -> std::pair<array, array> {
91
+ auto x = ensure_contiguous(input, encoder, s);
92
+ auto xq_shape = x.shape();
93
+ xq_shape.back() = x.shape(-1) * bits_ / 32;
94
+
95
+ auto sshape = x.shape();
96
+ const int64_t scales_inner = x.shape(-1) / group_size_;
97
+ auto [pad_outer, pad_inner] =
98
+ get_padded_scale_dims(x.shape(-2), scales_inner);
99
+ sshape[x.ndim() - 2] = pad_outer;
100
+ sshape[x.ndim() - 1] = pad_inner;
101
+ sshape.back() = scales_inner;
102
+
103
+ // Allocate outputs
104
+ const int64_t xq_bytes = x.size() * bits_ / 8;
105
+ const int64_t batch = x.size() / (x.shape(-2) * x.shape(-1));
106
+ const int64_t scales_bytes = batch * (pad_outer * pad_inner);
107
+
108
+ array x_q(cu::malloc_async(xq_bytes, encoder), std::move(xq_shape), uint32);
109
+ array scales_x(
110
+ cu::malloc_async(scales_bytes, encoder), std::move(sshape), uint8);
111
+
112
+ fp_quantize(x, x_q, scales_x, group_size_, bits_, encoder, s);
113
+
114
+ encoder.add_temporary(x_q);
115
+ encoder.add_temporary(scales_x);
116
+ return {x_q, scales_x};
117
+ };
118
+ auto [x_q, scale_x_pre] = quantize(inputs[0], encoder, s);
119
+ auto [w_q, scale_w_pre] = !w_quantized ? quantize(inputs[1], encoder, s)
120
+ : std::make_pair(inputs[1], inputs[2]);
121
+
122
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
123
+
124
+ auto out_dtype = out.dtype();
125
+
126
+ int M = x_q.shape(-2);
127
+ int N = w_q.shape(-2); // always transposed
128
+ int K_packed = x_q.shape(-1);
129
+ int K = K_packed * (32 / bits_);
130
+
131
+ // Repack scales from linear to tiled layout for tensor cores
132
+ array scale_x = pad_and_swizzle_scales(scale_x_pre, encoder, s);
133
+ array scale_w = pad_and_swizzle_scales(scale_w_pre, encoder, s);
134
+
135
+ bool x_transposed = false;
136
+ bool w_transposed = true; // always transposed
137
+ int64_t lda = K;
138
+ int64_t ldb = K;
139
+
140
+ qqmm_impl(
141
+ encoder,
142
+ M,
143
+ N,
144
+ K,
145
+ x_transposed,
146
+ lda,
147
+ w_transposed,
148
+ ldb,
149
+ out,
150
+ x_q,
151
+ w_q,
152
+ scale_x,
153
+ scale_w,
154
+ out_dtype,
155
+ mode_);
156
+ }
157
+
158
+ } // namespace mlx::core
@@ -0,0 +1,50 @@
1
+ // Copyright © 2026 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/quantized/qqmm_impl.h"
4
+ #include "mlx/backend/cuda/quantized/cublas_qqmm.h"
5
+
6
+ namespace mlx::core {
7
+
8
+ void qqmm_impl(
9
+ cu::CommandEncoder& encoder,
10
+ int M,
11
+ int N,
12
+ int K,
13
+ bool a_transposed,
14
+ int64_t lda,
15
+ bool b_transposed,
16
+ int64_t ldb,
17
+ array& out,
18
+ const array& a,
19
+ const array& b,
20
+ const array& a_scale,
21
+ const array& b_scale,
22
+ Dtype out_dtype,
23
+ QuantizationMode mode,
24
+ float alpha) {
25
+ // Invoke CublasQQMM
26
+ std::string qmode = quantization_mode_to_string(mode);
27
+
28
+ // Currently only supports non-batched QQMM operations
29
+ // that covers all use cases for training, we will just collapse (batch,
30
+ // seq_len) into (tokens)
31
+ CublasQQMM qqmm(
32
+ encoder.device(),
33
+ a_transposed,
34
+ M,
35
+ K,
36
+ lda,
37
+ b_transposed,
38
+ K,
39
+ N,
40
+ ldb,
41
+ 1, // batch_count
42
+ 0, // a_batch_stride
43
+ 0, // b_batch_stride
44
+ out_dtype,
45
+ qmode);
46
+
47
+ qqmm.run(encoder, out, a, b, a_scale, b_scale, alpha);
48
+ }
49
+
50
+ } // namespace mlx::core
@@ -0,0 +1,26 @@
1
+ // Copyright © 2026 Apple Inc.
2
+ #pragma once
3
+
4
+ #include "mlx/backend/cuda/device.h"
5
+ #include "mlx/primitives.h"
6
+
7
+ namespace mlx::core {
8
+ void qqmm_impl(
9
+ cu::CommandEncoder& encoder,
10
+ int M,
11
+ int N,
12
+ int K,
13
+ bool a_transposed,
14
+ int64_t lda,
15
+ bool b_transposed,
16
+ int64_t ldb,
17
+ array& out,
18
+ const array& a,
19
+ const array& b,
20
+ const array& a_scale,
21
+ const array& b_scale,
22
+ Dtype out_dtype,
23
+ QuantizationMode mode,
24
+ float alpha = 1.0f);
25
+
26
+ } // namespace mlx::core
@@ -0,0 +1,227 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/device.h"
4
+ #include "mlx/backend/cuda/kernel_utils.cuh"
5
+ #include "mlx/backend/cuda/quantized/qqmm_utils.h"
6
+
7
+ #include <cooperative_groups.h>
8
+
9
+ namespace mlx::core {
10
+
11
+ namespace cg = cooperative_groups;
12
+
13
+ constexpr int TILE_ROWS = 128;
14
+ constexpr int TILE_COLS = 4;
15
+ constexpr int TILES_PER_LANE = 1;
16
+ constexpr int LANES_PER_BLOCK = 32;
17
+
18
+ // To pass scales to tensor cores, they need to be repacked into a tiled layout
19
+ // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
20
+ // Tiled layout for scale factors is very well described in CUTLASS
21
+ // documentation:
22
+ // https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts
23
+ // Conceptually, it should be like this:
24
+ // q_w = mx.zeros(shape=(M, N)) <-- zeros just for an example
25
+ // s.shape = (M, N // 16) -- packed in row contigous order, group_size = 16
26
+ // cbg_cnt = N // 16 // 4
27
+ // rb_cnt = M // 128
28
+ // tmp = x.reshape(rb_cnt, 4, 32, cbg_cnt, 4)
29
+ // repacked_scales = tmp.transpose(0, 3, 2, 1, 4)
30
+ // example: indecis of intial tile 128 x 4 of scales (packed in row major tensor
31
+ // (M, K // 16), where M = 128, K = 64): array([[0, 1, 2, 3],
32
+ // [4, 5, 6, 7],
33
+ // [8, 9, 10, 11],
34
+ // ...,
35
+ // [500, 501, 502, 503],
36
+ // [504, 505, 506, 507],
37
+ // [508, 509, 510, 511]]
38
+ // packed scales within tile 128 x 4:
39
+ // array([[[[[0, 1, 2, 3], <-- s_0,0..s_0,3 scales
40
+ // [128, 129, 130, 131], <-- s_32,0..s_32,3 scales
41
+ // [256, 257, 258, 259], <-- s_64,0..s_64,3 scales
42
+ // [384, 385, 386, 387]], <-- s_96,0..s_96,3 scales
43
+ // [[4, 5, 6, 7], <-- s_1,0..s_1,3 scales
44
+ // [132, 133, 134, 135], ...
45
+ // [260, 261, 262, 263],
46
+ // [388, 389, 390, 391]],
47
+ // [[124, 125, 126, 127],
48
+ // [252, 253, 254, 255],
49
+ // [380, 381, 382, 383],
50
+ // [508, 509, 510, 511]]]]],
51
+
52
+ inline std::tuple<dim3, dim3> get_swizzle_launch_args(
53
+ size_t M_swizzled,
54
+ size_t K_swizzled) {
55
+ constexpr int tiles_per_block = LANES_PER_BLOCK * TILES_PER_LANE;
56
+ constexpr int warps_per_block = TILE_ROWS / 4; // 128 / 4 = 32
57
+
58
+ const int num_tiles_k = K_swizzled / TILE_COLS;
59
+ const int num_tiles_m = M_swizzled / TILE_ROWS;
60
+
61
+ dim3 grid;
62
+ grid.x = cuda::ceil_div(num_tiles_k, tiles_per_block);
63
+ grid.y = num_tiles_m;
64
+ grid.z = 1;
65
+ // Block is always (32, 32) = 1024 threads
66
+ dim3 block(LANES_PER_BLOCK, warps_per_block, 1);
67
+
68
+ return std::make_tuple(grid, block);
69
+ }
70
+
71
+ namespace cu {
72
+
73
+ __global__ void swizzle_scales(
74
+ const uint8_t* scales_linear,
75
+ uint8_t* scales_swizzled,
76
+ const size_t M,
77
+ const size_t K,
78
+ const size_t M_swizzled,
79
+ const size_t K_swizzled) {
80
+ constexpr int tile_size = TILE_ROWS * TILE_COLS;
81
+ constexpr int num_tile_rows_per_thread = 4;
82
+ constexpr int max_tiles_per_block = LANES_PER_BLOCK * TILES_PER_LANE;
83
+
84
+ constexpr int tile_stride = tile_size / 16; // 32 int4s per tile
85
+
86
+ // Each thread loads 16 scales from 4 rows (stride 32) and packs them into
87
+ // int4. For example: thread (0, 0) loads scales at rows 0,32,64,96 of tile 0,
88
+ // thread (1, 0) loads rows 0,32,64,96 of of tile 1, etc.
89
+ // The store is strided within a warp (stride 32 int4s), so we first
90
+ // write to shared memory, then do a coalesced store from shared to global
91
+ auto block_size = cg::this_thread_block().dim_threads();
92
+ auto block_idx = cg::this_thread_block().group_index();
93
+ auto idx_in_block = cg::this_thread_block().thread_index();
94
+
95
+ auto tidx = idx_in_block.x;
96
+ auto tidy = idx_in_block.y;
97
+ auto linear_tid = tidy * block_size.x + tidx;
98
+
99
+ const int bid_x = block_idx.x;
100
+ const int bid_y = block_idx.y;
101
+
102
+ const int K_int = K_swizzled / 4;
103
+
104
+ const size_t output_offset = static_cast<size_t>(bid_y) * TILE_ROWS * K_int +
105
+ static_cast<size_t>(bid_x) * max_tiles_per_block * tile_size / 4;
106
+ int* output_block = reinterpret_cast<int*>(scales_swizzled) + output_offset;
107
+
108
+ const int grid_dim_x = cg::this_grid().dim_blocks().x;
109
+ const int grid_dim_y = cg::this_grid().dim_blocks().y;
110
+
111
+ int remaining = K_int - bid_x * max_tiles_per_block;
112
+ int tiles_in_block = min(remaining, max_tiles_per_block);
113
+ bool valid_tile = tidx * TILES_PER_LANE < tiles_in_block;
114
+
115
+ __shared__ int4 strided_scales_thread[max_tiles_per_block * tile_stride];
116
+
117
+ // Initialize to zero for padding
118
+ int thread_tile_rows[num_tile_rows_per_thread] = {0};
119
+
120
+ if (valid_tile) {
121
+ const size_t col_base =
122
+ static_cast<size_t>(bid_x) * max_tiles_per_block * TILE_COLS +
123
+ tidx * TILE_COLS;
124
+
125
+ const bool aligned_k = (K % 4 == 0);
126
+
127
+ if (aligned_k) {
128
+ // fast path: K is aligned, use vectorized loads with stride K/4
129
+ const int K_stride = K / 4;
130
+ const size_t block_offset =
131
+ static_cast<size_t>(bid_y) * TILE_ROWS * K_stride +
132
+ static_cast<size_t>(bid_x) * max_tiles_per_block;
133
+ const int* input_block =
134
+ reinterpret_cast<const int*>(scales_linear) + block_offset;
135
+ // load
136
+ #pragma unroll
137
+ for (int i = 0; i < num_tile_rows_per_thread; i++) {
138
+ const size_t row =
139
+ static_cast<size_t>(bid_y) * TILE_ROWS + i * block_size.x + tidy;
140
+ const int thread_offset =
141
+ (i * block_size.x + tidy) * K_stride + tidx * TILES_PER_LANE;
142
+ if (row < M && col_base + TILE_COLS <= K) {
143
+ thread_tile_rows[i] = __ldg(input_block + thread_offset);
144
+ } else if (row < M) {
145
+ // partial tile at K boundary: load byte-by-byte
146
+ #pragma unroll
147
+ for (int c = 0; c < TILE_COLS; c++) {
148
+ if (col_base + c < K) {
149
+ reinterpret_cast<uint8_t*>(&thread_tile_rows[i])[c] =
150
+ scales_linear[row * K + col_base + c];
151
+ }
152
+ }
153
+ }
154
+ }
155
+ } else {
156
+ #pragma unroll
157
+ for (int i = 0; i < num_tile_rows_per_thread; i++) {
158
+ const size_t row =
159
+ static_cast<size_t>(bid_y) * TILE_ROWS + i * block_size.x + tidy;
160
+ if (row < M) {
161
+ const size_t row_start = row * K;
162
+ #pragma unroll
163
+ for (int c = 0; c < TILE_COLS; c++) {
164
+ if (col_base + c < K) {
165
+ reinterpret_cast<uint8_t*>(&thread_tile_rows[i])[c] =
166
+ scales_linear[row_start + col_base + c];
167
+ }
168
+ }
169
+ }
170
+ }
171
+ }
172
+ // store to shared with XOR swizzle to avoid bank conflicts
173
+ int base_idx = tidx * tile_stride + tidy;
174
+ int xor_bits = (tidy >> 3) & 0x3;
175
+ int swizzled_idx = base_idx ^ xor_bits;
176
+ strided_scales_thread[swizzled_idx] =
177
+ *reinterpret_cast<int4*>(thread_tile_rows);
178
+ }
179
+
180
+ cg::thread_block block = cg::this_thread_block();
181
+ cg::sync(block);
182
+
183
+ const int total_int4s = tiles_in_block * tile_stride;
184
+ #pragma unroll
185
+ for (int i = linear_tid; i < total_int4s; i += block_size.x * block_size.y) {
186
+ int tile_idx = i / tile_stride;
187
+ int row_idx = i % tile_stride;
188
+ int base_idx = tile_idx * tile_stride + row_idx;
189
+ int xor_bits = (row_idx >> 3) & 0x3;
190
+ int swizzled_idx = base_idx ^ xor_bits;
191
+ reinterpret_cast<int4*>(output_block)[i] =
192
+ strided_scales_thread[swizzled_idx];
193
+ }
194
+ }
195
+ } // namespace cu
196
+
197
+ void swizzle_scales(
198
+ const array& scales,
199
+ array& scales_tiled,
200
+ cu::CommandEncoder& enc,
201
+ const Stream& s) {
202
+ enc.set_input_array(scales);
203
+ enc.set_output_array(scales_tiled);
204
+ // Note: scales_tiled is padded to full tiles so if num_rows or num_cols
205
+ // are not multiples of tile sizes
206
+ size_t input_rows = scales.shape(-2);
207
+ size_t input_cols = scales.shape(-1);
208
+
209
+ size_t output_rows = scales_tiled.shape(-2);
210
+ size_t output_cols = scales_tiled.shape(-1);
211
+
212
+ auto [num_blocks, block_dims] =
213
+ get_swizzle_launch_args(output_rows, output_cols);
214
+ enc.add_kernel_node(
215
+ cu::swizzle_scales,
216
+ num_blocks,
217
+ block_dims,
218
+ 0,
219
+ gpu_ptr<uint8_t>(scales),
220
+ gpu_ptr<uint8_t>(scales_tiled),
221
+ input_rows,
222
+ input_cols,
223
+ output_rows,
224
+ output_cols);
225
+ }
226
+
227
+ } // namespace mlx::core
@@ -0,0 +1,30 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/cuda/device.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ // Compute padded dimensions for tiled layout
11
+ // Tiles are 128 rows × 4 columns, must allocate full tiles
12
+ inline std::pair<int, int> get_padded_scale_dims(int num_rows, int num_cols) {
13
+ constexpr int rows_per_tile = 128;
14
+ constexpr int cols_per_tile = 4;
15
+
16
+ int padded_rows =
17
+ ((num_rows + rows_per_tile - 1) / rows_per_tile) * rows_per_tile;
18
+ int padded_cols =
19
+ ((num_cols + cols_per_tile - 1) / cols_per_tile) * cols_per_tile;
20
+
21
+ return {padded_rows, padded_cols};
22
+ }
23
+
24
+ void swizzle_scales(
25
+ const array& scales,
26
+ array& scales_tiled,
27
+ cu::CommandEncoder& enc,
28
+ const Stream& s);
29
+
30
+ } // namespace mlx::core
@@ -0,0 +1,85 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/quantized/quantized.h"
4
+ #include "mlx/backend/cuda/device.h"
5
+ #include "mlx/backend/cuda/quantized/qmv.h"
6
+ #include "mlx/backend/cuda/quantized/quantized_utils.h"
7
+ #include "mlx/fast_primitives.h"
8
+ #include "mlx/primitives.h"
9
+
10
+ #include <nvtx3/nvtx3.hpp>
11
+
12
+ namespace mlx::core {
13
+
14
+ void QuantizedMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
15
+ nvtx3::scoped_range r("QuantizedMatmul::eval_gpu");
16
+ auto& s = stream();
17
+ auto& d = cu::device(s.device);
18
+ auto& enc = d.get_command_encoder(s);
19
+
20
+ out.set_data(cu::malloc_async(out.nbytes(), enc));
21
+
22
+ // Make sure the last two dims of x and w, s, b are contiguous. This should
23
+ // be relaxed for x.
24
+ array x = ensure_row_contiguous_matrix(inputs[0], enc, s);
25
+ array w = ensure_row_contiguous_matrix(inputs[1], enc, s);
26
+ array scales = ensure_row_contiguous_matrix(inputs[2], enc, s);
27
+ std::optional<array> biases = std::nullopt;
28
+ if (inputs.size() == 4) {
29
+ biases = ensure_row_contiguous_matrix(inputs[3], enc, s);
30
+ }
31
+
32
+ bool non_batched = w.ndim() == 2 && x.flags().row_contiguous;
33
+ int K = x.shape(-1);
34
+ int M = non_batched ? x.size() / K : x.shape(-2);
35
+ int N = out.shape(-1);
36
+
37
+ if (M > 8 || !transpose_ || mode_ == QuantizationMode::Affine) {
38
+ throw std::runtime_error("QMM NYI");
39
+ }
40
+
41
+ if (transpose_) {
42
+ fp_qmv(w, scales, x, out, bits_, group_size_, M, N, K, enc);
43
+ return;
44
+ }
45
+ }
46
+
47
+ void fast::Quantize::eval_gpu(
48
+ const std::vector<array>& inputs,
49
+ std::vector<array>& outputs) {
50
+ nvtx3::scoped_range r("Quantize::eval_gpu");
51
+ auto& s = stream();
52
+ auto& d = cu::device(s.device);
53
+ auto& enc = d.get_command_encoder(s);
54
+
55
+ if (dequantize_) {
56
+ auto wq = ensure_row_contiguous(inputs[0], enc, s);
57
+ auto scales = ensure_row_contiguous(inputs[1], enc, s);
58
+ auto& w = outputs[0];
59
+
60
+ w.set_data(cu::malloc_async(w.nbytes(), enc));
61
+
62
+ if (mode_ == QuantizationMode::Affine) {
63
+ auto biases = ensure_row_contiguous(inputs[2], enc, s);
64
+ affine_dequantize(wq, scales, biases, w, group_size_, bits_, enc, s);
65
+ } else {
66
+ fp_dequantize(wq, scales, w, group_size_, bits_, enc, s);
67
+ }
68
+ } else {
69
+ auto w = ensure_contiguous(inputs[0], enc, s);
70
+ auto& wq = outputs[0];
71
+ auto& scales = outputs[1];
72
+
73
+ wq.set_data(cu::malloc_async(wq.nbytes(), enc));
74
+ scales.set_data(cu::malloc_async(scales.nbytes(), enc));
75
+ if (mode_ == QuantizationMode::Affine) {
76
+ auto& biases = outputs[2];
77
+ biases.set_data(cu::malloc_async(biases.nbytes(), enc));
78
+ affine_quantize(w, wq, scales, biases, group_size_, bits_, enc, s);
79
+ } else {
80
+ fp_quantize(w, wq, scales, group_size_, bits_, enc, s);
81
+ }
82
+ }
83
+ }
84
+
85
+ } // namespace mlx::core
@@ -0,0 +1,53 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/device.h"
4
+
5
+ namespace mlx::core {
6
+
7
+ void affine_quantize(
8
+ const array& w,
9
+ array& wq,
10
+ array& scales,
11
+ array& biases,
12
+ int group_size_,
13
+ int bits_,
14
+ cu::CommandEncoder& enc,
15
+ const Stream& s);
16
+
17
+ void affine_dequantize(
18
+ const array& wq,
19
+ const array& scales,
20
+ const array& biases,
21
+ array& w,
22
+ int group_size_,
23
+ int bits_,
24
+ cu::CommandEncoder& enc,
25
+ const Stream& s);
26
+
27
+ void fp_quantize(
28
+ const array& w,
29
+ array& wq,
30
+ array& scales,
31
+ int group_size,
32
+ int bits,
33
+ cu::CommandEncoder& enc,
34
+ const Stream& s);
35
+
36
+ void fp_dequantize(
37
+ const array& wq,
38
+ const array& scales,
39
+ array& w,
40
+ int group_size,
41
+ int bits,
42
+ cu::CommandEncoder& enc,
43
+ const Stream& s);
44
+
45
+ void fp_quantize_dequantize(
46
+ const array& w,
47
+ array& what,
48
+ int group_size,
49
+ int bits,
50
+ cu::CommandEncoder& enc,
51
+ const Stream& s);
52
+
53
+ } // namespace mlx::core