mlx 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of mlx might be problematic. Click here for more details.

Files changed (914) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/CMakeLists.txt +7 -0
  3. data/ext/mlx/Makefile +273 -0
  4. data/ext/mlx/extconf.rb +94 -0
  5. data/ext/mlx/mkmf.log +44 -0
  6. data/ext/mlx/native.bundle +0 -0
  7. data/ext/mlx/native.bundle.dSYM/Contents/Info.plist +20 -0
  8. data/ext/mlx/native.bundle.dSYM/Contents/Resources/DWARF/native.bundle +0 -0
  9. data/ext/mlx/native.bundle.dSYM/Contents/Resources/Relocations/aarch64/native.bundle.yml +5 -0
  10. data/ext/mlx/native.cpp +8027 -0
  11. data/ext/mlx/native.o +0 -0
  12. data/lib/mlx/core.rb +1678 -0
  13. data/lib/mlx/distributed_utils/common.rb +116 -0
  14. data/lib/mlx/distributed_utils/config.rb +600 -0
  15. data/lib/mlx/distributed_utils/launch.rb +490 -0
  16. data/lib/mlx/extension.rb +24 -0
  17. data/lib/mlx/nn/base.rb +388 -0
  18. data/lib/mlx/nn/init.rb +140 -0
  19. data/lib/mlx/nn/layers/activations.rb +336 -0
  20. data/lib/mlx/nn/layers/base.rb +6 -0
  21. data/lib/mlx/nn/layers/containers.rb +20 -0
  22. data/lib/mlx/nn/layers/convolution.rb +120 -0
  23. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  24. data/lib/mlx/nn/layers/distributed.rb +309 -0
  25. data/lib/mlx/nn/layers/dropout.rb +75 -0
  26. data/lib/mlx/nn/layers/embedding.rb +28 -0
  27. data/lib/mlx/nn/layers/linear.rb +79 -0
  28. data/lib/mlx/nn/layers/normalization.rb +216 -0
  29. data/lib/mlx/nn/layers/pooling.rb +167 -0
  30. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  31. data/lib/mlx/nn/layers/quantized.rb +215 -0
  32. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  33. data/lib/mlx/nn/layers/transformer.rb +330 -0
  34. data/lib/mlx/nn/layers/upsample.rb +97 -0
  35. data/lib/mlx/nn/layers.rb +18 -0
  36. data/lib/mlx/nn/losses.rb +251 -0
  37. data/lib/mlx/nn/utils.rb +167 -0
  38. data/lib/mlx/nn.rb +12 -0
  39. data/lib/mlx/optimizers/optimizers.rb +808 -0
  40. data/lib/mlx/optimizers/schedulers.rb +62 -0
  41. data/lib/mlx/optimizers.rb +9 -0
  42. data/lib/mlx/utils.rb +171 -0
  43. data/lib/mlx/version +1 -0
  44. data/lib/mlx/version.rb +5 -0
  45. data/lib/mlx.rb +64 -0
  46. data/mlx/.clang-format +87 -0
  47. data/mlx/.git +1 -0
  48. data/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
  49. data/mlx/.github/actions/build-cuda-release/action.yml +31 -0
  50. data/mlx/.github/actions/build-docs/action.yml +38 -0
  51. data/mlx/.github/actions/build-linux/action.yml +38 -0
  52. data/mlx/.github/actions/build-linux-release/action.yml +42 -0
  53. data/mlx/.github/actions/build-macos/action.yml +80 -0
  54. data/mlx/.github/actions/build-macos-release/action.yml +36 -0
  55. data/mlx/.github/actions/build-windows/action.yml +26 -0
  56. data/mlx/.github/actions/setup-linux/action.yml +93 -0
  57. data/mlx/.github/actions/setup-macos/action.yml +24 -0
  58. data/mlx/.github/actions/setup-windows/action.yml +42 -0
  59. data/mlx/.github/actions/test-linux/action.yml +69 -0
  60. data/mlx/.github/actions/test-windows/action.yml +20 -0
  61. data/mlx/.github/dependabot.yml +6 -0
  62. data/mlx/.github/pull_request_template.md +12 -0
  63. data/mlx/.github/scripts/build-sanitizer-tests.sh +48 -0
  64. data/mlx/.github/scripts/setup+build-cpp-linux-fedora-container.sh +27 -0
  65. data/mlx/.github/workflows/build_and_test.yml +152 -0
  66. data/mlx/.github/workflows/documentation.yml +28 -0
  67. data/mlx/.github/workflows/nightly.yml +104 -0
  68. data/mlx/.github/workflows/release.yml +256 -0
  69. data/mlx/.gitignore +81 -0
  70. data/mlx/.pre-commit-config.yaml +27 -0
  71. data/mlx/ACKNOWLEDGMENTS.md +268 -0
  72. data/mlx/CITATION.cff +24 -0
  73. data/mlx/CMakeLists.txt +437 -0
  74. data/mlx/CODE_OF_CONDUCT.md +132 -0
  75. data/mlx/CONTRIBUTING.md +38 -0
  76. data/mlx/LICENSE +21 -0
  77. data/mlx/MANIFEST.in +6 -0
  78. data/mlx/README.md +121 -0
  79. data/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
  80. data/mlx/benchmarks/cpp/autograd.cpp +39 -0
  81. data/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
  82. data/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
  83. data/mlx/benchmarks/cpp/single_ops.cpp +288 -0
  84. data/mlx/benchmarks/cpp/time_utils.h +39 -0
  85. data/mlx/benchmarks/numpy/single_ops.py +39 -0
  86. data/mlx/benchmarks/numpy/time_utils.py +20 -0
  87. data/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
  88. data/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
  89. data/mlx/benchmarks/python/blas/bench_gemv.py +220 -0
  90. data/mlx/benchmarks/python/comparative/README.md +15 -0
  91. data/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
  92. data/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
  93. data/mlx/benchmarks/python/comparative/compare.py +284 -0
  94. data/mlx/benchmarks/python/compile_bench.py +107 -0
  95. data/mlx/benchmarks/python/conv1d_bench.py +123 -0
  96. data/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
  97. data/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
  98. data/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
  99. data/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
  100. data/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
  101. data/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
  102. data/mlx/benchmarks/python/conv_bench.py +135 -0
  103. data/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
  104. data/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
  105. data/mlx/benchmarks/python/distributed_bench.py +66 -0
  106. data/mlx/benchmarks/python/einsum_bench.py +84 -0
  107. data/mlx/benchmarks/python/fft_bench.py +118 -0
  108. data/mlx/benchmarks/python/gather_bench.py +52 -0
  109. data/mlx/benchmarks/python/gather_mm_bench.py +74 -0
  110. data/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
  111. data/mlx/benchmarks/python/hadamard_bench.py +70 -0
  112. data/mlx/benchmarks/python/large_gemm_bench.py +119 -0
  113. data/mlx/benchmarks/python/layer_norm_bench.py +82 -0
  114. data/mlx/benchmarks/python/masked_scatter.py +212 -0
  115. data/mlx/benchmarks/python/rms_norm_bench.py +63 -0
  116. data/mlx/benchmarks/python/rope_bench.py +35 -0
  117. data/mlx/benchmarks/python/scatter_bench.py +96 -0
  118. data/mlx/benchmarks/python/sdpa_bench.py +223 -0
  119. data/mlx/benchmarks/python/sdpa_vector_bench.py +95 -0
  120. data/mlx/benchmarks/python/single_ops.py +132 -0
  121. data/mlx/benchmarks/python/synchronize_bench.py +55 -0
  122. data/mlx/benchmarks/python/time_utils.py +38 -0
  123. data/mlx/cmake/FindCUDNN.cmake +177 -0
  124. data/mlx/cmake/FindNCCL.cmake +54 -0
  125. data/mlx/cmake/Findnvpl.cmake +3 -0
  126. data/mlx/cmake/extension.cmake +50 -0
  127. data/mlx/docs/.clang-format +2 -0
  128. data/mlx/docs/.gitignore +3 -0
  129. data/mlx/docs/.nojekyll +0 -0
  130. data/mlx/docs/Doxyfile +51 -0
  131. data/mlx/docs/Makefile +18 -0
  132. data/mlx/docs/README.md +54 -0
  133. data/mlx/docs/index.html +1 -0
  134. data/mlx/docs/requirements.txt +5 -0
  135. data/mlx/docs/src/_static/distributed/m3-ultra-mesh-broken.png +0 -0
  136. data/mlx/docs/src/_static/distributed/m3-ultra-mesh.png +0 -0
  137. data/mlx/docs/src/_static/metal_debugger/capture.png +0 -0
  138. data/mlx/docs/src/_static/metal_debugger/schema.png +0 -0
  139. data/mlx/docs/src/_static/mlx_logo.png +0 -0
  140. data/mlx/docs/src/_static/mlx_logo_dark.png +0 -0
  141. data/mlx/docs/src/_static/tp_inference/all-to-sharded-linear.png +0 -0
  142. data/mlx/docs/src/_static/tp_inference/column-row-tp.png +0 -0
  143. data/mlx/docs/src/_static/tp_inference/llama-transformer.png +0 -0
  144. data/mlx/docs/src/_static/tp_inference/sharded-to-all-linear.png +0 -0
  145. data/mlx/docs/src/_templates/module-base-class.rst +33 -0
  146. data/mlx/docs/src/_templates/nn-module-template.rst +20 -0
  147. data/mlx/docs/src/_templates/optimizers-template.rst +20 -0
  148. data/mlx/docs/src/conf.py +99 -0
  149. data/mlx/docs/src/cpp/ops.rst +7 -0
  150. data/mlx/docs/src/dev/custom_metal_kernels.rst +445 -0
  151. data/mlx/docs/src/dev/extensions.rst +811 -0
  152. data/mlx/docs/src/dev/metal_debugger.rst +68 -0
  153. data/mlx/docs/src/dev/metal_logging.rst +40 -0
  154. data/mlx/docs/src/dev/mlx_in_cpp.rst +121 -0
  155. data/mlx/docs/src/examples/data_parallelism.rst +91 -0
  156. data/mlx/docs/src/examples/linear_regression.rst +77 -0
  157. data/mlx/docs/src/examples/llama-inference.rst +382 -0
  158. data/mlx/docs/src/examples/mlp.rst +134 -0
  159. data/mlx/docs/src/examples/tensor_parallelism.rst +239 -0
  160. data/mlx/docs/src/index.rst +96 -0
  161. data/mlx/docs/src/install.rst +340 -0
  162. data/mlx/docs/src/python/array.rst +65 -0
  163. data/mlx/docs/src/python/cuda.rst +9 -0
  164. data/mlx/docs/src/python/data_types.rst +78 -0
  165. data/mlx/docs/src/python/devices_and_streams.rst +21 -0
  166. data/mlx/docs/src/python/distributed.rst +22 -0
  167. data/mlx/docs/src/python/export.rst +14 -0
  168. data/mlx/docs/src/python/fast.rst +16 -0
  169. data/mlx/docs/src/python/fft.rst +24 -0
  170. data/mlx/docs/src/python/linalg.rst +27 -0
  171. data/mlx/docs/src/python/memory_management.rst +16 -0
  172. data/mlx/docs/src/python/metal.rst +12 -0
  173. data/mlx/docs/src/python/nn/distributed.rst +30 -0
  174. data/mlx/docs/src/python/nn/functions.rst +40 -0
  175. data/mlx/docs/src/python/nn/init.rst +45 -0
  176. data/mlx/docs/src/python/nn/layers.rst +74 -0
  177. data/mlx/docs/src/python/nn/losses.rst +25 -0
  178. data/mlx/docs/src/python/nn/module.rst +38 -0
  179. data/mlx/docs/src/python/nn.rst +186 -0
  180. data/mlx/docs/src/python/ops.rst +184 -0
  181. data/mlx/docs/src/python/optimizers/common_optimizers.rst +22 -0
  182. data/mlx/docs/src/python/optimizers/optimizer.rst +23 -0
  183. data/mlx/docs/src/python/optimizers/schedulers.rst +15 -0
  184. data/mlx/docs/src/python/optimizers.rst +78 -0
  185. data/mlx/docs/src/python/random.rst +48 -0
  186. data/mlx/docs/src/python/transforms.rst +22 -0
  187. data/mlx/docs/src/python/tree_utils.rst +23 -0
  188. data/mlx/docs/src/usage/compile.rst +516 -0
  189. data/mlx/docs/src/usage/distributed.rst +572 -0
  190. data/mlx/docs/src/usage/export.rst +288 -0
  191. data/mlx/docs/src/usage/function_transforms.rst +191 -0
  192. data/mlx/docs/src/usage/indexing.rst +194 -0
  193. data/mlx/docs/src/usage/launching_distributed.rst +234 -0
  194. data/mlx/docs/src/usage/lazy_evaluation.rst +144 -0
  195. data/mlx/docs/src/usage/numpy.rst +124 -0
  196. data/mlx/docs/src/usage/quick_start.rst +67 -0
  197. data/mlx/docs/src/usage/saving_and_loading.rst +81 -0
  198. data/mlx/docs/src/usage/unified_memory.rst +78 -0
  199. data/mlx/docs/src/usage/using_streams.rst +18 -0
  200. data/mlx/examples/cmake_project/CMakeLists.txt +22 -0
  201. data/mlx/examples/cmake_project/README.md +26 -0
  202. data/mlx/examples/cmake_project/example.cpp +14 -0
  203. data/mlx/examples/cpp/CMakeLists.txt +12 -0
  204. data/mlx/examples/cpp/distributed.cpp +22 -0
  205. data/mlx/examples/cpp/linear_regression.cpp +54 -0
  206. data/mlx/examples/cpp/logistic_regression.cpp +54 -0
  207. data/mlx/examples/cpp/metal_capture.cpp +31 -0
  208. data/mlx/examples/cpp/timer.h +20 -0
  209. data/mlx/examples/cpp/tutorial.cpp +99 -0
  210. data/mlx/examples/export/CMakeLists.txt +22 -0
  211. data/mlx/examples/export/README.md +49 -0
  212. data/mlx/examples/export/eval_mlp.cpp +25 -0
  213. data/mlx/examples/export/eval_mlp.py +52 -0
  214. data/mlx/examples/export/train_mlp.cpp +35 -0
  215. data/mlx/examples/export/train_mlp.py +76 -0
  216. data/mlx/examples/extensions/CMakeLists.txt +78 -0
  217. data/mlx/examples/extensions/README.md +24 -0
  218. data/mlx/examples/extensions/axpby/axpby.cpp +306 -0
  219. data/mlx/examples/extensions/axpby/axpby.h +90 -0
  220. data/mlx/examples/extensions/axpby/axpby.metal +47 -0
  221. data/mlx/examples/extensions/bindings.cpp +39 -0
  222. data/mlx/examples/extensions/mlx_sample_extensions/__init__.py +5 -0
  223. data/mlx/examples/extensions/pyproject.toml +8 -0
  224. data/mlx/examples/extensions/requirements.txt +4 -0
  225. data/mlx/examples/extensions/setup.py +18 -0
  226. data/mlx/examples/extensions/test.py +12 -0
  227. data/mlx/examples/python/linear_regression.py +46 -0
  228. data/mlx/examples/python/logistic_regression.py +49 -0
  229. data/mlx/examples/python/qqmm.py +117 -0
  230. data/mlx/mlx/3rdparty/.clang-format +2 -0
  231. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  232. data/mlx/mlx/CMakeLists.txt +107 -0
  233. data/mlx/mlx/allocator.h +75 -0
  234. data/mlx/mlx/api.h +29 -0
  235. data/mlx/mlx/array.cpp +354 -0
  236. data/mlx/mlx/array.h +647 -0
  237. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  238. data/mlx/mlx/backend/common/binary.h +97 -0
  239. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  240. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  241. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  242. data/mlx/mlx/backend/common/common.cpp +305 -0
  243. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  244. data/mlx/mlx/backend/common/compiled.h +77 -0
  245. data/mlx/mlx/backend/common/copy.h +50 -0
  246. data/mlx/mlx/backend/common/hadamard.h +109 -0
  247. data/mlx/mlx/backend/common/load.cpp +57 -0
  248. data/mlx/mlx/backend/common/matmul.h +67 -0
  249. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  250. data/mlx/mlx/backend/common/reduce.h +59 -0
  251. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  252. data/mlx/mlx/backend/common/slicing.h +20 -0
  253. data/mlx/mlx/backend/common/ternary.h +85 -0
  254. data/mlx/mlx/backend/common/unary.h +29 -0
  255. data/mlx/mlx/backend/common/utils.cpp +231 -0
  256. data/mlx/mlx/backend/common/utils.h +205 -0
  257. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  258. data/mlx/mlx/backend/cpu/arange.h +28 -0
  259. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  260. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  261. data/mlx/mlx/backend/cpu/binary.h +517 -0
  262. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  263. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  264. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  265. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  266. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  267. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  268. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  269. data/mlx/mlx/backend/cpu/copy.h +36 -0
  270. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  271. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  272. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  273. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  274. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  275. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  276. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  277. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  278. data/mlx/mlx/backend/cpu/eval.h +12 -0
  279. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  280. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  281. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  282. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  283. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  284. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  285. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  286. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  287. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  288. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  289. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  290. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  291. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  292. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  293. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  294. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  295. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  296. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  297. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  298. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  299. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  300. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  301. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  302. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  303. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  304. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  305. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  306. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  307. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  308. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  309. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  310. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  311. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  312. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  313. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  314. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  315. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  316. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  317. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  318. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  319. data/mlx/mlx/backend/cpu/unary.h +281 -0
  320. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  321. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  322. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  323. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  324. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  325. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  326. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  327. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  328. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  329. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  330. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  331. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  332. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  333. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  334. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  335. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  336. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  337. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  338. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  339. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  340. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  341. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  342. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  343. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  344. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  345. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  346. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  347. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  348. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  349. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  350. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  351. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  352. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  353. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  354. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  355. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  356. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  357. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  358. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  359. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  360. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  361. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  362. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  363. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  364. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  365. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  366. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  367. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  368. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  369. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  370. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  371. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  372. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  373. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  374. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  375. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  376. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  377. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  378. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  379. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  380. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  381. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  382. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  383. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  384. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  385. data/mlx/mlx/backend/cuda/device.h +195 -0
  386. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  387. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  388. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  389. data/mlx/mlx/backend/cuda/event.cu +415 -0
  390. data/mlx/mlx/backend/cuda/event.h +79 -0
  391. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  392. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  393. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  394. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  395. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  396. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  397. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  398. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  399. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  400. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  401. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  402. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  403. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  404. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  405. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  406. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  407. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  408. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  409. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  410. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  411. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  412. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  413. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  414. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  415. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  416. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  417. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  418. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  419. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  420. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  421. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  422. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  423. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  424. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  425. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  426. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  427. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  428. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  429. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  430. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  431. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  432. data/mlx/mlx/backend/cuda/random.cu +202 -0
  433. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  434. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  435. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  436. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  437. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  438. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  439. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  440. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  441. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  442. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  443. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  444. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  445. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  446. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  447. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  448. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  449. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  450. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  451. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  452. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  453. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  454. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  455. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  456. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  457. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  458. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  459. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  460. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  461. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  462. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  463. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  464. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  465. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  466. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  467. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  468. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  469. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  470. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  471. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  472. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  473. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  474. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  475. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  476. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  477. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  478. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  479. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  480. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  481. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  482. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  483. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  484. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  485. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  486. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  487. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  488. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  489. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  490. data/mlx/mlx/backend/cuda/utils.h +49 -0
  491. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  492. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  493. data/mlx/mlx/backend/cuda/worker.h +55 -0
  494. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  495. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  496. data/mlx/mlx/backend/gpu/copy.h +57 -0
  497. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  498. data/mlx/mlx/backend/gpu/eval.h +18 -0
  499. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  500. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  501. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  502. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  503. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  504. data/mlx/mlx/backend/metal/allocator.h +79 -0
  505. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  506. data/mlx/mlx/backend/metal/binary.h +33 -0
  507. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  508. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  509. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  510. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  511. data/mlx/mlx/backend/metal/device.cpp +816 -0
  512. data/mlx/mlx/backend/metal/device.h +289 -0
  513. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  514. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  515. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  516. data/mlx/mlx/backend/metal/event.cpp +62 -0
  517. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  518. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  519. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  520. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  521. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  522. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  523. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  524. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  525. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  526. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  527. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  528. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  529. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  530. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  531. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  532. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  533. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  534. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  535. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  536. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  537. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  538. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  539. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  540. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  541. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  542. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  543. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  544. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  545. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  546. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  547. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  548. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  549. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  550. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  551. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  552. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  553. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  554. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  555. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  556. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  557. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  558. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  559. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  560. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  561. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  562. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  563. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  564. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  565. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  566. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  567. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  568. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  569. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  570. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  571. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  572. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  573. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  574. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  575. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  576. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  577. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  578. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  579. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  580. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  581. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  582. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  583. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  584. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  585. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  586. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  587. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  588. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  589. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  590. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  591. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  592. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  593. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  594. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  595. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  596. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  597. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  598. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  599. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  600. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  601. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  602. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  603. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  604. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  605. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  606. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  607. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  608. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  609. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  610. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  611. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  612. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  613. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  614. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  615. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  616. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  617. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  618. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  619. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  620. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  621. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  622. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  623. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  624. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  625. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  626. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  627. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  628. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  629. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  630. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  631. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  632. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  633. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  634. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  635. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  636. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  637. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  638. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  639. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  640. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  641. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  642. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  643. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  644. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  645. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  646. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  647. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  648. data/mlx/mlx/backend/metal/kernels.h +375 -0
  649. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  650. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  651. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  652. data/mlx/mlx/backend/metal/matmul.h +144 -0
  653. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  654. data/mlx/mlx/backend/metal/metal.h +25 -0
  655. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  656. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  657. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  658. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  659. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  660. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  661. data/mlx/mlx/backend/metal/reduce.h +41 -0
  662. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  663. data/mlx/mlx/backend/metal/resident.h +32 -0
  664. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  665. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  666. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  667. data/mlx/mlx/backend/metal/scan.h +17 -0
  668. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  669. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  670. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  671. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  672. data/mlx/mlx/backend/metal/ternary.h +21 -0
  673. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  674. data/mlx/mlx/backend/metal/unary.h +21 -0
  675. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  676. data/mlx/mlx/backend/metal/utils.h +99 -0
  677. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  678. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  679. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  680. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  681. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  682. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  683. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  684. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  685. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  686. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  687. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  688. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  689. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  690. data/mlx/mlx/compile.cpp +1243 -0
  691. data/mlx/mlx/compile.h +45 -0
  692. data/mlx/mlx/compile_impl.h +70 -0
  693. data/mlx/mlx/device.cpp +72 -0
  694. data/mlx/mlx/device.h +56 -0
  695. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  696. data/mlx/mlx/distributed/distributed.cpp +197 -0
  697. data/mlx/mlx/distributed/distributed.h +61 -0
  698. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  699. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  700. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  701. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  702. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  703. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  704. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  705. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  706. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  707. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  708. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  709. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  710. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  711. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  712. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  713. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  714. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  715. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  716. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  717. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  718. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  719. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  720. data/mlx/mlx/distributed/ops.cpp +186 -0
  721. data/mlx/mlx/distributed/ops.h +57 -0
  722. data/mlx/mlx/distributed/primitives.cpp +95 -0
  723. data/mlx/mlx/distributed/primitives.h +156 -0
  724. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  725. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  726. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  727. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  728. data/mlx/mlx/distributed/ring/ring.h +12 -0
  729. data/mlx/mlx/distributed/utils.cpp +206 -0
  730. data/mlx/mlx/distributed/utils.h +67 -0
  731. data/mlx/mlx/dtype.cpp +197 -0
  732. data/mlx/mlx/dtype.h +116 -0
  733. data/mlx/mlx/dtype_utils.cpp +42 -0
  734. data/mlx/mlx/dtype_utils.h +119 -0
  735. data/mlx/mlx/einsum.cpp +941 -0
  736. data/mlx/mlx/einsum.h +23 -0
  737. data/mlx/mlx/event.h +58 -0
  738. data/mlx/mlx/export.cpp +1130 -0
  739. data/mlx/mlx/export.h +137 -0
  740. data/mlx/mlx/export_impl.h +99 -0
  741. data/mlx/mlx/fast.cpp +941 -0
  742. data/mlx/mlx/fast.h +103 -0
  743. data/mlx/mlx/fast_primitives.h +427 -0
  744. data/mlx/mlx/fence.h +39 -0
  745. data/mlx/mlx/fft.cpp +262 -0
  746. data/mlx/mlx/fft.h +159 -0
  747. data/mlx/mlx/graph_utils.cpp +175 -0
  748. data/mlx/mlx/graph_utils.h +67 -0
  749. data/mlx/mlx/io/CMakeLists.txt +25 -0
  750. data/mlx/mlx/io/gguf.cpp +470 -0
  751. data/mlx/mlx/io/gguf.h +20 -0
  752. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  753. data/mlx/mlx/io/load.cpp +397 -0
  754. data/mlx/mlx/io/load.h +175 -0
  755. data/mlx/mlx/io/no_gguf.cpp +20 -0
  756. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  757. data/mlx/mlx/io/safetensors.cpp +234 -0
  758. data/mlx/mlx/io.h +61 -0
  759. data/mlx/mlx/linalg.cpp +708 -0
  760. data/mlx/mlx/linalg.h +115 -0
  761. data/mlx/mlx/memory.h +80 -0
  762. data/mlx/mlx/mlx.h +25 -0
  763. data/mlx/mlx/ops.cpp +6094 -0
  764. data/mlx/mlx/ops.h +1610 -0
  765. data/mlx/mlx/primitives.cpp +5850 -0
  766. data/mlx/mlx/primitives.h +2525 -0
  767. data/mlx/mlx/random.cpp +492 -0
  768. data/mlx/mlx/random.h +283 -0
  769. data/mlx/mlx/scheduler.cpp +73 -0
  770. data/mlx/mlx/scheduler.h +189 -0
  771. data/mlx/mlx/small_vector.h +540 -0
  772. data/mlx/mlx/stream.h +42 -0
  773. data/mlx/mlx/threadpool.h +133 -0
  774. data/mlx/mlx/transforms.cpp +1065 -0
  775. data/mlx/mlx/transforms.h +231 -0
  776. data/mlx/mlx/transforms_impl.h +88 -0
  777. data/mlx/mlx/types/bf16.h +187 -0
  778. data/mlx/mlx/types/complex.h +113 -0
  779. data/mlx/mlx/types/fp16.h +234 -0
  780. data/mlx/mlx/types/half_types.h +58 -0
  781. data/mlx/mlx/types/limits.h +70 -0
  782. data/mlx/mlx/utils.cpp +302 -0
  783. data/mlx/mlx/utils.h +174 -0
  784. data/mlx/mlx/version.cpp +11 -0
  785. data/mlx/mlx/version.h +22 -0
  786. data/mlx/mlx.pc.in +52 -0
  787. data/mlx/pyproject.toml +7 -0
  788. data/mlx/python/mlx/__main__.py +27 -0
  789. data/mlx/python/mlx/_distributed_utils/common.py +135 -0
  790. data/mlx/python/mlx/_distributed_utils/config.py +631 -0
  791. data/mlx/python/mlx/_distributed_utils/launch.py +570 -0
  792. data/mlx/python/mlx/_reprlib_fix.py +16 -0
  793. data/mlx/python/mlx/_stub_patterns.txt +36 -0
  794. data/mlx/python/mlx/extension.py +88 -0
  795. data/mlx/python/mlx/nn/__init__.py +5 -0
  796. data/mlx/python/mlx/nn/init.py +441 -0
  797. data/mlx/python/mlx/nn/layers/__init__.py +105 -0
  798. data/mlx/python/mlx/nn/layers/activations.py +661 -0
  799. data/mlx/python/mlx/nn/layers/base.py +675 -0
  800. data/mlx/python/mlx/nn/layers/containers.py +24 -0
  801. data/mlx/python/mlx/nn/layers/convolution.py +232 -0
  802. data/mlx/python/mlx/nn/layers/convolution_transpose.py +242 -0
  803. data/mlx/python/mlx/nn/layers/distributed.py +601 -0
  804. data/mlx/python/mlx/nn/layers/dropout.py +137 -0
  805. data/mlx/python/mlx/nn/layers/embedding.py +53 -0
  806. data/mlx/python/mlx/nn/layers/linear.py +180 -0
  807. data/mlx/python/mlx/nn/layers/normalization.py +363 -0
  808. data/mlx/python/mlx/nn/layers/pooling.py +398 -0
  809. data/mlx/python/mlx/nn/layers/positional_encoding.py +162 -0
  810. data/mlx/python/mlx/nn/layers/quantized.py +426 -0
  811. data/mlx/python/mlx/nn/layers/recurrent.py +289 -0
  812. data/mlx/python/mlx/nn/layers/transformer.py +354 -0
  813. data/mlx/python/mlx/nn/layers/upsample.py +277 -0
  814. data/mlx/python/mlx/nn/losses.py +610 -0
  815. data/mlx/python/mlx/nn/utils.py +165 -0
  816. data/mlx/python/mlx/optimizers/__init__.py +4 -0
  817. data/mlx/python/mlx/optimizers/optimizers.py +976 -0
  818. data/mlx/python/mlx/optimizers/schedulers.py +158 -0
  819. data/mlx/python/mlx/py.typed +1 -0
  820. data/mlx/python/mlx/utils.py +325 -0
  821. data/mlx/python/src/CMakeLists.txt +96 -0
  822. data/mlx/python/src/array.cpp +1525 -0
  823. data/mlx/python/src/buffer.h +124 -0
  824. data/mlx/python/src/constants.cpp +15 -0
  825. data/mlx/python/src/convert.cpp +504 -0
  826. data/mlx/python/src/convert.h +50 -0
  827. data/mlx/python/src/cuda.cpp +19 -0
  828. data/mlx/python/src/device.cpp +98 -0
  829. data/mlx/python/src/distributed.cpp +352 -0
  830. data/mlx/python/src/export.cpp +356 -0
  831. data/mlx/python/src/fast.cpp +627 -0
  832. data/mlx/python/src/fft.cpp +514 -0
  833. data/mlx/python/src/indexing.cpp +1016 -0
  834. data/mlx/python/src/indexing.h +41 -0
  835. data/mlx/python/src/linalg.cpp +663 -0
  836. data/mlx/python/src/load.cpp +531 -0
  837. data/mlx/python/src/load.h +51 -0
  838. data/mlx/python/src/memory.cpp +125 -0
  839. data/mlx/python/src/metal.cpp +98 -0
  840. data/mlx/python/src/mlx.cpp +51 -0
  841. data/mlx/python/src/mlx_func.cpp +116 -0
  842. data/mlx/python/src/mlx_func.h +31 -0
  843. data/mlx/python/src/ops.cpp +5545 -0
  844. data/mlx/python/src/random.cpp +516 -0
  845. data/mlx/python/src/small_vector.h +76 -0
  846. data/mlx/python/src/stream.cpp +147 -0
  847. data/mlx/python/src/transforms.cpp +1542 -0
  848. data/mlx/python/src/trees.cpp +311 -0
  849. data/mlx/python/src/trees.h +62 -0
  850. data/mlx/python/src/utils.cpp +98 -0
  851. data/mlx/python/src/utils.h +78 -0
  852. data/mlx/python/tests/__main__.py +5 -0
  853. data/mlx/python/tests/cuda_skip.py +62 -0
  854. data/mlx/python/tests/mlx_distributed_tests.py +314 -0
  855. data/mlx/python/tests/mlx_tests.py +116 -0
  856. data/mlx/python/tests/mpi_test_distributed.py +142 -0
  857. data/mlx/python/tests/nccl_test_distributed.py +52 -0
  858. data/mlx/python/tests/ring_test_distributed.py +131 -0
  859. data/mlx/python/tests/test_array.py +2139 -0
  860. data/mlx/python/tests/test_autograd.py +880 -0
  861. data/mlx/python/tests/test_bf16.py +196 -0
  862. data/mlx/python/tests/test_blas.py +1429 -0
  863. data/mlx/python/tests/test_compile.py +1277 -0
  864. data/mlx/python/tests/test_constants.py +41 -0
  865. data/mlx/python/tests/test_conv.py +1198 -0
  866. data/mlx/python/tests/test_conv_transpose.py +810 -0
  867. data/mlx/python/tests/test_device.py +150 -0
  868. data/mlx/python/tests/test_double.py +306 -0
  869. data/mlx/python/tests/test_einsum.py +363 -0
  870. data/mlx/python/tests/test_eval.py +200 -0
  871. data/mlx/python/tests/test_export_import.py +614 -0
  872. data/mlx/python/tests/test_fast.py +923 -0
  873. data/mlx/python/tests/test_fast_sdpa.py +647 -0
  874. data/mlx/python/tests/test_fft.py +323 -0
  875. data/mlx/python/tests/test_graph.py +37 -0
  876. data/mlx/python/tests/test_init.py +139 -0
  877. data/mlx/python/tests/test_linalg.py +621 -0
  878. data/mlx/python/tests/test_load.py +447 -0
  879. data/mlx/python/tests/test_losses.py +427 -0
  880. data/mlx/python/tests/test_memory.py +77 -0
  881. data/mlx/python/tests/test_nn.py +1986 -0
  882. data/mlx/python/tests/test_ops.py +3261 -0
  883. data/mlx/python/tests/test_optimizers.py +584 -0
  884. data/mlx/python/tests/test_quantized.py +1160 -0
  885. data/mlx/python/tests/test_random.py +392 -0
  886. data/mlx/python/tests/test_reduce.py +223 -0
  887. data/mlx/python/tests/test_tree.py +96 -0
  888. data/mlx/python/tests/test_upsample.py +100 -0
  889. data/mlx/python/tests/test_vmap.py +860 -0
  890. data/mlx/setup.py +315 -0
  891. data/mlx/tests/CMakeLists.txt +44 -0
  892. data/mlx/tests/allocator_tests.cpp +41 -0
  893. data/mlx/tests/arg_reduce_tests.cpp +204 -0
  894. data/mlx/tests/array_tests.cpp +663 -0
  895. data/mlx/tests/autograd_tests.cpp +1399 -0
  896. data/mlx/tests/blas_tests.cpp +110 -0
  897. data/mlx/tests/compile_tests.cpp +818 -0
  898. data/mlx/tests/creations_tests.cpp +239 -0
  899. data/mlx/tests/custom_vjp_tests.cpp +55 -0
  900. data/mlx/tests/device_tests.cpp +35 -0
  901. data/mlx/tests/einsum_tests.cpp +85 -0
  902. data/mlx/tests/eval_tests.cpp +93 -0
  903. data/mlx/tests/export_import_tests.cpp +164 -0
  904. data/mlx/tests/fft_tests.cpp +366 -0
  905. data/mlx/tests/gpu_tests.cpp +523 -0
  906. data/mlx/tests/linalg_tests.cpp +639 -0
  907. data/mlx/tests/load_tests.cpp +270 -0
  908. data/mlx/tests/ops_tests.cpp +4159 -0
  909. data/mlx/tests/random_tests.cpp +716 -0
  910. data/mlx/tests/scheduler_tests.cpp +121 -0
  911. data/mlx/tests/tests.cpp +26 -0
  912. data/mlx/tests/utils_tests.cpp +67 -0
  913. data/mlx/tests/vmap_tests.cpp +547 -0
  914. metadata +958 -0
