mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,94 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/allocator.h"
6
+ #include "mlx/backend/common/buffer_cache.h"
7
+ #include "mlx/backend/cuda/cuda_utils.h"
8
+
9
+ #include <cuda_runtime.h>
10
+ #include <mutex>
11
+ #include <set>
12
+ #include <utility>
13
+
14
+ namespace mlx::core::cu {
15
+
16
+ class CommandEncoder;
17
+
18
+ using allocator::Buffer;
19
+
20
+ // Stores cuda-managed unified memory.
21
+ struct CudaBuffer {
22
+ void* data;
23
+ size_t size;
24
+ int device; // -1 for managed
25
+ };
26
+
27
+ class SmallSizePool {
28
+ private:
29
+ union Block {
30
+ Block* next;
31
+ CudaBuffer buf;
32
+ };
33
+
34
+ Block* buffer_{nullptr};
35
+ void* data_{nullptr};
36
+ Block* next_free_{nullptr};
37
+
38
+ public:
39
+ SmallSizePool();
40
+ ~SmallSizePool();
41
+
42
+ SmallSizePool(const SmallSizePool&) = delete;
43
+ SmallSizePool& operator=(const SmallSizePool&) = delete;
44
+
45
+ CudaBuffer* malloc();
46
+ void free(CudaBuffer* buf);
47
+ bool in_pool(CudaBuffer* buf);
48
+ };
49
+
50
+ class CudaAllocator : public allocator::Allocator {
51
+ public:
52
+ Buffer malloc(size_t size) override;
53
+ Buffer malloc_async(size_t size, int device, cudaStream_t stream);
54
+ void free(Buffer buffer) override;
55
+ size_t size(Buffer buffer) const override;
56
+
57
+ // Replace the memory of |buf| with unified memory (managed memory or pinned
58
+ // host memory), and copy the data over. Pass |stream| to copy asynchronously.
59
+ void move_to_unified_memory(CudaBuffer& buf, cudaStream_t stream = nullptr);
60
+
61
+ size_t get_active_memory() const;
62
+ size_t get_peak_memory() const;
63
+ void reset_peak_memory();
64
+ size_t get_memory_limit();
65
+ size_t set_memory_limit(size_t limit);
66
+ size_t get_cache_memory() const;
67
+ size_t set_cache_limit(size_t limit);
68
+ void clear_cache();
69
+
70
+ private:
71
+ void free_cuda_buffer(CudaBuffer* buf);
72
+ void free_async(CudaBuffer& buf, cudaStream_t stream = nullptr);
73
+
74
+ CudaAllocator();
75
+ friend CudaAllocator& allocator();
76
+
77
+ std::mutex mutex_;
78
+ size_t memory_limit_;
79
+ size_t free_limit_;
80
+ size_t total_memory_;
81
+ size_t max_pool_size_;
82
+ BufferCache<CudaBuffer> buffer_cache_;
83
+ size_t active_memory_{0};
84
+ size_t peak_memory_{0};
85
+ std::vector<CudaStream> free_streams_;
86
+ std::vector<cudaMemPool_t> mem_pools_;
87
+ SmallSizePool scalar_pool_;
88
+ };
89
+
90
+ CudaAllocator& allocator();
91
+
92
+ Buffer malloc_async(size_t size, CommandEncoder& encoder);
93
+
94
+ } // namespace mlx::core::cu
@@ -0,0 +1,68 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/device.h"
4
+ #include "mlx/backend/cuda/device/fp16_math.cuh"
5
+ #include "mlx/backend/cuda/kernel_utils.cuh"
6
+ #include "mlx/dtype_utils.h"
7
+ #include "mlx/primitives.h"
8
+
9
+ #include <cooperative_groups.h>
10
+ #include <nvtx3/nvtx3.hpp>
11
+
12
+ namespace mlx::core {
13
+
14
+ namespace cu {
15
+
16
+ namespace cg = cooperative_groups;
17
+
18
+ template <typename T, typename IdxT, int N_WRITES>
19
+ __global__ void arange(T* out, IdxT size, T start, T step) {
20
+ IdxT index = cg::this_grid().thread_rank();
21
+
22
+ if ((index + 1) * N_WRITES > size) {
23
+ for (IdxT i = index * N_WRITES; i < size; ++i) {
24
+ out[i] = start + i * step;
25
+ }
26
+ } else {
27
+ AlignedVector<T, N_WRITES> out_vec;
28
+ #pragma unroll
29
+ for (int i = 0; i < N_WRITES; ++i) {
30
+ out_vec[i] = start + (index * N_WRITES + i) * step;
31
+ }
32
+
33
+ store_vector<N_WRITES>(out, index, out_vec);
34
+ }
35
+ }
36
+
37
+ } // namespace cu
38
+
39
+ void Arange::eval_gpu(const std::vector<array>& inputs, array& out) {
40
+ nvtx3::scoped_range r("Arange::eval_gpu");
41
+ if (out.size() == 0) {
42
+ return;
43
+ }
44
+ auto& encoder = cu::get_command_encoder(stream());
45
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
46
+ encoder.set_output_array(out);
47
+
48
+ dispatch_int_float_types(out.dtype(), "Arange", [&](auto type_tag) {
49
+ using CTYPE = MLX_GET_TYPE(type_tag);
50
+ using OutType = cuda_type_t<CTYPE>;
51
+ constexpr int N_WRITES = 16 / sizeof(OutType);
52
+ dispatch_bool(out.data_size() > INT32_MAX, [&](auto large) {
53
+ using IdxT = std::conditional_t<large(), int64_t, int32_t>;
54
+ auto [num_blocks, block_dims] = get_launch_args(out, large(), N_WRITES);
55
+ encoder.add_kernel_node(
56
+ cu::arange<OutType, IdxT, N_WRITES>,
57
+ num_blocks,
58
+ block_dims,
59
+ 0,
60
+ gpu_ptr<OutType>(out),
61
+ out.data_size(),
62
+ static_cast<CTYPE>(start_),
63
+ static_cast<CTYPE>(start_ + step_) - static_cast<CTYPE>(start_));
64
+ });
65
+ });
66
+ }
67
+
68
+ } // namespace mlx::core
@@ -0,0 +1,189 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/common/utils.h"
4
+ #include "mlx/backend/cuda/device.h"
5
+ #include "mlx/backend/cuda/device/fp16_math.cuh"
6
+ #include "mlx/backend/cuda/kernel_utils.cuh"
7
+ #include "mlx/dtype_utils.h"
8
+ #include "mlx/primitives.h"
9
+
10
+ #include <cooperative_groups.h>
11
+ #include <nvtx3/nvtx3.hpp>
12
+ #include <cub/block/block_load.cuh>
13
+ #include <cub/block/block_reduce.cuh>
14
+
15
+ #include <cassert>
16
+
17
+ namespace mlx::core {
18
+
19
+ namespace cu {
20
+
21
+ namespace cg = cooperative_groups;
22
+
23
+ template <typename T>
24
+ struct IndexValPair {
25
+ uint32_t index;
26
+ T val;
27
+ };
28
+
29
+ template <typename T>
30
+ struct ArgMin {
31
+ constexpr __device__ T init() {
32
+ return Limits<T>::max();
33
+ }
34
+
35
+ __device__ IndexValPair<T> operator()(
36
+ const IndexValPair<T>& best,
37
+ const IndexValPair<T>& current) {
38
+ if (best.val > current.val ||
39
+ (best.val == current.val && best.index > current.index)) {
40
+ return current;
41
+ } else {
42
+ return best;
43
+ }
44
+ }
45
+
46
+ template <int N>
47
+ __device__ IndexValPair<T> reduce_many(
48
+ IndexValPair<T> best,
49
+ const AlignedVector<T, N>& vals,
50
+ uint32_t offset) {
51
+ #pragma unroll
52
+ for (int i = 0; i < N; i++) {
53
+ if (vals[i] < best.val) {
54
+ best.val = vals[i];
55
+ best.index = offset + i;
56
+ }
57
+ }
58
+ return best;
59
+ }
60
+ };
61
+
62
+ template <typename T>
63
+ struct ArgMax {
64
+ constexpr __device__ T init() {
65
+ return Limits<T>::min();
66
+ }
67
+
68
+ __device__ IndexValPair<T> operator()(
69
+ const IndexValPair<T>& best,
70
+ const IndexValPair<T>& current) {
71
+ if (best.val < current.val ||
72
+ (best.val == current.val && best.index > current.index)) {
73
+ return current;
74
+ } else {
75
+ return best;
76
+ }
77
+ }
78
+
79
+ template <int N>
80
+ __device__ IndexValPair<T> reduce_many(
81
+ IndexValPair<T> best,
82
+ const AlignedVector<T, N>& vals,
83
+ uint32_t offset) {
84
+ #pragma unroll
85
+ for (int i = 0; i < N; i++) {
86
+ if (vals[i] > best.val) {
87
+ best.val = vals[i];
88
+ best.index = offset + i;
89
+ }
90
+ }
91
+ return best;
92
+ }
93
+ };
94
+
95
+ template <typename T, typename Op, int BLOCK_DIM, int N_READS = 4>
96
+ __global__ void arg_reduce_general(
97
+ const T* in,
98
+ uint32_t* out,
99
+ size_t size,
100
+ const __grid_constant__ Shape shape,
101
+ const __grid_constant__ Strides in_strides,
102
+ const __grid_constant__ Strides out_strides,
103
+ int32_t ndim,
104
+ int64_t axis_stride,
105
+ int32_t axis_size) {
106
+ auto block = cg::this_thread_block();
107
+
108
+ int64_t index = cg::this_grid().block_rank();
109
+ if (index >= size) {
110
+ return;
111
+ }
112
+
113
+ int64_t in_idx = elem_to_loc(index, shape.data(), in_strides.data(), ndim);
114
+ int64_t out_idx = elem_to_loc(index, shape.data(), out_strides.data(), ndim);
115
+ in += in_idx;
116
+
117
+ Op op;
118
+ T init = op.init();
119
+ IndexValPair<T> best{0, init};
120
+
121
+ for (int r = 0; r < cuda::ceil_div(axis_size, BLOCK_DIM * N_READS); ++r) {
122
+ auto tid = r * BLOCK_DIM + block.thread_index().x;
123
+ auto vals = load_vector<N_READS>(in, tid, axis_size, axis_stride, init);
124
+ best = op.reduce_many(best, vals, tid * N_READS);
125
+ }
126
+
127
+ typedef cub::BlockReduce<IndexValPair<T>, BLOCK_DIM> BlockReduceT;
128
+ __shared__ typename BlockReduceT::TempStorage temp;
129
+
130
+ best = BlockReduceT(temp).Reduce(best, op);
131
+
132
+ if (block.thread_rank() == 0) {
133
+ out[out_idx] = best.index;
134
+ }
135
+ }
136
+
137
+ } // namespace cu
138
+
139
+ void ArgReduce::eval_gpu(const std::vector<array>& inputs, array& out) {
140
+ nvtx3::scoped_range r("ArgReduce::eval_gpu");
141
+ assert(inputs.size() == 1);
142
+ auto& in = inputs[0];
143
+
144
+ auto& s = stream();
145
+ auto& encoder = cu::get_command_encoder(s);
146
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
147
+
148
+ // Prepare the shapes, strides and axis arguments.
149
+ Shape shape = remove_index(in.shape(), axis_);
150
+ Strides in_strides = remove_index(in.strides(), axis_);
151
+ Strides out_strides = out.ndim() == in.ndim()
152
+ ? remove_index(out.strides(), axis_)
153
+ : out.strides();
154
+ int64_t axis_stride = in.strides()[axis_];
155
+ int32_t axis_size = in.shape()[axis_];
156
+ int32_t ndim = shape.size();
157
+
158
+ // ArgReduce.
159
+ encoder.set_input_array(in);
160
+ encoder.set_output_array(out);
161
+ dispatch_real_types(in.dtype(), "ArgReduce", [&](auto type_tag) {
162
+ using T = cuda_type_t<MLX_GET_TYPE(type_tag)>;
163
+ constexpr uint32_t N_READS = 4;
164
+ dispatch_block_dim(cuda::ceil_div(axis_size, N_READS), [&](auto block_dim) {
165
+ dim3 num_blocks = get_2d_grid_dims(out.shape(), out.strides());
166
+ auto kernel =
167
+ cu::arg_reduce_general<T, cu::ArgMax<T>, block_dim(), N_READS>;
168
+ if (reduce_type_ == ArgReduce::ArgMin) {
169
+ kernel = cu::arg_reduce_general<T, cu::ArgMin<T>, block_dim(), N_READS>;
170
+ }
171
+ encoder.add_kernel_node(
172
+ kernel,
173
+ num_blocks,
174
+ block_dim(),
175
+ 0,
176
+ gpu_ptr<T>(in),
177
+ gpu_ptr<uint32_t>(out),
178
+ out.size(),
179
+ const_param(shape),
180
+ const_param(in_strides),
181
+ const_param(out_strides),
182
+ ndim,
183
+ axis_stride,
184
+ axis_size);
185
+ });
186
+ });
187
+ }
188
+
189
+ } // namespace mlx::core
@@ -0,0 +1,150 @@
1
+ # Based on: https://github.com/sivachandran/cmake-bin2h
2
+ #
3
+ # Copyright 2020 Sivachandran Paramasivam
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+
23
+ include(CMakeParseArguments)
24
+
25
+ # Function to wrap a given string into multiple lines at the given column
26
+ # position.
27
+ #
28
+ # Parameters:
29
+ #
30
+ # * VARIABLE - The name of the CMake variable holding the string.
31
+ # * AT_COLUMN - The column position at which string will be wrapped.
32
+ function(WRAP_STRING)
33
+ set(oneValueArgs VARIABLE AT_COLUMN)
34
+ cmake_parse_arguments(WRAP_STRING "${options}" "${oneValueArgs}" "" ${ARGN})
35
+
36
+ string(LENGTH ${${WRAP_STRING_VARIABLE}} stringLength)
37
+ math(EXPR offset "0")
38
+
39
+ while(stringLength GREATER 0)
40
+ if(stringLength GREATER ${WRAP_STRING_AT_COLUMN})
41
+ math(EXPR length "${WRAP_STRING_AT_COLUMN}")
42
+ else()
43
+ math(EXPR length "${stringLength}")
44
+ endif()
45
+
46
+ string(SUBSTRING ${${WRAP_STRING_VARIABLE}} ${offset} ${length} line)
47
+ set(lines "${lines}\n ${line}")
48
+
49
+ math(EXPR stringLength "${stringLength} - ${length}")
50
+ math(EXPR offset "${offset} + ${length}")
51
+ endwhile()
52
+
53
+ set(${WRAP_STRING_VARIABLE}
54
+ "${lines}"
55
+ PARENT_SCOPE)
56
+ endfunction()
57
+
58
+ # Function to embed contents of a file as byte array in C/C++ header file(.h).
59
+ # The header file will contain a byte array and integer variable holding the
60
+ # size of the array.
61
+ #
62
+ # Parameters:
63
+ #
64
+ # * SOURCE_FILES - The paths of source files whose contents will be embedded in
65
+ # the header file.
66
+ # * VARIABLE_NAME - The name of the variable for the byte array. The string
67
+ # "_SIZE" will be append to this name and will be used a variable name for
68
+ # size variable.
69
+ # * HEADER_FILE - The path of header file.
70
+ # * APPEND - If specified appends to the header file instead of overwriting it
71
+ # * HEADER_NAMESPACE - The namespace, where the array should be located in.
72
+ # * NULL_TERMINATE - If specified a null byte(zero) will be append to the byte
73
+ # array.
74
+ #
75
+ # Usage:
76
+ #
77
+ # bin2h(SOURCE_FILE "Logo.png" HEADER_FILE "Logo.h" VARIABLE_NAME "LOGO_PNG")
78
+ function(BIN2H)
79
+ set(options APPEND NULL_TERMINATE)
80
+ set(oneValueArgs VARIABLE_NAME HEADER_FILE HEADER_NAMESPACE)
81
+ set(multiValueArgs SOURCE_FILES)
82
+ cmake_parse_arguments(BIN2H "${options}" "${oneValueArgs}"
83
+ "${multiValueArgs}" ${ARGN})
84
+
85
+ set(arrayDefinition "")
86
+ foreach(SOURCE_FILE IN LISTS BIN2H_SOURCE_FILES)
87
+ # get filename without extension
88
+ get_filename_component(FILE_NAME_WE ${SOURCE_FILE} NAME_WE)
89
+ # convert the filename to a valid C identifier
90
+ string(MAKE_C_IDENTIFIER "${FILE_NAME_WE}" VALID_FILE_NAME)
91
+
92
+ # reads source file contents as hex string
93
+ file(READ ${SOURCE_FILE} hexString HEX)
94
+
95
+ # append null
96
+ if(BIN2H_NULL_TERMINATE)
97
+ string(APPEND hexString "00")
98
+ endif()
99
+
100
+ # wraps the hex string into multiple lines
101
+ wrap_string(VARIABLE hexString AT_COLUMN 24)
102
+
103
+ # strip the © in source code
104
+ string(REGEX REPLACE "c2a9" "2020" arrayValues ${hexString})
105
+
106
+ string(REGEX REPLACE "([0-9a-f][0-9a-f])" " 0x\\1," arrayValues
107
+ ${arrayValues})
108
+
109
+ # make a full variable name for the array
110
+ set(FULL_VARIABLE_NAME "${BIN2H_VARIABLE_NAME}_${VALID_FILE_NAME}")
111
+
112
+ # declares byte array and the length variables
113
+ string(APPEND arrayDefinition
114
+ "constexpr char ${FULL_VARIABLE_NAME}[] = {${arrayValues}\n};\n\n")
115
+ endforeach()
116
+
117
+ # add namespace wrapper if defined
118
+ if(DEFINED BIN2H_HEADER_NAMESPACE)
119
+ set(namespaceStart "namespace ${BIN2H_HEADER_NAMESPACE} {")
120
+ set(namespaceEnd "} // namespace ${BIN2H_HEADER_NAMESPACE}")
121
+ set(declarations "${namespaceStart}\n\n${arrayDefinition}${namespaceEnd}\n")
122
+ endif()
123
+
124
+ set(arrayIncludes "#pragma once")
125
+ string(PREPEND declarations "${arrayIncludes}\n\n")
126
+
127
+ if(BIN2H_APPEND)
128
+ file(APPEND ${BIN2H_HEADER_FILE} "${declarations}")
129
+ else()
130
+ file(WRITE ${BIN2H_HEADER_FILE} "${declarations}")
131
+ endif()
132
+ endfunction()
133
+
134
+ # ----------------------------- CLI args -----------------------------
135
+
136
+ string(REPLACE ":" ";" MLX_JIT_SOURCES_LIST ${MLX_JIT_SOURCES})
137
+ foreach(source ${MLX_JIT_SOURCES_LIST})
138
+ list(APPEND MLX_JIT_SOURCES_ABS "${MLX_SOURCE_ROOT}/${source}")
139
+ endforeach()
140
+
141
+ bin2h(
142
+ SOURCE_FILES
143
+ ${MLX_JIT_SOURCES_ABS}
144
+ NULL_TERMINATE
145
+ VARIABLE_NAME
146
+ "jit_source"
147
+ HEADER_NAMESPACE
148
+ "mlx::core"
149
+ HEADER_FILE
150
+ "${CMAKE_CURRENT_BINARY_DIR}/gen/cuda_jit_sources.h")
@@ -0,0 +1,21 @@
1
+ target_sources(
2
+ mlx
3
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/add.cu
4
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/arctan2.cu
5
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/bitwise_binary.cu
6
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/divide.cu
7
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/equal.cu
8
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater.cu
9
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/greater_equal.cu
10
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less.cu
11
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/less_equal.cu
12
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_and.cu
13
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/logical_or.cu
14
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/log_add_exp.cu
15
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/minimum.cu
16
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/maximum.cu
17
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/multiply.cu
18
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/power.cu
19
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/remainder.cu
20
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/not_equal.cu
21
+ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/subtract.cu)
@@ -0,0 +1,7 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/binary/binary.cuh"
4
+
5
+ namespace mlx::core {
6
+ BINARY_GPU(Add)
7
+ } // namespace mlx::core
@@ -0,0 +1,7 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/binary/binary.cuh"
4
+
5
+ namespace mlx::core {
6
+ BINARY_GPU(ArcTan2)
7
+ } // namespace mlx::core