mlx 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlx might be problematic. Click here for more details.

Files changed (914) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/CMakeLists.txt +7 -0
  3. data/ext/mlx/Makefile +273 -0
  4. data/ext/mlx/extconf.rb +94 -0
  5. data/ext/mlx/mkmf.log +44 -0
  6. data/ext/mlx/native.bundle +0 -0
  7. data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
  8. data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
  9. data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
  10. data/ext/mlx/native.cpp +8027 -0
  11. data/ext/mlx/native.o +0 -0
  12. data/lib/mlx/core.rb +1678 -0
  13. data/lib/mlx/distributed_utils/common.rb +116 -0
  14. data/lib/mlx/distributed_utils/config.rb +600 -0
  15. data/lib/mlx/distributed_utils/launch.rb +490 -0
  16. data/lib/mlx/extension.rb +24 -0
  17. data/lib/mlx/nn/base.rb +388 -0
  18. data/lib/mlx/nn/init.rb +140 -0
  19. data/lib/mlx/nn/layers/activations.rb +336 -0
  20. data/lib/mlx/nn/layers/base.rb +6 -0
  21. data/lib/mlx/nn/layers/containers.rb +20 -0
  22. data/lib/mlx/nn/layers/convolution.rb +120 -0
  23. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  24. data/lib/mlx/nn/layers/distributed.rb +309 -0
  25. data/lib/mlx/nn/layers/dropout.rb +75 -0
  26. data/lib/mlx/nn/layers/embedding.rb +28 -0
  27. data/lib/mlx/nn/layers/linear.rb +79 -0
  28. data/lib/mlx/nn/layers/normalization.rb +216 -0
  29. data/lib/mlx/nn/layers/pooling.rb +167 -0
  30. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  31. data/lib/mlx/nn/layers/quantized.rb +215 -0
  32. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  33. data/lib/mlx/nn/layers/transformer.rb +330 -0
  34. data/lib/mlx/nn/layers/upsample.rb +97 -0
  35. data/lib/mlx/nn/layers.rb +18 -0
  36. data/lib/mlx/nn/losses.rb +251 -0
  37. data/lib/mlx/nn/utils.rb +167 -0
  38. data/lib/mlx/nn.rb +12 -0
  39. data/lib/mlx/optimizers/optimizers.rb +808 -0
  40. data/lib/mlx/optimizers/schedulers.rb +62 -0
  41. data/lib/mlx/optimizers.rb +9 -0
  42. data/lib/mlx/utils.rb +171 -0
  43. data/lib/mlx/version +1 -0
  44. data/lib/mlx/version.rb +5 -0
  45. data/lib/mlx.rb +64 -0
  46. data/mlx/.clang-format +87 -0
  47. data/mlx/.git +1 -0
  48. data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
  49. data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
  50. data/mlx/.github/actions/build-docs/action.yml +38 -0
  51. data/mlx/.github/actions/build-linux/action.yml +38 -0
  52. data/mlx/.github/actions/build-linux-release/action.yml +42 -0
  53. data/mlx/.github/actions/build-macos/action.yml +80 -0
  54. data/mlx/.github/actions/build-macos-release/action.yml +36 -0
  55. data/mlx/.github/actions/build-windows/action.yml +26 -0
  56. data/mlx/.github/actions/setup-linux/action.yml +93 -0
  57. data/mlx/.github/actions/setup-macos/action.yml +24 -0
  58. data/mlx/.github/actions/setup-windows/action.yml +42 -0
  59. data/mlx/.github/actions/test-linux/action.yml +69 -0
  60. data/mlx/.github/actions/test-windows/action.yml +20 -0
  61. data/mlx/.github/dependabot.yml +6 -0
  62. data/mlx/.github/pull_request_template.md +12 -0
  63. data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
  64. data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
  65. data/mlx/.github/workflows/build_and_test.yml +152 -0
  66. data/mlx/.github/workflows/documentation.yml +28 -0
  67. data/mlx/.github/workflows/nightly.yml +104 -0
  68. data/mlx/.github/workflows/release.yml +256 -0
  69. data/mlx/.gitignore +81 -0
  70. data/mlx/.pre-commit-config.yaml +27 -0
  71. data/mlx/ACKNOWLEDGMENTS.md +268 -0
  72. data/mlx/CITATION.cff +24 -0
  73. data/mlx/CMakeLists.txt +437 -0
  74. data/mlx/CODE_OF_CONDUCT.md +132 -0
  75. data/mlx/CONTRIBUTING.md +38 -0
  76. data/mlx/LICENSE +21 -0
  77. data/mlx/MANIFEST.in +6 -0
  78. data/mlx/README.md +121 -0
  79. data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
  80. data/mlx/benchmarks/cpp/autograd.cpp +39 -0
  81. data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
  82. data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
  83. data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
  84. data/mlx/benchmarks/cpp/time_utils.h +39 -0
  85. data/mlx/benchmarks/numpy/single_ops.py +39 -0
  86. data/mlx/benchmarks/numpy/time_utils.py +20 -0
  87. data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
  88. data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
  89. data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
  90. data/mlx/benchmarks/python/comparative/README.md +15 -0
  91. data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
  92. data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
  93. data/mlx/benchmarks/python/comparative/compare.py +284 -0
  94. data/mlx/benchmarks/python/compile_bench.py +107 -0
  95. data/mlx/benchmarks/python/conv1d_bench.py +123 -0
  96. data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
  97. data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
  98. data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
  99. data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
  100. data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
  101. data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
  102. data/mlx/benchmarks/python/conv_bench.py +135 -0
  103. data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
  104. data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
  105. data/mlx/benchmarks/python/distributed_bench.py +66 -0
  106. data/mlx/benchmarks/python/einsum_bench.py +84 -0
  107. data/mlx/benchmarks/python/fft_bench.py +118 -0
  108. data/mlx/benchmarks/python/gather_bench.py +52 -0
  109. data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
  110. data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
  111. data/mlx/benchmarks/python/hadamard_bench.py +70 -0
  112. data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
  113. data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
  114. data/mlx/benchmarks/python/masked_scatter.py +212 -0
  115. data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
  116. data/mlx/benchmarks/python/rope_bench.py +35 -0
  117. data/mlx/benchmarks/python/scatter_bench.py +96 -0
  118. data/mlx/benchmarks/python/sdpa_bench.py +223 -0
  119. data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
  120. data/mlx/benchmarks/python/single_ops.py +132 -0
  121. data/mlx/benchmarks/python/synchronize_bench.py +55 -0
  122. data/mlx/benchmarks/python/time_utils.py +38 -0
  123. data/mlx/cmake/FindCUDNN.cmake +177 -0
  124. data/mlx/cmake/FindNCCL.cmake +54 -0
  125. data/mlx/cmake/Findnvpl.cmake +3 -0
  126. data/mlx/cmake/extension.cmake +50 -0
  127. data/mlx/docs/.clang-format +2 -0
  128. data/mlx/docs/.gitignore +3 -0
  129. data/mlx/docs/.nojekyll +0 -0
  130. data/mlx/docs/Doxyfile +51 -0
  131. data/mlx/docs/Makefile +18 -0
  132. data/mlx/docs/README.md +54 -0
  133. data/mlx/docs/index.html +1 -0
  134. data/mlx/docs/requirements.txt +5 -0
  135. data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
  136. data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
  137. data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
  138. data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
  139. data/mlx/docs/src/_static/mlx_logo.png +0 -0
  140. data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
  141. data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
  142. data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
  143. data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
  144. data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
  145. data/mlx/docs/src/_templates/module-base-class.rst +33 -0
  146. data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
  147. data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
  148. data/mlx/docs/src/conf.py +99 -0
  149. data/mlx/docs/src/cpp/ops.rst +7 -0
  150. data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
  151. data/mlx/docs/src/dev/extensions.rst +811 -0
  152. data/mlx/docs/src/dev/metal_debugger.rst +68 -0
  153. data/mlx/docs/src/dev/metal_logging.rst +40 -0
  154. data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
  155. data/mlx/docs/src/examples/data_parallelism.rst +91 -0
  156. data/mlx/docs/src/examples/linear_regression.rst +77 -0
  157. data/mlx/docs/src/examples/llama-inference.rst +382 -0
  158. data/mlx/docs/src/examples/mlp.rst +134 -0
  159. data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
  160. data/mlx/docs/src/index.rst +96 -0
  161. data/mlx/docs/src/install.rst +340 -0
  162. data/mlx/docs/src/python/array.rst +65 -0
  163. data/mlx/docs/src/python/cuda.rst +9 -0
  164. data/mlx/docs/src/python/data_types.rst +78 -0
  165. data/mlx/docs/src/python/devices_and_streams.rst +21 -0
  166. data/mlx/docs/src/python/distributed.rst +22 -0
  167. data/mlx/docs/src/python/export.rst +14 -0
  168. data/mlx/docs/src/python/fast.rst +16 -0
  169. data/mlx/docs/src/python/fft.rst +24 -0
  170. data/mlx/docs/src/python/linalg.rst +27 -0
  171. data/mlx/docs/src/python/memory_management.rst +16 -0
  172. data/mlx/docs/src/python/metal.rst +12 -0
  173. data/mlx/docs/src/python/nn/distributed.rst +30 -0
  174. data/mlx/docs/src/python/nn/functions.rst +40 -0
  175. data/mlx/docs/src/python/nn/init.rst +45 -0
  176. data/mlx/docs/src/python/nn/layers.rst +74 -0
  177. data/mlx/docs/src/python/nn/losses.rst +25 -0
  178. data/mlx/docs/src/python/nn/module.rst +38 -0
  179. data/mlx/docs/src/python/nn.rst +186 -0
  180. data/mlx/docs/src/python/ops.rst +184 -0
  181. data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
  182. data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
  183. data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
  184. data/mlx/docs/src/python/optimizers.rst +78 -0
  185. data/mlx/docs/src/python/random.rst +48 -0
  186. data/mlx/docs/src/python/transforms.rst +22 -0
  187. data/mlx/docs/src/python/tree_utils.rst +23 -0
  188. data/mlx/docs/src/usage/compile.rst +516 -0
  189. data/mlx/docs/src/usage/distributed.rst +572 -0
  190. data/mlx/docs/src/usage/export.rst +288 -0
  191. data/mlx/docs/src/usage/function_transforms.rst +191 -0
  192. data/mlx/docs/src/usage/indexing.rst +194 -0
  193. data/mlx/docs/src/usage/launching_distributed.rst +234 -0
  194. data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
  195. data/mlx/docs/src/usage/numpy.rst +124 -0
  196. data/mlx/docs/src/usage/quick_start.rst +67 -0
  197. data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
  198. data/mlx/docs/src/usage/unified_memory.rst +78 -0
  199. data/mlx/docs/src/usage/using_streams.rst +18 -0
  200. data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
  201. data/mlx/examples/cmake_project/README.md +26 -0
  202. data/mlx/examples/cmake_project/example.cpp +14 -0
  203. data/mlx/examples/cpp/CMakeLists.txt +12 -0
  204. data/mlx/examples/cpp/distributed.cpp +22 -0
  205. data/mlx/examples/cpp/linear_regression.cpp +54 -0
  206. data/mlx/examples/cpp/logistic_regression.cpp +54 -0
  207. data/mlx/examples/cpp/metal_capture.cpp +31 -0
  208. data/mlx/examples/cpp/timer.h +20 -0
  209. data/mlx/examples/cpp/tutorial.cpp +99 -0
  210. data/mlx/examples/export/CMakeLists.txt +22 -0
  211. data/mlx/examples/export/README.md +49 -0
  212. data/mlx/examples/export/eval_mlp.cpp +25 -0
  213. data/mlx/examples/export/eval_mlp.py +52 -0
  214. data/mlx/examples/export/train_mlp.cpp +35 -0
  215. data/mlx/examples/export/train_mlp.py +76 -0
  216. data/mlx/examples/extensions/CMakeLists.txt +78 -0
  217. data/mlx/examples/extensions/README.md +24 -0
  218. data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
  219. data/mlx/examples/extensions/axpby/axpby.h +90 -0
  220. data/mlx/examples/extensions/axpby/axpby.metal +47 -0
  221. data/mlx/examples/extensions/bindings.cpp +39 -0
  222. data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
  223. data/mlx/examples/extensions/pyproject.toml +8 -0
  224. data/mlx/examples/extensions/requirements.txt +4 -0
  225. data/mlx/examples/extensions/setup.py +18 -0
  226. data/mlx/examples/extensions/test.py +12 -0
  227. data/mlx/examples/python/linear_regression.py +46 -0
  228. data/mlx/examples/python/logistic_regression.py +49 -0
  229. data/mlx/examples/python/qqmm.py +117 -0
  230. data/mlx/mlx/3rdparty/.clang-format +2 -0
  231. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  232. data/mlx/mlx/CMakeLists.txt +107 -0
  233. data/mlx/mlx/allocator.h +75 -0
  234. data/mlx/mlx/api.h +29 -0
  235. data/mlx/mlx/array.cpp +354 -0
  236. data/mlx/mlx/array.h +647 -0
  237. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  238. data/mlx/mlx/backend/common/binary.h +97 -0
  239. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  240. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  241. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  242. data/mlx/mlx/backend/common/common.cpp +305 -0
  243. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  244. data/mlx/mlx/backend/common/compiled.h +77 -0
  245. data/mlx/mlx/backend/common/copy.h +50 -0
  246. data/mlx/mlx/backend/common/hadamard.h +109 -0
  247. data/mlx/mlx/backend/common/load.cpp +57 -0
  248. data/mlx/mlx/backend/common/matmul.h +67 -0
  249. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  250. data/mlx/mlx/backend/common/reduce.h +59 -0
  251. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  252. data/mlx/mlx/backend/common/slicing.h +20 -0
  253. data/mlx/mlx/backend/common/ternary.h +85 -0
  254. data/mlx/mlx/backend/common/unary.h +29 -0
  255. data/mlx/mlx/backend/common/utils.cpp +231 -0
  256. data/mlx/mlx/backend/common/utils.h +205 -0
  257. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  258. data/mlx/mlx/backend/cpu/arange.h +28 -0
  259. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  260. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  261. data/mlx/mlx/backend/cpu/binary.h +517 -0
  262. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  263. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  264. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  265. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  266. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  267. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  268. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  269. data/mlx/mlx/backend/cpu/copy.h +36 -0
  270. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  271. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  272. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  273. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  274. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  275. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  276. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  277. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  278. data/mlx/mlx/backend/cpu/eval.h +12 -0
  279. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  280. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  281. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  282. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  283. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  284. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  285. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  286. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  287. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  288. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  289. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  290. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  291. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  292. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  293. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  294. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  295. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  296. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  297. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  298. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  299. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  300. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  301. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  302. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  303. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  304. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  305. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  306. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  307. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  308. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  309. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  310. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  311. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  312. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  313. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  314. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  315. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  316. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  317. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  318. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  319. data/mlx/mlx/backend/cpu/unary.h +281 -0
  320. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  321. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  322. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  323. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  324. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  325. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  326. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  327. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  328. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  329. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  330. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  331. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  332. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  333. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  334. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  335. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  336. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  337. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  338. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  339. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  340. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  341. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  342. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  343. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  344. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  345. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  346. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  347. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  348. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  349. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  350. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  351. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  352. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  353. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  354. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  355. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  356. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  357. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  358. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  359. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  360. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  361. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  362. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  363. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  364. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  365. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  366. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  367. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  368. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  369. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  370. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  371. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  372. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  373. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  374. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  375. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  376. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  377. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  378. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  379. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  380. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  381. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  382. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  383. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  384. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  385. data/mlx/mlx/backend/cuda/device.h +195 -0
  386. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  387. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  388. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  389. data/mlx/mlx/backend/cuda/event.cu +415 -0
  390. data/mlx/mlx/backend/cuda/event.h +79 -0
  391. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  392. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  393. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  394. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  395. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  396. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  397. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  398. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  399. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  400. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  401. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  402. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  403. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  404. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  405. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  406. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  407. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  408. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  409. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  410. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  411. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  412. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  413. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  414. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  415. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  416. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  417. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  418. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  419. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  420. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  421. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  422. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  423. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  424. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  425. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  426. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  427. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  428. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  429. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  430. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  431. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  432. data/mlx/mlx/backend/cuda/random.cu +202 -0
  433. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  434. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  435. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  436. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  437. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  438. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  439. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  440. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  441. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  442. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  443. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  444. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  445. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  446. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  447. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  448. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  449. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  450. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  451. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  452. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  453. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  454. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  455. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  456. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  457. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  458. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  459. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  460. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  461. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  462. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  463. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  464. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  465. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  466. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  467. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  468. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  469. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  470. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  471. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  472. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  473. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  474. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  475. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  476. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  477. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  478. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  479. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  480. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  481. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  482. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  483. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  484. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  485. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  486. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  487. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  488. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  489. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  490. data/mlx/mlx/backend/cuda/utils.h +49 -0
  491. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  492. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  493. data/mlx/mlx/backend/cuda/worker.h +55 -0
  494. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  495. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  496. data/mlx/mlx/backend/gpu/copy.h +57 -0
  497. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  498. data/mlx/mlx/backend/gpu/eval.h +18 -0
  499. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  500. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  501. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  502. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  503. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  504. data/mlx/mlx/backend/metal/allocator.h +79 -0
  505. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  506. data/mlx/mlx/backend/metal/binary.h +33 -0
  507. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  508. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  509. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  510. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  511. data/mlx/mlx/backend/metal/device.cpp +816 -0
  512. data/mlx/mlx/backend/metal/device.h +289 -0
  513. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  514. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  515. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  516. data/mlx/mlx/backend/metal/event.cpp +62 -0
  517. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  518. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  519. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  520. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  521. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  522. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  523. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  524. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  525. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  526. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  527. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  528. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  529. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  530. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  531. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  532. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  533. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  534. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  535. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  536. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  537. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  538. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  539. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  540. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  541. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  542. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  543. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  544. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  545. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  546. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  547. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  548. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  549. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  550. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  551. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  552. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  553. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  554. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  555. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  556. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  557. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  558. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  559. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  560. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  561. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  562. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  563. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  564. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  565. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  566. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  567. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  568. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  569. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  570. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  571. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  572. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  573. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  574. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  575. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  576. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  577. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  578. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  579. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  580. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  581. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  582. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  583. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  584. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  585. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  586. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  587. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  588. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  589. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  590. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  591. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  592. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  593. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  594. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  595. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  596. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  597. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  598. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  599. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  600. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  601. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  602. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  603. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  604. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  605. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  606. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  607. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  608. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  609. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  610. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  611. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  612. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  613. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  614. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  615. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  616. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  617. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  618. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  619. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  620. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  621. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  622. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  623. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  624. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  625. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  626. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  627. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  628. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  629. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  630. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  631. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  632. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  633. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  634. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  635. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  636. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  637. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  638. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  639. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  640. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  641. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  642. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  643. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  644. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  645. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  646. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  647. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  648. data/mlx/mlx/backend/metal/kernels.h +375 -0
  649. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  650. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  651. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  652. data/mlx/mlx/backend/metal/matmul.h +144 -0
  653. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  654. data/mlx/mlx/backend/metal/metal.h +25 -0
  655. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  656. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  657. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  658. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  659. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  660. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  661. data/mlx/mlx/backend/metal/reduce.h +41 -0
  662. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  663. data/mlx/mlx/backend/metal/resident.h +32 -0
  664. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  665. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  666. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  667. data/mlx/mlx/backend/metal/scan.h +17 -0
  668. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  669. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  670. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  671. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  672. data/mlx/mlx/backend/metal/ternary.h +21 -0
  673. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  674. data/mlx/mlx/backend/metal/unary.h +21 -0
  675. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  676. data/mlx/mlx/backend/metal/utils.h +99 -0
  677. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  678. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  679. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  680. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  681. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  682. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  683. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  684. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  685. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  686. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  687. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  688. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  689. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  690. data/mlx/mlx/compile.cpp +1243 -0
  691. data/mlx/mlx/compile.h +45 -0
  692. data/mlx/mlx/compile_impl.h +70 -0
  693. data/mlx/mlx/device.cpp +72 -0
  694. data/mlx/mlx/device.h +56 -0
  695. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  696. data/mlx/mlx/distributed/distributed.cpp +197 -0
  697. data/mlx/mlx/distributed/distributed.h +61 -0
  698. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  699. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  700. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  701. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  702. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  703. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  704. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  705. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  706. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  707. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  708. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  709. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  710. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  711. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  712. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  713. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  714. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  715. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  716. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  717. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  718. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  719. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  720. data/mlx/mlx/distributed/ops.cpp +186 -0
  721. data/mlx/mlx/distributed/ops.h +57 -0
  722. data/mlx/mlx/distributed/primitives.cpp +95 -0
  723. data/mlx/mlx/distributed/primitives.h +156 -0
  724. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  725. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  726. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  727. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  728. data/mlx/mlx/distributed/ring/ring.h +12 -0
  729. data/mlx/mlx/distributed/utils.cpp +206 -0
  730. data/mlx/mlx/distributed/utils.h +67 -0
  731. data/mlx/mlx/dtype.cpp +197 -0
  732. data/mlx/mlx/dtype.h +116 -0
  733. data/mlx/mlx/dtype_utils.cpp +42 -0
  734. data/mlx/mlx/dtype_utils.h +119 -0
  735. data/mlx/mlx/einsum.cpp +941 -0
  736. data/mlx/mlx/einsum.h +23 -0
  737. data/mlx/mlx/event.h +58 -0
  738. data/mlx/mlx/export.cpp +1130 -0
  739. data/mlx/mlx/export.h +137 -0
  740. data/mlx/mlx/export_impl.h +99 -0
  741. data/mlx/mlx/fast.cpp +941 -0
  742. data/mlx/mlx/fast.h +103 -0
  743. data/mlx/mlx/fast_primitives.h +427 -0
  744. data/mlx/mlx/fence.h +39 -0
  745. data/mlx/mlx/fft.cpp +262 -0
  746. data/mlx/mlx/fft.h +159 -0
  747. data/mlx/mlx/graph_utils.cpp +175 -0
  748. data/mlx/mlx/graph_utils.h +67 -0
  749. data/mlx/mlx/io/CMakeLists.txt +25 -0
  750. data/mlx/mlx/io/gguf.cpp +470 -0
  751. data/mlx/mlx/io/gguf.h +20 -0
  752. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  753. data/mlx/mlx/io/load.cpp +397 -0
  754. data/mlx/mlx/io/load.h +175 -0
  755. data/mlx/mlx/io/no_gguf.cpp +20 -0
  756. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  757. data/mlx/mlx/io/safetensors.cpp +234 -0
  758. data/mlx/mlx/io.h +61 -0
  759. data/mlx/mlx/linalg.cpp +708 -0
  760. data/mlx/mlx/linalg.h +115 -0
  761. data/mlx/mlx/memory.h +80 -0
  762. data/mlx/mlx/mlx.h +25 -0
  763. data/mlx/mlx/ops.cpp +6094 -0
  764. data/mlx/mlx/ops.h +1610 -0
  765. data/mlx/mlx/primitives.cpp +5850 -0
  766. data/mlx/mlx/primitives.h +2525 -0
  767. data/mlx/mlx/random.cpp +492 -0
  768. data/mlx/mlx/random.h +283 -0
  769. data/mlx/mlx/scheduler.cpp +73 -0
  770. data/mlx/mlx/scheduler.h +189 -0
  771. data/mlx/mlx/small_vector.h +540 -0
  772. data/mlx/mlx/stream.h +42 -0
  773. data/mlx/mlx/threadpool.h +133 -0
  774. data/mlx/mlx/transforms.cpp +1065 -0
  775. data/mlx/mlx/transforms.h +231 -0
  776. data/mlx/mlx/transforms_impl.h +88 -0
  777. data/mlx/mlx/types/bf16.h +187 -0
  778. data/mlx/mlx/types/complex.h +113 -0
  779. data/mlx/mlx/types/fp16.h +234 -0
  780. data/mlx/mlx/types/half_types.h +58 -0
  781. data/mlx/mlx/types/limits.h +70 -0
  782. data/mlx/mlx/utils.cpp +302 -0
  783. data/mlx/mlx/utils.h +174 -0
  784. data/mlx/mlx/version.cpp +11 -0
  785. data/mlx/mlx/version.h +22 -0
  786. data/mlx/mlx.pc.in +52 -0
  787. data/mlx/pyproject.toml +7 -0
  788. data/mlx/python/mlx/__main__.py +27 -0
  789. data/mlx/python/mlx/_distributed_utils/common.py +135 -0
  790. data/mlx/python/mlx/_distributed_utils/config.py +631 -0
  791. data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
  792. data/mlx/python/mlx/_reprlib_fix.py +16 -0
  793. data/mlx/python/mlx/_stub_patterns.txt +36 -0
  794. data/mlx/python/mlx/extension.py +88 -0
  795. data/mlx/python/mlx/nn/__init__.py +5 -0
  796. data/mlx/python/mlx/nn/init.py +441 -0
  797. data/mlx/python/mlx/nn/layers/__init__.py +105 -0
  798. data/mlx/python/mlx/nn/layers/activations.py +661 -0
  799. data/mlx/python/mlx/nn/layers/base.py +675 -0
  800. data/mlx/python/mlx/nn/layers/containers.py +24 -0
  801. data/mlx/python/mlx/nn/layers/convolution.py +232 -0
  802. data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
  803. data/mlx/python/mlx/nn/layers/distributed.py +601 -0
  804. data/mlx/python/mlx/nn/layers/dropout.py +137 -0
  805. data/mlx/python/mlx/nn/layers/embedding.py +53 -0
  806. data/mlx/python/mlx/nn/layers/linear.py +180 -0
  807. data/mlx/python/mlx/nn/layers/normalization.py +363 -0
  808. data/mlx/python/mlx/nn/layers/pooling.py +398 -0
  809. data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
  810. data/mlx/python/mlx/nn/layers/quantized.py +426 -0
  811. data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
  812. data/mlx/python/mlx/nn/layers/transformer.py +354 -0
  813. data/mlx/python/mlx/nn/layers/upsample.py +277 -0
  814. data/mlx/python/mlx/nn/losses.py +610 -0
  815. data/mlx/python/mlx/nn/utils.py +165 -0
  816. data/mlx/python/mlx/optimizers/__init__.py +4 -0
  817. data/mlx/python/mlx/optimizers/optimizers.py +976 -0
  818. data/mlx/python/mlx/optimizers/schedulers.py +158 -0
  819. data/mlx/python/mlx/py.typed +1 -0
  820. data/mlx/python/mlx/utils.py +325 -0
  821. data/mlx/python/src/CMakeLists.txt +96 -0
  822. data/mlx/python/src/array.cpp +1525 -0
  823. data/mlx/python/src/buffer.h +124 -0
  824. data/mlx/python/src/constants.cpp +15 -0
  825. data/mlx/python/src/convert.cpp +504 -0
  826. data/mlx/python/src/convert.h +50 -0
  827. data/mlx/python/src/cuda.cpp +19 -0
  828. data/mlx/python/src/device.cpp +98 -0
  829. data/mlx/python/src/distributed.cpp +352 -0
  830. data/mlx/python/src/export.cpp +356 -0
  831. data/mlx/python/src/fast.cpp +627 -0
  832. data/mlx/python/src/fft.cpp +514 -0
  833. data/mlx/python/src/indexing.cpp +1016 -0
  834. data/mlx/python/src/indexing.h +41 -0
  835. data/mlx/python/src/linalg.cpp +663 -0
  836. data/mlx/python/src/load.cpp +531 -0
  837. data/mlx/python/src/load.h +51 -0
  838. data/mlx/python/src/memory.cpp +125 -0
  839. data/mlx/python/src/metal.cpp +98 -0
  840. data/mlx/python/src/mlx.cpp +51 -0
  841. data/mlx/python/src/mlx_func.cpp +116 -0
  842. data/mlx/python/src/mlx_func.h +31 -0
  843. data/mlx/python/src/ops.cpp +5545 -0
  844. data/mlx/python/src/random.cpp +516 -0
  845. data/mlx/python/src/small_vector.h +76 -0
  846. data/mlx/python/src/stream.cpp +147 -0
  847. data/mlx/python/src/transforms.cpp +1542 -0
  848. data/mlx/python/src/trees.cpp +311 -0
  849. data/mlx/python/src/trees.h +62 -0
  850. data/mlx/python/src/utils.cpp +98 -0
  851. data/mlx/python/src/utils.h +78 -0
  852. data/mlx/python/tests/__main__.py +5 -0
  853. data/mlx/python/tests/cuda_skip.py +62 -0
  854. data/mlx/python/tests/mlx_distributed_tests.py +314 -0
  855. data/mlx/python/tests/mlx_tests.py +116 -0
  856. data/mlx/python/tests/mpi_test_distributed.py +142 -0
  857. data/mlx/python/tests/nccl_test_distributed.py +52 -0
  858. data/mlx/python/tests/ring_test_distributed.py +131 -0
  859. data/mlx/python/tests/test_array.py +2139 -0
  860. data/mlx/python/tests/test_autograd.py +880 -0
  861. data/mlx/python/tests/test_bf16.py +196 -0
  862. data/mlx/python/tests/test_blas.py +1429 -0
  863. data/mlx/python/tests/test_compile.py +1277 -0
  864. data/mlx/python/tests/test_constants.py +41 -0
  865. data/mlx/python/tests/test_conv.py +1198 -0
  866. data/mlx/python/tests/test_conv_transpose.py +810 -0
  867. data/mlx/python/tests/test_device.py +150 -0
  868. data/mlx/python/tests/test_double.py +306 -0
  869. data/mlx/python/tests/test_einsum.py +363 -0
  870. data/mlx/python/tests/test_eval.py +200 -0
  871. data/mlx/python/tests/test_export_import.py +614 -0
  872. data/mlx/python/tests/test_fast.py +923 -0
  873. data/mlx/python/tests/test_fast_sdpa.py +647 -0
  874. data/mlx/python/tests/test_fft.py +323 -0
  875. data/mlx/python/tests/test_graph.py +37 -0
  876. data/mlx/python/tests/test_init.py +139 -0
  877. data/mlx/python/tests/test_linalg.py +621 -0
  878. data/mlx/python/tests/test_load.py +447 -0
  879. data/mlx/python/tests/test_losses.py +427 -0
  880. data/mlx/python/tests/test_memory.py +77 -0
  881. data/mlx/python/tests/test_nn.py +1986 -0
  882. data/mlx/python/tests/test_ops.py +3261 -0
  883. data/mlx/python/tests/test_optimizers.py +584 -0
  884. data/mlx/python/tests/test_quantized.py +1160 -0
  885. data/mlx/python/tests/test_random.py +392 -0
  886. data/mlx/python/tests/test_reduce.py +223 -0
  887. data/mlx/python/tests/test_tree.py +96 -0
  888. data/mlx/python/tests/test_upsample.py +100 -0
  889. data/mlx/python/tests/test_vmap.py +860 -0
  890. data/mlx/setup.py +315 -0
  891. data/mlx/tests/CMakeLists.txt +44 -0
  892. data/mlx/tests/allocator_tests.cpp +41 -0
  893. data/mlx/tests/arg_reduce_tests.cpp +204 -0
  894. data/mlx/tests/array_tests.cpp +663 -0
  895. data/mlx/tests/autograd_tests.cpp +1399 -0
  896. data/mlx/tests/blas_tests.cpp +110 -0
  897. data/mlx/tests/compile_tests.cpp +818 -0
  898. data/mlx/tests/creations_tests.cpp +239 -0
  899. data/mlx/tests/custom_vjp_tests.cpp +55 -0
  900. data/mlx/tests/device_tests.cpp +35 -0
  901. data/mlx/tests/einsum_tests.cpp +85 -0
  902. data/mlx/tests/eval_tests.cpp +93 -0
  903. data/mlx/tests/export_import_tests.cpp +164 -0
  904. data/mlx/tests/fft_tests.cpp +366 -0
  905. data/mlx/tests/gpu_tests.cpp +523 -0
  906. data/mlx/tests/linalg_tests.cpp +639 -0
  907. data/mlx/tests/load_tests.cpp +270 -0
  908. data/mlx/tests/ops_tests.cpp +4159 -0
  909. data/mlx/tests/random_tests.cpp +716 -0
  910. data/mlx/tests/scheduler_tests.cpp +121 -0
  911. data/mlx/tests/tests.cpp +26 -0
  912. data/mlx/tests/utils_tests.cpp +67 -0
  913. data/mlx/tests/vmap_tests.cpp +547 -0
  914. metadata +958 -0
