mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,334 @@
1
+ #pragma once
2
+
3
+ #include <cuda.h>
4
+ #include <cuda_fp4.h>
5
+ #include <cuda_runtime.h>
6
+ #include "mlx/backend/cuda/vector_types.cuh"
7
+
8
+ namespace mlx::core::cu {
9
+
10
+ using bf16x4 = Vector4_t<__nv_bfloat16>;
11
+ using fp16x4 = Vector4_t<__half>;
12
+ using f32x4 = Vector4_t<float>;
13
+
14
+ template <typename T>
15
+ __device__ __forceinline__ uint16_t
16
+ scale_cvt_Tx4_to_fp4x4_fallback(const Vector4_t<T> input, const float scale) {
17
+ // Fallback implementation for architectures that do not support cvt
18
+ // instructions or for cuda versions with no fp4 support (< 12.8) -> scalar
19
+ uint16_t out_fp4x4 = 0;
20
+ fp32x4 scaled;
21
+ scaled.x = static_cast<float>(input.x) * scale;
22
+ scaled.y = static_cast<float>(input.y) * scale;
23
+ scaled.z = static_cast<float>(input.z) * scale;
24
+ scaled.w = static_cast<float>(input.w) * scale;
25
+ uint8_t q0 = __nv_fp4_e2m1(scaled.x).__x;
26
+ uint8_t q1 = __nv_fp4_e2m1(scaled.y).__x;
27
+ uint8_t q2 = __nv_fp4_e2m1(scaled.z).__x;
28
+ uint8_t q3 = __nv_fp4_e2m1(scaled.w).__x;
29
+ out_fp4x4 = (static_cast<uint16_t>(q3) << 12) |
30
+ (static_cast<uint16_t>(q2) << 8) | (static_cast<uint16_t>(q1) << 4) |
31
+ static_cast<uint16_t>(q0);
32
+ return out_fp4x4;
33
+ }
34
+
35
+ #if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
36
+ defined(__CUDA_ARCH_SPECIFIC__)
37
+
38
+ __device__ __forceinline__ uint16_t
39
+ scale_cvt_bf16x4_to_fp4x4_rn(const bf16x4 input_bf16x4, const float2 scale) {
40
+ uint16_t out_fp4x4 = 0;
41
+ asm volatile(
42
+ "{\n"
43
+ ".reg.b16 x0_bf16; \n\t" // first bf16
44
+ ".reg.b16 x1_bf16; \n\t" // second bf16
45
+ ".reg.b16 x2_bf16; \n\t" // third bf16
46
+ ".reg.b16 x3_bf16; \n\t" // fourth bf16
47
+ ".reg.b32 x0; \n\t" // to hold scaled first
48
+ ".reg.b32 x1; \n\t" // to hold scaled second
49
+ ".reg.b32 x2; \n\t" // to hold scaled third
50
+ ".reg.b32 x3; \n\t" // to hold scaled fourth
51
+ ".reg.b64 x01; \n\t" // to hold vector mul
52
+ ".reg.b64 x23; \n\t"
53
+ ".reg.b8 q0; \n\t" // output byte fp4x2 (first pair)
54
+ ".reg.b8 q1; \n\t" // output byte fp4x2 (second pair)
55
+ "mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t" // unpack bf16
56
+ "cvt.f32.bf16 x0, x0_bf16; \n\t" // convert to f32
57
+ "cvt.f32.bf16 x1, x1_bf16; \n\t"
58
+ "cvt.f32.bf16 x2, x2_bf16; \n\t"
59
+ "cvt.f32.bf16 x3, x3_bf16; \n\t"
60
+ "mov.b64 x01, {x0, x1}; \n\t"
61
+ "mul.f32x2 x01, x01, %2; \n\t" // scale first pair
62
+ "mov.b64 x23, {x2, x3}; \n\t"
63
+ "mul.f32x2 x23, x23, %2; \n\t" // scale second pair
64
+ "mov.b64 {x0, x1}, x01; \n\t"
65
+ "mov.b64 {x2, x3}, x23; \n\t"
66
+ "cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t" // convert to fp4x2 first
67
+ // pair
68
+ "cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t" // convert to fp4x2 second
69
+ // pair
70
+ "mov.b16 %0, {q0, q1}; \n\t" // pack to output
71
+ "}"
72
+ : "=h"(out_fp4x4)
73
+ : "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
74
+ "l"(reinterpret_cast<const uint64_t&>(
75
+ scale))); // here cast is needed becuase an asm operand must have
76
+ // scalar type
77
+ return out_fp4x4;
78
+ }
79
+
80
+ __device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4_rs(
81
+ const bf16x4 input_bf16x4,
82
+ const float2 scale,
83
+ uint32_t rbits) {
84
+ uint16_t out_fp4x4 = 0;
85
+ asm volatile(
86
+ "{\n"
87
+ ".reg.b16 x0_bf16; \n\t"
88
+ ".reg.b16 x1_bf16; \n\t"
89
+ ".reg.b16 x2_bf16; \n\t"
90
+ ".reg.b16 x3_bf16; \n\t"
91
+ ".reg.b32 x0; \n\t"
92
+ ".reg.b32 x1; \n\t"
93
+ ".reg.b32 x2; \n\t"
94
+ ".reg.b32 x3; \n\t"
95
+ ".reg.b64 x01; \n\t"
96
+ ".reg.b64 x23; \n\t"
97
+ ".reg.b16 q0; \n\t"
98
+ "mov.b64 {x0_bf16, x1_bf16, x2_bf16, x3_bf16} , %1; \n\t"
99
+ "cvt.f32.bf16 x0, x0_bf16; \n\t"
100
+ "cvt.f32.bf16 x1, x1_bf16; \n\t"
101
+ "cvt.f32.bf16 x2, x2_bf16; \n\t"
102
+ "cvt.f32.bf16 x3, x3_bf16; \n\t"
103
+ "mov.b64 x01, {x0, x1}; \n\t"
104
+ "mul.f32x2 x01, x01, %2; \n\t"
105
+ "mov.b64 x23, {x2, x3}; \n\t"
106
+ "mul.f32x2 x23, x23, %2; \n\t"
107
+ "mov.b64 {x0, x1}, x01; \n\t"
108
+ "mov.b64 {x2, x3}, x23; \n\t"
109
+ "cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
110
+ "}"
111
+ : "=h"(out_fp4x4)
112
+ : "l"(reinterpret_cast<const uint64_t&>(input_bf16x4)),
113
+ "l"(reinterpret_cast<const uint64_t&>(scale)),
114
+ "r"(rbits));
115
+ return out_fp4x4;
116
+ }
117
+
118
+ __device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rn(
119
+ const float2 input_fp32x2_0,
120
+ const float2 input_fp32x2_1,
121
+ const float2 scale) {
122
+ uint16_t out_fp4x4 = 0;
123
+ asm volatile(
124
+ "{\n"
125
+ ".reg.b32 x0; \n\t"
126
+ ".reg.b32 x1; \n\t"
127
+ ".reg.b32 x2; \n\t"
128
+ ".reg.b32 x3; \n\t"
129
+ ".reg.b64 x01; \n\t"
130
+ ".reg.b64 x23; \n\t"
131
+ ".reg.b8 q0; \n\t"
132
+ ".reg.b8 q1; \n\t"
133
+ "mov.b64 x01, {%1, %2}; \n\t"
134
+ "mul.f32x2 x01, x01, %5; \n\t"
135
+ "mov.b64 x23, {%3, %4}; \n\t"
136
+ "mul.f32x2 x23, x23, %5; \n\t"
137
+ "mov.b64 {x0, x1}, x01; \n\t"
138
+ "mov.b64 {x2, x3}, x23; \n\t"
139
+ "cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
140
+ "cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
141
+ "mov.b16 %0, {q0, q1}; \n\t"
142
+ "}"
143
+ : "=h"(out_fp4x4)
144
+ : "f"(input_fp32x2_0.x),
145
+ "f"(input_fp32x2_0.y),
146
+ "f"(input_fp32x2_1.x),
147
+ "f"(input_fp32x2_1.y),
148
+ "l"(reinterpret_cast<const uint64_t&>(scale)));
149
+ return out_fp4x4;
150
+ }
151
+
152
+ __device__ __forceinline__ uint16_t scale_cvt_fp32x4_to_fp4x4_rs(
153
+ const float2 input_fp32x2_0,
154
+ const float2 input_fp32x2_1,
155
+ const float2 scale,
156
+ uint32_t rbits) {
157
+ uint16_t out_fp4x4 = 0;
158
+ asm volatile(
159
+ "{\n"
160
+ ".reg.b32 x0; \n\t"
161
+ ".reg.b32 x1; \n\t"
162
+ ".reg.b32 x2; \n\t"
163
+ ".reg.b32 x3; \n\t"
164
+ ".reg.b64 x01; \n\t"
165
+ ".reg.b64 x23; \n\t"
166
+ ".reg.b16 q0; \n\t"
167
+ "mov.b64 x01, {%1, %2}; \n\t"
168
+ "mul.f32x2 x01, x01, %5; \n\t"
169
+ "mov.b64 x23, {%3, %4}; \n\t"
170
+ "mul.f32x2 x23, x23, %5; \n\t"
171
+ "mov.b64 {x0, x1}, x01; \n\t"
172
+ "mov.b64 {x2, x3}, x23; \n\t"
173
+ "cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %6; \n\t"
174
+ "}"
175
+ : "=h"(out_fp4x4)
176
+ : "f"(input_fp32x2_0.x),
177
+ "f"(input_fp32x2_0.y),
178
+ "f"(input_fp32x2_1.x),
179
+ "f"(input_fp32x2_1.y),
180
+ "l"(reinterpret_cast<const uint64_t&>(scale)),
181
+ "r"(rbits));
182
+ return out_fp4x4;
183
+ }
184
+
185
+ __device__ __forceinline__ uint16_t
186
+ scale_cvt_fp16x4_to_fp4x4_rn(const fp16x4 input_fp16x4, const float2 scale) {
187
+ uint16_t out_fp4x4 = 0;
188
+ asm volatile(
189
+ "{\n"
190
+ ".reg.b16 x0_fp16; \n\t"
191
+ ".reg.b16 x1_fp16; \n\t"
192
+ ".reg.b16 x2_fp16; \n\t"
193
+ ".reg.b16 x3_fp16; \n\t"
194
+ ".reg.b32 x0; \n\t"
195
+ ".reg.b32 x1; \n\t"
196
+ ".reg.b32 x2; \n\t"
197
+ ".reg.b32 x3; \n\t"
198
+ ".reg.b64 x01; \n\t"
199
+ ".reg.b64 x23; \n\t"
200
+ ".reg.b8 q0; \n\t"
201
+ ".reg.b8 q1; \n\t"
202
+ "mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
203
+ "cvt.f32.f16 x0, x0_fp16; \n\t"
204
+ "cvt.f32.f16 x1, x1_fp16; \n\t"
205
+ "cvt.f32.f16 x2, x2_fp16; \n\t"
206
+ "cvt.f32.f16 x3, x3_fp16; \n\t"
207
+ "mov.b64 x01, {x0, x1}; \n\t"
208
+ "mul.f32x2 x01, x01, %2; \n\t"
209
+ "mov.b64 x23, {x2, x3}; \n\t"
210
+ "mul.f32x2 x23, x23, %2; \n\t"
211
+ "mov.b64 {x0, x1}, x01; \n\t"
212
+ "mov.b64 {x2, x3}, x23; \n\t"
213
+ "cvt.rn.satfinite.e2m1x2.f32 q0, x1, x0; \n\t"
214
+ "cvt.rn.satfinite.e2m1x2.f32 q1, x3, x2; \n\t"
215
+ "mov.b16 %0, {q0, q1}; \n\t"
216
+ "}"
217
+ : "=h"(out_fp4x4)
218
+ : "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
219
+ "l"(reinterpret_cast<const uint64_t&>(scale)));
220
+ return out_fp4x4;
221
+ }
222
+
223
+ __device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4_rs(
224
+ const fp16x4 input_fp16x4,
225
+ const float2 scale,
226
+ uint32_t rbits) {
227
+ uint16_t out_fp4x4 = 0;
228
+ asm volatile(
229
+ "{\n"
230
+ ".reg.b16 x0_fp16; \n\t"
231
+ ".reg.b16 x1_fp16; \n\t"
232
+ ".reg.b16 x2_fp16; \n\t"
233
+ ".reg.b16 x3_fp16; \n\t"
234
+ ".reg.b32 x0; \n\t"
235
+ ".reg.b32 x1; \n\t"
236
+ ".reg.b32 x2; \n\t"
237
+ ".reg.b32 x3; \n\t"
238
+ ".reg.b64 x01; \n\t"
239
+ ".reg.b64 x23; \n\t"
240
+ ".reg.b16 q0; \n\t"
241
+ "mov.b64 {x0_fp16, x1_fp16, x2_fp16, x3_fp16} , %1; \n\t"
242
+ "cvt.f32.f16 x0, x0_fp16; \n\t"
243
+ "cvt.f32.f16 x1, x1_fp16; \n\t"
244
+ "cvt.f32.f16 x2, x2_fp16; \n\t"
245
+ "cvt.f32.f16 x3, x3_fp16; \n\t"
246
+ "mov.b64 x01, {x0, x1}; \n\t"
247
+ "mul.f32x2 x01, x01, %2; \n\t"
248
+ "mov.b64 x23, {x2, x3}; \n\t"
249
+ "mul.f32x2 x23, x23, %2; \n\t"
250
+ "mov.b64 {x0, x1}, x01; \n\t"
251
+ "mov.b64 {x2, x3}, x23; \n\t"
252
+ "cvt.rs.satfinite.e2m1x4.f32 q0, {x3, x2, x1, x0}, %3; \n\t"
253
+ "}"
254
+ : "=h"(out_fp4x4)
255
+ : "l"(reinterpret_cast<const uint64_t&>(input_fp16x4)),
256
+ "l"(reinterpret_cast<const uint64_t&>(scale)),
257
+ "r"(rbits));
258
+ return out_fp4x4;
259
+ }
260
+
261
+ template <bool USE_SR>
262
+ __device__ __forceinline__ uint16_t scale_cvt_bf16x4_to_fp4x4(
263
+ const bf16x4 input,
264
+ const float scale,
265
+ uint32_t rbits) {
266
+ float2 scale_fp32x2 = make_float2(scale, scale);
267
+ if constexpr (USE_SR) {
268
+ return scale_cvt_bf16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
269
+ } else {
270
+ return scale_cvt_bf16x4_to_fp4x4_rn(input, scale_fp32x2);
271
+ }
272
+ }
273
+
274
+ template <bool USE_SR>
275
+ __device__ __forceinline__ uint16_t scale_cvt_fp16x4_to_fp4x4(
276
+ const fp16x4 input,
277
+ const float scale,
278
+ uint32_t rbits) {
279
+ float2 scale_fp32x2 = make_float2(scale, scale);
280
+ if constexpr (USE_SR) {
281
+ return scale_cvt_fp16x4_to_fp4x4_rs(input, scale_fp32x2, rbits);
282
+ } else {
283
+ return scale_cvt_fp16x4_to_fp4x4_rn(input, scale_fp32x2);
284
+ }
285
+ }
286
+
287
+ template <bool USE_SR>
288
+ __device__ __forceinline__ uint16_t
289
+ scale_cvt_f32x4_to_fp4x4(const f32x4 input, const float scale, uint32_t rbits) {
290
+ float2 scale_fp32x2 = make_float2(scale, scale);
291
+ float2 input_fp32x2_0 = make_float2(input.x, input.y);
292
+ float2 input_fp32x2_1 = make_float2(input.z, input.w);
293
+
294
+ if constexpr (USE_SR) {
295
+ return scale_cvt_fp32x4_to_fp4x4_rs(
296
+ input_fp32x2_0, input_fp32x2_1, scale_fp32x2, rbits);
297
+ } else {
298
+ return scale_cvt_fp32x4_to_fp4x4_rn(
299
+ input_fp32x2_0, input_fp32x2_1, scale_fp32x2);
300
+ }
301
+ }
302
+
303
+ template <typename T, bool USE_SR>
304
+ __device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4_fast(
305
+ const Vector4_t<T> input,
306
+ const float scale,
307
+ uint32_t rbits) {
308
+ if constexpr (std::is_same<T, __nv_bfloat16>::value) {
309
+ return scale_cvt_bf16x4_to_fp4x4<USE_SR>(input, scale, rbits);
310
+ } else if constexpr (std::is_same<T, __half>::value) {
311
+ return scale_cvt_fp16x4_to_fp4x4<USE_SR>(input, scale, rbits);
312
+ } else {
313
+ return scale_cvt_f32x4_to_fp4x4<USE_SR>(input, scale, rbits);
314
+ }
315
+ }
316
+ #endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) &&
317
+ // (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
318
+
319
+ template <typename T, bool USE_SR>
320
+ __device__ __forceinline__ uint16_t scale_cvt_Tx4_to_fp4x4(
321
+ const Vector4_t<T> input,
322
+ const float scale,
323
+ uint32_t rbits) {
324
+ #if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
325
+ (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000)
326
+ return scale_cvt_Tx4_to_fp4x4_fast<T, USE_SR>(input, scale, rbits);
327
+ #else
328
+ static_assert(
329
+ !USE_SR,
330
+ "Stochastic rounding (USE_SR=true) requires CUDA >= 12.8 and compute capability >= 1000.");
331
+ return scale_cvt_Tx4_to_fp4x4_fallback(input, scale);
332
+ #endif
333
+ }
334
+ } // namespace mlx::core::cu
@@ -0,0 +1,304 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/device/utils.cuh"
4
+ #include "mlx/backend/cuda/kernel_utils.cuh"
5
+ #include "mlx/backend/cuda/quantized/qmv.h"
6
+ #include "mlx/backend/cuda/quantized/quantized_utils.cuh"
7
+ #include "mlx/dtype_utils.h"
8
+
9
+ #include <cooperative_groups.h>
10
+ #include <cooperative_groups/reduce.h>
11
+
12
+ namespace mlx::core::cu {
13
+
14
+ namespace cg = cooperative_groups;
15
+
16
+ static constexpr int rows_per_block = 8;
17
+
18
+ template <typename T>
19
+ __device__ void adjust_matrix_offsets(
20
+ const T*& x,
21
+ const uint32_t*& w,
22
+ const uint8_t*& scales,
23
+ T*& y,
24
+ int output_stride,
25
+ const int& x_batch_ndims,
26
+ const Shape x_shape,
27
+ const Strides x_strides,
28
+ const int& w_batch_ndims,
29
+ const Shape w_shape,
30
+ const Strides w_strides,
31
+ const Strides s_strides) {
32
+ uint32_t idx = cg::this_grid().block_index().z;
33
+ if (x_batch_ndims == 1) {
34
+ x += idx * x_strides[0];
35
+ } else {
36
+ x += elem_to_loc(idx, x_shape.data(), x_strides.data(), x_batch_ndims);
37
+ }
38
+ if (w_batch_ndims == 1) {
39
+ w += idx * w_strides[0];
40
+ scales += idx * s_strides[0];
41
+ } else {
42
+ auto [w_idx, s_idx] = elem_to_loc(
43
+ idx, w_shape.data(), w_strides.data(), s_strides.data(), w_batch_ndims);
44
+ w += w_idx;
45
+ scales += s_idx;
46
+ }
47
+ y += idx * output_stride;
48
+ }
49
+
50
+ template <
51
+ typename T,
52
+ int rows_per_block,
53
+ int n_per_thread,
54
+ int bits,
55
+ int group_size,
56
+ bool use_mx_scale>
57
+ __device__ void fp_qmv_impl(
58
+ const uint32_t* mat,
59
+ const uint8_t* scales_,
60
+ const T* vec,
61
+ T* out,
62
+ int rows,
63
+ int cols) {
64
+ auto block = cg::this_thread_block();
65
+ auto warp = cg::tiled_partition<WARP_SIZE>(block);
66
+
67
+ constexpr int vals_per_item = bits == 8 ? 4 : 8;
68
+ constexpr int nv_per_thread = vals_per_item * n_per_thread;
69
+ auto g_idx = block.group_index();
70
+ auto t_idx = block.thread_index();
71
+ int row = g_idx.y * rows_per_block + t_idx.y;
72
+
73
+ vec += g_idx.x * cols;
74
+ out += g_idx.x * rows;
75
+
76
+ using ScaleType =
77
+ std::conditional_t<use_mx_scale, __nv_fp8_e8m0, __nv_fp8_e4m3>;
78
+ auto scales = (ScaleType*)(scales_);
79
+ auto packed_cols = cols / vals_per_item;
80
+
81
+ if (row < rows) {
82
+ constexpr int scales_per_step = std::max(nv_per_thread / group_size, 1);
83
+ constexpr int scale_step = (WARP_SIZE * nv_per_thread) / group_size;
84
+ constexpr int n_per_step = n_per_thread / scales_per_step;
85
+ // Offset scales to correct row
86
+ scales += row * (cols / group_size) +
87
+ (warp.thread_rank() * nv_per_thread) / group_size;
88
+ float sum = 0.0f;
89
+ for (int col = n_per_thread * warp.thread_rank(); col < packed_cols;
90
+ col += (WARP_SIZE * n_per_thread)) {
91
+ auto local_vec =
92
+ unsafe_load_vector<nv_per_thread>(vec + vals_per_item * col, 0);
93
+ auto local_mat =
94
+ unsafe_load_vector<n_per_thread>(mat + row * packed_cols + col, 0);
95
+ #pragma unroll
96
+ for (int i = 0; i < scales_per_step; ++i) {
97
+ float2 local_sum = {0.0f, 0.0f};
98
+ #pragma unroll
99
+ for (int j = 0; j < n_per_step; ++j) {
100
+ int k = n_per_step * i + j;
101
+ if constexpr (bits == 8) {
102
+ auto v = dequant_fp8(local_mat[k]);
103
+ local_sum.x +=
104
+ v.x * static_cast<float>(local_vec[vals_per_item * k]);
105
+ local_sum.x +=
106
+ v.y * static_cast<float>(local_vec[vals_per_item * k + 1]);
107
+ local_sum.y +=
108
+ v.z * static_cast<float>(local_vec[vals_per_item * k + 2]);
109
+ local_sum.y +=
110
+ v.w * static_cast<float>(local_vec[vals_per_item * k + 3]);
111
+ } else {
112
+ auto v = dequant_fp4(local_mat[k]);
113
+ local_sum.x +=
114
+ v.x * static_cast<float>(local_vec[vals_per_item * k]);
115
+ local_sum.y +=
116
+ v.y * static_cast<float>(local_vec[vals_per_item * k + 1]);
117
+ local_sum.x +=
118
+ v.z * static_cast<float>(local_vec[vals_per_item * k + 2]);
119
+ local_sum.y +=
120
+ v.w * static_cast<float>(local_vec[vals_per_item * k + 3]);
121
+
122
+ v = dequant_fp4(local_mat[k] >> 16);
123
+ local_sum.x +=
124
+ v.x * static_cast<float>(local_vec[vals_per_item * k + 4]);
125
+ local_sum.y +=
126
+ v.y * static_cast<float>(local_vec[vals_per_item * k + 5]);
127
+ local_sum.x +=
128
+ v.z * static_cast<float>(local_vec[vals_per_item * k + 6]);
129
+ local_sum.y +=
130
+ v.w * static_cast<float>(local_vec[vals_per_item * k + 7]);
131
+ }
132
+ }
133
+ sum += (local_sum.x + local_sum.y) * float(scales[i]);
134
+ }
135
+ scales += scale_step;
136
+ }
137
+
138
+ sum = cg::reduce(warp, sum, cg::plus<float>{});
139
+ if (warp.thread_rank() == 0) {
140
+ out[row] = static_cast<T>(sum);
141
+ }
142
+ }
143
+ }
144
+
145
+ template <
146
+ typename T,
147
+ int rows_per_block,
148
+ int n_per_thread,
149
+ int bits,
150
+ int group_size,
151
+ bool use_mx_scale>
152
+ __global__ void fp_qmv_single(
153
+ const uint32_t* mat,
154
+ const uint8_t* scales,
155
+ const T* vec,
156
+ T* out,
157
+ int rows,
158
+ int cols) {
159
+ fp_qmv_impl<T, rows_per_block, n_per_thread, bits, group_size, use_mx_scale>(
160
+ mat, scales, vec, out, rows, cols);
161
+ }
162
+
163
+ template <
164
+ typename T,
165
+ int rows_per_block,
166
+ int n_per_thread,
167
+ int bits,
168
+ int group_size,
169
+ bool use_mx_scale>
170
+ __global__ void fp_qmv_batched(
171
+ const uint32_t* mat,
172
+ const uint8_t* scales,
173
+ const T* vec,
174
+ T* out,
175
+ int rows,
176
+ int cols,
177
+ int vec_batch_ndims,
178
+ const __grid_constant__ Shape vec_shape,
179
+ const __grid_constant__ Strides vec_strides,
180
+ int mat_batch_ndims,
181
+ const __grid_constant__ Shape mat_shape,
182
+ const __grid_constant__ Strides mat_strides,
183
+ const __grid_constant__ Strides scales_strides) {
184
+ adjust_matrix_offsets<T>(
185
+ vec,
186
+ mat,
187
+ scales,
188
+ out,
189
+ rows * vec_shape[vec_batch_ndims],
190
+ vec_batch_ndims,
191
+ vec_shape,
192
+ vec_strides,
193
+ mat_batch_ndims,
194
+ mat_shape,
195
+ mat_strides,
196
+ scales_strides);
197
+ fp_qmv_impl<T, rows_per_block, n_per_thread, bits, group_size, use_mx_scale>(
198
+ mat, scales, vec, out, rows, cols);
199
+ }
200
+
201
+ template <typename F>
202
+ void dispatch_1_2_4(int n, F&& f) {
203
+ switch (n) {
204
+ case 1:
205
+ f(std::integral_constant<int, 1>{});
206
+ break;
207
+ case 2:
208
+ f(std::integral_constant<int, 2>{});
209
+ break;
210
+ case 4:
211
+ f(std::integral_constant<int, 4>{});
212
+ break;
213
+ }
214
+ }
215
+
216
+ void fp_qmv(
217
+ const array& mat,
218
+ const array& scales,
219
+ const array& vec,
220
+ array& out,
221
+ int bits,
222
+ int group_size,
223
+ int M,
224
+ int N,
225
+ int K,
226
+ CommandEncoder& encoder) {
227
+ encoder.set_input_array(mat);
228
+ encoder.set_input_array(scales);
229
+ encoder.set_input_array(vec);
230
+ encoder.set_output_array(out);
231
+ dispatch_float_types(out.dtype(), "qmv", [&](auto type_tag) {
232
+ using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
233
+ if constexpr (!std::is_same_v<T, double>) {
234
+ dim3 block_dims{WARP_SIZE, rows_per_block};
235
+ uint32_t B = out.size() / (M * N);
236
+ uint32_t blocks_y = (N + rows_per_block - 1) / rows_per_block;
237
+ const uint32_t* mat_ptr = gpu_ptr<uint32_t>(mat);
238
+ const T* vec_ptr = gpu_ptr<T>(vec);
239
+ int n = 1;
240
+ if (K % 32 == 0 && cu::is_aligned<4>(mat_ptr) &&
241
+ ((bits == 4 && cu::is_aligned<8>(vec_ptr)) ||
242
+ cu::is_aligned<4>(vec_ptr))) {
243
+ n = 4;
244
+ } else if (
245
+ cu::is_aligned<2>(mat_ptr) &&
246
+ ((bits == 4 && cu::is_aligned<4>(vec_ptr)) ||
247
+ cu::is_aligned<2>(vec_ptr))) {
248
+ n = 2;
249
+ }
250
+ dispatch_1_2_4(n, [&](auto n) {
251
+ dispatch_bool(B > 1, [&](auto batched) {
252
+ if (!batched.value) {
253
+ auto kernel =
254
+ fp_qmv_single<T, rows_per_block, n.value, 4, 32, true>;
255
+ if (bits == 8) {
256
+ kernel = fp_qmv_single<T, rows_per_block, n.value, 8, 32, true>;
257
+ } else if (group_size == 16) {
258
+ kernel = fp_qmv_single<T, rows_per_block, n.value, 4, 16, false>;
259
+ }
260
+ encoder.add_kernel_node(
261
+ kernel,
262
+ {static_cast<uint32_t>(M), blocks_y},
263
+ block_dims,
264
+ 0,
265
+ mat_ptr,
266
+ gpu_ptr<uint8_t>(scales),
267
+ vec_ptr,
268
+ gpu_ptr<T>(out),
269
+ N,
270
+ K);
271
+ } else {
272
+ auto kernel =
273
+ fp_qmv_batched<T, rows_per_block, n.value, 4, 32, true>;
274
+ if (bits == 8) {
275
+ kernel = fp_qmv_batched<T, rows_per_block, n.value, 8, 32, true>;
276
+ } else if (group_size == 16) {
277
+ kernel = fp_qmv_batched<T, rows_per_block, n.value, 4, 16, false>;
278
+ }
279
+ encoder.add_kernel_node(
280
+ kernel,
281
+ {static_cast<uint32_t>(M), blocks_y, B},
282
+ block_dims,
283
+ 0,
284
+ mat_ptr,
285
+ gpu_ptr<uint8_t>(scales),
286
+ vec_ptr,
287
+ gpu_ptr<T>(out),
288
+ N,
289
+ K,
290
+ vec.ndim() - 2,
291
+ const_param(vec.shape()),
292
+ const_param(vec.strides()),
293
+ mat.ndim() - 2,
294
+ const_param(mat.shape()),
295
+ const_param(mat.strides()),
296
+ const_param(scales.strides()));
297
+ }
298
+ });
299
+ });
300
+ }
301
+ });
302
+ }
303
+
304
+ } // namespace mlx::core::cu
@@ -0,0 +1,21 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/cuda/device.h"
6
+
7
+ namespace mlx::core::cu {
8
+
9
+ void fp_qmv(
10
+ const array& w,
11
+ const array& scales,
12
+ const array& vec,
13
+ array& out,
14
+ int bits,
15
+ int group_size,
16
+ int M,
17
+ int N,
18
+ int K,
19
+ CommandEncoder& encoder);
20
+
21
+ } // namespace mlx::core::cu