mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,265 @@
1
+ # Filename rules in cuda backend:
2
+ #
3
+ # * Use .cu/.cuh if code contains device code, and .cpp/.h if not.
4
+ # * Device-only code should be put in device/ subdir.
5
+ # * Files in device/ subdir should not include files outside.
6
+ target_sources(
7
+ mlx
8
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/allocator.cpp
9
+ ${CMAKE_CURRENT_SOURCE_DIR}/arange.cu
10
+ ${CMAKE_CURRENT_SOURCE_DIR}/arg_reduce.cu
11
+ ${CMAKE_CURRENT_SOURCE_DIR}/binary_two.cu
12
+ ${CMAKE_CURRENT_SOURCE_DIR}/compiled.cpp
13
+ ${CMAKE_CURRENT_SOURCE_DIR}/copy.cu
14
+ ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_contiguous.cu
15
+ ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general.cu
16
+ ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_dynamic.cu
17
+ ${CMAKE_CURRENT_SOURCE_DIR}/copy/copy_general_input.cu
18
+ ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
19
+ ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_conv.cu
20
+ ${CMAKE_CURRENT_SOURCE_DIR}/conv/gemm_grouped_conv.cu
21
+ ${CMAKE_CURRENT_SOURCE_DIR}/cublas_utils.cpp
22
+ ${CMAKE_CURRENT_SOURCE_DIR}/cudnn_utils.cpp
23
+ ${CMAKE_CURRENT_SOURCE_DIR}/device_info.cpp
24
+ ${CMAKE_CURRENT_SOURCE_DIR}/custom_kernel.cpp
25
+ ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
26
+ ${CMAKE_CURRENT_SOURCE_DIR}/distributed.cu
27
+ ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
28
+ ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
29
+ ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
30
+ ${CMAKE_CURRENT_SOURCE_DIR}/gemms/gemv.cu
31
+ ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm.cpp
32
+ ${CMAKE_CURRENT_SOURCE_DIR}/gemms/grouped_gemm_unaligned.cu
33
+ ${CMAKE_CURRENT_SOURCE_DIR}/jit_module.cpp
34
+ ${CMAKE_CURRENT_SOURCE_DIR}/indexing.cpp
35
+ ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
36
+ ${CMAKE_CURRENT_SOURCE_DIR}/matmul.cpp
37
+ ${CMAKE_CURRENT_SOURCE_DIR}/load.cpp
38
+ ${CMAKE_CURRENT_SOURCE_DIR}/layer_norm.cu
39
+ ${CMAKE_CURRENT_SOURCE_DIR}/logsumexp.cu
40
+ ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cpp
41
+ ${CMAKE_CURRENT_SOURCE_DIR}/random.cu
42
+ ${CMAKE_CURRENT_SOURCE_DIR}/reduce.cu
43
+ ${CMAKE_CURRENT_SOURCE_DIR}/reduce/all_reduce.cu
44
+ ${CMAKE_CURRENT_SOURCE_DIR}/reduce/col_reduce.cu
45
+ ${CMAKE_CURRENT_SOURCE_DIR}/reduce/init_reduce.cu
46
+ ${CMAKE_CURRENT_SOURCE_DIR}/reduce/row_reduce.cu
47
+ ${CMAKE_CURRENT_SOURCE_DIR}/rms_norm.cu
48
+ ${CMAKE_CURRENT_SOURCE_DIR}/rope.cu
49
+ ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cpp
50
+ ${CMAKE_CURRENT_SOURCE_DIR}/scaled_dot_product_attention.cu
51
+ ${CMAKE_CURRENT_SOURCE_DIR}/scan.cu
52
+ ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
53
+ ${CMAKE_CURRENT_SOURCE_DIR}/softmax.cu
54
+ ${CMAKE_CURRENT_SOURCE_DIR}/sort.cu
55
+ ${CMAKE_CURRENT_SOURCE_DIR}/ternary.cu
56
+ ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
57
+ ${CMAKE_CURRENT_SOURCE_DIR}/quantized/affine_quantize.cu
58
+ ${CMAKE_CURRENT_SOURCE_DIR}/quantized/fp_quantize.cu
59
+ ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qmv.cu
60
+ ${CMAKE_CURRENT_SOURCE_DIR}/quantized/quantized.cpp
61
+ ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm.cpp
62
+ ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_utils.cu
63
+ ${CMAKE_CURRENT_SOURCE_DIR}/quantized/convert_fp8.cu
64
+ ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
65
+
66
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/binary)
67
+ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/unary)
68
+
69
+ # fp4 is not available on < 12.8
70
+ if(CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 12.8.0)
71
+ target_include_directories(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/)
72
+ target_sources(mlx
73
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/no_qqmm_impl.cpp)
74
+ else()
75
+ target_sources(
76
+ mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/quantized/qqmm_impl.cpp
77
+ ${CMAKE_CURRENT_SOURCE_DIR}/quantized/cublas_qqmm.cpp)
78
+ endif()
79
+
80
+ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.9.0)
81
+ target_sources(
82
+ mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_9.cu)
83
+ else()
84
+ target_sources(
85
+ mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/gemms/cublas_gemm_batched_12_0.cpp)
86
+ endif()
87
+
88
+ # Embed kernel sources in binary for JIT compilation.
89
+ file(
90
+ GLOB MLX_JIT_SOURCES
91
+ RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
92
+ "${CMAKE_CURRENT_SOURCE_DIR}/device/*.h"
93
+ "${CMAKE_CURRENT_SOURCE_DIR}/device/*.cuh")
94
+ string(JOIN ":" MLX_JIT_SOURCES_ARG ${MLX_JIT_SOURCES})
95
+ add_custom_command(
96
+ OUTPUT gen/cuda_jit_sources.h
97
+ COMMAND
98
+ ${CMAKE_COMMAND} -DMLX_SOURCE_ROOT=${CMAKE_CURRENT_SOURCE_DIR}
99
+ -DMLX_JIT_SOURCES=${MLX_JIT_SOURCES_ARG} -P
100
+ "${CMAKE_CURRENT_SOURCE_DIR}/bin2h.cmake"
101
+ DEPENDS bin2h.cmake ${MLX_JIT_SOURCES})
102
+ add_custom_target(cuda_jit_sources DEPENDS gen/cuda_jit_sources.h)
103
+ add_dependencies(mlx cuda_jit_sources)
104
+ target_include_directories(mlx PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/gen")
105
+
106
+ # ------------------------ Compilation configs ------------------------
107
+
108
+ target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
109
+
110
+ # Enable defining device lambda functions.
111
+ target_compile_options(mlx
112
+ PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")
113
+
114
+ # Enable calling host constexpr functions from device. This is needed because
115
+ # the constexpr version of isnan is host only.
116
+ target_compile_options(
117
+ mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--expt-relaxed-constexpr>")
118
+
119
+ # Suppress nvcc warnings on C++ headers.
120
+ target_compile_options(
121
+ mlx
122
+ PRIVATE
123
+ $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe="--diag_suppress=27,997,1394,20011,20208">
124
+ )
125
+
126
+ # Ignore some valid nvcc warnings, we might want to fix them in future.
127
+ target_compile_options(
128
+ mlx PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcudafe="--diag_suppress=177,550">)
129
+
130
+ # Use stronger binaries compression. This feature was introduced in CUDA 12.8
131
+ # and requires drivers released after CUDA 12.4.
132
+ if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8.0)
133
+ target_compile_options(
134
+ mlx PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--compress-mode=size>")
135
+ endif()
136
+
137
+ # Use native CUDA arch by default.
138
+ if(NOT DEFINED MLX_CUDA_ARCHITECTURES)
139
+ execute_process(
140
+ COMMAND __nvcc_device_query
141
+ OUTPUT_VARIABLE MLX_CUDA_ARCHITECTURES
142
+ OUTPUT_STRIP_TRAILING_WHITESPACE)
143
+ set(UPGRADABLE_ARCHITECTURES "90;100;121")
144
+ if(MLX_CUDA_ARCHITECTURES STREQUAL "")
145
+ message(
146
+ FATAL_ERROR
147
+ "Can not get native CUDA arch, must set MLX_CUDA_ARCHITECTURES")
148
+ elseif(MLX_CUDA_ARCHITECTURES IN_LIST UPGRADABLE_ARCHITECTURES)
149
+ # Use arch-specific compute capability whenever possible.
150
+ set(MLX_CUDA_ARCHITECTURES "${MLX_CUDA_ARCHITECTURES}a")
151
+ endif()
152
+ endif()
153
+ message(STATUS "CUDA architectures: ${MLX_CUDA_ARCHITECTURES}")
154
+ set_target_properties(mlx PROPERTIES CUDA_ARCHITECTURES
155
+ "${MLX_CUDA_ARCHITECTURES}")
156
+
157
+ # Search CUDA libs from installed python packages.
158
+ if(WIN32)
159
+ # Resolve paths of unfound DLL at runtime.
160
+ if(BUILD_SHARED_LIBS)
161
+ target_link_libraries(mlx PRIVATE "delayimp.lib")
162
+ target_sources(mlx PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/delayload.cpp)
163
+ else()
164
+ # For static library the delayload must be compiled into final executables.
165
+ target_link_libraries(mlx PUBLIC "delayimp.lib")
166
+ target_sources(
167
+ mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/delayload.cpp>)
168
+ endif()
169
+ # Get all the CUDA DLLs we could link with.
170
+ file(
171
+ GLOB CUDA_DLL_NAMES
172
+ RELATIVE "${CUDAToolkit_BIN_DIR}/x64"
173
+ "${CUDAToolkit_BIN_DIR}/x64/*.dll")
174
+ # Delay load CUDA and cuDNN libs.
175
+ foreach(CUDA_DLL ${CUDA_DLL_NAMES} ${CUDNN_DLL_NAMES})
176
+ target_link_options(mlx PUBLIC "/DELAYLOAD:${CUDA_DLL}")
177
+ endforeach()
178
+ # Pass the locations where CUDA DLLs are placed.
179
+ if(NOT MLX_LOAD_CUDA_LIBS_FROM_PYTHON)
180
+ target_compile_definitions(
181
+ mlx PUBLIC MLX_CUDA_BIN_DIR="${CUDAToolkit_BIN_DIR}/x64"
182
+ MLX_CUDNN_BIN_DIR="${CUDNN_BIN_DIR}")
183
+ endif()
184
+ else()
185
+ # For POSIX we rely on RPATH to search for CUDA libs.
186
+ if(MLX_LOAD_CUDA_LIBS_FROM_PYTHON)
187
+ set_property(
188
+ TARGET mlx
189
+ APPEND
190
+ PROPERTY INSTALL_RPATH
191
+ # The paths here should match the install_requires in setup.py.
192
+ "$ORIGIN/../../nvidia/cublas/lib"
193
+ "$ORIGIN/../../nvidia/cuda_nvrtc/lib"
194
+ "$ORIGIN/../../nvidia/cudnn/lib"
195
+ "$ORIGIN/../../nvidia/nccl/lib")
196
+ endif()
197
+ endif()
198
+
199
+ # ------------------------ Dependencies ------------------------
200
+
201
+ # Use fixed version of CCCL.
202
+ FetchContent_Declare(
203
+ cccl
204
+ URL "https://github.com/NVIDIA/cccl/releases/download/v3.1.3/cccl-v3.1.3.zip")
205
+ FetchContent_MakeAvailable(cccl)
206
+ target_include_directories(mlx BEFORE PRIVATE "${cccl_SOURCE_DIR}/include")
207
+
208
+ # Install CCCL headers for JIT.
209
+ install(DIRECTORY ${cccl_SOURCE_DIR}/include/cuda
210
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
211
+ install(DIRECTORY ${cccl_SOURCE_DIR}/include/nv
212
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/cccl)
213
+
214
+ # The binary of C++ tests will not be installed so it can not find the CCCL
215
+ # headers, and we have to hard-code the path.
216
+ if(MLX_BUILD_TESTS)
217
+ target_compile_definitions(mlx
218
+ PRIVATE MLX_CCCL_DIR="${cccl_SOURCE_DIR}/include")
219
+ endif()
220
+
221
+ # Use fixed version of NVTX.
222
+ FetchContent_Declare(
223
+ nvtx3
224
+ GIT_REPOSITORY https://github.com/NVIDIA/NVTX.git
225
+ GIT_TAG v3.1.1
226
+ GIT_SHALLOW TRUE
227
+ SOURCE_SUBDIR c EXCLUDE_FROM_ALL)
228
+ FetchContent_MakeAvailable(nvtx3)
229
+ target_link_libraries(mlx PUBLIC $<BUILD_INTERFACE:nvtx3-cpp>)
230
+
231
+ # Make cuda runtime APIs available in non-cuda files.
232
+ target_include_directories(mlx PRIVATE ${CUDAToolkit_INCLUDE_DIRS})
233
+
234
+ # Use cublasLt.
235
+ target_link_libraries(mlx PRIVATE CUDA::cublasLt)
236
+
237
+ # Use NVRTC and driver APIs.
238
+ target_link_libraries(mlx PRIVATE CUDA::nvrtc CUDA::cuda_driver)
239
+
240
+ # Use the frontend APIs of cuDNN.
241
+ FetchContent_Declare(
242
+ cudnn
243
+ GIT_REPOSITORY https://github.com/NVIDIA/cudnn-frontend.git
244
+ GIT_TAG v1.16.0
245
+ GIT_SHALLOW TRUE
246
+ EXCLUDE_FROM_ALL)
247
+ set(CUDNN_FRONTEND_SKIP_JSON_LIB ON)
248
+ set(CUDNN_FRONTEND_BUILD_SAMPLES OFF)
249
+ set(CUDNN_FRONTEND_BUILD_TESTS OFF)
250
+ set(CUDNN_FRONTEND_BUILD_PYTHON_BINDINGS OFF)
251
+ FetchContent_MakeAvailable(cudnn)
252
+ target_link_libraries(mlx PRIVATE cudnn_frontend)
253
+ # Link with the actual cuDNN libraries.
254
+ target_link_libraries(mlx PRIVATE CUDNN::cudnn_all)
255
+
256
+ # Use header-only CUTLASS.
257
+ FetchContent_Declare(
258
+ cutlass
259
+ GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
260
+ GIT_TAG v4.3.5
261
+ GIT_SHALLOW TRUE
262
+ SOURCE_SUBDIR include EXCLUDE_FROM_ALL)
263
+ FetchContent_MakeAvailable(cutlass)
264
+ target_include_directories(
265
+ mlx SYSTEM PRIVATE $<BUILD_INTERFACE:${cutlass_SOURCE_DIR}/include>)
@@ -0,0 +1,451 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/allocator.h"
4
+ #include "mlx/backend/cuda/device.h"
5
+ #include "mlx/backend/cuda/utils.h"
6
+ #include "mlx/backend/gpu/device_info.h"
7
+ #include "mlx/memory.h"
8
+ #include "mlx/scheduler.h"
9
+ #include "mlx/utils.h"
10
+
11
+ #include <cuda_runtime.h>
12
+ #include <fmt/format.h>
13
+
14
+ #include <cassert>
15
+ #include <fstream>
16
+ #include <string>
17
+
18
+ namespace mlx::core {
19
+
20
+ namespace cu {
21
+
22
+ constexpr int page_size = 16384;
23
+
24
+ // Any allocations smaller than this will try to use the small pool
25
+ constexpr int small_block_size = 8;
26
+
27
+ // The small pool size in bytes. This should be a multiple of the host page
28
+ // size and small_block_size.
29
+ constexpr int small_pool_size = 4 * page_size;
30
+
31
+ // Check if running on Windows or Windows Subsystem for Linux
32
+ bool is_windows() {
33
+ #if defined(_WIN32)
34
+ return true;
35
+ #elif defined(__linux__)
36
+ // WSL kernels contain "microsoft" or "WSL" in /proc/version
37
+ static bool is_wsl = []() {
38
+ std::ifstream version("/proc/version");
39
+ if (version.is_open()) {
40
+ std::string line;
41
+ std::getline(version, line);
42
+ return line.find("microsoft") != std::string::npos ||
43
+ line.find("Microsoft") != std::string::npos ||
44
+ line.find("WSL") != std::string::npos;
45
+ }
46
+ return false;
47
+ }();
48
+ return is_wsl;
49
+ #else
50
+ return false;
51
+ #endif
52
+ }
53
+
54
+ bool supports_managed_memory() {
55
+ static bool managed_memory = []() {
56
+ int device_count = gpu::device_count();
57
+ for (int i = 0; i < device_count; ++i) {
58
+ auto& d = cu::device(i);
59
+ if (!d.managed_memory()) {
60
+ return false;
61
+ }
62
+ // Empirically on Windows (and WSL) if there is no concurrentManagedAccess
63
+ // the managed memory also does not work.
64
+ if (is_windows() && !d.concurrent_managed_access()) {
65
+ return false;
66
+ }
67
+ }
68
+ return true;
69
+ }();
70
+ return managed_memory;
71
+ }
72
+
73
+ inline void* unified_malloc(size_t size) {
74
+ void* data = nullptr;
75
+ if (supports_managed_memory()) {
76
+ CHECK_CUDA_ERROR(cudaMallocManaged(&data, size));
77
+ } else {
78
+ CHECK_CUDA_ERROR(cudaMallocHost(&data, size));
79
+ }
80
+ return data;
81
+ }
82
+
83
+ inline void unified_free(void* data) {
84
+ if (supports_managed_memory()) {
85
+ CHECK_CUDA_ERROR(cudaFree(data));
86
+ } else {
87
+ CHECK_CUDA_ERROR(cudaFreeHost(data));
88
+ }
89
+ }
90
+
91
+ #if CUDART_VERSION >= 13000
92
+ inline cudaMemLocation cuda_mem_loc(int i) {
93
+ cudaMemLocation loc;
94
+ loc.type = cudaMemLocationTypeDevice;
95
+ loc.id = i;
96
+ return loc;
97
+ }
98
+ #else
99
+ inline int cuda_mem_loc(int i) {
100
+ return i;
101
+ }
102
+ #endif // CUDART_VERSION >= 13000
103
+
104
+ SmallSizePool::SmallSizePool() {
105
+ auto num_blocks = small_pool_size / small_block_size;
106
+ buffer_ = new Block[num_blocks];
107
+ next_free_ = buffer_;
108
+
109
+ data_ = unified_malloc(small_pool_size);
110
+ if (supports_managed_memory()) {
111
+ int device_count = gpu::device_count();
112
+ for (int i = 0; i < device_count; ++i) {
113
+ if (device(i).concurrent_managed_access()) {
114
+ auto loc = cuda_mem_loc(i);
115
+ CHECK_CUDA_ERROR(cudaMemAdvise(
116
+ data_, small_pool_size, cudaMemAdviseSetAccessedBy, loc));
117
+ }
118
+ }
119
+ }
120
+
121
+ auto curr = next_free_;
122
+ for (size_t i = 1; i < num_blocks; ++i) {
123
+ curr->next = buffer_ + i;
124
+ curr = curr->next;
125
+ }
126
+ curr->next = nullptr;
127
+ }
128
+
129
+ SmallSizePool::~SmallSizePool() {
130
+ unified_free(data_);
131
+ delete[] buffer_;
132
+ }
133
+
134
+ CudaBuffer* SmallSizePool::malloc() {
135
+ if (next_free_ == nullptr) {
136
+ return nullptr;
137
+ }
138
+ Block* b = next_free_;
139
+ uint64_t i = next_free_ - buffer_;
140
+ next_free_ = next_free_->next;
141
+ b->buf.data = static_cast<char*>(data_) + i * small_block_size;
142
+ b->buf.size = small_block_size;
143
+ b->buf.device = -1;
144
+ return &b->buf;
145
+ }
146
+
147
+ void SmallSizePool::free(CudaBuffer* buf) {
148
+ auto b = reinterpret_cast<Block*>(buf);
149
+ b->next = next_free_;
150
+ next_free_ = b;
151
+ }
152
+
153
+ bool SmallSizePool::in_pool(CudaBuffer* buf) {
154
+ constexpr int num_blocks = (small_pool_size / small_block_size);
155
+ auto b = reinterpret_cast<Block*>(buf);
156
+ int64_t block_num = b - buffer_;
157
+ return block_num >= 0 && block_num < num_blocks;
158
+ }
159
+
160
+ CudaAllocator::CudaAllocator()
161
+ : buffer_cache_(
162
+ page_size,
163
+ [](CudaBuffer* buf) { return buf->size; },
164
+ [this](CudaBuffer* buf) { free_cuda_buffer(buf); }) {
165
+ size_t free;
166
+ CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total_memory_));
167
+ memory_limit_ = total_memory_ * 0.95;
168
+ free_limit_ = total_memory_ - memory_limit_;
169
+ max_pool_size_ = memory_limit_;
170
+
171
+ int device_count = gpu::device_count();
172
+ free_streams_.resize(device_count);
173
+ mem_pools_.resize(device_count);
174
+ for (int i = 0; i < device_count; ++i) {
175
+ auto& d = device(i);
176
+ if (d.memory_pools()) {
177
+ free_streams_[i] = CudaStream(d);
178
+ CHECK_CUDA_ERROR(cudaDeviceGetDefaultMemPool(&mem_pools_[i], i));
179
+ }
180
+ }
181
+ }
182
+
183
+ Buffer
184
+ CudaAllocator::malloc_async(size_t size, int device, cudaStream_t stream) {
185
+ if (size == 0) {
186
+ return Buffer{new CudaBuffer{nullptr, 0, -1}};
187
+ }
188
+
189
+ if (size <= small_block_size) {
190
+ size = 8;
191
+ } else if (size < page_size) {
192
+ size = next_power_of_2(size);
193
+ } else {
194
+ size = page_size * ((size + page_size - 1) / page_size);
195
+ }
196
+
197
+ if (size <= small_block_size || stream == nullptr) {
198
+ device = -1;
199
+ }
200
+
201
+ // Find available buffer from cache.
202
+ std::unique_lock lock(mutex_);
203
+ CudaBuffer* buf = buffer_cache_.reuse_from_cache(size);
204
+ if (!buf) {
205
+ // If we have a lot of memory pressure try to reclaim memory from the cache.
206
+ int64_t mem_to_free =
207
+ get_active_memory() + get_cache_memory() + size - memory_limit_;
208
+ if (mem_to_free > 0) {
209
+ buffer_cache_.release_cached_buffers(mem_to_free);
210
+ }
211
+
212
+ // Try the scalar pool first
213
+ if (size <= small_block_size) {
214
+ buf = scalar_pool_.malloc();
215
+ }
216
+ lock.unlock();
217
+ if (!buf) {
218
+ void* data = nullptr;
219
+ if (device == -1) {
220
+ data = unified_malloc(size);
221
+ } else {
222
+ cu::device(device).make_current();
223
+ if (mem_pools_[device]) { // supports memory pools
224
+ CHECK_CUDA_ERROR(cudaMallocAsync(&data, size, stream));
225
+ } else {
226
+ CHECK_CUDA_ERROR(cudaMalloc(&data, size));
227
+ }
228
+ }
229
+ if (!data) {
230
+ std::ostringstream msg;
231
+ msg << "[malloc] Unable to allocate " << size << " bytes.";
232
+ throw std::runtime_error(msg.str());
233
+ }
234
+ buf = new CudaBuffer{data, size, device};
235
+ }
236
+ lock.lock();
237
+
238
+ // If any cuda memory pool has too much reserved memory, clear some
239
+ // memory from the cache. This prevents graph / kernel execution failing
240
+ // from OOM
241
+ if (get_cache_memory() > 0) {
242
+ for (auto p : mem_pools_) {
243
+ if (p) {
244
+ size_t used = 0;
245
+ CHECK_CUDA_ERROR(cudaMemPoolGetAttribute(
246
+ p, cudaMemPoolAttrReservedMemCurrent, &used));
247
+ if (used > (total_memory_ - free_limit_)) {
248
+ buffer_cache_.release_cached_buffers(free_limit_);
249
+ break;
250
+ }
251
+ }
252
+ }
253
+ }
254
+ }
255
+ active_memory_ += buf->size;
256
+ peak_memory_ = std::max(active_memory_, peak_memory_);
257
+
258
+ // Maintain the cache below the requested limit.
259
+ if (get_cache_memory() > max_pool_size_) {
260
+ buffer_cache_.release_cached_buffers(get_cache_memory() - max_pool_size_);
261
+ }
262
+ lock.unlock();
263
+ // Copy to unified memory here if the buffer is not on the right device.
264
+ if (buf->device >= 0 && buf->device != device) {
265
+ move_to_unified_memory(*buf, stream);
266
+ }
267
+ return Buffer{buf};
268
+ }
269
+
270
+ Buffer CudaAllocator::malloc(size_t size) {
271
+ return malloc_async(size, -1, nullptr);
272
+ }
273
+
274
+ void CudaAllocator::free(Buffer buffer) {
275
+ auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
276
+ if (!buf) {
277
+ return;
278
+ }
279
+ if (buf->size == 0) {
280
+ delete buf;
281
+ return;
282
+ }
283
+
284
+ std::unique_lock lock(mutex_);
285
+ active_memory_ -= buf->size;
286
+ if (get_cache_memory() < max_pool_size_) {
287
+ buffer_cache_.recycle_to_cache(buf);
288
+ } else {
289
+ free_cuda_buffer(buf);
290
+ }
291
+ }
292
+
293
+ size_t CudaAllocator::size(Buffer buffer) const {
294
+ auto* buf = static_cast<CudaBuffer*>(buffer.ptr());
295
+ if (!buf) {
296
+ return 0;
297
+ }
298
+ return buf->size;
299
+ }
300
+
301
+ void CudaAllocator::move_to_unified_memory(
302
+ CudaBuffer& buf,
303
+ cudaStream_t stream) {
304
+ if (buf.device == -1) {
305
+ return;
306
+ }
307
+ void* data = unified_malloc(buf.size);
308
+ cudaMemcpyKind kind =
309
+ supports_managed_memory() ? cudaMemcpyDefault : cudaMemcpyDeviceToHost;
310
+ if (stream && mem_pools_[buf.device]) {
311
+ CHECK_CUDA_ERROR(cudaMemcpyAsync(data, buf.data, buf.size, kind, stream));
312
+ free_async(buf, stream);
313
+ } else {
314
+ CHECK_CUDA_ERROR(cudaMemcpy(data, buf.data, buf.size, kind));
315
+ free_async(buf);
316
+ }
317
+ buf.data = data;
318
+ buf.device = -1;
319
+ }
320
+
321
+ // This must be called with mutex_ aquired
322
+ void CudaAllocator::free_cuda_buffer(CudaBuffer* buf) {
323
+ if (scalar_pool_.in_pool(buf)) {
324
+ scalar_pool_.free(buf);
325
+ } else {
326
+ free_async(*buf);
327
+ delete buf;
328
+ }
329
+ }
330
+
331
+ void CudaAllocator::free_async(CudaBuffer& buf, cudaStream_t stream) {
332
+ if (buf.device == -1) {
333
+ unified_free(buf.data);
334
+ } else {
335
+ // Free asynchronously when memory pools is supported.
336
+ if (mem_pools_[buf.device]) {
337
+ if (!stream) {
338
+ stream = free_streams_[buf.device];
339
+ }
340
+ CHECK_CUDA_ERROR(cudaFreeAsync(buf.data, stream));
341
+ } else {
342
+ CHECK_CUDA_ERROR(cudaFree(buf.data));
343
+ }
344
+ }
345
+ }
346
+
347
+ size_t CudaAllocator::get_active_memory() const {
348
+ return active_memory_;
349
+ }
350
+
351
+ size_t CudaAllocator::get_peak_memory() const {
352
+ return peak_memory_;
353
+ }
354
+
355
+ void CudaAllocator::reset_peak_memory() {
356
+ std::lock_guard lock(mutex_);
357
+ peak_memory_ = 0;
358
+ }
359
+
360
+ size_t CudaAllocator::get_memory_limit() {
361
+ return memory_limit_;
362
+ }
363
+
364
+ size_t CudaAllocator::set_memory_limit(size_t limit) {
365
+ std::lock_guard lock(mutex_);
366
+ std::swap(limit, memory_limit_);
367
+ return limit;
368
+ }
369
+
370
+ size_t CudaAllocator::get_cache_memory() const {
371
+ return buffer_cache_.cache_size();
372
+ }
373
+
374
+ size_t CudaAllocator::set_cache_limit(size_t limit) {
375
+ std::lock_guard lk(mutex_);
376
+ std::swap(limit, max_pool_size_);
377
+ return limit;
378
+ }
379
+
380
+ void CudaAllocator::clear_cache() {
381
+ std::lock_guard lk(mutex_);
382
+ buffer_cache_.clear();
383
+ }
384
+
385
+ CudaAllocator& allocator() {
386
+ static auto* allocator_ = []() {
387
+ // Ensure scheduler is created before allocator.
388
+ scheduler::scheduler();
389
+ // By creating the |allocator_| on heap, the destructor of CudaAllocator
390
+ // will not be called on exit and buffers in the cache will be leaked. This
391
+ // can save some time at program exit.
392
+ return new CudaAllocator();
393
+ }();
394
+ return *allocator_;
395
+ }
396
+
397
+ Buffer malloc_async(size_t size, CommandEncoder& encoder) {
398
+ return allocator().malloc_async(
399
+ size, encoder.device().cuda_device(), encoder.stream());
400
+ }
401
+
402
+ } // namespace cu
403
+
404
+ namespace allocator {
405
+
406
+ Allocator& allocator() {
407
+ return cu::allocator();
408
+ }
409
+
410
+ void* Buffer::raw_ptr() {
411
+ if (!ptr_) {
412
+ return nullptr;
413
+ }
414
+ auto& cbuf = *static_cast<cu::CudaBuffer*>(ptr_);
415
+ cu::allocator().move_to_unified_memory(cbuf);
416
+ return cbuf.data;
417
+ }
418
+
419
+ } // namespace allocator
420
+
421
+ size_t get_active_memory() {
422
+ return cu::allocator().get_active_memory();
423
+ }
424
+ size_t get_peak_memory() {
425
+ return cu::allocator().get_peak_memory();
426
+ }
427
+ void reset_peak_memory() {
428
+ return cu::allocator().reset_peak_memory();
429
+ }
430
+ size_t set_memory_limit(size_t limit) {
431
+ return cu::allocator().set_memory_limit(limit);
432
+ }
433
+ size_t get_memory_limit() {
434
+ return cu::allocator().get_memory_limit();
435
+ }
436
+ size_t get_cache_memory() {
437
+ return cu::allocator().get_cache_memory();
438
+ }
439
+ size_t set_cache_limit(size_t limit) {
440
+ return cu::allocator().set_cache_limit(limit);
441
+ }
442
+ void clear_cache() {
443
+ cu::allocator().clear_cache();
444
+ }
445
+
446
+ // Not supported in CUDA.
447
+ size_t set_wired_limit(size_t) {
448
+ return 0;
449
+ }
450
+
451
+ } // namespace mlx::core