@@ -0,0 +1,1016 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+ #include <numeric>
3
+ #include <optional>
4
+ #include <sstream>
5
+
6
+ #include "python/src/convert.h"
7
+ #include "python/src/indexing.h"
8
+
9
+ #include "mlx/ops.h"
10
+
11
+ bool is_none_slice(const nb::slice& in_slice) {
12
+ return (
13
+ nb::getattr(in_slice, "start").is_none() &&
14
+ nb::getattr(in_slice, "stop").is_none() &&
15
+ nb::getattr(in_slice, "step").is_none());
16
+ }
17
+
18
+ int safe_to_int32(nb::object obj) {
19
+ auto val = nb::cast<int64_t>(nb::cast<nb::int_>(obj));
20
+ if (val > INT32_MAX || val < INT32_MIN) {
21
+ throw std::invalid_argument("Slice indices must be 32-bit integers.");
22
+ }
23
+ return static_cast<int>(val);
24
+ }
25
+
26
+ int get_slice_int(nb::object obj, int default_val) {
27
+ if (!obj.is_none()) {
28
+ if (!nb::isinstance<nb::int_>(obj)) {
29
+ throw std::invalid_argument("Slice indices must be integers or None.");
30
+ }
31
+ return safe_to_int32(obj);
32
+ }
33
+ return default_val;
34
+ }
35
+
36
+ void get_slice_params(
37
+ mx::ShapeElem& starts,
38
+ mx::ShapeElem& ends,
39
+ mx::ShapeElem& strides,
40
+ const nb::slice& in_slice,
41
+ int axis_size) {
42
+ // Following numpy's convention
43
+ // Assume n is the number of elements in the dimension being sliced.
44
+ // Then, if i is not given it defaults to 0 for k > 0 and n - 1 for
45
+ // k < 0 . If j is not given it defaults to n for k > 0 and -n-1 for
46
+ // k < 0 . If k is not given it defaults to 1
47
+
48
+ strides = get_slice_int(nb::getattr(in_slice, "step"), 1);
49
+ starts = get_slice_int(
50
+ nb::getattr(in_slice, "start"), strides < 0 ? axis_size - 1 : 0);
51
+ ends = get_slice_int(
52
+ nb::getattr(in_slice, "stop"), strides < 0 ? -axis_size - 1 : axis_size);
53
+ }
54
+
55
+ mx::array get_int_index(nb::object idx, int axis_size) {
56
+ int idx_ = safe_to_int32(idx);
57
+ idx_ = (idx_ < 0) ? idx_ + axis_size : idx_;
58
+
59
+ return mx::array(idx_, mx::uint32);
60
+ }
61
+
62
+ bool is_valid_index_type(const nb::object& obj) {
63
+ return nb::isinstance<nb::slice>(obj) || nb::isinstance<nb::int_>(obj) ||
64
+ nb::isinstance<mx::array>(obj) || obj.is_none() ||
65
+ nb::ellipsis().is(obj) || nb::isinstance<nb::list>(obj);
66
+ }
67
+
68
+ mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) {
69
+ // Check input and raise error if 0 dim for parity with np
70
+ if (src.ndim() == 0) {
71
+ throw std::invalid_argument(
72
+ "too many indices for array: array is 0-dimensional");
73
+ }
74
+
75
+ // Return a copy of the array if none slice is request
76
+ if (is_none_slice(in_slice)) {
77
+ return src;
78
+ }
79
+
80
+ mx::Shape starts(src.ndim(), 0);
81
+ auto ends = src.shape();
82
+ mx::Shape strides(src.ndim(), 1);
83
+
84
+ // Check and update slice params
85
+ get_slice_params(starts[0], ends[0], strides[0], in_slice, ends[0]);
86
+ return slice(src, starts, ends, strides);
87
+ }
88
+
89
+ mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) {
90
+ // Check input and raise error if 0 dim for parity with np
91
+ if (src.ndim() == 0) {
92
+ throw std::invalid_argument(
93
+ "too many indices for array: array is 0-dimensional");
94
+ }
95
+
96
+ if (indices.dtype() == mx::bool_) {
97
+ throw std::invalid_argument("boolean indices are not yet supported");
98
+ }
99
+
100
+ // If only one input array is mentioned, we set axis=0 in take
101
+ // for parity with np
102
+ return take(src, indices, 0);
103
+ }
104
+
105
+ mx::array mlx_get_item_int(const mx::array& src, const nb::int_& idx) {
106
+ // Check input and raise error if 0 dim for parity with np
107
+ if (src.ndim() == 0) {
108
+ throw std::invalid_argument(
109
+ "too many indices for array: array is 0-dimensional");
110
+ }
111
+
112
+ // If only one input idx is mentioned, we set axis=0 in take
113
+ // for parity with np
114
+ return take(src, get_int_index(idx, src.shape(0)), 0);
115
+ }
116
+
117
+ mx::array mlx_gather_nd(
118
+ mx::array src,
119
+ const std::vector<nb::object>& indices,
120
+ bool gather_first,
121
+ int& max_dims) {
122
+ max_dims = 0;
123
+ std::vector<mx::array> gather_indices;
124
+ std::vector<bool> is_slice(indices.size(), false);
125
+ int num_slices = 0;
126
+ // gather all the arrays
127
+ for (int i = 0; i < indices.size(); i++) {
128
+ auto& idx = indices[i];
129
+
130
+ if (nb::isinstance<nb::slice>(idx)) {
131
+ mx::ShapeElem start, end, stride;
132
+ get_slice_params(
133
+ start, end, stride, nb::cast<nb::slice>(idx), src.shape(i));
134
+
135
+ // Handle negative indices
136
+ start = (start < 0) ? start + src.shape(i) : start;
137
+ end = (end < 0) ? end + src.shape(i) : end;
138
+
139
+ gather_indices.push_back(arange(start, end, stride, mx::uint32));
140
+ num_slices++;
141
+ is_slice[i] = true;
142
+ } else if (nb::isinstance<nb::int_>(idx)) {
143
+ gather_indices.push_back(get_int_index(idx, src.shape(i)));
144
+ } else if (nb::isinstance<mx::array>(idx)) {
145
+ auto arr = nb::cast<mx::array>(idx);
146
+ max_dims = std::max(static_cast<int>(arr.ndim()), max_dims);
147
+ gather_indices.push_back(arr);
148
+ }
149
+ }
150
+
151
+ // reshape them so that the int/array indices are first
152
+ if (gather_first) {
153
+ int slice_index = 0;
154
+ for (int i = 0; i < gather_indices.size(); i++) {
155
+ if (is_slice[i]) {
156
+ mx::Shape index_shape(max_dims + num_slices, 1);
157
+ index_shape[max_dims + slice_index] = gather_indices[i].shape(0);
158
+ gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
159
+ slice_index++;
160
+ } else {
161
+ auto index_shape = gather_indices[i].shape();
162
+ index_shape.insert(index_shape.end(), num_slices, 1);
163
+ gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
164
+ }
165
+ }
166
+ } else {
167
+ // reshape them so that the int/array indices are last
168
+ for (int i = 0; i < gather_indices.size(); i++) {
169
+ if (i < num_slices) {
170
+ mx::Shape index_shape(max_dims + num_slices, 1);
171
+ index_shape[i] = gather_indices[i].shape(0);
172
+ gather_indices[i] = reshape(gather_indices[i], std::move(index_shape));
173
+ }
174
+ }
175
+ }
176
+
177
+ // Do the gather
178
+ std::vector<int> axes(indices.size());
179
+ std::iota(axes.begin(), axes.end(), 0);
180
+ auto slice_sizes = src.shape();
181
+ std::fill(slice_sizes.begin(), slice_sizes.begin() + indices.size(), 1);
182
+ src = gather(src, gather_indices, axes, slice_sizes);
183
+
184
+ // Squeeze the array index dims
185
+ for (auto& ax : axes) {
186
+ ax += max_dims + num_slices;
187
+ }
188
+ return mx::squeeze(src, axes);
189
+ }
190
+
191
+ auto mlx_expand_ellipsis(const mx::Shape& shape, const nb::tuple& entries) {
192
+ std::vector<nb::object> indices;
193
+
194
+ // Go over all entries and note the position of ellipsis
195
+ int non_none_indices_before = 0;
196
+ int non_none_indices_after = 0;
197
+ std::vector<nb::object> r_indices;
198
+ int i = 0;
199
+ bool has_ellipsis = false;
200
+
201
+ // Start from dimension 0 till we hit an ellipsis
202
+ for (; i < entries.size(); i++) {
203
+ auto idx = entries[i];
204
+ if (!is_valid_index_type(idx)) {
205
+ throw std::invalid_argument(
206
+ "Cannot index mlx array using the given type yet");
207
+ }
208
+ if (!nb::ellipsis().is(idx)) {
209
+ indices.push_back(idx);
210
+ non_none_indices_before += !idx.is_none();
211
+ } else {
212
+ has_ellipsis = true;
213
+ break;
214
+ }
215
+ }
216
+
217
+ // If we do hit an ellipsis, collect indices from the back
218
+ for (int j = entries.size() - 1; j > i; j--) {
219
+ auto idx = entries[j];
220
+ if (!is_valid_index_type(idx)) {
221
+ throw std::invalid_argument(
222
+ "Cannot index mlx array using the given type yet");
223
+ }
224
+ if (nb::ellipsis().is(idx)) {
225
+ throw std::invalid_argument(
226
+ "An index can only have a single ellipsis (...)");
227
+ }
228
+ r_indices.push_back(idx);
229
+ non_none_indices_after += !idx.is_none();
230
+ }
231
+
232
+ // Count up the number of non none indices
233
+ int non_none_indices = non_none_indices_before + non_none_indices_after;
234
+
235
+ // Expand ellipsis
236
+ if (has_ellipsis) {
237
+ for (int axis = non_none_indices_before;
238
+ axis < shape.size() - non_none_indices_after;
239
+ axis++) {
240
+ indices.push_back(
241
+ nb::slice(mx::ShapeElem{0}, shape[axis], mx::ShapeElem{1}));
242
+ non_none_indices++;
243
+ }
244
+ }
245
+
246
+ // Insert indices collected after the ellipsis
247
+ indices.insert(indices.end(), r_indices.rbegin(), r_indices.rend());
248
+
249
+ return std::make_pair(non_none_indices, indices);
250
+ }
251
+
252
+ mx::array mlx_get_item_nd(mx::array src, const nb::tuple& entries) {
253
+ // No indices make this a noop
254
+ if (entries.size() == 0) {
255
+ return src;
256
+ }
257
+
258
+ // The plan is as follows:
259
+ // 1. Replace the ellipsis with a series of slice(None)
260
+ // 2. Convert list to array
261
+ // 3. Loop over the indices and calculate the gather indices
262
+ // 4. Calculate the remaining slices and reshapes
263
+
264
+ // Ellipsis handling
265
+ auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
266
+ // List handling
267
+ for (auto& idx : indices) {
268
+ if (nb::isinstance<nb::list>(idx)) {
269
+ idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
270
+ }
271
+ }
272
+
273
+ // Check for the number of indices passed
274
+ if (non_none_indices > src.ndim()) {
275
+ std::ostringstream msg;
276
+ msg << "Too many indices for array with " << src.ndim() << " dimensions.";
277
+ throw std::invalid_argument(msg.str());
278
+ }
279
+
280
+ // Gather handling
281
+ //
282
+ // Check whether we have arrays or integer indices and delegate to gather_nd
283
+ // after removing the slices at the end and all Nones.
284
+ std::vector<nb::object> remaining_indices;
285
+ bool have_array = false;
286
+ {
287
+ // First check whether the results of gather are going to be 1st or
288
+ // normally in between.
289
+ bool have_non_array = false;
290
+ bool gather_first = false;
291
+ for (auto& idx : indices) {
292
+ if (nb::isinstance<mx::array>(idx) || (nb::isinstance<nb::int_>(idx))) {
293
+ if (have_array && have_non_array) {
294
+ gather_first = true;
295
+ break;
296
+ }
297
+ have_array = true;
298
+ } else {
299
+ have_non_array |= have_array;
300
+ }
301
+ }
302
+
303
+ int n_arr = 0;
304
+ for (auto& idx : indices) {
305
+ n_arr += nb::isinstance<mx::array>(idx);
306
+ }
307
+
308
+ have_array &= n_arr > 0;
309
+
310
+ if (have_array) {
311
+ int last_array;
312
+ // Then find the last array
313
+ for (last_array = indices.size() - 1; last_array >= 0; last_array--) {
314
+ auto& idx = indices[last_array];
315
+ if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
316
+ break;
317
+ }
318
+ }
319
+
320
+ std::vector<nb::object> gather_indices;
321
+ for (int i = 0; i <= last_array; i++) {
322
+ auto& idx = indices[i];
323
+ if (!idx.is_none()) {
324
+ gather_indices.push_back(idx);
325
+ }
326
+ }
327
+ int max_dims;
328
+ src = mlx_gather_nd(src, gather_indices, gather_first, max_dims);
329
+
330
+ // Reassemble the indices for the slicing or reshaping if there are any
331
+ if (gather_first) {
332
+ for (int i = 0; i < max_dims; i++) {
333
+ remaining_indices.push_back(
334
+ nb::slice(nb::none(), nb::none(), nb::none()));
335
+ }
336
+ for (int i = 0; i < last_array; i++) {
337
+ auto& idx = indices[i];
338
+ if (idx.is_none()) {
339
+ remaining_indices.push_back(indices[i]);
340
+ } else if (nb::isinstance<nb::slice>(idx)) {
341
+ remaining_indices.push_back(
342
+ nb::slice(nb::none(), nb::none(), nb::none()));
343
+ }
344
+ }
345
+ for (int i = last_array + 1; i < indices.size(); i++) {
346
+ remaining_indices.push_back(indices[i]);
347
+ }
348
+ } else {
349
+ for (int i = 0; i < indices.size(); i++) {
350
+ auto& idx = indices[i];
351
+ if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::int_>(idx)) {
352
+ break;
353
+ } else if (idx.is_none()) {
354
+ remaining_indices.push_back(idx);
355
+ } else {
356
+ remaining_indices.push_back(
357
+ nb::slice(nb::none(), nb::none(), nb::none()));
358
+ }
359
+ }
360
+ for (int i = 0; i < max_dims; i++) {
361
+ remaining_indices.push_back(
362
+ nb::slice(nb::none(), nb::none(), nb::none()));
363
+ }
364
+ for (int i = last_array + 1; i < indices.size(); i++) {
365
+ remaining_indices.push_back(indices[i]);
366
+ }
367
+ }
368
+ }
369
+ }
370
+ if (have_array && remaining_indices.empty()) {
371
+ return src;
372
+ }
373
+ if (remaining_indices.empty()) {
374
+ remaining_indices = indices;
375
+ }
376
+
377
+ bool squeeze_needed = false;
378
+ bool unsqueeze_needed = false;
379
+
380
+ // Slice handling
381
+ {
382
+ mx::Shape starts(src.ndim(), 0);
383
+ auto ends = src.shape();
384
+ mx::Shape strides(src.ndim(), 1);
385
+ int axis = 0;
386
+ for (auto& idx : remaining_indices) {
387
+ if (!idx.is_none()) {
388
+ if (!have_array && nb::isinstance<nb::int_>(idx)) {
389
+ int st = nb::cast<int>(idx);
390
+ st = (st < 0) ? st + src.shape(axis) : st;
391
+
392
+ starts[axis] = st;
393
+ ends[axis] = st + 1;
394
+
395
+ squeeze_needed = true;
396
+
397
+ } else {
398
+ get_slice_params(
399
+ starts[axis],
400
+ ends[axis],
401
+ strides[axis],
402
+ nb::cast<nb::slice>(idx),
403
+ ends[axis]);
404
+ }
405
+
406
+ axis++;
407
+ } else {
408
+ unsqueeze_needed = true;
409
+ }
410
+ }
411
+ src = slice(src, starts, ends, strides);
412
+ }
413
+
414
+ // Unsqueeze handling
415
+ if (unsqueeze_needed || squeeze_needed) {
416
+ std::vector<int> squeeze_axes;
417
+ std::vector<int> unsqueeze_axes;
418
+ for (int axis = 0; axis < remaining_indices.size(); ++axis) {
419
+ auto& idx = remaining_indices[axis];
420
+ if (unsqueeze_needed && idx.is_none()) {
421
+ unsqueeze_axes.push_back(axis - squeeze_axes.size());
422
+ } else if (squeeze_needed && nb::isinstance<nb::int_>(idx)) {
423
+ squeeze_axes.push_back(axis - unsqueeze_axes.size());
424
+ }
425
+ }
426
+ if (!squeeze_axes.empty()) {
427
+ src = squeeze(src, std::move(squeeze_axes));
428
+ }
429
+ if (!unsqueeze_axes.empty()) {
430
+ src = expand_dims(src, std::move(unsqueeze_axes));
431
+ }
432
+ }
433
+
434
+ return src;
435
+ }
436
+
437
+ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) {
438
+ if (nb::isinstance<nb::slice>(obj)) {
439
+ return mlx_get_item_slice(src, nb::cast<nb::slice>(obj));
440
+ } else if (nb::isinstance<mx::array>(obj)) {
441
+ return mlx_get_item_array(src, nb::cast<mx::array>(obj));
442
+ } else if (nb::isinstance<nb::int_>(obj)) {
443
+ return mlx_get_item_int(src, nb::cast<nb::int_>(obj));
444
+ } else if (nb::isinstance<nb::tuple>(obj)) {
445
+ return mlx_get_item_nd(src, nb::cast<nb::tuple>(obj));
446
+ } else if (nb::isinstance<nb::ellipsis>(obj)) {
447
+ return src;
448
+ } else if (obj.is_none()) {
449
+ return expand_dims(src, 0);
450
+ } else if (nb::isinstance<nb::list>(obj)) {
451
+ return mlx_get_item_array(
452
+ src, array_from_list(nb::cast<nb::list>(obj), {}));
453
+ }
454
+ throw std::invalid_argument("Cannot index mlx array using the given type.");
455
+ }
456
+
457
+ std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
458
+ mlx_scatter_args_int(
459
+ const mx::array& src,
460
+ const nb::int_& idx,
461
+ const mx::array& update) {
462
+ if (src.ndim() == 0) {
463
+ throw std::invalid_argument(
464
+ "too many indices for array: array is 0-dimensional");
465
+ }
466
+
467
+ // Remove any leading singleton dimensions from the update
468
+ // and then broadcast update to shape of src[0, ...]
469
+ int s = 0;
470
+ for (; s < update.ndim() && update.shape(s) == 1; s++)
471
+ ;
472
+ auto up_shape = mx::Shape(update.shape().begin() + s, update.shape().end());
473
+ auto shape = src.shape();
474
+ shape[0] = 1;
475
+
476
+ return {
477
+ {get_int_index(idx, src.shape(0))},
478
+ broadcast_to(reshape(update, up_shape), shape),
479
+ {0}};
480
+ }
481
+
482
+ mx::array squeeze_leading_singletons(const mx::array& in) {
483
+ int s = 0;
484
+ for (; s < in.ndim() && in.shape(s) == 1; s++)
485
+ ;
486
+ auto squeeze_axes = std::vector<int>(s);
487
+ std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
488
+ return mx::squeeze(in, squeeze_axes);
489
+ }
490
+
491
+ std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
492
+ mlx_scatter_args_array(
493
+ const mx::array& src,
494
+ const mx::array& indices,
495
+ const mx::array& update) {
496
+ if (src.ndim() == 0) {
497
+ throw std::invalid_argument(
498
+ "too many indices for array: array is 0-dimensional");
499
+ }
500
+
501
+ auto up = squeeze_leading_singletons(update);
502
+
503
+ // The update shape must broadcast with indices.shape + [1] + src.shape[1:]
504
+ auto up_shape = indices.shape();
505
+ up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end());
506
+ up = broadcast_to(up, up_shape);
507
+ up_shape.insert(up_shape.begin() + indices.ndim(), 1);
508
+ up = reshape(up, up_shape);
509
+
510
+ return {{indices}, up, {0}};
511
+ }
512
+
513
+ std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
514
+ mlx_scatter_args_slice(
515
+ const mx::array& src,
516
+ const nb::slice& in_slice,
517
+ const mx::array& update) {
518
+ // Check input and raise error if 0 dim for parity with np
519
+ if (src.ndim() == 0) {
520
+ throw std::invalid_argument(
521
+ "too many indices for array: array is 0-dimensional");
522
+ }
523
+
524
+ // If none slice is requested broadcast the update
525
+ // to the src size and return it.
526
+ if (is_none_slice(in_slice)) {
527
+ return {
528
+ {}, broadcast_to(squeeze_leading_singletons(update), src.shape()), {}};
529
+ }
530
+
531
+ mx::ShapeElem start = 0;
532
+ auto end = src.shape(0);
533
+ mx::ShapeElem stride = 1;
534
+
535
+ // Check and update slice params
536
+ get_slice_params(start, end, stride, in_slice, end);
537
+
538
+ // If simple stride
539
+ if (stride == 1) {
540
+ // Squeeze out singleton dims from the start of update
541
+ auto up = squeeze_leading_singletons(update);
542
+
543
+ // Build array to mark start of slice
544
+ auto idx = mx::array({start}, {1}, mx::uint32);
545
+
546
+ // Get slice size
547
+ int slice_size = (end - start);
548
+
549
+ // Broadcast update to slice size
550
+ mx::Shape up_shape_broadcast = {1, slice_size};
551
+ up_shape_broadcast.insert(
552
+ up_shape_broadcast.end(), src.shape().begin() + 1, src.shape().end());
553
+
554
+ up = broadcast_to(up, up_shape_broadcast);
555
+
556
+ auto indices = std::vector<mx::array>{idx};
557
+ auto axes = std::vector<int>{0};
558
+
559
+ return {indices, up, axes};
560
+ }
561
+
562
+ return mlx_scatter_args_array(
563
+ src, arange(start, end, stride, mx::uint32), update);
564
+ }
565
+
566
+ std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
567
+ mlx_scatter_args_nd(
568
+ const mx::array& src,
569
+ const nb::tuple& entries,
570
+ const mx::array& update) {
571
+ // Expand ellipses into a series of ':' slices
572
+ auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
573
+
574
+ // Convert List to array
575
+ for (auto& idx : indices) {
576
+ if (nb::isinstance<nb::list>(idx)) {
577
+ idx = nb::cast(array_from_list(nb::cast<nb::list>(idx), {}));
578
+ }
579
+ }
580
+
581
+ if (non_none_indices > src.ndim()) {
582
+ std::ostringstream msg;
583
+ msg << "Too many indices for array with " << src.ndim() << " dimensions.";
584
+ throw std::invalid_argument(msg.str());
585
+ }
586
+
587
+ auto up = squeeze_leading_singletons(update);
588
+
589
+ // If no non-None indices return the broadcasted update
590
+ if (non_none_indices == 0) {
591
+ return {{}, broadcast_to(up, src.shape()), {}};
592
+ }
593
+
594
+ // Analyse the types of the indices
595
+ size_t max_dim = 0;
596
+ bool arrays_first = false;
597
+ int num_none = 0;
598
+ int num_slices = 0;
599
+ int num_arrays = 0;
600
+ int num_strided_slices = 0;
601
+ int num_simple_slices_post = 0;
602
+ {
603
+ bool have_array = false;
604
+ bool have_non_array = false;
605
+ for (auto& idx : indices) {
606
+ if (idx.is_none()) {
607
+ have_non_array = have_array;
608
+ num_none++;
609
+
610
+ } else if (nb::isinstance<nb::slice>(idx)) {
611
+ have_non_array = have_array;
612
+ num_slices++;
613
+
614
+ auto slice = nb::cast<nb::slice>(idx);
615
+ int stride = get_slice_int(nb::getattr(slice, "step"), 1);
616
+ if (stride != 1) {
617
+ num_strided_slices++;
618
+ num_simple_slices_post = 0;
619
+ } else {
620
+ num_simple_slices_post++;
621
+ }
622
+
623
+ } else if (nb::isinstance<mx::array>(idx)) {
624
+ have_array = true;
625
+ if (have_array && have_non_array) {
626
+ arrays_first = true;
627
+ }
628
+ max_dim = std::max(nb::cast<mx::array>(idx).ndim(), max_dim);
629
+ num_arrays++;
630
+ num_simple_slices_post = 0;
631
+ }
632
+ }
633
+ }
634
+
635
+ // We have index dims for the arrays, strided slices (implemented as arrays),
636
+ // none
637
+ int idx_ndim = max_dim + num_none + num_slices - num_simple_slices_post;
638
+
639
+ // If we have simple non-strided slices, we also attach an index for that
640
+ idx_ndim = idx_ndim == 0 ? 1 : idx_ndim;
641
+
642
+ // Go over each index type and translate to the needed scatter args
643
+ std::vector<mx::array> arr_indices;
644
+ int slice_num = 0;
645
+ int array_num = 0;
646
+ int ax = 0;
647
+
648
+ // We collect the shapes of the slices and updates during this process
649
+ std::vector<int> update_shape(non_none_indices, 1);
650
+ std::vector<int> slice_shapes;
651
+
652
+ for (int i = 0; i < indices.size(); ++i) {
653
+ auto& pyidx = indices[i];
654
+ if (nb::isinstance<nb::slice>(pyidx)) {
655
+ mx::ShapeElem start, end, stride;
656
+ auto axis_size = src.shape(ax++);
657
+ get_slice_params(
658
+ start, end, stride, nb::cast<nb::slice>(pyidx), axis_size);
659
+
660
+ // Handle negative indices
661
+ start = (start < 0) ? start + axis_size : start;
662
+ end = (end < 0) ? end + axis_size : end;
663
+
664
+ mx::Shape idx_shape(idx_ndim, 1);
665
+
666
+ // If it's a simple slice, we only need to add the start index
667
+ if (array_num >= num_arrays && num_strided_slices <= 0 && stride == 1) {
668
+ auto idx = mx::array({start}, idx_shape, mx::uint32);
669
+ slice_shapes.push_back(end - start);
670
+ arr_indices.push_back(idx);
671
+
672
+ // Add the shape to the update
673
+ update_shape[ax - 1] = slice_shapes.back();
674
+ }
675
+ // Otherwise we expand the slice into indices using arange
676
+ else {
677
+ auto idx = arange(start, end, stride, mx::uint32);
678
+ auto loc = slice_num + (arrays_first ? max_dim : 0);
679
+ idx_shape[loc] = idx.size();
680
+ arr_indices.push_back(reshape(idx, idx_shape));
681
+
682
+ slice_num++;
683
+ num_strided_slices--;
684
+
685
+ // Add the shape to the update
686
+ update_shape[ax - 1] = 1;
687
+ }
688
+ } else if (nb::isinstance<nb::int_>(pyidx)) {
689
+ // Add index to arrays
690
+ arr_indices.push_back(get_int_index(pyidx, src.shape(ax++)));
691
+ // Add the shape to the update
692
+ update_shape[ax - 1] = 1;
693
+ } else if (pyidx.is_none()) {
694
+ // We only use the None's for bookeeping dimensions
695
+ slice_num++;
696
+ } else if (nb::isinstance<mx::array>(pyidx)) {
697
+ ax++;
698
+ auto idx = nb::cast<mx::array>(pyidx);
699
+ mx::Shape idx_shape(idx_ndim, 1);
700
+
701
+ // Place the arrays in the correct dimension
702
+ int st = (!arrays_first) * slice_num + max_dim - idx.ndim();
703
+ for (int j = 0; j < idx.ndim(); j++) {
704
+ idx_shape[st + j] = idx.shape()[j];
705
+ }
706
+ arr_indices.push_back(reshape(idx, idx_shape));
707
+ if (!arrays_first && ++array_num == num_arrays) {
708
+ slice_num += max_dim;
709
+ }
710
+
711
+ // Add the shape to the update
712
+ update_shape[ax - 1] = 1;
713
+ } else {
714
+ throw std::invalid_argument(
715
+ "Cannot index mlx array using the given type yet");
716
+ }
717
+ }
718
+
719
+ // Broadcast the update to the indices and slices
720
+ arr_indices = broadcast_arrays(arr_indices);
721
+ auto up_shape_broadcast = arr_indices[0].shape();
722
+
723
+ up_shape_broadcast.insert(
724
+ up_shape_broadcast.end(), slice_shapes.begin(), slice_shapes.end());
725
+ up_shape_broadcast.insert(
726
+ up_shape_broadcast.end(),
727
+ src.shape().begin() + non_none_indices,
728
+ src.shape().end());
729
+ up = broadcast_to(up, up_shape_broadcast);
730
+
731
+ // Reshape the update with the size-1 dims for the int and array indices
732
+ auto up_reshape = arr_indices[0].shape();
733
+ up_reshape.insert(up_reshape.end(), update_shape.begin(), update_shape.end());
734
+ up_reshape.insert(
735
+ up_reshape.end(),
736
+ src.shape().begin() + non_none_indices,
737
+ src.shape().end());
738
+
739
+ up = reshape(up, up_reshape);
740
+
741
+ // Collect axes
742
+ std::vector<int> axes(arr_indices.size(), 0);
743
+ std::iota(axes.begin(), axes.end(), 0);
744
+
745
+ return {arr_indices, up, axes};
746
+ }
747
+
748
+ std::tuple<std::vector<mx::array>, mx::array, std::vector<int>>
749
+ mlx_compute_scatter_args(
750
+ const mx::array& src,
751
+ const nb::object& obj,
752
+ const ScalarOrArray& v) {
753
+ auto vals = to_array(v, src.dtype());
754
+ if (nb::isinstance<nb::slice>(obj)) {
755
+ return mlx_scatter_args_slice(src, nb::cast<nb::slice>(obj), vals);
756
+ } else if (nb::isinstance<mx::array>(obj)) {
757
+ return mlx_scatter_args_array(src, nb::cast<mx::array>(obj), vals);
758
+ } else if (nb::isinstance<nb::int_>(obj)) {
759
+ return mlx_scatter_args_int(src, nb::cast<nb::int_>(obj), vals);
760
+ } else if (nb::isinstance<nb::tuple>(obj)) {
761
+ return mlx_scatter_args_nd(src, nb::cast<nb::tuple>(obj), vals);
762
+ } else if (obj.is_none()) {
763
+ return {{}, broadcast_to(vals, src.shape()), {}};
764
+ } else if (nb::isinstance<nb::list>(obj)) {
765
+ return mlx_scatter_args_array(
766
+ src, array_from_list(nb::cast<nb::list>(obj), {}), vals);
767
+ }
768
+
769
+ throw std::invalid_argument("Cannot index mlx array using the given type.");
770
+ }
771
+
772
+ auto mlx_slice_update(
773
+ const mx::array& src,
774
+ const nb::object& obj,
775
+ const ScalarOrArray& v) {
776
+ // Can't route to slice update if not slice, tuple, or int
777
+ if (src.ndim() == 0 || nb::isinstance<nb::bool_>(obj) ||
778
+ (!nb::isinstance<nb::slice>(obj) && !nb::isinstance<nb::tuple>(obj) &&
779
+ !nb::isinstance<nb::int_>(obj))) {
780
+ return std::make_pair(false, src);
781
+ }
782
+ if (nb::isinstance<nb::tuple>(obj)) {
783
+ // Can't route to slice update if any arrays are present
784
+ for (auto idx : nb::cast<nb::tuple>(obj)) {
785
+ if (nb::isinstance<mx::array>(idx) || nb::isinstance<nb::list>(idx)) {
786
+ return std::make_pair(false, src);
787
+ }
788
+ }
789
+ }
790
+ // Should be able to route to slice update
791
+
792
+ // Pre process tuple
793
+ auto upd = to_array(v, src.dtype());
794
+
795
+ // Remove extra leading singletons dimensions from the update
796
+ int s = 0;
797
+ for (; s < static_cast<int>(upd.ndim()) - 1 && upd.shape(s) == 1 &&
798
+ (upd.ndim() - s) > src.ndim();
799
+ s++) {
800
+ };
801
+ auto squeeze_axes = std::vector<int>(s);
802
+ std::iota(squeeze_axes.begin(), squeeze_axes.end(), 0);
803
+ auto up = mx::squeeze(upd, squeeze_axes);
804
+
805
+ // Build slice update params
806
+ mx::Shape starts(src.ndim(), 0);
807
+ mx::Shape stops = src.shape();
808
+ mx::Shape strides(src.ndim(), 1);
809
+ if (nb::isinstance<nb::int_>(obj)) {
810
+ if (src.ndim() < 1) {
811
+ std::ostringstream msg;
812
+ msg << "Too many indices for array with " << src.ndim() << " dimensions.";
813
+ throw std::invalid_argument(msg.str());
814
+ }
815
+ auto idx = nb::cast<int>(obj);
816
+ idx = idx < 0 ? idx + stops[0] : idx;
817
+ starts[0] = idx;
818
+ stops[0] = idx + 1;
819
+ auto out = slice_update(
820
+ src, up, std::move(starts), std::move(stops), std::move(strides));
821
+ return std::make_pair(true, out);
822
+ }
823
+
824
+ // If it's just a simple slice, just do a slice update and return
825
+ if (nb::isinstance<nb::slice>(obj)) {
826
+ // Read slice arguments
827
+ get_slice_params(
828
+ starts[0],
829
+ stops[0],
830
+ strides[0],
831
+ nb::cast<nb::slice>(obj),
832
+ src.shape(0));
833
+
834
+ // Do slice update
835
+ auto out = slice_update(src, up, starts, stops, strides);
836
+ return std::make_pair(true, out);
837
+ }
838
+
839
+ // It must be a tuple
840
+ auto entries = nb::cast<nb::tuple>(obj);
841
+
842
+ // Expand ellipses into a series of ':' slices
843
+ auto [non_none_indices, indices] = mlx_expand_ellipsis(src.shape(), entries);
844
+
845
+ // Dimension check
846
+ if (non_none_indices > src.ndim()) {
847
+ std::ostringstream msg;
848
+ msg << "Too many indices for array with " << src.ndim() << " dimensions.";
849
+ throw std::invalid_argument(msg.str());
850
+ }
851
+
852
+ // If no non-None indices return the broadcasted update
853
+ if (non_none_indices == 0) {
854
+ return std::make_pair(true, broadcast_to(up, src.shape()));
855
+ }
856
+
857
+ int unspecified = src.ndim() - non_none_indices;
858
+ std::vector<int> squeeze_dims;
859
+ std::vector<int> expand_dims;
860
+ for (int i = indices.size() - 1,
861
+ ax = non_none_indices - 1,
862
+ upd_ax = upd.ndim() - unspecified - 1;
863
+ i >= 0;
864
+ --i) {
865
+ auto& pyidx = indices[i];
866
+ if (nb::isinstance<nb::slice>(pyidx)) {
867
+ get_slice_params(
868
+ starts[ax],
869
+ stops[ax],
870
+ strides[ax],
871
+ nb::cast<nb::slice>(pyidx),
872
+ src.shape(ax));
873
+ ax--;
874
+ upd_ax--;
875
+ } else if (nb::isinstance<nb::int_>(pyidx)) {
876
+ int st = nb::cast<int>(pyidx);
877
+ st = (st < 0) ? st + src.shape(i) : st;
878
+ starts[ax] = st;
879
+ stops[ax] = st + 1;
880
+ if (upd_ax >= 0) {
881
+ expand_dims.push_back(i - indices.size() - unspecified);
882
+ }
883
+ ax--;
884
+ } else if (pyidx.is_none()) {
885
+ if (upd_ax-- >= 0) {
886
+ squeeze_dims.push_back(i - indices.size() - unspecified);
887
+ }
888
+ }
889
+ }
890
+
891
+ up = mx::squeeze(
892
+ mx::expand_dims(up, std::move(expand_dims)), std::move(squeeze_dims));
893
+ auto out = slice_update(src, up, starts, stops, strides);
894
+ return std::make_pair(true, out);
895
+ }
896
+
897
+ std::optional<mx::array> extract_boolean_mask(const nb::object& obj) {
898
+ using NDArray = nb::ndarray<nb::ro, nb::c_contig, nb::device::cpu>;
899
+ if (nb::isinstance<nb::bool_>(obj)) {
900
+ return mx::array(nb::cast<bool>(obj), mx::bool_);
901
+ } else if (nb::isinstance<mx::array>(obj)) {
902
+ auto mask = nb::cast<mx::array>(obj);
903
+ if (mask.dtype() == mx::bool_) {
904
+ return mask;
905
+ }
906
+ } else if (nb::isinstance<NDArray>(obj)) {
907
+ auto mask = nb::cast<NDArray>(obj);
908
+ if (mask.dtype() == nb::dtype<bool>()) {
909
+ return nd_array_to_mlx(mask, mx::bool_);
910
+ }
911
+ } else if (nb::isinstance<nb::list>(obj)) {
912
+ auto mask = array_from_list(nb::cast<nb::list>(obj), {});
913
+ if (mask.dtype() == mx::bool_) {
914
+ return mask;
915
+ }
916
+ }
917
+ return std::nullopt;
918
+ }
919
+
920
+ void mlx_set_item(
921
+ mx::array& src,
922
+ const nb::object& obj,
923
+ const ScalarOrArray& v) {
924
+ auto [success, out] = mlx_slice_update(src, obj, v);
925
+ if (success) {
926
+ src.overwrite_descriptor(out);
927
+ return;
928
+ }
929
+
930
+ if (auto mask = extract_boolean_mask(obj)) {
931
+ auto updates = to_array(v, src.dtype());
932
+ auto result = masked_scatter(src, *mask, updates);
933
+ src.overwrite_descriptor(result);
934
+ return;
935
+ }
936
+
937
+ auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
938
+ if (indices.size() > 0) {
939
+ auto out = scatter(src, indices, updates, axes);
940
+ src.overwrite_descriptor(out);
941
+ } else {
942
+ src.overwrite_descriptor(updates);
943
+ }
944
+ }
945
+
946
+ mx::array mlx_add_item(
947
+ const mx::array& src,
948
+ const nb::object& obj,
949
+ const ScalarOrArray& v) {
950
+ auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
951
+ if (indices.size() > 0) {
952
+ return scatter_add(src, indices, updates, axes);
953
+ } else {
954
+ return src + updates;
955
+ }
956
+ }
957
+
958
+ mx::array mlx_subtract_item(
959
+ const mx::array& src,
960
+ const nb::object& obj,
961
+ const ScalarOrArray& v) {
962
+ auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
963
+ if (indices.size() > 0) {
964
+ return scatter_add(src, indices, -updates, axes);
965
+ } else {
966
+ return src - updates;
967
+ }
968
+ }
969
+
970
+ mx::array mlx_multiply_item(
971
+ const mx::array& src,
972
+ const nb::object& obj,
973
+ const ScalarOrArray& v) {
974
+ auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
975
+ if (indices.size() > 0) {
976
+ return scatter_prod(src, indices, updates, axes);
977
+ } else {
978
+ return src * updates;
979
+ }
980
+ }
981
+
982
+ mx::array mlx_divide_item(
983
+ const mx::array& src,
984
+ const nb::object& obj,
985
+ const ScalarOrArray& v) {
986
+ auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
987
+ if (indices.size() > 0) {
988
+ return scatter_prod(src, indices, reciprocal(updates), axes);
989
+ } else {
990
+ return src / updates;
991
+ }
992
+ }
993
+
994
+ mx::array mlx_maximum_item(
995
+ const mx::array& src,
996
+ const nb::object& obj,
997
+ const ScalarOrArray& v) {
998
+ auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
999
+ if (indices.size() > 0) {
1000
+ return scatter_max(src, indices, updates, axes);
1001
+ } else {
1002
+ return maximum(src, updates);
1003
+ }
1004
+ }
1005
+
1006
+ mx::array mlx_minimum_item(
1007
+ const mx::array& src,
1008
+ const nb::object& obj,
1009
+ const ScalarOrArray& v) {
1010
+ auto [indices, updates, axes] = mlx_compute_scatter_args(src, obj, v);
1011
+ if (indices.size() > 0) {
1012
+ return scatter_min(src, indices, updates, axes);
1013
+ } else {
1014
+ return minimum(src, updates);
1015
+ }
1016
+ }