mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,1243 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+ #include <cstdlib>
3
+ #include <map>
4
+ #include <sstream>
5
+ #include <unordered_map>
6
+ #include <unordered_set>
7
+
8
+ #include "mlx/allocator.h"
9
+ #include "mlx/backend/common/compiled.h"
10
+ #include "mlx/compile.h"
11
+ #include "mlx/compile_impl.h"
12
+ #include "mlx/fast_primitives.h"
13
+ #include "mlx/graph_utils.h"
14
+ #include "mlx/primitives.h"
15
+ #include "mlx/transforms.h"
16
+ #include "mlx/transforms_impl.h"
17
+ #include "mlx/utils.h"
18
+
19
+ namespace mlx::core {
20
+
21
+ constexpr int max_compile_depth = 11;
22
+ constexpr int max_compile_arrays = 24;
23
+
24
+ bool is_unary(const Primitive& p) {
25
+ return (
26
+ typeid(p) == typeid(Abs) || typeid(p) == typeid(ArcCos) ||
27
+ typeid(p) == typeid(ArcCosh) || typeid(p) == typeid(ArcSin) ||
28
+ typeid(p) == typeid(ArcSinh) || typeid(p) == typeid(ArcTan) ||
29
+ typeid(p) == typeid(ArcTanh) || typeid(p) == typeid(AsType) ||
30
+ typeid(p) == typeid(Ceil) || typeid(p) == typeid(Cos) ||
31
+ typeid(p) == typeid(Conjugate) || typeid(p) == typeid(Cosh) ||
32
+ typeid(p) == typeid(Remainder) || typeid(p) == typeid(Erf) ||
33
+ typeid(p) == typeid(ErfInv) || typeid(p) == typeid(Exp) ||
34
+ typeid(p) == typeid(Floor) || typeid(p) == typeid(Log) ||
35
+ typeid(p) == typeid(Log1p) || typeid(p) == typeid(LogicalNot) ||
36
+ typeid(p) == typeid(Negative) || typeid(p) == typeid(Round) ||
37
+ typeid(p) == typeid(Sigmoid) || typeid(p) == typeid(Sign) ||
38
+ typeid(p) == typeid(Sin) || typeid(p) == typeid(Sinh) ||
39
+ typeid(p) == typeid(Square) || typeid(p) == typeid(Sqrt) ||
40
+ typeid(p) == typeid(Tan) || typeid(p) == typeid(Tanh) ||
41
+ typeid(p) == typeid(Expm1) || typeid(p) == typeid(Real) ||
42
+ typeid(p) == typeid(Imag) || typeid(p) == typeid(BitwiseInvert));
43
+ }
44
+
45
+ bool is_binary(const Primitive& p) {
46
+ return (
47
+ typeid(p) == typeid(Add) || typeid(p) == typeid(Divide) ||
48
+ typeid(p) == typeid(Equal) || typeid(p) == typeid(Greater) ||
49
+ typeid(p) == typeid(GreaterEqual) || typeid(p) == typeid(Less) ||
50
+ typeid(p) == typeid(LessEqual) || typeid(p) == typeid(LogicalNot) ||
51
+ typeid(p) == typeid(LogicalAnd) || typeid(p) == typeid(LogicalOr) ||
52
+ typeid(p) == typeid(LogAddExp) || typeid(p) == typeid(Maximum) ||
53
+ typeid(p) == typeid(Minimum) || typeid(p) == typeid(Multiply) ||
54
+ typeid(p) == typeid(NotEqual) || typeid(p) == typeid(Power) ||
55
+ typeid(p) == typeid(Subtract) || typeid(p) == typeid(BitwiseBinary) ||
56
+ typeid(p) == typeid(ArcTan2));
57
+ }
58
+
59
+ bool is_ternary(const Primitive& p) {
60
+ return typeid(p) == typeid(Select);
61
+ }
62
+
63
+ bool is_broadcast(const Primitive& p) {
64
+ return typeid(p) == typeid(Broadcast);
65
+ }
66
+
67
+ bool is_noop(const Primitive& p) {
68
+ return typeid(p) == typeid(Copy) || typeid(p) == typeid(StopGradient);
69
+ }
70
+
71
+ bool is_reduction(const Primitive& p) {
72
+ return typeid(p) == typeid(Reduce) || typeid(p) == typeid(ArgReduce);
73
+ }
74
+
75
+ bool is_fusable(const Primitive& p) {
76
+ return is_unary(p) || is_binary(p) || is_ternary(p) || is_broadcast(p);
77
+ }
78
+
79
+ Compiled::Compiled(
80
+ Stream stream,
81
+ std::vector<array> inputs,
82
+ std::vector<array> outputs,
83
+ std::vector<array> tape,
84
+ std::unordered_set<uintptr_t> constant_ids)
85
+ : Primitive(stream),
86
+ inputs_(std::move(inputs)),
87
+ outputs_(std::move(outputs)),
88
+ tape_(std::move(tape)),
89
+ constant_ids_(std::move(constant_ids)),
90
+ is_constant_([this](size_t i) {
91
+ return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();
92
+ }) {
93
+ // Build the kernel name.
94
+ NodeNamer namer;
95
+ std::ostringstream os;
96
+ std::ostringstream constant_hasher;
97
+
98
+ std::unordered_set<uintptr_t> output_ids;
99
+ for (auto& o : outputs_) {
100
+ output_ids.insert(o.id());
101
+ }
102
+
103
+ // Fill the input names. This is not really necessary, I just like having A,
104
+ // B, C, ... as the inputs.
105
+ for (const auto& x : inputs_) {
106
+ namer.get_name(x);
107
+ }
108
+
109
+ // The primitives describing the tape. For unary and binary primitives this
110
+ // must be enough to describe the full computation.
111
+ for (const auto& a : tape_) {
112
+ // name and type of output
113
+ os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
114
+ // whether or not it's an output
115
+ if (output_ids.find(a.id()) != output_ids.end()) {
116
+ os << "O";
117
+ } else {
118
+ os << "I";
119
+ }
120
+ // computation performed
121
+ os << a.primitive().name();
122
+ // name of inputs to the function
123
+ for (auto& inp : a.inputs()) {
124
+ os << namer.get_name(inp);
125
+ }
126
+ }
127
+ os << "_";
128
+
129
+ for (const auto& x : inputs_) {
130
+ if (constant_ids_.find(x.id()) != constant_ids_.end()) {
131
+ os << "C";
132
+ print_constant(constant_hasher, x);
133
+ } else {
134
+ os << (is_scalar(x) ? "S" : "V");
135
+ }
136
+ }
137
+ os << "_";
138
+ for (const auto& x : inputs) {
139
+ if (constant_ids.find(x.id()) != constant_ids.end()) {
140
+ continue;
141
+ }
142
+ os << kindof(x.dtype()) << x.itemsize();
143
+ }
144
+ os << "_" << std::hash<std::string>{}(constant_hasher.str());
145
+
146
+ kernel_lib_ = os.str();
147
+ }
148
+
149
+ std::vector<array> Compiled::vjp(
150
+ const std::vector<array>&,
151
+ const std::vector<array>&,
152
+ const std::vector<int>&,
153
+ const std::vector<array>&) {
154
+ throw std::runtime_error("[Compiled] Cannot vjp primitive.");
155
+ }
156
+
157
+ std::vector<array> Compiled::jvp(
158
+ const std::vector<array>&,
159
+ const std::vector<array>&,
160
+ const std::vector<int>&) {
161
+ throw std::runtime_error("[Compiled] Cannot jvp primitive.");
162
+ }
163
+
164
+ std::pair<std::vector<array>, std::vector<int>> Compiled::vmap(
165
+ const std::vector<array>&,
166
+ const std::vector<int>&) {
167
+ throw std::runtime_error("[Compiled] Cannot vmap primitive.");
168
+ }
169
+
170
+ bool Compiled::is_equivalent(const Primitive& other) const {
171
+ const Compiled& a_other = static_cast<const Compiled&>(other);
172
+ return std::equal(
173
+ tape_.begin(),
174
+ tape_.end(),
175
+ a_other.tape_.begin(),
176
+ a_other.tape_.end(),
177
+ [](const array& a1, const array& a2) {
178
+ auto& p1 = a1.primitive();
179
+ auto& p2 = a2.primitive();
180
+ return typeid(p1) == typeid(p2) && p1.is_equivalent(p2);
181
+ });
182
+ }
183
+
184
+ const char* Compiled::name() const {
185
+ if (name_.empty()) {
186
+ std::ostringstream os;
187
+ os << "Compiled";
188
+ for (auto& a : tape_) {
189
+ os << a.primitive().name();
190
+ }
191
+ name_ = os.str();
192
+ }
193
+ return name_.c_str();
194
+ }
195
+
196
+ std::vector<Shape> Compiled::output_shapes(const std::vector<array>& inputs) {
197
+ size_t nd = 0;
198
+ for (auto& in : inputs) {
199
+ nd = std::max(nd, in.ndim());
200
+ }
201
+ Shape out_shape(nd, 0);
202
+ for (auto& in : inputs) {
203
+ auto dd = nd - in.ndim();
204
+ for (auto i = dd; i < nd; ++i) {
205
+ out_shape[i] = std::max(out_shape[i], in.shape()[i - dd]);
206
+ }
207
+ }
208
+ // All outputs have the same shape
209
+ return std::vector<Shape>(outputs_.size(), out_shape);
210
+ }
211
+
212
+ namespace detail {
213
+
214
+ CompileMode& compile_mode() {
215
+ auto get_val = []() {
216
+ if (std::getenv("MLX_DISABLE_COMPILE")) {
217
+ return CompileMode::disabled;
218
+ } else {
219
+ return CompileMode::enabled;
220
+ }
221
+ };
222
+ static CompileMode compile_mode_ = get_val();
223
+ return compile_mode_;
224
+ }
225
+
226
+ // Helper like below but only merges the two provided arrays. If the src has
227
+ // siblings then these won't be merged to the dst.
228
+ void merge_one(array& dst, array& src, ParentsMap& parents_map) {
229
+ auto src_parents = parents_map.find(src.id());
230
+ if (src_parents == parents_map.end()) {
231
+ return;
232
+ }
233
+ auto& pairs = parents_map[dst.id()];
234
+ for (auto& parent : src_parents->second) {
235
+ parent.first.inputs()[parent.second] = dst;
236
+ pairs.push_back(parent);
237
+ }
238
+
239
+ // If src is a parent of dst, remove it from dst's parents
240
+ for (auto it = pairs.begin(); it != pairs.end();) {
241
+ if (it->first.id() == src.id()) {
242
+ it = pairs.erase(it);
243
+ } else {
244
+ it++;
245
+ }
246
+ }
247
+ // Remove the source from the map to avoid fusing with it again
248
+ parents_map.erase(src_parents);
249
+ }
250
+
251
+ // Helper that merges two arrays in the graph by setting the parents of the
252
+ // source to point to the destination. The arrays are assumed to be coming from
253
+ // equivalent primitives so their siblings are merged as well.
254
+ void merge(array& dst, array& src, ParentsMap& parents_map) {
255
+ // Canonicalize the order of the primitives outputs
256
+ auto sources = src.outputs();
257
+ auto dests = dst.outputs();
258
+ // For each src parent, point it to the corresponding dst
259
+ for (int i = 0; i < sources.size(); ++i) {
260
+ merge_one(dests[i], sources[i], parents_map);
261
+ }
262
+ }
263
+
264
+ // Any parent in the divider will continue to refer to `x` but any parent not
265
+ // in the divider will refer to a copy of the operation.
266
+ array split_one(
267
+ const array& x,
268
+ ParentsMap& parents_map,
269
+ const std::unordered_set<uintptr_t>& divider) {
270
+ array y(x.shape(), x.dtype(), x.primitive_ptr(), x.inputs());
271
+
272
+ auto& x_parents = parents_map[x.id()];
273
+ auto& y_parents = parents_map[y.id()];
274
+
275
+ for (auto it = x_parents.begin(); it != x_parents.end();) {
276
+ if (divider.find(it->first.id()) != divider.end()) {
277
+ it->first.inputs()[it->second] = y;
278
+ y_parents.emplace_back(std::move(*it));
279
+ it = x_parents.erase(it);
280
+ } else {
281
+ it++;
282
+ }
283
+ }
284
+
285
+ return y;
286
+ }
287
+
288
+ template <typename T, typename... U>
289
+ std::uintptr_t get_function_address(const std::function<T(U...)>& fun) {
290
+ using FunType = T (*)(U...);
291
+ const FunType* fun_ptr = fun.template target<FunType>();
292
+ if (fun_ptr == nullptr) {
293
+ return 0;
294
+ }
295
+ return reinterpret_cast<std::uintptr_t>(*fun_ptr);
296
+ }
297
+
298
+ class CompilerCache {
299
+ public:
300
+ struct CacheEntry {
301
+ CacheEntry(Stream stream, bool shapeless)
302
+ : stream(stream), shapeless(shapeless) {};
303
+ Stream stream;
304
+ bool shapeless;
305
+ std::vector<array> inputs;
306
+ std::vector<array> outputs;
307
+ std::vector<array> tape;
308
+ bool empty{true};
309
+ std::vector<uint64_t> constants;
310
+ std::shared_ptr<void> extra;
311
+ };
312
+
313
+ // Returns a reference to a CacheEntry which can be updated
314
+ // by the caller to avoid copying large tapes / inputs / outputs
315
+ CacheEntry& find(
316
+ std::uintptr_t fun_id,
317
+ const std::vector<array>& inputs,
318
+ bool shapeless,
319
+ const std::vector<uint64_t>& constants) {
320
+ // Find the cache entries for |fun_id|.
321
+ std::vector<CacheEntry>& entries = cache_[fun_id];
322
+
323
+ // Compare if 2 arrays have same shape and dtype.
324
+ auto has_same_shape_and_dtype = [shapeless](
325
+ const std::vector<array>& in1,
326
+ const std::vector<array>& in2) {
327
+ if (in1.size() != in2.size()) {
328
+ return false;
329
+ }
330
+ for (size_t i = 0; i < in1.size(); ++i) {
331
+ if (in1[i].ndim() != in2[i].ndim()) {
332
+ return false;
333
+ }
334
+ if (!shapeless && in1[i].shape() != in2[i].shape()) {
335
+ return false;
336
+ }
337
+ if (in1[i].dtype() != in2[i].dtype()) {
338
+ return false;
339
+ }
340
+ }
341
+ return true;
342
+ };
343
+ // Loop over entries and check:
344
+ // - Default stream and device match the entry's default stream
345
+ // - Inputs match i.e. shapes and types must be equal.
346
+ auto stream = default_stream(default_device());
347
+ for (CacheEntry& entry : entries) {
348
+ // Check that the default stream and device match
349
+ if (entry.stream != stream) {
350
+ continue;
351
+ }
352
+ if (entry.shapeless != shapeless) {
353
+ continue;
354
+ }
355
+
356
+ // Check the inputs match and return if so
357
+ if (has_same_shape_and_dtype(inputs, entry.inputs) &&
358
+ constants == entry.constants) {
359
+ return entry;
360
+ }
361
+ }
362
+ // Otherwise append a new cache entry
363
+ entries.push_back(CacheEntry{stream, shapeless});
364
+ return entries.back();
365
+ }
366
+
367
+ void erase(std::uintptr_t fun_id) {
368
+ cache_.erase(fun_id);
369
+ }
370
+
371
+ void clear() {
372
+ cache_.clear();
373
+ }
374
+
375
+ private:
376
+ CompilerCache() {
377
+ // Make sure the allocator is fully
378
+ // initialized before the compiler cache
379
+ allocator::allocator();
380
+ }
381
+
382
+ friend CompilerCache& compiler_cache();
383
+ std::unordered_map<std::uintptr_t, std::vector<CacheEntry>> cache_;
384
+ };
385
+
386
+ CompilerCache& compiler_cache() {
387
+ static CompilerCache compiler_cache_;
388
+ return compiler_cache_;
389
+ }
390
+
391
+ std::tuple<std::vector<array>, std::vector<array>, std::shared_ptr<void>>
392
+ compile_trace(
393
+ const ArrayFnWithExtra& fun,
394
+ const std::vector<array>& inputs,
395
+ bool shapeless) {
396
+ // Set the global tracing flag.
397
+ detail::InTracing in_tracing{shapeless};
398
+
399
+ // Run the function on placeholder inputs
400
+ // to get compute graph
401
+ std::vector<array> tracer_inputs;
402
+ for (int i = 0; i < inputs.size(); ++i) {
403
+ array in(inputs[i].shape(), inputs[i].dtype(), nullptr, {});
404
+ in.set_tracer(true);
405
+ tracer_inputs.push_back(std::move(in));
406
+ }
407
+
408
+ auto output = fun(tracer_inputs);
409
+ return {tracer_inputs, output.first, output.second};
410
+ }
411
+
412
+ // Traverses the graph to build a tape and a map of array ids to their parents
413
+ std::pair<std::vector<array>, ParentsMap> compile_dfs(
414
+ const std::vector<array>& inputs,
415
+ std::vector<array>& outputs,
416
+ const std::vector<array>& original_inputs) {
417
+ std::vector<array> tape;
418
+ std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
419
+ parents_map;
420
+ {
421
+ std::function<void(const array&)> recurse;
422
+ std::unordered_set<std::uintptr_t> input_set;
423
+ std::unordered_set<std::uintptr_t> original_input_set;
424
+ for (int i = 0; i < inputs.size(); ++i) {
425
+ input_set.insert(inputs[i].id());
426
+ original_input_set.insert(original_inputs[i].id());
427
+ }
428
+
429
+ // DFS the graph to build the tape, and log parents and scalars
430
+ std::unordered_set<std::uintptr_t> cache;
431
+ recurse = [&](const array& a) {
432
+ auto id = a.id();
433
+ if (original_input_set.find(id) != original_input_set.end()) {
434
+ throw std::invalid_argument(
435
+ "[compile] Attempting to compile a function with uncaptured inputs is not allowed.");
436
+ }
437
+ if (cache.find(id) != cache.end()) {
438
+ return;
439
+ }
440
+ for (int i = 0; i < a.inputs().size(); i++) {
441
+ auto& in = a.inputs()[i];
442
+ parents_map[in.id()].push_back({a, i});
443
+ for (auto& s : a.siblings()) {
444
+ parents_map[in.id()].push_back({s, i});
445
+ }
446
+ // Don't recurse on inputs (but add them to the tape for the purpose
447
+ // of future optimizations)
448
+ if (input_set.find(a.id()) == input_set.end()) {
449
+ recurse(in);
450
+ }
451
+ }
452
+ cache.insert(id);
453
+ for (auto& s : a.siblings()) {
454
+ cache.insert(s.id());
455
+ }
456
+ tape.push_back(a);
457
+ };
458
+ for (auto& a : outputs) {
459
+ recurse(a);
460
+ }
461
+ }
462
+
463
+ // Deep copy the tape and parents map while preserving inputs and outputs
464
+ std::vector<array> new_tape;
465
+ std::unordered_set<uintptr_t> io_set;
466
+ std::unordered_map<uintptr_t, array> old_to_new;
467
+ for (auto& o : outputs) {
468
+ old_to_new.insert({o.id(), o});
469
+ io_set.insert(o.id());
470
+ for (auto& s : o.siblings()) {
471
+ old_to_new.insert({s.id(), s});
472
+ io_set.insert(s.id());
473
+ }
474
+ }
475
+ for (auto& i : inputs) {
476
+ io_set.insert(i.id());
477
+ old_to_new.insert({i.id(), i});
478
+ }
479
+
480
+ new_tape.reserve(tape.size());
481
+ for (auto& arr : tape) {
482
+ if (!arr.has_primitive() || (io_set.find(arr.id()) != io_set.end())) {
483
+ old_to_new.insert({arr.id(), arr});
484
+ new_tape.push_back(arr);
485
+ continue;
486
+ }
487
+ std::vector<array> inputs;
488
+ inputs.reserve(arr.inputs().size());
489
+ for (auto& i : arr.inputs()) {
490
+ inputs.push_back(old_to_new.find(i.id())->second);
491
+ }
492
+ if (arr.siblings().size() > 0) {
493
+ std::vector<Dtype> types;
494
+ std::vector<Shape> shapes;
495
+ auto out = arr.outputs();
496
+ for (auto& o : out) {
497
+ types.push_back(o.dtype());
498
+ shapes.push_back(o.shape());
499
+ }
500
+ auto as = array::make_arrays(
501
+ std::move(shapes), types, arr.primitive_ptr(), std::move(inputs));
502
+ for (int i = 0; i < out.size(); ++i) {
503
+ old_to_new.insert({out[i].id(), as[i]});
504
+ }
505
+ new_tape.push_back(as[arr.sibling_position()]);
506
+ } else {
507
+ auto a = array(
508
+ arr.shape(), arr.dtype(), arr.primitive_ptr(), std::move(inputs));
509
+ old_to_new.insert({arr.id(), a});
510
+ new_tape.push_back(a);
511
+ }
512
+ }
513
+ io_set.clear();
514
+ for (auto& o : outputs) {
515
+ if (!(io_set.insert(o.id()).second)) {
516
+ continue;
517
+ }
518
+ for (auto& i : o.inputs()) {
519
+ i = old_to_new.find(i.id())->second;
520
+ }
521
+ for (auto& s : o.siblings()) {
522
+ io_set.insert(s.id());
523
+ for (auto& i : s.inputs()) {
524
+ i = old_to_new.find(i.id())->second;
525
+ }
526
+ }
527
+ }
528
+ tape = std::move(new_tape);
529
+
530
+ std::unordered_map<std::uintptr_t, std::vector<std::pair<array, int>>>
531
+ new_parents_map;
532
+ for (auto& [id, vec] : parents_map) {
533
+ for (auto& [a, _] : vec) {
534
+ a = old_to_new.find(a.id())->second;
535
+ }
536
+ new_parents_map[old_to_new.find(id)->second.id()] = std::move(vec);
537
+ }
538
+ parents_map = std::move(new_parents_map);
539
+ return {tape, parents_map};
540
+ }
541
+
542
+ static inline uint64_t splitmix64(uint64_t x) noexcept {
543
+ x += 0x9e3779b97f4a7c15ull;
544
+ x = (x ^ (x >> 30)) * 0xbf58476d1ce4e5b9ull;
545
+ x = (x ^ (x >> 27)) * 0x94d049bb133111ebull;
546
+ return x ^ (x >> 31);
547
+ }
548
+
549
+ struct VecU64Hash {
550
+ size_t operator()(const std::vector<uint64_t>& s) const noexcept {
551
+ uint64_t h =
552
+ 0x243f6a8885a308d3ull ^ (uint64_t)s.size() * 0x9e3779b97f4a7c15ull;
553
+ for (uint64_t x : s) {
554
+ h = splitmix64(x ^ splitmix64(h + 0x9e3779b97f4a7c15ull));
555
+ }
556
+ return (size_t)h;
557
+ }
558
+ };
559
+
560
+ // Simplify the tape. Note, this function modifies in-place both the tape,
561
+ // the parents map to remove orphaned arrays, and potentially the outputs
562
+ void compile_simplify(
563
+ std::vector<array>& tape,
564
+ ParentsMap& parents_map,
565
+ std::vector<array>& outputs,
566
+ int passes) {
567
+ // Helpers to identify identical scalars
568
+ std::map<std::pair<uint64_t, Dtype::Val>, array> scalars;
569
+ auto is_scalar = [](const array& a) {
570
+ // Condition for when it's safe to read an array
571
+ return a.is_available() && a.ndim() == 0;
572
+ };
573
+ auto get_scalar_rep = [](const array& a) {
574
+ uint64_t v = 0;
575
+ switch (a.dtype().size()) {
576
+ case 1:
577
+ v = *a.data<uint8_t>();
578
+ break;
579
+ case 2:
580
+ v = *a.data<uint16_t>();
581
+ break;
582
+ case 4:
583
+ v = *a.data<uint32_t>();
584
+ break;
585
+ case 8:
586
+ v = *a.data<uint64_t>();
587
+ break;
588
+ }
589
+ return std::make_pair(v, a.dtype().val());
590
+ };
591
+
592
+ for (auto& a : tape) {
593
+ if (is_scalar(a)) {
594
+ scalars.insert({get_scalar_rep(a), a});
595
+ }
596
+ }
597
+
598
+ // Depth-1 array equivalence check.
599
+ auto array_equivalent = [](const array& a, const array& b) {
600
+ if (!a.has_primitive() || !b.has_primitive()) {
601
+ return false;
602
+ }
603
+ if (a.primitive_id() == b.primitive_id()) {
604
+ return false;
605
+ }
606
+ const auto& pa = a.primitive();
607
+ const auto& pb = b.primitive();
608
+ if (typeid(pa) != typeid(pb)) {
609
+ return false;
610
+ }
611
+
612
+ if (a.inputs().size() != b.inputs().size()) {
613
+ return false;
614
+ }
615
+
616
+ for (int i = 0; i < a.inputs().size(); i++) {
617
+ if (a.inputs()[i].id() != b.inputs()[i].id()) {
618
+ return false;
619
+ }
620
+ }
621
+
622
+ return pa.is_equivalent(pb);
623
+ };
624
+
625
+ // Merge scalars
626
+ std::vector<array> new_tape;
627
+ for (auto& arr : tape) {
628
+ // Check if we can merge scalars
629
+ if (is_scalar(arr)) {
630
+ auto scalar = scalars.find(get_scalar_rep(arr));
631
+ if (scalar->second.id() != arr.id()) {
632
+ merge(scalar->second, arr, parents_map);
633
+ // Don't keep orphaned scalars in the tape
634
+ continue;
635
+ }
636
+ }
637
+ new_tape.push_back(std::move(arr));
638
+ }
639
+ tape = std::move(new_tape);
640
+
641
+ // Remove no-ops
642
+ {
643
+ std::unordered_map<uintptr_t, array> output_map;
644
+ for (auto& o : outputs) {
645
+ output_map.insert({o.id(), o});
646
+ }
647
+ for (auto& arr : tape) {
648
+ if (!arr.has_primitive() || !is_noop(arr.primitive())) {
649
+ new_tape.push_back(std::move(arr));
650
+ continue;
651
+ }
652
+ merge_one(arr.inputs()[0], arr, parents_map);
653
+ if (auto it = output_map.find(arr.id()); it != output_map.end()) {
654
+ it->second = arr.inputs()[0];
655
+ }
656
+ }
657
+ tape = std::move(new_tape);
658
+ for (auto& o : outputs) {
659
+ o = output_map.at(o.id());
660
+ }
661
+ }
662
+
663
+ std::unordered_map<std::uintptr_t, uint32_t> tape_order;
664
+ for (uint32_t i = 0; i < tape.size(); ++i) {
665
+ tape_order.insert({tape[i].id(), i});
666
+ }
667
+
668
+ std::unordered_set<uintptr_t> output_set;
669
+ for (auto& o : outputs) {
670
+ output_set.insert(o.id());
671
+ }
672
+
673
+ // Multi-pass merge only keeping non-orphaned arrays in the tape
674
+ for (int pass = 0; pass < passes; ++pass) {
675
+ for (auto& arr : tape) {
676
+ // Helper to check if we can merge the parents of the
677
+ // given array
678
+ auto maybe_merge_parents = [&](auto& a) {
679
+ auto parents = parents_map.find(a.id());
680
+ if (parents != parents_map.end()) {
681
+ auto N = parents->second.size();
682
+ std::vector<bool> mask(N, false);
683
+
684
+ auto try_merge = [&](int dst_idx, int src_idx) {
685
+ if (tape_order[parents->second[src_idx].first.id()] <
686
+ tape_order[parents->second[dst_idx].first.id()]) {
687
+ std::swap(src_idx, dst_idx);
688
+ }
689
+ auto& src = parents->second[src_idx].first;
690
+ auto& dst = parents->second[dst_idx].first;
691
+ if (src.id() != dst.id() && array_equivalent(src, dst) &&
692
+ output_set.find(src.id()) == output_set.end()) {
693
+ merge(dst, src, parents_map);
694
+ mask[src_idx] = true;
695
+ }
696
+ };
697
+
698
+ if (N > 100) {
699
+ std::unordered_map<
700
+ std::vector<uint64_t>,
701
+ std::vector<int>,
702
+ VecU64Hash>
703
+ dst_map;
704
+ // Find possibly mergeable groups
705
+ for (int i = 0; i < N; i++) {
706
+ // Make the hash key
707
+ std::vector<uint64_t> key;
708
+ auto& curr = parents->second[i].first;
709
+ key.reserve(curr.inputs().size() + 2);
710
+ for (auto& in : curr.inputs()) {
711
+ key.push_back(in.id());
712
+ }
713
+ auto& p = curr.primitive();
714
+ key.push_back(curr.inputs().size());
715
+ key.push_back(typeid(p).hash_code());
716
+ auto it = dst_map.find(key);
717
+ if (it == dst_map.end()) {
718
+ bool _;
719
+ std::tie(it, _) = dst_map.insert({key, std::vector<int>{}});
720
+ }
721
+ it->second.push_back(i);
722
+ }
723
+ for (auto& [_, group] : dst_map) {
724
+ for (int i = 0; i < group.size(); ++i) {
725
+ if (mask[group[i]]) {
726
+ continue;
727
+ }
728
+ for (int j = i + 1; j < group.size(); ++j) {
729
+ if (mask[group[j]]) {
730
+ continue;
731
+ }
732
+ try_merge(group[i], group[j]);
733
+ }
734
+ }
735
+ }
736
+ } else {
737
+ for (int i = 0; i < N; ++i) {
738
+ if (mask[i]) {
739
+ continue;
740
+ }
741
+ for (int j = i + 1; j < N; ++j) {
742
+ if (mask[j]) {
743
+ continue;
744
+ }
745
+ try_merge(i, j);
746
+ }
747
+ }
748
+ }
749
+
750
+ // Erase orphaned parents so we don't keep fusing with them
751
+ for (int i = N - 1; i >= 0; --i) {
752
+ if (mask[i]) {
753
+ parents->second.erase(parents->second.begin() + i);
754
+ }
755
+ }
756
+ return false;
757
+ } else {
758
+ return output_set.find(a.id()) == output_set.end();
759
+ }
760
+ };
761
+ bool discard = maybe_merge_parents(arr);
762
+ for (auto& s : arr.siblings()) {
763
+ discard &= maybe_merge_parents(s);
764
+ }
765
+ // If an array and its siblings have no parents, and none of them are
766
+ // outputs, it is safe to remove it from the tape
767
+ if (!discard) {
768
+ new_tape.push_back(std::move(arr));
769
+ }
770
+ }
771
+ tape = std::move(new_tape);
772
+ }
773
+ }
774
+
775
+ // Extract sub-graphs of the graph that can be compiled
776
+ // and replace them with a Compiled Primitive.
777
+ void compile_fuse(
778
+ std::vector<array>& tape,
779
+ ParentsMap& parents_map,
780
+ const std::vector<array>& inputs,
781
+ std::vector<array>& outputs) {
782
+ // Track outputs to replace with new compiled outputs
783
+ std::unordered_map<uintptr_t, array> output_map;
784
+ for (auto& o : outputs) {
785
+ output_map.insert({o.id(), o});
786
+ }
787
+
788
+ // Set of inputs to distinguish constants
789
+ std::unordered_set<uintptr_t> input_ids;
790
+ for (auto& in : inputs) {
791
+ input_ids.insert(in.id());
792
+ }
793
+
794
+ // Go through the tape in reverse order and check for fusable sub-graphs
795
+ std::vector<array> new_tape;
796
+ std::unordered_set<uintptr_t> global_cache;
797
+ for (int i = tape.size() - 1; i >= 0; --i) {
798
+ auto& arr = tape[i];
799
+
800
+ // Already compiled
801
+ if (global_cache.find(arr.id()) != global_cache.end()) {
802
+ continue;
803
+ }
804
+
805
+ // Two pass recursion:
806
+ // First pass:
807
+ // - Collect all the primitives which we can fuse with
808
+ // - Keeps a cache of fusable primitives which may be added out of
809
+ // DAG order. We have to determine if all of a fused primitive's
810
+ // outputs are also in the fused section, and this may not be the
811
+ // case the first time we visit it.
812
+ // Second pass:
813
+ // - Collect inputs to the new compiled primitive
814
+ // - Add fusable primitives to a tape in the correct order
815
+
816
+ std::function<void(const array&, int, const Stream&, const Shape&)> recurse;
817
+ std::unordered_set<uintptr_t> cache;
818
+ std::unordered_set<uintptr_t> input_set;
819
+ recurse = [&](const array& a,
820
+ int depth,
821
+ const Stream& s,
822
+ const Shape& shape) {
823
+ if (cache.find(a.id()) != cache.end()) {
824
+ return;
825
+ }
826
+
827
+ // Stop fusing if:
828
+ // - Depth limit exceeded
829
+ // - Constant input
830
+ // - Stream mismatch
831
+ // - Non fusable primitive
832
+ // - Is global output but has a different shape
833
+ if (depth >= max_compile_depth || !a.has_primitive() ||
834
+ a.primitive().stream() != s || !is_fusable(a.primitive()) ||
835
+ (output_map.find(a.id()) != output_map.end() && a.shape() != shape)) {
836
+ // Possible input
837
+ input_set.insert(a.id());
838
+ return;
839
+ }
840
+
841
+ bool all_parents_in = true;
842
+ if (depth > 0) {
843
+ // Guaranteed to have a parent since nested in the
844
+ // recursion.
845
+ auto& parents = parents_map.at(a.id());
846
+ for (auto& [p, idx] : parents) {
847
+ auto in_cache = cache.find(p.id()) != cache.end();
848
+ if (!in_cache) {
849
+ all_parents_in = false;
850
+ break;
851
+ }
852
+ }
853
+ }
854
+
855
+ // Arrays with a mix of parents outside the compilable section
856
+ // are not fusable except for broadcast which we can split to avoid
857
+ // stopping fusion
858
+ if (!all_parents_in) {
859
+ if (a.has_primitive() && is_broadcast(a.primitive())) {
860
+ array b = split_one(a, parents_map, cache);
861
+ recurse(b, depth, s, shape);
862
+ } else {
863
+ // Possible input
864
+ input_set.insert(a.id());
865
+ }
866
+ return;
867
+ }
868
+
869
+ if (output_map.find(a.id()) != output_map.end()) {
870
+ input_set.insert(a.id());
871
+ } else {
872
+ // Not an input anymore since fusing it
873
+ input_set.erase(a.id());
874
+ }
875
+ if (input_set.size() >= max_compile_arrays) {
876
+ return;
877
+ }
878
+ cache.insert({a.id()});
879
+
880
+ for (auto& in : a.inputs()) {
881
+ recurse(in, depth + 1, s, shape);
882
+ }
883
+ };
884
+
885
+ // This will be the result of the fused operation so it needs
886
+ // a) to not be already computed ie have a primitive
887
+ // b) that primitive to not be a broadcast since it will unnecessarily
888
+ // cast to a contiguous array potentially blowing up memory
889
+ if (arr.has_primitive() && !is_broadcast(arr.primitive())) {
890
+ Stream s = arr.primitive().stream();
891
+ recurse(arr, 0, s, arr.shape());
892
+ }
893
+
894
+ // Not worth fusing a single primitive
895
+ if (cache.size() <= 1) {
896
+ new_tape.push_back(arr);
897
+ continue;
898
+ }
899
+
900
+ // Recurse a second time to build the tape in the right
901
+ // order and collect the inputs
902
+ input_set.clear();
903
+ std::vector<array> inputs;
904
+ std::vector<array> fused_tape;
905
+ std::unordered_set<uintptr_t> tape_set;
906
+ std::function<void(const array&)> recurse_tape;
907
+ recurse_tape = [&](const array& a) {
908
+ if (cache.find(a.id()) == cache.end()) {
909
+ if (input_set.find(a.id()) == input_set.end()) {
910
+ input_set.insert(a.id());
911
+ inputs.push_back(a);
912
+ }
913
+ return;
914
+ }
915
+ if (tape_set.find(a.id()) != tape_set.end()) {
916
+ return;
917
+ }
918
+ tape_set.insert(a.id());
919
+ for (auto& in : a.inputs()) {
920
+ recurse_tape(in);
921
+ }
922
+ fused_tape.push_back(a);
923
+ };
924
+ recurse_tape(arr);
925
+
926
+ std::vector<array> old_outputs;
927
+ // Add to global cache and add any global outputs to outputs
928
+ // of new primitive
929
+ for (int j = 0; j < fused_tape.size() - 1; ++j) {
930
+ auto& f = fused_tape[j];
931
+ if (output_map.find(f.id()) != output_map.end()) {
932
+ old_outputs.push_back(f);
933
+ // Parents are now siblings, update the parent map
934
+ auto& pairs = parents_map[f.id()];
935
+ pairs.erase(
936
+ std::remove_if(
937
+ pairs.begin(),
938
+ pairs.end(),
939
+ [&](auto& p) {
940
+ return cache.find(p.first.id()) != cache.end();
941
+ }),
942
+ pairs.end());
943
+ } else {
944
+ // Remove inner fused arrays parents from the parents map
945
+ // to keep the parents map in a valid state
946
+ parents_map.erase(f.id());
947
+ }
948
+ global_cache.insert({f.id()});
949
+ }
950
+ old_outputs.push_back(arr);
951
+
952
+ std::vector<Shape> shapes;
953
+ std::vector<Dtype> types;
954
+ for (auto& o : old_outputs) {
955
+ if (o.shape() != old_outputs.back().shape()) {
956
+ throw std::runtime_error(
957
+ "[compile] Compilation failed. Tried to fuse operations with different output shapes");
958
+ }
959
+ shapes.push_back(o.shape());
960
+ types.push_back(o.dtype());
961
+ }
962
+ std::unordered_set<uintptr_t> constant_ids;
963
+ for (auto& in : inputs) {
964
+ // Scalar constant
965
+ if (in.size() == 1 && !in.has_primitive() &&
966
+ input_ids.find(in.id()) == input_ids.end()) {
967
+ constant_ids.insert(in.id());
968
+ }
969
+ }
970
+ auto compiled_outputs = array::make_arrays(
971
+ std::move(shapes),
972
+ types,
973
+ std::make_shared<Compiled>(
974
+ old_outputs.back().primitive().stream(),
975
+ inputs,
976
+ old_outputs,
977
+ std::move(fused_tape),
978
+ std::move(constant_ids)),
979
+ inputs);
980
+
981
+ // One output per primitive
982
+ new_tape.push_back(compiled_outputs.back());
983
+
984
+ // Replace inputs old parents with compiled_outputs
985
+ for (int i = 0; i < inputs.size(); ++i) {
986
+ auto& pairs = parents_map[inputs[i].id()];
987
+ pairs.erase(
988
+ std::remove_if(
989
+ pairs.begin(),
990
+ pairs.end(),
991
+ [&](auto& p) { return cache.find(p.first.id()) != cache.end(); }),
992
+ pairs.end());
993
+ for (auto& o : compiled_outputs) {
994
+ pairs.push_back({o, i});
995
+ }
996
+ }
997
+
998
+ // - Update outputs parents to point to compiled outputs
999
+ // - Update any overall graph outputs to be compiled outputs
1000
+ for (int o = 0; o < old_outputs.size(); ++o) {
1001
+ merge_one(compiled_outputs[o], old_outputs[o], parents_map);
1002
+ if (auto it = output_map.find(old_outputs[o].id());
1003
+ it != output_map.end()) {
1004
+ it->second = compiled_outputs[o];
1005
+ }
1006
+ }
1007
+ }
1008
+
1009
+ std::reverse(new_tape.begin(), new_tape.end());
1010
+ tape = std::move(new_tape);
1011
+
1012
+ // Replace output with potentially compiled output
1013
+ for (auto& o : outputs) {
1014
+ o = output_map.at(o.id());
1015
+ }
1016
+ }
1017
+
1018
+ std::vector<array> compile_replace(
1019
+ const std::vector<array>& tape,
1020
+ const std::vector<array>& trace_inputs,
1021
+ const std::vector<array>& trace_outputs,
1022
+ const std::vector<array>& inputs,
1023
+ bool shapeless) {
1024
+ std::unordered_map<uintptr_t, array> trace_to_real;
1025
+ for (int i = 0; i < inputs.size(); ++i) {
1026
+ trace_to_real.insert({trace_inputs[i].id(), inputs[i]});
1027
+ }
1028
+
1029
+ auto is_load = [](const Primitive& p) { return typeid(p) == typeid(Load); };
1030
+
1031
+ for (auto& a : tape) {
1032
+ // Arrays in the tape without primitives are either:
1033
+ // - inputs, which are already in the map
1034
+ // - constants, which can be used directly
1035
+ // - a load primitive which has no inputs and will become a constant
1036
+ // after the first eval
1037
+ if (!a.has_primitive() || is_load(a.primitive())) {
1038
+ trace_to_real.insert({a.id(), a});
1039
+ } else {
1040
+ // Find real inputs
1041
+ std::vector<array> real_inputs;
1042
+ for (auto& in : a.inputs()) {
1043
+ real_inputs.push_back(trace_to_real.at(in.id()));
1044
+ }
1045
+ if (a.siblings().empty()) {
1046
+ auto shape =
1047
+ shapeless ? a.primitive().output_shapes(real_inputs)[0] : a.shape();
1048
+ auto real_a = array(
1049
+ std::move(shape),
1050
+ a.dtype(),
1051
+ a.primitive_ptr(),
1052
+ std::move(real_inputs));
1053
+ trace_to_real.insert({a.id(), std::move(real_a)});
1054
+ } else {
1055
+ // Ensure the order is correct for multi-output primitives
1056
+ std::vector<Dtype> types;
1057
+ auto trace_out = a.outputs();
1058
+ for (auto& o : trace_out) {
1059
+ types.push_back(o.dtype());
1060
+ }
1061
+ std::vector<Shape> shapes;
1062
+ if (shapeless) {
1063
+ shapes = a.primitive().output_shapes(real_inputs);
1064
+ } else {
1065
+ for (auto& o : trace_out) {
1066
+ shapes.push_back(o.shape());
1067
+ }
1068
+ }
1069
+ auto real_out = array::make_arrays(
1070
+ std::move(shapes), types, a.primitive_ptr(), real_inputs);
1071
+ for (int i = 0; i < trace_out.size(); ++i) {
1072
+ trace_to_real.insert({trace_out[i].id(), std::move(real_out[i])});
1073
+ }
1074
+ }
1075
+ }
1076
+ }
1077
+
1078
+ std::vector<array> outputs;
1079
+ for (auto& o : trace_outputs) {
1080
+ outputs.push_back(trace_to_real.at(o.id()));
1081
+ }
1082
+ return outputs;
1083
+ }
1084
+
1085
+ bool skip_compile() {
1086
+ return compile_mode() == CompileMode::disabled ||
1087
+ !(compile_available_for_device(default_device()));
1088
+ }
1089
+
1090
+ ArrayFnWithExtra compile(
1091
+ ArrayFnWithExtra fun,
1092
+ std::uintptr_t fun_id,
1093
+ bool shapeless /* = false */,
1094
+ std::vector<uint64_t> constants /* = {} */) {
1095
+ if (skip_compile()) {
1096
+ return fun;
1097
+ }
1098
+ if (!fun) {
1099
+ throw std::invalid_argument(
1100
+ "[compile] Cannot compile a function without a target.");
1101
+ }
1102
+
1103
+ return [fun = std::move(fun),
1104
+ fun_id,
1105
+ shapeless,
1106
+ constants = std::move(constants)](const std::vector<array>& inputs) {
1107
+ // If the inputs are tracers, trace the original graph
1108
+ if (std::any_of(inputs.begin(), inputs.end(), [](auto& in) {
1109
+ return in.is_tracer();
1110
+ })) {
1111
+ return fun(inputs);
1112
+ }
1113
+
1114
+ // Find a cache entry with the correct inputs
1115
+ auto& entry = compiler_cache().find(fun_id, inputs, shapeless, constants);
1116
+
1117
+ // No matching cache entry existed, so compile
1118
+ if (entry.empty) {
1119
+ // Mark the entry as not empty since we are about to fill it
1120
+ entry.empty = false;
1121
+ // Set the constants
1122
+ entry.constants = std::move(constants);
1123
+ // Trace to build the graph
1124
+ std::tie(entry.inputs, entry.outputs, entry.extra) =
1125
+ compile_trace(fun, inputs, shapeless);
1126
+
1127
+ // DFS the graph and get a tape, and a map of array id to (parent,
1128
+ // position in parent inputs)
1129
+ std::unordered_map<uintptr_t, std::vector<std::pair<array, int>>>
1130
+ parents_map;
1131
+ std::tie(entry.tape, parents_map) =
1132
+ compile_dfs(entry.inputs, entry.outputs, inputs);
1133
+
1134
+ // Simplify the tape
1135
+ if (compile_mode() != CompileMode::no_simplify) {
1136
+ compile_simplify(
1137
+ entry.tape, parents_map, entry.outputs, /* passes */ 3);
1138
+ }
1139
+
1140
+ // Kernel fusion to generate Compiled primitives. The tape and
1141
+ // new outputs must be updated accordingly
1142
+ if (compile_mode() != CompileMode::no_fuse) {
1143
+ compile_fuse(entry.tape, parents_map, entry.inputs, entry.outputs);
1144
+ }
1145
+ }
1146
+
1147
+ // At this point we must have a tape, now replace the placeholders
1148
+ // with real arrays that can be evaluated
1149
+ return ArraysAndExtra{
1150
+ compile_replace(
1151
+ entry.tape, entry.inputs, entry.outputs, inputs, shapeless),
1152
+ entry.extra};
1153
+ };
1154
+ }
1155
+
1156
+ std::function<std::vector<array>(const std::vector<array>&)> compile(
1157
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
1158
+ std::uintptr_t fun_id,
1159
+ bool shapeless /* = false */,
1160
+ std::vector<uint64_t> constants /* = {} */) {
1161
+ if (skip_compile()) {
1162
+ return fun;
1163
+ }
1164
+ if (!fun) {
1165
+ throw std::invalid_argument(
1166
+ "[compile] Cannot compile a function without a target.");
1167
+ }
1168
+
1169
+ ArrayFnWithExtra fun_with_extra =
1170
+ [fun = std::move(fun)](const std::vector<array>& inputs) {
1171
+ return ArraysAndExtra{fun(inputs), nullptr};
1172
+ };
1173
+
1174
+ auto compiled_fun = compile(
1175
+ std::move(fun_with_extra), fun_id, shapeless, std::move(constants));
1176
+
1177
+ return [compiled_fun =
1178
+ std::move(compiled_fun)](const std::vector<array>& inputs) {
1179
+ return compiled_fun(inputs).first;
1180
+ };
1181
+ }
1182
+
1183
+ void compile_erase(std::uintptr_t fun_id) {
1184
+ detail::compiler_cache().erase(fun_id);
1185
+ }
1186
+
1187
+ void compile_clear_cache() {
1188
+ detail::compiler_cache().clear();
1189
+ }
1190
+
1191
+ } // namespace detail
1192
+
1193
+ std::function<std::vector<array>(const std::vector<array>&)> compile(
1194
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
1195
+ bool shapeless /* false */) {
1196
+ if (detail::skip_compile()) {
1197
+ return fun;
1198
+ }
1199
+ auto fun_id = detail::get_function_address(fun);
1200
+ if (fun_id) {
1201
+ // If the function has an addressable target then no need to manage it's
1202
+ // lifetime
1203
+ return detail::compile(std::move(fun), fun_id, shapeless);
1204
+ } else {
1205
+ auto pfun = std::shared_ptr<
1206
+ std::function<std::vector<array>(const std::vector<array>&)>>(
1207
+ new std::function<std::vector<array>(const std::vector<array>&)>{fun},
1208
+ [](auto* p) {
1209
+ detail::compile_erase(reinterpret_cast<std::uintptr_t>(p));
1210
+ delete p;
1211
+ });
1212
+ fun_id = reinterpret_cast<std::uintptr_t>(pfun.get());
1213
+ return detail::compile(
1214
+ [pfun = std::move(pfun)](const auto& inputs) {
1215
+ return (*pfun)(inputs);
1216
+ },
1217
+ fun_id,
1218
+ shapeless);
1219
+ }
1220
+ }
1221
+
1222
+ std::function<std::vector<array>(const std::vector<array>&)> compile(
1223
+ std::vector<array> (*fun)(const std::vector<array>&),
1224
+ bool shapeless /* = false */) {
1225
+ if (detail::skip_compile()) {
1226
+ return fun;
1227
+ }
1228
+ return detail::compile(fun, reinterpret_cast<std::uintptr_t>(fun), shapeless);
1229
+ }
1230
+
1231
+ void disable_compile() {
1232
+ detail::compile_mode() = CompileMode::disabled;
1233
+ }
1234
+
1235
+ void enable_compile() {
1236
+ detail::compile_mode() = CompileMode::enabled;
1237
+ }
1238
+
1239
+ void set_compile_mode(CompileMode mode) {
1240
+ detail::compile_mode() = mode;
1241
+ }
1242
+
1243
+ } // namespace mlx::core