mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
data/mlx/mlx/compile.h ADDED
@@ -0,0 +1,45 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/api.h"
6
+ #include "mlx/array.h"
7
+
8
+ namespace mlx::core {
9
+
10
+ enum class CompileMode { disabled, no_simplify, no_fuse, enabled };
11
+
12
+ /** Compile takes a function and returns a compiled function. */
13
+ MLX_API std::function<std::vector<array>(const std::vector<array>&)> compile(
14
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
15
+ bool shapeless = false);
16
+
17
+ MLX_API std::function<std::vector<array>(const std::vector<array>&)> compile(
18
+ std::vector<array> (*fun)(const std::vector<array>&),
19
+ bool shapeless = false);
20
+
21
+ // Convert capture-less lambdas to function pointers.
22
+ template <
23
+ typename F,
24
+ typename = std::enable_if_t<
25
+ std::is_convertible_v<F, decltype(+std::declval<F>())>>>
26
+ std::function<std::vector<array>(const std::vector<array>&)> compile(
27
+ F&& f,
28
+ bool shapeless = false) {
29
+ return compile(+f, shapeless);
30
+ }
31
+
32
+ /** Globally disable compilation.
33
+ * Setting the environment variable ``MLX_DISABLE_COMPILE`` can also
34
+ * be used to disable compilation.
35
+ */
36
+ MLX_API void disable_compile();
37
+
38
+ /** Globally enable compilation.
39
+ * This will override the environment variable ``MLX_DISABLE_COMPILE``.
40
+ */
41
+ MLX_API void enable_compile();
42
+
43
+ /** Set the compiler mode to the given value. */
44
+ MLX_API void set_compile_mode(CompileMode mode);
45
+ } // namespace mlx::core
@@ -0,0 +1,70 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <unordered_map>
6
+
7
+ #include "mlx/api.h"
8
+ #include "mlx/array.h"
9
+
10
+ namespace mlx::core::detail {
11
+
12
+ using ArraysAndExtra = std::pair<std::vector<array>, std::shared_ptr<void>>;
13
+ using ArrayFnWithExtra =
14
+ std::function<ArraysAndExtra(const std::vector<array>&)>;
15
+
16
+ // This is not part of the general C++ API as calling with a bad id is a bad
17
+ // idea.
18
+ MLX_API std::function<std::vector<array>(const std::vector<array>&)> compile(
19
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
20
+ std::uintptr_t fun_id,
21
+ bool shapeless = false,
22
+ std::vector<uint64_t> constants = {});
23
+
24
+ MLX_API ArrayFnWithExtra compile(
25
+ ArrayFnWithExtra fun,
26
+ std::uintptr_t fun_id,
27
+ bool shapeless,
28
+ std::vector<uint64_t> constants);
29
+
30
+ // Erase cached compile functions
31
+ MLX_API void compile_erase(std::uintptr_t fun_id);
32
+
33
+ // Clear the compiler cache causing a recompilation of all compiled functions
34
+ // when called again.
35
+ MLX_API void compile_clear_cache();
36
+
37
+ bool compile_available_for_device(const Device& device);
38
+
39
+ std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
40
+ compile_trace(
41
+ const ArrayFnWithExtra& fun,
42
+ const std::vector<array>& inputs,
43
+ bool shapeless);
44
+
45
+ using ParentsMap =
46
+ std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>;
47
+
48
+ // Traverses the graph to build a tape and a map of array ids to their parents
49
+ std::pair<std::vector<array>, ParentsMap> compile_dfs(
50
+ const std::vector<array>& inputs,
51
+ std::vector<array>& outputs,
52
+ const std::vector<array>& original_inputs);
53
+
54
+ // Simplify the tape.
55
+ void compile_simplify(
56
+ std::vector<array>& tape,
57
+ ParentsMap& parents_map,
58
+ std::vector<array>& outputs,
59
+ int passes);
60
+
61
+ std::vector<array> compile_replace(
62
+ const std::vector<array>& tape,
63
+ const std::vector<array>& trace_inputs,
64
+ const std::vector<array>& trace_outputs,
65
+ const std::vector<array>& inputs,
66
+ bool shapeless);
67
+
68
+ void compile_validate_shapeless(const std::vector<array>& tape);
69
+
70
+ } // namespace mlx::core::detail
@@ -0,0 +1,72 @@
1
+ // Copyright © 2023-2026 Apple Inc.
2
+
3
+ #include <stdexcept>
4
+
5
+ #include "mlx/backend/cpu/device_info.h"
6
+ #include "mlx/backend/gpu/device_info.h"
7
+ #include "mlx/device.h"
8
+
9
+ namespace mlx::core {
10
+
11
+ Device& mutable_default_device() {
12
+ static Device default_device{gpu::is_available() ? Device::gpu : Device::cpu};
13
+ return default_device;
14
+ }
15
+
16
+ const Device& default_device() {
17
+ return mutable_default_device();
18
+ }
19
+
20
+ void set_default_device(const Device& d) {
21
+ if (!gpu::is_available() && d == Device::gpu) {
22
+ throw std::invalid_argument(
23
+ "[set_default_device] Cannot set gpu device without gpu backend.");
24
+ }
25
+ mutable_default_device() = d;
26
+ }
27
+
28
+ bool operator==(const Device& lhs, const Device& rhs) {
29
+ return lhs.type == rhs.type && lhs.index == rhs.index;
30
+ }
31
+
32
+ bool operator!=(const Device& lhs, const Device& rhs) {
33
+ return !(lhs == rhs);
34
+ }
35
+
36
+ bool is_available(const Device& d) {
37
+ switch (d.type) {
38
+ case Device::cpu:
39
+ return cpu::is_available() && (d.index < cpu::device_count());
40
+ case Device::gpu:
41
+ return gpu::is_available() && (d.index < gpu::device_count());
42
+ }
43
+ // appease compiler
44
+ return false;
45
+ }
46
+
47
+ int device_count(Device::DeviceType type) {
48
+ switch (type) {
49
+ case Device::cpu:
50
+ return cpu::device_count();
51
+ case Device::gpu:
52
+ return gpu::device_count();
53
+ }
54
+ // appease compiler
55
+ return 0;
56
+ }
57
+
58
+ const std::unordered_map<std::string, std::variant<std::string, size_t>>&
59
+ device_info(const Device& d) {
60
+ switch (d.type) {
61
+ case Device::cpu:
62
+ return cpu::device_info(d.index);
63
+ case Device::gpu:
64
+ return gpu::device_info(d.index);
65
+ }
66
+ // appease compiler
67
+ static std::unordered_map<std::string, std::variant<std::string, size_t>>
68
+ empty;
69
+ return empty;
70
+ }
71
+
72
+ } // namespace mlx::core
data/mlx/mlx/device.h ADDED
@@ -0,0 +1,56 @@
1
+ // Copyright © 2023-2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/api.h"
6
+
7
+ #include <string>
8
+ #include <unordered_map>
9
+ #include <variant>
10
+
11
+ namespace mlx::core {
12
+
13
+ struct MLX_API Device {
14
+ enum class DeviceType {
15
+ cpu,
16
+ gpu,
17
+ };
18
+
19
+ static constexpr DeviceType cpu = DeviceType::cpu;
20
+ static constexpr DeviceType gpu = DeviceType::gpu;
21
+
22
+ Device(DeviceType type, int index = 0) : type(type), index(index) {}
23
+
24
+ DeviceType type;
25
+ int index;
26
+ };
27
+
28
+ MLX_API const Device& default_device();
29
+
30
+ MLX_API void set_default_device(const Device& d);
31
+
32
+ MLX_API bool operator==(const Device& lhs, const Device& rhs);
33
+ MLX_API bool operator!=(const Device& lhs, const Device& rhs);
34
+
35
+ MLX_API bool is_available(const Device& d);
36
+
37
+ /** Get the number of available devices for the given device type. */
38
+ MLX_API int device_count(Device::DeviceType type);
39
+
40
+ /**
41
+ * Get information about a device.
42
+ *
43
+ * Returns a map of device properties. Keys vary by backend:
44
+ * - device_name (string): Device name
45
+ * - architecture (string): Architecture identifier
46
+ * - total_memory/memory_size (size_t): Total device memory
47
+ * - free_memory (size_t): Available memory (CUDA only)
48
+ * - uuid (string): Device UUID (CUDA only)
49
+ * - pci_bus_id (string): PCI bus ID (CUDA only)
50
+ * - compute_capability_major/minor (size_t): Compute capability (CUDA only)
51
+ */
52
+ MLX_API const
53
+ std::unordered_map<std::string, std::variant<std::string, size_t>>&
54
+ device_info(const Device& d = default_device());
55
+
56
+ } // namespace mlx::core
@@ -0,0 +1,14 @@
1
+ target_sources(
2
+ mlx
3
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
4
+ ${CMAKE_CURRENT_SOURCE_DIR}/ops.cpp
5
+ ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cpp)
6
+
7
+ if(MLX_BUILD_CPU AND NOT WIN32)
8
+ target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp)
9
+ endif()
10
+
11
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/mpi)
12
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ring)
13
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/nccl)
14
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/jaccl)
@@ -0,0 +1,197 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #include <unordered_map>
4
+
5
+ #include "mlx/backend/cuda/cuda.h"
6
+ #include "mlx/distributed/distributed.h"
7
+ #include "mlx/distributed/distributed_impl.h"
8
+ #include "mlx/distributed/jaccl/jaccl.h"
9
+ #include "mlx/distributed/mpi/mpi.h"
10
+ #include "mlx/distributed/nccl/nccl.h"
11
+ #include "mlx/distributed/ring/ring.h"
12
+
13
+ namespace mlx::core::distributed {
14
+
15
+ namespace detail {
16
+
17
+ Stream communication_stream(Group group, StreamOrDevice s /* = {} */) {
18
+ return group.raw_group()->communication_stream(s);
19
+ }
20
+
21
+ void all_sum(Group group, const array& input, array& output, Stream stream) {
22
+ group.raw_group()->all_sum(input, output, stream);
23
+ }
24
+
25
+ void all_max(Group group, const array& input, array& output, Stream stream) {
26
+ group.raw_group()->all_max(input, output, stream);
27
+ }
28
+
29
+ void all_min(Group group, const array& input, array& output, Stream stream) {
30
+ group.raw_group()->all_min(input, output, stream);
31
+ }
32
+
33
+ void all_gather(Group group, const array& input, array& output, Stream stream) {
34
+ group.raw_group()->all_gather(input, output, stream);
35
+ }
36
+
37
+ void send(Group group, const array& input, int dst, Stream stream) {
38
+ group.raw_group()->send(input, dst, stream);
39
+ }
40
+
41
+ void recv(Group group, array& out, int src, Stream stream) {
42
+ group.raw_group()->recv(out, src, stream);
43
+ }
44
+
45
+ void sum_scatter(
46
+ Group group,
47
+ const array& input,
48
+ array& output,
49
+ Stream stream) {
50
+ group.raw_group()->sum_scatter(input, output, stream);
51
+ }
52
+
53
+ class EmptyGroup : public GroupImpl {
54
+ public:
55
+ Stream communication_stream(StreamOrDevice s) override {
56
+ return to_stream(s);
57
+ }
58
+
59
+ int rank() override {
60
+ return 0;
61
+ }
62
+
63
+ int size() override {
64
+ return 1;
65
+ }
66
+
67
+ std::shared_ptr<GroupImpl> split(int color, int key = -1) override {
68
+ throw std::runtime_error("Cannot split the distributed group further.");
69
+ }
70
+
71
+ void all_sum(const array&, array&, Stream) override {
72
+ throw std::runtime_error(
73
+ "Communication not implemented in an empty distributed group.");
74
+ }
75
+ void all_gather(const array&, array&, Stream) override {
76
+ throw std::runtime_error(
77
+ "Communication not implemented in an empty distributed group.");
78
+ }
79
+ void send(const array&, int, Stream) override {
80
+ throw std::runtime_error(
81
+ "Communication not implemented in an empty distributed group.");
82
+ }
83
+ void recv(array&, int, Stream) override {
84
+ throw std::runtime_error(
85
+ "Communication not implemented in an empty distributed group.");
86
+ }
87
+
88
+ void all_max(const array&, array&, Stream) override {
89
+ throw std::runtime_error(
90
+ "Communication not implemented in an empty distributed group.");
91
+ }
92
+
93
+ void all_min(const array&, array&, Stream) override {
94
+ throw std::runtime_error(
95
+ "Communication not implemented in an empty distributed group.");
96
+ }
97
+ void sum_scatter(const array&, array&, Stream) override {
98
+ throw std::runtime_error(
99
+ "Communication not implemented in an empty distributed group.");
100
+ }
101
+ };
102
+
103
+ } // namespace detail
104
+
105
+ bool is_available() {
106
+ return mpi::is_available() || ring::is_available() || nccl::is_available() ||
107
+ jaccl::is_available();
108
+ }
109
+
110
+ bool is_available(const std::string& bk) {
111
+ if (bk == "any") {
112
+ return is_available();
113
+ }
114
+ if (bk == "mpi") {
115
+ return mpi::is_available();
116
+ }
117
+ if (bk == "ring") {
118
+ return ring::is_available();
119
+ }
120
+ if (bk == "nccl") {
121
+ return nccl::is_available();
122
+ }
123
+ if (bk == "jaccl") {
124
+ return jaccl::is_available();
125
+ }
126
+ return false;
127
+ }
128
+
129
+ int Group::rank() const {
130
+ return group_->rank();
131
+ }
132
+
133
+ int Group::size() const {
134
+ return group_->size();
135
+ }
136
+
137
+ Group Group::split(int color, int key /* = -1 */) const {
138
+ return Group(group_->split(color, key));
139
+ }
140
+
141
+ Group init(bool strict /* = false */, const std::string& bk /* = "any" */) {
142
+ static std::unordered_map<std::string, std::shared_ptr<detail::GroupImpl>>
143
+ backends;
144
+
145
+ // Already initialized so return the group.
146
+ if (auto g = backends.find(bk); g != backends.end()) {
147
+ return Group(g->second);
148
+ }
149
+
150
+ // Create the requested communication group
151
+ std::shared_ptr<detail::GroupImpl> group{nullptr};
152
+ std::string bk_ = bk;
153
+ if (bk == "mpi") {
154
+ group = mpi::init(strict);
155
+ } else if (bk == "ring") {
156
+ group = ring::init(strict);
157
+ } else if (bk == "nccl") {
158
+ group = nccl::init(strict);
159
+ } else if (bk == "jaccl") {
160
+ group = jaccl::init(strict);
161
+ } else if (bk == "any") {
162
+ if (mlx::core::cu::is_available()) {
163
+ group = nccl::init(false);
164
+ bk_ = "nccl";
165
+ }
166
+ if (group == nullptr) {
167
+ group = ring::init(false);
168
+ bk_ = "ring";
169
+ }
170
+ if (group == nullptr) {
171
+ group = mpi::init(false);
172
+ bk_ = "mpi";
173
+ }
174
+ if (group == nullptr) {
175
+ group = jaccl::init(false);
176
+ bk_ = "jaccl";
177
+ }
178
+ if (group == nullptr && strict) {
179
+ throw std::runtime_error("[distributed] Couldn't initialize any backend");
180
+ }
181
+ } else {
182
+ std::ostringstream msg;
183
+ msg << "[distributed] The only valid values for backend are 'any', 'mpi', 'nccl', "
184
+ << "'jaccl' and 'ring' but '" << bk << "' was provided.";
185
+ throw std::invalid_argument(msg.str());
186
+ }
187
+
188
+ if (group == nullptr) {
189
+ group = std::make_shared<detail::EmptyGroup>();
190
+ } else {
191
+ backends.insert({"any", group});
192
+ }
193
+ backends.insert({std::move(bk_), group});
194
+ return Group(group);
195
+ }
196
+
197
+ } // namespace mlx::core::distributed
@@ -0,0 +1,61 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <memory>
6
+
7
+ #include "mlx/api.h"
8
+ #include "mlx/array.h"
9
+ #include "mlx/utils.h"
10
+
11
+ namespace mlx::core::distributed {
12
+
13
+ // Forward declaration of the base group implementation.
14
+ namespace detail {
15
+ class GroupImpl;
16
+ };
17
+
18
+ /* Check if a communication backend is available */
19
+ MLX_API bool is_available();
20
+ MLX_API bool is_available(const std::string& bk);
21
+
22
+ /**
23
+ * A distributed::Group represents a group of independent mlx processes that
24
+ * can communicate. We must also be able to create sub-groups from a group in
25
+ * order to define more granular communication.
26
+ */
27
+ struct MLX_API Group {
28
+ Group(std::shared_ptr<detail::GroupImpl> group) : group_(std::move(group)) {}
29
+
30
+ int rank() const;
31
+ int size() const;
32
+
33
+ /**
34
+ * Split the group according to the provided color. Namely processes that use
35
+ * the same color will go to the same group.
36
+ *
37
+ * The key defines the rank of the processes in the new group. The smaller
38
+ * the key the smaller the rank. If the provided key is negative, then the
39
+ * rank in the current group is used.
40
+ */
41
+ Group split(int color, int key = -1) const;
42
+
43
+ const std::shared_ptr<detail::GroupImpl>& raw_group() const {
44
+ return group_;
45
+ }
46
+
47
+ private:
48
+ std::shared_ptr<detail::GroupImpl> group_{nullptr};
49
+ };
50
+
51
+ /**
52
+ * Initialize the distributed backend and return the group containing all
53
+ * discoverable processes.
54
+ *
55
+ * If strict is true then throw an error if we couldn't initialize the
56
+ * distributed subsystem. Otherwise simply return a singleton group which will
57
+ * render communication operations as no-op.
58
+ */
59
+ MLX_API Group init(bool strict = false, const std::string& bk = "any");
60
+
61
+ } // namespace mlx::core::distributed
@@ -0,0 +1,59 @@
1
+ // Copyright © 2024 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/distributed/distributed.h"
6
+
7
+ namespace mlx::core::distributed::detail {
8
+
9
+ /**
10
+ * Abstract base class of a distributed group implementation.
11
+ */
12
+ class GroupImpl {
13
+ public:
14
+ virtual ~GroupImpl() {}
15
+
16
+ // Choose the stream this communication group can operate on
17
+ virtual Stream communication_stream(StreamOrDevice s = {}) = 0;
18
+
19
+ // Group operations
20
+ virtual int rank() = 0;
21
+ virtual int size() = 0;
22
+ virtual std::shared_ptr<GroupImpl> split(int color, int key = -1) = 0;
23
+
24
+ // Actual communication operations
25
+ virtual void all_sum(const array& input, array& output, Stream stream) = 0;
26
+ virtual void all_gather(const array& input, array& output, Stream stream) = 0;
27
+ virtual void send(const array& input, int dst, Stream stream) = 0;
28
+ virtual void recv(array& out, int src, Stream stream) = 0;
29
+ virtual void all_max(const array& input, array& output, Stream stream) = 0;
30
+ virtual void all_min(const array& input, array& output, Stream stream) = 0;
31
+ virtual void
32
+ sum_scatter(const array& input, array& output, Stream stream) = 0;
33
+ };
34
+
35
+ /* Define the MLX stream that the communication should happen in. */
36
+ Stream communication_stream(Group group, StreamOrDevice s = {});
37
+
38
+ /* Perform an all reduce sum operation */
39
+ void all_sum(Group group, const array& input, array& output, Stream stream);
40
+
41
+ /* Perform an all gather operation */
42
+ void all_gather(Group group, const array& input, array& output, Stream stream);
43
+
44
+ /** Send an array to the dst rank */
45
+ void send(Group group, const array& input, int dst, Stream stream);
46
+
47
+ /** Recv an array from the src rank */
48
+ void recv(Group group, array& out, int src, Stream stream);
49
+
50
+ /** Max reduction */
51
+ void all_max(Group group, const array& input, array& output, Stream stream);
52
+
53
+ /** Min reduction */
54
+ void all_min(Group group, const array& input, array& output, Stream stream);
55
+
56
+ /** Reduce scatter with average operation */
57
+ void sum_scatter(Group group, const array& input, array& output, Stream stream);
58
+
59
+ } // namespace mlx::core::distributed::detail
@@ -0,0 +1,12 @@
1
+ if(MLX_BUILD_CPU
2
+ AND ${CMAKE_SYSTEM_NAME} MATCHES "Darwin"
3
+ AND MACOS_SDK_VERSION GREATER_EQUAL 26.2)
4
+ target_sources(
5
+ mlx
6
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/jaccl.cpp
7
+ ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
8
+ ${CMAKE_CURRENT_SOURCE_DIR}/mesh.cpp
9
+ ${CMAKE_CURRENT_SOURCE_DIR}/ring.cpp)
10
+ else()
11
+ target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/no_jaccl.cpp)
12
+ endif()