mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,154 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include "mlx/backend/common/reduce.h"
4
+
5
+ namespace mlx::core {
6
+
7
+ std::pair<Shape, Strides> shapes_without_reduction_axes(
8
+ Shape shape,
9
+ Strides strides,
10
+ const std::vector<int>& axes) {
11
+ for (int i = axes.size() - 1; i >= 0; i--) {
12
+ int a = axes[i];
13
+ shape.erase(shape.begin() + a);
14
+ strides.erase(strides.begin() + a);
15
+ }
16
+
17
+ return std::make_pair(shape, strides);
18
+ }
19
+
20
+ std::pair<Shape, Strides> shapes_without_reduction_axes(
21
+ const array& x,
22
+ const std::vector<int>& axes) {
23
+ auto shape = x.shape();
24
+ auto strides = x.strides();
25
+ return shapes_without_reduction_axes(
26
+ std::move(shape), std::move(strides), axes);
27
+ }
28
+
29
+ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes) {
30
+ // The data is all there and we are reducing over everything
31
+ if (x.size() == x.data_size() && axes.size() == x.ndim() &&
32
+ x.flags().contiguous) {
33
+ return ContiguousAllReduce;
34
+ }
35
+
36
+ // Row contiguous input so the output is row contiguous
37
+ if (x.flags().row_contiguous) {
38
+ // Merge consecutive axes
39
+ Shape shape = {x.shape(axes[0])};
40
+ Strides strides = {x.strides()[axes[0]]};
41
+ for (int i = 1; i < axes.size(); i++) {
42
+ if (axes[i] - 1 == axes[i - 1] && x.shape(axes[i]) > 1) {
43
+ shape.back() *= x.shape(axes[i]);
44
+ strides.back() = x.strides()[axes[i]];
45
+ } else {
46
+ shape.push_back(x.shape(axes[i]));
47
+ strides.push_back(x.strides()[axes[i]]);
48
+ }
49
+ }
50
+
51
+ // Remove singleton axes from the plan
52
+ for (int i = shape.size() - 1; i >= 0; i--) {
53
+ if (shape[i] == 1) {
54
+ shape.erase(shape.begin() + i);
55
+ strides.erase(strides.begin() + i);
56
+ }
57
+ }
58
+
59
+ if (strides.back() == 1) {
60
+ return ReductionPlan(ContiguousReduce, shape, strides);
61
+ } else if (strides.back() > 1) {
62
+ return ReductionPlan(ContiguousStridedReduce, shape, strides);
63
+ }
64
+ }
65
+
66
+ // Let's check if we can optimize our access patterns
67
+ //
68
+ // 1. We have a reduction axis with stride 1. Simply call
69
+ // GeneralContiguousReduce and be done with it.
70
+ // 2. We have transpositions and we are not reducing over the axis with
71
+ // stride 1. However, we are reducing over an axis where everything is
72
+ // contiguous in memory to the right of that axis. We can call strided
73
+ // reduce and be done with it.
74
+ // 2. We have weird transpositions and expands. Copy the strides to the
75
+ // output, then call strided reduce.
76
+
77
+ // Sort reduction axes by stride in order to merge them and figure out if we
78
+ // have a contiguous reduction.
79
+ std::vector<std::pair<int, int64_t>> reductions;
80
+ for (auto a : axes) {
81
+ if (x.shape(a) > 1) {
82
+ reductions.push_back(std::make_pair(x.shape(a), x.strides()[a]));
83
+ }
84
+ }
85
+ std::sort(reductions.begin(), reductions.end(), [](auto a, auto b) {
86
+ bool a_is_zero = a.second == 0;
87
+ bool b_is_zero = b.second == 0;
88
+ return (a_is_zero != b_is_zero) ? a.second < b.second : a.second > b.second;
89
+ });
90
+ // Extract the two smallest and try to merge them in case the contiguous
91
+ // reduction can be bigger than just the last axis.
92
+ for (int i = reductions.size() - 1; i >= 1; i--) {
93
+ auto a = reductions[i];
94
+ auto b = reductions[i - 1];
95
+
96
+ // b.stride = a.shape * a.stride then a and b are contiguous
97
+ if (b.second == a.first * a.second) {
98
+ reductions.erase(reductions.begin() + i);
99
+ reductions[i - 1] = std::make_pair(a.first * b.first, a.second);
100
+ }
101
+ }
102
+
103
+ Shape shape;
104
+ Strides strides;
105
+ for (auto r : reductions) {
106
+ shape.push_back(r.first);
107
+ strides.push_back(r.second);
108
+ }
109
+
110
+ // We can call the contiguous reduction op for every weird way the input is
111
+ // structured in the rest of the axes.
112
+ if (strides.back() == 1) {
113
+ return ReductionPlan(GeneralContiguousReduce, shape, strides);
114
+ }
115
+
116
+ // Delegate to the general strided reduction op if the axes after
117
+ // strides.back() are contiguous.
118
+ if (strides.back() > 1) {
119
+ int64_t size = 1;
120
+ bool have_expand = false;
121
+ for (int i = x.ndim() - 1; i >= 0; i--) {
122
+ if (axes.back() == i) {
123
+ continue;
124
+ }
125
+
126
+ auto stride_i = x.strides()[i];
127
+ auto shape_i = x.shape(i);
128
+ if (stride_i == 0) {
129
+ if (shape_i == 1) {
130
+ continue;
131
+ }
132
+
133
+ have_expand = true;
134
+ break;
135
+ }
136
+
137
+ if (stride_i != size && shape_i != 1) {
138
+ break;
139
+ }
140
+ size *= shape_i;
141
+ }
142
+ // In the case of an expanded dimension we are being conservative and
143
+ // require the smallest reduction stride to be smaller than the maximum row
144
+ // contiguous size. The reason is that we can't easily know if the reduced
145
+ // axis is before or after an expanded dimension.
146
+ if (size > strides.back() || (size == strides.back() && !have_expand)) {
147
+ return ReductionPlan(GeneralStridedReduce, shape, strides);
148
+ }
149
+ }
150
+
151
+ return ReductionPlan(GeneralReduce, shape, strides);
152
+ }
153
+
154
+ } // namespace mlx::core
@@ -0,0 +1,59 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/common/utils.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ enum ReductionOpType {
10
+ // Self-explanatory. Read everything and produce 1 output.
11
+ ContiguousAllReduce,
12
+
13
+ // The input is contiguous and the last axis is reduced
14
+ // N1xR1xN2xR2x...xNnxRn
15
+ ContiguousReduce,
16
+
17
+ // The input is contiguous and the last axis is not reduced
18
+ // R1xN1xR2xN2x...xRnxNn
19
+ ContiguousStridedReduce,
20
+
21
+ // The input is not contiguous but the last axis is and it is reduced so we
22
+ // need to figure out the offsets but we can call the contiguous reduce after
23
+ // that.
24
+ // N3xR1xN1xR4x...xRn
25
+ GeneralContiguousReduce,
26
+
27
+ // The input is not contiguous but the last reduction axis and the last axis
28
+ // are so we need to figure out the offset but we can call the strided reduce
29
+ // after that.
30
+ GeneralStridedReduce,
31
+
32
+ // The input is not contiguous after the reduction axis and it may contain
33
+ // 0-stride axes or transpositions. We could copy the strides and produce a
34
+ // transposed outcome or we can read the input out of order and write the
35
+ // output in order.
36
+ GeneralReduce
37
+ };
38
+
39
+ struct ReductionPlan {
40
+ ReductionOpType type;
41
+ Shape shape;
42
+ Strides strides;
43
+
44
+ ReductionPlan(ReductionOpType type_, Shape shape_, Strides strides_)
45
+ : type(type_), shape(std::move(shape_)), strides(std::move(strides_)) {}
46
+ ReductionPlan(ReductionOpType type_) : type(type_) {}
47
+ };
48
+
49
+ ReductionPlan get_reduction_plan(const array& x, const std::vector<int>& axes);
50
+
51
+ std::pair<Shape, Strides> shapes_without_reduction_axes(
52
+ const array& x,
53
+ const std::vector<int>& axes);
54
+ std::pair<Shape, Strides> shapes_without_reduction_axes(
55
+ Shape shape,
56
+ Strides strides,
57
+ const std::vector<int>& axes);
58
+
59
+ } // namespace mlx::core
@@ -0,0 +1,71 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include "mlx/backend/common/utils.h"
4
+
5
+ namespace mlx::core {
6
+
7
+ std::tuple<int64_t, Strides> prepare_slice(
8
+ const array& in,
9
+ const Shape& start_indices,
10
+ const Shape& strides) {
11
+ int64_t data_offset = 0;
12
+ Strides inp_strides(in.ndim(), 0);
13
+ for (int i = 0; i < in.ndim(); ++i) {
14
+ data_offset += start_indices[i] * in.strides()[i];
15
+ inp_strides[i] = in.strides()[i] * strides[i];
16
+ }
17
+ return std::make_tuple(data_offset, inp_strides);
18
+ }
19
+
20
+ void shared_buffer_slice(
21
+ const array& in,
22
+ const Strides& out_strides,
23
+ int64_t data_offset,
24
+ size_t data_size,
25
+ array& out) {
26
+ // Compute row/col contiguity
27
+ auto [no_bsx_size, is_row_contiguous, is_col_contiguous] =
28
+ check_contiguity(out.shape(), out_strides);
29
+
30
+ auto flags = in.flags();
31
+ flags.row_contiguous = is_row_contiguous;
32
+ flags.col_contiguous = is_col_contiguous;
33
+ flags.contiguous = (no_bsx_size == data_size);
34
+
35
+ out.copy_shared_buffer(in, out_strides, flags, data_size, data_offset);
36
+ }
37
+
38
+ void slice(
39
+ const array& in,
40
+ array& out,
41
+ const Shape& start_indices,
42
+ const Shape& strides) {
43
+ if (out.size() == 0) {
44
+ out.set_data(allocator::malloc(0));
45
+ return;
46
+ }
47
+
48
+ // Calculate out strides, initial offset
49
+ auto [data_offset, inp_strides] = prepare_slice(in, start_indices, strides);
50
+
51
+ // Get the location of the end based on the inp strides and out.shape()
52
+ int64_t low_idx = 0;
53
+ int64_t high_idx = 0;
54
+ for (int i = 0; i < inp_strides.size(); ++i) {
55
+ auto delta = inp_strides[i] * (out.shape()[i] - 1);
56
+ if (inp_strides[i] > 0) {
57
+ high_idx += delta;
58
+ } else {
59
+ low_idx += delta;
60
+ }
61
+ }
62
+ int64_t data_size = (high_idx - low_idx) + 1;
63
+ if (data_size < 0) {
64
+ std::ostringstream msg;
65
+ msg << "[slice] Computed invalid data size: " << data_size << ".";
66
+ throw std::runtime_error(msg.str());
67
+ }
68
+ shared_buffer_slice(in, inp_strides, data_offset, data_size, out);
69
+ }
70
+
71
+ } // namespace mlx::core
@@ -0,0 +1,20 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/array.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ std::tuple<int64_t, Strides> prepare_slice(
10
+ const array& in,
11
+ const Shape& start_indices,
12
+ const Shape& strides);
13
+
14
+ void slice(
15
+ const array& in,
16
+ array& out,
17
+ const Shape& start_indices,
18
+ const Shape& strides);
19
+
20
+ } // namespace mlx::core
@@ -0,0 +1,85 @@
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+ #include "mlx/allocator.h"
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/common/utils.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ // TODO: Add support for more combinations of input types.
11
+ enum class TernaryOpType {
12
+ ScalarScalarScalar,
13
+ VectorVectorVector,
14
+ VectorVectorScalar,
15
+ VectorScalarVector,
16
+ General,
17
+ };
18
+
19
+ inline TernaryOpType
20
+ get_ternary_op_type(const array& a, const array& b, const array& c) {
21
+ TernaryOpType topt;
22
+ if (a.data_size() == 1 && b.data_size() == 1 && c.data_size() == 1) {
23
+ topt = TernaryOpType::ScalarScalarScalar;
24
+ } else if (
25
+ (a.flags().row_contiguous && b.flags().row_contiguous &&
26
+ c.flags().row_contiguous) ||
27
+ (a.flags().col_contiguous && b.flags().col_contiguous &&
28
+ c.flags().col_contiguous)) {
29
+ topt = TernaryOpType::VectorVectorVector;
30
+ } else if (
31
+ b.data_size() == 1 && a.flags().row_contiguous &&
32
+ c.flags().row_contiguous) {
33
+ topt = TernaryOpType::VectorScalarVector;
34
+ } else if (
35
+ c.data_size() == 1 && a.flags().row_contiguous &&
36
+ b.flags().row_contiguous) {
37
+ topt = TernaryOpType::VectorVectorScalar;
38
+ } else {
39
+ topt = TernaryOpType::General;
40
+ }
41
+ return topt;
42
+ }
43
+
44
+ inline void set_ternary_op_output_data(
45
+ const array& a,
46
+ const array& b,
47
+ const array& c,
48
+ array& out,
49
+ TernaryOpType topt,
50
+ std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
51
+ auto maybe_donate = [&out](const array& x) {
52
+ if (is_donatable(x, out)) {
53
+ out.copy_shared_buffer(x);
54
+ return true;
55
+ }
56
+ return false;
57
+ };
58
+
59
+ switch (topt) {
60
+ case TernaryOpType::ScalarScalarScalar:
61
+ out.set_data(mallocfn(out.itemsize()), 1, b.strides(), b.flags());
62
+ break;
63
+ case TernaryOpType::VectorVectorVector:
64
+ if (!(maybe_donate(a) || maybe_donate(b) || maybe_donate(c))) {
65
+ out.set_data(
66
+ mallocfn(out.itemsize() * b.data_size()),
67
+ b.data_size(),
68
+ b.strides(),
69
+ b.flags());
70
+ }
71
+ break;
72
+ case TernaryOpType::VectorVectorScalar:
73
+ case TernaryOpType::VectorScalarVector:
74
+ case TernaryOpType::General:
75
+ // Try to donate an input which is row_contiguous
76
+ if (!((a.flags().row_contiguous && maybe_donate(a)) ||
77
+ (b.flags().row_contiguous && maybe_donate(b)) ||
78
+ (c.flags().row_contiguous && maybe_donate(c)))) {
79
+ out.set_data(mallocfn(out.nbytes()));
80
+ }
81
+ break;
82
+ }
83
+ }
84
+
85
+ } // namespace mlx::core
@@ -0,0 +1,29 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/allocator.h"
6
+ #include "mlx/backend/common/utils.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ inline void set_unary_output_data(
11
+ const array& in,
12
+ array& out,
13
+ std::function<allocator::Buffer(size_t)> mallocfn = allocator::malloc) {
14
+ if (in.flags().contiguous) {
15
+ if (is_donatable(in, out)) {
16
+ out.copy_shared_buffer(in);
17
+ } else {
18
+ out.set_data(
19
+ mallocfn(in.data_size() * out.itemsize()),
20
+ in.data_size(),
21
+ in.strides(),
22
+ in.flags());
23
+ }
24
+ } else {
25
+ out.set_data(mallocfn(out.nbytes()));
26
+ }
27
+ }
28
+
29
+ } // namespace mlx::core
@@ -0,0 +1,231 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <dlfcn.h>
4
+
5
+ #include "mlx/backend/common/utils.h"
6
+
7
+ namespace mlx::core {
8
+
9
+ std::filesystem::path current_binary_dir() {
10
+ static std::filesystem::path binary_dir = []() {
11
+ Dl_info info;
12
+ if (!dladdr(reinterpret_cast<void*>(&current_binary_dir), &info)) {
13
+ throw std::runtime_error("Unable to get current binary dir.");
14
+ }
15
+ return std::filesystem::path(info.dli_fname).parent_path();
16
+ }();
17
+ return binary_dir;
18
+ }
19
+
20
+ std::tuple<Shape, std::vector<Strides>> collapse_contiguous_dims(
21
+ const Shape& shape,
22
+ const std::vector<Strides>& strides,
23
+ int64_t size_cap) {
24
+ // Make a vector that has axes separated with -1. Collapse all axes between
25
+ // -1.
26
+ Shape to_collapse;
27
+ if (shape.size() > 0) {
28
+ if (shape[0] != 1) {
29
+ to_collapse.push_back(0);
30
+ }
31
+ size_t size = shape[0];
32
+ for (int i = 1; i < shape.size(); i++) {
33
+ bool contiguous = true;
34
+ size *= shape[i];
35
+ for (const auto& st : strides) {
36
+ if (st[i] * shape[i] != st[i - 1] || size > size_cap) {
37
+ contiguous = false;
38
+ size = shape[i];
39
+ break;
40
+ }
41
+ }
42
+ if (!contiguous) {
43
+ to_collapse.push_back(-1);
44
+ }
45
+ if (shape[i] != 1) {
46
+ to_collapse.push_back(i);
47
+ }
48
+ }
49
+ to_collapse.push_back(-1);
50
+ }
51
+
52
+ Shape out_shape;
53
+ std::vector<Strides> out_strides(strides.size());
54
+ for (int i = 0;;) {
55
+ while (i < to_collapse.size() && to_collapse[i] == -1) {
56
+ ++i;
57
+ };
58
+ if (i == to_collapse.size()) {
59
+ break;
60
+ }
61
+ int current_shape = shape[to_collapse[i]];
62
+ int k = i;
63
+ while (to_collapse[++k] != -1) {
64
+ current_shape *= shape[to_collapse[k]];
65
+ }
66
+ out_shape.push_back(current_shape);
67
+ for (int j = 0; j < strides.size(); j++) {
68
+ const auto& st = strides[j];
69
+ out_strides[j].push_back(st[to_collapse[k - 1]]);
70
+ }
71
+ i = k + 1;
72
+ }
73
+
74
+ if (!shape.empty() && out_shape.empty()) {
75
+ out_shape.push_back(1);
76
+ for (auto& out_stride : out_strides) {
77
+ out_stride.push_back(0);
78
+ }
79
+ }
80
+ return std::make_tuple(out_shape, out_strides);
81
+ }
82
+
83
+ std::pair<Shape, Strides> collapse_contiguous_dims(
84
+ const Shape& shape,
85
+ const Strides& strides,
86
+ int64_t size_cap) {
87
+ Shape collapsed_shape;
88
+ Strides collapsed_strides;
89
+
90
+ if (shape.size() > 0) {
91
+ collapsed_shape.push_back(shape[0]);
92
+ collapsed_strides.push_back(strides[0]);
93
+ for (int i = 1; i < shape.size(); i++) {
94
+ if (shape[i] == 1) {
95
+ continue;
96
+ } else if (
97
+ strides[i] * shape[i] != collapsed_strides.back() ||
98
+ collapsed_shape.back() * static_cast<int64_t>(shape[i]) > size_cap) {
99
+ collapsed_shape.push_back(shape[i]);
100
+ collapsed_strides.push_back(strides[i]);
101
+ } else {
102
+ collapsed_shape.back() *= shape[i];
103
+ collapsed_strides.back() = strides[i];
104
+ }
105
+ }
106
+ }
107
+
108
+ return std::make_pair(collapsed_shape, collapsed_strides);
109
+ }
110
+
111
+ std::pair<Shape, Strides> collapse_contiguous_dims(
112
+ const array& a,
113
+ int64_t size_cap /* = std::numeric_limits<int32_t>::max()*/) {
114
+ return collapse_contiguous_dims(a.shape(), a.strides(), size_cap);
115
+ }
116
+
117
+ Dims get_block_dims_common(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
118
+ int pows[3] = {0, 0, 0};
119
+ int sum = 0;
120
+ while (true) {
121
+ int presum = sum;
122
+ // Check all the pows
123
+ if (dim0 >= (1 << (pows[0] + 1))) {
124
+ pows[0]++;
125
+ sum++;
126
+ }
127
+ if (sum == 10) {
128
+ break;
129
+ }
130
+ if (dim1 >= (1 << (pows[1] + 1))) {
131
+ pows[1]++;
132
+ sum++;
133
+ }
134
+ if (sum == 10) {
135
+ break;
136
+ }
137
+ if (dim2 >= (1 << (pows[2] + 1))) {
138
+ pows[2]++;
139
+ sum++;
140
+ }
141
+ if (sum == presum || sum == pow2) {
142
+ break;
143
+ }
144
+ }
145
+ return std::make_tuple(1ul << pows[0], 1ul << pows[1], 1ul << pows[2]);
146
+ }
147
+
148
+ Dims get_2d_grid_dims_common(const Shape& shape, const Strides& strides) {
149
+ // Dims with strides of 0 are ignored as they
150
+ // correspond to broadcasted dimensions
151
+ size_t grid_x = 1;
152
+ size_t grid_y = 1;
153
+ for (int i = 0; i < shape.size(); ++i) {
154
+ if (strides[i] == 0) {
155
+ continue;
156
+ }
157
+ if (grid_x * shape[i] < UINT32_MAX) {
158
+ grid_x *= shape[i];
159
+ } else {
160
+ grid_y *= shape[i];
161
+ }
162
+ }
163
+ if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
164
+ throw std::runtime_error("Unable to safely factor shape.");
165
+ }
166
+ if (grid_y > grid_x) {
167
+ std::swap(grid_x, grid_y);
168
+ }
169
+ return std::make_tuple(
170
+ static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
171
+ }
172
+
173
+ Dims get_2d_grid_dims_common(
174
+ const Shape& shape,
175
+ const Strides& strides,
176
+ size_t divisor) {
177
+ // Compute the 2d grid dimensions such that the total size of the grid is
178
+ // divided by divisor.
179
+ size_t grid_x = 1;
180
+ size_t grid_y = 1;
181
+ for (int i = 0; i < shape.size(); ++i) {
182
+ if (strides[i] == 0) {
183
+ continue;
184
+ }
185
+
186
+ // No need to add this shape we can just remove it from the divisor.
187
+ if (divisor % shape[i] == 0) {
188
+ divisor /= shape[i];
189
+ continue;
190
+ }
191
+
192
+ if (grid_x * shape[i] < UINT32_MAX) {
193
+ grid_x *= shape[i];
194
+ } else {
195
+ grid_y *= shape[i];
196
+ }
197
+
198
+ if (divisor > 1) {
199
+ if (grid_x % divisor == 0) {
200
+ grid_x /= divisor;
201
+ divisor = 1;
202
+ } else if (grid_y % divisor == 0) {
203
+ grid_y /= divisor;
204
+ divisor = 1;
205
+ }
206
+ }
207
+ }
208
+ if (grid_y > UINT32_MAX || grid_x > UINT32_MAX) {
209
+ throw std::runtime_error("Unable to safely factor shape.");
210
+ }
211
+ if (grid_y > grid_x) {
212
+ std::swap(grid_x, grid_y);
213
+ }
214
+ if (divisor > 1) {
215
+ grid_x = ((grid_x + divisor - 1) / divisor) * divisor;
216
+ }
217
+ return std::make_tuple(
218
+ static_cast<uint32_t>(grid_x), static_cast<uint32_t>(grid_y), 1);
219
+ }
220
+
221
+ std::pair<Dims, Dims> get_grid_and_block_common(int dim0, int dim1, int dim2) {
222
+ auto [bx, by, bz] = get_block_dims_common(dim0, dim1, dim2);
223
+ auto gx = (dim0 + bx - 1) / bx;
224
+ auto gy = (dim1 + by - 1) / by;
225
+ auto gz = (dim2 + bz - 1) / bz;
226
+
227
+ return std::make_pair(
228
+ std::make_tuple(gx, gy, gz), std::make_tuple(bx, by, bz));
229
+ }
230
+
231
+ } // namespace mlx::core