mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,1076 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include <algorithm>
4
+ #include <cassert>
5
+ #include <cstdint>
6
+
7
+ #include "mlx/backend/cuda/device.h"
8
+ #include "mlx/backend/cuda/device/fp16_math.cuh"
9
+ #include "mlx/backend/cuda/kernel_utils.cuh"
10
+ #include "mlx/backend/gpu/copy.h"
11
+ #include "mlx/dtype_utils.h"
12
+ #include "mlx/primitives.h"
13
+
14
+ #include <nvtx3/nvtx3.hpp>
15
+ #include <cuda/std/limits>
16
+ #include <cuda/std/type_traits>
17
+
18
+ namespace mlx::core {
19
+
20
+ constexpr int N_PER_THREAD = 8;
21
+
22
+ namespace cu {
23
+
24
+ template <typename T>
25
+ __device__ __forceinline__ T nan_value();
26
+
27
+ template <>
28
+ __device__ __forceinline__ float nan_value<float>() {
29
+ return cuda::std::numeric_limits<float>::quiet_NaN();
30
+ }
31
+
32
+ template <>
33
+ __device__ __forceinline__ double nan_value<double>() {
34
+ return cuda::std::numeric_limits<double>::quiet_NaN();
35
+ }
36
+
37
+ template <>
38
+ __device__ __forceinline__ __half nan_value<__half>() {
39
+ return __float2half(cuda::std::numeric_limits<float>::quiet_NaN());
40
+ }
41
+
42
+ template <>
43
+ __device__ __forceinline__ __nv_bfloat16 nan_value<__nv_bfloat16>() {
44
+ return __float2bfloat16(cuda::std::numeric_limits<float>::quiet_NaN());
45
+ }
46
+
47
+ template <typename T, typename = void>
48
+ struct InitValue {
49
+ __device__ __forceinline__ static T value() {
50
+ return Limits<T>::max();
51
+ }
52
+ };
53
+
54
+ template <typename T>
55
+ struct InitValue<T, cuda::std::enable_if_t<std::is_floating_point_v<T>>> {
56
+ __device__ __forceinline__ static T value() {
57
+ return nan_value<T>();
58
+ }
59
+ };
60
+
61
+ template <typename T>
62
+ __device__ __forceinline__ void thread_swap(T& a, T& b) {
63
+ T w = a;
64
+ a = b;
65
+ b = w;
66
+ }
67
+
68
+ template <typename T>
69
+ struct LessThan {
70
+ __device__ __forceinline__ static T init() {
71
+ return InitValue<T>::value();
72
+ }
73
+
74
+ __device__ __forceinline__ bool operator()(T a, T b) const {
75
+ if constexpr (std::is_floating_point_v<T>) {
76
+ bool an = cuda::std::isnan(a);
77
+ bool bn = cuda::std::isnan(b);
78
+ if (an | bn) {
79
+ return (!an) & bn;
80
+ }
81
+ }
82
+ return a < b;
83
+ }
84
+ };
85
+
86
+ template <
87
+ typename ValT,
88
+ typename IdxT,
89
+ bool ARG_SORT,
90
+ int N_PER_THREAD,
91
+ typename CompareOp>
92
+ struct ThreadSort {
93
+ __device__ __forceinline__ static void sort(
94
+ ValT (&vals)[N_PER_THREAD],
95
+ IdxT (&idxs)[N_PER_THREAD]) {
96
+ CompareOp op;
97
+ #pragma unroll
98
+ for (int i = 0; i < N_PER_THREAD; ++i) {
99
+ #pragma unroll
100
+ for (int j = i & 1; j < N_PER_THREAD - 1; j += 2) {
101
+ if (op(vals[j + 1], vals[j])) {
102
+ thread_swap(vals[j + 1], vals[j]);
103
+ if constexpr (ARG_SORT) {
104
+ thread_swap(idxs[j + 1], idxs[j]);
105
+ }
106
+ }
107
+ }
108
+ }
109
+ }
110
+ };
111
+
112
+ template <
113
+ typename ValT,
114
+ typename IdxT,
115
+ bool ARG_SORT,
116
+ int BLOCK_THREADS,
117
+ int N_PER_THREAD,
118
+ typename CompareOp>
119
+ struct BlockMergeSort {
120
+ using thread_sort_t =
121
+ ThreadSort<ValT, IdxT, ARG_SORT, N_PER_THREAD, CompareOp>;
122
+
123
+ __device__ __forceinline__ static int merge_partition(
124
+ const ValT* As,
125
+ const ValT* Bs,
126
+ int A_sz,
127
+ int B_sz,
128
+ int sort_md) {
129
+ CompareOp op;
130
+
131
+ int A_st = max(0, sort_md - B_sz);
132
+ int A_ed = min(sort_md, A_sz);
133
+
134
+ while (A_st < A_ed) {
135
+ int md = A_st + (A_ed - A_st) / 2;
136
+ auto a = As[md];
137
+ auto b = Bs[sort_md - 1 - md];
138
+
139
+ if (op(b, a)) {
140
+ A_ed = md;
141
+ } else {
142
+ A_st = md + 1;
143
+ }
144
+ }
145
+
146
+ return A_ed;
147
+ }
148
+
149
+ __device__ __forceinline__ static void merge_step(
150
+ const ValT* As,
151
+ const ValT* Bs,
152
+ const IdxT* As_idx,
153
+ const IdxT* Bs_idx,
154
+ int A_sz,
155
+ int B_sz,
156
+ ValT (&vals)[N_PER_THREAD],
157
+ IdxT (&idxs)[N_PER_THREAD]) {
158
+ CompareOp op;
159
+ int a_idx = 0;
160
+ int b_idx = 0;
161
+
162
+ #pragma unroll
163
+ for (int i = 0; i < N_PER_THREAD; ++i) {
164
+ auto a = (a_idx < A_sz) ? As[a_idx] : ValT(CompareOp::init());
165
+ auto b = (b_idx < B_sz) ? Bs[b_idx] : ValT(CompareOp::init());
166
+ bool pred = (b_idx < B_sz) && (a_idx >= A_sz || op(b, a));
167
+
168
+ vals[i] = pred ? b : a;
169
+ if constexpr (ARG_SORT) {
170
+ if (pred) {
171
+ idxs[i] = Bs_idx[b_idx];
172
+ } else {
173
+ idxs[i] = (a_idx < A_sz) ? As_idx[a_idx] : IdxT(0);
174
+ }
175
+ }
176
+
177
+ b_idx += int(pred);
178
+ a_idx += int(!pred);
179
+ }
180
+ }
181
+
182
+ __device__ __forceinline__ static void
183
+ sort(ValT* tgp_vals, IdxT* tgp_idxs, int size_sorted_axis) {
184
+ int idx = threadIdx.x * N_PER_THREAD;
185
+
186
+ ValT thread_vals[N_PER_THREAD];
187
+ IdxT thread_idxs[N_PER_THREAD];
188
+ #pragma unroll
189
+ for (int i = 0; i < N_PER_THREAD; ++i) {
190
+ thread_vals[i] = tgp_vals[idx + i];
191
+ if constexpr (ARG_SORT) {
192
+ thread_idxs[i] = tgp_idxs[idx + i];
193
+ }
194
+ }
195
+
196
+ if (idx < size_sorted_axis) {
197
+ thread_sort_t::sort(thread_vals, thread_idxs);
198
+ }
199
+
200
+ for (int merge_threads = 2; merge_threads <= BLOCK_THREADS;
201
+ merge_threads *= 2) {
202
+ __syncthreads();
203
+ #pragma unroll
204
+ for (int i = 0; i < N_PER_THREAD; ++i) {
205
+ tgp_vals[idx + i] = thread_vals[i];
206
+ if constexpr (ARG_SORT) {
207
+ tgp_idxs[idx + i] = thread_idxs[i];
208
+ }
209
+ }
210
+ __syncthreads();
211
+
212
+ int merge_group = threadIdx.x / merge_threads;
213
+ int merge_lane = threadIdx.x % merge_threads;
214
+
215
+ int sort_sz = N_PER_THREAD * merge_threads;
216
+ int sort_st = N_PER_THREAD * merge_threads * merge_group;
217
+
218
+ int A_st = sort_st;
219
+ int A_ed = sort_st + sort_sz / 2;
220
+ int B_st = sort_st + sort_sz / 2;
221
+ int B_ed = sort_st + sort_sz;
222
+
223
+ const ValT* As = tgp_vals + A_st;
224
+ const ValT* Bs = tgp_vals + B_st;
225
+ int A_sz = A_ed - A_st;
226
+ int B_sz = B_ed - B_st;
227
+
228
+ int sort_md = N_PER_THREAD * merge_lane;
229
+ int partition = merge_partition(As, Bs, A_sz, B_sz, sort_md);
230
+
231
+ As += partition;
232
+ Bs += sort_md - partition;
233
+
234
+ A_sz -= partition;
235
+ B_sz -= sort_md - partition;
236
+
237
+ const IdxT* As_idx = ARG_SORT ? tgp_idxs + A_st + partition : nullptr;
238
+ const IdxT* Bs_idx =
239
+ ARG_SORT ? tgp_idxs + B_st + sort_md - partition : nullptr;
240
+
241
+ merge_step(As, Bs, As_idx, Bs_idx, A_sz, B_sz, thread_vals, thread_idxs);
242
+ }
243
+
244
+ __syncthreads();
245
+ #pragma unroll
246
+ for (int i = 0; i < N_PER_THREAD; ++i) {
247
+ tgp_vals[idx + i] = thread_vals[i];
248
+ if constexpr (ARG_SORT) {
249
+ tgp_idxs[idx + i] = thread_idxs[i];
250
+ }
251
+ }
252
+ }
253
+ };
254
+
255
+ template <
256
+ typename T,
257
+ typename U,
258
+ bool ARG_SORT,
259
+ int BLOCK_THREADS,
260
+ int N_PER_THREAD,
261
+ typename CompareOp = LessThan<T>>
262
+ struct KernelMergeSort {
263
+ using ValT = T;
264
+ using IdxT = uint32_t;
265
+ using block_merge_sort_t = BlockMergeSort<
266
+ ValT,
267
+ IdxT,
268
+ ARG_SORT,
269
+ BLOCK_THREADS,
270
+ N_PER_THREAD,
271
+ CompareOp>;
272
+
273
+ static constexpr int N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
274
+
275
+ __device__ __forceinline__ static void block_sort(
276
+ const T* inp,
277
+ U* out,
278
+ int size_sorted_axis,
279
+ int64_t in_stride_sorted_axis,
280
+ int64_t out_stride_sorted_axis,
281
+ int64_t in_stride_segment_axis,
282
+ int64_t out_stride_segment_axis,
283
+ ValT* tgp_vals,
284
+ IdxT* tgp_idxs) {
285
+ inp += blockIdx.y * in_stride_segment_axis;
286
+ out += blockIdx.y * out_stride_segment_axis;
287
+
288
+ for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
289
+ tgp_vals[i] = i < size_sorted_axis ? inp[i * in_stride_sorted_axis]
290
+ : ValT(CompareOp::init());
291
+ if constexpr (ARG_SORT) {
292
+ tgp_idxs[i] = i;
293
+ }
294
+ }
295
+
296
+ __syncthreads();
297
+ block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis);
298
+ __syncthreads();
299
+
300
+ for (int i = threadIdx.x; i < size_sorted_axis; i += BLOCK_THREADS) {
301
+ if constexpr (ARG_SORT) {
302
+ out[i * out_stride_sorted_axis] = tgp_idxs[i];
303
+ } else {
304
+ out[i * out_stride_sorted_axis] = tgp_vals[i];
305
+ }
306
+ }
307
+ }
308
+ };
309
+
310
+ template <
311
+ typename T,
312
+ typename U,
313
+ bool ARG_SORT,
314
+ int BLOCK_THREADS,
315
+ int N_PER_THREAD>
316
+ __global__ void block_sort_kernel(
317
+ const T* inp,
318
+ U* out,
319
+ int size_sorted_axis,
320
+ int64_t in_stride_sorted_axis,
321
+ int64_t out_stride_sorted_axis,
322
+ int64_t in_stride_segment_axis,
323
+ int64_t out_stride_segment_axis) {
324
+ using sort_kernel =
325
+ KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
326
+ using ValT = typename sort_kernel::ValT;
327
+ using IdxT = typename sort_kernel::IdxT;
328
+
329
+ if constexpr (ARG_SORT) {
330
+ __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];
331
+ __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
332
+ sort_kernel::block_sort(
333
+ inp,
334
+ out,
335
+ size_sorted_axis,
336
+ in_stride_sorted_axis,
337
+ out_stride_sorted_axis,
338
+ in_stride_segment_axis,
339
+ out_stride_segment_axis,
340
+ tgp_vals,
341
+ tgp_idxs);
342
+ } else {
343
+ __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];
344
+ sort_kernel::block_sort(
345
+ inp,
346
+ out,
347
+ size_sorted_axis,
348
+ in_stride_sorted_axis,
349
+ out_stride_sorted_axis,
350
+ in_stride_segment_axis,
351
+ out_stride_segment_axis,
352
+ tgp_vals,
353
+ nullptr);
354
+ }
355
+ }
356
+
357
+ template <
358
+ typename T,
359
+ typename U,
360
+ bool ARG_SORT,
361
+ int BLOCK_THREADS,
362
+ int N_PER_THREAD>
363
+ __global__ void block_sort_nc_kernel(
364
+ const T* inp,
365
+ U* out,
366
+ int size_sorted_axis,
367
+ int64_t in_stride_sorted_axis,
368
+ int64_t out_stride_sorted_axis,
369
+ const __grid_constant__ Shape nc_shape,
370
+ const __grid_constant__ Strides in_nc_strides,
371
+ const __grid_constant__ Strides out_nc_strides,
372
+ int nc_dim) {
373
+ using sort_kernel =
374
+ KernelMergeSort<T, U, ARG_SORT, BLOCK_THREADS, N_PER_THREAD>;
375
+ using ValT = typename sort_kernel::ValT;
376
+ using IdxT = typename sort_kernel::IdxT;
377
+
378
+ int64_t in_block_idx = elem_to_loc(
379
+ int64_t(blockIdx.y), nc_shape.data(), in_nc_strides.data(), nc_dim);
380
+ int64_t out_block_idx = elem_to_loc(
381
+ int64_t(blockIdx.y), nc_shape.data(), out_nc_strides.data(), nc_dim);
382
+
383
+ inp += in_block_idx;
384
+ out += out_block_idx;
385
+
386
+ if constexpr (ARG_SORT) {
387
+ __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];
388
+ __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
389
+ sort_kernel::block_sort(
390
+ inp,
391
+ out,
392
+ size_sorted_axis,
393
+ in_stride_sorted_axis,
394
+ out_stride_sorted_axis,
395
+ 0,
396
+ 0,
397
+ tgp_vals,
398
+ tgp_idxs);
399
+ } else {
400
+ __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];
401
+ sort_kernel::block_sort(
402
+ inp,
403
+ out,
404
+ size_sorted_axis,
405
+ in_stride_sorted_axis,
406
+ out_stride_sorted_axis,
407
+ 0,
408
+ 0,
409
+ tgp_vals,
410
+ nullptr);
411
+ }
412
+ }
413
+
414
+ template <
415
+ typename ValT,
416
+ typename IdxT,
417
+ bool ARG_SORT,
418
+ int BLOCK_THREADS,
419
+ int N_PER_THREAD,
420
+ typename CompareOp = LessThan<ValT>>
421
+ struct KernelMultiBlockMergeSort {
422
+ using block_merge_sort_t = BlockMergeSort<
423
+ ValT,
424
+ IdxT,
425
+ ARG_SORT,
426
+ BLOCK_THREADS,
427
+ N_PER_THREAD,
428
+ CompareOp>;
429
+
430
+ static constexpr int N_PER_BLOCK = BLOCK_THREADS * N_PER_THREAD;
431
+
432
+ __device__ __forceinline__ static void block_sort(
433
+ const ValT* inp,
434
+ ValT* out_vals,
435
+ IdxT* out_idxs,
436
+ int size_sorted_axis,
437
+ int64_t stride_sorted_axis,
438
+ ValT* tgp_vals,
439
+ IdxT* tgp_idxs) {
440
+ int base_idx = blockIdx.x * N_PER_BLOCK;
441
+
442
+ for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
443
+ int idx = base_idx + i;
444
+ tgp_vals[i] = idx < size_sorted_axis ? inp[idx * stride_sorted_axis]
445
+ : ValT(CompareOp::init());
446
+ tgp_idxs[i] = idx;
447
+ }
448
+
449
+ __syncthreads();
450
+ block_merge_sort_t::sort(tgp_vals, tgp_idxs, size_sorted_axis);
451
+ __syncthreads();
452
+
453
+ for (int i = threadIdx.x; i < N_PER_BLOCK; i += BLOCK_THREADS) {
454
+ int idx = base_idx + i;
455
+ if (idx < size_sorted_axis) {
456
+ out_vals[idx] = tgp_vals[i];
457
+ out_idxs[idx] = tgp_idxs[i];
458
+ }
459
+ }
460
+ }
461
+
462
+ __device__ __forceinline__ static int merge_partition(
463
+ const ValT* As,
464
+ const ValT* Bs,
465
+ int A_sz,
466
+ int B_sz,
467
+ int sort_md) {
468
+ CompareOp op;
469
+
470
+ int A_st = max(0, sort_md - B_sz);
471
+ int A_ed = min(sort_md, A_sz);
472
+
473
+ while (A_st < A_ed) {
474
+ int md = A_st + (A_ed - A_st) / 2;
475
+ auto a = As[md];
476
+ auto b = Bs[sort_md - 1 - md];
477
+
478
+ if (op(b, a)) {
479
+ A_ed = md;
480
+ } else {
481
+ A_st = md + 1;
482
+ }
483
+ }
484
+
485
+ return A_ed;
486
+ }
487
+ };
488
+
489
+ template <
490
+ typename ValT,
491
+ typename IdxT,
492
+ bool ARG_SORT,
493
+ int BLOCK_THREADS,
494
+ int N_PER_THREAD>
495
+ __global__ void mb_block_sort_kernel(
496
+ const ValT* inp,
497
+ ValT* out_vals,
498
+ IdxT* out_idxs,
499
+ int size_sorted_axis,
500
+ int64_t stride_sorted_axis,
501
+ const __grid_constant__ Shape nc_shape,
502
+ const __grid_constant__ Strides nc_strides,
503
+ int nc_dim) {
504
+ using sort_kernel = KernelMultiBlockMergeSort<
505
+ ValT,
506
+ IdxT,
507
+ ARG_SORT,
508
+ BLOCK_THREADS,
509
+ N_PER_THREAD>;
510
+
511
+ int64_t block_idx = elem_to_loc(
512
+ int64_t(blockIdx.y), nc_shape.data(), nc_strides.data(), nc_dim);
513
+
514
+ inp += block_idx;
515
+ out_vals += blockIdx.y * size_sorted_axis;
516
+ out_idxs += blockIdx.y * size_sorted_axis;
517
+
518
+ __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];
519
+ __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
520
+
521
+ sort_kernel::block_sort(
522
+ inp,
523
+ out_vals,
524
+ out_idxs,
525
+ size_sorted_axis,
526
+ stride_sorted_axis,
527
+ tgp_vals,
528
+ tgp_idxs);
529
+ }
530
+
531
+ template <
532
+ typename ValT,
533
+ typename IdxT,
534
+ bool ARG_SORT,
535
+ int BLOCK_THREADS,
536
+ int N_PER_THREAD>
537
+ __global__ void mb_block_partition_kernel(
538
+ IdxT* block_partitions,
539
+ const ValT* dev_vals,
540
+ const IdxT* dev_idxs,
541
+ int size_sorted_axis,
542
+ int merge_tiles,
543
+ int n_blocks) {
544
+ using sort_kernel = KernelMultiBlockMergeSort<
545
+ ValT,
546
+ IdxT,
547
+ ARG_SORT,
548
+ BLOCK_THREADS,
549
+ N_PER_THREAD>;
550
+
551
+ (void)dev_idxs;
552
+
553
+ block_partitions += blockIdx.y * blockDim.x;
554
+ dev_vals += blockIdx.y * size_sorted_axis;
555
+ dev_idxs += blockIdx.y * size_sorted_axis;
556
+
557
+ for (int i = threadIdx.x; i <= n_blocks; i += blockDim.x) {
558
+ int merge_group = i / merge_tiles;
559
+ int merge_lane = i % merge_tiles;
560
+
561
+ int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
562
+ int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
563
+
564
+ int A_st = min(size_sorted_axis, sort_st);
565
+ int A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
566
+ int B_st = A_ed;
567
+ int B_ed = min(size_sorted_axis, B_st + sort_sz / 2);
568
+
569
+ int partition_at = min(B_ed - A_st, sort_kernel::N_PER_BLOCK * merge_lane);
570
+ int partition = sort_kernel::merge_partition(
571
+ dev_vals + A_st,
572
+ dev_vals + B_st,
573
+ A_ed - A_st,
574
+ B_ed - B_st,
575
+ partition_at);
576
+
577
+ block_partitions[i] = A_st + partition;
578
+ }
579
+ }
580
+
581
+ template <
582
+ typename ValT,
583
+ typename IdxT,
584
+ bool ARG_SORT,
585
+ int BLOCK_THREADS,
586
+ int N_PER_THREAD,
587
+ typename CompareOp = LessThan<ValT>>
588
+ __global__ void mb_block_merge_kernel(
589
+ const IdxT* block_partitions,
590
+ const ValT* dev_vals_in,
591
+ const IdxT* dev_idxs_in,
592
+ ValT* dev_vals_out,
593
+ IdxT* dev_idxs_out,
594
+ int size_sorted_axis,
595
+ int merge_tiles,
596
+ int num_tiles) {
597
+ using sort_kernel = KernelMultiBlockMergeSort<
598
+ ValT,
599
+ IdxT,
600
+ ARG_SORT,
601
+ BLOCK_THREADS,
602
+ N_PER_THREAD,
603
+ CompareOp>;
604
+
605
+ using block_sort_t = typename sort_kernel::block_merge_sort_t;
606
+
607
+ block_partitions += blockIdx.y * (num_tiles + 1);
608
+ dev_vals_in += blockIdx.y * size_sorted_axis;
609
+ dev_idxs_in += blockIdx.y * size_sorted_axis;
610
+ dev_vals_out += blockIdx.y * size_sorted_axis;
611
+ dev_idxs_out += blockIdx.y * size_sorted_axis;
612
+
613
+ int block_idx = blockIdx.x;
614
+ int merge_group = block_idx / merge_tiles;
615
+ int sort_st = sort_kernel::N_PER_BLOCK * merge_tiles * merge_group;
616
+ int sort_sz = sort_kernel::N_PER_BLOCK * merge_tiles;
617
+ int sort_md = sort_kernel::N_PER_BLOCK * block_idx - sort_st;
618
+
619
+ int A_st = block_partitions[block_idx + 0];
620
+ int A_ed = block_partitions[block_idx + 1];
621
+ int B_st = min(size_sorted_axis, 2 * sort_st + sort_sz / 2 + sort_md - A_st);
622
+ int B_ed = min(
623
+ size_sorted_axis,
624
+ 2 * sort_st + sort_sz / 2 + sort_md + sort_kernel::N_PER_BLOCK - A_ed);
625
+
626
+ if ((block_idx % merge_tiles) == merge_tiles - 1) {
627
+ A_ed = min(size_sorted_axis, sort_st + sort_sz / 2);
628
+ B_ed = min(size_sorted_axis, sort_st + sort_sz);
629
+ }
630
+
631
+ int A_sz = A_ed - A_st;
632
+ int B_sz = B_ed - B_st;
633
+
634
+ ValT thread_vals[N_PER_THREAD];
635
+ IdxT thread_idxs[N_PER_THREAD];
636
+ #pragma unroll
637
+ for (int i = 0; i < N_PER_THREAD; i++) {
638
+ int idx = BLOCK_THREADS * i + threadIdx.x;
639
+ if (idx < (A_sz + B_sz)) {
640
+ thread_vals[i] = (idx < A_sz) ? dev_vals_in[A_st + idx]
641
+ : dev_vals_in[B_st + idx - A_sz];
642
+ thread_idxs[i] = (idx < A_sz) ? dev_idxs_in[A_st + idx]
643
+ : dev_idxs_in[B_st + idx - A_sz];
644
+ } else {
645
+ thread_vals[i] = CompareOp::init();
646
+ thread_idxs[i] = 0;
647
+ }
648
+ }
649
+
650
+ __shared__ ValT tgp_vals[sort_kernel::N_PER_BLOCK];
651
+ __shared__ IdxT tgp_idxs[sort_kernel::N_PER_BLOCK];
652
+ __syncthreads();
653
+ #pragma unroll
654
+ for (int i = 0; i < N_PER_THREAD; i++) {
655
+ int idx = BLOCK_THREADS * i + threadIdx.x;
656
+ tgp_vals[idx] = thread_vals[i];
657
+ tgp_idxs[idx] = thread_idxs[i];
658
+ }
659
+ __syncthreads();
660
+
661
+ int sort_md_local = min(A_sz + B_sz, N_PER_THREAD * int(threadIdx.x));
662
+
663
+ int A_st_local = block_sort_t::merge_partition(
664
+ tgp_vals, tgp_vals + A_sz, A_sz, B_sz, sort_md_local);
665
+ int A_ed_local = A_sz;
666
+
667
+ int B_st_local = sort_md_local - A_st_local;
668
+ int B_ed_local = B_sz;
669
+
670
+ int A_sz_local = A_ed_local - A_st_local;
671
+ int B_sz_local = B_ed_local - B_st_local;
672
+
673
+ block_sort_t::merge_step(
674
+ tgp_vals + A_st_local,
675
+ tgp_vals + A_ed_local + B_st_local,
676
+ tgp_idxs + A_st_local,
677
+ tgp_idxs + A_ed_local + B_st_local,
678
+ A_sz_local,
679
+ B_sz_local,
680
+ thread_vals,
681
+ thread_idxs);
682
+
683
+ __syncthreads();
684
+ #pragma unroll
685
+ for (int i = 0; i < N_PER_THREAD; ++i) {
686
+ int idx = threadIdx.x * N_PER_THREAD;
687
+ tgp_vals[idx + i] = thread_vals[i];
688
+ tgp_idxs[idx + i] = thread_idxs[i];
689
+ }
690
+
691
+ __syncthreads();
692
+ int base_idx = blockIdx.x * sort_kernel::N_PER_BLOCK;
693
+ for (int i = threadIdx.x; i < sort_kernel::N_PER_BLOCK; i += BLOCK_THREADS) {
694
+ int idx = base_idx + i;
695
+ if (idx < size_sorted_axis) {
696
+ dev_vals_out[idx] = tgp_vals[i];
697
+ dev_idxs_out[idx] = tgp_idxs[i];
698
+ }
699
+ }
700
+ }
701
+
702
+ } // namespace cu
703
+
704
+ namespace {
705
+
706
+ void single_block_sort(
707
+ const Stream& s,
708
+ const array& in,
709
+ array& out,
710
+ int axis,
711
+ int bn,
712
+ bool argsort) {
713
+ int n_rows = in.size() / in.shape(axis);
714
+
715
+ auto in_nc_str = in.strides();
716
+ in_nc_str.erase(in_nc_str.begin() + axis);
717
+
718
+ auto out_nc_str = out.strides();
719
+ out_nc_str.erase(out_nc_str.begin() + axis);
720
+
721
+ auto nc_shape = in.shape();
722
+ nc_shape.erase(nc_shape.begin() + axis);
723
+
724
+ int nc_dim = nc_shape.size();
725
+
726
+ int size_sorted_axis = in.shape(axis);
727
+ int64_t in_stride_sorted_axis = in.strides()[axis];
728
+ int64_t out_stride_sorted_axis = out.strides()[axis];
729
+
730
+ bool contiguous = in.flags().contiguous;
731
+ auto check_strides = [](const array& x, int64_t sort_stride) {
732
+ int64_t min_stride =
733
+ *std::min_element(x.strides().begin(), x.strides().end());
734
+ int64_t max_stride =
735
+ *std::max_element(x.strides().begin(), x.strides().end());
736
+ return sort_stride == min_stride || sort_stride == max_stride;
737
+ };
738
+ contiguous &= check_strides(in, in_stride_sorted_axis);
739
+ contiguous &= check_strides(out, out_stride_sorted_axis);
740
+
741
+ auto& encoder = cu::get_command_encoder(s);
742
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
743
+ encoder.set_input_array(in);
744
+ encoder.set_output_array(out);
745
+
746
+ dispatch_all_types(in.dtype(), [&](auto type_tag) {
747
+ using CTYPE = MLX_GET_TYPE(type_tag);
748
+ if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
749
+ using ValT = cuda_type_t<CTYPE>;
750
+ dispatch_block_dim(bn, [&](auto block_dim) {
751
+ constexpr int BLOCK_THREADS = block_dim();
752
+ if constexpr (BLOCK_THREADS < 1024) {
753
+ dim3 grid(1, n_rows, 1);
754
+ dim3 block(BLOCK_THREADS, 1, 1);
755
+
756
+ dispatch_bool(argsort, [&](auto arg_tag) {
757
+ constexpr bool ARG_SORT = decltype(arg_tag)::value;
758
+ using OutT = std::conditional_t<ARG_SORT, uint32_t, ValT>;
759
+
760
+ if (contiguous) {
761
+ auto kernel = cu::block_sort_kernel<
762
+ ValT,
763
+ OutT,
764
+ ARG_SORT,
765
+ BLOCK_THREADS,
766
+ N_PER_THREAD>;
767
+ int64_t in_stride_segment_axis = INT64_MAX;
768
+ int64_t out_stride_segment_axis = INT64_MAX;
769
+ for (int i = 0; i < nc_shape.size(); i++) {
770
+ if (nc_shape[i] == 1) {
771
+ continue;
772
+ }
773
+ if (in_nc_str[i] > INT32_MAX || out_nc_str[i] > INT32_MAX) {
774
+ throw std::runtime_error(
775
+ "[Sort::eval_gpu] Stride too large.");
776
+ }
777
+ in_stride_segment_axis =
778
+ std::min(in_stride_segment_axis, in_nc_str[i]);
779
+ out_stride_segment_axis =
780
+ std::min(out_stride_segment_axis, out_nc_str[i]);
781
+ }
782
+ encoder.add_kernel_node(
783
+ kernel,
784
+ grid,
785
+ block,
786
+ 0,
787
+ gpu_ptr<ValT>(in),
788
+ gpu_ptr<OutT>(out),
789
+ size_sorted_axis,
790
+ in_stride_sorted_axis,
791
+ out_stride_sorted_axis,
792
+ in_stride_segment_axis,
793
+ out_stride_segment_axis);
794
+ } else {
795
+ auto kernel = cu::block_sort_nc_kernel<
796
+ ValT,
797
+ OutT,
798
+ ARG_SORT,
799
+ BLOCK_THREADS,
800
+ N_PER_THREAD>;
801
+ auto nc_shape_param = const_param(nc_shape);
802
+ auto in_nc_strides_param = const_param(in_nc_str);
803
+ auto out_nc_strides_param = const_param(out_nc_str);
804
+ encoder.add_kernel_node(
805
+ kernel,
806
+ grid,
807
+ block,
808
+ 0,
809
+ gpu_ptr<ValT>(in),
810
+ gpu_ptr<OutT>(out),
811
+ size_sorted_axis,
812
+ in_stride_sorted_axis,
813
+ out_stride_sorted_axis,
814
+ nc_shape_param,
815
+ in_nc_strides_param,
816
+ out_nc_strides_param,
817
+ nc_dim);
818
+ }
819
+ });
820
+ }
821
+ });
822
+ } else {
823
+ throw std::runtime_error(
824
+ "CUDA backend does not support sorting complex numbers");
825
+ }
826
+ });
827
+ }
828
+
829
+ void multi_block_sort(
830
+ const Stream& s,
831
+ const array& in,
832
+ array& out,
833
+ int axis,
834
+ int n_blocks,
835
+ bool argsort) {
836
+ int n_rows = in.size() / in.shape(axis);
837
+
838
+ auto nc_str = in.strides();
839
+ nc_str.erase(nc_str.begin() + axis);
840
+
841
+ auto nc_shape = in.shape();
842
+ nc_shape.erase(nc_shape.begin() + axis);
843
+
844
+ int nc_dim = nc_shape.size();
845
+
846
+ if (nc_dim == 0) {
847
+ nc_shape = {0};
848
+ nc_str = {1};
849
+ }
850
+
851
+ int size_sorted_axis = in.shape(axis);
852
+ int64_t stride_sorted_axis = in.strides()[axis];
853
+
854
+ array dev_vals_in({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
855
+ array dev_vals_out({n_rows, size_sorted_axis}, in.dtype(), nullptr, {});
856
+
857
+ array dev_idxs_in({n_rows, size_sorted_axis}, uint32, nullptr, {});
858
+ array dev_idxs_out({n_rows, size_sorted_axis}, uint32, nullptr, {});
859
+
860
+ array block_partitions({n_rows, n_blocks + 1}, uint32, nullptr, {});
861
+
862
+ auto& encoder = cu::get_command_encoder(s);
863
+
864
+ dev_vals_in.set_data(cu::malloc_async(dev_vals_in.nbytes(), encoder));
865
+ dev_vals_out.set_data(cu::malloc_async(dev_vals_out.nbytes(), encoder));
866
+ dev_idxs_in.set_data(cu::malloc_async(dev_idxs_in.nbytes(), encoder));
867
+ dev_idxs_out.set_data(cu::malloc_async(dev_idxs_out.nbytes(), encoder));
868
+ block_partitions.set_data(
869
+ cu::malloc_async(block_partitions.nbytes(), encoder));
870
+
871
+ encoder.add_temporary(block_partitions);
872
+
873
+ dispatch_all_types(in.dtype(), [&](auto type_tag) {
874
+ using CTYPE = MLX_GET_TYPE(type_tag);
875
+ if constexpr (!std::is_same_v<CTYPE, complex64_t>) {
876
+ using ValT = cuda_type_t<CTYPE>;
877
+ using IdxT = uint32_t;
878
+ constexpr int BLOCK_THREADS = sizeof(ValT) == 8 ? 256 : 512;
879
+ dim3 grid(n_blocks, n_rows, 1);
880
+ dim3 block(BLOCK_THREADS, 1, 1);
881
+
882
+ dispatch_bool(argsort, [&](auto arg_tag) {
883
+ constexpr bool ARG_SORT = decltype(arg_tag)::value;
884
+ auto nc_shape_param = const_param(nc_shape);
885
+ auto nc_strides_param = const_param(nc_str);
886
+
887
+ auto block_sort_kernel = cu::mb_block_sort_kernel<
888
+ ValT,
889
+ IdxT,
890
+ ARG_SORT,
891
+ BLOCK_THREADS,
892
+ N_PER_THREAD>;
893
+ encoder.set_input_array(in);
894
+ encoder.set_output_array(dev_vals_in);
895
+ encoder.set_output_array(dev_idxs_in);
896
+ encoder.add_kernel_node(
897
+ block_sort_kernel,
898
+ grid,
899
+ block,
900
+ 0,
901
+ gpu_ptr<ValT>(in),
902
+ gpu_ptr<ValT>(dev_vals_in),
903
+ gpu_ptr<IdxT>(dev_idxs_in),
904
+ size_sorted_axis,
905
+ stride_sorted_axis,
906
+ nc_shape_param,
907
+ nc_strides_param,
908
+ nc_dim);
909
+
910
+ int n_thr_per_group = (n_blocks + 1) < 1024 ? (n_blocks + 1) : 1024;
911
+
912
+ for (int merge_tiles = 2; (merge_tiles / 2) < n_blocks;
913
+ merge_tiles *= 2) {
914
+ auto partition_kernel = cu::mb_block_partition_kernel<
915
+ ValT,
916
+ IdxT,
917
+ ARG_SORT,
918
+ BLOCK_THREADS,
919
+ N_PER_THREAD>;
920
+
921
+ encoder.set_input_array(dev_vals_in);
922
+ encoder.set_input_array(dev_idxs_in);
923
+ encoder.set_output_array(block_partitions);
924
+
925
+ encoder.add_kernel_node(
926
+ partition_kernel,
927
+ dim3(1, n_rows, 1),
928
+ dim3(n_thr_per_group, 1, 1),
929
+ 0,
930
+ gpu_ptr<IdxT>(block_partitions),
931
+ gpu_ptr<ValT>(dev_vals_in),
932
+ gpu_ptr<IdxT>(dev_idxs_in),
933
+ size_sorted_axis,
934
+ merge_tiles,
935
+ n_blocks);
936
+
937
+ auto merge_kernel = cu::mb_block_merge_kernel<
938
+ ValT,
939
+ IdxT,
940
+ ARG_SORT,
941
+ BLOCK_THREADS,
942
+ N_PER_THREAD>;
943
+
944
+ encoder.set_input_array(dev_vals_in);
945
+ encoder.set_input_array(dev_idxs_in);
946
+ encoder.set_input_array(block_partitions);
947
+ encoder.set_output_array(dev_vals_out);
948
+ encoder.set_output_array(dev_idxs_out);
949
+
950
+ encoder.add_kernel_node(
951
+ merge_kernel,
952
+ dim3(n_blocks, n_rows, 1),
953
+ dim3(BLOCK_THREADS, 1, 1),
954
+ 0,
955
+ gpu_ptr<IdxT>(block_partitions),
956
+ gpu_ptr<ValT>(dev_vals_in),
957
+ gpu_ptr<IdxT>(dev_idxs_in),
958
+ gpu_ptr<ValT>(dev_vals_out),
959
+ gpu_ptr<IdxT>(dev_idxs_out),
960
+ size_sorted_axis,
961
+ merge_tiles,
962
+ n_blocks);
963
+ std::swap(dev_vals_in, dev_vals_out);
964
+ std::swap(dev_idxs_in, dev_idxs_out);
965
+ }
966
+ });
967
+ } else {
968
+ throw std::runtime_error(
969
+ "CUDA backend does not support sorting complex numbers");
970
+ }
971
+ });
972
+
973
+ encoder.add_temporary(dev_vals_out);
974
+ encoder.add_temporary(dev_idxs_out);
975
+ encoder.add_temporary(argsort ? dev_vals_in : dev_idxs_in);
976
+ if (axis == in.ndim() - 1) {
977
+ // Copy buffer to out, no need for temporary
978
+ out.copy_shared_buffer(
979
+ argsort ? dev_idxs_in : dev_vals_in,
980
+ out.strides(),
981
+ out.flags(),
982
+ out.size());
983
+ } else {
984
+ encoder.add_temporary(argsort ? dev_idxs_in : dev_vals_in);
985
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
986
+ auto strides = out.strides();
987
+ for (int ax = axis + 1; ax < strides.size(); ax++) {
988
+ strides[ax] *= out.shape(axis);
989
+ }
990
+ strides[axis] = 1;
991
+ copy_gpu_inplace(
992
+ (argsort) ? dev_idxs_in : dev_vals_in,
993
+ out,
994
+ out.shape(),
995
+ strides,
996
+ out.strides(),
997
+ 0,
998
+ 0,
999
+ CopyType::General,
1000
+ s);
1001
+ }
1002
+ }
1003
+
1004
+ void gpu_merge_sort(
1005
+ const Stream& s,
1006
+ const array& in,
1007
+ array& out,
1008
+ int axis_,
1009
+ bool argsort) {
1010
+ int axis = axis_ < 0 ? axis_ + in.ndim() : axis_;
1011
+ int size_sorted_axis = in.shape(axis);
1012
+
1013
+ constexpr int tn = N_PER_THREAD;
1014
+ int potential_bn = (size_sorted_axis + tn - 1) / tn;
1015
+
1016
+ int bn;
1017
+ if (potential_bn > 256) {
1018
+ bn = 512;
1019
+ } else if (potential_bn > 128) {
1020
+ bn = 256;
1021
+ } else if (potential_bn > 64) {
1022
+ bn = 128;
1023
+ } else if (potential_bn > 32) {
1024
+ bn = 64;
1025
+ } else {
1026
+ bn = 32;
1027
+ }
1028
+
1029
+ if (bn == 512 && size_of(in.dtype()) > 4) {
1030
+ bn = 256;
1031
+ }
1032
+
1033
+ int n_per_block = bn * tn;
1034
+ int n_blocks = (size_sorted_axis + n_per_block - 1) / n_per_block;
1035
+
1036
+ if (n_blocks > 1) {
1037
+ return multi_block_sort(s, in, out, axis, n_blocks, argsort);
1038
+ }
1039
+ return single_block_sort(s, in, out, axis, bn, argsort);
1040
+ }
1041
+
1042
+ void gpu_sort(
1043
+ const Stream& s,
1044
+ const array& in,
1045
+ array& out,
1046
+ int axis,
1047
+ bool argsort) {
1048
+ auto& encoder = cu::get_command_encoder(s);
1049
+ gpu_merge_sort(s, in, out, axis, argsort);
1050
+ }
1051
+
1052
+ } // namespace
1053
+
1054
+ void ArgSort::eval_gpu(const std::vector<array>& inputs, array& out) {
1055
+ nvtx3::scoped_range r("ArgSort::eval_gpu");
1056
+ assert(inputs.size() == 1);
1057
+ gpu_sort(stream(), inputs[0], out, axis_, true);
1058
+ }
1059
+
1060
+ void Sort::eval_gpu(const std::vector<array>& inputs, array& out) {
1061
+ nvtx3::scoped_range r("Sort::eval_gpu");
1062
+ assert(inputs.size() == 1);
1063
+ gpu_sort(stream(), inputs[0], out, axis_, false);
1064
+ }
1065
+
1066
+ void ArgPartition::eval_gpu(const std::vector<array>& inputs, array& out) {
1067
+ nvtx3::scoped_range r("ArgPartition::eval_gpu");
1068
+ gpu_sort(stream(), inputs[0], out, axis_, true);
1069
+ }
1070
+
1071
+ void Partition::eval_gpu(const std::vector<array>& inputs, array& out) {
1072
+ nvtx3::scoped_range r("Partition::eval_gpu");
1073
+ gpu_sort(stream(), inputs[0], out, axis_, false);
1074
+ }
1075
+
1076
+ } // namespace mlx::core