mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,211 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/cuda/device/atomic_ops.cuh"
6
+ #include "mlx/backend/cuda/device/cast_op.cuh"
7
+ #include "mlx/backend/cuda/device/utils.cuh"
8
+ #include "mlx/backend/cuda/reduce/reduce_utils.cuh"
9
+
10
+ namespace mlx::core::cu {
11
+
12
+ // Reduce ops.
13
+ struct And {
14
+ __device__ __forceinline__ bool operator()(bool a, bool b) {
15
+ return a && b;
16
+ }
17
+
18
+ __device__ void atomic_update(bool* x, bool y) {
19
+ atomic_reduce<bool, And>(x, y);
20
+ }
21
+ };
22
+
23
+ struct Or {
24
+ __device__ __forceinline__ bool operator()(bool a, bool b) {
25
+ return a || b;
26
+ }
27
+
28
+ __device__ void atomic_update(bool* x, bool y) {
29
+ atomic_reduce<bool, Or>(x, y);
30
+ }
31
+ };
32
+
33
+ struct Sum {
34
+ template <typename T>
35
+ __device__ __forceinline__ T operator()(T a, T b) {
36
+ return a + b;
37
+ }
38
+
39
+ template <typename T>
40
+ __device__ void atomic_update(T* x, T y) {
41
+ atomic_reduce<T, Sum>(x, y);
42
+ }
43
+
44
+ __device__ void atomic_update(__nv_bfloat16* x, __nv_bfloat16 y) {
45
+ atomic_add(x, y);
46
+ }
47
+
48
+ __device__ void atomic_update(int* x, int y) {
49
+ atomic_add(x, y);
50
+ }
51
+
52
+ __device__ void atomic_update(float* x, float y) {
53
+ atomic_add(x, y);
54
+ }
55
+ };
56
+
57
+ struct Prod {
58
+ template <typename T>
59
+ __device__ __forceinline__ T operator()(T a, T b) {
60
+ return a * b;
61
+ }
62
+
63
+ template <typename T>
64
+ __device__ void atomic_update(T* x, T y) {
65
+ atomic_reduce<T, Prod>(x, y);
66
+ }
67
+ };
68
+
69
+ struct Min {
70
+ template <typename T>
71
+ __device__ __forceinline__ T operator()(T a, T b) {
72
+ if constexpr (is_complex_v<T>) {
73
+ if (cuda::std::isnan(a.real()) || cuda::std::isnan(a.imag())) {
74
+ return a;
75
+ }
76
+ if (cuda::std::isnan(b.real()) || cuda::std::isnan(b.imag())) {
77
+ return b;
78
+ }
79
+ } else if constexpr (!cuda::std::is_integral_v<T>) {
80
+ if (cuda::std::isnan(a) || cuda::std::isnan(b)) {
81
+ return cuda::std::numeric_limits<float>::quiet_NaN();
82
+ }
83
+ }
84
+ return a < b ? a : b;
85
+ }
86
+
87
+ template <typename T>
88
+ __device__ void atomic_update(T* x, T y) {
89
+ atomic_reduce<T, Min>(x, y);
90
+ }
91
+ };
92
+
93
+ struct Max {
94
+ template <typename T>
95
+ __device__ __forceinline__ T operator()(T a, T b) {
96
+ if constexpr (is_complex_v<T>) {
97
+ if (cuda::std::isnan(a.real()) || cuda::std::isnan(a.imag())) {
98
+ return a;
99
+ }
100
+ if (cuda::std::isnan(b.real()) || cuda::std::isnan(b.imag())) {
101
+ return b;
102
+ }
103
+ } else if constexpr (!cuda::std::is_integral_v<T>) {
104
+ if (cuda::std::isnan(a) || cuda::std::isnan(b)) {
105
+ return cuda::std::numeric_limits<float>::quiet_NaN();
106
+ }
107
+ }
108
+ return a > b ? a : b;
109
+ }
110
+
111
+ template <typename T>
112
+ __device__ void atomic_update(T* x, T y) {
113
+ atomic_reduce<T, Max>(x, y);
114
+ }
115
+ };
116
+
117
+ // Traits to get the result type of reduce op.
118
+ template <typename Op, typename T>
119
+ struct ReduceResult;
120
+
121
+ template <typename T>
122
+ struct ReduceResult<And, T> {
123
+ using type = bool;
124
+ };
125
+
126
+ template <typename T>
127
+ struct ReduceResult<Or, T> {
128
+ using type = bool;
129
+ };
130
+
131
+ template <typename T>
132
+ struct ReduceResult<Sum, T> {
133
+ using type = cuda::std::conditional_t<
134
+ (cuda::std::is_integral_v<T> && sizeof(T) <= 4),
135
+ int32_t,
136
+ T>;
137
+ };
138
+
139
+ template <typename T>
140
+ struct ReduceResult<Prod, T> {
141
+ using type = cuda::std::conditional_t<
142
+ (cuda::std::is_integral_v<T> && sizeof(T) <= 4),
143
+ int32_t,
144
+ T>;
145
+ };
146
+
147
+ template <typename T>
148
+ struct ReduceResult<Min, T> {
149
+ using type = T;
150
+ };
151
+
152
+ template <typename T>
153
+ struct ReduceResult<Max, T> {
154
+ using type = T;
155
+ };
156
+
157
+ // Traits to get the init value of reduce op.
158
+ template <typename Op, typename T>
159
+ struct ReduceInit;
160
+
161
+ template <typename T>
162
+ struct ReduceInit<And, T> {
163
+ static constexpr __host__ __device__ bool value() {
164
+ return true;
165
+ }
166
+ };
167
+
168
+ template <typename T>
169
+ struct ReduceInit<Or, T> {
170
+ static constexpr __host__ __device__ bool value() {
171
+ return false;
172
+ }
173
+ };
174
+
175
+ template <typename T>
176
+ struct ReduceInit<Sum, T> {
177
+ static constexpr __host__ __device__ auto value() {
178
+ if constexpr (is_complex_v<T>) {
179
+ return T{0, 0};
180
+ } else {
181
+ return cast_to<typename ReduceResult<Sum, T>::type>(0);
182
+ }
183
+ }
184
+ };
185
+
186
+ template <typename T>
187
+ struct ReduceInit<Prod, T> {
188
+ static constexpr __host__ __device__ auto value() {
189
+ if constexpr (is_complex_v<T>) {
190
+ return T{1, 0};
191
+ } else {
192
+ return cast_to<typename ReduceResult<Prod, T>::type>(1);
193
+ }
194
+ }
195
+ };
196
+
197
+ template <typename T>
198
+ struct ReduceInit<Min, T> {
199
+ static constexpr __host__ __device__ T value() {
200
+ return Limits<T>::max();
201
+ }
202
+ };
203
+
204
+ template <typename T>
205
+ struct ReduceInit<Max, T> {
206
+ static constexpr __host__ __device__ T value() {
207
+ return Limits<T>::min();
208
+ }
209
+ };
210
+
211
+ } // namespace mlx::core::cu
@@ -0,0 +1,145 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <numeric>
6
+
7
+ #include "mlx/backend/common/utils.h"
8
+ #include "mlx/backend/cuda/device.h"
9
+ #include "mlx/backend/cuda/device/utils.cuh"
10
+
11
+ #include <cooperative_groups.h>
12
+ #include <cooperative_groups/reduce.h>
13
+
14
+ namespace mlx::core {
15
+
16
+ namespace cu {
17
+
18
+ namespace cg = cooperative_groups;
19
+
20
+ template <size_t N>
21
+ struct uint_by_size;
22
+ template <>
23
+ struct uint_by_size<2> {
24
+ using type = uint16_t;
25
+ };
26
+ template <>
27
+ struct uint_by_size<4> {
28
+ using type = uint32_t;
29
+ };
30
+ template <>
31
+ struct uint_by_size<8> {
32
+ using type = unsigned long long int;
33
+ };
34
+
35
+ template <typename T, typename Op>
36
+ __device__ void atomic_reduce(T* x, T y) {
37
+ if constexpr (sizeof(T) == 1) {
38
+ using U = uint16_t;
39
+ U* x_int = (U*)((char*)x - ((size_t)x % 2));
40
+ int shift = ((char*)x - (char*)x_int) * 8;
41
+ int mask = 0xff << shift;
42
+ U old_val, new_val;
43
+ do {
44
+ old_val = *x_int;
45
+ T result = Op{}(static_cast<T>((old_val >> shift) & 0xff), y);
46
+ new_val = (old_val & ~mask) | (result << shift);
47
+ } while (atomicCAS(x_int, old_val, new_val) != old_val);
48
+ } else {
49
+ using U = typename uint_by_size<sizeof(T)>::type;
50
+ U* x_int = (U*)(x);
51
+ U old_val, new_val;
52
+ do {
53
+ old_val = *x_int;
54
+ T result = Op{}(*((T*)&old_val), y);
55
+ new_val = *((U*)&result);
56
+ } while (atomicCAS(x_int, old_val, new_val) != old_val);
57
+ }
58
+ }
59
+
60
+ template <typename T, int N, typename Block, typename Warp, typename Op>
61
+ inline __device__ void
62
+ block_reduce(Block block, Warp warp, T (&vals)[N], T* smem, Op op, T init) {
63
+ // First reduce in the current warp
64
+ for (int i = 0; i < N; i++) {
65
+ vals[i] = cg::reduce(warp, vals[i], op);
66
+ }
67
+
68
+ // Reduce across warps
69
+ if (warp.meta_group_size() > 1) {
70
+ if (warp.thread_rank() == 0) {
71
+ for (int i = 0; i < N; i++) {
72
+ smem[warp.meta_group_rank() * N + i] = vals[i];
73
+ }
74
+ }
75
+ block.sync();
76
+ if (warp.thread_rank() < warp.meta_group_size()) {
77
+ for (int i = 0; i < N; i++) {
78
+ vals[i] = smem[warp.thread_rank() * N + i];
79
+ }
80
+ } else {
81
+ for (int i = 0; i < N; i++) {
82
+ vals[i] = init;
83
+ }
84
+ }
85
+ for (int i = 0; i < N; i++) {
86
+ vals[i] = cg::reduce(warp, vals[i], op);
87
+ }
88
+ }
89
+ }
90
+
91
+ } // namespace cu
92
+
93
+ inline void allocate_same_layout(
94
+ array& out,
95
+ const array& in,
96
+ const std::vector<int>& axes,
97
+ cu::CommandEncoder& encoder) {
98
+ if (in.flags().row_contiguous) {
99
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
100
+ return;
101
+ }
102
+
103
+ if (out.ndim() < in.ndim()) {
104
+ throw std::runtime_error(
105
+ "Reduction without keepdims only supported for row-contiguous inputs");
106
+ }
107
+
108
+ // Calculate the transpositions applied to in in order to apply them to out.
109
+ std::vector<int> axis_order(in.ndim());
110
+ std::iota(axis_order.begin(), axis_order.end(), 0);
111
+ std::sort(axis_order.begin(), axis_order.end(), [&](int left, int right) {
112
+ return in.strides(left) > in.strides(right);
113
+ });
114
+
115
+ // Transpose the shape and calculate the strides
116
+ Shape out_shape(in.ndim());
117
+ Strides out_strides(in.ndim(), 1);
118
+ for (int i = 0; i < in.ndim(); i++) {
119
+ out_shape[i] = out.shape(axis_order[i]);
120
+ }
121
+ for (int i = in.ndim() - 2; i >= 0; i--) {
122
+ out_strides[i] = out_shape[i + 1] * out_strides[i + 1];
123
+ }
124
+
125
+ // Reverse the axis order to get the final strides
126
+ Strides final_strides(in.ndim());
127
+ for (int i = 0; i < in.ndim(); i++) {
128
+ final_strides[axis_order[i]] = out_strides[i];
129
+ }
130
+
131
+ // Calculate the resulting contiguity and do the memory allocation
132
+ auto [data_size, rc, cc] = check_contiguity(out.shape(), final_strides);
133
+ auto fl = in.flags();
134
+ fl.row_contiguous = rc;
135
+ fl.col_contiguous = cc;
136
+ fl.contiguous = true;
137
+ out.set_data(
138
+ cu::malloc_async(out.nbytes(), encoder),
139
+ data_size,
140
+ final_strides,
141
+ fl,
142
+ allocator::free);
143
+ }
144
+
145
+ } // namespace mlx::core
@@ -0,0 +1,361 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include <numeric>
4
+
5
+ #include "mlx/backend/cuda/device.h"
6
+ #include "mlx/backend/cuda/reduce/reduce.cuh"
7
+
8
+ #include <cooperative_groups.h>
9
+ #include <cooperative_groups/reduce.h>
10
+
11
+ namespace mlx::core {
12
+
13
+ namespace cu {
14
+
15
+ namespace cg = cooperative_groups;
16
+
17
+ struct RowReduceArgs {
18
+ // The size of the row being reduced, i.e. the size of last dimension.
19
+ int row_size;
20
+
21
+ // Input shape and strides excluding the reduction axes.
22
+ Shape shape;
23
+ Strides strides;
24
+ int ndim;
25
+
26
+ // Input shape and strides of the reduction axes excluding last dimension.
27
+ Shape reduce_shape;
28
+ Strides reduce_strides;
29
+ int reduce_ndim;
30
+
31
+ // The number of rows we are reducing. Namely prod(reduce_shape).
32
+ size_t non_row_reductions;
33
+
34
+ RowReduceArgs(
35
+ const array& in,
36
+ const ReductionPlan& plan,
37
+ const std::vector<int>& axes) {
38
+ assert(!plan.shape.empty());
39
+ row_size = plan.shape.back();
40
+
41
+ auto [shape_vec, strides_vec] = shapes_without_reduction_axes(in, axes);
42
+ std::tie(shape_vec, strides_vec) =
43
+ collapse_contiguous_dims(shape_vec, strides_vec);
44
+ shape = const_param(shape_vec);
45
+ strides = const_param(strides_vec);
46
+ ndim = shape_vec.size();
47
+
48
+ reduce_shape = const_param(plan.shape);
49
+ reduce_strides = const_param(plan.strides);
50
+ reduce_ndim = plan.shape.size() - 1;
51
+
52
+ non_row_reductions = 1;
53
+ for (int i = 0; i < reduce_ndim; i++) {
54
+ non_row_reductions *= reduce_shape[i];
55
+ }
56
+ }
57
+
58
+ // Convert shape and strides as if in was contiguous
59
+ void sort_access_pattern(const array& in, const std::vector<int>& axes) {
60
+ auto shape_vec = in.shape();
61
+ auto strides_vec = in.strides();
62
+ std::tie(shape_vec, strides_vec) =
63
+ shapes_without_reduction_axes(shape_vec, strides_vec, axes);
64
+ std::vector<int> indices(shape_vec.size());
65
+ std::iota(indices.begin(), indices.end(), 0);
66
+ std::sort(indices.begin(), indices.end(), [&](int left, int right) {
67
+ return strides_vec[left] > strides_vec[right];
68
+ });
69
+ decltype(shape_vec) sorted_shape;
70
+ decltype(strides_vec) sorted_strides;
71
+ for (auto idx : indices) {
72
+ sorted_shape.push_back(shape_vec[idx]);
73
+ sorted_strides.push_back(strides_vec[idx]);
74
+ }
75
+ std::tie(shape_vec, strides_vec) =
76
+ collapse_contiguous_dims(sorted_shape, sorted_strides);
77
+ shape = const_param(shape_vec);
78
+ strides = const_param(strides_vec);
79
+ ndim = shape_vec.size();
80
+ }
81
+ };
82
+
83
+ template <typename T, typename U, typename ReduceOp, int N = 4, int M = 1>
84
+ __global__ void
85
+ row_reduce_simple(const T* in, U* out, size_t n_rows, int size) {
86
+ auto grid = cg::this_grid();
87
+ auto block = cg::this_thread_block();
88
+ auto warp = cg::tiled_partition<WARP_SIZE>(block);
89
+
90
+ const U init = cu::ReduceInit<ReduceOp, T>::value();
91
+ ReduceOp op;
92
+
93
+ AlignedVector<T, N> vals[M];
94
+ AlignedVector<U, M> accs;
95
+ for (int i = 0; i < M; i++) {
96
+ accs[i] = init;
97
+ }
98
+
99
+ const size_t start_row =
100
+ min(n_rows - M, static_cast<size_t>(grid.block_rank() * M));
101
+ const size_t full_blocks = size / (block.size() * N);
102
+ const size_t final_offset = full_blocks * (block.size() * N);
103
+ in += start_row * size + block.thread_rank() * N;
104
+ out += start_row;
105
+
106
+ for (size_t r = 0; r < full_blocks; r++) {
107
+ for (int k = 0; k < M; k++) {
108
+ vals[k] = load_vector<N>(in + k * size, 0);
109
+ }
110
+ for (int k = 0; k < M; k++) {
111
+ for (int j = 0; j < N; j++) {
112
+ accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
113
+ }
114
+ }
115
+
116
+ in += block.size() * N;
117
+ }
118
+
119
+ if (final_offset < size) {
120
+ for (int k = 0; k < M; k++) {
121
+ for (int i = 0; i < N; i++) {
122
+ vals[k][i] = ((final_offset + block.thread_rank() * N + i) < size)
123
+ ? in[k * size + i]
124
+ : cast_to<T>(init);
125
+ }
126
+ }
127
+ for (int k = 0; k < M; k++) {
128
+ for (int j = 0; j < N; j++) {
129
+ accs[k] = op(accs[k], cast_to<U>(vals[k][j]));
130
+ }
131
+ }
132
+ }
133
+
134
+ __shared__ U shared_accumulators[32 * M];
135
+ block_reduce(block, warp, accs.val, shared_accumulators, op, init);
136
+
137
+ if (block.thread_rank() == 0) {
138
+ if (grid.block_rank() * M + M <= n_rows) {
139
+ store_vector(out, 0, accs);
140
+ } else {
141
+ short offset = grid.block_rank() * M + M - n_rows;
142
+ for (int i = offset; i < M; i++) {
143
+ out[i] = accs[i];
144
+ }
145
+ }
146
+ }
147
+ }
148
+
149
+ template <typename T, typename U, typename Op, int NDIM, int N_READS = 4>
150
+ __global__ void row_reduce_looped(
151
+ const T* in,
152
+ U* out,
153
+ const __grid_constant__ RowReduceArgs args) {
154
+ auto grid = cg::this_grid();
155
+ auto block = cg::this_thread_block();
156
+ auto warp = cg::tiled_partition<WARP_SIZE>(block);
157
+
158
+ size_t out_idx = grid.block_rank();
159
+
160
+ Op op;
161
+
162
+ U total[1];
163
+ U init = ReduceInit<Op, T>::value();
164
+ total[0] = init;
165
+ LoopedElemToLoc<NDIM, (NDIM > 2)> loop(args.reduce_ndim);
166
+ const size_t full_blocks = args.row_size / (block.size() * N_READS);
167
+ const size_t final_offset = full_blocks * (block.size() * N_READS);
168
+
169
+ in += elem_to_loc(out_idx, args.shape.data(), args.strides.data(), args.ndim);
170
+ in += block.thread_rank() * N_READS;
171
+
172
+ // Unaligned reduce
173
+ if (final_offset < args.row_size) {
174
+ bool mask[N_READS];
175
+ for (int i = 0; i < N_READS; i++) {
176
+ mask[i] =
177
+ (final_offset + block.thread_rank() * N_READS + i) < args.row_size;
178
+ }
179
+
180
+ for (size_t n = 0; n < args.non_row_reductions; n++) {
181
+ const T* inlocal = in + loop.location();
182
+
183
+ for (size_t r = 0; r < full_blocks; r++) {
184
+ auto vals = load_vector<N_READS>(inlocal, 0);
185
+ for (int i = 0; i < N_READS; i++) {
186
+ total[0] = op(total[0], cast_to<U>(vals[i]));
187
+ }
188
+ inlocal += block.size() * N_READS;
189
+ }
190
+
191
+ {
192
+ T vals[N_READS];
193
+ for (int i = 0; i < N_READS; i++) {
194
+ vals[i] = mask[i] ? inlocal[i] : cast_to<T>(init);
195
+ }
196
+ for (int i = 0; i < N_READS; i++) {
197
+ total[0] = op(total[0], cast_to<U>(vals[i]));
198
+ }
199
+ }
200
+
201
+ loop.next(args.reduce_shape.data(), args.reduce_strides.data());
202
+ }
203
+ }
204
+
205
+ // Aligned case
206
+ else {
207
+ for (size_t n = 0; n < args.non_row_reductions; n++) {
208
+ const T* inlocal = in + loop.location();
209
+
210
+ for (size_t r = 0; r < full_blocks; r++) {
211
+ auto vals = load_vector<N_READS>(inlocal, 0);
212
+ for (int i = 0; i < N_READS; i++) {
213
+ total[0] = op(total[0], cast_to<U>(vals[i]));
214
+ }
215
+ inlocal += block.size() * N_READS;
216
+ }
217
+
218
+ loop.next(args.reduce_shape.data(), args.reduce_strides.data());
219
+ }
220
+ }
221
+
222
+ __shared__ U shared_accumulators[32];
223
+ block_reduce(block, warp, total, shared_accumulators, op, init);
224
+
225
+ if (block.thread_rank() == 0) {
226
+ out[out_idx] = total[0];
227
+ }
228
+ }
229
+
230
+ } // namespace cu
231
+
232
+ void row_reduce_simple(
233
+ cu::CommandEncoder& encoder,
234
+ const array& in,
235
+ array& out,
236
+ Reduce::ReduceType reduce_type,
237
+ const std::vector<int>& axes,
238
+ const ReductionPlan& plan) {
239
+ // Allocate data for the output using in's layout to avoid elem_to_loc in the
240
+ // kernel.
241
+ allocate_same_layout(out, in, axes, encoder);
242
+
243
+ // TODO: If out.size() < 1024 which will be a common case then write this in
244
+ // 2 passes. Something like 32 * out.size() and then do a warp reduce.
245
+ encoder.set_input_array(in);
246
+ encoder.set_output_array(out);
247
+ dispatch_all_types(in.dtype(), [&](auto type_tag) {
248
+ dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
249
+ using OP = MLX_GET_TYPE(reduce_type_tag);
250
+ using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
251
+ using U = typename cu::ReduceResult<OP, T>::type;
252
+
253
+ constexpr int N_READS = 16 / sizeof(T);
254
+
255
+ // Calculate the grid and block dims
256
+ size_t reductions = (plan.shape.back() + N_READS - 1) / N_READS;
257
+ dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
258
+ int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
259
+ warps /= 4;
260
+ warps = std::max(std::min(warps, 32), 1);
261
+ int threads = warps * WARP_SIZE;
262
+ dim3 block(threads, 1, 1);
263
+
264
+ // Pick the kernel
265
+ auto kernel = cu::row_reduce_simple<T, U, OP, N_READS>;
266
+ if (grid.x >= 1024) {
267
+ grid.x = (grid.x + 1) / 2;
268
+ kernel = cu::row_reduce_simple<T, U, OP, N_READS, 2>;
269
+ }
270
+
271
+ T* indata = const_cast<T*>(gpu_ptr<T>(in));
272
+ int size = plan.shape.back();
273
+ encoder.add_kernel_node(
274
+ kernel, grid, block, 0, indata, gpu_ptr<U>(out), out.size(), size);
275
+ });
276
+ });
277
+ }
278
+
279
+ void row_reduce_looped(
280
+ cu::CommandEncoder& encoder,
281
+ const array& in,
282
+ array& out,
283
+ Reduce::ReduceType reduce_type,
284
+ const std::vector<int>& axes,
285
+ const ReductionPlan& plan,
286
+ cu::RowReduceArgs args) {
287
+ // Allocate data for the output using in's layout to access them as
288
+ // contiguously as possible.
289
+ allocate_same_layout(out, in, axes, encoder);
290
+
291
+ encoder.set_input_array(in);
292
+ encoder.set_output_array(out);
293
+ dispatch_all_types(in.dtype(), [&](auto type_tag) {
294
+ dispatch_reduce_ops(reduce_type, [&](auto reduce_type_tag) {
295
+ using OP = MLX_GET_TYPE(reduce_type_tag);
296
+ using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
297
+ using U = typename cu::ReduceResult<OP, T>::type;
298
+
299
+ constexpr int N_READS = 16 / sizeof(T);
300
+
301
+ // Calculate the grid and block dims
302
+ args.sort_access_pattern(in, axes);
303
+ dim3 grid = get_2d_grid_dims(out.shape(), out.strides());
304
+ size_t reductions = (args.row_size + N_READS - 1) / N_READS;
305
+ int warps = (reductions + WARP_SIZE - 1) / WARP_SIZE;
306
+ warps /= 4;
307
+ warps = std::max(std::min(warps, 32), 1);
308
+ int threads = warps * WARP_SIZE;
309
+ dim3 block(threads, 1, 1);
310
+
311
+ // Pick the kernel
312
+ auto kernel = cu::row_reduce_looped<T, U, OP, 1, N_READS>;
313
+ dispatch_reduce_ndim(args.reduce_ndim, [&](auto reduce_ndim) {
314
+ kernel = cu::row_reduce_looped<T, U, OP, reduce_ndim.value, N_READS>;
315
+ });
316
+
317
+ encoder.add_kernel_node(
318
+ kernel, grid, block, 0, gpu_ptr<T>(in), gpu_ptr<U>(out), args);
319
+ });
320
+ });
321
+ }
322
+
323
+ void row_reduce(
324
+ cu::CommandEncoder& encoder,
325
+ const array& in,
326
+ array& out,
327
+ Reduce::ReduceType reduce_type,
328
+ const std::vector<int>& axes,
329
+ const ReductionPlan& plan) {
330
+ // Current row reduction options
331
+ //
332
+ // - row_reduce_simple
333
+ //
334
+ // That means that we are simply reducing across the fastest moving axis.
335
+ // We are reducing 1 or 2 rows per threadblock depending on the size of
336
+ // output.
337
+ //
338
+ // - row_reduce_looped
339
+ //
340
+ // It is a general row reduction. We are computing 1 output per
341
+ // threadblock. We read the fastest moving axis vectorized and loop over
342
+ // the rest of the axes.
343
+ //
344
+ // Notes: We opt to read as much in order as possible and leave
345
+ // transpositions as they are (contrary to our Metal backend).
346
+
347
+ // Simple row reduce means that we have 1 axis that we are reducing over and
348
+ // it has stride 1.
349
+ if (plan.shape.size() == 1) {
350
+ row_reduce_simple(encoder, in, out, reduce_type, axes, plan);
351
+ return;
352
+ }
353
+
354
+ // Make the args struct to help route to the best kernel
355
+ cu::RowReduceArgs args(in, plan, axes);
356
+
357
+ // Fallback row reduce
358
+ row_reduce_looped(encoder, in, out, reduce_type, axes, plan, std::move(args));
359
+ }
360
+
361
+ } // namespace mlx::core