mlx 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlx might be problematic. Click here for more details.

Files changed (914) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/CMakeLists.txt +7 -0
  3. data/ext/mlx/Makefile +273 -0
  4. data/ext/mlx/extconf.rb +94 -0
  5. data/ext/mlx/mkmf.log +44 -0
  6. data/ext/mlx/native.bundle +0 -0
  7. data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
  8. data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
  9. data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
  10. data/ext/mlx/native.cpp +8027 -0
  11. data/ext/mlx/native.o +0 -0
  12. data/lib/mlx/core.rb +1678 -0
  13. data/lib/mlx/distributed_utils/common.rb +116 -0
  14. data/lib/mlx/distributed_utils/config.rb +600 -0
  15. data/lib/mlx/distributed_utils/launch.rb +490 -0
  16. data/lib/mlx/extension.rb +24 -0
  17. data/lib/mlx/nn/base.rb +388 -0
  18. data/lib/mlx/nn/init.rb +140 -0
  19. data/lib/mlx/nn/layers/activations.rb +336 -0
  20. data/lib/mlx/nn/layers/base.rb +6 -0
  21. data/lib/mlx/nn/layers/containers.rb +20 -0
  22. data/lib/mlx/nn/layers/convolution.rb +120 -0
  23. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  24. data/lib/mlx/nn/layers/distributed.rb +309 -0
  25. data/lib/mlx/nn/layers/dropout.rb +75 -0
  26. data/lib/mlx/nn/layers/embedding.rb +28 -0
  27. data/lib/mlx/nn/layers/linear.rb +79 -0
  28. data/lib/mlx/nn/layers/normalization.rb +216 -0
  29. data/lib/mlx/nn/layers/pooling.rb +167 -0
  30. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  31. data/lib/mlx/nn/layers/quantized.rb +215 -0
  32. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  33. data/lib/mlx/nn/layers/transformer.rb +330 -0
  34. data/lib/mlx/nn/layers/upsample.rb +97 -0
  35. data/lib/mlx/nn/layers.rb +18 -0
  36. data/lib/mlx/nn/losses.rb +251 -0
  37. data/lib/mlx/nn/utils.rb +167 -0
  38. data/lib/mlx/nn.rb +12 -0
  39. data/lib/mlx/optimizers/optimizers.rb +808 -0
  40. data/lib/mlx/optimizers/schedulers.rb +62 -0
  41. data/lib/mlx/optimizers.rb +9 -0
  42. data/lib/mlx/utils.rb +171 -0
  43. data/lib/mlx/version +1 -0
  44. data/lib/mlx/version.rb +5 -0
  45. data/lib/mlx.rb +64 -0
  46. data/mlx/.clang-format +87 -0
  47. data/mlx/.git +1 -0
  48. data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
  49. data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
  50. data/mlx/.github/actions/build-docs/action.yml +38 -0
  51. data/mlx/.github/actions/build-linux/action.yml +38 -0
  52. data/mlx/.github/actions/build-linux-release/action.yml +42 -0
  53. data/mlx/.github/actions/build-macos/action.yml +80 -0
  54. data/mlx/.github/actions/build-macos-release/action.yml +36 -0
  55. data/mlx/.github/actions/build-windows/action.yml +26 -0
  56. data/mlx/.github/actions/setup-linux/action.yml +93 -0
  57. data/mlx/.github/actions/setup-macos/action.yml +24 -0
  58. data/mlx/.github/actions/setup-windows/action.yml +42 -0
  59. data/mlx/.github/actions/test-linux/action.yml +69 -0
  60. data/mlx/.github/actions/test-windows/action.yml +20 -0
  61. data/mlx/.github/dependabot.yml +6 -0
  62. data/mlx/.github/pull_request_template.md +12 -0
  63. data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
  64. data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
  65. data/mlx/.github/workflows/build_and_test.yml +152 -0
  66. data/mlx/.github/workflows/documentation.yml +28 -0
  67. data/mlx/.github/workflows/nightly.yml +104 -0
  68. data/mlx/.github/workflows/release.yml +256 -0
  69. data/mlx/.gitignore +81 -0
  70. data/mlx/.pre-commit-config.yaml +27 -0
  71. data/mlx/ACKNOWLEDGMENTS.md +268 -0
  72. data/mlx/CITATION.cff +24 -0
  73. data/mlx/CMakeLists.txt +437 -0
  74. data/mlx/CODE_OF_CONDUCT.md +132 -0
  75. data/mlx/CONTRIBUTING.md +38 -0
  76. data/mlx/LICENSE +21 -0
  77. data/mlx/MANIFEST.in +6 -0
  78. data/mlx/README.md +121 -0
  79. data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
  80. data/mlx/benchmarks/cpp/autograd.cpp +39 -0
  81. data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
  82. data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
  83. data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
  84. data/mlx/benchmarks/cpp/time_utils.h +39 -0
  85. data/mlx/benchmarks/numpy/single_ops.py +39 -0
  86. data/mlx/benchmarks/numpy/time_utils.py +20 -0
  87. data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
  88. data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
  89. data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
  90. data/mlx/benchmarks/python/comparative/README.md +15 -0
  91. data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
  92. data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
  93. data/mlx/benchmarks/python/comparative/compare.py +284 -0
  94. data/mlx/benchmarks/python/compile_bench.py +107 -0
  95. data/mlx/benchmarks/python/conv1d_bench.py +123 -0
  96. data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
  97. data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
  98. data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
  99. data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
  100. data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
  101. data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
  102. data/mlx/benchmarks/python/conv_bench.py +135 -0
  103. data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
  104. data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
  105. data/mlx/benchmarks/python/distributed_bench.py +66 -0
  106. data/mlx/benchmarks/python/einsum_bench.py +84 -0
  107. data/mlx/benchmarks/python/fft_bench.py +118 -0
  108. data/mlx/benchmarks/python/gather_bench.py +52 -0
  109. data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
  110. data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
  111. data/mlx/benchmarks/python/hadamard_bench.py +70 -0
  112. data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
  113. data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
  114. data/mlx/benchmarks/python/masked_scatter.py +212 -0
  115. data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
  116. data/mlx/benchmarks/python/rope_bench.py +35 -0
  117. data/mlx/benchmarks/python/scatter_bench.py +96 -0
  118. data/mlx/benchmarks/python/sdpa_bench.py +223 -0
  119. data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
  120. data/mlx/benchmarks/python/single_ops.py +132 -0
  121. data/mlx/benchmarks/python/synchronize_bench.py +55 -0
  122. data/mlx/benchmarks/python/time_utils.py +38 -0
  123. data/mlx/cmake/FindCUDNN.cmake +177 -0
  124. data/mlx/cmake/FindNCCL.cmake +54 -0
  125. data/mlx/cmake/Findnvpl.cmake +3 -0
  126. data/mlx/cmake/extension.cmake +50 -0
  127. data/mlx/docs/.clang-format +2 -0
  128. data/mlx/docs/.gitignore +3 -0
  129. data/mlx/docs/.nojekyll +0 -0
  130. data/mlx/docs/Doxyfile +51 -0
  131. data/mlx/docs/Makefile +18 -0
  132. data/mlx/docs/README.md +54 -0
  133. data/mlx/docs/index.html +1 -0
  134. data/mlx/docs/requirements.txt +5 -0
  135. data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
  136. data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
  137. data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
  138. data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
  139. data/mlx/docs/src/_static/mlx_logo.png +0 -0
  140. data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
  141. data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
  142. data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
  143. data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
  144. data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
  145. data/mlx/docs/src/_templates/module-base-class.rst +33 -0
  146. data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
  147. data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
  148. data/mlx/docs/src/conf.py +99 -0
  149. data/mlx/docs/src/cpp/ops.rst +7 -0
  150. data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
  151. data/mlx/docs/src/dev/extensions.rst +811 -0
  152. data/mlx/docs/src/dev/metal_debugger.rst +68 -0
  153. data/mlx/docs/src/dev/metal_logging.rst +40 -0
  154. data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
  155. data/mlx/docs/src/examples/data_parallelism.rst +91 -0
  156. data/mlx/docs/src/examples/linear_regression.rst +77 -0
  157. data/mlx/docs/src/examples/llama-inference.rst +382 -0
  158. data/mlx/docs/src/examples/mlp.rst +134 -0
  159. data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
  160. data/mlx/docs/src/index.rst +96 -0
  161. data/mlx/docs/src/install.rst +340 -0
  162. data/mlx/docs/src/python/array.rst +65 -0
  163. data/mlx/docs/src/python/cuda.rst +9 -0
  164. data/mlx/docs/src/python/data_types.rst +78 -0
  165. data/mlx/docs/src/python/devices_and_streams.rst +21 -0
  166. data/mlx/docs/src/python/distributed.rst +22 -0
  167. data/mlx/docs/src/python/export.rst +14 -0
  168. data/mlx/docs/src/python/fast.rst +16 -0
  169. data/mlx/docs/src/python/fft.rst +24 -0
  170. data/mlx/docs/src/python/linalg.rst +27 -0
  171. data/mlx/docs/src/python/memory_management.rst +16 -0
  172. data/mlx/docs/src/python/metal.rst +12 -0
  173. data/mlx/docs/src/python/nn/distributed.rst +30 -0
  174. data/mlx/docs/src/python/nn/functions.rst +40 -0
  175. data/mlx/docs/src/python/nn/init.rst +45 -0
  176. data/mlx/docs/src/python/nn/layers.rst +74 -0
  177. data/mlx/docs/src/python/nn/losses.rst +25 -0
  178. data/mlx/docs/src/python/nn/module.rst +38 -0
  179. data/mlx/docs/src/python/nn.rst +186 -0
  180. data/mlx/docs/src/python/ops.rst +184 -0
  181. data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
  182. data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
  183. data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
  184. data/mlx/docs/src/python/optimizers.rst +78 -0
  185. data/mlx/docs/src/python/random.rst +48 -0
  186. data/mlx/docs/src/python/transforms.rst +22 -0
  187. data/mlx/docs/src/python/tree_utils.rst +23 -0
  188. data/mlx/docs/src/usage/compile.rst +516 -0
  189. data/mlx/docs/src/usage/distributed.rst +572 -0
  190. data/mlx/docs/src/usage/export.rst +288 -0
  191. data/mlx/docs/src/usage/function_transforms.rst +191 -0
  192. data/mlx/docs/src/usage/indexing.rst +194 -0
  193. data/mlx/docs/src/usage/launching_distributed.rst +234 -0
  194. data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
  195. data/mlx/docs/src/usage/numpy.rst +124 -0
  196. data/mlx/docs/src/usage/quick_start.rst +67 -0
  197. data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
  198. data/mlx/docs/src/usage/unified_memory.rst +78 -0
  199. data/mlx/docs/src/usage/using_streams.rst +18 -0
  200. data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
  201. data/mlx/examples/cmake_project/README.md +26 -0
  202. data/mlx/examples/cmake_project/example.cpp +14 -0
  203. data/mlx/examples/cpp/CMakeLists.txt +12 -0
  204. data/mlx/examples/cpp/distributed.cpp +22 -0
  205. data/mlx/examples/cpp/linear_regression.cpp +54 -0
  206. data/mlx/examples/cpp/logistic_regression.cpp +54 -0
  207. data/mlx/examples/cpp/metal_capture.cpp +31 -0
  208. data/mlx/examples/cpp/timer.h +20 -0
  209. data/mlx/examples/cpp/tutorial.cpp +99 -0
  210. data/mlx/examples/export/CMakeLists.txt +22 -0
  211. data/mlx/examples/export/README.md +49 -0
  212. data/mlx/examples/export/eval_mlp.cpp +25 -0
  213. data/mlx/examples/export/eval_mlp.py +52 -0
  214. data/mlx/examples/export/train_mlp.cpp +35 -0
  215. data/mlx/examples/export/train_mlp.py +76 -0
  216. data/mlx/examples/extensions/CMakeLists.txt +78 -0
  217. data/mlx/examples/extensions/README.md +24 -0
  218. data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
  219. data/mlx/examples/extensions/axpby/axpby.h +90 -0
  220. data/mlx/examples/extensions/axpby/axpby.metal +47 -0
  221. data/mlx/examples/extensions/bindings.cpp +39 -0
  222. data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
  223. data/mlx/examples/extensions/pyproject.toml +8 -0
  224. data/mlx/examples/extensions/requirements.txt +4 -0
  225. data/mlx/examples/extensions/setup.py +18 -0
  226. data/mlx/examples/extensions/test.py +12 -0
  227. data/mlx/examples/python/linear_regression.py +46 -0
  228. data/mlx/examples/python/logistic_regression.py +49 -0
  229. data/mlx/examples/python/qqmm.py +117 -0
  230. data/mlx/mlx/3rdparty/.clang-format +2 -0
  231. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  232. data/mlx/mlx/CMakeLists.txt +107 -0
  233. data/mlx/mlx/allocator.h +75 -0
  234. data/mlx/mlx/api.h +29 -0
  235. data/mlx/mlx/array.cpp +354 -0
  236. data/mlx/mlx/array.h +647 -0
  237. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  238. data/mlx/mlx/backend/common/binary.h +97 -0
  239. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  240. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  241. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  242. data/mlx/mlx/backend/common/common.cpp +305 -0
  243. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  244. data/mlx/mlx/backend/common/compiled.h +77 -0
  245. data/mlx/mlx/backend/common/copy.h +50 -0
  246. data/mlx/mlx/backend/common/hadamard.h +109 -0
  247. data/mlx/mlx/backend/common/load.cpp +57 -0
  248. data/mlx/mlx/backend/common/matmul.h +67 -0
  249. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  250. data/mlx/mlx/backend/common/reduce.h +59 -0
  251. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  252. data/mlx/mlx/backend/common/slicing.h +20 -0
  253. data/mlx/mlx/backend/common/ternary.h +85 -0
  254. data/mlx/mlx/backend/common/unary.h +29 -0
  255. data/mlx/mlx/backend/common/utils.cpp +231 -0
  256. data/mlx/mlx/backend/common/utils.h +205 -0
  257. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  258. data/mlx/mlx/backend/cpu/arange.h +28 -0
  259. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  260. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  261. data/mlx/mlx/backend/cpu/binary.h +517 -0
  262. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  263. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  264. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  265. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  266. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  267. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  268. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  269. data/mlx/mlx/backend/cpu/copy.h +36 -0
  270. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  271. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  272. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  273. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  274. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  275. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  276. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  277. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  278. data/mlx/mlx/backend/cpu/eval.h +12 -0
  279. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  280. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  281. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  282. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  283. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  284. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  285. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  286. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  287. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  288. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  289. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  290. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  291. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  292. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  293. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  294. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  295. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  296. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  297. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  298. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  299. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  300. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  301. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  302. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  303. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  304. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  305. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  306. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  307. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  308. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  309. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  310. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  311. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  312. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  313. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  314. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  315. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  316. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  317. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  318. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  319. data/mlx/mlx/backend/cpu/unary.h +281 -0
  320. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  321. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  322. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  323. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  324. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  325. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  326. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  327. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  328. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  329. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  330. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  331. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  332. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  333. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  334. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  335. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  336. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  337. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  338. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  339. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  340. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  341. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  342. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  343. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  344. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  345. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  346. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  347. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  348. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  349. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  350. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  351. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  352. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  353. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  354. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  355. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  356. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  357. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  358. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  359. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  360. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  361. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  362. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  363. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  364. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  365. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  366. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  367. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  368. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  369. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  370. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  371. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  372. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  373. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  374. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  375. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  376. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  377. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  378. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  379. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  380. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  381. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  382. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  383. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  384. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  385. data/mlx/mlx/backend/cuda/device.h +195 -0
  386. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  387. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  388. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  389. data/mlx/mlx/backend/cuda/event.cu +415 -0
  390. data/mlx/mlx/backend/cuda/event.h +79 -0
  391. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  392. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  393. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  394. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  395. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  396. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  397. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  398. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  399. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  400. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  401. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  402. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  403. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  404. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  405. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  406. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  407. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  408. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  409. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  410. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  411. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  412. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  413. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  414. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  415. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  416. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  417. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  418. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  419. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  420. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  421. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  422. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  423. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  424. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  425. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  426. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  427. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  428. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  429. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  430. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  431. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  432. data/mlx/mlx/backend/cuda/random.cu +202 -0
  433. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  434. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  435. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  436. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  437. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  438. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  439. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  440. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  441. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  442. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  443. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  444. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  445. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  446. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  447. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  448. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  449. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  450. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  451. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  452. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  453. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  454. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  455. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  456. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  457. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  458. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  459. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  460. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  461. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  462. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  463. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  464. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  465. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  466. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  467. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  468. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  469. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  470. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  471. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  472. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  473. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  474. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  475. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  476. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  477. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  478. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  479. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  480. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  481. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  482. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  483. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  484. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  485. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  486. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  487. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  488. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  489. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  490. data/mlx/mlx/backend/cuda/utils.h +49 -0
  491. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  492. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  493. data/mlx/mlx/backend/cuda/worker.h +55 -0
  494. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  495. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  496. data/mlx/mlx/backend/gpu/copy.h +57 -0
  497. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  498. data/mlx/mlx/backend/gpu/eval.h +18 -0
  499. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  500. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  501. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  502. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  503. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  504. data/mlx/mlx/backend/metal/allocator.h +79 -0
  505. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  506. data/mlx/mlx/backend/metal/binary.h +33 -0
  507. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  508. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  509. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  510. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  511. data/mlx/mlx/backend/metal/device.cpp +816 -0
  512. data/mlx/mlx/backend/metal/device.h +289 -0
  513. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  514. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  515. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  516. data/mlx/mlx/backend/metal/event.cpp +62 -0
  517. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  518. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  519. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  520. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  521. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  522. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  523. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  524. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  525. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  526. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  527. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  528. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  529. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  530. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  531. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  532. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  533. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  534. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  535. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  536. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  537. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  538. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  539. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  540. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  541. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  542. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  543. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  544. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  545. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  546. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  547. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  548. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  549. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  550. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  551. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  552. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  553. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  554. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  555. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  556. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  557. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  558. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  559. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  560. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  561. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  562. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  563. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  564. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  565. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  566. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  567. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  568. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  569. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  570. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  571. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  572. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  573. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  574. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  575. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  576. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  577. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  578. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  579. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  580. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  581. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  582. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  583. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  584. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  585. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  586. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  587. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  588. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  589. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  590. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  591. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  592. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  593. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  594. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  595. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  596. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  597. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  598. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  599. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  600. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  601. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  602. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  603. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  604. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  605. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  606. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  607. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  608. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  609. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  610. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  611. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  612. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  613. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  614. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  615. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  616. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  617. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  618. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  619. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  620. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  621. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  622. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  623. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  624. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  625. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  626. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  627. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  628. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  629. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  630. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  631. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  632. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  633. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  634. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  635. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  636. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  637. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  638. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  639. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  640. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  641. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  642. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  643. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  644. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  645. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  646. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  647. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  648. data/mlx/mlx/backend/metal/kernels.h +375 -0
  649. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  650. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  651. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  652. data/mlx/mlx/backend/metal/matmul.h +144 -0
  653. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  654. data/mlx/mlx/backend/metal/metal.h +25 -0
  655. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  656. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  657. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  658. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  659. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  660. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  661. data/mlx/mlx/backend/metal/reduce.h +41 -0
  662. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  663. data/mlx/mlx/backend/metal/resident.h +32 -0
  664. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  665. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  666. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  667. data/mlx/mlx/backend/metal/scan.h +17 -0
  668. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  669. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  670. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  671. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  672. data/mlx/mlx/backend/metal/ternary.h +21 -0
  673. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  674. data/mlx/mlx/backend/metal/unary.h +21 -0
  675. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  676. data/mlx/mlx/backend/metal/utils.h +99 -0
  677. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  678. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  679. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  680. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  681. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  682. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  683. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  684. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  685. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  686. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  687. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  688. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  689. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  690. data/mlx/mlx/compile.cpp +1243 -0
  691. data/mlx/mlx/compile.h +45 -0
  692. data/mlx/mlx/compile_impl.h +70 -0
  693. data/mlx/mlx/device.cpp +72 -0
  694. data/mlx/mlx/device.h +56 -0
  695. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  696. data/mlx/mlx/distributed/distributed.cpp +197 -0
  697. data/mlx/mlx/distributed/distributed.h +61 -0
  698. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  699. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  700. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  701. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  702. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  703. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  704. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  705. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  706. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  707. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  708. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  709. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  710. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  711. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  712. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  713. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  714. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  715. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  716. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  717. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  718. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  719. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  720. data/mlx/mlx/distributed/ops.cpp +186 -0
  721. data/mlx/mlx/distributed/ops.h +57 -0
  722. data/mlx/mlx/distributed/primitives.cpp +95 -0
  723. data/mlx/mlx/distributed/primitives.h +156 -0
  724. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  725. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  726. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  727. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  728. data/mlx/mlx/distributed/ring/ring.h +12 -0
  729. data/mlx/mlx/distributed/utils.cpp +206 -0
  730. data/mlx/mlx/distributed/utils.h +67 -0
  731. data/mlx/mlx/dtype.cpp +197 -0
  732. data/mlx/mlx/dtype.h +116 -0
  733. data/mlx/mlx/dtype_utils.cpp +42 -0
  734. data/mlx/mlx/dtype_utils.h +119 -0
  735. data/mlx/mlx/einsum.cpp +941 -0
  736. data/mlx/mlx/einsum.h +23 -0
  737. data/mlx/mlx/event.h +58 -0
  738. data/mlx/mlx/export.cpp +1130 -0
  739. data/mlx/mlx/export.h +137 -0
  740. data/mlx/mlx/export_impl.h +99 -0
  741. data/mlx/mlx/fast.cpp +941 -0
  742. data/mlx/mlx/fast.h +103 -0
  743. data/mlx/mlx/fast_primitives.h +427 -0
  744. data/mlx/mlx/fence.h +39 -0
  745. data/mlx/mlx/fft.cpp +262 -0
  746. data/mlx/mlx/fft.h +159 -0
  747. data/mlx/mlx/graph_utils.cpp +175 -0
  748. data/mlx/mlx/graph_utils.h +67 -0
  749. data/mlx/mlx/io/CMakeLists.txt +25 -0
  750. data/mlx/mlx/io/gguf.cpp +470 -0
  751. data/mlx/mlx/io/gguf.h +20 -0
  752. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  753. data/mlx/mlx/io/load.cpp +397 -0
  754. data/mlx/mlx/io/load.h +175 -0
  755. data/mlx/mlx/io/no_gguf.cpp +20 -0
  756. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  757. data/mlx/mlx/io/safetensors.cpp +234 -0
  758. data/mlx/mlx/io.h +61 -0
  759. data/mlx/mlx/linalg.cpp +708 -0
  760. data/mlx/mlx/linalg.h +115 -0
  761. data/mlx/mlx/memory.h +80 -0
  762. data/mlx/mlx/mlx.h +25 -0
  763. data/mlx/mlx/ops.cpp +6094 -0
  764. data/mlx/mlx/ops.h +1610 -0
  765. data/mlx/mlx/primitives.cpp +5850 -0
  766. data/mlx/mlx/primitives.h +2525 -0
  767. data/mlx/mlx/random.cpp +492 -0
  768. data/mlx/mlx/random.h +283 -0
  769. data/mlx/mlx/scheduler.cpp +73 -0
  770. data/mlx/mlx/scheduler.h +189 -0
  771. data/mlx/mlx/small_vector.h +540 -0
  772. data/mlx/mlx/stream.h +42 -0
  773. data/mlx/mlx/threadpool.h +133 -0
  774. data/mlx/mlx/transforms.cpp +1065 -0
  775. data/mlx/mlx/transforms.h +231 -0
  776. data/mlx/mlx/transforms_impl.h +88 -0
  777. data/mlx/mlx/types/bf16.h +187 -0
  778. data/mlx/mlx/types/complex.h +113 -0
  779. data/mlx/mlx/types/fp16.h +234 -0
  780. data/mlx/mlx/types/half_types.h +58 -0
  781. data/mlx/mlx/types/limits.h +70 -0
  782. data/mlx/mlx/utils.cpp +302 -0
  783. data/mlx/mlx/utils.h +174 -0
  784. data/mlx/mlx/version.cpp +11 -0
  785. data/mlx/mlx/version.h +22 -0
  786. data/mlx/mlx.pc.in +52 -0
  787. data/mlx/pyproject.toml +7 -0
  788. data/mlx/python/mlx/__main__.py +27 -0
  789. data/mlx/python/mlx/_distributed_utils/common.py +135 -0
  790. data/mlx/python/mlx/_distributed_utils/config.py +631 -0
  791. data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
  792. data/mlx/python/mlx/_reprlib_fix.py +16 -0
  793. data/mlx/python/mlx/_stub_patterns.txt +36 -0
  794. data/mlx/python/mlx/extension.py +88 -0
  795. data/mlx/python/mlx/nn/__init__.py +5 -0
  796. data/mlx/python/mlx/nn/init.py +441 -0
  797. data/mlx/python/mlx/nn/layers/__init__.py +105 -0
  798. data/mlx/python/mlx/nn/layers/activations.py +661 -0
  799. data/mlx/python/mlx/nn/layers/base.py +675 -0
  800. data/mlx/python/mlx/nn/layers/containers.py +24 -0
  801. data/mlx/python/mlx/nn/layers/convolution.py +232 -0
  802. data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
  803. data/mlx/python/mlx/nn/layers/distributed.py +601 -0
  804. data/mlx/python/mlx/nn/layers/dropout.py +137 -0
  805. data/mlx/python/mlx/nn/layers/embedding.py +53 -0
  806. data/mlx/python/mlx/nn/layers/linear.py +180 -0
  807. data/mlx/python/mlx/nn/layers/normalization.py +363 -0
  808. data/mlx/python/mlx/nn/layers/pooling.py +398 -0
  809. data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
  810. data/mlx/python/mlx/nn/layers/quantized.py +426 -0
  811. data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
  812. data/mlx/python/mlx/nn/layers/transformer.py +354 -0
  813. data/mlx/python/mlx/nn/layers/upsample.py +277 -0
  814. data/mlx/python/mlx/nn/losses.py +610 -0
  815. data/mlx/python/mlx/nn/utils.py +165 -0
  816. data/mlx/python/mlx/optimizers/__init__.py +4 -0
  817. data/mlx/python/mlx/optimizers/optimizers.py +976 -0
  818. data/mlx/python/mlx/optimizers/schedulers.py +158 -0
  819. data/mlx/python/mlx/py.typed +1 -0
  820. data/mlx/python/mlx/utils.py +325 -0
  821. data/mlx/python/src/CMakeLists.txt +96 -0
  822. data/mlx/python/src/array.cpp +1525 -0
  823. data/mlx/python/src/buffer.h +124 -0
  824. data/mlx/python/src/constants.cpp +15 -0
  825. data/mlx/python/src/convert.cpp +504 -0
  826. data/mlx/python/src/convert.h +50 -0
  827. data/mlx/python/src/cuda.cpp +19 -0
  828. data/mlx/python/src/device.cpp +98 -0
  829. data/mlx/python/src/distributed.cpp +352 -0
  830. data/mlx/python/src/export.cpp +356 -0
  831. data/mlx/python/src/fast.cpp +627 -0
  832. data/mlx/python/src/fft.cpp +514 -0
  833. data/mlx/python/src/indexing.cpp +1016 -0
  834. data/mlx/python/src/indexing.h +41 -0
  835. data/mlx/python/src/linalg.cpp +663 -0
  836. data/mlx/python/src/load.cpp +531 -0
  837. data/mlx/python/src/load.h +51 -0
  838. data/mlx/python/src/memory.cpp +125 -0
  839. data/mlx/python/src/metal.cpp +98 -0
  840. data/mlx/python/src/mlx.cpp +51 -0
  841. data/mlx/python/src/mlx_func.cpp +116 -0
  842. data/mlx/python/src/mlx_func.h +31 -0
  843. data/mlx/python/src/ops.cpp +5545 -0
  844. data/mlx/python/src/random.cpp +516 -0
  845. data/mlx/python/src/small_vector.h +76 -0
  846. data/mlx/python/src/stream.cpp +147 -0
  847. data/mlx/python/src/transforms.cpp +1542 -0
  848. data/mlx/python/src/trees.cpp +311 -0
  849. data/mlx/python/src/trees.h +62 -0
  850. data/mlx/python/src/utils.cpp +98 -0
  851. data/mlx/python/src/utils.h +78 -0
  852. data/mlx/python/tests/__main__.py +5 -0
  853. data/mlx/python/tests/cuda_skip.py +62 -0
  854. data/mlx/python/tests/mlx_distributed_tests.py +314 -0
  855. data/mlx/python/tests/mlx_tests.py +116 -0
  856. data/mlx/python/tests/mpi_test_distributed.py +142 -0
  857. data/mlx/python/tests/nccl_test_distributed.py +52 -0
  858. data/mlx/python/tests/ring_test_distributed.py +131 -0
  859. data/mlx/python/tests/test_array.py +2139 -0
  860. data/mlx/python/tests/test_autograd.py +880 -0
  861. data/mlx/python/tests/test_bf16.py +196 -0
  862. data/mlx/python/tests/test_blas.py +1429 -0
  863. data/mlx/python/tests/test_compile.py +1277 -0
  864. data/mlx/python/tests/test_constants.py +41 -0
  865. data/mlx/python/tests/test_conv.py +1198 -0
  866. data/mlx/python/tests/test_conv_transpose.py +810 -0
  867. data/mlx/python/tests/test_device.py +150 -0
  868. data/mlx/python/tests/test_double.py +306 -0
  869. data/mlx/python/tests/test_einsum.py +363 -0
  870. data/mlx/python/tests/test_eval.py +200 -0
  871. data/mlx/python/tests/test_export_import.py +614 -0
  872. data/mlx/python/tests/test_fast.py +923 -0
  873. data/mlx/python/tests/test_fast_sdpa.py +647 -0
  874. data/mlx/python/tests/test_fft.py +323 -0
  875. data/mlx/python/tests/test_graph.py +37 -0
  876. data/mlx/python/tests/test_init.py +139 -0
  877. data/mlx/python/tests/test_linalg.py +621 -0
  878. data/mlx/python/tests/test_load.py +447 -0
  879. data/mlx/python/tests/test_losses.py +427 -0
  880. data/mlx/python/tests/test_memory.py +77 -0
  881. data/mlx/python/tests/test_nn.py +1986 -0
  882. data/mlx/python/tests/test_ops.py +3261 -0
  883. data/mlx/python/tests/test_optimizers.py +584 -0
  884. data/mlx/python/tests/test_quantized.py +1160 -0
  885. data/mlx/python/tests/test_random.py +392 -0
  886. data/mlx/python/tests/test_reduce.py +223 -0
  887. data/mlx/python/tests/test_tree.py +96 -0
  888. data/mlx/python/tests/test_upsample.py +100 -0
  889. data/mlx/python/tests/test_vmap.py +860 -0
  890. data/mlx/setup.py +315 -0
  891. data/mlx/tests/CMakeLists.txt +44 -0
  892. data/mlx/tests/allocator_tests.cpp +41 -0
  893. data/mlx/tests/arg_reduce_tests.cpp +204 -0
  894. data/mlx/tests/array_tests.cpp +663 -0
  895. data/mlx/tests/autograd_tests.cpp +1399 -0
  896. data/mlx/tests/blas_tests.cpp +110 -0
  897. data/mlx/tests/compile_tests.cpp +818 -0
  898. data/mlx/tests/creations_tests.cpp +239 -0
  899. data/mlx/tests/custom_vjp_tests.cpp +55 -0
  900. data/mlx/tests/device_tests.cpp +35 -0
  901. data/mlx/tests/einsum_tests.cpp +85 -0
  902. data/mlx/tests/eval_tests.cpp +93 -0
  903. data/mlx/tests/export_import_tests.cpp +164 -0
  904. data/mlx/tests/fft_tests.cpp +366 -0
  905. data/mlx/tests/gpu_tests.cpp +523 -0
  906. data/mlx/tests/linalg_tests.cpp +639 -0
  907. data/mlx/tests/load_tests.cpp +270 -0
  908. data/mlx/tests/ops_tests.cpp +4159 -0
  909. data/mlx/tests/random_tests.cpp +716 -0
  910. data/mlx/tests/scheduler_tests.cpp +121 -0
  911. data/mlx/tests/tests.cpp +26 -0
  912. data/mlx/tests/utils_tests.cpp +67 -0
  913. data/mlx/tests/vmap_tests.cpp +547 -0
  914. metadata +958 -0
