mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,1065 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+ #include <algorithm>
3
+ #include <deque>
4
+ #include <future>
5
+ #include <numeric>
6
+ #include <set>
7
+ #include <sstream>
8
+ #include <stack>
9
+ #include <unordered_map>
10
+ #include <unordered_set>
11
+
12
+ #include "mlx/backend/cpu/eval.h"
13
+ #include "mlx/backend/gpu/eval.h"
14
+ #include "mlx/fence.h"
15
+ #include "mlx/memory.h"
16
+ #include "mlx/ops.h"
17
+ #include "mlx/primitives.h"
18
+ #include "mlx/scheduler.h"
19
+ #include "mlx/transforms.h"
20
+ #include "mlx/transforms_impl.h"
21
+ #include "mlx/utils.h"
22
+
23
+ namespace mlx::core {
24
+
25
+ static constexpr int MAX_ACTIVE_TASKS = 10;
26
+
27
+ /* This class is only meant to be used in eval
28
+ * for synchronizing with the main thread. */
29
+ class Synchronizer : public Primitive {
30
+ public:
31
+ explicit Synchronizer(Stream stream) : Primitive(stream) {}
32
+
33
+ void eval_cpu(const std::vector<array>&, std::vector<array>&) override {}
34
+ void eval_gpu(const std::vector<array>&, std::vector<array>&) override {}
35
+
36
+ DEFINE_NAME(Synchronize);
37
+ };
38
+
39
+ // Initialize the static tracing members from transforms_impl.h
40
+ //
41
+ // These are used to implement the in_tracing() function the returns true if we
42
+ // are currently under a function transformation and the retain_graph()
43
+ // function which returns true if we are forced to retain the graph during
44
+ // evaluation.
45
+ std::vector<std::pair<char, char>>& detail::InTracing::trace_stack() {
46
+ static std::vector<std::pair<char, char>> trace_stack_;
47
+ return trace_stack_;
48
+ }
49
+ int detail::InTracing::grad_counter{0};
50
+ int detail::RetainGraph::tracing_counter{0};
51
+
52
+ array eval_impl(std::vector<array> outputs, bool async) {
53
+ std::deque<array> tape;
54
+
55
+ // Make an effort to choose a good output stream
56
+ Stream stream = default_stream(default_device());
57
+ for (auto& o : outputs) {
58
+ if (o.status() == array::Status::unscheduled && o.has_primitive()) {
59
+ stream = o.primitive().stream();
60
+ break;
61
+ }
62
+ }
63
+
64
+ // Map of array id that needs fence and stream it's computed on
65
+ std::unordered_map<uintptr_t, std::pair<uint32_t, bool>> needs_fence;
66
+
67
+ auto synchronizer = array(
68
+ {}, bool_, std::make_shared<Synchronizer>(stream), std::move(outputs));
69
+
70
+ // Stream fences for inter-stream synchronization
71
+ std::unordered_map<uint32_t, Fence> fences;
72
+
73
+ // Stream events for synchronization after eval
74
+ std::unordered_map<uint32_t, Event> events;
75
+ {
76
+ auto e = Event{stream};
77
+ e.set_value(1);
78
+ synchronizer.attach_event(e);
79
+ events.emplace(stream.index, std::move(e));
80
+ }
81
+
82
+ {
83
+ // Record the degree of each input
84
+ std::unordered_map<std::uintptr_t, int> cache;
85
+
86
+ std::stack<std::pair<std::reference_wrapper<array>, int>> dfs;
87
+ dfs.emplace(synchronizer, 0);
88
+ while (!dfs.empty()) {
89
+ auto& [a_ref, idx] = dfs.top();
90
+ auto& a = a_ref.get();
91
+
92
+ if (idx < a.inputs().size()) {
93
+ // Add an input, and continue
94
+ auto& in = a.inputs()[idx++];
95
+
96
+ if (in.status() == array::Status::unscheduled) {
97
+ if (async && in.is_tracer()) {
98
+ throw std::invalid_argument(
99
+ "[async_eval] Not allowed inside a graph transformation.");
100
+ }
101
+ if (!in.has_primitive()) {
102
+ if (in.is_tracer()) {
103
+ throw std::invalid_argument(
104
+ "[eval] Attempting to eval an array during function"
105
+ " transformations like compile or vmap is not allowed.");
106
+ }
107
+ throw std::runtime_error(
108
+ "[eval] Attempting to eval an array without a primitive.\n"
109
+ "If you are compiling a function, make sure all the inputs "
110
+ "and outputs are captured:\n"
111
+ "https://ml-explore.github.io/mlx/build/html/usage/compile.html#pure-functions.\n"
112
+ "If you are not using compile, this may be a bug. "
113
+ "Please file an issue here:\n"
114
+ "https://github.com/ml-explore/mlx/issues.");
115
+ }
116
+ if (a.primitive().stream() != in.primitive().stream()) {
117
+ bool device_switch =
118
+ a.primitive().stream().device != in.primitive().stream().device;
119
+ auto [it, inserted] = needs_fence.emplace(
120
+ in.id(),
121
+ std::make_pair(in.primitive().stream().index, device_switch));
122
+ if (!inserted) {
123
+ it->second.second |= device_switch;
124
+ }
125
+ }
126
+ }
127
+
128
+ // All siblings have the same degree
129
+ auto cache_it = cache.find(in.id());
130
+ if (cache_it == cache.end()) {
131
+ dfs.emplace(in, 0);
132
+ cache.insert({in.id(), 1});
133
+ for (auto& s : in.siblings()) {
134
+ cache.insert({s.id(), 1});
135
+ }
136
+ } else {
137
+ cache_it->second++;
138
+ for (auto& s : in.siblings()) {
139
+ cache[s.id()]++;
140
+ }
141
+ }
142
+ continue;
143
+ }
144
+ if ((a.status() != array::Status::unscheduled) && !a.is_tracer() &&
145
+ a.has_primitive()) {
146
+ // If the array is evaluated and is no longer a tracer, detach it
147
+ a.detach();
148
+ }
149
+ dfs.pop();
150
+ }
151
+
152
+ // Build the tape in BFS order with a width limit
153
+ int max_width = env::bfs_max_width();
154
+ dfs = std::stack<std::pair<std::reference_wrapper<array>, int>>();
155
+ tape.push_back(synchronizer);
156
+ for (int i = 0; !cache.empty() && (i < tape.size() || !dfs.empty());) {
157
+ auto& a = (i >= tape.size()) ? dfs.top().first.get() : tape[i];
158
+ int j = 0;
159
+ if (i >= tape.size()) {
160
+ j = dfs.top().second;
161
+ dfs.pop();
162
+ } else {
163
+ i++;
164
+ }
165
+ for (; j < a.inputs().size(); ++j) {
166
+ auto& in = a.inputs()[j];
167
+ if (in.status() != array::Status::unscheduled) {
168
+ continue;
169
+ }
170
+
171
+ // If the width limit is exceeded, push the array on the stack
172
+ // and go down a level
173
+ if ((tape.size() - i) >= max_width) {
174
+ dfs.emplace(a, j);
175
+ break;
176
+ }
177
+
178
+ auto it = cache.find(in.id());
179
+ it->second -= 1;
180
+
181
+ if (it->second != 0) {
182
+ for (auto& s : in.siblings()) {
183
+ cache[s.id()] -= 1;
184
+ }
185
+ continue;
186
+ }
187
+
188
+ // Remove input and siblings from cache
189
+ cache.erase(it);
190
+ for (auto& s : in.siblings()) {
191
+ cache.erase(s.id());
192
+ }
193
+
194
+ tape.push_back(in);
195
+ }
196
+ }
197
+ }
198
+
199
+ std::unordered_set<int> open_streams;
200
+ while (!tape.empty()) {
201
+ auto arr = std::move(tape.back());
202
+ tape.pop_back();
203
+
204
+ auto stream = arr.primitive().stream();
205
+ open_streams.insert(stream.index);
206
+
207
+ if (async) {
208
+ // Lookup corresponding event
209
+ auto e = events.find(stream.index);
210
+ if (e == events.end()) {
211
+ e = events.emplace(stream.index, Event{stream}).first;
212
+ }
213
+ e->second.set_value(1);
214
+ arr.attach_event(e->second);
215
+ for (auto& s : arr.siblings()) {
216
+ s.attach_event(e->second);
217
+ }
218
+ }
219
+
220
+ for (auto& in : arr.inputs()) {
221
+ if (auto it = needs_fence.find(in.id()); it != needs_fence.end()) {
222
+ // Use fence to wait within a single eval
223
+ // Get the input array's stream fence and wait on the
224
+ // output arrays stream
225
+ fences[it->second.first].wait(stream, in);
226
+ } else if (in.event().valid()) {
227
+ if (in.event().is_signaled()) {
228
+ in.detach_event();
229
+ } else if (in.event().stream() != stream) {
230
+ // Use event to wait across async eval
231
+ in.event().wait(stream);
232
+ }
233
+ }
234
+ }
235
+
236
+ if (arr.primitive().device() == Device::gpu) {
237
+ gpu::eval(arr);
238
+ } else {
239
+ cpu::eval(arr);
240
+ }
241
+
242
+ if (scheduler::n_active_tasks() > MAX_ACTIVE_TASKS ||
243
+ (get_active_memory() > get_memory_limit() &&
244
+ scheduler::n_active_tasks() > 0)) {
245
+ // Commit any open streams
246
+ for (auto i : open_streams) {
247
+ auto s = get_stream(i);
248
+ if (s.device == Device::gpu) {
249
+ gpu::finalize(s);
250
+ }
251
+ }
252
+ scheduler::wait_for_one();
253
+ while (get_active_memory() > get_memory_limit() &&
254
+ scheduler::n_active_tasks() > 0) {
255
+ scheduler::wait_for_one();
256
+ }
257
+ }
258
+
259
+ auto maybe_update_fence = [&fences, &needs_fence, stream](const array& a) {
260
+ if (auto nf = needs_fence.find(a.id()); nf != needs_fence.end()) {
261
+ auto it = fences.find(stream.index);
262
+ if (it == fences.end()) {
263
+ it = fences.emplace(stream.index, Fence{stream}).first;
264
+ }
265
+ it->second.update(stream, a, nf->second.second);
266
+ }
267
+ };
268
+
269
+ arr.set_status(array::Status::evaluated);
270
+ // TODO Maybe always want the fence coherent kernel in the same cbuf
271
+ // as the other kernels?
272
+ maybe_update_fence(arr);
273
+ for (auto& sib : arr.siblings()) {
274
+ sib.set_status(array::Status::evaluated);
275
+ maybe_update_fence(sib);
276
+ }
277
+ if (!arr.is_tracer()) {
278
+ arr.detach();
279
+ }
280
+ }
281
+
282
+ // Signal the event in its stream
283
+ for (auto i : open_streams) {
284
+ auto s = get_stream(i);
285
+ if (auto e = events.find(i); e != events.end()) {
286
+ e->second.signal(s);
287
+ }
288
+ if (s.device == Device::gpu) {
289
+ gpu::finalize(s);
290
+ }
291
+ }
292
+
293
+ return synchronizer;
294
+ }
295
+
296
+ void async_eval(std::vector<array> outputs) {
297
+ if (outputs.empty()) {
298
+ return;
299
+ }
300
+
301
+ if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {
302
+ return x.status() == array::Status::unscheduled;
303
+ })) {
304
+ return;
305
+ }
306
+
307
+ eval_impl(std::move(outputs), true);
308
+ }
309
+
310
+ void eval(std::vector<array> outputs) {
311
+ if (outputs.empty()) {
312
+ return;
313
+ }
314
+
315
+ if (std::none_of(outputs.begin(), outputs.end(), [](array& x) {
316
+ return x.status() == array::Status::unscheduled;
317
+ })) {
318
+ for (auto& x : outputs) {
319
+ x.wait();
320
+ }
321
+ return;
322
+ }
323
+
324
+ eval_impl(std::move(outputs), false).wait();
325
+ }
326
+
327
+ std::pair<std::vector<array>, std::vector<array>> vjp(
328
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
329
+ const std::vector<array>& primals,
330
+ const std::vector<array>& cotans,
331
+ const std::vector<int>& argnums) {
332
+ // Set the global tracing flag.
333
+ detail::InTracing in_tracing{false, true};
334
+
335
+ // Make tracers from given primals
336
+ std::vector<array> primals_;
337
+ for (auto& p : primals) {
338
+ auto s = p.has_primitive() ? p.primitive().stream()
339
+ : default_stream(default_device());
340
+ primals_.push_back(copy(p, s)); // Does not do a deep copy
341
+ primals_.back().set_tracer(true);
342
+ }
343
+
344
+ // Pass tracer primals through the function
345
+ // Any variables that depend on the primals are marked as tracers
346
+ auto outputs = fun(primals_);
347
+
348
+ // Map outputs to passed cotans while ignoring the outputs
349
+ // that have stop_gradient called on them
350
+ int cotan_index = 0;
351
+ std::vector<std::pair<int, int>> output_cotan_pairs;
352
+ for (int i = 0; i < outputs.size(); ++i) {
353
+ auto& out = outputs[i];
354
+ if (out.has_primitive()) {
355
+ if (auto& p = out.primitive(); typeid(p) == typeid(StopGradient)) {
356
+ continue;
357
+ }
358
+ }
359
+ if (cotan_index >= cotans.size()) {
360
+ std::ostringstream msg;
361
+ msg << "[vjp] Number of outputs to compute gradients for ("
362
+ << outputs.size() << ") does not match number of cotangents ("
363
+ << cotans.size() << ").";
364
+ throw std::invalid_argument(msg.str());
365
+ }
366
+ if (out.shape() != cotans[cotan_index].shape()) {
367
+ std::ostringstream msg;
368
+ msg << "[vjp] Output shape " << out.shape()
369
+ << " does not match cotangent shape " << cotans[cotan_index].shape()
370
+ << ".";
371
+ if (outputs.size() == 1 && out.size() == 1) {
372
+ msg << " If you are using grad your function must return a scalar.";
373
+ }
374
+ throw std::invalid_argument(msg.str());
375
+ }
376
+ output_cotan_pairs.emplace_back(i, cotan_index++);
377
+ }
378
+
379
+ // Topologically sort the compute graph, add graph nodes
380
+ // to the tape which need a gradient.
381
+ std::unordered_set<std::uintptr_t> cache;
382
+ std::unordered_set<std::uintptr_t> calc_grad;
383
+ for (int i = 0, j = 0; i < primals_.size(); ++i) {
384
+ auto& primal = primals_[i];
385
+ primal.set_tracer(false);
386
+ cache.insert(primal.id());
387
+ if (j < argnums.size() && argnums[j] == i) {
388
+ j++;
389
+ calc_grad.insert(primal.id());
390
+ }
391
+ }
392
+
393
+ std::vector<array> tape;
394
+
395
+ std::function<void(array&)> recurse;
396
+ recurse = [&](auto& a) {
397
+ // Check if visited and add to cache if not
398
+ if (auto inserted = cache.insert(a.id()); !inserted.second) {
399
+ return;
400
+ }
401
+ a.set_tracer(false);
402
+ for (auto& s : a.siblings()) {
403
+ s.set_tracer(false);
404
+ cache.insert(s.id());
405
+ }
406
+
407
+ for (auto& input : a.inputs()) {
408
+ recurse(input);
409
+ }
410
+
411
+ // Stop grad
412
+ if (a.has_primitive()) {
413
+ if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {
414
+ return;
415
+ }
416
+ }
417
+
418
+ // Calculate gradient if any inputs require gradient
419
+ for (auto& input : a.inputs()) {
420
+ if (calc_grad.find(input.id()) != calc_grad.end()) {
421
+ tape.push_back(a);
422
+ calc_grad.insert(a.id());
423
+ for (auto& s : a.siblings()) {
424
+ calc_grad.insert(s.id());
425
+ }
426
+ break;
427
+ }
428
+ }
429
+ };
430
+
431
+ for (auto out : outputs) {
432
+ recurse(out);
433
+ }
434
+
435
+ // Run the tape backwards, computing vector-jacobian
436
+ // products for each primitive
437
+ std::unordered_map<std::uintptr_t, array> cotan_map;
438
+ for (auto [out_idx, cotan_idx] : output_cotan_pairs) {
439
+ auto& o = outputs[out_idx];
440
+ auto s = o.has_primitive() ? o.primitive().stream()
441
+ : default_stream(default_device());
442
+ cotan_map.insert({o.id(), astype(cotans[cotan_idx], o.dtype(), s)});
443
+ }
444
+ for (auto it = tape.rbegin(); it != tape.rend(); ++it) {
445
+ auto& a = *it;
446
+
447
+ // Get the arguments whose gradients are needed
448
+ std::vector<int> argnums;
449
+ for (int i = 0; i < a.inputs().size(); ++i) {
450
+ if (calc_grad.find(a.inputs()[i].id()) != calc_grad.end()) {
451
+ argnums.push_back(i);
452
+ }
453
+ }
454
+
455
+ // Check if any of the array or its siblings have cotangents,
456
+ // if not, we can skip this primitive
457
+ auto outputs = a.outputs();
458
+ bool has_cotans =
459
+ std::any_of(outputs.cbegin(), outputs.cend(), [&cotan_map](auto& s) {
460
+ return cotan_map.find(s.id()) != cotan_map.end();
461
+ });
462
+ if (!has_cotans) {
463
+ continue;
464
+ }
465
+
466
+ auto s = a.primitive().stream();
467
+ std::vector<array> cotangents{};
468
+ for (auto& o : outputs) {
469
+ if (auto cotan_it = cotan_map.find(o.id()); cotan_it != cotan_map.end()) {
470
+ cotangents.push_back(cotan_map.extract(cotan_it).mapped());
471
+ } else {
472
+ cotangents.push_back(zeros_like(o, s));
473
+ }
474
+ }
475
+
476
+ std::vector<array> vjps;
477
+ {
478
+ detail::RetainGraph retain;
479
+ vjps = a.primitive().vjp(a.inputs(), cotangents, argnums, outputs);
480
+ }
481
+ // Accumulate the vector-jacobian products for each input
482
+ for (int i = 0; i < argnums.size(); ++i) {
483
+ auto in_id = a.inputs()[argnums[i]].id();
484
+ if (auto cotan_it = cotan_map.find(in_id); cotan_it != cotan_map.end()) {
485
+ cotan_it->second = add(cotan_it->second, vjps[i], s);
486
+ } else {
487
+ cotan_map.insert({in_id, vjps[i]});
488
+ }
489
+ }
490
+ }
491
+ std::vector<array> vjps;
492
+ for (auto arg : argnums) {
493
+ auto& primal = primals_[arg];
494
+ if (auto cotan_it = cotan_map.find(primal.id());
495
+ cotan_it != cotan_map.end()) {
496
+ vjps.push_back(cotan_it->second);
497
+ } else {
498
+ auto s = primal.has_primitive() ? primal.primitive().stream()
499
+ : default_stream(default_device());
500
+ vjps.push_back(zeros_like(primal, s));
501
+ }
502
+ }
503
+ return {outputs, vjps};
504
+ }
505
+
506
+ std::pair<std::vector<array>, std::vector<array>> vjp(
507
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
508
+ const std::vector<array>& primals,
509
+ const std::vector<array>& cotans) {
510
+ std::vector<int> argnums(primals.size());
511
+ std::iota(argnums.begin(), argnums.end(), 0);
512
+ return vjp(fun, primals, cotans, argnums);
513
+ }
514
+
515
+ std::pair<array, array> vjp(
516
+ const std::function<array(const array&)>& fun,
517
+ const array& primal,
518
+ const array& cotan) {
519
+ auto vec_fun = [fun](const std::vector<array>& inputs) {
520
+ return std::vector<array>{fun(inputs[0])};
521
+ };
522
+ auto [outputs, vjps] = vjp(vec_fun, {primal}, {cotan});
523
+ return {outputs[0], vjps[0]};
524
+ }
525
+
526
+ std::pair<std::vector<array>, std::vector<array>> jvp(
527
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
528
+ const std::vector<array>& primals,
529
+ const std::vector<array>& tangents) {
530
+ // Set the global tracing flag.
531
+ detail::InTracing in_tracing{false, true};
532
+
533
+ if (primals.size() != tangents.size()) {
534
+ throw std::invalid_argument(
535
+ "[jvp] Number of inputs does not match number of tangents.");
536
+ }
537
+ for (int i = 0; i < primals.size(); ++i) {
538
+ if (primals[i].shape() != tangents[i].shape()) {
539
+ throw std::invalid_argument(
540
+ "[jvp] Input shape does not match shape of tangent.");
541
+ }
542
+ }
543
+
544
+ std::vector<array> primals_;
545
+ for (auto& p : primals) {
546
+ auto s = p.has_primitive() ? p.primitive().stream()
547
+ : default_stream(default_device());
548
+ primals_.push_back(copy(p, s)); // Does not do a deep copy
549
+ primals_.back().set_tracer(true);
550
+ }
551
+ auto outputs = fun(primals_);
552
+
553
+ // Topologically sort the compute graph, record outputs
554
+ // in the tape if a gradient is needed.
555
+ std::unordered_set<std::uintptr_t> cache;
556
+ std::unordered_set<std::uintptr_t> calc_grad;
557
+ for (auto& primal : primals_) {
558
+ primal.set_tracer(false);
559
+ calc_grad.insert(primal.id());
560
+ cache.insert(primal.id());
561
+ }
562
+
563
+ std::vector<array> tape;
564
+
565
+ std::function<void(array&)> recurse;
566
+ recurse = [&](auto& a) {
567
+ // Check if visited and add to cache if not
568
+ if (auto inserted = cache.insert(a.id()); !inserted.second) {
569
+ return;
570
+ }
571
+ a.set_tracer(false);
572
+ for (auto& s : a.siblings()) {
573
+ s.set_tracer(false);
574
+ cache.insert(s.id());
575
+ }
576
+
577
+ for (auto input : a.inputs()) {
578
+ recurse(input);
579
+ }
580
+
581
+ // Stop grad
582
+ if (a.has_primitive()) {
583
+ if (auto& p = a.primitive(); typeid(p) == typeid(StopGradient)) {
584
+ return;
585
+ }
586
+ }
587
+
588
+ // Calculate gradient if any inputs require gradient
589
+ for (auto& input : a.inputs()) {
590
+ if (calc_grad.find(input.id()) != calc_grad.end()) {
591
+ tape.push_back(a);
592
+ calc_grad.insert(a.id());
593
+ for (auto& s : a.siblings()) {
594
+ calc_grad.insert(s.id());
595
+ }
596
+ break;
597
+ }
598
+ }
599
+ };
600
+
601
+ for (auto out : outputs) {
602
+ recurse(out);
603
+ }
604
+
605
+ std::unordered_map<std::uintptr_t, array> tan_map;
606
+ for (int i = 0; i < primals_.size(); ++i) {
607
+ tan_map.insert({primals_[i].id(), tangents[i]});
608
+ }
609
+
610
+ for (auto& a : tape) {
611
+ // Get the arguments used in the jvp
612
+ std::vector<int> argnums;
613
+ std::vector<array> tangents;
614
+ for (int i = 0; i < a.inputs().size(); ++i) {
615
+ if (auto it = tan_map.find(a.inputs()[i].id()); it != tan_map.end()) {
616
+ argnums.push_back(i);
617
+ tangents.push_back(it->second);
618
+ }
619
+ }
620
+
621
+ auto jvps = a.primitive().jvp(a.inputs(), tangents, argnums);
622
+ auto outputs = a.outputs();
623
+ for (int i = 0; i < jvps.size(); ++i) {
624
+ tan_map.insert({outputs[i].id(), jvps[i]});
625
+ }
626
+ }
627
+
628
+ std::vector<array> jvps;
629
+ for (auto& out : outputs) {
630
+ if (auto it = tan_map.find(out.id()); it != tan_map.end()) {
631
+ jvps.push_back(it->second);
632
+ } else {
633
+ auto s = out.has_primitive() ? out.primitive().stream()
634
+ : default_stream(default_device());
635
+ jvps.push_back(zeros_like(out, s));
636
+ }
637
+ }
638
+ return {outputs, jvps};
639
+ }
640
+
641
+ std::pair<array, array> jvp(
642
+ const std::function<array(const array&)>& fun,
643
+ const array& primal,
644
+ const array& tangent) {
645
+ auto vec_fun = [fun](const std::vector<array>& inputs) {
646
+ return std::vector<array>{fun(inputs[0])};
647
+ };
648
+ auto [outputs, jvps] = jvp(vec_fun, {primal}, {tangent});
649
+ return {outputs[0], jvps[0]};
650
+ }
651
+
652
+ ValueAndGradFn value_and_grad(
653
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
654
+ const std::vector<int>& argnums) {
655
+ if (argnums.empty()) {
656
+ throw std::invalid_argument("[grad] Must specify at least one argument.");
657
+ }
658
+ return [fun, argnums](const std::vector<array>& inputs) {
659
+ std::set<int> args;
660
+ for (auto& arg : argnums) {
661
+ args.insert(arg < 0 ? arg + inputs.size() : arg);
662
+ }
663
+ if (args.size() != argnums.size()) {
664
+ throw std::invalid_argument(
665
+ "[grad] Repeat argument number not allowed in grad.");
666
+ }
667
+ if (*args.begin() < 0 || *args.rbegin() >= inputs.size()) {
668
+ std::ostringstream msg;
669
+ msg << "[grad] Invalid argument number for function with "
670
+ << inputs.size() << " inputs.";
671
+ throw std::invalid_argument(msg.str());
672
+ }
673
+ std::vector<int> sorted_argnums(args.begin(), args.end());
674
+
675
+ auto gfun = [&fun](const std::vector<array>& inputs) {
676
+ auto outputs = fun(inputs);
677
+ for (int i = 1; i < outputs.size(); i++) {
678
+ auto& out = outputs[i];
679
+ auto s = out.has_primitive() ? out.primitive().stream()
680
+ : default_stream(default_device());
681
+ outputs[i] = stop_gradient(out, s);
682
+ }
683
+ return outputs;
684
+ };
685
+
686
+ // Set the incoming gradient to float32, vjp will cast it to the output type
687
+ auto [outputs, grads] = vjp(gfun, inputs, {array(1.0f)}, sorted_argnums);
688
+ return std::make_pair(outputs, grads);
689
+ };
690
+ }
691
+
692
+ namespace detail {
693
+
694
+ std::pair<std::vector<array>, std::vector<array>> vmap_trace(
695
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
696
+ const std::vector<array>& inputs,
697
+ const std::vector<int>& in_axes) {
698
+ // Set the global tracing flag.
699
+ detail::InTracing in_tracing;
700
+
701
+ if (in_axes.size() != inputs.size()) {
702
+ std::stringstream ss;
703
+ ss << "[vmap] The number of in axes (" << in_axes.size()
704
+ << ") must match the number of inputs (" << inputs.size() << ").";
705
+ throw std::invalid_argument(ss.str());
706
+ }
707
+
708
+ // Some error checking and get the vmap axis size
709
+ size_t vmap_ax_size;
710
+ for (int i = 0; i < inputs.size(); ++i) {
711
+ if (in_axes[i] != -1) {
712
+ if (inputs[i].ndim() == 0) {
713
+ throw std::invalid_argument(
714
+ "[vmap] Cannot vmap an input with zero dimensions.");
715
+ }
716
+ if (in_axes[i] > inputs[i].ndim()) {
717
+ std::ostringstream msg;
718
+ msg << "[vmap] Axis " << in_axes[i] << " invalid for input with "
719
+ << inputs[i].ndim() << " dimensions.";
720
+ throw std::invalid_argument(msg.str());
721
+ }
722
+ vmap_ax_size = inputs[i].shape(in_axes[i]);
723
+ }
724
+ }
725
+ // Check that all vmapped axes have the same size
726
+ for (int i = 0; i < inputs.size(); ++i) {
727
+ if (in_axes[i] != -1) {
728
+ if (size_t in_ax = inputs[i].shape(in_axes[i]); vmap_ax_size != in_ax) {
729
+ std::ostringstream msg;
730
+ msg << "[vmap] Inconsistent axis sizes: " << in_ax << " and "
731
+ << vmap_ax_size << ".";
732
+ throw std::invalid_argument(msg.str());
733
+ }
734
+ }
735
+ }
736
+
737
+ // Run the function on placeholder inputs
738
+ // to get the original graph
739
+ std::vector<array> s_inputs;
740
+ for (int i = 0; i < inputs.size(); ++i) {
741
+ if (in_axes[i] != -1) {
742
+ auto shape = inputs[i].shape();
743
+ shape.erase(shape.begin() + in_axes[i]);
744
+ array in(shape, inputs[i].dtype(), nullptr, {});
745
+ s_inputs.push_back(in);
746
+ s_inputs.back().set_tracer(true);
747
+ } else {
748
+ s_inputs.push_back(inputs[i]);
749
+ }
750
+ }
751
+ return {s_inputs, fun(s_inputs)};
752
+ }
753
+
754
+ std::vector<array> vmap_replace(
755
+ const std::vector<array>& inputs,
756
+ const std::vector<array>& s_inputs,
757
+ const std::vector<array>& s_outputs,
758
+ const std::vector<int>& in_axes,
759
+ const std::vector<int>& out_axes) {
760
+ if (out_axes.size() != s_outputs.size()) {
761
+ std::stringstream msg;
762
+ msg << "[vmap] The number of out axes (" << out_axes.size()
763
+ << ") must match the number of outputs (" << s_outputs.size() << ").";
764
+ throw std::invalid_argument(msg.str());
765
+ }
766
+
767
+ int vmap_size = -1;
768
+ for (int i = 0; i < inputs.size(); ++i) {
769
+ if (in_axes[i] >= 0) {
770
+ vmap_size = inputs[i].shape(in_axes[i]);
771
+ break;
772
+ }
773
+ }
774
+ if (vmap_size == -1) {
775
+ throw std::invalid_argument("At least one of in_axes must be non-None.");
776
+ }
777
+
778
+ std::unordered_map<std::uintptr_t, std::pair<array, int>> tmap;
779
+ std::unordered_set<std::uintptr_t> needs_vmap;
780
+ std::unordered_set<std::uintptr_t> cache;
781
+ for (int i = 0; i < s_inputs.size(); ++i) {
782
+ auto in = s_inputs[i];
783
+ if (in_axes[i] != -1) {
784
+ tmap.insert({in.id(), {inputs[i], in_axes[i]}});
785
+ needs_vmap.insert(in.id());
786
+ in.set_tracer(false);
787
+ }
788
+ cache.insert(in.id());
789
+ }
790
+
791
+ // Topologically sort the graph
792
+ std::vector<array> tape;
793
+
794
+ std::function<void(const array&)> recurse;
795
+
796
+ recurse = [&](const array& a) {
797
+ auto id = a.id();
798
+ if (cache.find(id) != cache.end()) {
799
+ return;
800
+ }
801
+ cache.insert(id);
802
+ for (auto& s : a.siblings()) {
803
+ cache.insert(s.id());
804
+ }
805
+
806
+ // Recurse on inputs
807
+ for (auto& input : a.inputs()) {
808
+ recurse(input);
809
+ }
810
+ // If any input needs a vmap, then the outputs also need
811
+ // a vmap
812
+ for (auto& input : a.inputs()) {
813
+ if (needs_vmap.find(input.id()) != needs_vmap.end()) {
814
+ tape.push_back(a);
815
+ tape.back().set_tracer(false);
816
+ needs_vmap.insert(a.id());
817
+ for (auto s : a.siblings()) {
818
+ needs_vmap.insert(s.id());
819
+ s.set_tracer(false);
820
+ }
821
+ break;
822
+ }
823
+ }
824
+ };
825
+
826
+ for (auto& out : s_outputs) {
827
+ if (out.has_primitive()) {
828
+ recurse(out);
829
+ }
830
+ }
831
+
832
+ // Transform each primitive in the graph with
833
+ // its vmap implementation
834
+ for (auto& a : tape) {
835
+ std::vector<array> v_inputs;
836
+ std::vector<int> v_axes;
837
+ for (auto& in : a.inputs()) {
838
+ auto map_it = tmap.find(in.id());
839
+ if (map_it != tmap.end()) {
840
+ v_inputs.push_back(map_it->second.first);
841
+ v_axes.push_back(map_it->second.second);
842
+ } else {
843
+ v_inputs.push_back(in);
844
+ v_axes.push_back(-1);
845
+ }
846
+ }
847
+
848
+ auto [v_outputs, v_out_axes] = a.primitive().vmap(v_inputs, v_axes);
849
+
850
+ // For each primitive's outputs add its id, the vout id and the vax
851
+ auto outputs = a.outputs();
852
+ for (int i = 0; i < v_outputs.size(); ++i) {
853
+ tmap.insert({outputs[i].id(), {v_outputs[i], v_out_axes[i]}});
854
+ }
855
+ }
856
+
857
+ // Populate the outputs and make sure all the output axes are
858
+ // in the right place
859
+ std::vector<array> outputs;
860
+ for (int i = 0; i < s_outputs.size(); ++i) {
861
+ if (auto map_it = tmap.find(s_outputs[i].id()); map_it != tmap.end()) {
862
+ auto& [out, vdim] = map_it->second;
863
+ if (vdim != out_axes[i]) {
864
+ if (out_axes[i] >= out.ndim()) {
865
+ std::ostringstream msg;
866
+ msg << "[vmap] Axis " << out_axes[i] << " invalid for output with "
867
+ << out.ndim() << " dimensions.";
868
+ throw std::invalid_argument(msg.str());
869
+ }
870
+ out = moveaxis(out, vdim, out_axes[i]);
871
+ }
872
+ outputs.push_back(out);
873
+ } else {
874
+ // When the output has no input dependencies
875
+ // use the size of the vmapped axis in the inputs to expand the output
876
+ array output = expand_dims(s_outputs[i], out_axes[i]);
877
+ output = repeat(output, vmap_size, out_axes[i]);
878
+ outputs.push_back(output);
879
+ }
880
+ }
881
+ return outputs;
882
+ }
883
+
884
+ } // namespace detail
885
+
886
+ std::function<std::vector<array>(const std::vector<array>&)> vmap(
887
+ const std::function<std::vector<array>(const std::vector<array>&)>& fun,
888
+ const std::vector<int>& in_axes /* = {} */,
889
+ const std::vector<int>& out_axes /* = {} */) {
890
+ auto infer_axes = [](auto axes) {
891
+ return !axes.empty() &&
892
+ std::all_of(axes.begin(), axes.end(), [](int ax) { return ax < 0; });
893
+ };
894
+ if (infer_axes(in_axes) != infer_axes(out_axes)) {
895
+ throw std::invalid_argument(
896
+ "[vmap] Input (or output) axes must be "
897
+ "specified if output (or input) axes are.");
898
+ }
899
+ auto vfun = [fun, in_axes = in_axes, out_axes = out_axes](
900
+ const std::vector<array>& inputs) mutable {
901
+ if (in_axes.size() == 0) {
902
+ in_axes.resize(inputs.size(), 0);
903
+ }
904
+
905
+ auto [trace_inputs, trace_outputs] =
906
+ detail::vmap_trace(fun, inputs, in_axes);
907
+
908
+ if (out_axes.size() == 0) {
909
+ out_axes.resize(trace_outputs.size(), 0);
910
+ }
911
+
912
+ return detail::vmap_replace(
913
+ inputs, trace_inputs, trace_outputs, in_axes, out_axes);
914
+ };
915
+
916
+ return vfun;
917
+ }
918
+
919
+ std::function<array(const array&, const array&)> vmap(
920
+ const std::function<array(const array&, const array&)>& fun,
921
+ int in_axis_a /* = 0 */,
922
+ int in_axis_b /* = 0 */,
923
+ int out_axis /* = 0 */) {
924
+ auto vfun = vmap(
925
+ [fun](const std::vector<array>& inputs) {
926
+ return std::vector<array>{fun(inputs[0], inputs[1])};
927
+ },
928
+ {in_axis_a, in_axis_b},
929
+ {out_axis});
930
+ return [vfun](const array& a, const array& b) { return vfun({a, b})[0]; };
931
+ }
932
+
933
+ std::function<array(const array&)> vmap(
934
+ const std::function<array(const array&)>& fun,
935
+ int in_axis /* = 0 */,
936
+ int out_axis /* = 0 */) {
937
+ auto vfun = vmap(
938
+ [fun](const std::vector<array>& inputs) {
939
+ return std::vector<array>{fun(inputs[0])};
940
+ },
941
+ {in_axis},
942
+ {out_axis});
943
+ return [vfun](const array& a) { return vfun({a})[0]; };
944
+ }
945
+
946
+ std::function<std::vector<array>(const std::vector<array>&)> custom_function(
947
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
948
+ std::optional<std::function<std::vector<array>(
949
+ const std::vector<array>&,
950
+ const std::vector<array>&,
951
+ const std::vector<array>&)>> fun_vjp /* = std::nullopt */,
952
+ std::optional<std::function<std::vector<array>(
953
+ const std::vector<array>&,
954
+ const std::vector<array>&,
955
+ const std::vector<int>&)>> fun_jvp /* = std::nullopt */,
956
+ std::optional<std::function<std::pair<std::vector<array>, std::vector<int>>(
957
+ const std::vector<array>&,
958
+ const std::vector<int>&)>> fun_vmap /* = std::nullopt */) {
959
+ if (!fun_vjp.has_value() && !fun_jvp.has_value() && !fun_vmap.has_value()) {
960
+ return fun;
961
+ }
962
+
963
+ return [fun = std::move(fun),
964
+ fun_vjp = std::move(fun_vjp),
965
+ fun_jvp = std::move(fun_jvp),
966
+ fun_vmap = std::move(fun_vmap)](const std::vector<array>& args) {
967
+ // Compute the outputs
968
+ auto outputs = fun(args);
969
+ for (auto& out : outputs) {
970
+ out = stop_gradient(out);
971
+ }
972
+
973
+ // Prepare the inputs to the primitive
974
+ // We also add the outputs to the primitive so that it can "run" the forward
975
+ // pass.
976
+ std::vector<array> inputs = args;
977
+ inputs.insert(inputs.end(), outputs.begin(), outputs.end());
978
+
979
+ // Compute the stream. Maybe do it in a smarter way at some point in the
980
+ // future.
981
+ Stream s = (outputs[0].has_primitive()) ? outputs[0].primitive().stream()
982
+ : default_stream(default_device());
983
+
984
+ // Make the output info
985
+ std::vector<Shape> shapes;
986
+ std::vector<Dtype> dtypes;
987
+ for (const auto& out : outputs) {
988
+ shapes.emplace_back(out.shape());
989
+ dtypes.emplace_back(out.dtype());
990
+ }
991
+
992
+ return array::make_arrays(
993
+ std::move(shapes),
994
+ dtypes,
995
+ std::make_shared<CustomTransforms>(
996
+ to_stream(s),
997
+ outputs.size(),
998
+
999
+ // We use the passed vjp function or compute it from the inputs and
1000
+ // passed cotangents. Note that this may be less efficient than
1001
+ // using `fun` directly because we may not be able to fully reuse
1002
+ // the outputs of the forward pass.
1003
+ fun_vjp.value_or(
1004
+ [fun](auto primals, auto cotangents, auto outputs) {
1005
+ auto [__, vjps] = vjp(fun, primals, cotangents);
1006
+ return vjps;
1007
+ }),
1008
+
1009
+ // We use the passed jvp function or compute it from the primals
1010
+ // and tangents. Similarly we can't take full advantage of the
1011
+ // argnums so it is best to use `fun` directly if we don't need a
1012
+ // custom transform.
1013
+ //
1014
+ // TODO: Use stop_gradient to make full use of argnums and not
1015
+ // waste computation.
1016
+ fun_jvp.value_or([fun](auto primals, auto tangents, auto argnums) {
1017
+ std::vector<array> all_tangents;
1018
+ for (int i = 0, j = 0; i < primals.size(); i++) {
1019
+ if (j < argnums.size() && i == argnums[j]) {
1020
+ all_tangents.emplace_back(tangents[j++]);
1021
+ } else {
1022
+ all_tangents.emplace_back(zeros_like(primals[i]));
1023
+ }
1024
+ }
1025
+ auto [__, jvps] = jvp(fun, primals, all_tangents);
1026
+ return jvps;
1027
+ }),
1028
+
1029
+ // Same as above, we use the passed vmap function or we compute it
1030
+ // from `fun`. The output axes is selected to be all 0s which again
1031
+ // may be suboptimal but the only thing we can do without any
1032
+ // information for `fun`.
1033
+ fun_vmap.value_or(
1034
+ [fun, out_size = outputs.size()](auto inputs, auto in_axes)
1035
+ -> std::pair<std::vector<array>, std::vector<int>> {
1036
+ std::vector<int> out_axes(out_size, 0);
1037
+ return {vmap(fun, in_axes, out_axes)(inputs), out_axes};
1038
+ })),
1039
+ inputs);
1040
+ };
1041
+ }
1042
+
1043
+ std::function<std::vector<array>(const std::vector<array>&)> custom_vjp(
1044
+ std::function<std::vector<array>(const std::vector<array>&)> fun,
1045
+ std::function<std::vector<array>(
1046
+ const std::vector<array>&,
1047
+ const std::vector<array>&,
1048
+ const std::vector<array>&)> fun_vjp) {
1049
+ return custom_function(fun, fun_vjp, std::nullopt, std::nullopt);
1050
+ }
1051
+
1052
+ std::function<std::vector<array>(const std::vector<array>&)> checkpoint(
1053
+ std::function<std::vector<array>(const std::vector<array>&)> fun) {
1054
+ auto vjp_fun = [fun](
1055
+ const std::vector<array>& primals,
1056
+ const std::vector<array>& cotangents,
1057
+ const std::vector<array>& outputs) -> std::vector<array> {
1058
+ auto [__, vjps] = vjp(fun, depends(primals, outputs), cotangents);
1059
+ return vjps;
1060
+ };
1061
+
1062
+ return custom_vjp(fun, vjp_fun);
1063
+ }
1064
+
1065
+ } // namespace mlx::core