mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,727 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <fmt/format.h>
4
+
5
+ #include "mlx/backend/common/compiled.h"
6
+ #include "mlx/backend/common/utils.h"
7
+ #include "mlx/backend/gpu/copy.h"
8
+ #include "mlx/backend/metal/device.h"
9
+ #include "mlx/backend/metal/jit/includes.h"
10
+ #include "mlx/backend/metal/jit/indexing.h"
11
+ #include "mlx/backend/metal/kernels.h"
12
+ #include "mlx/backend/metal/scan.h"
13
+ #include "mlx/backend/metal/utils.h"
14
+ #include "mlx/dtype.h"
15
+ #include "mlx/primitives.h"
16
+ #include "mlx/utils.h"
17
+
18
+ namespace mlx::core {
19
+
20
+ constexpr int METAL_MAX_INDEX_ARRAYS = 20;
21
+
22
+ std::pair<std::string, std::string> make_index_args(
23
+ const std::string& idx_type,
24
+ int nidx) {
25
+ std::ostringstream idx_args;
26
+ std::ostringstream idx_arr;
27
+ for (int i = 0; i < nidx; ++i) {
28
+ idx_args << fmt::format(
29
+ "const device {0} *idx{1} [[buffer({2})]],", idx_type, i, 20 + i);
30
+ idx_arr << fmt::format("idx{0}", i);
31
+ if (i < nidx - 1) {
32
+ idx_args << "\n";
33
+ idx_arr << ",";
34
+ }
35
+ }
36
+ return {idx_args.str(), idx_arr.str()};
37
+ }
38
+
39
+ void Gather::eval_gpu(const std::vector<array>& inputs, array& out) {
40
+ auto& src = inputs[0];
41
+ int nidx = inputs.size() - 1;
42
+
43
+ if (nidx > METAL_MAX_INDEX_ARRAYS) {
44
+ std::ostringstream msg;
45
+ msg << "[Gather::eval_gpu] Gathering with more than "
46
+ << METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
47
+ throw std::runtime_error(msg.str());
48
+ }
49
+
50
+ out.set_data(allocator::malloc(out.nbytes()));
51
+ if (out.size() == 0) {
52
+ return;
53
+ }
54
+
55
+ auto& s = stream();
56
+ auto& d = metal::device(s.device);
57
+
58
+ size_t slice_size = 1;
59
+ for (auto s : slice_sizes_) {
60
+ slice_size *= s;
61
+ }
62
+
63
+ bool large_index = nidx && inputs[1].size() > INT32_MAX;
64
+ bool large_src = src.size() > INT32_MAX;
65
+ bool large_out = out.size() > INT32_MAX;
66
+ bool large = large_index || large_src || large_out;
67
+
68
+ std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
69
+
70
+ if (src.flags().row_contiguous && nidx == 1 && axes_[0] == 0 &&
71
+ inputs[1].flags().row_contiguous && slice_size == src.strides()[0]) {
72
+ int work_per_thread = (slice_size > 8 && src.dtype().size() < 4) ? 2 : 1;
73
+ auto& indices = inputs[1];
74
+ std::string kernel_name = fmt::format(
75
+ "gather_front{0}_{1}_{2}_{3}",
76
+ type_to_name(out),
77
+ idx_type_name,
78
+ large ? "int64_t" : "int",
79
+ work_per_thread);
80
+ std::string lib_name = kernel_name;
81
+
82
+ auto lib = d.get_library(lib_name, [&]() {
83
+ std::string kernel_source = metal::utils();
84
+ kernel_source += metal::gather_front();
85
+ kernel_source += get_template_definition(
86
+ kernel_name,
87
+ "gather_front",
88
+ get_type_string(out.dtype()),
89
+ get_type_string(indices.dtype()),
90
+ large ? "int64_t" : "int",
91
+ work_per_thread);
92
+
93
+ return kernel_source;
94
+ });
95
+
96
+ auto& compute_encoder = d.get_command_encoder(s.index);
97
+ auto kernel = d.get_kernel(kernel_name, lib);
98
+ compute_encoder.set_compute_pipeline_state(kernel);
99
+
100
+ size_t dim_x = (slice_size + work_per_thread - 1) / work_per_thread;
101
+ size_t dim_y = indices.size();
102
+ auto group_dims = get_block_dims(dim_x, dim_y, 1);
103
+ MTL::Size grid_dims = MTL::Size(dim_x, dim_y, 1);
104
+
105
+ compute_encoder.set_input_array(src, 0);
106
+ compute_encoder.set_input_array(indices, 1);
107
+ compute_encoder.set_output_array(out, 2);
108
+ compute_encoder.set_bytes(slice_size, 3);
109
+ compute_encoder.set_bytes(src.shape(0), 4);
110
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
111
+
112
+ return;
113
+ }
114
+
115
+ int idx_ndim = nidx ? inputs[1].ndim() : 0;
116
+ size_t ndim = src.ndim();
117
+
118
+ std::string kernel_name = fmt::format(
119
+ "gather{0}{1}_{2}_{3}_{4}",
120
+ type_to_name(out),
121
+ idx_type_name,
122
+ nidx,
123
+ idx_ndim,
124
+ large ? "int64_t" : "int");
125
+ std::string lib_name = kernel_name;
126
+
127
+ auto lib = d.get_library(lib_name, [&]() {
128
+ std::string kernel_source = metal::utils();
129
+ kernel_source += metal::gather();
130
+ std::string out_type_str = get_type_string(out.dtype());
131
+ std::string idx_type_str =
132
+ nidx ? get_type_string(inputs[1].dtype()) : "bool";
133
+ auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
134
+
135
+ // Index dimension specializations
136
+ kernel_source += fmt::format(
137
+ gather_kernels,
138
+ type_to_name(out) + idx_type_name,
139
+ out_type_str,
140
+ idx_type_str,
141
+ nidx,
142
+ idx_args,
143
+ idx_arr,
144
+ idx_ndim,
145
+ large ? "int64_t" : "int");
146
+ return kernel_source;
147
+ });
148
+
149
+ auto& compute_encoder = d.get_command_encoder(s.index);
150
+ auto kernel = d.get_kernel(kernel_name, lib);
151
+ compute_encoder.set_compute_pipeline_state(kernel);
152
+
153
+ // Launch 3D grid of threads
154
+ // First two dimensions for the indices, the last one for the slice
155
+ size_t dim0 = 1;
156
+ size_t dim1 = 1;
157
+ if (nidx) {
158
+ if (inputs[1].ndim() >= 1) {
159
+ dim0 = inputs[1].shape(0);
160
+ }
161
+ if (inputs[1].ndim() >= 2) {
162
+ dim1 = inputs[1].size() / dim0;
163
+ }
164
+ }
165
+ size_t dim2 = slice_size;
166
+ auto group_dims = get_block_dims(dim0, dim1, dim2);
167
+ MTL::Size grid_dims = MTL::Size(dim0, dim1, dim2);
168
+
169
+ // Collect all idx shapes and strides into one place
170
+ std::vector<int> idx_shapes;
171
+ std::vector<size_t> idx_strides;
172
+ std::vector<char> idx_contigs;
173
+ for (int i = 0; i < nidx; ++i) {
174
+ idx_shapes.insert(
175
+ idx_shapes.end(),
176
+ inputs[i + 1].shape().begin(),
177
+ inputs[i + 1].shape().end());
178
+ idx_strides.insert(
179
+ idx_strides.end(),
180
+ inputs[i + 1].strides().begin(),
181
+ inputs[i + 1].strides().end());
182
+ idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
183
+ }
184
+
185
+ // Set all the buffers
186
+ compute_encoder.set_input_array(src, 0);
187
+ compute_encoder.set_output_array(out, 1);
188
+
189
+ // Set source info
190
+ compute_encoder.set_vector_bytes(src.shape(), 2);
191
+ compute_encoder.set_vector_bytes(src.strides(), 3);
192
+ compute_encoder.set_bytes(ndim, 4);
193
+ compute_encoder.set_vector_bytes(slice_sizes_, 5);
194
+ compute_encoder.set_vector_bytes(axes_, 6);
195
+
196
+ // Set index info
197
+ //
198
+ // We don't need to check for empty idx_shapes because gather has a
199
+ // idx_ndim == 0 specialization
200
+ compute_encoder.set_vector_bytes(idx_shapes, 7);
201
+ compute_encoder.set_vector_bytes(idx_strides, 8);
202
+ compute_encoder.set_vector_bytes(idx_contigs, 9);
203
+ compute_encoder.set_bytes(idx_ndim, 10);
204
+
205
+ // Set index buffers
206
+ for (int i = 0; i < nidx; ++i) {
207
+ compute_encoder.set_input_array(inputs[i + 1], 20 + i);
208
+ }
209
+
210
+ // Launch grid
211
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
212
+ }
213
+
214
+ void Scatter::eval_gpu(const std::vector<array>& inputs, array& out) {
215
+ if (size_of(out.dtype()) == 8) {
216
+ std::ostringstream msg;
217
+ msg << "[Scatter::eval_gpu] Does not support " << out.dtype();
218
+ throw std::invalid_argument(msg.str());
219
+ }
220
+
221
+ int nidx = axes_.size();
222
+ if (nidx > METAL_MAX_INDEX_ARRAYS) {
223
+ std::ostringstream msg;
224
+ msg << "[Scatter::eval_gpu] Gathering with more than "
225
+ << METAL_MAX_INDEX_ARRAYS << " index arrays not yet supported.";
226
+ throw std::runtime_error(msg.str());
227
+ }
228
+
229
+ // Copy src into out
230
+ CopyType copy_type;
231
+ if (inputs[0].data_size() == 1) {
232
+ copy_type = CopyType::Scalar;
233
+ } else if (inputs[0].flags().row_contiguous) {
234
+ copy_type = CopyType::Vector;
235
+ } else {
236
+ copy_type = CopyType::General;
237
+ }
238
+ copy_gpu(inputs[0], out, copy_type);
239
+
240
+ auto& upd = inputs.back();
241
+
242
+ // Empty update
243
+ if (upd.size() == 0) {
244
+ return;
245
+ }
246
+
247
+ // Get stream
248
+ auto& s = stream();
249
+ auto& d = metal::device(s.device);
250
+
251
+ int idx_ndim = nidx ? inputs[1].ndim() : 0;
252
+ size_t idx_size = nidx ? inputs[1].size() : 1;
253
+
254
+ auto idx_to_out = idx_size / out.size();
255
+ int nwork;
256
+ if (idx_ndim <= 1 || idx_to_out < 1) {
257
+ nwork = 1;
258
+ } else if (idx_to_out <= 4) {
259
+ nwork = 4;
260
+ } else if (idx_to_out < 16) {
261
+ nwork = 8;
262
+ } else if (idx_to_out < 32) {
263
+ nwork = 16;
264
+ } else {
265
+ nwork = 32;
266
+ }
267
+
268
+ std::string idx_type_name = nidx ? type_to_name(inputs[1]) : "";
269
+ std::string op_name;
270
+ switch (reduce_type_) {
271
+ case Scatter::None:
272
+ op_name = "none";
273
+ break;
274
+ case Scatter::Sum:
275
+ op_name = "sum";
276
+ break;
277
+ case Scatter::Prod:
278
+ op_name = "prod";
279
+ break;
280
+ case Scatter::Max:
281
+ op_name = "max";
282
+ break;
283
+ case Scatter::Min:
284
+ op_name = "min";
285
+ break;
286
+ }
287
+ auto upd_contig = upd.flags().row_contiguous;
288
+ bool large_out = out.size() > INT32_MAX;
289
+ bool large_idx = nidx && (inputs[1].size() > INT32_MAX);
290
+ bool large_upd = upd.size() > INT32_MAX;
291
+ bool large = large_out || large_idx || large_upd;
292
+ std::string kernel_name = fmt::format(
293
+ "scatter{0}{1}_{2}_{3}_{4}_nwork{5}_{6}",
294
+ type_to_name(out),
295
+ idx_type_name,
296
+ op_name,
297
+ nidx,
298
+ upd_contig ? "updc_true" : "updc_false",
299
+ nwork,
300
+ large ? "int64_t" : "int");
301
+ std::string lib_name = kernel_name;
302
+
303
+ auto lib = d.get_library(lib_name, [&]() {
304
+ std::string kernel_source = metal::utils();
305
+ concatenate(kernel_source, metal::reduce_utils(), metal::scatter());
306
+
307
+ std::string out_type_str = get_type_string(out.dtype());
308
+ std::string idx_type_str =
309
+ nidx ? get_type_string(inputs[1].dtype()) : "bool";
310
+ std::string op_type;
311
+ switch (reduce_type_) {
312
+ case Scatter::None:
313
+ op_type = "None";
314
+ break;
315
+ case Scatter::Sum:
316
+ op_type = "Sum<{0}>";
317
+ break;
318
+ case Scatter::Prod:
319
+ op_type = "Prod<{0}>";
320
+ break;
321
+ case Scatter::Max:
322
+ op_type = "Max<{0}>";
323
+ break;
324
+ case Scatter::Min:
325
+ op_type = "Min<{0}>";
326
+ break;
327
+ }
328
+ if (reduce_type_ != Scatter::None) {
329
+ op_type = fmt::format(fmt::runtime(op_type), out_type_str);
330
+ }
331
+ auto [idx_args, idx_arr] = make_index_args(idx_type_str, nidx);
332
+
333
+ kernel_source += fmt::format(
334
+ scatter_kernels,
335
+ type_to_name(out) + idx_type_name + "_" + op_name,
336
+ out_type_str,
337
+ idx_type_str,
338
+ op_type,
339
+ nidx,
340
+ idx_args,
341
+ idx_arr,
342
+ upd_contig,
343
+ nwork,
344
+ large ? "int64_t" : "int");
345
+ return kernel_source;
346
+ });
347
+
348
+ auto& compute_encoder = d.get_command_encoder(s.index);
349
+ auto kernel = d.get_kernel(kernel_name, lib);
350
+
351
+ size_t nthreads = upd.size();
352
+
353
+ compute_encoder.set_compute_pipeline_state(kernel);
354
+
355
+ // Set all the buffers
356
+ compute_encoder.set_input_array(upd, 1);
357
+ compute_encoder.set_output_array(out, 2);
358
+
359
+ // Set update info
360
+ size_t upd_ndim = upd.ndim();
361
+ size_t upd_size = 1;
362
+ for (int i = idx_ndim; i < upd.ndim(); ++i) {
363
+ upd_size *= upd.shape(i);
364
+ }
365
+ // Collect all idx shapes and strides into one place
366
+ Shape idx_shapes;
367
+ Strides idx_strides;
368
+ // To access .data() use char instead of bool
369
+ // bool is 1 byte in Metal so this is safe
370
+ std::vector<char> idx_contigs;
371
+ for (int i = 0; i < nidx; ++i) {
372
+ idx_shapes.insert(
373
+ idx_shapes.end(),
374
+ inputs[i + 1].shape().begin(),
375
+ inputs[i + 1].shape().end());
376
+ idx_strides.insert(
377
+ idx_strides.end(),
378
+ inputs[i + 1].strides().begin(),
379
+ inputs[i + 1].strides().end());
380
+ idx_contigs.push_back(inputs[i + 1].flags().row_contiguous);
381
+ }
382
+
383
+ if (upd_ndim == 0) {
384
+ // Need placeholders so Metal doesn't complain
385
+ int shape_ = 0;
386
+ int64_t stride_ = 0;
387
+ compute_encoder.set_bytes(shape_, 3);
388
+ compute_encoder.set_bytes(stride_, 4);
389
+ } else {
390
+ compute_encoder.set_vector_bytes(upd.shape(), 3);
391
+ compute_encoder.set_vector_bytes(upd.strides(), 4);
392
+ }
393
+ compute_encoder.set_bytes(upd_ndim, 5);
394
+ compute_encoder.set_bytes(upd_size, 6);
395
+
396
+ // Set output info
397
+ size_t out_ndim = out.ndim();
398
+ if (out_ndim == 0) {
399
+ // Need placeholders so Metal doesn't complain
400
+ int shape_ = 0;
401
+ int64_t stride_ = 0;
402
+ compute_encoder.set_bytes(shape_, 7);
403
+ compute_encoder.set_bytes(stride_, 8);
404
+ } else {
405
+ compute_encoder.set_vector_bytes(out.shape(), 7);
406
+ compute_encoder.set_vector_bytes(out.strides(), 8);
407
+ }
408
+ compute_encoder.set_bytes(out_ndim, 9);
409
+ compute_encoder.set_vector_bytes(axes_, 10);
410
+
411
+ // Set index info
412
+ if (idx_ndim == 0) {
413
+ // Add a 0 in idx_shapes and strides to avoid the missing buffer binding
414
+ // error in the metal API.
415
+ idx_shapes.push_back(0);
416
+ idx_strides.push_back(0);
417
+ idx_contigs.push_back(false);
418
+ }
419
+ compute_encoder.set_vector_bytes(idx_shapes, 11);
420
+ compute_encoder.set_vector_bytes(idx_strides, 12);
421
+ compute_encoder.set_vector_bytes(idx_contigs, 13);
422
+ compute_encoder.set_bytes(idx_ndim, 14);
423
+ compute_encoder.set_bytes(idx_size, 15);
424
+
425
+ // Set index buffers
426
+ for (int i = 0; i < nidx; ++i) {
427
+ compute_encoder.set_input_array(inputs[i + 1], 20 + i);
428
+ }
429
+
430
+ // Launch grid
431
+ auto grid_y = (nthreads / upd_size);
432
+ grid_y = (grid_y + nwork - 1) / nwork;
433
+ MTL::Size grid_dims = MTL::Size(upd_size, grid_y, 1);
434
+ auto thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
435
+ if (thread_group_size != 1024) {
436
+ throw std::runtime_error("[Scatter::eval_gpu] Invalid number of threads");
437
+ }
438
+ MTL::Size group_dims = get_block_dims(upd_size, grid_y, 1);
439
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
440
+ }
441
+
442
+ void GatherAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
443
+ auto& src = inputs[0];
444
+ auto& idx = inputs[1];
445
+
446
+ out.set_data(allocator::malloc(out.nbytes()));
447
+ if (out.size() == 0) {
448
+ return;
449
+ }
450
+
451
+ auto& s = stream();
452
+ auto& d = metal::device(s.device);
453
+
454
+ size_t ndim = src.ndim();
455
+
456
+ bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
457
+
458
+ std::string kernel_name = fmt::format(
459
+ "gather_axis{0}{1}_{2}",
460
+ type_to_name(out),
461
+ type_to_name(idx),
462
+ large ? "int64_t" : "int");
463
+ std::string lib_name = kernel_name;
464
+ kernel_name += src.flags().row_contiguous ? "c" : "nc";
465
+ kernel_name += idx.flags().row_contiguous ? "c" : "nc";
466
+
467
+ auto lib = d.get_library(lib_name, [&]() {
468
+ std::string kernel_source = metal::utils();
469
+ kernel_source += metal::gather_axis();
470
+ std::string out_type_str = get_type_string(out.dtype());
471
+ std::string idx_type_str = get_type_string(idx.dtype());
472
+ for (int i = 0; i < 4; ++i) {
473
+ bool sc = i & 1;
474
+ bool ic = i & 2;
475
+ kernel_source += get_template_definition(
476
+ lib_name + (sc ? "c" : "nc") + (ic ? "c" : "nc"),
477
+ "gather_axis",
478
+ out_type_str,
479
+ idx_type_str,
480
+ large ? "int64_t" : "int",
481
+ sc ? "true" : "false",
482
+ ic ? "true" : "false");
483
+ }
484
+ return kernel_source;
485
+ });
486
+
487
+ auto& compute_encoder = d.get_command_encoder(s.index);
488
+ auto kernel = d.get_kernel(kernel_name, lib);
489
+ compute_encoder.set_compute_pipeline_state(kernel);
490
+
491
+ // Grid [size post, index size, size pre]
492
+ size_t size_pre = 1;
493
+ size_t size_post = 1;
494
+ for (int i = 0; i < axis_; ++i) {
495
+ size_pre *= idx.shape(i);
496
+ }
497
+ for (int i = axis_ + 1; i < idx.ndim(); ++i) {
498
+ size_post *= idx.shape(i);
499
+ }
500
+
501
+ int idx_ax_size = idx.shape(axis_);
502
+ auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);
503
+ MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);
504
+
505
+ // Set all the buffers
506
+ compute_encoder.set_input_array(src, 0);
507
+ compute_encoder.set_input_array(idx, 1);
508
+ compute_encoder.set_output_array(out, 2);
509
+
510
+ // Set source info
511
+ compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
512
+ compute_encoder.set_vector_bytes(remove_index(src.strides(), axis_), 4);
513
+ compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
514
+ compute_encoder.set_bytes(ndim - 1, 6);
515
+ compute_encoder.set_bytes(axis_, 7);
516
+ compute_encoder.set_bytes(src.shape(axis_), 8);
517
+ compute_encoder.set_bytes(src.strides(axis_), 9);
518
+ compute_encoder.set_bytes(idx.strides(axis_), 10);
519
+
520
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
521
+ }
522
+
523
+ void ScatterAxis::eval_gpu(const std::vector<array>& inputs, array& out) {
524
+ auto& src = inputs[0];
525
+ auto& idx = inputs[1];
526
+ auto& upd = inputs[2];
527
+
528
+ // Copy src into out
529
+ CopyType copy_type;
530
+ if (src.data_size() == 1) {
531
+ copy_type = CopyType::Scalar;
532
+ } else if (src.flags().row_contiguous) {
533
+ copy_type = CopyType::Vector;
534
+ } else {
535
+ copy_type = CopyType::General;
536
+ }
537
+ copy_gpu(src, out, copy_type);
538
+
539
+ // Empty update
540
+ if (upd.size() == 0) {
541
+ return;
542
+ }
543
+
544
+ auto& s = stream();
545
+ auto& d = metal::device(s.device);
546
+
547
+ size_t ndim = src.ndim();
548
+
549
+ bool large = idx.size() > INT32_MAX || src.size() > INT32_MAX;
550
+
551
+ std::string op_name;
552
+ switch (reduce_type_) {
553
+ case ScatterAxis::None:
554
+ op_name = "none";
555
+ break;
556
+ case ScatterAxis::Sum:
557
+ op_name = "sum";
558
+ break;
559
+ }
560
+
561
+ std::string kernel_name = fmt::format(
562
+ "scatter_axis{0}{1}_{2}_{3}",
563
+ type_to_name(out),
564
+ type_to_name(idx),
565
+ op_name,
566
+ large ? "int64_t" : "int");
567
+ std::string lib_name = kernel_name;
568
+ kernel_name += upd.flags().row_contiguous ? "c" : "nc";
569
+ kernel_name += idx.flags().row_contiguous ? "c" : "nc";
570
+
571
+ auto lib = d.get_library(lib_name, [&]() {
572
+ std::string kernel_source = metal::utils();
573
+ kernel_source += metal::reduce_utils();
574
+ kernel_source += metal::scatter_axis();
575
+ std::string out_type_str = get_type_string(out.dtype());
576
+ std::string idx_type_str = get_type_string(idx.dtype());
577
+ std::string op_type;
578
+ switch (reduce_type_) {
579
+ case ScatterAxis::None:
580
+ op_type = "None";
581
+ break;
582
+ case ScatterAxis::Sum:
583
+ op_type = "Sum<" + out_type_str + ">";
584
+ break;
585
+ }
586
+
587
+ for (int i = 0; i < 4; ++i) {
588
+ bool uc = i & 1;
589
+ bool ic = i & 2;
590
+ kernel_source += get_template_definition(
591
+ lib_name + (uc ? "c" : "nc") + (ic ? "c" : "nc"),
592
+ "scatter_axis",
593
+ out_type_str,
594
+ idx_type_str,
595
+ large ? "int64_t" : "int",
596
+ op_type,
597
+ uc ? "true" : "false",
598
+ ic ? "true" : "false");
599
+ }
600
+ return kernel_source;
601
+ });
602
+
603
+ auto& compute_encoder = d.get_command_encoder(s.index);
604
+ auto kernel = d.get_kernel(kernel_name, lib);
605
+ compute_encoder.set_compute_pipeline_state(kernel);
606
+
607
+ // Grid [size post, index size, size pre]
608
+ size_t size_pre = 1;
609
+ size_t size_post = 1;
610
+ for (int i = 0; i < axis_; ++i) {
611
+ size_pre *= idx.shape(i);
612
+ }
613
+ for (int i = axis_ + 1; i < idx.ndim(); ++i) {
614
+ size_post *= idx.shape(i);
615
+ }
616
+
617
+ int idx_ax_size = idx.shape(axis_);
618
+ auto group_dims = get_block_dims(size_post, idx_ax_size, size_pre);
619
+ MTL::Size grid_dims = MTL::Size(size_post, idx_ax_size, size_pre);
620
+
621
+ // Set all the buffers
622
+ compute_encoder.set_input_array(upd, 0);
623
+ compute_encoder.set_input_array(idx, 1);
624
+ compute_encoder.set_output_array(out, 2);
625
+
626
+ // Set source info
627
+ if (ndim > 1) {
628
+ compute_encoder.set_vector_bytes(remove_index(idx.shape(), axis_), 3);
629
+ compute_encoder.set_vector_bytes(remove_index(upd.strides(), axis_), 4);
630
+ compute_encoder.set_vector_bytes(remove_index(idx.strides(), axis_), 5);
631
+ } else {
632
+ // The following will be ignored in the kernel but we still have to set
633
+ // some value so that metal validation passes.
634
+ compute_encoder.set_vector_bytes(idx.shape(), 3);
635
+ compute_encoder.set_vector_bytes(upd.strides(), 4);
636
+ compute_encoder.set_vector_bytes(idx.strides(), 5);
637
+ }
638
+ compute_encoder.set_bytes(ndim - 1, 6);
639
+ compute_encoder.set_bytes(axis_, 7);
640
+ compute_encoder.set_bytes(out.shape(axis_), 8);
641
+ compute_encoder.set_bytes(upd.strides(axis_), 9);
642
+ compute_encoder.set_bytes(idx.strides(axis_), 10);
643
+
644
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
645
+ }
646
+
647
+ void MaskedScatter::eval_gpu(const std::vector<array>& inputs, array& out) {
648
+ const array& dst = inputs[0];
649
+ const array& mask = inputs[1];
650
+ const array& src = inputs[2];
651
+
652
+ auto& s = stream();
653
+ auto& d = metal::device(s.device);
654
+
655
+ const size_t total = mask.size();
656
+ const CopyType ct = (total == 1)
657
+ ? CopyType::Scalar
658
+ : (dst.flags().row_contiguous ? CopyType::Vector : CopyType::General);
659
+ copy_gpu(dst, out, ct, s);
660
+ if (total == 0) {
661
+ return;
662
+ }
663
+
664
+ array mask_flat = flatten_in_eval(mask, 1, -1, s);
665
+ if (mask_flat.data<void>() != mask.data<void>()) {
666
+ d.add_temporary(mask_flat, s.index);
667
+ }
668
+
669
+ if (!mask_flat.flags().row_contiguous) {
670
+ mask_flat = contiguous_copy_gpu(mask_flat, s);
671
+ d.add_temporary(mask_flat, s.index);
672
+ }
673
+
674
+ // Prefix (exclusive) of mask → scatter_offsets
675
+ array scatter_offsets(mask_flat.shape(), uint32, nullptr, {});
676
+ scatter_offsets.set_data(allocator::malloc(scatter_offsets.nbytes()));
677
+ d.add_temporary(scatter_offsets, s.index);
678
+
679
+ scan_gpu_inplace(
680
+ mask_flat,
681
+ scatter_offsets,
682
+ Scan::Sum,
683
+ /*axis=*/1,
684
+ /*reverse=*/false,
685
+ /*inclusive=*/false,
686
+ s);
687
+
688
+ // Kernel selection/build
689
+ static constexpr std::string_view kBaseName = "masked_assign";
690
+ const std::string dtype_tag = type_to_name(out.dtype());
691
+ const std::string value_type = get_type_string(out.dtype());
692
+ const std::string contiguous =
693
+ (src.flags().row_contiguous) ? "true" : "false";
694
+ const std::string kernel_name =
695
+ fmt::format("{}_{}_{}", kBaseName, dtype_tag, contiguous);
696
+
697
+ auto lib = d.get_library(kernel_name, [&]() {
698
+ std::string source = metal::utils();
699
+ source += metal::masked_scatter();
700
+ source +=
701
+ fmt::format(masked_assign_kernel, kernel_name, value_type, contiguous);
702
+ return source;
703
+ });
704
+ auto kernel = d.get_kernel(kernel_name, lib);
705
+
706
+ // Binding
707
+ int bind_idx = 0;
708
+ const int ndim = static_cast<int>(src.ndim());
709
+ auto& compute_encoder = d.get_command_encoder(s.index);
710
+ compute_encoder.set_compute_pipeline_state(kernel);
711
+ compute_encoder.set_input_array(mask_flat, bind_idx++);
712
+ compute_encoder.set_input_array(scatter_offsets, bind_idx++);
713
+ compute_encoder.set_input_array(src, bind_idx++);
714
+ compute_encoder.set_output_array(out, bind_idx++);
715
+ compute_encoder.set_vector_bytes(src.shape(), bind_idx++);
716
+ compute_encoder.set_vector_bytes(src.strides(), bind_idx++);
717
+ compute_encoder.set_bytes(ndim, bind_idx++);
718
+ compute_encoder.set_bytes(src.size() / src.shape(0), bind_idx++);
719
+ compute_encoder.set_bytes(mask_flat.size() / mask.shape(0), bind_idx++);
720
+
721
+ // Dispatch
722
+ auto group_dims = get_block_dims(total, 1, 1);
723
+ MTL::Size grid_dims(total, 1, 1);
724
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
725
+ }
726
+
727
+ } // namespace mlx::core