mlx 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlx might be problematic. Click here for more details.

Files changed (914) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/CMakeLists.txt +7 -0
  3. data/ext/mlx/Makefile +273 -0
  4. data/ext/mlx/extconf.rb +94 -0
  5. data/ext/mlx/mkmf.log +44 -0
  6. data/ext/mlx/native.bundle +0 -0
  7. data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
  8. data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
  9. data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
  10. data/ext/mlx/native.cpp +8027 -0
  11. data/ext/mlx/native.o +0 -0
  12. data/lib/mlx/core.rb +1678 -0
  13. data/lib/mlx/distributed_utils/common.rb +116 -0
  14. data/lib/mlx/distributed_utils/config.rb +600 -0
  15. data/lib/mlx/distributed_utils/launch.rb +490 -0
  16. data/lib/mlx/extension.rb +24 -0
  17. data/lib/mlx/nn/base.rb +388 -0
  18. data/lib/mlx/nn/init.rb +140 -0
  19. data/lib/mlx/nn/layers/activations.rb +336 -0
  20. data/lib/mlx/nn/layers/base.rb +6 -0
  21. data/lib/mlx/nn/layers/containers.rb +20 -0
  22. data/lib/mlx/nn/layers/convolution.rb +120 -0
  23. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  24. data/lib/mlx/nn/layers/distributed.rb +309 -0
  25. data/lib/mlx/nn/layers/dropout.rb +75 -0
  26. data/lib/mlx/nn/layers/embedding.rb +28 -0
  27. data/lib/mlx/nn/layers/linear.rb +79 -0
  28. data/lib/mlx/nn/layers/normalization.rb +216 -0
  29. data/lib/mlx/nn/layers/pooling.rb +167 -0
  30. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  31. data/lib/mlx/nn/layers/quantized.rb +215 -0
  32. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  33. data/lib/mlx/nn/layers/transformer.rb +330 -0
  34. data/lib/mlx/nn/layers/upsample.rb +97 -0
  35. data/lib/mlx/nn/layers.rb +18 -0
  36. data/lib/mlx/nn/losses.rb +251 -0
  37. data/lib/mlx/nn/utils.rb +167 -0
  38. data/lib/mlx/nn.rb +12 -0
  39. data/lib/mlx/optimizers/optimizers.rb +808 -0
  40. data/lib/mlx/optimizers/schedulers.rb +62 -0
  41. data/lib/mlx/optimizers.rb +9 -0
  42. data/lib/mlx/utils.rb +171 -0
  43. data/lib/mlx/version +1 -0
  44. data/lib/mlx/version.rb +5 -0
  45. data/lib/mlx.rb +64 -0
  46. data/mlx/.clang-format +87 -0
  47. data/mlx/.git +1 -0
  48. data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
  49. data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
  50. data/mlx/.github/actions/build-docs/action.yml +38 -0
  51. data/mlx/.github/actions/build-linux/action.yml +38 -0
  52. data/mlx/.github/actions/build-linux-release/action.yml +42 -0
  53. data/mlx/.github/actions/build-macos/action.yml +80 -0
  54. data/mlx/.github/actions/build-macos-release/action.yml +36 -0
  55. data/mlx/.github/actions/build-windows/action.yml +26 -0
  56. data/mlx/.github/actions/setup-linux/action.yml +93 -0
  57. data/mlx/.github/actions/setup-macos/action.yml +24 -0
  58. data/mlx/.github/actions/setup-windows/action.yml +42 -0
  59. data/mlx/.github/actions/test-linux/action.yml +69 -0
  60. data/mlx/.github/actions/test-windows/action.yml +20 -0
  61. data/mlx/.github/dependabot.yml +6 -0
  62. data/mlx/.github/pull_request_template.md +12 -0
  63. data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
  64. data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
  65. data/mlx/.github/workflows/build_and_test.yml +152 -0
  66. data/mlx/.github/workflows/documentation.yml +28 -0
  67. data/mlx/.github/workflows/nightly.yml +104 -0
  68. data/mlx/.github/workflows/release.yml +256 -0
  69. data/mlx/.gitignore +81 -0
  70. data/mlx/.pre-commit-config.yaml +27 -0
  71. data/mlx/ACKNOWLEDGMENTS.md +268 -0
  72. data/mlx/CITATION.cff +24 -0
  73. data/mlx/CMakeLists.txt +437 -0
  74. data/mlx/CODE_OF_CONDUCT.md +132 -0
  75. data/mlx/CONTRIBUTING.md +38 -0
  76. data/mlx/LICENSE +21 -0
  77. data/mlx/MANIFEST.in +6 -0
  78. data/mlx/README.md +121 -0
  79. data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
  80. data/mlx/benchmarks/cpp/autograd.cpp +39 -0
  81. data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
  82. data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
  83. data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
  84. data/mlx/benchmarks/cpp/time_utils.h +39 -0
  85. data/mlx/benchmarks/numpy/single_ops.py +39 -0
  86. data/mlx/benchmarks/numpy/time_utils.py +20 -0
  87. data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
  88. data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
  89. data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
  90. data/mlx/benchmarks/python/comparative/README.md +15 -0
  91. data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
  92. data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
  93. data/mlx/benchmarks/python/comparative/compare.py +284 -0
  94. data/mlx/benchmarks/python/compile_bench.py +107 -0
  95. data/mlx/benchmarks/python/conv1d_bench.py +123 -0
  96. data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
  97. data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
  98. data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
  99. data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
  100. data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
  101. data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
  102. data/mlx/benchmarks/python/conv_bench.py +135 -0
  103. data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
  104. data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
  105. data/mlx/benchmarks/python/distributed_bench.py +66 -0
  106. data/mlx/benchmarks/python/einsum_bench.py +84 -0
  107. data/mlx/benchmarks/python/fft_bench.py +118 -0
  108. data/mlx/benchmarks/python/gather_bench.py +52 -0
  109. data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
  110. data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
  111. data/mlx/benchmarks/python/hadamard_bench.py +70 -0
  112. data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
  113. data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
  114. data/mlx/benchmarks/python/masked_scatter.py +212 -0
  115. data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
  116. data/mlx/benchmarks/python/rope_bench.py +35 -0
  117. data/mlx/benchmarks/python/scatter_bench.py +96 -0
  118. data/mlx/benchmarks/python/sdpa_bench.py +223 -0
  119. data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
  120. data/mlx/benchmarks/python/single_ops.py +132 -0
  121. data/mlx/benchmarks/python/synchronize_bench.py +55 -0
  122. data/mlx/benchmarks/python/time_utils.py +38 -0
  123. data/mlx/cmake/FindCUDNN.cmake +177 -0
  124. data/mlx/cmake/FindNCCL.cmake +54 -0
  125. data/mlx/cmake/Findnvpl.cmake +3 -0
  126. data/mlx/cmake/extension.cmake +50 -0
  127. data/mlx/docs/.clang-format +2 -0
  128. data/mlx/docs/.gitignore +3 -0
  129. data/mlx/docs/.nojekyll +0 -0
  130. data/mlx/docs/Doxyfile +51 -0
  131. data/mlx/docs/Makefile +18 -0
  132. data/mlx/docs/README.md +54 -0
  133. data/mlx/docs/index.html +1 -0
  134. data/mlx/docs/requirements.txt +5 -0
  135. data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
  136. data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
  137. data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
  138. data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
  139. data/mlx/docs/src/_static/mlx_logo.png +0 -0
  140. data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
  141. data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
  142. data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
  143. data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
  144. data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
  145. data/mlx/docs/src/_templates/module-base-class.rst +33 -0
  146. data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
  147. data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
  148. data/mlx/docs/src/conf.py +99 -0
  149. data/mlx/docs/src/cpp/ops.rst +7 -0
  150. data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
  151. data/mlx/docs/src/dev/extensions.rst +811 -0
  152. data/mlx/docs/src/dev/metal_debugger.rst +68 -0
  153. data/mlx/docs/src/dev/metal_logging.rst +40 -0
  154. data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
  155. data/mlx/docs/src/examples/data_parallelism.rst +91 -0
  156. data/mlx/docs/src/examples/linear_regression.rst +77 -0
  157. data/mlx/docs/src/examples/llama-inference.rst +382 -0
  158. data/mlx/docs/src/examples/mlp.rst +134 -0
  159. data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
  160. data/mlx/docs/src/index.rst +96 -0
  161. data/mlx/docs/src/install.rst +340 -0
  162. data/mlx/docs/src/python/array.rst +65 -0
  163. data/mlx/docs/src/python/cuda.rst +9 -0
  164. data/mlx/docs/src/python/data_types.rst +78 -0
  165. data/mlx/docs/src/python/devices_and_streams.rst +21 -0
  166. data/mlx/docs/src/python/distributed.rst +22 -0
  167. data/mlx/docs/src/python/export.rst +14 -0
  168. data/mlx/docs/src/python/fast.rst +16 -0
  169. data/mlx/docs/src/python/fft.rst +24 -0
  170. data/mlx/docs/src/python/linalg.rst +27 -0
  171. data/mlx/docs/src/python/memory_management.rst +16 -0
  172. data/mlx/docs/src/python/metal.rst +12 -0
  173. data/mlx/docs/src/python/nn/distributed.rst +30 -0
  174. data/mlx/docs/src/python/nn/functions.rst +40 -0
  175. data/mlx/docs/src/python/nn/init.rst +45 -0
  176. data/mlx/docs/src/python/nn/layers.rst +74 -0
  177. data/mlx/docs/src/python/nn/losses.rst +25 -0
  178. data/mlx/docs/src/python/nn/module.rst +38 -0
  179. data/mlx/docs/src/python/nn.rst +186 -0
  180. data/mlx/docs/src/python/ops.rst +184 -0
  181. data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
  182. data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
  183. data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
  184. data/mlx/docs/src/python/optimizers.rst +78 -0
  185. data/mlx/docs/src/python/random.rst +48 -0
  186. data/mlx/docs/src/python/transforms.rst +22 -0
  187. data/mlx/docs/src/python/tree_utils.rst +23 -0
  188. data/mlx/docs/src/usage/compile.rst +516 -0
  189. data/mlx/docs/src/usage/distributed.rst +572 -0
  190. data/mlx/docs/src/usage/export.rst +288 -0
  191. data/mlx/docs/src/usage/function_transforms.rst +191 -0
  192. data/mlx/docs/src/usage/indexing.rst +194 -0
  193. data/mlx/docs/src/usage/launching_distributed.rst +234 -0
  194. data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
  195. data/mlx/docs/src/usage/numpy.rst +124 -0
  196. data/mlx/docs/src/usage/quick_start.rst +67 -0
  197. data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
  198. data/mlx/docs/src/usage/unified_memory.rst +78 -0
  199. data/mlx/docs/src/usage/using_streams.rst +18 -0
  200. data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
  201. data/mlx/examples/cmake_project/README.md +26 -0
  202. data/mlx/examples/cmake_project/example.cpp +14 -0
  203. data/mlx/examples/cpp/CMakeLists.txt +12 -0
  204. data/mlx/examples/cpp/distributed.cpp +22 -0
  205. data/mlx/examples/cpp/linear_regression.cpp +54 -0
  206. data/mlx/examples/cpp/logistic_regression.cpp +54 -0
  207. data/mlx/examples/cpp/metal_capture.cpp +31 -0
  208. data/mlx/examples/cpp/timer.h +20 -0
  209. data/mlx/examples/cpp/tutorial.cpp +99 -0
  210. data/mlx/examples/export/CMakeLists.txt +22 -0
  211. data/mlx/examples/export/README.md +49 -0
  212. data/mlx/examples/export/eval_mlp.cpp +25 -0
  213. data/mlx/examples/export/eval_mlp.py +52 -0
  214. data/mlx/examples/export/train_mlp.cpp +35 -0
  215. data/mlx/examples/export/train_mlp.py +76 -0
  216. data/mlx/examples/extensions/CMakeLists.txt +78 -0
  217. data/mlx/examples/extensions/README.md +24 -0
  218. data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
  219. data/mlx/examples/extensions/axpby/axpby.h +90 -0
  220. data/mlx/examples/extensions/axpby/axpby.metal +47 -0
  221. data/mlx/examples/extensions/bindings.cpp +39 -0
  222. data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
  223. data/mlx/examples/extensions/pyproject.toml +8 -0
  224. data/mlx/examples/extensions/requirements.txt +4 -0
  225. data/mlx/examples/extensions/setup.py +18 -0
  226. data/mlx/examples/extensions/test.py +12 -0
  227. data/mlx/examples/python/linear_regression.py +46 -0
  228. data/mlx/examples/python/logistic_regression.py +49 -0
  229. data/mlx/examples/python/qqmm.py +117 -0
  230. data/mlx/mlx/3rdparty/.clang-format +2 -0
  231. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  232. data/mlx/mlx/CMakeLists.txt +107 -0
  233. data/mlx/mlx/allocator.h +75 -0
  234. data/mlx/mlx/api.h +29 -0
  235. data/mlx/mlx/array.cpp +354 -0
  236. data/mlx/mlx/array.h +647 -0
  237. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  238. data/mlx/mlx/backend/common/binary.h +97 -0
  239. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  240. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  241. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  242. data/mlx/mlx/backend/common/common.cpp +305 -0
  243. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  244. data/mlx/mlx/backend/common/compiled.h +77 -0
  245. data/mlx/mlx/backend/common/copy.h +50 -0
  246. data/mlx/mlx/backend/common/hadamard.h +109 -0
  247. data/mlx/mlx/backend/common/load.cpp +57 -0
  248. data/mlx/mlx/backend/common/matmul.h +67 -0
  249. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  250. data/mlx/mlx/backend/common/reduce.h +59 -0
  251. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  252. data/mlx/mlx/backend/common/slicing.h +20 -0
  253. data/mlx/mlx/backend/common/ternary.h +85 -0
  254. data/mlx/mlx/backend/common/unary.h +29 -0
  255. data/mlx/mlx/backend/common/utils.cpp +231 -0
  256. data/mlx/mlx/backend/common/utils.h +205 -0
  257. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  258. data/mlx/mlx/backend/cpu/arange.h +28 -0
  259. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  260. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  261. data/mlx/mlx/backend/cpu/binary.h +517 -0
  262. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  263. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  264. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  265. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  266. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  267. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  268. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  269. data/mlx/mlx/backend/cpu/copy.h +36 -0
  270. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  271. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  272. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  273. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  274. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  275. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  276. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  277. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  278. data/mlx/mlx/backend/cpu/eval.h +12 -0
  279. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  280. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  281. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  282. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  283. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  284. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  285. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  286. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  287. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  288. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  289. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  290. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  291. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  292. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  293. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  294. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  295. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  296. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  297. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  298. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  299. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  300. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  301. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  302. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  303. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  304. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  305. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  306. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  307. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  308. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  309. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  310. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  311. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  312. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  313. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  314. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  315. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  316. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  317. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  318. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  319. data/mlx/mlx/backend/cpu/unary.h +281 -0
  320. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  321. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  322. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  323. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  324. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  325. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  326. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  327. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  328. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  329. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  330. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  331. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  332. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  333. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  334. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  335. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  336. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  337. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  338. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  339. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  340. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  341. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  342. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  343. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  344. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  345. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  346. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  347. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  348. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  349. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  350. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  351. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  352. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  353. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  354. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  355. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  356. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  357. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  358. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  359. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  360. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  361. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  362. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  363. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  364. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  365. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  366. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  367. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  368. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  369. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  370. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  371. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  372. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  373. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  374. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  375. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  376. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  377. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  378. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  379. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  380. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  381. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  382. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  383. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  384. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  385. data/mlx/mlx/backend/cuda/device.h +195 -0
  386. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  387. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  388. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  389. data/mlx/mlx/backend/cuda/event.cu +415 -0
  390. data/mlx/mlx/backend/cuda/event.h +79 -0
  391. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  392. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  393. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  394. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  395. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  396. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  397. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  398. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  399. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  400. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  401. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  402. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  403. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  404. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  405. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  406. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  407. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  408. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  409. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  410. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  411. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  412. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  413. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  414. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  415. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  416. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  417. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  418. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  419. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  420. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  421. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  422. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  423. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  424. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  425. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  426. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  427. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  428. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  429. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  430. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  431. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  432. data/mlx/mlx/backend/cuda/random.cu +202 -0
  433. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  434. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  435. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  436. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  437. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  438. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  439. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  440. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  441. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  442. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  443. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  444. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  445. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  446. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  447. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  448. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  449. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  450. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  451. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  452. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  453. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  454. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  455. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  456. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  457. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  458. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  459. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  460. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  461. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  462. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  463. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  464. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  465. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  466. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  467. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  468. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  469. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  470. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  471. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  472. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  473. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  474. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  475. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  476. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  477. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  478. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  479. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  480. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  481. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  482. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  483. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  484. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  485. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  486. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  487. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  488. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  489. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  490. data/mlx/mlx/backend/cuda/utils.h +49 -0
  491. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  492. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  493. data/mlx/mlx/backend/cuda/worker.h +55 -0
  494. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  495. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  496. data/mlx/mlx/backend/gpu/copy.h +57 -0
  497. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  498. data/mlx/mlx/backend/gpu/eval.h +18 -0
  499. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  500. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  501. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  502. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  503. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  504. data/mlx/mlx/backend/metal/allocator.h +79 -0
  505. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  506. data/mlx/mlx/backend/metal/binary.h +33 -0
  507. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  508. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  509. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  510. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  511. data/mlx/mlx/backend/metal/device.cpp +816 -0
  512. data/mlx/mlx/backend/metal/device.h +289 -0
  513. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  514. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  515. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  516. data/mlx/mlx/backend/metal/event.cpp +62 -0
  517. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  518. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  519. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  520. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  521. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  522. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  523. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  524. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  525. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  526. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  527. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  528. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  529. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  530. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  531. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  532. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  533. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  534. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  535. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  536. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  537. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  538. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  539. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  540. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  541. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  542. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  543. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  544. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  545. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  546. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  547. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  548. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  549. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  550. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  551. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  552. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  553. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  554. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  555. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  556. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  557. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  558. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  559. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  560. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  561. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  562. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  563. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  564. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  565. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  566. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  567. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  568. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  569. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  570. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  571. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  572. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  573. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  574. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  575. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  576. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  577. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  578. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  579. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  580. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  581. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  582. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  583. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  584. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  585. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  586. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  587. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  588. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  589. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  590. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  591. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  592. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  593. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  594. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  595. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  596. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  597. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  598. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  599. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  600. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  601. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  602. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  603. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  604. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  605. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  606. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  607. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  608. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  609. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  610. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  611. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  612. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  613. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  614. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  615. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  616. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  617. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  618. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  619. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  620. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  621. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  622. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  623. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  624. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  625. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  626. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  627. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  628. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  629. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  630. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  631. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  632. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  633. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  634. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  635. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  636. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  637. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  638. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  639. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  640. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  641. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  642. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  643. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  644. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  645. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  646. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  647. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  648. data/mlx/mlx/backend/metal/kernels.h +375 -0
  649. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  650. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  651. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  652. data/mlx/mlx/backend/metal/matmul.h +144 -0
  653. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  654. data/mlx/mlx/backend/metal/metal.h +25 -0
  655. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  656. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  657. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  658. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  659. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  660. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  661. data/mlx/mlx/backend/metal/reduce.h +41 -0
  662. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  663. data/mlx/mlx/backend/metal/resident.h +32 -0
  664. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  665. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  666. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  667. data/mlx/mlx/backend/metal/scan.h +17 -0
  668. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  669. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  670. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  671. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  672. data/mlx/mlx/backend/metal/ternary.h +21 -0
  673. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  674. data/mlx/mlx/backend/metal/unary.h +21 -0
  675. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  676. data/mlx/mlx/backend/metal/utils.h +99 -0
  677. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  678. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  679. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  680. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  681. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  682. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  683. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  684. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  685. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  686. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  687. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  688. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  689. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  690. data/mlx/mlx/compile.cpp +1243 -0
  691. data/mlx/mlx/compile.h +45 -0
  692. data/mlx/mlx/compile_impl.h +70 -0
  693. data/mlx/mlx/device.cpp +72 -0
  694. data/mlx/mlx/device.h +56 -0
  695. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  696. data/mlx/mlx/distributed/distributed.cpp +197 -0
  697. data/mlx/mlx/distributed/distributed.h +61 -0
  698. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  699. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  700. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  701. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  702. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  703. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  704. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  705. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  706. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  707. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  708. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  709. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  710. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  711. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  712. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  713. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  714. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  715. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  716. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  717. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  718. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  719. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  720. data/mlx/mlx/distributed/ops.cpp +186 -0
  721. data/mlx/mlx/distributed/ops.h +57 -0
  722. data/mlx/mlx/distributed/primitives.cpp +95 -0
  723. data/mlx/mlx/distributed/primitives.h +156 -0
  724. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  725. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  726. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  727. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  728. data/mlx/mlx/distributed/ring/ring.h +12 -0
  729. data/mlx/mlx/distributed/utils.cpp +206 -0
  730. data/mlx/mlx/distributed/utils.h +67 -0
  731. data/mlx/mlx/dtype.cpp +197 -0
  732. data/mlx/mlx/dtype.h +116 -0
  733. data/mlx/mlx/dtype_utils.cpp +42 -0
  734. data/mlx/mlx/dtype_utils.h +119 -0
  735. data/mlx/mlx/einsum.cpp +941 -0
  736. data/mlx/mlx/einsum.h +23 -0
  737. data/mlx/mlx/event.h +58 -0
  738. data/mlx/mlx/export.cpp +1130 -0
  739. data/mlx/mlx/export.h +137 -0
  740. data/mlx/mlx/export_impl.h +99 -0
  741. data/mlx/mlx/fast.cpp +941 -0
  742. data/mlx/mlx/fast.h +103 -0
  743. data/mlx/mlx/fast_primitives.h +427 -0
  744. data/mlx/mlx/fence.h +39 -0
  745. data/mlx/mlx/fft.cpp +262 -0
  746. data/mlx/mlx/fft.h +159 -0
  747. data/mlx/mlx/graph_utils.cpp +175 -0
  748. data/mlx/mlx/graph_utils.h +67 -0
  749. data/mlx/mlx/io/CMakeLists.txt +25 -0
  750. data/mlx/mlx/io/gguf.cpp +470 -0
  751. data/mlx/mlx/io/gguf.h +20 -0
  752. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  753. data/mlx/mlx/io/load.cpp +397 -0
  754. data/mlx/mlx/io/load.h +175 -0
  755. data/mlx/mlx/io/no_gguf.cpp +20 -0
  756. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  757. data/mlx/mlx/io/safetensors.cpp +234 -0
  758. data/mlx/mlx/io.h +61 -0
  759. data/mlx/mlx/linalg.cpp +708 -0
  760. data/mlx/mlx/linalg.h +115 -0
  761. data/mlx/mlx/memory.h +80 -0
  762. data/mlx/mlx/mlx.h +25 -0
  763. data/mlx/mlx/ops.cpp +6094 -0
  764. data/mlx/mlx/ops.h +1610 -0
  765. data/mlx/mlx/primitives.cpp +5850 -0
  766. data/mlx/mlx/primitives.h +2525 -0
  767. data/mlx/mlx/random.cpp +492 -0
  768. data/mlx/mlx/random.h +283 -0
  769. data/mlx/mlx/scheduler.cpp +73 -0
  770. data/mlx/mlx/scheduler.h +189 -0
  771. data/mlx/mlx/small_vector.h +540 -0
  772. data/mlx/mlx/stream.h +42 -0
  773. data/mlx/mlx/threadpool.h +133 -0
  774. data/mlx/mlx/transforms.cpp +1065 -0
  775. data/mlx/mlx/transforms.h +231 -0
  776. data/mlx/mlx/transforms_impl.h +88 -0
  777. data/mlx/mlx/types/bf16.h +187 -0
  778. data/mlx/mlx/types/complex.h +113 -0
  779. data/mlx/mlx/types/fp16.h +234 -0
  780. data/mlx/mlx/types/half_types.h +58 -0
  781. data/mlx/mlx/types/limits.h +70 -0
  782. data/mlx/mlx/utils.cpp +302 -0
  783. data/mlx/mlx/utils.h +174 -0
  784. data/mlx/mlx/version.cpp +11 -0
  785. data/mlx/mlx/version.h +22 -0
  786. data/mlx/mlx.pc.in +52 -0
  787. data/mlx/pyproject.toml +7 -0
  788. data/mlx/python/mlx/__main__.py +27 -0
  789. data/mlx/python/mlx/_distributed_utils/common.py +135 -0
  790. data/mlx/python/mlx/_distributed_utils/config.py +631 -0
  791. data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
  792. data/mlx/python/mlx/_reprlib_fix.py +16 -0
  793. data/mlx/python/mlx/_stub_patterns.txt +36 -0
  794. data/mlx/python/mlx/extension.py +88 -0
  795. data/mlx/python/mlx/nn/__init__.py +5 -0
  796. data/mlx/python/mlx/nn/init.py +441 -0
  797. data/mlx/python/mlx/nn/layers/__init__.py +105 -0
  798. data/mlx/python/mlx/nn/layers/activations.py +661 -0
  799. data/mlx/python/mlx/nn/layers/base.py +675 -0
  800. data/mlx/python/mlx/nn/layers/containers.py +24 -0
  801. data/mlx/python/mlx/nn/layers/convolution.py +232 -0
  802. data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
  803. data/mlx/python/mlx/nn/layers/distributed.py +601 -0
  804. data/mlx/python/mlx/nn/layers/dropout.py +137 -0
  805. data/mlx/python/mlx/nn/layers/embedding.py +53 -0
  806. data/mlx/python/mlx/nn/layers/linear.py +180 -0
  807. data/mlx/python/mlx/nn/layers/normalization.py +363 -0
  808. data/mlx/python/mlx/nn/layers/pooling.py +398 -0
  809. data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
  810. data/mlx/python/mlx/nn/layers/quantized.py +426 -0
  811. data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
  812. data/mlx/python/mlx/nn/layers/transformer.py +354 -0
  813. data/mlx/python/mlx/nn/layers/upsample.py +277 -0
  814. data/mlx/python/mlx/nn/losses.py +610 -0
  815. data/mlx/python/mlx/nn/utils.py +165 -0
  816. data/mlx/python/mlx/optimizers/__init__.py +4 -0
  817. data/mlx/python/mlx/optimizers/optimizers.py +976 -0
  818. data/mlx/python/mlx/optimizers/schedulers.py +158 -0
  819. data/mlx/python/mlx/py.typed +1 -0
  820. data/mlx/python/mlx/utils.py +325 -0
  821. data/mlx/python/src/CMakeLists.txt +96 -0
  822. data/mlx/python/src/array.cpp +1525 -0
  823. data/mlx/python/src/buffer.h +124 -0
  824. data/mlx/python/src/constants.cpp +15 -0
  825. data/mlx/python/src/convert.cpp +504 -0
  826. data/mlx/python/src/convert.h +50 -0
  827. data/mlx/python/src/cuda.cpp +19 -0
  828. data/mlx/python/src/device.cpp +98 -0
  829. data/mlx/python/src/distributed.cpp +352 -0
  830. data/mlx/python/src/export.cpp +356 -0
  831. data/mlx/python/src/fast.cpp +627 -0
  832. data/mlx/python/src/fft.cpp +514 -0
  833. data/mlx/python/src/indexing.cpp +1016 -0
  834. data/mlx/python/src/indexing.h +41 -0
  835. data/mlx/python/src/linalg.cpp +663 -0
  836. data/mlx/python/src/load.cpp +531 -0
  837. data/mlx/python/src/load.h +51 -0
  838. data/mlx/python/src/memory.cpp +125 -0
  839. data/mlx/python/src/metal.cpp +98 -0
  840. data/mlx/python/src/mlx.cpp +51 -0
  841. data/mlx/python/src/mlx_func.cpp +116 -0
  842. data/mlx/python/src/mlx_func.h +31 -0
  843. data/mlx/python/src/ops.cpp +5545 -0
  844. data/mlx/python/src/random.cpp +516 -0
  845. data/mlx/python/src/small_vector.h +76 -0
  846. data/mlx/python/src/stream.cpp +147 -0
  847. data/mlx/python/src/transforms.cpp +1542 -0
  848. data/mlx/python/src/trees.cpp +311 -0
  849. data/mlx/python/src/trees.h +62 -0
  850. data/mlx/python/src/utils.cpp +98 -0
  851. data/mlx/python/src/utils.h +78 -0
  852. data/mlx/python/tests/__main__.py +5 -0
  853. data/mlx/python/tests/cuda_skip.py +62 -0
  854. data/mlx/python/tests/mlx_distributed_tests.py +314 -0
  855. data/mlx/python/tests/mlx_tests.py +116 -0
  856. data/mlx/python/tests/mpi_test_distributed.py +142 -0
  857. data/mlx/python/tests/nccl_test_distributed.py +52 -0
  858. data/mlx/python/tests/ring_test_distributed.py +131 -0
  859. data/mlx/python/tests/test_array.py +2139 -0
  860. data/mlx/python/tests/test_autograd.py +880 -0
  861. data/mlx/python/tests/test_bf16.py +196 -0
  862. data/mlx/python/tests/test_blas.py +1429 -0
  863. data/mlx/python/tests/test_compile.py +1277 -0
  864. data/mlx/python/tests/test_constants.py +41 -0
  865. data/mlx/python/tests/test_conv.py +1198 -0
  866. data/mlx/python/tests/test_conv_transpose.py +810 -0
  867. data/mlx/python/tests/test_device.py +150 -0
  868. data/mlx/python/tests/test_double.py +306 -0
  869. data/mlx/python/tests/test_einsum.py +363 -0
  870. data/mlx/python/tests/test_eval.py +200 -0
  871. data/mlx/python/tests/test_export_import.py +614 -0
  872. data/mlx/python/tests/test_fast.py +923 -0
  873. data/mlx/python/tests/test_fast_sdpa.py +647 -0
  874. data/mlx/python/tests/test_fft.py +323 -0
  875. data/mlx/python/tests/test_graph.py +37 -0
  876. data/mlx/python/tests/test_init.py +139 -0
  877. data/mlx/python/tests/test_linalg.py +621 -0
  878. data/mlx/python/tests/test_load.py +447 -0
  879. data/mlx/python/tests/test_losses.py +427 -0
  880. data/mlx/python/tests/test_memory.py +77 -0
  881. data/mlx/python/tests/test_nn.py +1986 -0
  882. data/mlx/python/tests/test_ops.py +3261 -0
  883. data/mlx/python/tests/test_optimizers.py +584 -0
  884. data/mlx/python/tests/test_quantized.py +1160 -0
  885. data/mlx/python/tests/test_random.py +392 -0
  886. data/mlx/python/tests/test_reduce.py +223 -0
  887. data/mlx/python/tests/test_tree.py +96 -0
  888. data/mlx/python/tests/test_upsample.py +100 -0
  889. data/mlx/python/tests/test_vmap.py +860 -0
  890. data/mlx/setup.py +315 -0
  891. data/mlx/tests/CMakeLists.txt +44 -0
  892. data/mlx/tests/allocator_tests.cpp +41 -0
  893. data/mlx/tests/arg_reduce_tests.cpp +204 -0
  894. data/mlx/tests/array_tests.cpp +663 -0
  895. data/mlx/tests/autograd_tests.cpp +1399 -0
  896. data/mlx/tests/blas_tests.cpp +110 -0
  897. data/mlx/tests/compile_tests.cpp +818 -0
  898. data/mlx/tests/creations_tests.cpp +239 -0
  899. data/mlx/tests/custom_vjp_tests.cpp +55 -0
  900. data/mlx/tests/device_tests.cpp +35 -0
  901. data/mlx/tests/einsum_tests.cpp +85 -0
  902. data/mlx/tests/eval_tests.cpp +93 -0
  903. data/mlx/tests/export_import_tests.cpp +164 -0
  904. data/mlx/tests/fft_tests.cpp +366 -0
  905. data/mlx/tests/gpu_tests.cpp +523 -0
  906. data/mlx/tests/linalg_tests.cpp +639 -0
  907. data/mlx/tests/load_tests.cpp +270 -0
  908. data/mlx/tests/ops_tests.cpp +4159 -0
  909. data/mlx/tests/random_tests.cpp +716 -0
  910. data/mlx/tests/scheduler_tests.cpp +121 -0
  911. data/mlx/tests/tests.cpp +26 -0
  912. data/mlx/tests/utils_tests.cpp +67 -0
  913. data/mlx/tests/vmap_tests.cpp +547 -0
  914. metadata +958 -0