data/lib/mlx/core.rb ADDED
@@ -0,0 +1,1678 @@
1
+ # frozen_string_literal: true
2
+
3
+ require "open3"
4
+ require "tmpdir"
5
+
6
+ module MLX
7
+ module Core
8
+ class NativeUnavailableError < StandardError; end
9
+
10
+ module DeviceType
11
+ module_function
12
+
13
+ def cpu
14
+ :cpu
15
+ end
16
+
17
+ def gpu
18
+ :gpu
19
+ end
20
+ end
21
+
22
+ class Finfo
23
+ FLOAT_INFO = {
24
+ "float16" => { min: -65_504.0, max: 65_504.0, eps: 9.765625e-4 },
25
+ "bfloat16" => { min: -3.389531389e38, max: 3.389531389e38, eps: 7.8125e-3 },
26
+ "float32" => { min: -3.4028235e38, max: 3.4028235e38, eps: 1.1920929e-7 },
27
+ "float64" => { min: -Float::MAX, max: Float::MAX, eps: Float::EPSILON },
28
+ "complex64" => { min: -3.4028235e38, max: 3.4028235e38, eps: 1.1920929e-7 }
29
+ }.freeze
30
+
31
+ attr_reader :dtype, :min, :max, :eps
32
+
33
+ def initialize(dtype)
34
+ @dtype = dtype
35
+ info = FLOAT_INFO[dtype_name(dtype)]
36
+ raise ArgumentError, "unsupported dtype for finfo: #{dtype_name(dtype)}" if info.nil?
37
+
38
+ @min = info[:min]
39
+ @max = info[:max]
40
+ @eps = info[:eps]
41
+ end
42
+
43
+ private
44
+
45
+ def dtype_name(dtype)
46
+ if dtype.respond_to?(:name)
47
+ dtype.name.to_s
48
+ else
49
+ dtype.to_s
50
+ end
51
+ end
52
+ end
53
+
54
+ class Iinfo
55
+ INT_INFO = {
56
+ "bool_" => { min: 0, max: 1 },
57
+ "uint8" => { min: 0, max: 255 },
58
+ "uint16" => { min: 0, max: 65_535 },
59
+ "uint32" => { min: 0, max: 4_294_967_295 },
60
+ "uint64" => { min: 0, max: 18_446_744_073_709_551_615 },
61
+ "int8" => { min: -128, max: 127 },
62
+ "int16" => { min: -32_768, max: 32_767 },
63
+ "int32" => { min: -2_147_483_648, max: 2_147_483_647 },
64
+ "int64" => { min: -9_223_372_036_854_775_808, max: 9_223_372_036_854_775_807 }
65
+ }.freeze
66
+
67
+ attr_reader :dtype, :min, :max
68
+
69
+ def initialize(dtype)
70
+ @dtype = dtype
71
+ info = INT_INFO[dtype_name(dtype)]
72
+ raise ArgumentError, "unsupported dtype for iinfo: #{dtype_name(dtype)}" if info.nil?
73
+
74
+ @min = info[:min]
75
+ @max = info[:max]
76
+ end
77
+
78
+ private
79
+
80
+ def dtype_name(dtype)
81
+ if dtype.respond_to?(:name)
82
+ dtype.name.to_s
83
+ else
84
+ dtype.to_s
85
+ end
86
+ end
87
+ end
88
+
89
+ class ArrayLike
90
+ attr_reader :object
91
+
92
+ def initialize(object)
93
+ unless object.respond_to?(:__mlx__array__)
94
+ raise TypeError, "ArrayLike requires an object that responds to __mlx__array__"
95
+ end
96
+ @object = object
97
+ end
98
+
99
+ def to_a
100
+ out = @object.__mlx__array__
101
+ raise TypeError, "__mlx__array__ must return MLX::Core::Array" unless out.is_a?(MLX::Core::Array)
102
+
103
+ out
104
+ end
105
+ end
106
+
107
+ class ArrayIterator
108
+ def initialize(array)
109
+ @array = array
110
+ @index = 0
111
+ end
112
+
113
+ def __iter__
114
+ self
115
+ end
116
+
117
+ def __next__
118
+ raise StopIteration if @index >= @array.__len__
119
+
120
+ out = @array.__getitem__(@index)
121
+ @index += 1
122
+ out
123
+ end
124
+
125
+ alias next __next__
126
+ end
127
+
128
+ class ArrayAt
129
+ def initialize(array)
130
+ @array = array
131
+ @indices = nil
132
+ end
133
+
134
+ def [](indices)
135
+ @indices = indices
136
+ self
137
+ end
138
+
139
+ def add(value)
140
+ apply(value) { |lhs, rhs| MLX::Core.add(lhs, rhs) }
141
+ end
142
+
143
+ def subtract(value)
144
+ apply(value) { |lhs, rhs| MLX::Core.subtract(lhs, rhs) }
145
+ end
146
+
147
+ def multiply(value)
148
+ apply(value) { |lhs, rhs| MLX::Core.multiply(lhs, rhs) }
149
+ end
150
+
151
+ def divide(value)
152
+ apply(value) { |lhs, rhs| MLX::Core.divide(lhs, rhs) }
153
+ end
154
+
155
+ def maximum(value)
156
+ apply(value) { |lhs, rhs| MLX::Core.maximum(lhs, rhs) }
157
+ end
158
+
159
+ def minimum(value)
160
+ apply(value) { |lhs, rhs| MLX::Core.minimum(lhs, rhs) }
161
+ end
162
+
163
+ private
164
+
165
+ def apply(value)
166
+ raise ArgumentError, "must provide indices to array.at first" if @indices.nil?
167
+
168
+ current = @array.__getitem__(@indices)
169
+ rhs = value.is_a?(MLX::Core::Array) ? value : MLX::Core.array(value, current.dtype)
170
+ updated = yield(current, rhs)
171
+ @array.__setitem__(@indices, updated)
172
+ end
173
+ end
174
+
175
+ class DLPackCapsule
176
+ attr_reader :array, :dtype, :shape, :device, :stream
177
+
178
+ def initialize(array, device:, stream: nil)
179
+ unless array.is_a?(MLX::Core::Array)
180
+ raise TypeError, "DLPackCapsule requires an MLX::Core::Array"
181
+ end
182
+
183
+ @array = array
184
+ @dtype = array.dtype
185
+ @shape = array.shape.dup.freeze
186
+ @device = device.dup.freeze
187
+ @stream = stream
188
+ end
189
+
190
+ def to_h
191
+ {
192
+ "dtype" => (dtype.respond_to?(:name) ? dtype.name.to_s : dtype.to_s),
193
+ "shape" => shape,
194
+ "device" => device,
195
+ "stream" => stream
196
+ }
197
+ end
198
+ end
199
+
200
+ class CustomFunction
201
+ def initialize(fun)
202
+ raise TypeError, "expected callable object" unless fun.respond_to?(:call)
203
+
204
+ @fun = fun
205
+ @vjp = nil
206
+ @jvp = nil
207
+ @vmap = nil
208
+ end
209
+
210
+ def call(*args, **kwargs, &block)
211
+ @fun.call(*args, **kwargs, &block)
212
+ end
213
+
214
+ def vjp(fun = nil, &block)
215
+ @vjp = fun || block
216
+ raise ArgumentError, "expected callable object" unless @vjp.respond_to?(:call)
217
+
218
+ @vjp
219
+ end
220
+
221
+ def jvp(fun = nil, &block)
222
+ @jvp = fun || block
223
+ raise ArgumentError, "expected callable object" unless @jvp.respond_to?(:call)
224
+
225
+ @jvp
226
+ end
227
+
228
+ def vmap(fun = nil, &block)
229
+ @vmap = fun || block
230
+ raise ArgumentError, "expected callable object" unless @vmap.respond_to?(:call)
231
+
232
+ @vmap
233
+ end
234
+
235
+ def custom_vjp?
236
+ !@vjp.nil?
237
+ end
238
+
239
+ def custom_jvp?
240
+ !@jvp.nil?
241
+ end
242
+
243
+ def custom_vmap?
244
+ !@vmap.nil?
245
+ end
246
+
247
+ def call_custom_vjp(primals, cotangents, outputs)
248
+ raise ArgumentError, "custom vjp is not defined" unless custom_vjp?
249
+
250
+ @vjp.call(primals, cotangents, outputs)
251
+ end
252
+
253
+ def call_custom_jvp(primals, tangents)
254
+ raise ArgumentError, "custom jvp is not defined" unless custom_jvp?
255
+
256
+ @jvp.call(primals, tangents)
257
+ end
258
+
259
+ def call_custom_vmap(inputs, axes)
260
+ raise ArgumentError, "custom vmap is not defined" unless custom_vmap?
261
+
262
+ @vmap.call(inputs, axes)
263
+ end
264
+ end
265
+
266
+ class StreamContext
267
+ def initialize(target)
268
+ @target = target
269
+ @previous_device = nil
270
+ @previous_stream = nil
271
+ end
272
+
273
+ def enter
274
+ @previous_device = MLX::Core.default_device
275
+ @previous_stream = MLX::Core.default_stream(@previous_device)
276
+ MLX::Core.native_stream(@target)
277
+ self
278
+ end
279
+
280
+ def exit(*)
281
+ return self if @previous_device.nil?
282
+
283
+ MLX::Core.set_default_device(@previous_device)
284
+ MLX::Core.set_default_stream(@previous_stream)
285
+ @previous_device = nil
286
+ @previous_stream = nil
287
+ self
288
+ end
289
+ end
290
+
291
+ PY_EXTRACT_NPZ = <<~PY.freeze
292
+ import os, sys, zipfile
293
+ src = sys.argv[1]
294
+ out_dir = sys.argv[2]
295
+ with zipfile.ZipFile(src, "r") as zf:
296
+ zf.extractall(out_dir)
297
+ PY
298
+
299
+ PY_BUILD_NPZ = <<~PY.freeze
300
+ import os, sys, zipfile
301
+ out_path = sys.argv[1]
302
+ in_dir = sys.argv[2]
303
+ compressed = sys.argv[3] == "1"
304
+ mode = zipfile.ZIP_DEFLATED if compressed else zipfile.ZIP_STORED
305
+ with zipfile.ZipFile(out_path, "w", compression=mode, allowZip64=True) as zf:
306
+ for name in sorted(os.listdir(in_dir)):
307
+ zf.write(os.path.join(in_dir, name), arcname=name)
308
+ PY
309
+
310
+ module_function
311
+
312
+ def ensure_native!
313
+ return if MLX.native_available?
314
+
315
+ raise NativeUnavailableError,
316
+ "MLX native extension is unavailable. Build ext/mlx first."
317
+ end
318
+
319
+ def available?
320
+ MLX.native_available?
321
+ end
322
+
323
+ class << self
324
+ alias_method :native_load, :load if method_defined?(:load)
325
+ alias_method :native_grad, :grad if method_defined?(:grad) && !method_defined?(:native_grad)
326
+ alias_method :native_value_and_grad,
327
+ :value_and_grad if method_defined?(:value_and_grad) && !method_defined?(:native_value_and_grad)
328
+ alias_method :native_compile, :compile if method_defined?(:compile) && !method_defined?(:native_compile)
329
+ alias_method :native_checkpoint,
330
+ :checkpoint if method_defined?(:checkpoint) && !method_defined?(:native_checkpoint)
331
+ alias_method :native_stream, :stream if method_defined?(:stream) && !method_defined?(:native_stream)
332
+ alias_method :native_jvp, :jvp if method_defined?(:jvp) && !method_defined?(:native_jvp)
333
+ alias_method :native_vjp, :vjp if method_defined?(:vjp) && !method_defined?(:native_vjp)
334
+ alias_method :native_vmap, :vmap if method_defined?(:vmap) && !method_defined?(:native_vmap)
335
+ alias_method :native_export_to_dot,
336
+ :export_to_dot if method_defined?(:export_to_dot) && !method_defined?(:native_export_to_dot)
337
+
338
+ ARRAY_LEAF = :__mlx_array_leaf__
339
+
340
+ def load(file, format = nil, return_metadata = false)
341
+ ensure_native!
342
+ format_name = (format || infer_format(file)).to_s
343
+ if format_name == "npz"
344
+ raise ArgumentError, "metadata not supported for format npz" if return_metadata
345
+
346
+ return load_npz(file)
347
+ end
348
+
349
+ native_load(file, format, return_metadata)
350
+ end
351
+
352
+ def savez(file, *args, **kwargs)
353
+ ensure_native!
354
+ save_npz(file, args, kwargs, false)
355
+ end
356
+
357
+ def savez_compressed(file, *args, **kwargs)
358
+ ensure_native!
359
+ save_npz(file, args, kwargs, true)
360
+ end
361
+
362
+ def export_to_dot(target, *outputs)
363
+ ensure_native!
364
+ raise ArgumentError, "export_to_dot expects at least one output" if outputs.empty?
365
+
366
+ if target.respond_to?(:write)
367
+ Dir.mktmpdir do |dir|
368
+ path = File.join(dir, "graph.dot")
369
+ native_export_to_dot(path, *outputs)
370
+ content = File.binread(path)
371
+ target.write(content)
372
+ target.rewind if target.respond_to?(:rewind)
373
+ content
374
+ end
375
+ else
376
+ native_export_to_dot(target, *outputs)
377
+ end
378
+ end
379
+
380
+ def full_like(array, fill_value, dtype = nil)
381
+ ensure_native!
382
+ raise TypeError, "full_like expects an MLX::Core::Array" unless array.is_a?(MLX::Core::Array)
383
+
384
+ target_dtype = dtype || array.dtype
385
+ full(array.shape, fill_value, target_dtype)
386
+ end
387
+
388
+ def grad(fun, argnums = nil, argnames = nil)
389
+ ensure_native!
390
+ if fun.is_a?(CustomFunction) && fun.custom_vjp?
391
+ return build_custom_vjp_grad_function(fun)
392
+ end
393
+
394
+ argnums_v, argnames_v = normalize_diff_targets(argnums, argnames)
395
+ build_grad_like_function(fun, argnums_v, argnames_v, false)
396
+ end
397
+
398
+ def value_and_grad(fun, argnums = nil, argnames = nil)
399
+ ensure_native!
400
+ if fun.is_a?(CustomFunction) && fun.custom_vjp?
401
+ return build_custom_vjp_value_and_grad_function(fun)
402
+ end
403
+
404
+ argnums_v, argnames_v = normalize_diff_targets(argnums, argnames)
405
+ build_grad_like_function(fun, argnums_v, argnames_v, true)
406
+ end
407
+
408
+ def compile(fun, inputs = nil, outputs = nil, shapeless = false)
409
+ ensure_native!
410
+ cache = {}
411
+
412
+ lambda do |*args, **kwargs|
413
+ flat_inputs = []
414
+ input_spec = flatten_tree_spec([args, kwargs], flat_inputs, false)
415
+ key = structure_cache_key(input_spec)
416
+
417
+ entry = cache[key]
418
+ unless entry
419
+ output_spec = nil
420
+ lifted = lambda do |*flat_vars|
421
+ rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0)
422
+ unless cursor == flat_vars.length
423
+ raise RuntimeError, "internal input reconstruction mismatch"
424
+ end
425
+
426
+ call_args = rebuilt[0]
427
+ call_kwargs = rebuilt[1]
428
+ raw_output = fun.call(*call_args, **call_kwargs)
429
+
430
+ flat_output = []
431
+ output_spec = flatten_tree_spec(raw_output, flat_output, false)
432
+ flat_output
433
+ end
434
+
435
+ compiled = native_compile(lifted, inputs, outputs, shapeless)
436
+ entry = { fn: compiled, output_spec: -> { output_spec } }
437
+ cache[key] = entry
438
+ end
439
+
440
+ flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "compiled output")
441
+ spec = entry[:output_spec].call
442
+ raise RuntimeError, "missing output structure from compiled function" if spec.nil?
443
+
444
+ rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0)
445
+ unless cursor == flat_output.length
446
+ raise RuntimeError, "internal output reconstruction mismatch"
447
+ end
448
+ rebuilt
449
+ end
450
+ end
451
+
452
+ def checkpoint(fun)
453
+ ensure_native!
454
+ cache = {}
455
+
456
+ lambda do |*args, **kwargs|
457
+ flat_inputs = []
458
+ input_spec = flatten_tree_spec([args, kwargs], flat_inputs, false)
459
+ key = structure_cache_key(input_spec)
460
+
461
+ entry = cache[key]
462
+ unless entry
463
+ output_spec = nil
464
+ lifted = lambda do |*flat_vars|
465
+ rebuilt, cursor = inflate_tree_from_arrays(input_spec, flat_vars, 0)
466
+ unless cursor == flat_vars.length
467
+ raise RuntimeError, "internal input reconstruction mismatch"
468
+ end
469
+
470
+ call_args = rebuilt[0]
471
+ call_kwargs = rebuilt[1]
472
+ raw_output = fun.call(*call_args, **call_kwargs)
473
+
474
+ flat_output = []
475
+ output_spec = flatten_tree_spec(raw_output, flat_output, false)
476
+ flat_output
477
+ end
478
+
479
+ checkpointed = native_checkpoint(lifted)
480
+ entry = { fn: checkpointed, output_spec: -> { output_spec } }
481
+ cache[key] = entry
482
+ end
483
+
484
+ flat_output = normalize_array_sequence(entry[:fn].call(*flat_inputs), "checkpoint output")
485
+ spec = entry[:output_spec].call
486
+ raise RuntimeError, "missing output structure from checkpoint function" if spec.nil?
487
+
488
+ rebuilt, cursor = inflate_tree_from_arrays(spec, flat_output, 0)
489
+ unless cursor == flat_output.length
490
+ raise RuntimeError, "internal output reconstruction mismatch"
491
+ end
492
+ rebuilt
493
+ end
494
+ end
495
+
496
+ def stream(stream_or_device, &block)
497
+ ensure_native!
498
+ if block_given?
499
+ native_stream(stream_or_device, &block)
500
+ else
501
+ StreamContext.new(stream_or_device)
502
+ end
503
+ end
504
+
505
+ def jvp(fun, primals, tangents)
506
+ ensure_native!
507
+ if fun.is_a?(CustomFunction) && fun.custom_jvp?
508
+ return custom_jvp(fun, primals, tangents)
509
+ end
510
+ native_jvp(fun, primals, tangents)
511
+ end
512
+
513
+ def vjp(fun, primals, cotangents)
514
+ ensure_native!
515
+ if fun.is_a?(CustomFunction) && fun.custom_vjp?
516
+ return custom_vjp(fun, primals, cotangents)
517
+ end
518
+ native_vjp(fun, primals, cotangents)
519
+ end
520
+
521
+ def vmap(fun, in_axes = nil, out_axes = nil)
522
+ ensure_native!
523
+ if fun.is_a?(CustomFunction) && fun.custom_vmap?
524
+ return custom_vmap_callable(fun, in_axes, out_axes)
525
+ end
526
+ native_vmap(fun, in_axes, out_axes)
527
+ end
528
+
529
+ def custom_function(fun = nil, &block)
530
+ callable = fun || block
531
+ raise ArgumentError, "custom_function requires a callable" if callable.nil?
532
+
533
+ CustomFunction.new(callable)
534
+ end
535
+
536
+ def finfo(dtype)
537
+ Finfo.new(dtype)
538
+ end
539
+
540
+ def iinfo(dtype)
541
+ Iinfo.new(dtype)
542
+ end
543
+
544
+ def from_dlpack(dlpack_value)
545
+ case dlpack_value
546
+ when MLX::Core::DLPackCapsule
547
+ dlpack_value.array
548
+ when MLX::Core::Array
549
+ dlpack_value
550
+ else
551
+ raise TypeError, "from_dlpack expects MLX::Core::DLPackCapsule or MLX::Core::Array"
552
+ end
553
+ end
554
+
555
+ private
556
+
557
+ def infer_format(file)
558
+ path = file_path(file)
559
+ ext = File.extname(path).delete_prefix(".")
560
+ raise ArgumentError, "could not infer load format from file extension" if ext.empty?
561
+
562
+ ext
563
+ end
564
+
565
+ def file_path(file)
566
+ if file.respond_to?(:to_path)
567
+ file.to_path.to_s
568
+ else
569
+ file.to_s
570
+ end
571
+ end
572
+
573
+ def python_bin
574
+ ENV.fetch("PYTHON", "python3")
575
+ end
576
+
577
+ def run_python!(*argv)
578
+ stdout, stderr, status = Open3.capture3(*argv)
579
+ return if status.success?
580
+
581
+ raise RuntimeError, <<~MSG
582
+ python command failed: #{argv.join(" ")}
583
+ stdout:
584
+ #{stdout}
585
+ stderr:
586
+ #{stderr}
587
+ MSG
588
+ end
589
+
590
+ def load_npz(file)
591
+ path = file_path(file)
592
+ Dir.mktmpdir("mlx-ruby-npz-load") do |dir|
593
+ run_python!(python_bin, "-c", PY_EXTRACT_NPZ, path, dir)
594
+ out = {}
595
+ Dir.glob(File.join(dir, "**", "*.npy")).sort.each do |npy_path|
596
+ rel = npy_path.delete_prefix(dir + File::SEPARATOR)
597
+ key = rel.end_with?(".npy") ? rel[0...-4] : rel
598
+ out[key] = native_load(npy_path, "npy", false)
599
+ end
600
+ out
601
+ end
602
+ end
603
+
604
+ def save_npz(file, args, kwargs, compressed)
605
+ path = file_path(file)
606
+ path = "#{path}.npz" unless path.end_with?(".npz")
607
+
608
+ arrays = kwargs.transform_keys(&:to_s)
609
+ args.each_with_index do |value, i|
610
+ key = "arr_#{i}"
611
+ if arrays.key?(key)
612
+ raise ArgumentError, "Cannot use un-named variables and keyword #{key}"
613
+ end
614
+ arrays[key] = value
615
+ end
616
+
617
+ Dir.mktmpdir("mlx-ruby-npz-save") do |dir|
618
+ arrays.each do |name, value|
619
+ array_value = value.is_a?(MLX::Core::Array) ? value : MLX::Core.array(value)
620
+ save(File.join(dir, "#{name}.npy"), array_value)
621
+ end
622
+ run_python!(python_bin, "-c", PY_BUILD_NPZ, path, dir, compressed ? "1" : "0")
623
+ end
624
+
625
+ nil
626
+ end
627
+
628
+ def normalize_diff_targets(argnums, argnames)
629
+ argnames_v = normalize_argnames(argnames)
630
+ argnums_v = normalize_argnums(argnums, argnames_v)
631
+ if argnums_v.empty? && argnames_v.empty?
632
+ raise ArgumentError, "Gradient wrt no argument requested"
633
+ end
634
+ [argnums_v, argnames_v]
635
+ end
636
+
637
+ def normalize_argnums(argnums, argnames)
638
+ if argnums.nil?
639
+ return argnames.empty? ? [0] : []
640
+ end
641
+ values = if argnums.is_a?(::Integer)
642
+ [argnums]
643
+ elsif argnums.is_a?(::Array)
644
+ argnums
645
+ else
646
+ raise TypeError, "argnums must be an Integer, an Array of Integer, or nil"
647
+ end
648
+ out = values.map do |value|
649
+ raise TypeError, "argnums entries must be Integer" unless value.is_a?(::Integer)
650
+ raise ArgumentError, "argnums cannot contain negative values" if value.negative?
651
+ value
652
+ end
653
+ raise ArgumentError, "duplicate argnums are not allowed" if out.uniq.length != out.length
654
+
655
+ out
656
+ end
657
+
658
+ def normalize_argnames(argnames)
659
+ return [] if argnames.nil?
660
+ values = if argnames.is_a?(::String) || argnames.is_a?(::Symbol)
661
+ [argnames]
662
+ elsif argnames.is_a?(::Array)
663
+ argnames
664
+ else
665
+ raise TypeError, "argnames must be a String, Symbol, Array, or nil"
666
+ end
667
+ out = values.map(&:to_s)
668
+ raise ArgumentError, "duplicate argnames are not allowed" if out.uniq.length != out.length
669
+
670
+ out
671
+ end
672
+
673
+ def build_grad_like_function(fun, argnums, argnames, with_value)
674
+ lambda do |*args, **kwargs|
675
+ selections, flat_inputs = build_target_selections(args, kwargs, argnums, argnames)
676
+ native_argnums = (0...flat_inputs.length).to_a
677
+ captured_value = nil
678
+ lifted = lambda do |*flat_vars|
679
+ call_args, call_kwargs = apply_flat_vars_to_targets(args, kwargs, selections, flat_vars)
680
+ raw_value = fun.call(*call_args, **call_kwargs)
681
+ captured_value = raw_value
682
+ extract_loss(raw_value)
683
+ end
684
+
685
+ if with_value
686
+ native_fn = native_value_and_grad(lifted, native_argnums)
687
+ _loss, raw_grads = native_fn.call(*flat_inputs)
688
+ value = captured_value.nil? ? fun.call(*args, **kwargs) : captured_value
689
+ [value, rebuild_grad_result(raw_grads, selections, argnames)]
690
+ else
691
+ native_fn = native_grad(lifted, native_argnums)
692
+ raw_grads = native_fn.call(*flat_inputs)
693
+ rebuild_grad_result(raw_grads, selections, argnames)
694
+ end
695
+ end
696
+ end
697
+
698
+ def build_custom_vjp_grad_function(fun)
699
+ lambda do |*args, **kwargs|
700
+ unless kwargs.empty?
701
+ raise ArgumentError, "custom-function grad currently supports positional arguments only"
702
+ end
703
+ outputs = normalize_array_output(fun.call(*args), "custom_function output")
704
+ cotangents = outputs.map { |out| MLX::Core.ones_like(out) }
705
+ output_arg = outputs.length == 1 ? outputs[0] : outputs
706
+ grads = normalize_array_output(
707
+ fun.call_custom_vjp(args, cotangents, output_arg),
708
+ "custom_function vjp output"
709
+ )
710
+ grads.length == 1 ? grads[0] : grads
711
+ end
712
+ end
713
+
714
+ def build_custom_vjp_value_and_grad_function(fun)
715
+ grad_fn = build_custom_vjp_grad_function(fun)
716
+ lambda do |*args, **kwargs|
717
+ value = fun.call(*args, **kwargs)
718
+ [value, grad_fn.call(*args, **kwargs)]
719
+ end
720
+ end
721
+
722
+ def custom_jvp(fun, primals, tangents)
723
+ primals_list = normalize_array_output(primals, "primals")
724
+ tangents_list = normalize_array_output(tangents, "tangents")
725
+ outputs = normalize_array_output(fun.call(*primals_list), "custom_function output")
726
+ jvps = normalize_array_output(
727
+ fun.call_custom_jvp(primals_list, tangents_list),
728
+ "custom_function jvp output"
729
+ )
730
+ [outputs, jvps]
731
+ end
732
+
733
+ def custom_vjp(fun, primals, cotangents)
734
+ primals_list = normalize_array_output(primals, "primals")
735
+ cotangents_list = normalize_array_output(cotangents, "cotangents")
736
+ outputs = normalize_array_output(fun.call(*primals_list), "custom_function output")
737
+ output_arg = outputs.length == 1 ? outputs[0] : outputs
738
+ vjps = normalize_array_output(
739
+ fun.call_custom_vjp(primals_list, cotangents_list, output_arg),
740
+ "custom_function vjp output"
741
+ )
742
+ [outputs, vjps]
743
+ end
744
+
745
+ def custom_vmap_callable(fun, in_axes, _out_axes)
746
+ lambda do |*args|
747
+ input_axes = if in_axes.nil?
748
+ ::Array.new(args.length, 0)
749
+ elsif in_axes.is_a?(::Integer)
750
+ ::Array.new(args.length, in_axes)
751
+ elsif in_axes.is_a?(::Array)
752
+ in_axes
753
+ else
754
+ raise TypeError, "in_axes must be Integer, Array, or nil"
755
+ end
756
+ out = fun.call_custom_vmap(args, input_axes)
757
+ if out.is_a?(::Array) && out.length == 2
758
+ out[0]
759
+ else
760
+ out
761
+ end
762
+ end
763
+ end
764
+
765
+ def extract_loss(output)
766
+ return output if output.is_a?(MLX::Core::Array)
767
+
768
+ if output.is_a?(::Array) && !output.empty? && output[0].is_a?(MLX::Core::Array)
769
+ return output[0]
770
+ end
771
+
772
+ raise ArgumentError,
773
+ "function must return an MLX::Core::Array or an Array whose first element is an MLX::Core::Array"
774
+ end
775
+
776
+ def build_target_selections(args, kwargs, argnums, argnames)
777
+ positional = []
778
+ keyword = []
779
+ flat_inputs = []
780
+
781
+ argnums.each do |index|
782
+ if index >= args.length
783
+ raise ArgumentError,
784
+ "Can't compute gradient for positional argument #{index} when #{args.length} positional arguments were provided"
785
+ end
786
+ spec = flatten_tree_spec(args[index], flat_inputs, true)
787
+ positional << { index: index, spec: spec }
788
+ end
789
+
790
+ argnames.each do |name|
791
+ key = kwarg_key_for_name(kwargs, name)
792
+ unless key
793
+ raise ArgumentError,
794
+ "Can't compute gradient for keyword argument '#{name}' because it was not provided"
795
+ end
796
+ spec = flatten_tree_spec(kwargs[key], flat_inputs, true)
797
+ keyword << { key: key, name: name, spec: spec }
798
+ end
799
+
800
+ [{ positional: positional, keyword: keyword }, flat_inputs]
801
+ end
802
+
803
+ def flatten_tree_spec(value, arrays, strict_arrays)
804
+ if value.is_a?(MLX::Core::Array)
805
+ arrays << value
806
+ return ARRAY_LEAF
807
+ end
808
+ if value.is_a?(::Array)
809
+ return [:array, value.map { |item| flatten_tree_spec(item, arrays, strict_arrays) }]
810
+ end
811
+ if value.is_a?(::Hash)
812
+ return [:hash, value.map { |k, v| [k, flatten_tree_spec(v, arrays, strict_arrays)] }]
813
+ end
814
+ if strict_arrays
815
+ raise TypeError, "[tree_flatten] The argument should contain only arrays"
816
+ end
817
+ if value.nil? || value.is_a?(::Numeric) || value.is_a?(::String) ||
818
+ value.is_a?(::Symbol) || value == true || value == false
819
+ return [:const, value]
820
+ end
821
+ raise TypeError,
822
+ "[compile] Function arguments and outputs must be trees of arrays or constants (Numeric, String, Symbol, true/false, nil)"
823
+ end
824
+
825
+ def structure_cache_key(spec)
826
+ return "A" if spec == ARRAY_LEAF
827
+
828
+ tag, payload = spec
829
+ case tag
830
+ when :array
831
+ "L[#{payload.map { |entry| structure_cache_key(entry) }.join(",")}]"
832
+ when :hash
833
+ pairs = payload.map do |key, child|
834
+ "#{key.inspect}:#{structure_cache_key(child)}"
835
+ end
836
+ "H{#{pairs.join(",")}}"
837
+ when :const
838
+ "C(#{payload.class}:#{payload.inspect})"
839
+ else
840
+ raise ArgumentError, "invalid tree specification"
841
+ end
842
+ end
843
+
844
+ def inflate_tree_from_arrays(spec, arrays, cursor)
845
+ return [arrays.fetch(cursor), cursor + 1] if spec == ARRAY_LEAF
846
+
847
+ tag, payload = spec
848
+ case tag
849
+ when :array
850
+ out = []
851
+ payload.each do |child_spec|
852
+ item, cursor = inflate_tree_from_arrays(child_spec, arrays, cursor)
853
+ out << item
854
+ end
855
+ [out, cursor]
856
+ when :hash
857
+ out = {}
858
+ payload.each do |key, child_spec|
859
+ item, cursor = inflate_tree_from_arrays(child_spec, arrays, cursor)
860
+ out[key] = item
861
+ end
862
+ [out, cursor]
863
+ when :const
864
+ [payload, cursor]
865
+ else
866
+ raise ArgumentError, "invalid tree specification"
867
+ end
868
+ end
869
+
870
+ def normalize_raw_grads(raw)
871
+ normalize_array_sequence(raw, "gradient")
872
+ end
873
+
874
+ def normalize_array_sequence(raw, context)
875
+ return [raw] if raw.is_a?(MLX::Core::Array)
876
+
877
+ if raw.is_a?(::Array) && raw.all? { |item| item.is_a?(MLX::Core::Array) }
878
+ return raw
879
+ end
880
+ raise TypeError, "unexpected #{context} return type"
881
+ end
882
+
883
+ def normalize_array_output(raw, context)
884
+ if raw.is_a?(MLX::Core::Array)
885
+ [raw]
886
+ elsif raw.is_a?(::Array) && raw.all? { |item| item.is_a?(MLX::Core::Array) }
887
+ raw
888
+ else
889
+ raise TypeError, "unexpected #{context} type"
890
+ end
891
+ end
892
+
893
+ def rebuild_grad_result(raw_grads, selections, argnames)
894
+ grad_arrays = normalize_raw_grads(raw_grads)
895
+ cursor = 0
896
+
897
+ positional_grads = selections[:positional].map do |entry|
898
+ value, cursor = inflate_tree_from_arrays(entry[:spec], grad_arrays, cursor)
899
+ value
900
+ end
901
+ keyword_grads = {}
902
+ selections[:keyword].each do |entry|
903
+ value, cursor = inflate_tree_from_arrays(entry[:spec], grad_arrays, cursor)
904
+ keyword_grads[entry[:name]] = value
905
+ end
906
+ unless cursor == grad_arrays.length
907
+ raise RuntimeError, "internal gradient reconstruction mismatch"
908
+ end
909
+
910
+ if argnames.empty?
911
+ return positional_grads[0] if positional_grads.length == 1
912
+ return positional_grads
913
+ end
914
+
915
+ positional_out = if positional_grads.empty?
916
+ nil
917
+ elsif positional_grads.length == 1
918
+ positional_grads[0]
919
+ else
920
+ positional_grads
921
+ end
922
+ [positional_out, keyword_grads]
923
+ end
924
+
925
+ def apply_flat_vars_to_targets(args, kwargs, selections, flat_vars)
926
+ rebuilt_args = args.dup
927
+ rebuilt_kwargs = kwargs.dup
928
+ cursor = 0
929
+
930
+ selections[:positional].each do |entry|
931
+ value, cursor = inflate_tree_from_arrays(entry[:spec], flat_vars, cursor)
932
+ rebuilt_args[entry[:index]] = value
933
+ end
934
+
935
+ selections[:keyword].each do |entry|
936
+ value, cursor = inflate_tree_from_arrays(entry[:spec], flat_vars, cursor)
937
+ rebuilt_kwargs[entry[:key]] = value
938
+ end
939
+
940
+ unless cursor == flat_vars.length
941
+ raise RuntimeError, "internal target reconstruction mismatch"
942
+ end
943
+ [rebuilt_args, rebuilt_kwargs]
944
+ end
945
+
946
+ def kwarg_key_for_name(kwargs, name)
947
+ symbol = name.to_sym
948
+ return symbol if kwargs.key?(symbol)
949
+ return name if kwargs.key?(name)
950
+
951
+ nil
952
+ end
953
+ end
954
+
955
+ class Device
956
+ alias_method :native_equal, :== if method_defined?(:==) && !method_defined?(:native_equal)
957
+
958
+ def ==(other)
959
+ if other.is_a?(::Symbol) || other.is_a?(::String)
960
+ type == other.to_sym
961
+ else
962
+ native_equal(other)
963
+ end
964
+ end
965
+
966
+ alias eql? ==
967
+ end
968
+
969
+ class Array
970
+ EPSILON_BY_DTYPE = {
971
+ "float16" => 9.765625e-4,
972
+ "bfloat16" => 7.8125e-3,
973
+ "float32" => 1.1920929e-7,
974
+ "float64" => Float::EPSILON,
975
+ "complex64" => 1.1920929e-7
976
+ }.freeze
977
+
978
+ def T
979
+ transpose
980
+ end
981
+
982
+ def at
983
+ ArrayAt.new(self)
984
+ end
985
+
986
+ def real
987
+ MLX::Core.real(self)
988
+ end
989
+
990
+ def imag
991
+ MLX::Core.imag(self)
992
+ end
993
+
994
+ def itemsize
995
+ dtype.size
996
+ end
997
+
998
+ def nbytes
999
+ size * itemsize
1000
+ end
1001
+
1002
+ def add(other)
1003
+ MLX::Core.add(self, other)
1004
+ end
1005
+
1006
+ def subtract(other)
1007
+ MLX::Core.subtract(self, other)
1008
+ end
1009
+
1010
+ def multiply(other)
1011
+ MLX::Core.multiply(self, other)
1012
+ end
1013
+
1014
+ def divide(other)
1015
+ MLX::Core.divide(self, other)
1016
+ end
1017
+
1018
+ def exp
1019
+ MLX::Core.exp(self)
1020
+ end
1021
+
1022
+ def sin
1023
+ MLX::Core.sin(self)
1024
+ end
1025
+
1026
+ def cos
1027
+ MLX::Core.cos(self)
1028
+ end
1029
+
1030
+ def mean(axis = nil)
1031
+ MLX::Core.mean(self, axis)
1032
+ end
1033
+
1034
+ def sum(axis = nil)
1035
+ MLX::Core.sum(self, axis)
1036
+ end
1037
+
1038
+ def var(axis = nil, keepdims = nil, ddof = nil)
1039
+ MLX::Core.var(self, axis, keepdims, ddof)
1040
+ end
1041
+
1042
+ def std(axis = nil, keepdims = nil, ddof = nil)
1043
+ MLX::Core.std(self, axis, keepdims, ddof)
1044
+ end
1045
+
1046
+ def max(axis = nil, keepdims = nil)
1047
+ MLX::Core.max(self, axis, keepdims)
1048
+ end
1049
+
1050
+ def min(axis = nil, keepdims = nil)
1051
+ MLX::Core.min(self, axis, keepdims)
1052
+ end
1053
+
1054
+ def reshape(*shape)
1055
+ target = shape.length == 1 ? shape[0] : shape
1056
+ MLX::Core.reshape(self, target)
1057
+ end
1058
+
1059
+ def transpose(axes = nil)
1060
+ MLX::Core.transpose(self, axes)
1061
+ end
1062
+
1063
+ def squeeze(axis = nil)
1064
+ MLX::Core.squeeze(self, axis)
1065
+ end
1066
+
1067
+ def square
1068
+ MLX::Core.square(self)
1069
+ end
1070
+
1071
+ def sqrt
1072
+ MLX::Core.sqrt(self)
1073
+ end
1074
+
1075
+ def rsqrt
1076
+ MLX::Core.rsqrt(self)
1077
+ end
1078
+
1079
+ def reciprocal
1080
+ MLX::Core.reciprocal(self)
1081
+ end
1082
+
1083
+ def abs
1084
+ MLX::Core.abs(self)
1085
+ end
1086
+
1087
+ def all(axis = nil, keepdims = nil)
1088
+ MLX::Core.all(self, axis, keepdims)
1089
+ end
1090
+
1091
+ def any(axis = nil, keepdims = nil)
1092
+ MLX::Core.any(self, axis, keepdims)
1093
+ end
1094
+
1095
+ def argmax(axis = nil, keepdims = nil)
1096
+ MLX::Core.argmax(self, axis, keepdims)
1097
+ end
1098
+
1099
+ def argmin(axis = nil, keepdims = nil)
1100
+ MLX::Core.argmin(self, axis, keepdims)
1101
+ end
1102
+
1103
+ def astype(dtype, stream = nil)
1104
+ if stream.nil?
1105
+ MLX::Core.astype(self, dtype)
1106
+ else
1107
+ MLX::Core.astype(self, dtype, stream)
1108
+ end
1109
+ end
1110
+
1111
+ def conj
1112
+ MLX::Core.conj(self)
1113
+ end
1114
+
1115
+ def cummax(*args)
1116
+ MLX::Core.cummax(self, *args)
1117
+ end
1118
+
1119
+ def cummin(*args)
1120
+ MLX::Core.cummin(self, *args)
1121
+ end
1122
+
1123
+ def cumprod(*args)
1124
+ MLX::Core.cumprod(self, *args)
1125
+ end
1126
+
1127
+ def cumsum(*args)
1128
+ MLX::Core.cumsum(self, *args)
1129
+ end
1130
+
1131
+ def diag(*args)
1132
+ MLX::Core.diag(self, *args)
1133
+ end
1134
+
1135
+ def diagonal(*args)
1136
+ MLX::Core.diagonal(self, *args)
1137
+ end
1138
+
1139
+ def flatten(start_axis = 0, end_axis = -1)
1140
+ MLX::Core.flatten(self, start_axis, end_axis)
1141
+ end
1142
+
1143
+ def log
1144
+ MLX::Core.log(self)
1145
+ end
1146
+
1147
+ def log10
1148
+ MLX::Core.log10(self)
1149
+ end
1150
+
1151
+ def log1p
1152
+ MLX::Core.log1p(self)
1153
+ end
1154
+
1155
+ def log2
1156
+ MLX::Core.log2(self)
1157
+ end
1158
+
1159
+ def logcumsumexp(*args)
1160
+ MLX::Core.logcumsumexp(self, *args)
1161
+ end
1162
+
1163
+ def logsumexp(*args)
1164
+ MLX::Core.logsumexp(self, *args)
1165
+ end
1166
+
1167
+ def maximum(other)
1168
+ MLX::Core.maximum(self, other)
1169
+ end
1170
+
1171
+ def minimum(other)
1172
+ MLX::Core.minimum(self, other)
1173
+ end
1174
+
1175
+ def moveaxis(source, destination)
1176
+ MLX::Core.moveaxis(self, source, destination)
1177
+ end
1178
+
1179
+ def prod(axis = nil, keepdims = nil)
1180
+ MLX::Core.prod(self, axis, keepdims)
1181
+ end
1182
+
1183
+ def round(decimals = 0)
1184
+ MLX::Core.round(self, decimals)
1185
+ end
1186
+
1187
+ def split(indices_or_sections, axis = 0)
1188
+ MLX::Core.split(self, indices_or_sections, axis)
1189
+ end
1190
+
1191
+ def swapaxes(axis1, axis2)
1192
+ MLX::Core.swapaxes(self, axis1, axis2)
1193
+ end
1194
+
1195
+ def view(dtype)
1196
+ MLX::Core.view(self, dtype)
1197
+ end
1198
+
1199
+ def eps
1200
+ dtype_name = if dtype.respond_to?(:name)
1201
+ dtype.name.to_s
1202
+ else
1203
+ dtype.to_s
1204
+ end
1205
+ EPSILON_BY_DTYPE.fetch(dtype_name, Float::EPSILON)
1206
+ end
1207
+
1208
+ def tolist
1209
+ to_a
1210
+ end
1211
+
1212
+ def __add__(other)
1213
+ add(other)
1214
+ end
1215
+
1216
+ def __sub__(other)
1217
+ subtract(other)
1218
+ end
1219
+
1220
+ def __mul__(other)
1221
+ multiply(other)
1222
+ end
1223
+
1224
+ def __truediv__(other)
1225
+ divide(other)
1226
+ end
1227
+
1228
+ def __div__(other)
1229
+ __truediv__(other)
1230
+ end
1231
+
1232
+ def __matmul__(other)
1233
+ MLX::Core.matmul(self, other)
1234
+ end
1235
+
1236
+ def __imatmul__(other)
1237
+ __matmul__(other)
1238
+ end
1239
+
1240
+ def __len__
1241
+ shape.first || 0
1242
+ end
1243
+
1244
+ def __iter__
1245
+ ArrayIterator.new(self)
1246
+ end
1247
+
1248
+ def __next__
1249
+ @__mlx_array_iterator ||= __iter__
1250
+ @__mlx_array_iterator.__next__
1251
+ end
1252
+
1253
+ def __init__(*_)
1254
+ self
1255
+ end
1256
+
1257
+ def __repr__
1258
+ inspect
1259
+ end
1260
+
1261
+ def __bool__
1262
+ raise ArgumentError, "The truth value of an array with more than one element is ambiguous" if size != 1
1263
+
1264
+ !!item
1265
+ end
1266
+
1267
+ def __int__
1268
+ raise ArgumentError, "only size-1 arrays can be converted to Integer" if size != 1
1269
+
1270
+ Integer(item)
1271
+ end
1272
+
1273
+ def __float__
1274
+ raise ArgumentError, "only size-1 arrays can be converted to Float" if size != 1
1275
+
1276
+ Float(item)
1277
+ end
1278
+
1279
+ def __hash__
1280
+ object_id.hash
1281
+ end
1282
+
1283
+ def __array_namespace__
1284
+ MLX::Core
1285
+ end
1286
+
1287
+ def __eq__(other)
1288
+ MLX::Core.equal(self, other)
1289
+ end
1290
+
1291
+ def __ne__(other)
1292
+ MLX::Core.not_equal(self, other)
1293
+ end
1294
+
1295
+ def __abs__
1296
+ MLX::Core.abs(self)
1297
+ end
1298
+
1299
+ def __neg__
1300
+ MLX::Core.negative(self)
1301
+ end
1302
+
1303
+ def __pow__(other)
1304
+ MLX::Core.power(self, other)
1305
+ end
1306
+
1307
+ def __rpow__(other)
1308
+ MLX::Core.power(other, self)
1309
+ end
1310
+
1311
+ def __floordiv__(other)
1312
+ MLX::Core.floor_divide(self, other)
1313
+ end
1314
+
1315
+ def __mod__(other)
1316
+ MLX::Core.remainder(self, other)
1317
+ end
1318
+
1319
+ def __rmod__(other)
1320
+ MLX::Core.remainder(other, self)
1321
+ end
1322
+
1323
+ def __radd__(other)
1324
+ MLX::Core.add(other, self)
1325
+ end
1326
+
1327
+ def __rsub__(other)
1328
+ MLX::Core.subtract(other, self)
1329
+ end
1330
+
1331
+ def __rmul__(other)
1332
+ MLX::Core.multiply(other, self)
1333
+ end
1334
+
1335
+ def __rtruediv__(other)
1336
+ MLX::Core.divide(other, self)
1337
+ end
1338
+
1339
+ def __rdiv__(other)
1340
+ __rtruediv__(other)
1341
+ end
1342
+
1343
+ def __and__(other)
1344
+ MLX::Core.bitwise_and(self, other)
1345
+ end
1346
+
1347
+ def __or__(other)
1348
+ MLX::Core.bitwise_or(self, other)
1349
+ end
1350
+
1351
+ def __xor__(other)
1352
+ MLX::Core.bitwise_xor(self, other)
1353
+ end
1354
+
1355
+ def __invert__
1356
+ MLX::Core.bitwise_invert(self)
1357
+ end
1358
+
1359
+ def __lshift__(other)
1360
+ MLX::Core.left_shift(self, other)
1361
+ end
1362
+
1363
+ def __rshift__(other)
1364
+ MLX::Core.right_shift(self, other)
1365
+ end
1366
+
1367
+ def __lt__(other)
1368
+ MLX::Core.less(self, other)
1369
+ end
1370
+
1371
+ def __le__(other)
1372
+ MLX::Core.less_equal(self, other)
1373
+ end
1374
+
1375
+ def __gt__(other)
1376
+ MLX::Core.greater(self, other)
1377
+ end
1378
+
1379
+ def __ge__(other)
1380
+ MLX::Core.greater_equal(self, other)
1381
+ end
1382
+
1383
+ def __iadd__(other)
1384
+ __add__(other)
1385
+ end
1386
+
1387
+ def __isub__(other)
1388
+ __sub__(other)
1389
+ end
1390
+
1391
+ def __imul__(other)
1392
+ __mul__(other)
1393
+ end
1394
+
1395
+ def __itruediv__(other)
1396
+ __truediv__(other)
1397
+ end
1398
+
1399
+ def __ifloordiv__(other)
1400
+ __floordiv__(other)
1401
+ end
1402
+
1403
+ def __imod__(other)
1404
+ __mod__(other)
1405
+ end
1406
+
1407
+ def __ipow__(other)
1408
+ __pow__(other)
1409
+ end
1410
+
1411
+ def __iand__(other)
1412
+ __and__(other)
1413
+ end
1414
+
1415
+ def __ior__(other)
1416
+ __or__(other)
1417
+ end
1418
+
1419
+ def __ixor__(other)
1420
+ __xor__(other)
1421
+ end
1422
+
1423
+ def __ilshift__(other)
1424
+ __lshift__(other)
1425
+ end
1426
+
1427
+ def __irshift__(other)
1428
+ __rshift__(other)
1429
+ end
1430
+
1431
+ def __rfloordiv__(other)
1432
+ MLX::Core.floor_divide(other, self)
1433
+ end
1434
+
1435
+ def __getitem__(index)
1436
+ self[index]
1437
+ end
1438
+
1439
+ def __setitem__(index, value)
1440
+ fast_path = __setitem_1d_device_fast_path(index, value)
1441
+ return fast_path unless fast_path.nil?
1442
+
1443
+ copy = __ruby_deep_copy(to_a)
1444
+ replacement = value.is_a?(MLX::Core::Array) ? value.to_a : value
1445
+ __apply_setitem!(copy, index, replacement)
1446
+ MLX::Core.array(copy, dtype)
1447
+ end
1448
+
1449
+ def __copy__
1450
+ MLX::Core.array(to_a, dtype)
1451
+ end
1452
+
1453
+ def __deepcopy__(_memo = nil)
1454
+ __copy__
1455
+ end
1456
+
1457
+ def __getstate__
1458
+ dtype_name = if dtype.respond_to?(:name)
1459
+ dtype.name.to_s
1460
+ else
1461
+ dtype.to_s
1462
+ end
1463
+ {
1464
+ "values" => to_a,
1465
+ "dtype" => dtype_name
1466
+ }
1467
+ end
1468
+
1469
+ def __setstate__(state)
1470
+ values = state["values"] || state[:values]
1471
+ dtype_name = state["dtype"] || state[:dtype]
1472
+ if !dtype_name.nil? && MLX::Core.respond_to?(dtype_name.to_sym)
1473
+ MLX::Core.array(values, MLX::Core.public_send(dtype_name.to_sym))
1474
+ else
1475
+ MLX::Core.array(values)
1476
+ end
1477
+ end
1478
+
1479
+ def __format__(format_spec = "")
1480
+ if size == 1 && !format_spec.to_s.empty?
1481
+ kernel = Kernel.format(format_spec, item)
1482
+ return kernel
1483
+ end
1484
+ to_a.to_s
1485
+ end
1486
+
1487
+ def __dlpack__(stream = nil)
1488
+ unless stream.nil? || stream.is_a?(::Integer)
1489
+ raise ArgumentError, "__dlpack__ stream must be nil or Integer"
1490
+ end
1491
+
1492
+ MLX::Core::DLPackCapsule.new(self, device: __dlpack_device, stream: stream)
1493
+ end
1494
+
1495
+ def __dlpack_device
1496
+ device = MLX::Core.default_device
1497
+ type_id = case device.type
1498
+ when :cpu
1499
+ 1
1500
+ when :gpu
1501
+ MLX::Core.metal_is_available ? 8 : 13
1502
+ else
1503
+ device.type
1504
+ end
1505
+ [type_id, device.index]
1506
+ end
1507
+
1508
+ alias __dlpack_device__ __dlpack_device
1509
+
1510
+ private
1511
+
1512
+ def __setitem_1d_device_fast_path(index, replacement)
1513
+ return nil unless ndim == 1
1514
+
1515
+ if index.is_a?(::Integer)
1516
+ normalized = __normalize_1d_index(index)
1517
+ index_array = MLX::Core.array([normalized], MLX::Core.int32)
1518
+ values_array = __coerce_setitem_values_1d(replacement, 1)
1519
+ return MLX::Core.put_along_axis(self, index_array, values_array, 0)
1520
+ end
1521
+
1522
+ if index.is_a?(MLX::Core::Array) && index.ndim == 1
1523
+ case __dtype_name(index.dtype)
1524
+ when "bool_"
1525
+ return nil unless index.shape[0] == shape[0]
1526
+
1527
+ replacement_array = __coerce_setitem_mask_values_1d(replacement, shape[0])
1528
+ return MLX::Core.where(index, replacement_array, self)
1529
+ when "uint8", "uint16", "uint32", "uint64", "int8", "int16", "int32", "int64"
1530
+ index_array = index.astype(MLX::Core.int32)
1531
+ values_array = __coerce_setitem_values_1d(replacement, index_array.size)
1532
+ return MLX::Core.put_along_axis(self, index_array, values_array, 0)
1533
+ end
1534
+ end
1535
+
1536
+ return nil unless index.is_a?(::Array) && index.all? { |entry| entry.is_a?(::Integer) }
1537
+
1538
+ normalized = index.map { |entry| __normalize_1d_index(entry) }
1539
+ index_array = MLX::Core.array(normalized, MLX::Core.int32)
1540
+ values_array = __coerce_setitem_values_1d(replacement, normalized.length)
1541
+ MLX::Core.put_along_axis(self, index_array, values_array, 0)
1542
+ rescue StandardError
1543
+ nil
1544
+ end
1545
+
1546
+ def __normalize_1d_index(index)
1547
+ size_1d = shape[0]
1548
+ normalized = index
1549
+ normalized += size_1d if normalized.negative?
1550
+ if normalized.negative? || normalized >= size_1d
1551
+ raise IndexError, "index out of range"
1552
+ end
1553
+
1554
+ normalized
1555
+ end
1556
+
1557
+ def __coerce_setitem_values_1d(values, count)
1558
+ case values
1559
+ when MLX::Core::Array
1560
+ return MLX::Core.full([count], values, dtype) if values.size == 1 && count > 1
1561
+ raise ArgumentError, "__setitem__ replacement values must match index list length" if values.size != count
1562
+
1563
+ MLX::Core.reshape(values.astype(dtype), [count])
1564
+ when ::Array
1565
+ value_array = MLX::Core.array(values, dtype)
1566
+ return MLX::Core.full([count], value_array, dtype) if value_array.size == 1 && count > 1
1567
+ raise ArgumentError, "__setitem__ replacement values must match index list length" if value_array.size != count
1568
+
1569
+ MLX::Core.reshape(value_array, [count])
1570
+ else
1571
+ MLX::Core.full([count], values, dtype)
1572
+ end
1573
+ end
1574
+
1575
+ def __coerce_setitem_mask_values_1d(values, count)
1576
+ case values
1577
+ when MLX::Core::Array
1578
+ return MLX::Core.full([count], values, dtype) if values.size == 1
1579
+ raise ArgumentError, "__setitem__ replacement values must match mask length" if values.size != count
1580
+
1581
+ MLX::Core.reshape(values.astype(dtype), [count])
1582
+ when ::Array
1583
+ value_array = MLX::Core.array(values, dtype)
1584
+ return MLX::Core.full([count], value_array, dtype) if value_array.size == 1
1585
+ raise ArgumentError, "__setitem__ replacement values must match mask length" if value_array.size != count
1586
+
1587
+ MLX::Core.reshape(value_array, [count])
1588
+ else
1589
+ MLX::Core.full([count], values, dtype)
1590
+ end
1591
+ end
1592
+
1593
+ def __dtype_name(dtype_obj)
1594
+ if dtype_obj.respond_to?(:name)
1595
+ dtype_obj.name.to_s
1596
+ else
1597
+ dtype_obj.to_s
1598
+ end
1599
+ end
1600
+
1601
+ def __apply_setitem!(data, index, replacement)
1602
+ if index.is_a?(::Integer)
1603
+ data[index] = replacement
1604
+ return
1605
+ end
1606
+
1607
+ normalized = if index.is_a?(MLX::Core::Array)
1608
+ index.to_a
1609
+ elsif index.is_a?(::Array)
1610
+ index
1611
+ else
1612
+ raise ArgumentError, "__setitem__ supports Integer, Integer list, or boolean mask indices"
1613
+ end
1614
+
1615
+ unless data.is_a?(::Array)
1616
+ raise ArgumentError, "__setitem__ list/mask indices require array values"
1617
+ end
1618
+
1619
+ if normalized.all? { |v| v == true || v == false }
1620
+ __apply_boolean_mask_setitem!(data, normalized, replacement)
1621
+ return
1622
+ end
1623
+
1624
+ unless normalized.all? { |v| v.is_a?(::Integer) }
1625
+ raise ArgumentError, "__setitem__ list indices must be all Integers or all booleans"
1626
+ end
1627
+
1628
+ __apply_integer_list_setitem!(data, normalized, replacement)
1629
+ end
1630
+
1631
+ def __apply_boolean_mask_setitem!(data, mask, replacement)
1632
+ if mask.length != data.length
1633
+ raise ArgumentError, "__setitem__ boolean mask must match array length"
1634
+ end
1635
+
1636
+ replacement_values = replacement.is_a?(::Array) ? replacement.flatten : nil
1637
+ replacement_index = 0
1638
+
1639
+ mask.each_with_index do |flag, i|
1640
+ next unless flag
1641
+
1642
+ if replacement_values
1643
+ if replacement_index >= replacement_values.length
1644
+ raise ArgumentError, "__setitem__ replacement values shorter than mask true count"
1645
+ end
1646
+ data[i] = replacement_values[replacement_index]
1647
+ replacement_index += 1
1648
+ else
1649
+ data[i] = replacement
1650
+ end
1651
+ end
1652
+ end
1653
+
1654
+ def __apply_integer_list_setitem!(data, indices, replacement)
1655
+ if replacement.is_a?(::Array)
1656
+ values = replacement.flatten
1657
+ if values.length == 1
1658
+ indices.each { |i| data[i] = values[0] }
1659
+ return
1660
+ end
1661
+ if values.length != indices.length
1662
+ raise ArgumentError, "__setitem__ replacement values must match index list length"
1663
+ end
1664
+
1665
+ indices.each_with_index { |i, offset| data[i] = values[offset] }
1666
+ else
1667
+ indices.each { |i| data[i] = replacement }
1668
+ end
1669
+ end
1670
+
1671
+ def __ruby_deep_copy(value)
1672
+ return value.map { |item| __ruby_deep_copy(item) } if value.is_a?(::Array)
1673
+
1674
+ value
1675
+ end
1676
+ end
1677
+ end
1678
+ end