mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,796 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ // Required for using M_LOG2E in MSVC.
4
+ #define _USE_MATH_DEFINES
5
+
6
+ #include "mlx/backend/cuda/device.h"
7
+ #include "mlx/backend/cuda/device/config.h"
8
+ #include "mlx/backend/cuda/device/utils.cuh"
9
+ #include "mlx/backend/cuda/kernel_utils.cuh"
10
+ #include "mlx/backend/gpu/copy.h"
11
+ #include "mlx/dtype_utils.h"
12
+
13
+ #include <cooperative_groups.h>
14
+ #include <cooperative_groups/reduce.h>
15
+
16
+ namespace mlx::core {
17
+
18
+ namespace cu {
19
+
20
+ namespace cg = cooperative_groups;
21
+
22
+ #define PRAGMA_LOOP_UNROLL #pragma unroll
23
+
24
+ struct AttnParams {
25
+ int B;
26
+ int H;
27
+ int D;
28
+
29
+ int qL;
30
+ int kL;
31
+
32
+ int gqa_factor;
33
+ float scale;
34
+
35
+ int64_t Q_strides[3];
36
+ int64_t K_strides[3];
37
+ int64_t V_strides[3];
38
+ int64_t O_strides[3];
39
+ };
40
+
41
+ template <typename T, bool do_causal, int D>
42
+ __global__ void kernel_sdpav_1pass(
43
+ const T* Q,
44
+ const T* K,
45
+ const T* V,
46
+ T* O,
47
+ const T* sinks,
48
+ __grid_constant__ const AttnParams params) {
49
+ constexpr int BN = 32;
50
+ constexpr int BD = 32;
51
+
52
+ constexpr int v_per_thread = D / BD;
53
+
54
+ const int inner_k_stride = BN * int(params.K_strides[2]);
55
+ const int inner_v_stride = BN * int(params.V_strides[2]);
56
+
57
+ typedef float U;
58
+
59
+ U q[v_per_thread];
60
+ U k[v_per_thread];
61
+ U o[v_per_thread];
62
+
63
+ __shared__ U outputs[BN][BD + 1];
64
+ __shared__ U max_scores[BN];
65
+ __shared__ U sum_exp_scores[BN];
66
+
67
+ const U scale_log2 = params.scale * M_LOG2E;
68
+
69
+ auto block = cg::this_thread_block();
70
+ auto warp = cg::tiled_partition<32>(block);
71
+
72
+ const int lane_idx = warp.thread_rank();
73
+ const int warp_idx = warp.meta_group_rank();
74
+
75
+ // Adjust to thread block and thread
76
+ const int batch_idx = blockIdx.z;
77
+ const int head_idx = blockIdx.x;
78
+ const int kv_head_idx = head_idx / params.gqa_factor;
79
+
80
+ const int q_seq_idx = blockIdx.y;
81
+ const int kv_seq_idx = warp_idx;
82
+
83
+ Q += batch_idx * params.Q_strides[0] + // Batch
84
+ head_idx * params.Q_strides[1] + // Head
85
+ q_seq_idx * params.Q_strides[2]; // Sequence
86
+
87
+ K += batch_idx * params.K_strides[0] + // Batch
88
+ kv_head_idx * params.K_strides[1] + // Head
89
+ kv_seq_idx * params.K_strides[2]; // Sequence
90
+
91
+ V += batch_idx * params.V_strides[0] + // Batch
92
+ kv_head_idx * params.V_strides[1] + // Head
93
+ kv_seq_idx * params.V_strides[2]; // Sequence
94
+
95
+ O += batch_idx * params.O_strides[0] + // Batch
96
+ head_idx * params.O_strides[1] + // Head
97
+ q_seq_idx * params.O_strides[2]; // Sequence
98
+
99
+ // Read the query and 0 the output accumulator
100
+ PRAGMA_LOOP_UNROLL
101
+ for (int i = 0; i < v_per_thread; i++) {
102
+ q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
103
+ }
104
+
105
+ PRAGMA_LOOP_UNROLL
106
+ for (int i = 0; i < v_per_thread; i++) {
107
+ o[i] = 0.f;
108
+ }
109
+
110
+ U max_score = Limits<U>::finite_min();
111
+ U sum_exp_score = 0.f;
112
+ if (sinks && warp_idx == 0) {
113
+ max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
114
+ sum_exp_score = 1.f;
115
+ }
116
+
117
+ // For each key
118
+ for (int i = kv_seq_idx; i < params.kL; i += BN) {
119
+ bool use_key = true;
120
+ if constexpr (do_causal) {
121
+ use_key = i <= (params.kL - params.qL + q_seq_idx);
122
+ }
123
+
124
+ if (use_key) {
125
+ // Read the key
126
+ PRAGMA_LOOP_UNROLL
127
+ for (int j = 0; j < v_per_thread; j++) {
128
+ k[j] = K[v_per_thread * lane_idx + j];
129
+ }
130
+
131
+ // Compute the i-th score
132
+ U score = 0.f;
133
+ PRAGMA_LOOP_UNROLL
134
+ for (int j = 0; j < v_per_thread; j++) {
135
+ score += q[j] * k[j];
136
+ }
137
+
138
+ // Warp sum
139
+ score = cg::reduce(warp, score, cg::plus<U>());
140
+
141
+ // Update the accumulators
142
+ U new_max = max(max_score, score);
143
+ U factor = exp2f(max_score - new_max);
144
+ U exp_score = exp2f(score - new_max);
145
+
146
+ max_score = new_max;
147
+ sum_exp_score = sum_exp_score * factor + exp_score;
148
+
149
+ // Update the output accumulator
150
+ PRAGMA_LOOP_UNROLL
151
+ for (int j = 0; j < v_per_thread; j++) {
152
+ o[j] = o[j] * factor +
153
+ exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
154
+ }
155
+ }
156
+
157
+ // Move the pointers to the next kv
158
+ K += inner_k_stride;
159
+ V += inner_v_stride;
160
+ }
161
+
162
+ if (lane_idx == 0) {
163
+ max_scores[warp_idx] = max_score;
164
+ sum_exp_scores[warp_idx] = sum_exp_score;
165
+ }
166
+ block.sync();
167
+
168
+ max_score = max_scores[lane_idx];
169
+ U new_max = cg::reduce(warp, max_score, cg::greater<U>());
170
+ U factor = exp2f(max_score - new_max);
171
+ sum_exp_score =
172
+ cg::reduce(warp, sum_exp_scores[lane_idx] * factor, cg::plus<U>());
173
+ sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
174
+
175
+ // Now we need to aggregate all the outputs
176
+ PRAGMA_LOOP_UNROLL
177
+ for (int i = 0; i < v_per_thread; i++) {
178
+ outputs[lane_idx][warp_idx] = o[i];
179
+ block.sync();
180
+ U ot = outputs[warp_idx][lane_idx] * factor;
181
+ o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
182
+ block.sync();
183
+ }
184
+
185
+ // And write the output
186
+ if (lane_idx == 0) {
187
+ PRAGMA_LOOP_UNROLL
188
+ for (int i = 0; i < v_per_thread; i++) {
189
+ O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
190
+ }
191
+ }
192
+ }
193
+
194
+ template <typename T, bool do_causal, int D>
195
+ __global__ void kernel_sdpav_2pass_1(
196
+ const T* Q,
197
+ const T* K,
198
+ const T* V,
199
+ const T* sinks,
200
+ float* partials,
201
+ float* sums,
202
+ float* maxs,
203
+ __grid_constant__ const AttnParams params) {
204
+ constexpr int BN = 8;
205
+ constexpr int BD = 32;
206
+ constexpr int blocks = 32;
207
+
208
+ constexpr int v_per_thread = D / BD;
209
+
210
+ const int inner_k_stride = blocks * BN * int(params.K_strides[2]);
211
+ const int inner_v_stride = blocks * BN * int(params.V_strides[2]);
212
+
213
+ typedef float U;
214
+
215
+ U q[v_per_thread];
216
+ U k[v_per_thread];
217
+ U o[v_per_thread];
218
+
219
+ __shared__ U outputs[BN][BD + 1];
220
+ __shared__ U max_scores[BN];
221
+ __shared__ U sum_exp_scores[BN];
222
+
223
+ const U scale_log2 = params.scale * 1.44269504089f;
224
+
225
+ auto block = cg::this_thread_block();
226
+ auto warp = cg::tiled_partition<32>(block);
227
+
228
+ const int lane_idx = warp.thread_rank();
229
+ const int warp_idx = warp.meta_group_rank();
230
+
231
+ // Adjust to thread block and thread
232
+ const int batch_idx = blockIdx.z / blocks;
233
+ const int block_idx = blockIdx.z % blocks;
234
+ const int head_idx = blockIdx.x;
235
+ const int kv_head_idx = head_idx / params.gqa_factor;
236
+
237
+ const int q_seq_idx = blockIdx.y;
238
+ const int kv_seq_idx = block_idx * BN + warp_idx;
239
+
240
+ Q += batch_idx * params.Q_strides[0] + // Batch
241
+ head_idx * params.Q_strides[1] + // Head
242
+ q_seq_idx * params.Q_strides[2]; // Sequence
243
+
244
+ K += batch_idx * params.K_strides[0] + // Batch
245
+ kv_head_idx * params.K_strides[1] + // Head
246
+ kv_seq_idx * params.K_strides[2]; // Sequence
247
+
248
+ V += batch_idx * params.V_strides[0] + // Batch
249
+ kv_head_idx * params.V_strides[1] + // Head
250
+ kv_seq_idx * params.V_strides[2]; // Sequence
251
+
252
+ const int p_stride_s = blocks;
253
+ const int p_stride_h = params.qL * p_stride_s;
254
+ const int p_stride_b = params.H * p_stride_h;
255
+ const int p_offset = batch_idx * p_stride_b + // Batch
256
+ head_idx * p_stride_h + // Head
257
+ q_seq_idx * p_stride_s + // Sequence
258
+ block_idx; // Block
259
+
260
+ partials += p_offset * D;
261
+ sums += p_offset;
262
+ maxs += p_offset;
263
+
264
+ // Read the query and 0 the output accumulator
265
+ PRAGMA_LOOP_UNROLL
266
+ for (int i = 0; i < v_per_thread; i++) {
267
+ q[i] = scale_log2 * static_cast<U>(Q[v_per_thread * lane_idx + i]);
268
+ }
269
+
270
+ PRAGMA_LOOP_UNROLL
271
+ for (int i = 0; i < v_per_thread; i++) {
272
+ o[i] = 0.f;
273
+ }
274
+
275
+ U max_score = Limits<U>::finite_min();
276
+ U sum_exp_score = 0.f;
277
+ if (sinks && warp_idx == 0 && block_idx == 0) {
278
+ max_score = M_LOG2E * static_cast<U>(sinks[head_idx]);
279
+ sum_exp_score = 1.f;
280
+ }
281
+
282
+ // For each key
283
+ for (int i = kv_seq_idx; i < params.kL; i += blocks * BN) {
284
+ bool use_key = true;
285
+ if constexpr (do_causal) {
286
+ use_key = i <= (params.kL - params.qL + q_seq_idx);
287
+ }
288
+
289
+ if (use_key) {
290
+ // Read the key
291
+ PRAGMA_LOOP_UNROLL
292
+ for (int j = 0; j < v_per_thread; j++) {
293
+ k[j] = K[v_per_thread * lane_idx + j];
294
+ }
295
+
296
+ // Compute the i-th score
297
+ U score = 0.f;
298
+ PRAGMA_LOOP_UNROLL
299
+ for (int j = 0; j < v_per_thread; j++) {
300
+ score += q[j] * k[j];
301
+ }
302
+
303
+ // Warp sum
304
+ score = cg::reduce(warp, score, cg::plus<U>());
305
+
306
+ // Update the accumulators
307
+ U new_max = max(max_score, score);
308
+ U factor = exp2f(max_score - new_max);
309
+ U exp_score = exp2f(score - new_max);
310
+
311
+ max_score = new_max;
312
+ sum_exp_score = sum_exp_score * factor + exp_score;
313
+
314
+ // Update the output accumulator
315
+ PRAGMA_LOOP_UNROLL
316
+ for (int j = 0; j < v_per_thread; j++) {
317
+ o[j] = o[j] * factor +
318
+ exp_score * static_cast<U>(V[v_per_thread * lane_idx + j]);
319
+ }
320
+ }
321
+
322
+ // Move the pointers to the next kv
323
+ K += inner_k_stride;
324
+ V += inner_v_stride;
325
+ }
326
+
327
+ if (lane_idx == 0) {
328
+ max_scores[warp_idx] = max_score;
329
+ sum_exp_scores[warp_idx] = sum_exp_score;
330
+ }
331
+
332
+ block.sync();
333
+
334
+ max_score = (lane_idx < BN) ? max_scores[lane_idx] : -1e9;
335
+ U new_max = cg::reduce(warp, max_score, cg::greater<U>());
336
+ U factor = exp2f(max_score - new_max);
337
+ sum_exp_score = (lane_idx < BN) ? sum_exp_scores[lane_idx] : 0.f;
338
+ sum_exp_score = cg::reduce(warp, sum_exp_score * factor, cg::plus<U>());
339
+
340
+ // Write the sum and new max
341
+ if (warp_idx == 0) {
342
+ sums[0] = sum_exp_score;
343
+ maxs[0] = new_max;
344
+ }
345
+
346
+ // Now we need to aggregate all the outputs
347
+ auto ff = exp2f(max_scores[warp_idx] - new_max);
348
+ PRAGMA_LOOP_UNROLL
349
+ for (int i = 0; i < v_per_thread; i++) {
350
+ outputs[warp_idx][lane_idx] = o[i] * ff;
351
+ block.sync();
352
+
353
+ if (warp_idx == 0) {
354
+ U ot = outputs[0][lane_idx];
355
+ PRAGMA_LOOP_UNROLL
356
+ for (int j = 1; j < BN; j++) {
357
+ ot += outputs[j][lane_idx];
358
+ warp.sync();
359
+ }
360
+ o[i] = ot;
361
+ }
362
+ block.sync();
363
+ }
364
+
365
+ if (warp_idx == 0) {
366
+ PRAGMA_LOOP_UNROLL
367
+ for (int i = 0; i < v_per_thread; i++) {
368
+ partials[v_per_thread * lane_idx + i] = o[i];
369
+ }
370
+ }
371
+ }
372
+
373
+ template <typename T, bool do_causal, int D>
374
+ __global__ void kernel_sdpav_2pass_2(
375
+ const float* partials,
376
+ const float* sums,
377
+ const float* maxs,
378
+ T* O,
379
+ __grid_constant__ const AttnParams params) {
380
+ constexpr int BN = 32;
381
+ constexpr int BD = 32;
382
+ constexpr int blocks = 32;
383
+
384
+ constexpr int v_per_thread = D / BD;
385
+
386
+ typedef float U;
387
+
388
+ U o[v_per_thread];
389
+ __shared__ U outputs[BN][BD + 1];
390
+
391
+ auto block = cg::this_thread_block();
392
+ auto warp = cg::tiled_partition<32>(block);
393
+
394
+ const int lane_idx = warp.thread_rank();
395
+ const int warp_idx = warp.meta_group_rank();
396
+
397
+ // Adjust to thread block and thread
398
+ const int batch_idx = blockIdx.z;
399
+ const int head_idx = blockIdx.x;
400
+ const int q_seq_idx = blockIdx.y;
401
+
402
+ const int p_stride_s = blocks;
403
+ const int p_stride_h = params.qL * p_stride_s;
404
+ const int p_stride_b = params.H * p_stride_h;
405
+ const int p_offset = batch_idx * p_stride_b + // Batch
406
+ head_idx * p_stride_h + // Head
407
+ q_seq_idx * p_stride_s; // Sequence
408
+
409
+ partials += p_offset * D + warp_idx * D;
410
+ sums += p_offset;
411
+ maxs += p_offset;
412
+
413
+ O += batch_idx * params.O_strides[0] + // Batch
414
+ head_idx * params.O_strides[1] + // Head
415
+ q_seq_idx * params.O_strides[2]; // Sequence
416
+
417
+ U max_score = maxs[lane_idx];
418
+ U new_max = cg::reduce(warp, max_score, cg::greater<U>());
419
+ U factor = exp2f(max_score - new_max);
420
+ U sum_exp_score = cg::reduce(warp, sums[lane_idx] * factor, cg::plus<U>());
421
+ sum_exp_score = sum_exp_score == 0 ? 0 : __frcp_rn(sum_exp_score);
422
+
423
+ PRAGMA_LOOP_UNROLL
424
+ for (int i = 0; i < v_per_thread; i++) {
425
+ o[i] = partials[v_per_thread * lane_idx + i];
426
+ }
427
+
428
+ // Now we need to aggregate all the outputs
429
+ PRAGMA_LOOP_UNROLL
430
+ for (int i = 0; i < v_per_thread; i++) {
431
+ outputs[lane_idx][warp_idx] = o[i];
432
+ block.sync();
433
+ U ot = outputs[warp_idx][lane_idx] * factor;
434
+ o[i] = cg::reduce(warp, ot, cg::plus<U>()) * sum_exp_score;
435
+ block.sync();
436
+ }
437
+
438
+ // And write the output
439
+ if (lane_idx == 0) {
440
+ PRAGMA_LOOP_UNROLL
441
+ for (int i = 0; i < v_per_thread; i++) {
442
+ O[v_per_thread * warp_idx + i] = static_cast<T>(o[i]);
443
+ }
444
+ }
445
+ }
446
+
447
+ } // namespace cu
448
+
449
+ namespace {
450
+
451
+ template <typename F>
452
+ void dispatch_headdim(int n, F&& f) {
453
+ switch (n) {
454
+ case 64:
455
+ f(std::integral_constant<int, 64>{});
456
+ break;
457
+ case 96:
458
+ f(std::integral_constant<int, 96>{});
459
+ break;
460
+ case 128:
461
+ f(std::integral_constant<int, 128>{});
462
+ break;
463
+ }
464
+ }
465
+
466
+ void sdpa_vector_1pass_fallback(
467
+ const Stream& s,
468
+ cu::CommandEncoder& encoder,
469
+ const array& q,
470
+ const array& k,
471
+ const array& v,
472
+ const float scale,
473
+ array& o,
474
+ bool do_causal,
475
+ const std::optional<array>& sinks) {
476
+ encoder.set_input_array(q);
477
+ encoder.set_input_array(k);
478
+ encoder.set_input_array(v);
479
+ if (sinks) {
480
+ encoder.set_input_array(*sinks);
481
+ }
482
+ encoder.set_output_array(o);
483
+
484
+ cu::AttnParams params{
485
+ /* int B = */ q.shape(0),
486
+ /* int H = */ q.shape(1),
487
+ /* int D = */ q.shape(3),
488
+
489
+ /* int qL = */ q.shape(2),
490
+ /* int kL = */ k.shape(2),
491
+
492
+ /* int gqa_factor = */ q.shape(1) / k.shape(1),
493
+ /* float scale = */ scale,
494
+
495
+ /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
496
+ /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
497
+ /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
498
+ /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
499
+
500
+ dim3 grid_dim(params.H, params.qL, params.B);
501
+ dim3 block_dim(1024, 1, 1);
502
+
503
+ dispatch_float_types(o.dtype(), "kernel_sdpav_1pass", [&](auto type_tag) {
504
+ dispatch_bool(do_causal, [&](auto do_causal) {
505
+ dispatch_headdim(params.D, [&](auto headdim) {
506
+ using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
507
+
508
+ auto kernel =
509
+ cu::kernel_sdpav_1pass<DataType, do_causal.value, headdim.value>;
510
+ encoder.add_kernel_node(
511
+ kernel,
512
+ grid_dim,
513
+ block_dim,
514
+ 0,
515
+ gpu_ptr<DataType>(q),
516
+ gpu_ptr<DataType>(k),
517
+ gpu_ptr<DataType>(v),
518
+ gpu_ptr<DataType>(o),
519
+ sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
520
+ params);
521
+ });
522
+ });
523
+ });
524
+ }
525
+
526
+ void sdpa_vector_2pass_fallback(
527
+ const Stream& s,
528
+ cu::CommandEncoder& encoder,
529
+ const array& q,
530
+ const array& k,
531
+ const array& v,
532
+ const float scale,
533
+ array& o,
534
+ bool do_causal,
535
+ const std::optional<array>& sinks) {
536
+ cu::AttnParams params{
537
+ /* int B = */ q.shape(0),
538
+ /* int H = */ q.shape(1),
539
+ /* int D = */ q.shape(3),
540
+
541
+ /* int qL = */ q.shape(2),
542
+ /* int kL = */ k.shape(2),
543
+
544
+ /* int gqa_factor = */ q.shape(1) / k.shape(1),
545
+ /* float scale = */ scale,
546
+
547
+ /* int64_t Q_strides[3] = */ {q.strides(0), q.strides(1), q.strides(2)},
548
+ /* int64_t K_strides[3] = */ {k.strides(0), k.strides(1), k.strides(2)},
549
+ /* int64_t V_strides[3] = */ {v.strides(0), v.strides(1), v.strides(2)},
550
+ /* int64_t O_strides[3] = */ {o.strides(0), o.strides(1), o.strides(2)}};
551
+
552
+ // Allocate the intermediates
553
+ int blocks = 32;
554
+
555
+ Shape intermediate_shape;
556
+ intermediate_shape.reserve(o.ndim() + 1);
557
+ intermediate_shape.insert(
558
+ intermediate_shape.end(), o.shape().begin(), o.shape().end() - 1);
559
+ intermediate_shape.push_back(blocks);
560
+ intermediate_shape.push_back(o.shape().back());
561
+
562
+ array intermediate(intermediate_shape, float32, nullptr, {});
563
+ intermediate_shape.pop_back();
564
+ array sums(intermediate_shape, float32, nullptr, {});
565
+ array maxs(std::move(intermediate_shape), float32, nullptr, {});
566
+
567
+ intermediate.set_data(cu::malloc_async(intermediate.nbytes(), encoder));
568
+ sums.set_data(cu::malloc_async(sums.nbytes(), encoder));
569
+ maxs.set_data(cu::malloc_async(maxs.nbytes(), encoder));
570
+
571
+ encoder.add_temporary(intermediate);
572
+ encoder.add_temporary(sums);
573
+ encoder.add_temporary(maxs);
574
+
575
+ dispatch_float_types(o.dtype(), "kernel_sdpav_2pass", [&](auto type_tag) {
576
+ dispatch_bool(do_causal, [&](auto do_causal) {
577
+ dispatch_headdim(params.D, [&](auto headdim) {
578
+ using DataType = cuda_type_t<MLX_GET_TYPE(type_tag)>;
579
+
580
+ {
581
+ auto kernel = cu::
582
+ kernel_sdpav_2pass_1<DataType, do_causal.value, headdim.value>;
583
+
584
+ encoder.set_input_array(q);
585
+ encoder.set_input_array(k);
586
+ encoder.set_input_array(v);
587
+ if (sinks) {
588
+ encoder.set_input_array(*sinks);
589
+ }
590
+
591
+ encoder.set_output_array(intermediate);
592
+ encoder.set_output_array(sums);
593
+ encoder.set_output_array(maxs);
594
+
595
+ dim3 grid_dim(params.H, params.qL, params.B * 32);
596
+ dim3 block_dim(8 * 32, 1, 1);
597
+
598
+ encoder.add_kernel_node(
599
+ kernel,
600
+ grid_dim,
601
+ block_dim,
602
+ 0,
603
+ gpu_ptr<DataType>(q),
604
+ gpu_ptr<DataType>(k),
605
+ gpu_ptr<DataType>(v),
606
+ sinks ? gpu_ptr<DataType>(*sinks) : nullptr,
607
+ gpu_ptr<float>(intermediate),
608
+ gpu_ptr<float>(sums),
609
+ gpu_ptr<float>(maxs),
610
+ params);
611
+ }
612
+
613
+ {
614
+ auto kernel = cu::
615
+ kernel_sdpav_2pass_2<DataType, do_causal.value, headdim.value>;
616
+
617
+ encoder.set_input_array(intermediate);
618
+ encoder.set_input_array(sums);
619
+ encoder.set_input_array(maxs);
620
+ encoder.set_output_array(o);
621
+
622
+ dim3 grid_dim(params.H, params.qL, params.B);
623
+ dim3 block_dim(1024, 1, 1);
624
+
625
+ encoder.add_kernel_node(
626
+ kernel,
627
+ grid_dim,
628
+ block_dim,
629
+ 0,
630
+ gpu_ptr<float>(intermediate),
631
+ gpu_ptr<float>(sums),
632
+ gpu_ptr<float>(maxs),
633
+ gpu_ptr<DataType>(o),
634
+ params);
635
+ }
636
+ });
637
+ });
638
+ });
639
+ }
640
+
641
+ void sdpa_vector_fallback(
642
+ const Stream& s,
643
+ cu::CommandEncoder& encoder,
644
+ const array& q,
645
+ const array& k,
646
+ const array& v,
647
+ const float scale,
648
+ array& o,
649
+ bool do_causal,
650
+ const std::optional<array>& sinks) {
651
+ int kL = k.shape(2);
652
+
653
+ if (kL > 1024) {
654
+ return sdpa_vector_2pass_fallback(
655
+ s, encoder, q, k, v, scale, o, do_causal, sinks);
656
+ } else {
657
+ return sdpa_vector_1pass_fallback(
658
+ s, encoder, q, k, v, scale, o, do_causal, sinks);
659
+ }
660
+ }
661
+
662
+ } // namespace
663
+
664
+ bool supports_sdpa_vector(
665
+ const array& q,
666
+ const array& k,
667
+ const array& v,
668
+ bool has_arr_mask,
669
+ bool output_logsumexp) {
670
+ if (output_logsumexp) {
671
+ return false;
672
+ }
673
+
674
+ const int value_head_dim = v.shape(-1);
675
+ const int query_head_dim = q.shape(-1);
676
+ const int query_sequence_length = q.shape(2);
677
+ const int key_sequence_length = k.shape(2);
678
+
679
+ const bool sdpa_supported_head_dim = query_head_dim == value_head_dim &&
680
+ (query_head_dim == 64 || query_head_dim == 96 || query_head_dim == 128);
681
+
682
+ const bool supported_vector_config =
683
+ sdpa_supported_head_dim && query_sequence_length < 4;
684
+
685
+ return supported_vector_config && !has_arr_mask;
686
+ }
687
+
688
+ void sdpa_vector(
689
+ const array& q_pre,
690
+ const array& k_pre,
691
+ const array& v_pre,
692
+ float scale,
693
+ array& o,
694
+ bool do_causal,
695
+ const std::optional<array>& sinks_pre,
696
+ Stream s) {
697
+ auto& encoder = cu::get_command_encoder(s);
698
+ std::vector<array> copies;
699
+
700
+ // Define some copy functions to ensure the layout of the inputs is as
701
+ // expected.
702
+ copies.reserve(4);
703
+ auto copy_unless = [&copies, &s](
704
+ auto predicate, const array& arr) -> const array& {
705
+ if (!predicate(arr)) {
706
+ array arr_copy = contiguous_copy_gpu(arr, s);
707
+ copies.push_back(std::move(arr_copy));
708
+ return copies.back();
709
+ } else {
710
+ return arr;
711
+ }
712
+ };
713
+
714
+ // Checks that the headdim dimension has stride 1.
715
+ auto is_matrix_contiguous = [](const array& arr) {
716
+ return arr.strides(-1) == 1;
717
+ };
718
+
719
+ std::optional<array> sinks = std::nullopt;
720
+ if (sinks_pre) {
721
+ sinks = copy_unless(is_matrix_contiguous, sinks_pre.value());
722
+ }
723
+
724
+ // We are in vector mode ie single query
725
+ if (q_pre.shape(2) < 4) {
726
+ auto q_copy_unless = [](const array& arr) {
727
+ if (arr.flags().row_contiguous) {
728
+ return true;
729
+ }
730
+ auto& strides = arr.strides();
731
+ auto& shape = arr.shape();
732
+ if (shape[0] == 1 || shape[1] == 1) {
733
+ // If either the batch or head dimension is a singleton, the other can
734
+ // be transposed with the sequence dimension
735
+ auto bidx = shape[0] == 1 ? 1 : 0;
736
+ return (strides[3] == 1) && (strides[2] == shape[3] * shape[bidx]) &&
737
+ (strides[bidx] == shape[3]);
738
+ }
739
+ return false;
740
+ };
741
+
742
+ auto kv_copy_unless = [](const array& arr) {
743
+ // keys and values should be copied if:
744
+ // - the last dimension is not contiguous
745
+ // - the batch and head dim are not contiguous
746
+ auto& strides = arr.strides();
747
+ auto& shape = arr.shape();
748
+ if (strides.back() != 1) {
749
+ return false;
750
+ }
751
+ if (shape[0] == 1 || shape[1] == 1) {
752
+ return true;
753
+ }
754
+ return (strides[0] == strides[1] * shape[1]);
755
+ };
756
+
757
+ const auto& q = copy_unless(q_copy_unless, q_pre);
758
+ const auto& k = copy_unless(kv_copy_unless, k_pre);
759
+ const auto& v = copy_unless(kv_copy_unless, v_pre);
760
+
761
+ // Donate the query if possible
762
+ if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) {
763
+ o.copy_shared_buffer(q);
764
+ } else {
765
+ int64_t str_oD = 1;
766
+ int64_t str_oH = o.shape(3);
767
+ int64_t str_oL = o.shape(1) * str_oH;
768
+ int64_t str_oB = o.shape(2) * str_oL;
769
+
770
+ array::Flags flags{
771
+ /* bool contiguous = */ 1,
772
+ /* bool row_contiguous = */ o.shape(2) == 1,
773
+ /* bool col_contiguous = */ o.size() == o.shape(3),
774
+ };
775
+
776
+ o.set_data(
777
+ cu::malloc_async(o.nbytes(), encoder),
778
+ o.size(),
779
+ {str_oB, str_oH, str_oL, str_oD},
780
+ flags);
781
+ }
782
+
783
+ for (const auto& cp : copies) {
784
+ encoder.add_temporary(cp);
785
+ }
786
+
787
+ sdpa_vector_fallback(s, encoder, q, k, v, scale, o, do_causal, sinks);
788
+ }
789
+
790
+ // Full attention mode should never reach here
791
+ else {
792
+ throw std::runtime_error("Doesn't support matrix yet.");
793
+ }
794
+ }
795
+
796
+ } // namespace mlx::core