@@ -0,0 +1,1542 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <algorithm>
4
+ #include <numeric>
5
+ #include <sstream>
6
+ #include <unordered_set>
7
+
8
+ #include <nanobind/nanobind.h>
9
+ #include <nanobind/stl/optional.h>
10
+ #include <nanobind/stl/pair.h>
11
+ #include <nanobind/stl/string.h>
12
+ #include <nanobind/stl/unordered_set.h>
13
+ #include <nanobind/stl/variant.h>
14
+ #include <nanobind/stl/vector.h>
15
+
16
+ #include "mlx/array.h"
17
+ #include "mlx/compile.h"
18
+ #include "mlx/compile_impl.h"
19
+ #include "mlx/transforms.h"
20
+ #include "mlx/transforms_impl.h"
21
+ #include "mlx/utils.h"
22
+ #include "python/src/mlx_func.h"
23
+ #include "python/src/small_vector.h"
24
+ #include "python/src/trees.h"
25
+
26
+ namespace mx = mlx::core;
27
+ namespace nb = nanobind;
28
+ using namespace nb::literals;
29
+
30
+ // Needed for printing shapes and strides.
31
+ using mx::operator<<;
32
+
33
+ using IntOrVec = std::variant<int, std::vector<int>>;
34
+ using StrOrSet = std::variant<std::string, std::unordered_set<std::string>>;
35
+
36
+ inline std::string type_name_str(const nb::handle& o) {
37
+ return nb::cast<std::string>(nb::type_name(o.type()));
38
+ }
39
+
40
+ auto validate_argnums_argnames(
41
+ const std::optional<IntOrVec>& argnums,
42
+ const StrOrSet& argnames) {
43
+ std::unordered_set<std::string> setnames;
44
+ if (auto pv = std::get_if<std::string>(&argnames); pv) {
45
+ setnames = {*pv};
46
+ } else {
47
+ setnames = std::get<std::unordered_set<std::string>>(argnames);
48
+ }
49
+
50
+ if (!argnums.has_value()) {
51
+ // argnums was not provided and argnames was empty
52
+ if (setnames.empty()) {
53
+ return std::make_pair(std::vector<int>{0}, setnames);
54
+ } else {
55
+ return std::make_pair(std::vector<int>{}, setnames);
56
+ }
57
+ }
58
+
59
+ std::vector<int> vecnums;
60
+ if (auto pv = std::get_if<int>(&(*argnums)); pv) {
61
+ vecnums = {*pv};
62
+ } else {
63
+ vecnums = std::get<std::vector<int>>(*argnums);
64
+ }
65
+
66
+ return std::make_pair(vecnums, setnames);
67
+ }
68
+
69
+ auto py_value_and_grad(
70
+ const nb::callable& fun,
71
+ std::vector<int> argnums,
72
+ std::unordered_set<std::string> argnames,
73
+ const std::string& error_msg_tag,
74
+ bool scalar_func_only) {
75
+ // Sanitize argnums
76
+ if (argnums.size() == 0 && argnames.size() == 0) {
77
+ throw std::invalid_argument(
78
+ error_msg_tag + " Gradient wrt no argument requested");
79
+ }
80
+ for (auto arg : argnums) {
81
+ std::sort(argnums.begin(), argnums.end());
82
+ if (argnums[0] < 0) {
83
+ std::ostringstream msg;
84
+ msg << error_msg_tag
85
+ << " Can't compute the gradient of negative argument index "
86
+ << argnums[0];
87
+ throw std::invalid_argument(msg.str());
88
+ }
89
+ for (int i = 1; i < argnums.size(); ++i) {
90
+ if (argnums[i] == argnums[i - 1]) {
91
+ std::ostringstream msg;
92
+ msg << error_msg_tag << " Duplicate argument index " << argnums[0]
93
+ << " is not allowed.";
94
+ throw std::invalid_argument(msg.str());
95
+ }
96
+ }
97
+ }
98
+
99
+ return [fun, argnums, argnames, error_msg_tag, scalar_func_only](
100
+ nb::args& args, nb::kwargs& kwargs) {
101
+ // Sanitize the input
102
+ if (argnums.size() > 0 && argnums.back() >= args.size()) {
103
+ std::ostringstream msg;
104
+ msg << error_msg_tag << " Can't compute the gradient of argument index "
105
+ << argnums.back() << " because the function is called with only "
106
+ << args.size() << " positional arguments.";
107
+ throw std::invalid_argument(msg.str());
108
+ }
109
+
110
+ for (auto& key : argnames) {
111
+ if (!kwargs.contains(key)) {
112
+ std::ostringstream msg;
113
+ msg << error_msg_tag
114
+ << " Can't compute the gradient of keyword argument '" << key
115
+ << "' because the function is called with the "
116
+ << "following keyword arguments {";
117
+ for (auto item : kwargs) {
118
+ msg << nb::cast<std::string>(item.first) << ",";
119
+ }
120
+ msg << "}";
121
+ throw std::invalid_argument(msg.str());
122
+ }
123
+ }
124
+
125
+ // Collect the arrays
126
+ std::vector<mx::array> arrays;
127
+ std::vector<nb::object> array_objects;
128
+ auto flatten_with_objects = [&arrays, &array_objects](
129
+ auto tree, bool strict) {
130
+ tree_visit(tree, [&](nb::handle obj) {
131
+ if (nb::isinstance<mx::array>(obj)) {
132
+ arrays.push_back(nb::cast<mx::array>(obj));
133
+ array_objects.push_back(nb::borrow<nb::object>(obj));
134
+ } else if (strict) {
135
+ throw std::invalid_argument(
136
+ "[tree_flatten] The argument should contain only arrays");
137
+ }
138
+ });
139
+ };
140
+
141
+ std::vector<int> counts(1, 0);
142
+ std::vector<int> gradient_indices;
143
+ for (int i = 0, j = 0; i < args.size(); ++i) {
144
+ bool needs_grad = (j < argnums.size() && argnums[j] == i);
145
+ auto pre_size = arrays.size();
146
+ flatten_with_objects(args[i], /* strict = */ needs_grad);
147
+ if (needs_grad) {
148
+ auto old_size = gradient_indices.size();
149
+ auto delta_size = arrays.size() - pre_size;
150
+ gradient_indices.resize(old_size + delta_size);
151
+ std::iota(
152
+ gradient_indices.begin() + old_size,
153
+ gradient_indices.end(),
154
+ pre_size);
155
+ j++;
156
+ counts.push_back(delta_size);
157
+ }
158
+ }
159
+ for (auto item : kwargs) {
160
+ bool needs_grad =
161
+ (argnames.find(nb::cast<std::string>(item.first)) != argnames.end());
162
+ auto pre_size = arrays.size();
163
+ flatten_with_objects(item.second, /* strict = */ needs_grad);
164
+ if (needs_grad) {
165
+ auto old_size = gradient_indices.size();
166
+ auto delta_size = arrays.size() - pre_size;
167
+ gradient_indices.resize(old_size + delta_size);
168
+ std::iota(
169
+ gradient_indices.begin() + old_size,
170
+ gradient_indices.end(),
171
+ pre_size);
172
+ counts.push_back(delta_size);
173
+ }
174
+ }
175
+ std::partial_sum(counts.cbegin(), counts.cend(), counts.begin());
176
+
177
+ // value_out will hold the output of the python function in order to be
178
+ // able to reconstruct the python tree of extra return values
179
+ nb::object py_value_out;
180
+ auto value_and_grads = mx::value_and_grad(
181
+ [&fun,
182
+ &array_objects,
183
+ &args,
184
+ &kwargs,
185
+ &py_value_out,
186
+ &error_msg_tag,
187
+ scalar_func_only](const std::vector<mx::array>& a) {
188
+ nb::list tree;
189
+ tree.append(args);
190
+ tree.append(kwargs);
191
+ tree_fill(tree, a);
192
+
193
+ // Call the python function
194
+ py_value_out = fun(*tree[0], **tree[1]);
195
+
196
+ // Replace the tracers with the originals. Don't overwrite
197
+ // locations which were written to during the call to fun
198
+ int index = 0;
199
+ tree_visit_update(tree, [&](nb::handle node) {
200
+ auto replace_arr = nb::cast<mx::array>(node);
201
+ if (replace_arr.id() == a[index].id()) {
202
+ return array_objects[index++];
203
+ } else {
204
+ index++;
205
+ return nb::cast(replace_arr);
206
+ }
207
+ });
208
+
209
+ // Validate the return value of the python function
210
+ if (!nb::isinstance<mx::array>(py_value_out)) {
211
+ if (scalar_func_only) {
212
+ std::ostringstream msg;
213
+ msg << error_msg_tag << " The return value of the function "
214
+ << "whose gradient we want to compute should be a "
215
+ << "scalar array; but " << type_name_str(py_value_out)
216
+ << " was returned.";
217
+ throw std::invalid_argument(msg.str());
218
+ }
219
+ if (!nb::isinstance<nb::tuple>(py_value_out)) {
220
+ std::ostringstream msg;
221
+ msg << error_msg_tag << " The return value of the function "
222
+ << "whose gradient we want to compute should be either a "
223
+ << "scalar array or a tuple with the first value being a "
224
+ << "scalar array (Union[array, tuple[array, Any, ...]]); but "
225
+ << type_name_str(py_value_out) << " was returned.";
226
+ throw std::invalid_argument(msg.str());
227
+ }
228
+ nb::tuple ret = nb::cast<nb::tuple>(py_value_out);
229
+ if (ret.size() == 0) {
230
+ std::ostringstream msg;
231
+ msg << error_msg_tag << " The return value of the function "
232
+ << "whose gradient we want to compute should be either a "
233
+ << "scalar array or a non-empty tuple. The first value should be a "
234
+ << "scalar array and the rest can be anything. Instead, "
235
+ << "we got an empty tuple.";
236
+ throw std::invalid_argument(msg.str());
237
+ }
238
+ if (!nb::isinstance<mx::array>(ret[0])) {
239
+ std::ostringstream msg;
240
+ msg << error_msg_tag << " The return value of the function "
241
+ << "whose gradient we want to compute should be either a "
242
+ << "scalar array or a tuple with the first value being a "
243
+ << "scalar array (Union[array, tuple[array, Any, ...]]); but it "
244
+ << "was a tuple with the first value being of type "
245
+ << type_name_str(ret[0]) << " .";
246
+ throw std::invalid_argument(msg.str());
247
+ }
248
+ }
249
+
250
+ return tree_flatten(py_value_out, false);
251
+ },
252
+ gradient_indices)(arrays);
253
+
254
+ auto value = value_and_grads.first;
255
+ auto gradients = value_and_grads.second;
256
+
257
+ // Put the gradients back in their container.
258
+ // We have the following cases:
259
+ //
260
+ // 1. Single python positional argument has a gradient (eg argnums=[0])
261
+ // 2. Many python positional arguments have gradients (eg argnums=[0, 1])
262
+ // 3. A python keyword argument has gradients
263
+ //
264
+ // In case 1 we return the original python variable but with the gradients.
265
+ // In case 2 we return a tuple of the above.
266
+ // In case 3 we return a tuple containing a tuple and dict (sth like
267
+ // (tuple(), dict(x=mx.array(5))) ).
268
+ nb::object positional_grads;
269
+ nb::object keyword_grads;
270
+ nb::object py_grads;
271
+
272
+ // Collect the gradients for the positional arguments
273
+ if (argnums.size() == 1) {
274
+ positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]);
275
+ } else if (argnums.size() > 1) {
276
+ nb::list grads_;
277
+ for (int i = 0; i < argnums.size(); i++) {
278
+ grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i]));
279
+ }
280
+ positional_grads = nb::tuple(grads_);
281
+ } else {
282
+ positional_grads = nb::none();
283
+ }
284
+
285
+ // No keyword argument gradients so return the tuple of gradients
286
+ if (argnames.size() == 0) {
287
+ py_grads = positional_grads;
288
+ } else {
289
+ nb::dict grads_;
290
+ int i = 0;
291
+ for (auto item : kwargs) {
292
+ auto k = nb::cast<std::string>(item.first);
293
+ if (argnames.find(k) != argnames.end()) {
294
+ grads_[k.c_str()] = tree_unflatten(
295
+ nb::borrow(item.second), gradients, counts[i++ + argnums.size()]);
296
+ }
297
+ }
298
+ keyword_grads = grads_;
299
+
300
+ py_grads = nb::make_tuple(positional_grads, keyword_grads);
301
+ }
302
+
303
+ // Put the values back in the container
304
+ nb::object return_value = tree_unflatten(py_value_out, value);
305
+ return std::make_pair(return_value, py_grads);
306
+ };
307
+ }
308
+
309
+ auto py_vmap(
310
+ const nb::callable& fun,
311
+ const nb::object& in_axes,
312
+ const nb::object& out_axes) {
313
+ return [fun, in_axes, out_axes](const nb::args& args) {
314
+ auto axes_to_flat_tree = [](const nb::object& tree,
315
+ const nb::object& axes,
316
+ bool output_axes) {
317
+ std::vector<int> flat_axes;
318
+ bool encountered_tuple = false;
319
+ tree_visit(
320
+ {tree, axes},
321
+ [&flat_axes, &encountered_tuple, output_axes](
322
+ const std::vector<nb::object>& inputs) {
323
+ if (nb::isinstance<mx::array>(inputs[0])) {
324
+ if (inputs[1].is_none()) {
325
+ flat_axes.push_back(-1);
326
+ } else if (nb::isinstance<nb::int_>(inputs[1])) {
327
+ int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1]));
328
+ const mx::array& x = nb::cast<mx::array>(inputs[0]);
329
+ if (axis < 0) {
330
+ axis += x.ndim() + output_axes;
331
+ }
332
+ if (axis < 0 || axis >= (x.ndim() + output_axes)) {
333
+ std::ostringstream msg;
334
+ msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
335
+ << "vectorization axis " << axis
336
+ << " for array with shape " << x.shape();
337
+ throw std::invalid_argument(msg.str());
338
+ }
339
+ flat_axes.push_back(axis);
340
+ } else if (nb::isinstance<nb::tuple>(inputs[1])) {
341
+ encountered_tuple = true;
342
+ auto l = nb::cast<nb::tuple>(inputs[1]);
343
+ if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) {
344
+ int axis = nb::cast<int>(nb::cast<nb::int_>(l[0]));
345
+ const mx::array& x = nb::cast<mx::array>(inputs[0]);
346
+ if (axis < 0) {
347
+ axis += x.ndim() + output_axes;
348
+ }
349
+ if (axis < 0 || axis >= (x.ndim() + output_axes)) {
350
+ std::ostringstream msg;
351
+ msg << "[vmap] Invalid" << (output_axes ? " output " : " ")
352
+ << "vectorization axis " << axis
353
+ << " for array with shape " << x.shape();
354
+ throw std::invalid_argument(msg.str());
355
+ }
356
+ flat_axes.push_back(axis);
357
+ } else if (l.size() == 1 && l[0].is_none()) {
358
+ flat_axes.push_back(-1);
359
+ } else {
360
+ throw std::invalid_argument(
361
+ "[vmap] axis must be int or None.");
362
+ }
363
+ } else {
364
+ throw std::invalid_argument("[vmap] axis must be int or None.");
365
+ }
366
+ } else {
367
+ throw std::invalid_argument(
368
+ "[vmap] The arguments should contain only arrays");
369
+ }
370
+ });
371
+ if (encountered_tuple && !nb::isinstance<mx::array>(tree)) {
372
+ throw std::invalid_argument("[vmap] axis must be int or None.");
373
+ }
374
+ return flat_axes;
375
+ };
376
+
377
+ // Inputs must be array or tree of arrays
378
+ auto inputs = tree_flatten(args, true);
379
+ auto flat_in_axes =
380
+ axes_to_flat_tree((args.size() == 1) ? args[0] : args, in_axes, false);
381
+
382
+ // py_value_out will hold the output of the python function in order to be
383
+ // able to reconstruct the python tree of extra return values
384
+ nb::object py_outputs;
385
+
386
+ auto vmap_fn =
387
+ [&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) {
388
+ // Call the python function
389
+ py_outputs = fun(*tree_unflatten(args, a));
390
+
391
+ // Flatten the outputs
392
+ return tree_flatten(py_outputs, true);
393
+ };
394
+
395
+ auto [trace_inputs, trace_outputs] =
396
+ mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes);
397
+
398
+ auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true);
399
+
400
+ // Perform the vmap
401
+ auto outputs = mx::detail::vmap_replace(
402
+ inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes);
403
+
404
+ // Put the outputs back in the container
405
+ return tree_unflatten(py_outputs, outputs);
406
+ };
407
+ }
408
+
409
+ struct PyCompiledFun {
410
+ nb::callable fun;
411
+ std::uintptr_t fun_id;
412
+ nb::object captured_inputs;
413
+ nb::object captured_outputs;
414
+ bool shapeless;
415
+
416
+ // Data to attach to the compiled function that contains the python output
417
+ // structure and the number of arrays in said structure.
418
+ struct AttachedData {
419
+ nb::object output_structure;
420
+ int num_outputs;
421
+
422
+ AttachedData(nb::object output_structure_, int num_outputs_)
423
+ : output_structure(output_structure_), num_outputs(num_outputs_) {}
424
+ };
425
+
426
+ PyCompiledFun(
427
+ const nb::callable& fun,
428
+ nb::object inputs,
429
+ nb::object outputs,
430
+ bool shapeless)
431
+ : fun(fun),
432
+ fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())),
433
+ captured_inputs(inputs),
434
+ captured_outputs(outputs),
435
+ shapeless(shapeless) {}
436
+
437
+ PyCompiledFun(const PyCompiledFun&) = delete;
438
+ PyCompiledFun& operator=(const PyCompiledFun&) = delete;
439
+ PyCompiledFun& operator=(PyCompiledFun&& other) = delete;
440
+ PyCompiledFun(PyCompiledFun&& other)
441
+ : fun(std::move(other.fun)),
442
+ fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())) {
443
+ other.fun_id = 0;
444
+ captured_inputs = std::move(other.captured_inputs);
445
+ captured_outputs = std::move(other.captured_outputs);
446
+ shapeless = other.shapeless;
447
+ };
448
+
449
+ nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
450
+ // Flat array inputs
451
+ std::vector<mx::array> inputs;
452
+
453
+ // Compilation constants which includes the tree structure of the arguments
454
+ std::vector<uint64_t> constants;
455
+
456
+ // Reserve some large primes to signify the presence of an array, a list or
457
+ // a dict in order to encode the structure of the pytree. We choose primes
458
+ // to reduce slightly the chances of these numbers occurring by a
459
+ // multiplication as values in the constants list.
460
+ constexpr uint64_t array_identifier = 18446744073709551557UL;
461
+ constexpr uint64_t list_identifier = 18446744073709551533UL;
462
+ constexpr uint64_t dict_identifier = 18446744073709551521UL;
463
+ constexpr uint64_t none_identifier = 10239356951478402889UL;
464
+
465
+ // Flatten the tree with hashed constants and structure
466
+ std::function<void(nb::handle)> recurse;
467
+ recurse = [&](nb::handle obj) {
468
+ if (nb::isinstance<nb::list>(obj)) {
469
+ auto l = nb::cast<nb::list>(obj);
470
+ constants.push_back(list_identifier);
471
+ for (int i = 0; i < l.size(); ++i) {
472
+ recurse(l[i]);
473
+ }
474
+ } else if (nb::isinstance<nb::tuple>(obj)) {
475
+ auto l = nb::cast<nb::tuple>(obj);
476
+ constants.push_back(list_identifier);
477
+ for (auto item : obj) {
478
+ recurse(item);
479
+ }
480
+ } else if (nb::isinstance<nb::dict>(obj)) {
481
+ auto d = nb::cast<nb::dict>(obj);
482
+ constants.push_back(dict_identifier);
483
+ for (auto item : d) {
484
+ auto r = item.first.attr("__hash__")();
485
+ constants.push_back(nb::cast<int64_t>(r));
486
+ recurse(item.second);
487
+ }
488
+ } else if (nb::isinstance<mx::array>(obj)) {
489
+ inputs.push_back(nb::cast<mx::array>(obj));
490
+ constants.push_back(array_identifier);
491
+ } else if (nb::isinstance<nb::str>(obj)) {
492
+ auto r = obj.attr("__hash__")();
493
+ constants.push_back(nb::cast<int64_t>(r));
494
+ } else if (nb::isinstance<nb::int_>(obj)) {
495
+ constants.push_back(nb::cast<int64_t>(obj));
496
+ } else if (nb::isinstance<nb::float_>(obj)) {
497
+ auto r = nb::cast<double>(obj);
498
+ constants.push_back(*reinterpret_cast<uint64_t*>(&r));
499
+ } else if (obj.is_none()) {
500
+ constants.push_back(none_identifier);
501
+ } else {
502
+ std::ostringstream msg;
503
+ msg << "[compile] Function arguments must be trees of arrays "
504
+ << "or constants (floats, ints, strings, or None), but received "
505
+ << "type " << type_name_str(obj) << ".";
506
+ throw std::invalid_argument(msg.str());
507
+ }
508
+ };
509
+
510
+ recurse(args);
511
+ int num_args = inputs.size();
512
+ recurse(kwargs);
513
+ auto compile_fun = [this, &args, &kwargs, num_args](
514
+ const std::vector<mx::array>& a) {
515
+ // Put tracers into captured inputs
516
+ std::vector<mx::array> flat_in_captures;
517
+ std::vector<mx::array> trace_captures;
518
+ if (!captured_inputs.is_none()) {
519
+ flat_in_captures = tree_flatten(captured_inputs, false);
520
+ trace_captures.insert(
521
+ trace_captures.end(), a.end() - flat_in_captures.size(), a.end());
522
+ tree_fill(captured_inputs, trace_captures);
523
+ }
524
+
525
+ auto tree_outputs =
526
+ fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args));
527
+ auto [outputs, py_outputs] =
528
+ tree_flatten_with_structure(std::move(tree_outputs), false);
529
+
530
+ std::shared_ptr<void> extra_data =
531
+ std::make_shared<AttachedData>(py_outputs, outputs.size());
532
+
533
+ if (!captured_outputs.is_none()) {
534
+ auto flat_out_captures = tree_flatten(captured_outputs, false);
535
+ outputs.insert(
536
+ outputs.end(),
537
+ std::make_move_iterator(flat_out_captures.begin()),
538
+ std::make_move_iterator(flat_out_captures.end()));
539
+ }
540
+
541
+ // Replace tracers with originals in captured inputs
542
+ if (!captured_inputs.is_none()) {
543
+ tree_replace(captured_inputs, trace_captures, flat_in_captures);
544
+ }
545
+ return mx::detail::ArraysAndExtra{outputs, extra_data};
546
+ };
547
+
548
+ if (!captured_inputs.is_none()) {
549
+ auto flat_in_captures = tree_flatten(captured_inputs, false);
550
+ inputs.insert(
551
+ inputs.end(),
552
+ std::make_move_iterator(flat_in_captures.begin()),
553
+ std::make_move_iterator(flat_in_captures.end()));
554
+ }
555
+
556
+ // Compile and call
557
+ auto [outputs, extra_data] =
558
+ mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs);
559
+
560
+ int num_outputs =
561
+ reinterpret_cast<AttachedData*>(extra_data.get())->num_outputs;
562
+ nb::object py_outputs =
563
+ reinterpret_cast<AttachedData*>(extra_data.get())->output_structure;
564
+
565
+ if (!captured_outputs.is_none()) {
566
+ std::vector<mx::array> captures(
567
+ std::make_move_iterator(outputs.begin() + num_outputs),
568
+ std::make_move_iterator(outputs.end()));
569
+ tree_fill(captured_outputs, captures);
570
+ }
571
+
572
+ // Put the outputs back in the container
573
+ return tree_unflatten_from_structure(std::move(py_outputs), outputs);
574
+ }
575
+
576
+ nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
577
+ return const_cast<PyCompiledFun*>(this)->call_impl(args, kwargs);
578
+ };
579
+
580
+ ~PyCompiledFun() {
581
+ nb::gil_scoped_acquire gil;
582
+
583
+ mx::detail::compile_erase(fun_id);
584
+ fun.reset();
585
+ captured_inputs.reset();
586
+ captured_outputs.reset();
587
+ }
588
+ };
589
+
590
+ class PyCheckpointedFun {
591
+ public:
592
+ PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {}
593
+ ~PyCheckpointedFun() {
594
+ nb::gil_scoped_acquire gil;
595
+
596
+ fun_.reset();
597
+ }
598
+
599
+ struct InnerFunction {
600
+ nb::object fun_;
601
+ nb::object args_structure_;
602
+ std::weak_ptr<nb::object> output_structure_;
603
+
604
+ InnerFunction(
605
+ nb::object fun,
606
+ nb::object args_structure,
607
+ std::weak_ptr<nb::object> output_structure)
608
+ : fun_(std::move(fun)),
609
+ args_structure_(std::move(args_structure)),
610
+ output_structure_(output_structure) {}
611
+ ~InnerFunction() {
612
+ nb::gil_scoped_acquire gil;
613
+
614
+ fun_.reset();
615
+ args_structure_.reset();
616
+ }
617
+
618
+ std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
619
+ auto args = nb::cast<nb::tuple>(
620
+ tree_unflatten_from_structure(args_structure_, inputs));
621
+ auto [outputs, output_structure] =
622
+ tree_flatten_with_structure(fun_(*args[0], **args[1]), false);
623
+ if (auto s = output_structure_.lock()) {
624
+ *s = output_structure;
625
+ }
626
+ return outputs;
627
+ }
628
+ };
629
+
630
+ nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
631
+ auto output_structure = std::make_shared<nb::object>();
632
+ auto full_args = nb::make_tuple(args, kwargs);
633
+ auto [inputs, args_structure] =
634
+ tree_flatten_with_structure(full_args, false);
635
+
636
+ auto outputs = mx::checkpoint(
637
+ InnerFunction(fun_, args_structure, output_structure))(inputs);
638
+
639
+ return tree_unflatten_from_structure(*output_structure, outputs);
640
+ }
641
+
642
+ nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const {
643
+ return const_cast<PyCheckpointedFun*>(this)->call_impl(args, kwargs);
644
+ }
645
+
646
+ private:
647
+ nb::callable fun_;
648
+ };
649
+
650
+ int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg);
651
+
652
+ int py_custom_function_tp_clear(PyObject* self);
653
+
654
+ /**
655
+ * PyCustomFunction is the class that implements the python decorator
656
+ * `mx.custom_function`.
657
+ *
658
+ * It implements a callable that instead of simply calling `fun` it creates a
659
+ * CustomTransforms primitive via the `custom_function` C++ op which allows us
660
+ * to redefine the vjp, jvp and vmap transformations.
661
+ *
662
+ * The implementation is verbose due to explicit handling of the destruction of
663
+ * various python objects to make sure that there is no double-free and that
664
+ * all of them are deleted while under GIL.
665
+ *
666
+ * Namely, for every one of the functions passed to the C++ `custom_function`
667
+ * we create a callable struct that holds the following python objects (when
668
+ * needed).
669
+ *
670
+ * - An nb::callable which holds the passed function or transform
671
+ * - An nb::object holding input structure, namely the `(args, kwargs)`
672
+ * passed to the function in order to be able to recreate the arguments
673
+ * from the input arrays.
674
+ * - A std::shared_ptr<nb::object> holding the output structure name the
675
+ * structure of the return value of `fun`. It is a shared_ptr so that it
676
+ * can be set when the function is called and then used in the `vjp`
677
+ * transform. We delete the object only when the shared_ptr is about to be
678
+ * deleted see `output_structure_.use_count() == 1` to make sure that the
679
+ * object is deleted under GIL.
680
+ */
681
+ class PyCustomFunction {
682
+ public:
683
+ PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {}
684
+ ~PyCustomFunction() {
685
+ nb::gil_scoped_acquire gil;
686
+ reset();
687
+ }
688
+
689
+ struct InnerFunction {
690
+ nb::callable fun_;
691
+ nb::object input_structure_;
692
+ std::shared_ptr<nb::object> output_structure_;
693
+
694
+ InnerFunction(
695
+ nb::callable fun,
696
+ nb::object input_structure,
697
+ std::shared_ptr<nb::object> output_structure)
698
+ : fun_(std::move(fun)),
699
+ input_structure_(std::move(input_structure)),
700
+ output_structure_(std::move(output_structure)) {}
701
+ ~InnerFunction() {
702
+ nb::gil_scoped_acquire gil;
703
+
704
+ fun_.reset();
705
+ input_structure_.reset();
706
+ if (output_structure_.use_count() == 1) {
707
+ output_structure_->reset();
708
+ }
709
+ }
710
+
711
+ std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) {
712
+ nb::gil_scoped_acquire gil;
713
+
714
+ auto new_inputs = nb::cast<nb::tuple>(
715
+ tree_unflatten_from_structure(input_structure_, inputs));
716
+ std::vector<mx::array> outputs;
717
+ std::tie(outputs, *output_structure_) =
718
+ tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1]));
719
+ return outputs;
720
+ }
721
+ };
722
+
723
+ struct InnerVJPFunction {
724
+ nb::callable vjp_fun_;
725
+ nb::object input_structure_;
726
+ std::shared_ptr<nb::object> output_structure_;
727
+
728
+ InnerVJPFunction(
729
+ nb::callable vjp_fun,
730
+ nb::object input_structure,
731
+ std::shared_ptr<nb::object> output_structure)
732
+ : vjp_fun_(std::move(vjp_fun)),
733
+ input_structure_(std::move(input_structure)),
734
+ output_structure_(std::move(output_structure)) {}
735
+ ~InnerVJPFunction() {
736
+ nb::gil_scoped_acquire gil;
737
+
738
+ vjp_fun_.reset();
739
+ input_structure_.reset();
740
+ if (output_structure_.use_count() == 1) {
741
+ output_structure_->reset();
742
+ }
743
+ }
744
+
745
+ std::vector<mx::array> operator()(
746
+ const std::vector<mx::array>& primals,
747
+ const std::vector<mx::array>& cotangents,
748
+ const std::vector<mx::array>& outputs) {
749
+ nb::gil_scoped_acquire gil;
750
+
751
+ auto new_inputs = nb::cast<nb::tuple>(
752
+ tree_unflatten_from_structure(input_structure_, primals));
753
+ auto args = nb::cast<nb::tuple>(new_inputs[0]);
754
+ auto new_cotangents =
755
+ tree_unflatten_from_structure(*output_structure_, cotangents);
756
+ auto new_outputs =
757
+ tree_unflatten_from_structure(*output_structure_, outputs);
758
+
759
+ if (args.size() == 1) {
760
+ return tree_flatten(
761
+ vjp_fun_(args[0], new_cotangents, new_outputs, **new_inputs[1]),
762
+ false);
763
+ } else {
764
+ return tree_flatten(
765
+ vjp_fun_(args, new_cotangents, new_outputs, **new_inputs[1]),
766
+ false);
767
+ }
768
+ }
769
+ };
770
+
771
+ struct InnerJVPFunction {
772
+ nb::callable jvp_fun_;
773
+ nb::object input_structure_;
774
+
775
+ InnerJVPFunction(nb::callable jvp_fun, nb::object input_structure)
776
+ : jvp_fun_(std::move(jvp_fun)),
777
+ input_structure_(std::move(input_structure)) {}
778
+ ~InnerJVPFunction() {
779
+ nb::gil_scoped_acquire gil;
780
+
781
+ jvp_fun_.reset();
782
+ input_structure_.reset();
783
+ }
784
+
785
+ std::vector<mx::array> operator()(
786
+ const std::vector<mx::array>& primals,
787
+ const std::vector<mx::array>& tangents,
788
+ const std::vector<int>& argnums) {
789
+ nb::gil_scoped_acquire gil;
790
+
791
+ auto new_inputs = nb::cast<nb::tuple>(
792
+ tree_unflatten_from_structure(input_structure_, primals));
793
+ auto args = nb::cast<nb::tuple>(new_inputs[0]);
794
+ auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
795
+ if (kwargs.size() > 0) {
796
+ throw std::invalid_argument(
797
+ "[custom jvp] Function should only accept positional arguments");
798
+ }
799
+
800
+ // Make a new pytree which has tangents or None when a tangent is not
801
+ // available.
802
+ std::vector<bool> have_tangents(primals.size(), false);
803
+ for (auto arg : argnums) {
804
+ have_tangents[arg] = true;
805
+ }
806
+ int array_index = 0;
807
+ int tangent_index = 0;
808
+ auto new_tangents =
809
+ nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
810
+ if (nb::isinstance<mx::array>(element) &&
811
+ have_tangents[array_index++]) {
812
+ return nb::cast(tangents[tangent_index++]);
813
+ } else {
814
+ return nb::none();
815
+ }
816
+ }));
817
+
818
+ if (args.size() == 1) {
819
+ return tree_flatten(jvp_fun_(args[0], new_tangents[0]), false);
820
+ } else {
821
+ return tree_flatten(jvp_fun_(args, new_tangents), false);
822
+ }
823
+ }
824
+ };
825
+
826
+ struct InnerVmapFunction {
827
+ nb::callable vmap_fun_;
828
+ nb::object input_structure_;
829
+
830
+ InnerVmapFunction(nb::callable vmap_fun, nb::object input_structure)
831
+ : vmap_fun_(std::move(vmap_fun)),
832
+ input_structure_(std::move(input_structure)) {}
833
+ ~InnerVmapFunction() {
834
+ nb::gil_scoped_acquire gil;
835
+
836
+ vmap_fun_.reset();
837
+ input_structure_.reset();
838
+ }
839
+
840
+ std::pair<std::vector<mx::array>, std::vector<int>> operator()(
841
+ const std::vector<mx::array>& inputs,
842
+ const std::vector<int>& axes) {
843
+ nb::gil_scoped_acquire gil;
844
+
845
+ auto new_inputs = nb::cast<nb::tuple>(
846
+ tree_unflatten_from_structure(input_structure_, inputs));
847
+ auto args = nb::cast<nb::tuple>(new_inputs[0]);
848
+ auto kwargs = nb::cast<nb::dict>(new_inputs[1]);
849
+ if (kwargs.size() > 0) {
850
+ throw std::invalid_argument(
851
+ "[custom vmap] Function should only accept positional arguments");
852
+ }
853
+
854
+ int arr_index = 0;
855
+ auto new_axes =
856
+ nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) {
857
+ int axis = axes[arr_index++];
858
+ if (nb::isinstance<mx::array>(element) && axis >= 0) {
859
+ return nb::cast(axis);
860
+ } else {
861
+ return nb::none();
862
+ }
863
+ }));
864
+
865
+ nb::object result;
866
+ if (args.size() == 1) {
867
+ result = vmap_fun_(args[0], new_axes[0]);
868
+ } else {
869
+ result = vmap_fun_(args, new_axes);
870
+ }
871
+
872
+ if (!nb::isinstance<nb::tuple>(result)) {
873
+ throw std::invalid_argument(
874
+ "[custom vmap] Vmap function should return a tuple with 2 items.");
875
+ }
876
+ nb::tuple result_tuple = nb::cast<nb::tuple>(result);
877
+ if (result_tuple.size() != 2) {
878
+ throw std::invalid_argument(
879
+ "[custom vmap] Vmap function should return a tuple with 2 items.");
880
+ }
881
+
882
+ std::vector<mx::array> outputs;
883
+ std::vector<int> output_axes;
884
+ tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) {
885
+ if (nb::isinstance<mx::array>(objects[0])) {
886
+ outputs.push_back(nb::cast<mx::array>(objects[0]));
887
+ output_axes.push_back(
888
+ objects[1].is_none() ? -1 : nb::cast<int>(objects[1]));
889
+ }
890
+ });
891
+
892
+ return {outputs, output_axes};
893
+ }
894
+ };
895
+
896
+ nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) {
897
+ if (!vjp_fun_.has_value() && !jvp_fun_.has_value() &&
898
+ !vmap_fun_.has_value()) {
899
+ return fun_(*args, **kwargs);
900
+ }
901
+
902
+ // Extract the inputs and their structure in capturable vars
903
+ std::vector<mx::array> input_arrays;
904
+ nb::object input_structure;
905
+ auto full_args = nb::make_tuple(args, kwargs);
906
+ std::tie(input_arrays, input_structure) =
907
+ tree_flatten_with_structure(full_args, false);
908
+
909
+ // The output structure will be stored here to be used in the custom vjp
910
+ // function
911
+ auto output_structure = std::make_shared<nb::object>();
912
+
913
+ // Make a function that calls fun_ in the forward pass and vjp_ in the
914
+ // backward pass. Then call it immediately and return the results.
915
+ auto f = mx::custom_function(
916
+ InnerFunction(fun_, input_structure, output_structure),
917
+ make_vjp_function(input_structure, output_structure),
918
+ make_jvp_function(input_structure),
919
+ make_vmap_function(input_structure));
920
+
921
+ auto outputs = f(input_arrays);
922
+ return tree_unflatten_from_structure(*output_structure, outputs);
923
+ }
924
+
925
+ PyCustomFunction& set_vjp(nb::callable vjp_fun) {
926
+ vjp_fun_ = vjp_fun;
927
+ return *this;
928
+ }
929
+
930
+ PyCustomFunction& set_jvp(nb::callable jvp_fun) {
931
+ jvp_fun_ = jvp_fun;
932
+ return *this;
933
+ }
934
+
935
+ PyCustomFunction& set_vmap(nb::callable vmap_fun) {
936
+ vmap_fun_ = vmap_fun;
937
+ return *this;
938
+ }
939
+ void reset() {
940
+ fun_.reset();
941
+ if (vjp_fun_.has_value()) {
942
+ (*vjp_fun_).reset();
943
+ }
944
+ if (jvp_fun_.has_value()) {
945
+ (*jvp_fun_).reset();
946
+ }
947
+ if (vmap_fun_.has_value()) {
948
+ (*vmap_fun_).reset();
949
+ }
950
+ }
951
+
952
+ friend int py_custom_function_tp_traverse(PyObject*, visitproc, void*);
953
+
954
+ private:
955
+ std::optional<InnerVJPFunction> make_vjp_function(
956
+ nb::object input_structure,
957
+ std::shared_ptr<nb::object> output_structure) {
958
+ if (!vjp_fun_.has_value()) {
959
+ return std::nullopt;
960
+ }
961
+
962
+ return InnerVJPFunction(*vjp_fun_, input_structure, output_structure);
963
+ }
964
+
965
+ std::optional<InnerJVPFunction> make_jvp_function(
966
+ nb::object input_structure) {
967
+ if (!jvp_fun_.has_value()) {
968
+ return std::nullopt;
969
+ }
970
+
971
+ return InnerJVPFunction(*jvp_fun_, input_structure);
972
+ }
973
+
974
+ std::optional<InnerVmapFunction> make_vmap_function(
975
+ nb::object input_structure) {
976
+ if (!vmap_fun_.has_value()) {
977
+ return std::nullopt;
978
+ }
979
+
980
+ return InnerVmapFunction(*vmap_fun_, input_structure);
981
+ }
982
+
983
+ nb::callable fun_;
984
+ std::optional<nb::callable> vjp_fun_;
985
+ std::optional<nb::callable> jvp_fun_;
986
+ std::optional<nb::callable> vmap_fun_;
987
+ };
988
+
989
+ int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) {
990
+ Py_VISIT(Py_TYPE(self));
991
+ if (!nb::inst_ready(self)) {
992
+ return 0;
993
+ }
994
+
995
+ auto* p = nb::inst_ptr<PyCustomFunction>(self);
996
+ nb::handle v = nb::find(p->fun_);
997
+ Py_VISIT(v.ptr());
998
+ if (p->vjp_fun_.has_value()) {
999
+ nb::handle v = nb::find(*(p->vjp_fun_));
1000
+ Py_VISIT(v.ptr());
1001
+ }
1002
+ if (p->jvp_fun_.has_value()) {
1003
+ nb::handle v = nb::find(*(p->jvp_fun_));
1004
+ Py_VISIT(v.ptr());
1005
+ }
1006
+ if (p->vmap_fun_.has_value()) {
1007
+ nb::handle v = nb::find(*(p->vmap_fun_));
1008
+ Py_VISIT(v.ptr());
1009
+ }
1010
+ return 0;
1011
+ }
1012
+ int py_custom_function_tp_clear(PyObject* self) {
1013
+ auto* p = nb::inst_ptr<PyCustomFunction>(self);
1014
+ p->reset();
1015
+ return 0;
1016
+ }
1017
+ PyType_Slot py_custom_function_slots[] = {
1018
+ {Py_tp_traverse, (void*)py_custom_function_tp_traverse},
1019
+ {Py_tp_clear, (void*)py_custom_function_tp_clear},
1020
+ {0, 0}};
1021
+
1022
+ void init_transforms(nb::module_& m) {
1023
+ nb::class_<PyCustomFunction>(
1024
+ m,
1025
+ "custom_function",
1026
+ nb::type_slots(py_custom_function_slots),
1027
+ R"pbdoc(
1028
+ Set up a function for custom gradient and vmap definitions.
1029
+
1030
+ This class is meant to be used as a function decorator. Instances are
1031
+ callables that behave identically to the wrapped function. However, when
1032
+ a function transformation is used (e.g. computing gradients using
1033
+ :func:`value_and_grad`) then the functions defined via
1034
+ :meth:`custom_function.vjp`, :meth:`custom_function.jvp` and
1035
+ :meth:`custom_function.vmap` are used instead of the default transformation.
1036
+
1037
+ Note, all custom transformations are optional. Undefined transformations
1038
+ fall back to the default behaviour.
1039
+
1040
+ Example:
1041
+
1042
+ .. code-block:: python
1043
+
1044
+ import mlx.core as mx
1045
+
1046
+ @mx.custom_function
1047
+ def f(x, y):
1048
+ return mx.sin(x) * y
1049
+
1050
+ @f.vjp
1051
+ def f_vjp(primals, cotangent, output):
1052
+ x, y = primals
1053
+ return cotan * mx.cos(x) * y, cotan * mx.sin(x)
1054
+
1055
+ @f.jvp
1056
+ def f_jvp(primals, tangents):
1057
+ x, y = primals
1058
+ dx, dy = tangents
1059
+ return dx * mx.cos(x) * y + dy * mx.sin(x)
1060
+
1061
+ @f.vmap
1062
+ def f_vmap(inputs, axes):
1063
+ x, y = inputs
1064
+ ax, ay = axes
1065
+ if ay != ax and ax is not None:
1066
+ y = y.swapaxes(ay, ax)
1067
+ return mx.sin(x) * y, (ax or ay)
1068
+
1069
+ All ``custom_function`` instances behave as pure functions. Namely, any
1070
+ variables captured will be treated as constants and no gradients will be
1071
+ computed with respect to the captured arrays. For instance:
1072
+
1073
+ .. code-block:: python
1074
+
1075
+ import mlx.core as mx
1076
+
1077
+ def g(x, y):
1078
+ @mx.custom_function
1079
+ def f(x):
1080
+ return x * y
1081
+
1082
+ @f.vjp
1083
+ def f_vjp(x, dx, fx):
1084
+ # Note that we have only x, dx and fx and nothing with respect to y
1085
+ raise ValueError("Abort!")
1086
+
1087
+ return f(x)
1088
+
1089
+ x = mx.array(2.0)
1090
+ y = mx.array(3.0)
1091
+ print(g(x, y)) # prints 6.0
1092
+ print(mx.grad(g)(x, y)) # Raises exception
1093
+ print(mx.grad(g, argnums=1)(x, y)) # prints 0.0
1094
+ )pbdoc")
1095
+ .def(
1096
+ nb::init<nb::callable>(),
1097
+ "f"_a,
1098
+ nb::sig("def __init__(self, f: Callable)"))
1099
+ .def("__call__", &PyCustomFunction::call_impl)
1100
+ .def(
1101
+ "vjp",
1102
+ &PyCustomFunction::set_vjp,
1103
+ "f"_a,
1104
+ nb::sig("def vjp(self, f: Callable)"),
1105
+ R"pbdoc(
1106
+ Define a custom vjp for the wrapped function.
1107
+
1108
+ The vjp function takes three arguments:
1109
+
1110
+ - *primals*: A pytree that contains all the positional arguments to
1111
+ the function. It could be a single array, a tuple of arrays or a
1112
+ full blown tuple of dicts of arrays etc.
1113
+ - *cotangents*: A pytree that matches the structure of the output
1114
+ but contains the cotangents (usually the gradients of the loss
1115
+ function with respect to the outputs).
1116
+ - *outputs*: The outputs of the function to be used to avoid
1117
+ recomputing them for the gradient computation.
1118
+
1119
+ The vjp function should return the same pytree structure as the
1120
+ primals but containing the corresponding computed cotangents.
1121
+ )pbdoc")
1122
+ .def(
1123
+ "jvp",
1124
+ &PyCustomFunction::set_jvp,
1125
+ "f"_a,
1126
+ nb::sig("def jvp(self, f: Callable)"),
1127
+ R"pbdoc(
1128
+ Define a custom jvp for the wrapped function.
1129
+
1130
+ The jvp function takes two arguments:
1131
+
1132
+ - *primals*: A pytree that contains all the positional arguments to
1133
+ the function. It could be a single array, a tuple of arrays or a
1134
+ full blown tuple of dicts of arrays etc.
1135
+ - *tangents*: A pytree that matches the structure of the inputs but
1136
+ instead contains the gradients wrt to each input. Tangents could
1137
+ be ``None`` if some inputs don't have an associated gradient.
1138
+
1139
+ The jvp function should return the same pytree structure as the
1140
+ outputs of the function but containing the tangents.
1141
+ )pbdoc")
1142
+ .def(
1143
+ "vmap",
1144
+ &PyCustomFunction::set_vmap,
1145
+ "f"_a,
1146
+ nb::sig("def vmap(self, f: Callable)"),
1147
+ R"pbdoc(
1148
+ Define a custom vectorization transformation for the wrapped function.
1149
+
1150
+ The vmap function takes two arguments:
1151
+
1152
+ - *inputs*: A pytree that contains all the positional arguments to
1153
+ the function. It could be a single array, a tuple of arrays or a
1154
+ full blown tuple of dicts of arrays etc.
1155
+ - *axes*: A pytree that matches the structure of the inputs but
1156
+ instead contains the vectorization axis for each input or
1157
+ ``None`` if an input is not vectorized.
1158
+
1159
+ The vmap function should return the outputs of the original
1160
+ function but vectorized over the provided axes. It should also
1161
+ return a pytree with the vectorization axes of each output. If some
1162
+ outputs are no longer vectorized, then their vectorization axis
1163
+ should be ``None``.
1164
+ )pbdoc");
1165
+
1166
+ m.def(
1167
+ "eval",
1168
+ [](const nb::args& args) {
1169
+ std::vector<mx::array> arrays = tree_flatten(args, false);
1170
+ {
1171
+ nb::gil_scoped_release nogil;
1172
+ eval(arrays);
1173
+ }
1174
+ },
1175
+ nb::arg(),
1176
+ nb::sig("def eval(*args) -> None"),
1177
+ R"pbdoc(
1178
+ Evaluate an :class:`array` or tree of :class:`array`.
1179
+
1180
+ Args:
1181
+ *args (arrays or trees of arrays): Each argument can be a single array
1182
+ or a tree of arrays. If a tree is given the nodes can be a Python
1183
+ :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
1184
+ arrays are ignored.
1185
+ )pbdoc");
1186
+ m.def(
1187
+ "async_eval",
1188
+ [](const nb::args& args) {
1189
+ std::vector<mx::array> arrays = tree_flatten(args, false);
1190
+ {
1191
+ nb::gil_scoped_release nogil;
1192
+ async_eval(arrays);
1193
+ }
1194
+ },
1195
+ nb::arg(),
1196
+ nb::sig("def async_eval(*args)"),
1197
+ R"pbdoc(
1198
+ Asynchronously evaluate an :class:`array` or tree of :class:`array`.
1199
+
1200
+ .. note::
1201
+
1202
+ This is an experimental API and may change in future versions.
1203
+
1204
+ Args:
1205
+ *args (arrays or trees of arrays): Each argument can be a single array
1206
+ or a tree of arrays. If a tree is given the nodes can be a Python
1207
+ :class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not
1208
+ arrays are ignored.
1209
+
1210
+ Example:
1211
+ >>> x = mx.array(1.0)
1212
+ >>> y = mx.exp(x)
1213
+ >>> mx.async_eval(y)
1214
+ >>> print(y)
1215
+ >>>
1216
+ >>> y = mx.exp(x)
1217
+ >>> mx.async_eval(y)
1218
+ >>> z = y + 3
1219
+ >>> mx.async_eval(z)
1220
+ >>> print(z)
1221
+ )pbdoc");
1222
+ m.def(
1223
+ "jvp",
1224
+ [](const nb::callable& fun,
1225
+ const std::vector<mx::array>& primals,
1226
+ const std::vector<mx::array>& tangents) {
1227
+ auto vfun = [&fun](const std::vector<mx::array>& primals) {
1228
+ auto out = fun(*nb::cast(primals));
1229
+ if (nb::isinstance<mx::array>(out)) {
1230
+ return std::vector<mx::array>{nb::cast<mx::array>(out)};
1231
+ } else {
1232
+ return nb::cast<std::vector<mx::array>>(out);
1233
+ }
1234
+ };
1235
+ return jvp(vfun, primals, tangents);
1236
+ },
1237
+ "fun"_a,
1238
+ "primals"_a,
1239
+ "tangents"_a,
1240
+ nb::sig(
1241
+ "def jvp(fun: Callable, primals: list[array], tangents: list[array]) -> tuple[list[array], list[array]]"),
1242
+ R"pbdoc(
1243
+ Compute the Jacobian-vector product.
1244
+
1245
+ This computes the product of the Jacobian of a function ``fun`` evaluated
1246
+ at ``primals`` with the ``tangents``.
1247
+
1248
+ Args:
1249
+ fun (Callable): A function which takes a variable number of :class:`array`
1250
+ and returns a single :class:`array` or list of :class:`array`.
1251
+ primals (list(array)): A list of :class:`array` at which to
1252
+ evaluate the Jacobian.
1253
+ tangents (list(array)): A list of :class:`array` which are the
1254
+ "vector" in the Jacobian-vector product. The ``tangents`` should be the
1255
+ same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``).
1256
+
1257
+ Returns:
1258
+ tuple(list(array), list(array)): A tuple with the outputs of
1259
+ ``fun`` in the first position and the Jacobian-vector products
1260
+ in the second position.
1261
+
1262
+ Example:
1263
+
1264
+ .. code-block:: python
1265
+
1266
+ import mlx.core as mx
1267
+
1268
+ outs, jvps = mx.jvp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
1269
+
1270
+ )pbdoc");
1271
+ m.def(
1272
+ "vjp",
1273
+ [](const nb::callable& fun,
1274
+ const std::vector<mx::array>& primals,
1275
+ const std::vector<mx::array>& cotangents) {
1276
+ auto vfun = [&fun](const std::vector<mx::array>& primals) {
1277
+ auto out = fun(*nb::cast(primals));
1278
+ if (nb::isinstance<mx::array>(out)) {
1279
+ return std::vector<mx::array>{nb::cast<mx::array>(out)};
1280
+ } else {
1281
+ return nb::cast<std::vector<mx::array>>(out);
1282
+ }
1283
+ };
1284
+ return vjp(vfun, primals, cotangents);
1285
+ },
1286
+ "fun"_a,
1287
+ "primals"_a,
1288
+ "cotangents"_a,
1289
+ nb::sig(
1290
+ "def vjp(fun: Callable, primals: list[array], cotangents: list[array]) -> tuple[list[array], list[array]]"),
1291
+ R"pbdoc(
1292
+ Compute the vector-Jacobian product.
1293
+
1294
+ Computes the product of the ``cotangents`` with the Jacobian of a
1295
+ function ``fun`` evaluated at ``primals``.
1296
+
1297
+ Args:
1298
+ fun (Callable): A function which takes a variable number of :class:`array`
1299
+ and returns a single :class:`array` or list of :class:`array`.
1300
+ primals (list(array)): A list of :class:`array` at which to
1301
+ evaluate the Jacobian.
1302
+ cotangents (list(array)): A list of :class:`array` which are the
1303
+ "vector" in the vector-Jacobian product. The ``cotangents`` should be the
1304
+ same in number, shape, and type as the outputs of ``fun``.
1305
+
1306
+ Returns:
1307
+ tuple(list(array), list(array)): A tuple with the outputs of
1308
+ ``fun`` in the first position and the vector-Jacobian products
1309
+ in the second position.
1310
+
1311
+ Example:
1312
+
1313
+ .. code-block:: python
1314
+
1315
+ import mlx.core as mx
1316
+
1317
+ outs, vjps = mx.vjp(mx.sin, (mx.array(1.0),), (mx.array(1.0),))
1318
+
1319
+ )pbdoc");
1320
+ m.def(
1321
+ "value_and_grad",
1322
+ [](const nb::callable& fun,
1323
+ const std::optional<IntOrVec>& argnums,
1324
+ const StrOrSet& argnames) {
1325
+ auto [argnums_vec, argnames_set] =
1326
+ validate_argnums_argnames(argnums, argnames);
1327
+ return mlx_func(
1328
+ py_value_and_grad(
1329
+ fun, argnums_vec, argnames_set, "[value_and_grad]", false),
1330
+ fun);
1331
+ },
1332
+ "fun"_a,
1333
+ "argnums"_a = nb::none(),
1334
+ "argnames"_a = std::vector<std::string>{},
1335
+ nb::sig(
1336
+ "def value_and_grad(fun: Callable[P, R], argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable[P, Tuple[R, Any]]"),
1337
+ R"pbdoc(
1338
+ Returns a function which computes the value and gradient of ``fun``.
1339
+
1340
+ The function passed to :func:`value_and_grad` should return either
1341
+ a scalar loss or a tuple in which the first element is a scalar
1342
+ loss and the remaining elements can be anything.
1343
+
1344
+ .. code-block:: python
1345
+
1346
+ import mlx.core as mx
1347
+
1348
+ def mse(params, inputs, targets):
1349
+ outputs = forward(params, inputs)
1350
+ lvalue = (outputs - targets).square().mean()
1351
+ return lvalue
1352
+
1353
+ # Returns lvalue, dlvalue/dparams
1354
+ lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets)
1355
+
1356
+ def lasso(params, inputs, targets, a=1.0, b=1.0):
1357
+ outputs = forward(params, inputs)
1358
+ mse = (outputs - targets).square().mean()
1359
+ l1 = mx.abs(outputs - targets).mean()
1360
+
1361
+ loss = a*mse + b*l1
1362
+
1363
+ return loss, mse, l1
1364
+
1365
+ (loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets)
1366
+
1367
+ Args:
1368
+ fun (Callable): A function which takes a variable number of
1369
+ :class:`array` or trees of :class:`array` and returns
1370
+ a scalar output :class:`array` or a tuple the first element
1371
+ of which should be a scalar :class:`array`.
1372
+ argnums (int or list(int), optional): Specify the index (or indices)
1373
+ of the positional arguments of ``fun`` to compute the gradient
1374
+ with respect to. If neither ``argnums`` nor ``argnames`` are
1375
+ provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
1376
+ argument.
1377
+ argnames (str or list(str), optional): Specify keyword arguments of
1378
+ ``fun`` to compute gradients with respect to. It defaults to [] so
1379
+ no gradients for keyword arguments by default.
1380
+
1381
+ Returns:
1382
+ Callable: A function which returns a tuple where the first element
1383
+ is the output of `fun` and the second element is the gradients w.r.t.
1384
+ the loss.
1385
+ )pbdoc");
1386
+ m.def(
1387
+ "grad",
1388
+ [](const nb::callable& fun,
1389
+ const std::optional<IntOrVec>& argnums,
1390
+ const StrOrSet& argnames) {
1391
+ auto [argnums_vec, argnames_set] =
1392
+ validate_argnums_argnames(argnums, argnames);
1393
+ auto fn =
1394
+ py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true);
1395
+ return mlx_func(
1396
+ [fn = std::move(fn)](nb::args& args, nb::kwargs& kwargs) {
1397
+ return fn(args, kwargs).second;
1398
+ },
1399
+ fun);
1400
+ },
1401
+ "fun"_a,
1402
+ "argnums"_a = nb::none(),
1403
+ "argnames"_a = std::vector<std::string>{},
1404
+ nb::sig(
1405
+ "def grad(fun: Callable[P, R], argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable[P, Any]"),
1406
+ R"pbdoc(
1407
+ Returns a function which computes the gradient of ``fun``.
1408
+
1409
+ Args:
1410
+ fun (Callable): A function which takes a variable number of
1411
+ :class:`array` or trees of :class:`array` and returns
1412
+ a scalar output :class:`array`.
1413
+ argnums (int or list(int), optional): Specify the index (or indices)
1414
+ of the positional arguments of ``fun`` to compute the gradient
1415
+ with respect to. If neither ``argnums`` nor ``argnames`` are
1416
+ provided ``argnums`` defaults to ``0`` indicating ``fun``'s first
1417
+ argument.
1418
+ argnames (str or list(str), optional): Specify keyword arguments of
1419
+ ``fun`` to compute gradients with respect to. It defaults to [] so
1420
+ no gradients for keyword arguments by default.
1421
+
1422
+ Returns:
1423
+ Callable: A function which has the same input arguments as ``fun`` and
1424
+ returns the gradient(s).
1425
+ )pbdoc");
1426
+ m.def(
1427
+ "vmap",
1428
+ [](const nb::callable& fun,
1429
+ const nb::object& in_axes,
1430
+ const nb::object& out_axes) {
1431
+ return mlx_func(
1432
+ py_vmap(fun, in_axes, out_axes), fun, in_axes, out_axes);
1433
+ },
1434
+ "fun"_a,
1435
+ "in_axes"_a = 0,
1436
+ "out_axes"_a = 0,
1437
+ nb::sig(
1438
+ "def vmap(fun: Callable[P, R], in_axes: object = 0, out_axes: object = 0) -> Callable[P, R]"),
1439
+ R"pbdoc(
1440
+ Returns a vectorized version of ``fun``.
1441
+
1442
+ Args:
1443
+ fun (Callable): A function which takes a variable number of
1444
+ :class:`array` or a tree of :class:`array` and returns
1445
+ a variable number of :class:`array` or a tree of :class:`array`.
1446
+ in_axes (int, optional): An integer or a valid prefix tree of the
1447
+ inputs to ``fun`` where each node specifies the vmapped axis. If
1448
+ the value is ``None`` then the corresponding input(s) are not vmapped.
1449
+ Defaults to ``0``.
1450
+ out_axes (int, optional): An integer or a valid prefix tree of the
1451
+ outputs of ``fun`` where each node specifies the vmapped axis. If
1452
+ the value is ``None`` then the corresponding outputs(s) are not vmapped.
1453
+ Defaults to ``0``.
1454
+
1455
+ Returns:
1456
+ Callable: The vectorized function.
1457
+ )pbdoc");
1458
+ m.def(
1459
+ "compile",
1460
+ [](const nb::callable& fun,
1461
+ const nb::object& inputs,
1462
+ const nb::object& outputs,
1463
+ bool shapeless) {
1464
+ return mlx_func(
1465
+ nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}),
1466
+ fun,
1467
+ inputs,
1468
+ outputs);
1469
+ },
1470
+ "fun"_a,
1471
+ "inputs"_a = nb::none(),
1472
+ "outputs"_a = nb::none(),
1473
+ "shapeless"_a = false,
1474
+ nb::sig(
1475
+ "def compile(fun: Callable[P, R], inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable[P, R]"),
1476
+ R"pbdoc(
1477
+ Returns a compiled function which produces the same output as ``fun``.
1478
+
1479
+ Args:
1480
+ fun (Callable): A function which takes a variable number of
1481
+ :class:`array` or trees of :class:`array` and returns
1482
+ a variable number of :class:`array` or trees of :class:`array`.
1483
+ inputs (list or dict, optional): These inputs will be captured during
1484
+ the function compilation along with the inputs to ``fun``. The ``inputs``
1485
+ can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested
1486
+ lists, dictionaries, or arrays. Leaf nodes that are not
1487
+ :obj:`array` are ignored. Default: ``None``
1488
+ outputs (list or dict, optional): These outputs will be captured and
1489
+ updated in a compiled function. The ``outputs`` can be a
1490
+ :obj:`list` or a :obj:`dict` containing arbitrarily nested lists,
1491
+ dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored.
1492
+ Default: ``None``
1493
+ shapeless (bool, optional): A function compiled with the ``shapeless``
1494
+ option enabled will not be recompiled when the input shape changes. Not all
1495
+ functions can be compiled with ``shapeless`` enabled. Attempting to compile
1496
+ such functions with shapeless enabled will throw. Note, changing the number
1497
+ of dimensions or type of any input will result in a recompilation even with
1498
+ ``shapeless`` set to ``True``. Default: ``False``
1499
+
1500
+ Returns:
1501
+ Callable: A compiled function which has the same input arguments
1502
+ as ``fun`` and returns the the same output(s).
1503
+ )pbdoc");
1504
+ m.def(
1505
+ "disable_compile",
1506
+ &mx::disable_compile,
1507
+ R"pbdoc(
1508
+ Globally disable compilation. Setting the environment variable
1509
+ ``MLX_DISABLE_COMPILE`` can also be used to disable compilation.
1510
+ )pbdoc");
1511
+ m.def(
1512
+ "enable_compile",
1513
+ &mx::enable_compile,
1514
+ R"pbdoc(
1515
+ Globally enable compilation. This will override the environment
1516
+ variable ``MLX_DISABLE_COMPILE`` if set.
1517
+ )pbdoc");
1518
+ m.def(
1519
+ "checkpoint",
1520
+ [](nb::callable fun) { return mlx_func(PyCheckpointedFun{fun}, fun); },
1521
+ "fun"_a,
1522
+ nb::sig("def checkpoint(fun: Callable[P, R]) -> Callable[P, R]"),
1523
+ R"pbdoc(
1524
+ Transform the passed callable to one that performs gradient
1525
+ checkpointing with respect to the inputs of the callable.
1526
+
1527
+ Use this to reduce memory use for gradient computations at the expense of
1528
+ increased computation.
1529
+
1530
+ Args:
1531
+ fun (Callable): The function to checkpoint.
1532
+
1533
+ Returns:
1534
+ A callable that recomputes intermediate states during gradient
1535
+ computation.
1536
+ )pbdoc");
1537
+
1538
+ // Register static Python object cleanup before the interpreter exits
1539
+ auto atexit = nb::module_::import_("atexit");
1540
+ atexit.attr("register")(
1541
+ nb::cpp_function([]() { mx::detail::compile_clear_cache(); }));
1542
+ }