mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,132 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/common/utils.h"
4
+ #include "mlx/backend/cuda/copy/copy.cuh"
5
+
6
+ namespace mlx::core {
7
+
8
+ void copy_gpu(const array& in, array& out, CopyType ctype, const Stream& s) {
9
+ auto& encoder = cu::get_command_encoder(s);
10
+ bool donated = set_copy_output_data(
11
+ in, out, ctype, [&](auto n) { return cu::malloc_async(n, encoder); });
12
+ if (donated && in.dtype() == out.dtype()) {
13
+ // If the output has the same type as the input then there is nothing to
14
+ // copy, just use the buffer.
15
+ return;
16
+ }
17
+ if (ctype == CopyType::GeneralGeneral) {
18
+ ctype = CopyType::General;
19
+ }
20
+ copy_gpu_inplace(in, out, ctype, s);
21
+ }
22
+
23
+ void copy_gpu_inplace(
24
+ const array& in,
25
+ array& out,
26
+ const Shape& shape,
27
+ const Strides& strides_in,
28
+ const Strides& strides_out,
29
+ int64_t offset_in,
30
+ int64_t offset_out,
31
+ CopyType ctype,
32
+ const Stream& s,
33
+ std::optional<array> dynamic_offset_in,
34
+ std::optional<array> dynamic_offset_out) {
35
+ if (out.size() == 0) {
36
+ return;
37
+ }
38
+
39
+ auto& encoder = cu::get_command_encoder(s);
40
+ encoder.set_input_array(in);
41
+ encoder.set_output_array(out);
42
+ if (ctype == CopyType::Scalar || ctype == CopyType::Vector) {
43
+ copy_contiguous(encoder, ctype, in, out, offset_in, offset_out);
44
+ return;
45
+ }
46
+
47
+ if (ctype == CopyType::General || ctype == CopyType::GeneralGeneral) {
48
+ auto [shape_collapsed, strides_vec] = collapse_contiguous_dims(
49
+ shape, std::vector{strides_in, strides_out}, INT32_MAX);
50
+ if (ctype == CopyType::General) {
51
+ copy_general_input(
52
+ encoder,
53
+ ctype,
54
+ in,
55
+ out,
56
+ offset_in,
57
+ offset_out,
58
+ shape_collapsed,
59
+ strides_vec[0]);
60
+ } else {
61
+ if (dynamic_offset_in || dynamic_offset_out) {
62
+ if (!dynamic_offset_in) {
63
+ dynamic_offset_in = array(0, int64);
64
+ encoder.add_temporary(*dynamic_offset_in);
65
+ }
66
+ if (!dynamic_offset_out) {
67
+ dynamic_offset_out = array(0, int64);
68
+ encoder.add_temporary(*dynamic_offset_out);
69
+ }
70
+ encoder.set_input_array(*dynamic_offset_in);
71
+ encoder.set_input_array(*dynamic_offset_out);
72
+ copy_general_dynamic(
73
+ encoder,
74
+ ctype,
75
+ in,
76
+ out,
77
+ offset_in,
78
+ offset_out,
79
+ shape_collapsed,
80
+ strides_vec[0],
81
+ strides_vec[1],
82
+ *dynamic_offset_in,
83
+ *dynamic_offset_out);
84
+ } else {
85
+ copy_general(
86
+ encoder,
87
+ ctype,
88
+ in,
89
+ out,
90
+ offset_in,
91
+ offset_out,
92
+ shape_collapsed,
93
+ strides_vec[0],
94
+ strides_vec[1]);
95
+ }
96
+ }
97
+ return;
98
+ }
99
+ }
100
+
101
+ void fill_gpu(const array& in, array& out, const Stream& s) {
102
+ if (out.size() == 0) {
103
+ return;
104
+ }
105
+ auto& encoder = cu::get_command_encoder(s);
106
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
107
+ encoder.set_input_array(in);
108
+ encoder.set_output_array(out);
109
+ copy_contiguous(encoder, CopyType::Scalar, in, out, 0, 0);
110
+ }
111
+
112
+ void reshape_gpu(const array& in, array& out, Stream s) {
113
+ auto [copy_necessary, out_strides] = prepare_reshape(in, out);
114
+ if (copy_necessary) {
115
+ auto& encoder = cu::get_command_encoder(s);
116
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
117
+ copy_gpu_inplace(
118
+ in,
119
+ out,
120
+ in.shape(),
121
+ in.strides(),
122
+ make_contiguous_strides(in.shape()),
123
+ 0,
124
+ 0,
125
+ CopyType::General,
126
+ s);
127
+ } else {
128
+ shared_buffer_reshape(in, out_strides, out);
129
+ }
130
+ }
131
+
132
+ } // namespace mlx::core
@@ -0,0 +1,222 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/cublas_utils.h"
4
+ #include "mlx/backend/cuda/cuda.h"
5
+ #include "mlx/utils.h"
6
+
7
+ namespace mlx::core {
8
+ namespace cublas_utils {
9
+
10
+ namespace {
11
+
12
+ struct CublasPreference {
13
+ CublasPreference(cu::Device& device) {
14
+ // The recommended cublas workspace size is 4 MiB for pre-Hopper and 32 MiB
15
+ // for Hopper+:
16
+ // https://docs.nvidia.com/cuda/cublas/#cublassetworkspace
17
+ uint64_t MiB = 1024 * 1024;
18
+ uint64_t workspace_size =
19
+ device.compute_capability_major() >= 9 ? 32 * MiB : 4 * MiB;
20
+
21
+ CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceCreate(&pref_));
22
+ CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceSetAttribute(
23
+ pref_,
24
+ CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
25
+ &workspace_size,
26
+ sizeof(uint64_t)));
27
+ }
28
+
29
+ ~CublasPreference() {
30
+ CHECK_CUBLAS_ERROR(cublasLtMatmulPreferenceDestroy(pref_));
31
+ }
32
+
33
+ cublasLtMatmulPreference_t pref_{nullptr};
34
+ };
35
+
36
+ } // namespace
37
+
38
+ cublasLtMatmulPreference_t get_preference(cu::Device& device) {
39
+ static CublasPreference pref(device);
40
+ return pref.pref_;
41
+ }
42
+
43
+ cublasLtMatrixLayout_t create_matrix_layout(
44
+ cudaDataType_t type,
45
+ uint64_t rows,
46
+ uint64_t cols,
47
+ bool transposed,
48
+ int64_t ld,
49
+ int32_t batch_count,
50
+ int64_t batch_stride) {
51
+ cublasLtMatrixLayout_t desc;
52
+ if (transposed) {
53
+ std::swap(rows, cols);
54
+ }
55
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutCreate(&desc, type, rows, cols, ld));
56
+ if (batch_count > 1) {
57
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
58
+ desc,
59
+ CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
60
+ &batch_count,
61
+ sizeof(int32_t)));
62
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutSetAttribute(
63
+ desc,
64
+ CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
65
+ &batch_stride,
66
+ sizeof(int64_t)));
67
+ }
68
+ return desc;
69
+ }
70
+
71
+ } // namespace cublas_utils
72
+
73
+ CublasMatmulBase::~CublasMatmulBase() {
74
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(a_desc_));
75
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(b_desc_));
76
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(c_desc_));
77
+ CHECK_CUBLAS_ERROR(cublasLtMatrixLayoutDestroy(out_desc_));
78
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescDestroy(matmul_desc_));
79
+ }
80
+
81
+ void CublasMatmulBase::init_base(
82
+ cu::Device& device,
83
+ cudaDataType_t scale_type,
84
+ cublasComputeType_t compute_type,
85
+ cudaDataType_t data_type,
86
+ cudaDataType_t output_type,
87
+ bool a_transposed,
88
+ uint64_t a_rows,
89
+ uint64_t a_cols,
90
+ int64_t lda,
91
+ bool b_transposed,
92
+ uint64_t b_rows,
93
+ uint64_t b_cols,
94
+ int64_t ldb,
95
+ int32_t batch_count,
96
+ int64_t a_batch_stride,
97
+ int64_t b_batch_stride) {
98
+ M_ = a_rows;
99
+ N_ = b_cols;
100
+ scale_type_ = scale_type;
101
+ handle_ = device.get_cublaslt_handle();
102
+ pref_ = cublas_utils::get_preference(device);
103
+ heuristic_.state = CUBLAS_STATUS_NOT_INITIALIZED;
104
+
105
+ CHECK_CUBLAS_ERROR(
106
+ cublasLtMatmulDescCreate(&matmul_desc_, compute_type, scale_type));
107
+
108
+ int32_t pointer_mode = CUBLASLT_POINTER_MODE_HOST;
109
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
110
+ matmul_desc_,
111
+ CUBLASLT_MATMUL_DESC_POINTER_MODE,
112
+ &pointer_mode,
113
+ sizeof(int32_t)));
114
+
115
+ // In cublasLt matrices use column-major layout, while it is possible to use
116
+ // the CUBLASLT_ORDER_ROW option to switch to row-major layout, the bias
117
+ // epilogue does not work with the option. So instead we swap A and B to make
118
+ // cublasLt return the row-major result, which works because:
119
+ // - the data of a matrix in row-major layout is identical to its transpose in
120
+ // column-major layout
121
+ // - C^T = (A @ B)^T = B^T @ A^T
122
+ cublasOperation_t a_op = b_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
123
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
124
+ matmul_desc_,
125
+ CUBLASLT_MATMUL_DESC_TRANSA,
126
+ &a_op,
127
+ sizeof(cublasOperation_t)));
128
+ cublasOperation_t b_op = a_transposed ? CUBLAS_OP_T : CUBLAS_OP_N;
129
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
130
+ matmul_desc_,
131
+ CUBLASLT_MATMUL_DESC_TRANSB,
132
+ &b_op,
133
+ sizeof(cublasOperation_t)));
134
+
135
+ a_desc_ = cublas_utils::create_matrix_layout(
136
+ data_type,
137
+ b_cols,
138
+ b_rows,
139
+ b_transposed,
140
+ ldb,
141
+ batch_count,
142
+ b_batch_stride);
143
+ b_desc_ = cublas_utils::create_matrix_layout(
144
+ data_type,
145
+ a_cols,
146
+ a_rows,
147
+ a_transposed,
148
+ lda,
149
+ batch_count,
150
+ a_batch_stride);
151
+ out_desc_ = cublas_utils::create_matrix_layout(
152
+ output_type, b_cols, a_rows, false, b_cols, batch_count, b_cols * a_rows);
153
+ }
154
+
155
+ void CublasMatmulBase::execute_matmul(
156
+ cu::CommandEncoder& encoder,
157
+ void* out,
158
+ const void* a,
159
+ const void* b,
160
+ const void* c,
161
+ const void* alpha_ptr,
162
+ const void* beta_ptr) {
163
+ if (heuristic_.state != CUBLAS_STATUS_SUCCESS) {
164
+ int ret = 0;
165
+ CHECK_CUBLAS_ERROR(cublasLtMatmulAlgoGetHeuristic(
166
+ handle_,
167
+ matmul_desc_,
168
+ a_desc_,
169
+ b_desc_,
170
+ c ? c_desc_ : out_desc_,
171
+ out_desc_,
172
+ pref_,
173
+ 1,
174
+ &heuristic_,
175
+ &ret));
176
+ if (ret == 0) {
177
+ throw std::runtime_error("Can not find algorithm for matmul.");
178
+ }
179
+ }
180
+
181
+ void* workspace_ptr = allocate_workspace(encoder, heuristic_.workspaceSize);
182
+
183
+ // Execute matmul
184
+ auto capture = encoder.capture_context();
185
+ CHECK_CUBLAS_ERROR(cublasLtMatmul(
186
+ handle_,
187
+ matmul_desc_,
188
+ alpha_ptr,
189
+ b, // a and b are swapped for row-major layout
190
+ a_desc_,
191
+ a,
192
+ b_desc_,
193
+ beta_ptr,
194
+ c ? c : out,
195
+ c ? c_desc_ : out_desc_,
196
+ out,
197
+ out_desc_,
198
+ &heuristic_.algo,
199
+ workspace_ptr,
200
+ heuristic_.workspaceSize,
201
+ encoder.stream()));
202
+ }
203
+
204
+ void CublasMatmulBase::set_bias(
205
+ cu::CommandEncoder& encoder,
206
+ const array& bias) {
207
+ encoder.set_input_array(bias);
208
+ cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
209
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
210
+ matmul_desc_,
211
+ CUBLASLT_MATMUL_DESC_EPILOGUE,
212
+ &epilogue,
213
+ sizeof(epilogue)));
214
+ auto* bias_ptr = gpu_ptr<void>(bias);
215
+ CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(
216
+ matmul_desc_,
217
+ CUBLASLT_MATMUL_DESC_BIAS_POINTER,
218
+ &bias_ptr,
219
+ sizeof(bias_ptr)));
220
+ }
221
+
222
+ } // namespace mlx::core
@@ -0,0 +1,95 @@
1
+ // Copyright © 2025 Apple Inc.
2
+ #pragma once
3
+
4
+ #include <cublasLt.h>
5
+ #include "mlx/array.h"
6
+ #include "mlx/backend/cuda/device.h"
7
+ #include "mlx/dtype_utils.h"
8
+
9
+ namespace mlx::core {
10
+ namespace cublas_utils {
11
+
12
+ // Get the shared cublas preference for a device
13
+ cublasLtMatmulPreference_t get_preference(cu::Device& device);
14
+
15
+ cublasLtMatrixLayout_t create_matrix_layout(
16
+ cudaDataType_t type,
17
+ uint64_t rows,
18
+ uint64_t cols,
19
+ bool transposed,
20
+ int64_t ld,
21
+ int32_t batch_count,
22
+ int64_t batch_stride);
23
+
24
+ inline cudaDataType_t dtype_to_cublas_type(Dtype dtype, std::string_view tag) {
25
+ switch (dtype) {
26
+ case float16:
27
+ return CUDA_R_16F;
28
+ case bfloat16:
29
+ return CUDA_R_16BF;
30
+ case float32:
31
+ return CUDA_R_32F;
32
+ case float64:
33
+ return CUDA_R_64F;
34
+ case complex64:
35
+ return CUDA_C_32F;
36
+ default:
37
+ throw std::runtime_error(
38
+ fmt::format(
39
+ "Unsupported dtype in {}: {}.", tag, dtype_to_string(dtype)));
40
+ }
41
+ }
42
+
43
+ } // namespace cublas_utils
44
+
45
+ class CublasMatmulBase {
46
+ public:
47
+ virtual ~CublasMatmulBase();
48
+
49
+ void set_bias(cu::CommandEncoder& encoder, const array& bias);
50
+
51
+ protected:
52
+ CublasMatmulBase() = default;
53
+
54
+ // Common member variables shared by all matmul types
55
+ uint64_t M_;
56
+ uint64_t N_;
57
+ cudaDataType_t scale_type_;
58
+ cublasLtMatmulPreference_t pref_{nullptr};
59
+ cublasLtHandle_t handle_{nullptr};
60
+ cublasLtMatmulDesc_t matmul_desc_{nullptr};
61
+ cublasLtMatrixLayout_t a_desc_{nullptr};
62
+ cublasLtMatrixLayout_t b_desc_{nullptr};
63
+ cublasLtMatrixLayout_t c_desc_{nullptr};
64
+ cublasLtMatrixLayout_t out_desc_{nullptr};
65
+ cublasLtMatmulHeuristicResult_t heuristic_;
66
+
67
+ void init_base(
68
+ cu::Device& device,
69
+ cudaDataType_t scale_type,
70
+ cublasComputeType_t compute_type,
71
+ cudaDataType_t data_type,
72
+ cudaDataType_t output_type,
73
+ bool a_transposed,
74
+ uint64_t a_rows,
75
+ uint64_t a_cols,
76
+ int64_t lda,
77
+ bool b_transposed,
78
+ uint64_t b_rows,
79
+ uint64_t b_cols,
80
+ int64_t ldb,
81
+ int32_t batch_count,
82
+ int64_t a_batch_stride,
83
+ int64_t b_batch_stride);
84
+
85
+ void execute_matmul(
86
+ cu::CommandEncoder& encoder,
87
+ void* out,
88
+ const void* a,
89
+ const void* b,
90
+ const void* c,
91
+ const void* alpha_ptr,
92
+ const void* beta_ptr);
93
+ };
94
+
95
+ } // namespace mlx::core
@@ -0,0 +1,21 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <string>
6
+ #include <unordered_map>
7
+ #include <variant>
8
+
9
+ #include "mlx/api.h"
10
+
11
+ namespace mlx::core::cu {
12
+
13
+ /* Check if the CUDA backend is available. */
14
+ MLX_API bool is_available();
15
+
16
+ /* Get information about a CUDA device. */
17
+ MLX_API const
18
+ std::unordered_map<std::string, std::variant<std::string, size_t>>&
19
+ device_info(int device_index = 0);
20
+
21
+ } // namespace mlx::core::cu
@@ -0,0 +1,90 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <cublasLt.h>
6
+ #include <cuda.h>
7
+ #include <cuda_runtime.h>
8
+ #include <cudnn.h>
9
+
10
+ namespace mlx::core {
11
+
12
+ // Throw exception if the cuda API does not succeed.
13
+ void check_cublas_error(const char* name, cublasStatus_t err);
14
+ void check_cuda_error(const char* name, cudaError_t err);
15
+ void check_cuda_error(const char* name, CUresult err);
16
+ void check_cudnn_error(const char* name, cudnnStatus_t err);
17
+
18
+ // The macro version that prints the command that failed.
19
+ #define CHECK_CUBLAS_ERROR(cmd) check_cublas_error(#cmd, (cmd))
20
+ #define CHECK_CUDA_ERROR(cmd) check_cuda_error(#cmd, (cmd))
21
+ #define CHECK_CUDNN_ERROR(cmd) check_cudnn_error(#cmd, (cmd))
22
+
23
+ // Base class for RAII managed CUDA resources.
24
+ template <typename Handle, cudaError_t (*Destroy)(Handle)>
25
+ class CudaHandle {
26
+ public:
27
+ CudaHandle(Handle handle = nullptr) : handle_(handle) {}
28
+
29
+ CudaHandle(CudaHandle&& other) : handle_(other.handle_) {
30
+ assert(this != &other);
31
+ other.handle_ = nullptr;
32
+ }
33
+
34
+ ~CudaHandle() {
35
+ // Skip if there was an error to avoid throwing in the destructors
36
+ if (cudaPeekAtLastError() != cudaSuccess) {
37
+ return;
38
+ }
39
+ reset();
40
+ }
41
+
42
+ CudaHandle(const CudaHandle&) = delete;
43
+ CudaHandle& operator=(const CudaHandle&) = delete;
44
+
45
+ CudaHandle& operator=(CudaHandle&& other) {
46
+ assert(this != &other);
47
+ reset();
48
+ std::swap(handle_, other.handle_);
49
+ return *this;
50
+ }
51
+
52
+ void reset() {
53
+ if (handle_ != nullptr) {
54
+ CHECK_CUDA_ERROR(Destroy(handle_));
55
+ handle_ = nullptr;
56
+ }
57
+ }
58
+
59
+ operator Handle() const {
60
+ return handle_;
61
+ }
62
+
63
+ protected:
64
+ Handle handle_;
65
+ };
66
+
67
+ namespace cu {
68
+ class Device;
69
+ }; // namespace cu
70
+
71
+ // Wrappers of CUDA resources.
72
+ class CudaGraph : public CudaHandle<cudaGraph_t, cudaGraphDestroy> {
73
+ public:
74
+ using CudaHandle::CudaHandle;
75
+ explicit CudaGraph(cu::Device& device);
76
+ void end_capture(cudaStream_t stream);
77
+ };
78
+
79
+ class CudaGraphExec : public CudaHandle<cudaGraphExec_t, cudaGraphExecDestroy> {
80
+ public:
81
+ void instantiate(cudaGraph_t graph);
82
+ };
83
+
84
+ class CudaStream : public CudaHandle<cudaStream_t, cudaStreamDestroy> {
85
+ public:
86
+ using CudaHandle::CudaHandle;
87
+ explicit CudaStream(cu::Device& device);
88
+ };
89
+
90
+ } // namespace mlx::core
@@ -0,0 +1,133 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include "mlx/backend/cuda/cudnn_utils.h"
4
+ #include "mlx/backend/cuda/device.h"
5
+
6
+ namespace mlx::core {
7
+
8
+ namespace {
9
+
10
+ #define RETURN_IF_ERROR(cmd) \
11
+ if (auto ret = cmd; ret.is_bad()) { \
12
+ return ret; \
13
+ }
14
+
15
+ // In MLX a singleton dim (shape[dim] == 1) can have any stride, but in cuDNN
16
+ // whether a tensor is contiguous is determined with:
17
+ // shape[dim] == shape[dim + 1] * strides[dim + 1]
18
+ // So a contiguous array with singleton dims in MLX may be mistakenly treated
19
+ // as strided in cuDNN, and we work around it by normalizing the strides.
20
+ std::vector<int64_t> normalized_strides(const array& x) {
21
+ std::vector<int64_t> strides(x.strides().begin(), x.strides().end());
22
+ if (std::all_of(
23
+ strides.begin(), strides.end(), [](int64_t s) { return s == 0; })) {
24
+ strides.back() = 1;
25
+ return strides;
26
+ }
27
+ if (!x.flags().row_contiguous || x.ndim() < 2) {
28
+ return strides;
29
+ }
30
+ for (int i = x.ndim() - 2; i >= 0; --i) {
31
+ if (x.shape(i) == 1) {
32
+ strides[i] = x.shape(i + 1) * strides[i + 1];
33
+ }
34
+ }
35
+ return strides;
36
+ }
37
+
38
+ // Return the shape and strides after transposing from NHWC to NCHW.
39
+ inline auto nhwc_to_nchw(const array& x) {
40
+ auto shape = convert_vector<int64_t>(x.shape());
41
+ auto strides = normalized_strides(x);
42
+ assert(shape.size() >= 3);
43
+ shape.insert(shape.begin() + 1, shape.back());
44
+ shape.erase(shape.end() - 1);
45
+ strides.insert(strides.begin() + 1, strides.back());
46
+ strides.erase(strides.end() - 1);
47
+ return std::make_tuple(std::move(shape), std::move(strides));
48
+ }
49
+
50
+ } // namespace
51
+
52
+ fe::error_t DnnGraph::prepare() {
53
+ RETURN_IF_ERROR(validate());
54
+ try {
55
+ RETURN_IF_ERROR(build_operation_graph(handle_));
56
+ } catch (cudnn_frontend::cudnnException& error) {
57
+ // cuDNN bug: they did not catch all exceptions in the API.
58
+ return {fe::error_code_t::CUDNN_BACKEND_API_FAILED, error.what()};
59
+ }
60
+ RETURN_IF_ERROR(create_execution_plans({fe::HeurMode_t::A}));
61
+ return {};
62
+ }
63
+
64
+ fe::error_t DnnGraph::build() {
65
+ RETURN_IF_ERROR(check_support(handle_));
66
+ RETURN_IF_ERROR(build_plans(handle_));
67
+ return {};
68
+ }
69
+
70
+ fe::error_t DnnGraph::encode_graph(
71
+ cu::CommandEncoder& encoder,
72
+ std::unordered_map<int64_t, void*> variant_pack) {
73
+ cudnnSetStream(handle_, encoder.stream());
74
+ CudaGraph cuda_graph(encoder.device());
75
+ RETURN_IF_ERROR(populate_cuda_graph(
76
+ handle_, variant_pack, prepare_workspace(encoder), cuda_graph));
77
+ encoder.add_graph_node(cuda_graph);
78
+ return {};
79
+ }
80
+
81
+ fe::error_t DnnGraph::encode_capturing(
82
+ cu::CommandEncoder& encoder,
83
+ std::unordered_map<int64_t, void*> variant_pack) {
84
+ auto* workspace_ptr = prepare_workspace(encoder);
85
+ auto capture = encoder.capture_context();
86
+ cudnnSetStream(handle_, encoder.stream());
87
+ auto ret = execute(handle_, variant_pack, workspace_ptr);
88
+ if (ret.is_bad()) {
89
+ capture.discard = true;
90
+ }
91
+ return ret;
92
+ }
93
+
94
+ void* DnnGraph::prepare_workspace(cu::CommandEncoder& encoder) {
95
+ int64_t workspace_size = 0;
96
+ CHECK_CUDNN_FE_ERROR(get_workspace_size(workspace_size));
97
+ return allocate_workspace(encoder, workspace_size);
98
+ }
99
+
100
+ void DnnGraph::set_tensor_attrs(
101
+ std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
102
+ int64_t uid,
103
+ const array& x,
104
+ const std::vector<int64_t>& shape,
105
+ const std::vector<int64_t>& strides) {
106
+ tensor->set_uid(uid)
107
+ .set_alignment(get_alignment(x))
108
+ .set_data_type(dtype_to_cudnn_type(x.dtype()))
109
+ .set_dim(shape)
110
+ .set_stride(strides);
111
+ }
112
+
113
+ void DnnGraph::set_tensor_attrs(
114
+ std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
115
+ int64_t uid,
116
+ const array& x) {
117
+ set_tensor_attrs(
118
+ tensor,
119
+ uid,
120
+ x,
121
+ convert_vector<int64_t>(x.shape()),
122
+ normalized_strides(x));
123
+ }
124
+
125
+ void DnnGraph::set_tensor_attrs_nchw(
126
+ std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
127
+ int64_t uid,
128
+ const array& x) {
129
+ auto [shape, strides] = nhwc_to_nchw(x);
130
+ set_tensor_attrs(tensor, uid, x, shape, strides);
131
+ }
132
+
133
+ } // namespace mlx::core