@@ -0,0 +1,811 @@
1
+ Custom Extensions in MLX
2
+ ========================
3
+
4
+ You can extend MLX with custom operations on the CPU or GPU. This guide
5
+ explains how to do that with a simple example.
6
+
7
+ Introducing the Example
8
+ -----------------------
9
+
10
+ Let's say you would like an operation that takes in two arrays, ``x`` and
11
+ ``y``, scales them both by coefficients ``alpha`` and ``beta`` respectively,
12
+ and then adds them together to get the result ``z = alpha * x + beta * y``.
13
+ You can do that in MLX directly:
14
+
15
+ .. code-block:: python
16
+
17
+ import mlx.core as mx
18
+
19
+ def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
20
+ return alpha * x + beta * y
21
+
22
+ This function performs that operation while leaving the implementation and
23
+ function transformations to MLX.
24
+
25
+ However, you may want to customize the underlying implementation, perhaps to
26
+ make it faster. In this tutorial we will go through adding custom extensions.
27
+ It will cover:
28
+
29
+ * The structure of the MLX library.
30
+ * Implementing a CPU operation.
31
+ * Implementing a GPU operation using metal.
32
+ * Adding the ``vjp`` and ``jvp`` function transformation.
33
+ * Building a custom extension and binding it to python.
34
+
35
+ Operations and Primitives
36
+ -------------------------
37
+
38
+ Operations in MLX build the computation graph. Primitives provide the rules for
39
+ evaluating and transforming the graph. Let's start by discussing operations in
40
+ more detail.
41
+
42
+ Operations
43
+ ^^^^^^^^^^^
44
+
45
+ Operations are the front-end functions that operate on arrays. They are defined
46
+ in the C++ API (:ref:`cpp_ops`), and the Python API (:ref:`ops`) binds them.
47
+
48
+ We would like an operation :meth:`axpby` that takes in two arrays, ``x`` and
49
+ ``y``, and two scalars, ``alpha`` and ``beta``. This is how to define it in
50
+ C++:
51
+
52
+ .. code-block:: C++
53
+
54
+ /**
55
+ * Scale and sum two vectors element-wise
56
+ * z = alpha * x + beta * y
57
+ *
58
+ * Use NumPy-style broadcasting between x and y
59
+ * Inputs are upcasted to floats if needed
60
+ **/
61
+ array axpby(
62
+ const array& x, // Input array x
63
+ const array& y, // Input array y
64
+ const float alpha, // Scaling factor for x
65
+ const float beta, // Scaling factor for y
66
+ StreamOrDevice s = {} // Stream on which to schedule the operation
67
+ );
68
+
69
+ The simplest way to implement this is with existing operations:
70
+
71
+ .. code-block:: C++
72
+
73
+ array axpby(
74
+ const array& x, // Input array x
75
+ const array& y, // Input array y
76
+ const float alpha, // Scaling factor for x
77
+ const float beta, // Scaling factor for y
78
+ StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
79
+ ) {
80
+ // Scale x and y on the provided stream
81
+ auto ax = multiply(array(alpha), x, s);
82
+ auto by = multiply(array(beta), y, s);
83
+
84
+ // Add and return
85
+ return add(ax, by, s);
86
+ }
87
+
88
+ The operations themselves do not contain the implementations that act on the
89
+ data, nor do they contain the rules of transformations. Rather, they are an
90
+ easy to use interface that use :class:`Primitive` building blocks.
91
+
92
+ Primitives
93
+ ^^^^^^^^^^^
94
+
95
+ A :class:`Primitive` is part of the computation graph of an :class:`array`. It
96
+ defines how to create output arrays given input arrays. Further, a
97
+ :class:`Primitive` has methods to run on the CPU or GPU and for function
98
+ transformations such as ``vjp`` and ``jvp``. Let's go back to our example to be
99
+ more concrete:
100
+
101
+ .. code-block:: C++
102
+
103
+ class Axpby : public Primitive {
104
+ public:
105
+ explicit Axpby(Stream stream, float alpha, float beta)
106
+ : Primitive(stream), alpha_(alpha), beta_(beta){};
107
+
108
+ /**
109
+ * A primitive must know how to evaluate itself on the CPU/GPU
110
+ * for the given inputs and populate the output array.
111
+ *
112
+ * To avoid unnecessary allocations, the evaluation function
113
+ * is responsible for allocating space for the array.
114
+ */
115
+ void eval_cpu(
116
+ const std::vector<array>& inputs,
117
+ std::vector<array>& outputs) override;
118
+ void eval_gpu(
119
+ const std::vector<array>& inputs,
120
+ std::vector<array>& outputs) override;
121
+
122
+ /** The Jacobian-vector product. */
123
+ std::vector<array> jvp(
124
+ const std::vector<array>& primals,
125
+ const std::vector<array>& tangents,
126
+ const std::vector<int>& argnums) override;
127
+
128
+ /** The vector-Jacobian product. */
129
+ std::vector<array> vjp(
130
+ const std::vector<array>& primals,
131
+ const std::vector<array>& cotangents,
132
+ const std::vector<int>& argnums,
133
+ const std::vector<array>& outputs) override;
134
+
135
+ /**
136
+ * The primitive must know how to vectorize itself across
137
+ * the given axes. The output is a pair containing the array
138
+ * representing the vectorized computation and the axis which
139
+ * corresponds to the output vectorized dimension.
140
+ */
141
+ std::pair<std::vector<array>, std::vector<int>> vmap(
142
+ const std::vector<array>& inputs,
143
+ const std::vector<int>& axes) override;
144
+
145
+ /** The name of primitive. */
146
+ const char* name() const override {
147
+ return "Axpby";
148
+ }
149
+
150
+ /** Equivalence check **/
151
+ bool is_equivalent(const Primitive& other) const override;
152
+
153
+ private:
154
+ float alpha_;
155
+ float beta_;
156
+ };
157
+
158
+ The :class:`Axpby` class derives from the base :class:`Primitive` class. The
159
+ :class:`Axpby` treats ``alpha`` and ``beta`` as parameters. It then provides
160
+ implementations of how the output array is produced given the inputs through
161
+ :meth:`Axpby::eval_cpu` and :meth:`Axpby::eval_gpu`. It also provides rules
162
+ of transformations in :meth:`Axpby::jvp`, :meth:`Axpby::vjp`, and
163
+ :meth:`Axpby::vmap`.
164
+
165
+ Using the Primitive
166
+ ^^^^^^^^^^^^^^^^^^^
167
+
168
+ Operations can use this :class:`Primitive` to add a new :class:`array` to the
169
+ computation graph. An :class:`array` can be constructed by providing its data
170
+ type, shape, the :class:`Primitive` that computes it, and the :class:`array`
171
+ inputs that are passed to the primitive.
172
+
173
+ Let's reimplement our operation now in terms of our :class:`Axpby` primitive.
174
+
175
+ .. code-block:: C++
176
+
177
+ array axpby(
178
+ const array& x, // Input array x
179
+ const array& y, // Input array y
180
+ const float alpha, // Scaling factor for x
181
+ const float beta, // Scaling factor for y
182
+ StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
183
+ ) {
184
+ // Promote dtypes between x and y as needed
185
+ auto promoted_dtype = promote_types(x.dtype(), y.dtype());
186
+
187
+ // Upcast to float32 for non-floating point inputs x and y
188
+ auto out_dtype = issubdtype(promoted_dtype, float32)
189
+ ? promoted_dtype
190
+ : promote_types(promoted_dtype, float32);
191
+
192
+ // Cast x and y up to the determined dtype (on the same stream s)
193
+ auto x_casted = astype(x, out_dtype, s);
194
+ auto y_casted = astype(y, out_dtype, s);
195
+
196
+ // Broadcast the shapes of x and y (on the same stream s)
197
+ auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s);
198
+ auto out_shape = broadcasted_inputs[0].shape();
199
+
200
+ // Construct the array as the output of the Axpby primitive
201
+ // with the broadcasted and upcasted arrays as inputs
202
+ return array(
203
+ /* const std::vector<int>& shape = */ out_shape,
204
+ /* Dtype dtype = */ out_dtype,
205
+ /* std::unique_ptr<Primitive> primitive = */
206
+ std::make_shared<Axpby>(to_stream(s), alpha, beta),
207
+ /* const std::vector<array>& inputs = */ broadcasted_inputs);
208
+ }
209
+
210
+
211
+ This operation now handles the following:
212
+
213
+ #. Upcast inputs and resolve the output data type.
214
+ #. Broadcast the inputs and resolve the output shape.
215
+ #. Construct the primitive :class:`Axpby` using the given stream, ``alpha``, and ``beta``.
216
+ #. Construct the output :class:`array` using the primitive and the inputs.
217
+
218
+ Implementing the Primitive
219
+ --------------------------
220
+
221
+ No computation happens when we call the operation alone. The operation only
222
+ builds the computation graph. When we evaluate the output array, MLX schedules
223
+ the execution of the computation graph, and calls :meth:`Axpby::eval_cpu` or
224
+ :meth:`Axpby::eval_gpu` depending on the stream/device specified by the user.
225
+
226
+ .. warning::
227
+ When :meth:`Primitive::eval_cpu` or :meth:`Primitive::eval_gpu` are called,
228
+ no memory has been allocated for the output array. It falls on the implementation
229
+ of these functions to allocate memory as needed.
230
+
231
+ Implementing the CPU Back-end
232
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
233
+
234
+ Let's start by implementing :meth:`Axpby::eval_cpu`.
235
+
236
+ The method will go over each element of the output array, find the
237
+ corresponding input elements of ``x`` and ``y`` and perform the operation
238
+ point-wise. This is captured in the templated function :meth:`axpby_impl`.
239
+
240
+ .. code-block:: C++
241
+
242
+ template <typename T>
243
+ void axpby_impl(
244
+ const mx::array& x,
245
+ const mx::array& y,
246
+ mx::array& out,
247
+ float alpha_,
248
+ float beta_,
249
+ mx::Stream stream) {
250
+ out.set_data(mx::allocator::malloc(out.nbytes()));
251
+
252
+ // Get the CPU command encoder and register input and output arrays
253
+ auto& encoder = mx::cpu::get_command_encoder(stream);
254
+ encoder.set_input_array(x);
255
+ encoder.set_input_array(y);
256
+ encoder.set_output_array(out);
257
+
258
+ // Launch the CPU kernel
259
+ encoder.dispatch([x_ptr = x.data<T>(),
260
+ y_ptr = y.data<T>(),
261
+ out_ptr = out.data<T>(),
262
+ size = out.size(),
263
+ shape = out.shape(),
264
+ x_strides = x.strides(),
265
+ y_strides = y.strides(),
266
+ alpha_,
267
+ beta_]() {
268
+
269
+ // Cast alpha and beta to the relevant types
270
+ T alpha = static_cast<T>(alpha_);
271
+ T beta = static_cast<T>(beta_);
272
+
273
+ // Do the element-wise operation for each output
274
+ for (size_t out_idx = 0; out_idx < size; out_idx++) {
275
+ // Map linear indices to offsets in x and y
276
+ auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides);
277
+ auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides);
278
+
279
+ // We allocate the output to be contiguous and regularly strided
280
+ // (defaults to row major) and hence it doesn't need additional mapping
281
+ out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset];
282
+ }
283
+ });
284
+ }
285
+
286
+ Our implementation should work for all incoming floating point arrays.
287
+ Accordingly, we add dispatches for ``float32``, ``float16``, ``bfloat16`` and
288
+ ``complex64``. We throw an error if we encounter an unexpected type.
289
+
290
+ .. code-block:: C++
291
+
292
+ void Axpby::eval_cpu(
293
+ const std::vector<mx::array>& inputs,
294
+ std::vector<mx::array>& outputs) {
295
+ auto& x = inputs[0];
296
+ auto& y = inputs[1];
297
+ auto& out = outputs[0];
298
+
299
+ // Dispatch to the correct dtype
300
+ if (out.dtype() == mx::float32) {
301
+ return axpby_impl<float>(x, y, out, alpha_, beta_, stream());
302
+ } else if (out.dtype() == mx::float16) {
303
+ return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream());
304
+ } else if (out.dtype() == mx::bfloat16) {
305
+ return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream());
306
+ } else if (out.dtype() == mx::complex64) {
307
+ return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream());
308
+ } else {
309
+ throw std::runtime_error(
310
+ "Axpby is only supported for floating point types.");
311
+ }
312
+ }
313
+
314
+ Just this much is enough to run the operation :meth:`axpby` on a CPU stream! If
315
+ you do not plan on running the operation on the GPU or using transforms on
316
+ computation graphs that contain :class:`Axpby`, you can stop implementing the
317
+ primitive here.
318
+
319
+ Implementing the GPU Back-end
320
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
321
+
322
+ Apple silicon devices address their GPUs using the Metal_ shading language, and
323
+ GPU kernels in MLX are written using Metal.
324
+
325
+ .. note::
326
+
327
+ Here are some helpful resources if you are new to Metal:
328
+
329
+ * A walkthrough of the metal compute pipeline: `Metal Example`_
330
+ * Documentation for metal shading language: `Metal Specification`_
331
+ * Using metal from C++: `Metal-cpp`_
332
+
333
+ Let's keep the GPU kernel simple. We will launch exactly as many threads as
334
+ there are elements in the output. Each thread will pick the element it needs
335
+ from ``x`` and ``y``, do the point-wise operation, and update its assigned
336
+ element in the output.
337
+
338
+ .. code-block:: C++
339
+
340
+ template <typename T>
341
+ [[kernel]] void axpby_general(
342
+ device const T* x [[buffer(0)]],
343
+ device const T* y [[buffer(1)]],
344
+ device T* out [[buffer(2)]],
345
+ constant const float& alpha [[buffer(3)]],
346
+ constant const float& beta [[buffer(4)]],
347
+ constant const int* shape [[buffer(5)]],
348
+ constant const int64_t* x_strides [[buffer(6)]],
349
+ constant const int64_t* y_strides [[buffer(7)]],
350
+ constant const int& ndim [[buffer(8)]],
351
+ uint index [[thread_position_in_grid]]) {
352
+ // Convert linear indices to offsets in array
353
+ auto x_offset = elem_to_loc(index, shape, x_strides, ndim);
354
+ auto y_offset = elem_to_loc(index, shape, y_strides, ndim);
355
+
356
+ // Do the operation and update the output
357
+ out[index] =
358
+ static_cast<T>(alpha) * x[x_offset] + static_cast<T>(beta) * y[y_offset];
359
+ }
360
+
361
+ We then need to instantiate this template for all floating point types and give
362
+ each instantiation a unique host name so we can identify it.
363
+
364
+ .. code-block:: C++
365
+
366
+ instantiate_kernel("axpby_general_float32", axpby_general, float)
367
+ instantiate_kernel("axpby_general_float16", axpby_general, float16_t)
368
+ instantiate_kernel("axpby_general_bfloat16", axpby_general, bfloat16_t)
369
+ instantiate_kernel("axpby_general_complex64", axpby_general, complex64_t)
370
+
371
+ The logic to determine the kernel, set the inputs, resolve the grid dimensions,
372
+ and dispatch to the GPU are contained in :meth:`Axpby::eval_gpu` as shown
373
+ below.
374
+
375
+ .. code-block:: C++
376
+
377
+ /** Evaluate primitive on GPU */
378
+ void Axpby::eval_gpu(
379
+ const std::vector<array>& inputs,
380
+ std::vector<array>& outputs) {
381
+ // Prepare inputs
382
+ assert(inputs.size() == 2);
383
+ auto& x = inputs[0];
384
+ auto& y = inputs[1];
385
+ auto& out = outputs[0];
386
+
387
+ // Each primitive carries the stream it should execute on
388
+ // and each stream carries its device identifiers
389
+ auto& s = stream();
390
+ // We get the needed metal device using the stream
391
+ auto& d = metal::device(s.device);
392
+
393
+ // Allocate output memory
394
+ out.set_data(allocator::malloc(out.nbytes()));
395
+
396
+ // Resolve name of kernel
397
+ std::stream kname;
398
+ kname = "axpby_general_" + type_to_name(out);
399
+
400
+ // Load the metal library
401
+ auto lib = d.get_library("mlx_ext", current_binary_dir());
402
+
403
+ // Make a kernel from this metal library
404
+ auto kernel = d.get_kernel(kname, lib);
405
+
406
+ // Prepare to encode kernel
407
+ auto& compute_encoder = d.get_command_encoder(s.index);
408
+ compute_encoder.set_compute_pipeline_state(kernel);
409
+
410
+ // Kernel parameters are registered with buffer indices corresponding to
411
+ // those in the kernel declaration at axpby.metal
412
+ int ndim = out.ndim();
413
+ size_t nelem = out.size();
414
+
415
+ // Encode input arrays to kernel
416
+ compute_encoder.set_input_array(x, 0);
417
+ compute_encoder.set_input_array(y, 1);
418
+
419
+ // Encode output arrays to kernel
420
+ compute_encoder.set_output_array(out, 2);
421
+
422
+ // Encode alpha and beta
423
+ compute_encoder.set_bytes(alpha_, 3);
424
+ compute_encoder.set_bytes(beta_, 4);
425
+
426
+ // Encode shape, strides and ndim
427
+ compute_encoder.set_vector_bytes(x.shape(), 5);
428
+ compute_encoder.set_vector_bytes(x.strides(), 6);
429
+ compute_encoder.set_bytes(y.strides(), 7);
430
+ compute_encoder.set_bytes(ndim, 8);
431
+
432
+ // We launch 1 thread for each input and make sure that the number of
433
+ // threads in any given threadgroup is not higher than the max allowed
434
+ size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup());
435
+
436
+ // Fix the 3D size of each threadgroup (in terms of threads)
437
+ MTL::Size group_dims = MTL::Size(tgp_size, 1, 1);
438
+
439
+ // Fix the 3D size of the launch grid (in terms of threads)
440
+ MTL::Size grid_dims = MTL::Size(nelem, 1, 1);
441
+
442
+ // Launch the grid with the given number of threads divided among
443
+ // the given threadgroups
444
+ compute_encoder.dispatch_threads(grid_dims, group_dims);
445
+ }
446
+
447
+ We can now call the :meth:`axpby` operation on both the CPU and the GPU!
448
+
449
+ A few things to note about MLX and Metal before moving on. MLX keeps track of
450
+ the active ``command_buffer`` and the ``MTLCommandBuffer`` to which it is
451
+ associated. We rely on :meth:`d.get_command_encoder` to give us the active
452
+ metal compute command encoder instead of building a new one and calling
453
+ :meth:`compute_encoder->end_encoding` at the end. MLX adds kernels (compute
454
+ pipelines) to the active command buffer until some specified limit is hit or
455
+ the command buffer needs to be flushed for synchronization.
456
+
457
+ Primitive Transforms
458
+ ^^^^^^^^^^^^^^^^^^^^^
459
+
460
+ Next, let's add implementations for transformations in a :class:`Primitive`.
461
+ These transformations can be built on top of other operations, including the
462
+ one we just defined:
463
+
464
+ .. code-block:: C++
465
+
466
+ /** The Jacobian-vector product. */
467
+ std::vector<array> Axpby::jvp(
468
+ const std::vector<array>& primals,
469
+ const std::vector<array>& tangents,
470
+ const std::vector<int>& argnums) {
471
+ // Forward mode diff that pushes along the tangents
472
+ // The jvp transform on the primitive can be built with ops
473
+ // that are scheduled on the same stream as the primitive
474
+
475
+ // If argnums = {0}, we only push along x in which case the
476
+ // jvp is just the tangent scaled by alpha
477
+ // Similarly, if argnums = {1}, the jvp is just the tangent
478
+ // scaled by beta
479
+ if (argnums.size() > 1) {
480
+ auto scale = argnums[0] == 0 ? alpha_ : beta_;
481
+ auto scale_arr = array(scale, tangents[0].dtype());
482
+ return {multiply(scale_arr, tangents[0], stream())};
483
+ }
484
+ // If argnums = {0, 1}, we take contributions from both
485
+ // which gives us jvp = tangent_x * alpha + tangent_y * beta
486
+ else {
487
+ return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())};
488
+ }
489
+ }
490
+
491
+ .. code-block:: C++
492
+
493
+ /** The vector-Jacobian product. */
494
+ std::vector<array> Axpby::vjp(
495
+ const std::vector<array>& primals,
496
+ const std::vector<array>& cotangents,
497
+ const std::vector<int>& argnums,
498
+ const std::vector<int>& /* unused */) {
499
+ // Reverse mode diff
500
+ std::vector<array> vjps;
501
+ for (auto arg : argnums) {
502
+ auto scale = arg == 0 ? alpha_ : beta_;
503
+ auto scale_arr = array(scale, cotangents[0].dtype());
504
+ vjps.push_back(multiply(scale_arr, cotangents[0], stream()));
505
+ }
506
+ return vjps;
507
+ }
508
+
509
+ Note, a transformation does not need to be fully defined to start using
510
+ the :class:`Primitive`.
511
+
512
+ .. code-block:: C++
513
+
514
+ /** Vectorize primitive along given axis */
515
+ std::pair<std::vector<array>, std::vector<int>> Axpby::vmap(
516
+ const std::vector<array>& inputs,
517
+ const std::vector<int>& axes) {
518
+ throw std::runtime_error("[Axpby] vmap not implemented.");
519
+ }
520
+
521
+ Building and Binding
522
+ --------------------
523
+
524
+ Let's look at the overall directory structure first.
525
+
526
+ | extensions
527
+ | ├── axpby
528
+ | │ ├── axpby.cpp
529
+ | │ ├── axpby.h
530
+ | │ └── axpby.metal
531
+ | ├── mlx_sample_extensions
532
+ | │ └── __init__.py
533
+ | ├── bindings.cpp
534
+ | ├── CMakeLists.txt
535
+ | └── setup.py
536
+
537
+ * ``extensions/axpby/`` defines the C++ extension library
538
+ * ``extensions/mlx_sample_extensions`` sets out the structure for the
539
+ associated Python package
540
+ * ``extensions/bindings.cpp`` provides Python bindings for our operation
541
+ * ``extensions/CMakeLists.txt`` holds CMake rules to build the library and
542
+ Python bindings
543
+ * ``extensions/setup.py`` holds the ``setuptools`` rules to build and install
544
+ the Python package
545
+
546
+ Binding to Python
547
+ ^^^^^^^^^^^^^^^^^^
548
+
549
+ We use nanobind_ to build a Python API for the C++ library. Since bindings for
550
+ components such as :class:`mlx.core.array`, :class:`mlx.core.stream`, etc. are
551
+ already provided, adding our :meth:`axpby` is simple.
552
+
553
+ .. code-block:: C++
554
+
555
+ NB_MODULE(_ext, m) {
556
+ m.doc() = "Sample extension for MLX";
557
+
558
+ m.def(
559
+ "axpby",
560
+ &axpby,
561
+ "x"_a,
562
+ "y"_a,
563
+ "alpha"_a,
564
+ "beta"_a,
565
+ nb::kw_only(),
566
+ "stream"_a = nb::none(),
567
+ R"(
568
+ Scale and sum two vectors element-wise
569
+ ``z = alpha * x + beta * y``
570
+
571
+ Follows numpy style broadcasting between ``x`` and ``y``
572
+ Inputs are upcasted to floats if needed
573
+
574
+ Args:
575
+ x (array): Input array.
576
+ y (array): Input array.
577
+ alpha (float): Scaling factor for ``x``.
578
+ beta (float): Scaling factor for ``y``.
579
+
580
+ Returns:
581
+ array: ``alpha * x + beta * y``
582
+ )");
583
+ }
584
+
585
+ Most of the complexity in the above example comes from additional bells and
586
+ whistles such as the literal names and doc-strings.
587
+
588
+ .. warning::
589
+
590
+ :mod:`mlx.core` must be imported before importing
591
+ :mod:`mlx_sample_extensions` as defined by the nanobind module above to
592
+ ensure that the casters for :mod:`mlx.core` components like
593
+ :class:`mlx.core.array` are available.
594
+
595
+ .. _Building with CMake:
596
+
597
+ Building with CMake
598
+ ^^^^^^^^^^^^^^^^^^^^
599
+
600
+ Building the C++ extension library only requires that you ``find_package(MLX
601
+ CONFIG)`` and then link it to your library.
602
+
603
+ .. code-block:: cmake
604
+
605
+ # Add library
606
+ add_library(mlx_ext)
607
+
608
+ # Add sources
609
+ target_sources(
610
+ mlx_ext
611
+ PUBLIC
612
+ ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.cpp
613
+ )
614
+
615
+ # Add include headers
616
+ target_include_directories(
617
+ mlx_ext PUBLIC ${CMAKE_CURRENT_LIST_DIR}
618
+ )
619
+
620
+ # Link to mlx
621
+ target_link_libraries(mlx_ext PUBLIC mlx)
622
+
623
+ We also need to build the attached Metal library. For convenience, we provide a
624
+ :meth:`mlx_build_metallib` function that builds a ``.metallib`` target given
625
+ sources, headers, destinations, etc. (defined in ``cmake/extension.cmake`` and
626
+ automatically imported with MLX package).
627
+
628
+ Here is what that looks like in practice:
629
+
630
+ .. code-block:: cmake
631
+
632
+ # Build metallib
633
+ if(MLX_BUILD_METAL)
634
+
635
+ mlx_build_metallib(
636
+ TARGET mlx_ext_metallib
637
+ TITLE mlx_ext
638
+ SOURCES ${CMAKE_CURRENT_LIST_DIR}/axpby/axpby.metal
639
+ INCLUDE_DIRS ${PROJECT_SOURCE_DIR} ${MLX_INCLUDE_DIRS}
640
+ OUTPUT_DIRECTORY ${CMAKE_LIBRARY_OUTPUT_DIRECTORY}
641
+ )
642
+
643
+ add_dependencies(
644
+ mlx_ext
645
+ mlx_ext_metallib
646
+ )
647
+
648
+ endif()
649
+
650
+ Finally, we build the nanobind_ bindings
651
+
652
+ .. code-block:: cmake
653
+
654
+ nanobind_add_module(
655
+ _ext
656
+ NB_STATIC STABLE_ABI LTO NOMINSIZE
657
+ NB_DOMAIN mlx
658
+ ${CMAKE_CURRENT_LIST_DIR}/bindings.cpp
659
+ )
660
+ target_link_libraries(_ext PRIVATE mlx_ext)
661
+
662
+ if(BUILD_SHARED_LIBS)
663
+ target_link_options(_ext PRIVATE -Wl,-rpath,@loader_path)
664
+ endif()
665
+
666
+ Building with ``setuptools``
667
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
668
+
669
+ Once we have set out the CMake build rules as described above, we can use the
670
+ build utilities defined in :mod:`mlx.extension`:
671
+
672
+ .. code-block:: python
673
+
674
+ from mlx import extension
675
+ from setuptools import setup
676
+
677
+ if __name__ == "__main__":
678
+ setup(
679
+ name="mlx_sample_extensions",
680
+ version="0.0.0",
681
+ description="Sample C++ and Metal extensions for MLX primitives.",
682
+ ext_modules=[extension.CMakeExtension("mlx_sample_extensions._ext")],
683
+ cmdclass={"build_ext": extension.CMakeBuild},
684
+ packages=["mlx_sample_extensions"],
685
+ package_data={"mlx_sample_extensions": ["*.so", "*.dylib", "*.metallib"]},
686
+ extras_require={"dev":[]},
687
+ zip_safe=False,
688
+ python_requires=">=3.8",
689
+ )
690
+
691
+ .. note::
692
+ We treat ``extensions/mlx_sample_extensions`` as the package directory
693
+ even though it only contains a ``__init__.py`` to ensure the following:
694
+
695
+ * :mod:`mlx.core` must be imported before importing :mod:`_ext`
696
+ * The C++ extension library and the metal library are co-located with the python
697
+ bindings and copied together if the package is installed
698
+
699
+ To build the package, first install the build dependencies with ``pip install
700
+ -r requirements.txt``. You can then build inplace for development using
701
+ ``python setup.py build_ext -j8 --inplace`` (in ``extensions/``)
702
+
703
+ This results in the directory structure:
704
+
705
+ | extensions
706
+ | ├── mlx_sample_extensions
707
+ | │ ├── __init__.py
708
+ | │ ├── libmlx_ext.dylib # C++ extension library
709
+ | │ ├── mlx_ext.metallib # Metal library
710
+ | │ └── _ext.cpython-3x-darwin.so # Python Binding
711
+ | ...
712
+
713
+ When you try to install using the command ``python -m pip install .`` (in
714
+ ``extensions/``), the package will be installed with the same structure as
715
+ ``extensions/mlx_sample_extensions`` and the C++ and Metal library will be
716
+ copied along with the Python binding since they are specified as
717
+ ``package_data``.
718
+
719
+ Usage
720
+ -----
721
+
722
+ After installing the extension as described above, you should be able to simply
723
+ import the Python package and play with it as you would any other MLX operation.
724
+
725
+ Let's look at a simple script and its results:
726
+
727
+ .. code-block:: python
728
+
729
+ import mlx.core as mx
730
+ from mlx_sample_extensions import axpby
731
+
732
+ a = mx.ones((3, 4))
733
+ b = mx.ones((3, 4))
734
+ c = axpby(a, b, 4.0, 2.0, stream=mx.cpu)
735
+
736
+ print(f"c shape: {c.shape}")
737
+ print(f"c dtype: {c.dtype}")
738
+ print(f"c is correct: {mx.all(c == 6.0).item()}")
739
+
740
+ Output:
741
+
742
+ .. code-block::
743
+
744
+ c shape: [3, 4]
745
+ c dtype: float32
746
+ c is correct: True
747
+
748
+ Results
749
+ ^^^^^^^
750
+
751
+ Let's run a quick benchmark and see how our new ``axpby`` operation compares
752
+ with the naive :meth:`simple_axpby` we first defined.
753
+
754
+ .. code-block:: python
755
+
756
+ import mlx.core as mx
757
+ from mlx_sample_extensions import axpby
758
+ import time
759
+
760
+ def simple_axpby(x: mx.array, y: mx.array, alpha: float, beta: float) -> mx.array:
761
+ return alpha * x + beta * y
762
+
763
+ M = 4096
764
+ N = 4096
765
+
766
+ x = mx.random.normal((M, N))
767
+ y = mx.random.normal((M, N))
768
+ alpha = 4.0
769
+ beta = 2.0
770
+
771
+ mx.eval(x, y)
772
+
773
+ def bench(f):
774
+ # Warm up
775
+ for i in range(5):
776
+ z = f(x, y, alpha, beta)
777
+ mx.eval(z)
778
+
779
+ # Timed run
780
+ s = time.perf_counter()
781
+ for i in range(100):
782
+ z = f(x, y, alpha, beta)
783
+ mx.eval(z)
784
+ e = time.perf_counter()
785
+ return 1000 * (e - s) / 100
786
+
787
+ simple_time = bench(simple_axpby)
788
+ custom_time = bench(axpby)
789
+
790
+ print(f"Simple axpby: {simple_time:.3f} ms | Custom axpby: {custom_time:.3f} ms")
791
+
792
+ The results are ``Simple axpby: 1.559 ms | Custom axpby: 0.774 ms``. We see
793
+ modest improvements right away!
794
+
795
+ This operation is now good to be used to build other operations, in
796
+ :class:`mlx.nn.Module` calls, and also as a part of graph transformations like
797
+ :meth:`grad`.
798
+
799
+ Scripts
800
+ -------
801
+
802
+ .. admonition:: Download the code
803
+
804
+ The full example code is available in `mlx <https://github.com/ml-explore/mlx/tree/main/examples/extensions/>`_.
805
+
806
+ .. _Accelerate: https://developer.apple.com/documentation/accelerate/blas?language=objc
807
+ .. _Metal: https://developer.apple.com/documentation/metal?language=objc
808
+ .. _Metal-cpp: https://developer.apple.com/metal/cpp/
809
+ .. _`Metal Specification`: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
810
+ .. _`Metal Example`: https://developer.apple.com/documentation/metal/performing_calculations_on_a_gpu?language=objc
811
+ .. _nanobind: https://nanobind.readthedocs.io/en/latest/