mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,1038 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <algorithm>
4
+ #include <cassert>
5
+
6
+ #include "mlx/backend/gpu/copy.h"
7
+ #include "mlx/backend/metal/device.h"
8
+ #include "mlx/backend/metal/kernels.h"
9
+ #include "mlx/backend/metal/kernels/defines.h"
10
+ #include "mlx/backend/metal/reduce.h"
11
+ #include "mlx/backend/metal/utils.h"
12
+ #include "mlx/primitives.h"
13
+ #include "mlx/utils.h"
14
+
15
+ namespace mlx::core {
16
+
17
+ namespace {
18
+
19
+ struct RowReduceArgs {
20
+ // Input shape and strides not including the reduction axes
21
+ Shape shape;
22
+ Strides strides;
23
+ int ndim;
24
+
25
+ // Input shape and strides for the reduction axes
26
+ Shape reduce_shape;
27
+ Strides reduce_strides;
28
+ int reduce_ndim;
29
+
30
+ // The number of rows we are reducing. Namely prod(reduce_shape).
31
+ size_t non_row_reductions;
32
+
33
+ // The size of the row.
34
+ size_t row_size;
35
+
36
+ RowReduceArgs(
37
+ const array& in,
38
+ const ReductionPlan& plan,
39
+ const std::vector<int>& axes) {
40
+ row_size = plan.shape.back();
41
+
42
+ reduce_shape = plan.shape;
43
+ reduce_strides = plan.strides;
44
+ reduce_shape.pop_back();
45
+ reduce_strides.pop_back();
46
+ reduce_ndim = reduce_shape.size();
47
+
48
+ non_row_reductions = 1;
49
+ for (auto s : reduce_shape) {
50
+ non_row_reductions *= s;
51
+ }
52
+
53
+ std::tie(shape, strides) = shapes_without_reduction_axes(in, axes);
54
+ std::tie(shape, strides) = collapse_contiguous_dims(shape, strides);
55
+ ndim = shape.size();
56
+ }
57
+
58
+ void encode(CommandEncoder& compute_encoder) {
59
+ // Push 0s to avoid encoding empty vectors.
60
+ if (reduce_ndim == 0) {
61
+ reduce_shape.push_back(0);
62
+ reduce_strides.push_back(0);
63
+ }
64
+ if (ndim == 0) {
65
+ shape.push_back(0);
66
+ strides.push_back(0);
67
+ }
68
+
69
+ compute_encoder.set_bytes(row_size, 2);
70
+ compute_encoder.set_bytes(non_row_reductions, 3);
71
+ compute_encoder.set_vector_bytes(shape, 4);
72
+ compute_encoder.set_vector_bytes(strides, 5);
73
+ compute_encoder.set_bytes(ndim, 6);
74
+ compute_encoder.set_vector_bytes(reduce_shape, 7);
75
+ compute_encoder.set_vector_bytes(reduce_strides, 8);
76
+ compute_encoder.set_bytes(reduce_ndim, 9);
77
+
78
+ if (reduce_ndim == 0) {
79
+ reduce_shape.pop_back();
80
+ reduce_strides.pop_back();
81
+ }
82
+ if (ndim == 0) {
83
+ shape.pop_back();
84
+ strides.pop_back();
85
+ }
86
+ }
87
+ };
88
+
89
+ struct ColReduceArgs {
90
+ // Input shape and strides not including the reduction axes
91
+ Shape shape;
92
+ Strides strides;
93
+ int ndim;
94
+
95
+ // Input shape and strides for the reduction axes
96
+ Shape reduce_shape;
97
+ Strides reduce_strides;
98
+ int reduce_ndim;
99
+
100
+ // The number of column reductions we are doing. Namely prod(reduce_shape).
101
+ size_t non_col_reductions;
102
+
103
+ // The size of the contiguous column reduction.
104
+ size_t reduction_size;
105
+ int64_t reduction_stride;
106
+
107
+ ColReduceArgs(
108
+ const array& in,
109
+ const ReductionPlan& plan,
110
+ const std::vector<int>& axes) {
111
+ reduction_size = plan.shape.back();
112
+ reduction_stride = plan.strides.back();
113
+
114
+ reduce_shape = plan.shape;
115
+ reduce_strides = plan.strides;
116
+ reduce_shape.pop_back();
117
+ reduce_strides.pop_back();
118
+ reduce_ndim = reduce_shape.size();
119
+
120
+ non_col_reductions = 1;
121
+ for (auto s : reduce_shape) {
122
+ non_col_reductions *= s;
123
+ }
124
+
125
+ // We 'll use a stride_back variable because strides.back() could be 0 but
126
+ // yet we may have removed the appropriate amount of elements. It is safe
127
+ // to compute the stride by multiplying shapes (while < reduction_stride)
128
+ // because it is a contiguous section.
129
+ int64_t stride_back = 1;
130
+ std::tie(shape, strides) = shapes_without_reduction_axes(in, axes);
131
+ while (!shape.empty() && stride_back < reduction_stride) {
132
+ stride_back *= shape.back();
133
+ shape.pop_back();
134
+ strides.pop_back();
135
+ }
136
+ std::tie(shape, strides) = collapse_contiguous_dims(shape, strides);
137
+ ndim = shape.size();
138
+ }
139
+
140
+ /**
141
+ * Create the col reduce arguments for reducing the 1st axis of the row
142
+ * contiguous intermediate array.
143
+ */
144
+ ColReduceArgs(const array& intermediate) {
145
+ assert(intermediate.flags().row_contiguous);
146
+
147
+ reduction_size = intermediate.shape(0);
148
+ reduction_stride = intermediate.size() / reduction_size;
149
+ non_col_reductions = 1;
150
+ reduce_ndim = 0;
151
+ ndim = 0;
152
+ }
153
+
154
+ void encode(CommandEncoder& compute_encoder) {
155
+ // Push 0s to avoid encoding empty vectors.
156
+ if (reduce_ndim == 0) {
157
+ reduce_shape.push_back(0);
158
+ reduce_strides.push_back(0);
159
+ }
160
+ if (ndim == 0) {
161
+ shape.push_back(0);
162
+ strides.push_back(0);
163
+ }
164
+
165
+ compute_encoder.set_bytes(reduction_size, 2);
166
+ compute_encoder.set_bytes(reduction_stride, 3);
167
+ compute_encoder.set_vector_bytes(shape, 4);
168
+ compute_encoder.set_vector_bytes(strides, 5);
169
+ compute_encoder.set_bytes(ndim, 6);
170
+ compute_encoder.set_vector_bytes(reduce_shape, 7);
171
+ compute_encoder.set_vector_bytes(reduce_strides, 8);
172
+ compute_encoder.set_bytes(reduce_ndim, 9);
173
+ compute_encoder.set_bytes(non_col_reductions, 10);
174
+
175
+ if (reduce_ndim == 0) {
176
+ reduce_shape.pop_back();
177
+ reduce_strides.pop_back();
178
+ }
179
+ if (ndim == 0) {
180
+ shape.pop_back();
181
+ strides.pop_back();
182
+ }
183
+ }
184
+ };
185
+
186
+ } // namespace
187
+
188
+ inline auto safe_div(size_t n, size_t m) {
189
+ return m == 0 ? 0 : (n + m - 1) / m;
190
+ }
191
+
192
+ inline auto safe_divup(size_t n, size_t m) {
193
+ return safe_div(n, m) * m;
194
+ }
195
+
196
+ inline bool is_64b_int(Dtype dtype) {
197
+ return dtype == int64 || dtype == uint64;
198
+ }
199
+
200
+ inline bool is_64b_dtype(Dtype dtype) {
201
+ return dtype == int64 || dtype == uint64 || dtype == complex64;
202
+ }
203
+
204
+ inline int get_kernel_reduce_ndim(int reduce_ndim) {
205
+ if (reduce_ndim <= 1) {
206
+ return 1;
207
+ } else if (reduce_ndim == 2) {
208
+ return 2;
209
+ } else {
210
+ return 5;
211
+ }
212
+ }
213
+
214
+ inline int threadgroup_size_from_row_size(int row_size) {
215
+ // 1 simdgroup per row smallish rows
216
+ if (row_size <= 512) {
217
+ return 32;
218
+ }
219
+
220
+ // 2 simdgroups per row for medium rows
221
+ if (row_size <= 1024) {
222
+ return 128;
223
+ }
224
+
225
+ // up to 32 simdgroups after that
226
+ int thread_group_size;
227
+ thread_group_size = (row_size + REDUCE_N_READS - 1) / REDUCE_N_READS;
228
+ thread_group_size = ((thread_group_size + 31) / 32) * 32;
229
+ thread_group_size = std::min(1024, thread_group_size);
230
+ return thread_group_size;
231
+ }
232
+
233
+ inline auto output_grid_for_col_reduce(
234
+ const array& out,
235
+ const ColReduceArgs& args) {
236
+ auto out_shape = out.shape();
237
+ auto out_strides = out.strides();
238
+ while (!out_shape.empty() && out_strides.back() < args.reduction_stride) {
239
+ out_shape.pop_back();
240
+ out_strides.pop_back();
241
+ }
242
+ return get_2d_grid_dims(out_shape, out_strides);
243
+ }
244
+
245
+ std::pair<Dtype, Dtype> remap_reduce_types(
246
+ const array& in,
247
+ const std::string& op_name) {
248
+ if (op_name == "sum" || op_name == "prod") {
249
+ if (issubdtype(in.dtype(), integer)) {
250
+ switch (in.dtype()) {
251
+ case uint8:
252
+ return {uint8, uint32};
253
+ case uint16:
254
+ return {uint16, uint32};
255
+ case uint32:
256
+ return {uint32, uint32};
257
+ case uint64:
258
+ return {uint64, uint64};
259
+ case int8:
260
+ return {int8, int32};
261
+ case int16:
262
+ return {int16, int32};
263
+ case int32:
264
+ return {int32, int32};
265
+ case int64:
266
+ return {int64, int64};
267
+ default:
268
+ throw std::runtime_error("Unsupported integer type");
269
+ }
270
+ }
271
+ if (in.dtype() == bool_) {
272
+ return {int8, int32};
273
+ }
274
+ return {in.dtype(), in.dtype()};
275
+ } else if (op_name == "and" || op_name == "or") {
276
+ if (in.dtype().size() == 1) {
277
+ return {bool_, bool_};
278
+ } else if (in.dtype().size() == 2) {
279
+ return {int16, bool_};
280
+ } else if (in.dtype().size() == 4) {
281
+ return {int32, bool_};
282
+ } else {
283
+ return {int64, bool_};
284
+ }
285
+ }
286
+ return {in.dtype(), in.dtype()};
287
+ }
288
+
289
+ void init_reduce(
290
+ array& out,
291
+ const std::string& op_name,
292
+ CommandEncoder& compute_encoder,
293
+ metal::Device& d,
294
+ const Stream& s) {
295
+ auto [_, out_type] = remap_reduce_types(out, op_name);
296
+ const std::string func_name = "init_reduce";
297
+ std::string kname = func_name;
298
+ concatenate(kname, "_", op_name, type_to_name(out_type));
299
+ auto kernel = get_reduce_init_kernel(d, kname, func_name, op_name, out_type);
300
+ size_t nthreads = out.size();
301
+ MTL::Size grid_dims = MTL::Size(nthreads, 1, 1);
302
+ NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
303
+ if (thread_group_size > nthreads) {
304
+ thread_group_size = nthreads;
305
+ }
306
+ MTL::Size group_dims = MTL::Size(thread_group_size, 1, 1);
307
+ compute_encoder.set_compute_pipeline_state(kernel);
308
+ compute_encoder.set_output_array(out, 0);
309
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
310
+ }
311
+
312
+ void all_reduce_dispatch(
313
+ const array& in,
314
+ array& out,
315
+ const std::string& op_name,
316
+ CommandEncoder& compute_encoder,
317
+ metal::Device& d,
318
+ const Stream& s) {
319
+ // Set the kernel
320
+ auto [in_type, out_type] = remap_reduce_types(in, op_name);
321
+ const std::string func_name = "all_reduce";
322
+ std::string kname = func_name;
323
+ concatenate(kname, "_", op_name, type_to_name(in_type));
324
+ auto kernel = get_reduce_kernel(
325
+ d, kname, func_name, op_name, in_type, out_type, "int64_t");
326
+ compute_encoder.set_compute_pipeline_state(kernel);
327
+
328
+ size_t in_size = in.size();
329
+
330
+ // Small array so dispatch a single threadgroup
331
+ if (in_size <= REDUCE_N_READS * 1024) {
332
+ int threadgroup_size = (in_size + REDUCE_N_READS - 1) / REDUCE_N_READS;
333
+ threadgroup_size = ((threadgroup_size + 31) / 32) * 32;
334
+ MTL::Size grid_dims(threadgroup_size, 1, 1);
335
+
336
+ compute_encoder.set_input_array(in, 0);
337
+ compute_encoder.set_output_array(out, 1);
338
+ compute_encoder.set_bytes(in_size, 2);
339
+ compute_encoder.set_bytes(in_size, 3);
340
+ compute_encoder.dispatch_threads(grid_dims, grid_dims);
341
+ }
342
+
343
+ // We need multiple threadgroups so we 'll do it in 2 passes.
344
+ else {
345
+ int n_rows, threadgroup_2nd_pass;
346
+ // Less than 2**26 bytes
347
+ if (in.nbytes() <= (1 << 26)) {
348
+ n_rows = 32 * REDUCE_N_READS;
349
+ threadgroup_2nd_pass = 32;
350
+ }
351
+
352
+ // Really large matrix so parallelize as much as possible
353
+ else {
354
+ n_rows = 1024 * REDUCE_N_READS;
355
+ threadgroup_2nd_pass = 1024;
356
+ }
357
+
358
+ // Allocate an intermediate tensor to hold results if needed
359
+ array intermediate({n_rows}, out_type, nullptr, {});
360
+ intermediate.set_data(allocator::malloc(intermediate.nbytes()));
361
+ d.add_temporary(intermediate, s.index);
362
+
363
+ // 1st pass
364
+ size_t row_size = (in_size + n_rows - 1) / n_rows;
365
+ int threadgroup_size =
366
+ std::min((row_size + REDUCE_N_READS - 1) / REDUCE_N_READS, 1024ul);
367
+ threadgroup_size = ((threadgroup_size + 31) / 32) * 32;
368
+ MTL::Size grid_dims(threadgroup_size, n_rows, 1);
369
+ MTL::Size group_dims(threadgroup_size, 1, 1);
370
+ compute_encoder.set_input_array(in, 0);
371
+ compute_encoder.set_output_array(intermediate, 1);
372
+ compute_encoder.set_bytes(in_size, 2);
373
+ compute_encoder.set_bytes(row_size, 3);
374
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
375
+
376
+ // 2nd pass
377
+ std::string kname_2nd_pass = func_name;
378
+ concatenate(kname_2nd_pass, "_", op_name, type_to_name(intermediate));
379
+ auto kernel_2nd_pass = get_reduce_kernel(
380
+ d, kname_2nd_pass, func_name, op_name, out_type, out_type, "int64_t");
381
+ compute_encoder.set_compute_pipeline_state(kernel_2nd_pass);
382
+ size_t intermediate_size = n_rows;
383
+ grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
384
+ group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
385
+ compute_encoder.set_input_array(intermediate, 0);
386
+ compute_encoder.set_output_array(out, 1);
387
+ compute_encoder.set_bytes(intermediate_size, 2);
388
+ compute_encoder.set_bytes(intermediate_size, 3);
389
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
390
+ }
391
+ }
392
+
393
+ void row_reduce_small(
394
+ const array& in,
395
+ array& out,
396
+ const std::string& op_name,
397
+ RowReduceArgs& args,
398
+ CommandEncoder& compute_encoder,
399
+ metal::Device& d,
400
+ const Stream& s) {
401
+ // Set the kernel
402
+ int n = get_kernel_reduce_ndim(args.reduce_ndim);
403
+ auto [in_type, out_type] = remap_reduce_types(in, op_name);
404
+ const std::string func_name = "row_reduce_small";
405
+ std::string kname = func_name;
406
+ bool large = in.size() > INT32_MAX;
407
+ if (large) {
408
+ kname += "_large";
409
+ }
410
+ concatenate(
411
+ kname,
412
+ "_",
413
+ std::to_string(n),
414
+ "_reduce_",
415
+ op_name,
416
+ type_to_name(in_type));
417
+ auto kernel = get_reduce_kernel(
418
+ d,
419
+ kname,
420
+ func_name,
421
+ op_name,
422
+ in_type,
423
+ out_type,
424
+ large ? "size_t" : "int",
425
+ n);
426
+ compute_encoder.set_compute_pipeline_state(kernel);
427
+
428
+ // Figure out the grid dims
429
+ MTL::Size grid_dims;
430
+ MTL::Size group_dims;
431
+ if ((args.non_row_reductions < 32 && args.row_size <= 8) ||
432
+ args.non_row_reductions <= 8) {
433
+ grid_dims = get_2d_grid_dims(out.shape(), out.strides());
434
+ group_dims =
435
+ MTL::Size((grid_dims.width < 1024) ? grid_dims.width : 1024, 1, 1);
436
+ } else {
437
+ auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());
438
+ grid_dims = MTL::Size(32, out_grid_size.width, out_grid_size.height);
439
+ group_dims = MTL::Size(32, 1, 1);
440
+ }
441
+
442
+ // Launch
443
+ compute_encoder.set_input_array(in, 0);
444
+ compute_encoder.set_output_array(out, 1);
445
+ args.encode(compute_encoder);
446
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
447
+ }
448
+
449
+ void row_reduce_simple(
450
+ const array& in,
451
+ array& out,
452
+ const std::string& op_name,
453
+ RowReduceArgs& args,
454
+ CommandEncoder& compute_encoder,
455
+ metal::Device& d,
456
+ const Stream& s) {
457
+ // Set the kernel
458
+ auto [in_type, out_type] = remap_reduce_types(in, op_name);
459
+ const std::string func_name = "row_reduce_simple";
460
+ std::string kname = func_name;
461
+ concatenate(kname, "_", op_name, type_to_name(in_type));
462
+
463
+ auto kernel = get_reduce_kernel(
464
+ d, kname, func_name, op_name, in_type, out_type, "size_t");
465
+ compute_encoder.set_compute_pipeline_state(kernel);
466
+
467
+ // Figure out the grid dims
468
+ size_t row_size = args.row_size;
469
+ size_t out_size = out.size();
470
+ auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());
471
+ out_grid_size.width =
472
+ (out_grid_size.width + REDUCE_N_WRITES - 1) / REDUCE_N_WRITES;
473
+ int threadgroup_size = threadgroup_size_from_row_size(row_size);
474
+ if (in.itemsize() == 8) {
475
+ threadgroup_size = std::min(threadgroup_size, 512);
476
+ }
477
+ MTL::Size grid_dims(
478
+ threadgroup_size, out_grid_size.width, out_grid_size.height);
479
+ MTL::Size group_dims(threadgroup_size, 1, 1);
480
+
481
+ // Launch
482
+ compute_encoder.set_input_array(in, 0);
483
+ compute_encoder.set_output_array(out, 1);
484
+ compute_encoder.set_bytes(row_size, 2);
485
+ compute_encoder.set_bytes(out_size, 3);
486
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
487
+ }
488
+
489
+ void row_reduce_looped(
490
+ const array& in,
491
+ array& out,
492
+ const std::string& op_name,
493
+ RowReduceArgs& args,
494
+ CommandEncoder& compute_encoder,
495
+ metal::Device& d,
496
+ const Stream& s) {
497
+ auto [in_type, out_type] = remap_reduce_types(in, op_name);
498
+
499
+ // Set the kernel
500
+ int n = get_kernel_reduce_ndim(args.reduce_ndim);
501
+ const std::string func_name = "row_reduce_looped";
502
+ std::string kname = func_name;
503
+ bool large = in.size() > INT32_MAX;
504
+ if (large) {
505
+ kname += "_large";
506
+ }
507
+ concatenate(
508
+ kname,
509
+ "_",
510
+ std::to_string(n),
511
+ "_reduce_",
512
+ op_name,
513
+ type_to_name(in_type));
514
+ auto kernel = get_reduce_kernel(
515
+ d,
516
+ kname,
517
+ func_name,
518
+ op_name,
519
+ in_type,
520
+ out_type,
521
+ large ? "size_t" : "int",
522
+ n);
523
+ compute_encoder.set_compute_pipeline_state(kernel);
524
+
525
+ // Figure out the grid
526
+ auto out_grid_size = get_2d_grid_dims(out.shape(), out.strides());
527
+ int threadgroup_size = threadgroup_size_from_row_size(args.row_size);
528
+ MTL::Size grid_dims(
529
+ threadgroup_size, out_grid_size.width, out_grid_size.height);
530
+ MTL::Size group_dims(threadgroup_size, 1, 1);
531
+
532
+ // Launch
533
+ compute_encoder.set_input_array(in, 0);
534
+ compute_encoder.set_output_array(out, 1);
535
+ args.encode(compute_encoder);
536
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
537
+ }
538
+
539
+ void row_reduce_general_dispatch(
540
+ const array& in,
541
+ array& out,
542
+ const std::string& op_name,
543
+ const ReductionPlan& plan,
544
+ const std::vector<int>& axes,
545
+ CommandEncoder& compute_encoder,
546
+ metal::Device& d,
547
+ const Stream& s) {
548
+ // Prepare the arguments for the kernel
549
+ RowReduceArgs args(in, plan, axes);
550
+
551
+ // Case 1: The row is small
552
+ if (args.row_size <= 64) {
553
+ return row_reduce_small(in, out, op_name, args, compute_encoder, d, s);
554
+ }
555
+
556
+ // Case 2: Contiguous reduce without non-row reductions
557
+ if (plan.type == ContiguousReduce && args.reduce_ndim == 0 &&
558
+ in.size() / args.row_size >= 32) {
559
+ return row_reduce_simple(in, out, op_name, args, compute_encoder, d, s);
560
+ }
561
+
562
+ // Case 3: General row reduce including non-row reductions
563
+ return row_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
564
+ }
565
+
566
+ void strided_reduce_small(
567
+ const array& in,
568
+ array& out,
569
+ const std::string& op_name,
570
+ ColReduceArgs& args,
571
+ CommandEncoder& compute_encoder,
572
+ metal::Device& d,
573
+ const Stream& s) {
574
+ auto [in_type, out_type] = remap_reduce_types(in, op_name);
575
+
576
+ // Figure out the grid dims
577
+ MTL::Size grid_dims, group_dims;
578
+
579
+ // Prepare the arguments for the kernel
580
+ args.reduce_shape.push_back(args.reduction_size);
581
+ args.reduce_strides.push_back(args.reduction_stride);
582
+ args.reduce_ndim++;
583
+
584
+ int n = get_kernel_reduce_ndim(args.reduce_ndim);
585
+ const std::string func_name = "col_reduce_small";
586
+ std::string kname = func_name;
587
+ bool large = in.size() > INT32_MAX;
588
+ if (large) {
589
+ kname += "_large";
590
+ }
591
+ concatenate(
592
+ kname,
593
+ "_",
594
+ std::to_string(n),
595
+ "_reduce_",
596
+ op_name,
597
+ type_to_name(in_type));
598
+ auto kernel = get_reduce_kernel(
599
+ d,
600
+ kname,
601
+ func_name,
602
+ op_name,
603
+ in_type,
604
+ out_type,
605
+ large ? "size_t" : "int",
606
+ n);
607
+ compute_encoder.set_compute_pipeline_state(kernel);
608
+
609
+ const int n_reads = 4;
610
+ size_t reduction_stride_blocks =
611
+ (args.reduction_stride + n_reads - 1) / n_reads;
612
+ size_t total = args.reduction_size * args.non_col_reductions;
613
+ size_t threadgroup_x = std::min(reduction_stride_blocks, 32ul);
614
+ size_t threadgroup_y = std::min(
615
+ 8ul,
616
+ std::min(kernel->maxTotalThreadsPerThreadgroup() / threadgroup_x, total));
617
+
618
+ group_dims = MTL::Size(threadgroup_x, threadgroup_y, 1);
619
+ grid_dims = output_grid_for_col_reduce(out, args);
620
+ grid_dims = MTL::Size(
621
+ (reduction_stride_blocks + threadgroup_x - 1) / threadgroup_x,
622
+ grid_dims.width,
623
+ grid_dims.height);
624
+
625
+ // Launch
626
+ compute_encoder.set_input_array(in, 0);
627
+ compute_encoder.set_output_array(out, 1);
628
+ args.encode(compute_encoder);
629
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
630
+ }
631
+
632
+ void strided_reduce_longcolumn(
633
+ const array& in,
634
+ array& out,
635
+ const std::string& op_name,
636
+ ColReduceArgs& args,
637
+ CommandEncoder& compute_encoder,
638
+ metal::Device& d,
639
+ const Stream& s) {
640
+ auto [in_type, out_type] = remap_reduce_types(in, op_name);
641
+ size_t total_reduction_size = args.reduction_size * args.non_col_reductions;
642
+ size_t outer_blocks = 32;
643
+ if (total_reduction_size >= 32768) {
644
+ outer_blocks = 128;
645
+ }
646
+
647
+ // Prepare the temporary accumulator
648
+ Shape intermediate_shape;
649
+ intermediate_shape.reserve(out.ndim() + 1);
650
+ intermediate_shape.push_back(outer_blocks);
651
+ intermediate_shape.insert(
652
+ intermediate_shape.end(), out.shape().begin(), out.shape().end());
653
+ array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
654
+ intermediate.set_data(allocator::malloc(intermediate.nbytes()));
655
+ d.add_temporary(intermediate, s.index);
656
+
657
+ // Prepare the arguments for the kernel
658
+ args.reduce_shape.push_back(args.reduction_size);
659
+ args.reduce_strides.push_back(args.reduction_stride);
660
+ args.reduce_ndim++;
661
+
662
+ // Figure out the grid dims
663
+ size_t out_size = out.size();
664
+ size_t threadgroup_x = args.reduction_stride;
665
+ size_t threadgroup_y =
666
+ (args.non_col_reductions * args.reduction_size + outer_blocks - 1) /
667
+ outer_blocks;
668
+ threadgroup_y = std::min(32ul, threadgroup_y);
669
+
670
+ auto out_grid_size = output_grid_for_col_reduce(out, args);
671
+ MTL::Size grid_dims(out_grid_size.width, out_grid_size.height, outer_blocks);
672
+ MTL::Size group_dims(threadgroup_x, threadgroup_y, 1);
673
+
674
+ // Set the kernel
675
+ int n = get_kernel_reduce_ndim(args.reduce_ndim);
676
+ std::string func_name = "col_reduce_longcolumn";
677
+ std::string kname = func_name;
678
+ bool large = in.size() > INT32_MAX;
679
+ if (large) {
680
+ kname += "_large";
681
+ }
682
+ concatenate(
683
+ kname,
684
+ "_",
685
+ std::to_string(n),
686
+ "_reduce_",
687
+ op_name,
688
+ type_to_name(in_type));
689
+ auto kernel = get_reduce_kernel(
690
+ d,
691
+ kname,
692
+ func_name,
693
+ op_name,
694
+ in_type,
695
+ out_type,
696
+ large ? "int64_t" : "int",
697
+ n);
698
+ compute_encoder.set_compute_pipeline_state(kernel);
699
+
700
+ // Launch
701
+ compute_encoder.set_input_array(in, 0);
702
+ compute_encoder.set_output_array(intermediate, 1);
703
+ args.encode(compute_encoder);
704
+ compute_encoder.set_bytes(out_size, 11);
705
+ compute_encoder.dispatch_threadgroups(grid_dims, group_dims);
706
+
707
+ // Make the 2nd pass arguments and grid_dims
708
+ ColReduceArgs second_args(intermediate);
709
+ second_args.reduce_shape.push_back(outer_blocks);
710
+ second_args.reduce_strides.push_back(out.size());
711
+ second_args.reduce_ndim++;
712
+ int BN = 32;
713
+ grid_dims = MTL::Size(256 * ((out.size() + BN - 1) / BN), 1, 1);
714
+ group_dims = MTL::Size(256, 1, 1);
715
+
716
+ // Set the 2nd kernel
717
+ func_name = "col_reduce_looped";
718
+ kname = func_name;
719
+ large = intermediate.size() > INT32_MAX;
720
+ if (large) {
721
+ kname += "_large";
722
+ }
723
+ concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
724
+ kernel = get_reduce_kernel(
725
+ d,
726
+ kname,
727
+ func_name,
728
+ op_name,
729
+ intermediate.dtype(),
730
+ out_type,
731
+ large ? "int64_t" : "int",
732
+ 1,
733
+ 32,
734
+ 32);
735
+ compute_encoder.set_compute_pipeline_state(kernel);
736
+
737
+ compute_encoder.set_input_array(intermediate, 0);
738
+ compute_encoder.set_output_array(out, 1);
739
+ second_args.encode(compute_encoder);
740
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
741
+ }
742
+
743
+ void strided_reduce_looped(
744
+ const array& in,
745
+ array& out,
746
+ const std::string& op_name,
747
+ ColReduceArgs& args,
748
+ CommandEncoder& compute_encoder,
749
+ metal::Device& d,
750
+ const Stream& s) {
751
+ auto [in_type, out_type] = remap_reduce_types(in, op_name);
752
+
753
+ // Prepare the arguments for the kernel
754
+ args.reduce_shape.push_back(args.reduction_size);
755
+ args.reduce_strides.push_back(args.reduction_stride);
756
+ args.reduce_ndim++;
757
+
758
+ // Figure out the grid dims
759
+ auto out_grid_size = output_grid_for_col_reduce(out, args);
760
+ int BN = 32;
761
+ int BM = 1024 / BN;
762
+ int threadgroup_size = 8 * 32;
763
+ MTL::Size grid_dims(
764
+ threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
765
+ out_grid_size.width,
766
+ out_grid_size.height);
767
+ MTL::Size group_dims(threadgroup_size, 1, 1);
768
+
769
+ // Set the kernel
770
+ int n = get_kernel_reduce_ndim(args.reduce_ndim);
771
+ std::string func_name = "col_reduce_looped";
772
+ std::string kname = func_name;
773
+ bool large = in.size() > INT32_MAX;
774
+ if (large) {
775
+ kname += "_large";
776
+ }
777
+ concatenate(
778
+ kname,
779
+ "_",
780
+ std::to_string(n),
781
+ "_",
782
+ std::to_string(BM),
783
+ "_",
784
+ std::to_string(BN),
785
+ "_reduce_",
786
+ op_name,
787
+ type_to_name(in_type));
788
+ auto kernel = get_reduce_kernel(
789
+ d,
790
+ kname,
791
+ func_name,
792
+ op_name,
793
+ in_type,
794
+ out_type,
795
+ large ? "int64_t" : "int",
796
+ n,
797
+ BM,
798
+ BN);
799
+ compute_encoder.set_compute_pipeline_state(kernel);
800
+
801
+ // Launch
802
+ compute_encoder.set_input_array(in, 0);
803
+ compute_encoder.set_output_array(out, 1);
804
+ args.encode(compute_encoder);
805
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
806
+ }
807
+
808
+ void strided_reduce_2pass(
809
+ const array& in,
810
+ array& out,
811
+ const std::string& op_name,
812
+ ColReduceArgs& args,
813
+ CommandEncoder& compute_encoder,
814
+ metal::Device& d,
815
+ const Stream& s) {
816
+ auto [in_type, out_type] = remap_reduce_types(in, op_name);
817
+
818
+ // Prepare the temporary accumulator
819
+ Shape intermediate_shape;
820
+ intermediate_shape.reserve(out.ndim() + 1);
821
+ intermediate_shape.push_back(32);
822
+ intermediate_shape.insert(
823
+ intermediate_shape.end(), out.shape().begin(), out.shape().end());
824
+ array intermediate(std::move(intermediate_shape), out_type, nullptr, {});
825
+ intermediate.set_data(allocator::malloc(intermediate.nbytes()));
826
+ d.add_temporary(intermediate, s.index);
827
+
828
+ // Prepare the arguments for the kernel
829
+ args.reduce_shape.push_back(args.reduction_size);
830
+ args.reduce_strides.push_back(args.reduction_stride);
831
+ args.reduce_ndim++;
832
+
833
+ // Figure out the grid dims
834
+ size_t out_size = out.size() / args.reduction_stride;
835
+ auto out_grid_size = output_grid_for_col_reduce(out, args);
836
+ int outer_blocks = 32;
837
+ int BN = 32;
838
+ int BM = 1024 / BN;
839
+ int threadgroup_size = 8 * 32;
840
+ MTL::Size grid_dims(
841
+ threadgroup_size * ((args.reduction_stride + BN - 1) / BN),
842
+ out_grid_size.width * outer_blocks,
843
+ out_grid_size.height);
844
+ MTL::Size group_dims(threadgroup_size, 1, 1);
845
+
846
+ // Set the kernel
847
+ int n = get_kernel_reduce_ndim(args.reduce_ndim);
848
+ std::string func_name = "col_reduce_2pass";
849
+ std::string kname = func_name;
850
+ bool large = in.size() > INT32_MAX;
851
+ if (large) {
852
+ kname += "_large";
853
+ }
854
+ concatenate(
855
+ kname,
856
+ "_",
857
+ std::to_string(n),
858
+ "_",
859
+ std::to_string(BM),
860
+ "_",
861
+ std::to_string(BN),
862
+ "_reduce_",
863
+ op_name,
864
+ type_to_name(in_type));
865
+ auto kernel = get_reduce_kernel(
866
+ d,
867
+ kname,
868
+ func_name,
869
+ op_name,
870
+ in_type,
871
+ out_type,
872
+ large ? "int64_t" : "int",
873
+ n,
874
+ BM,
875
+ BN);
876
+ compute_encoder.set_compute_pipeline_state(kernel);
877
+
878
+ // Launch
879
+ compute_encoder.set_input_array(in, 0);
880
+ compute_encoder.set_output_array(intermediate, 1);
881
+ args.encode(compute_encoder);
882
+ compute_encoder.set_bytes(out_size, 11);
883
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
884
+
885
+ // Make the 2nd pass arguments and grid_dims
886
+ ColReduceArgs second_args(intermediate);
887
+ second_args.reduce_shape.push_back(outer_blocks);
888
+ second_args.reduce_strides.push_back(out.size());
889
+ second_args.reduce_ndim++;
890
+ grid_dims = MTL::Size(threadgroup_size * ((out.size() + BN - 1) / BN), 1, 1);
891
+
892
+ // Set the 2nd kernel
893
+ func_name = "col_reduce_looped";
894
+ kname = func_name;
895
+ large = intermediate.size() > INT32_MAX;
896
+ if (large) {
897
+ kname += "_large";
898
+ }
899
+ concatenate(kname, "_1_32_32_reduce_", op_name, type_to_name(intermediate));
900
+ kernel = get_reduce_kernel(
901
+ d,
902
+ kname,
903
+ func_name,
904
+ op_name,
905
+ intermediate.dtype(),
906
+ out_type,
907
+ large ? "int64_t" : "int",
908
+ 1,
909
+ 32,
910
+ 32);
911
+ compute_encoder.set_compute_pipeline_state(kernel);
912
+
913
+ compute_encoder.set_input_array(intermediate, 0);
914
+ compute_encoder.set_output_array(out, 1);
915
+ second_args.encode(compute_encoder);
916
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
917
+ }
918
+
919
+ void strided_reduce_general_dispatch(
920
+ const array& in,
921
+ array& out,
922
+ const std::string& op_name,
923
+ const ReductionPlan& plan,
924
+ const std::vector<int>& axes,
925
+ CommandEncoder& compute_encoder,
926
+ metal::Device& d,
927
+ const Stream& s) {
928
+ // Prepare the arguments for the kernel
929
+ ColReduceArgs args(in, plan, axes);
930
+
931
+ // Small column
932
+ if (args.reduction_size * args.non_col_reductions < 32) {
933
+ return strided_reduce_small(in, out, op_name, args, compute_encoder, d, s);
934
+ }
935
+
936
+ // Long column but small row
937
+ if (args.reduction_stride < 32 &&
938
+ args.reduction_size * args.non_col_reductions >= 1024) {
939
+ return strided_reduce_longcolumn(
940
+ in, out, op_name, args, compute_encoder, d, s);
941
+ }
942
+
943
+ if (args.reduction_size * args.non_col_reductions > 256 &&
944
+ out.size() / 32 < 1024) {
945
+ return strided_reduce_2pass(in, out, op_name, args, compute_encoder, d, s);
946
+ }
947
+
948
+ return strided_reduce_looped(in, out, op_name, args, compute_encoder, d, s);
949
+ }
950
+
951
+ void Reduce::eval_gpu(const std::vector<array>& inputs, array& out) {
952
+ assert(inputs.size() == 1);
953
+ array in = inputs[0];
954
+
955
+ // Make sure no identity reductions trickle down here
956
+ assert(!axes_.empty());
957
+ assert(out.size() != in.size());
958
+
959
+ // Continue with reduction operation
960
+ // Minimum of 4 bytes since we use size 4 structs for all reduce
961
+ // and metal will complain o/w
962
+ size_t min_bytes = std::max(out.nbytes(), 4ul);
963
+ out.set_data(allocator::malloc(min_bytes));
964
+ std::string op_name;
965
+ switch (reduce_type_) {
966
+ case Reduce::And:
967
+ op_name = "and";
968
+ break;
969
+ case Reduce::Or:
970
+ op_name = "or";
971
+ break;
972
+ case Reduce::Sum:
973
+ op_name = "sum";
974
+ break;
975
+ case Reduce::Prod:
976
+ op_name = "prod";
977
+ break;
978
+ case Reduce::Min:
979
+ op_name = out.dtype() == bool_ ? "and" : "min";
980
+ break;
981
+ case Reduce::Max:
982
+ op_name = out.dtype() == bool_ ? "or" : "max";
983
+ break;
984
+ }
985
+
986
+ // Initialize output
987
+ auto& s = stream();
988
+ auto& d = metal::device(s.device);
989
+ auto& compute_encoder = d.get_command_encoder(s.index);
990
+
991
+ // Reduce
992
+ if (in.size() > 0) {
993
+ ReductionPlan plan = get_reduction_plan(in, axes_);
994
+
995
+ // If it is a general reduce then copy the input to a contiguous array and
996
+ // recompute the plan.
997
+ //
998
+ // TODO: This can be avoided by making the output have the same strides as
999
+ // input for the axes with stride smaller than the minimum reduction
1000
+ // stride.
1001
+ if (plan.type == GeneralReduce) {
1002
+ array in_copy = contiguous_copy_gpu(in, s);
1003
+ d.add_temporary(in_copy, s.index);
1004
+ in = in_copy;
1005
+ plan = get_reduction_plan(in, axes_);
1006
+ }
1007
+
1008
+ // Reducing over everything and the data is all there no broadcasting or
1009
+ // slicing etc.
1010
+ if (plan.type == ContiguousAllReduce) {
1011
+ all_reduce_dispatch(in, out, op_name, compute_encoder, d, s);
1012
+ }
1013
+
1014
+ // At least the last dimension is row contiguous and we are reducing over
1015
+ // the last dim.
1016
+ else if (
1017
+ plan.type == ContiguousReduce || plan.type == GeneralContiguousReduce) {
1018
+ row_reduce_general_dispatch(
1019
+ in, out, op_name, plan, axes_, compute_encoder, d, s);
1020
+ }
1021
+
1022
+ // At least the last two dimensions are contiguous and we are doing a
1023
+ // strided reduce over these.
1024
+ else if (
1025
+ plan.type == ContiguousStridedReduce ||
1026
+ plan.type == GeneralStridedReduce) {
1027
+ strided_reduce_general_dispatch(
1028
+ in, out, op_name, plan, axes_, compute_encoder, d, s);
1029
+ }
1030
+ }
1031
+
1032
+ // Nothing to reduce just initialize the output
1033
+ else {
1034
+ init_reduce(out, op_name, compute_encoder, d, s);
1035
+ }
1036
+ }
1037
+
1038
+ } // namespace mlx::core