mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,187 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/cuda/device/config.h"
6
+ #include "mlx/backend/cuda/utils.h"
7
+ #include "mlx/dtype_utils.h"
8
+
9
+ #include <cudnn_frontend.h>
10
+ #include <fmt/format.h>
11
+
12
+ namespace mlx::core {
13
+
14
+ namespace cu {
15
+ class CommandEncoder;
16
+ }
17
+
18
+ namespace fe = cudnn_frontend;
19
+
20
+ #define CHECK_CUDNN_FE_ERROR(cmd) \
21
+ do { \
22
+ auto error = cmd; \
23
+ if (!error.is_good()) { \
24
+ throw std::runtime_error( \
25
+ fmt::format("{} failed: {}.", #cmd, error.get_message())); \
26
+ } \
27
+ } while (0)
28
+
29
+ // Return pointer alignment of |x|'s data.
30
+ inline uint8_t get_alignment(const array& x) {
31
+ uint8_t alignment = 1;
32
+ uintptr_t address = reinterpret_cast<uintptr_t>(gpu_ptr<void>(x));
33
+ for (; alignment < 32; alignment *= 2) {
34
+ if (address % (alignment * 2)) {
35
+ return alignment;
36
+ }
37
+ }
38
+ return alignment;
39
+ }
40
+
41
+ // Convert the type of elements in |vec| to |T|.
42
+ template <typename T, typename Vec>
43
+ inline std::vector<T> convert_vector(const Vec& vec) {
44
+ return std::vector<T>(vec.begin(), vec.end());
45
+ }
46
+
47
+ // Map dtype to cudnn data type.
48
+ inline fe::DataType_t dtype_to_cudnn_type(Dtype dtype) {
49
+ switch (dtype) {
50
+ case int8:
51
+ return fe::DataType_t::INT8;
52
+ case int32:
53
+ return fe::DataType_t::INT32;
54
+ case uint8:
55
+ return fe::DataType_t::UINT8;
56
+ case float16:
57
+ return fe::DataType_t::HALF;
58
+ case bfloat16:
59
+ return fe::DataType_t::BFLOAT16;
60
+ case float32:
61
+ return fe::DataType_t::FLOAT;
62
+ case float64:
63
+ return fe::DataType_t::DOUBLE;
64
+ default:
65
+ throw std::runtime_error(
66
+ fmt::format(
67
+ "Unsupported dtype in cuDNN: {}.", dtype_to_string(dtype)));
68
+ }
69
+ }
70
+
71
+ // Return an array that can be used as map key for |vec| with size <= MAX_NDIM.
72
+ //
73
+ // There are 2 differences from the const_param util from kernel_utils.cuh:
74
+ // 1. The rest of array is filled with 0.
75
+ // 2. This util can be used in .cpp files.
76
+ template <int NDIM = MAX_NDIM, typename Vec>
77
+ inline std::array<typename Vec::value_type, NDIM> vector_key(const Vec& vec) {
78
+ if (vec.size() > NDIM) {
79
+ throw std::runtime_error(
80
+ fmt::format("ndim can not be larger than {}.", NDIM));
81
+ }
82
+ std::array<typename Vec::value_type, NDIM> result = {};
83
+ std::copy_n(vec.begin(), vec.size(), result.begin());
84
+ return result;
85
+ }
86
+
87
+ // Extends cuDNN graph with helpers.
88
+ class DnnGraph : public fe::graph::Graph {
89
+ public:
90
+ DnnGraph(cudnnHandle_t handle, Dtype io_dtype, Dtype compute_dtype = float32)
91
+ : handle_(handle) {
92
+ set_io_data_type(dtype_to_cudnn_type(io_dtype));
93
+ set_intermediate_data_type(dtype_to_cudnn_type(compute_dtype));
94
+ set_compute_data_type(dtype_to_cudnn_type(compute_dtype));
95
+ }
96
+
97
+ // Create a cuDNN tensor description from MLX array |x|.
98
+ auto& tensor(
99
+ std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
100
+ int64_t uid,
101
+ const array& x) {
102
+ set_tensor_attrs(attrs, uid, x);
103
+ return attrs;
104
+ }
105
+ auto tensor(const char* name, int64_t uid, const array& x) {
106
+ auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
107
+ tensor(attrs, uid, x);
108
+ return attrs;
109
+ }
110
+
111
+ // Create a cuDNN tensor description from MLX array |x|, and transpose it from
112
+ // NHWC layout to NCHW.
113
+ auto& tensor_nchw(
114
+ std::shared_ptr<fe::graph::Tensor_attributes>& attrs,
115
+ int64_t uid,
116
+ const array& x) {
117
+ set_tensor_attrs_nchw(attrs, uid, x);
118
+ return attrs;
119
+ }
120
+ auto tensor_nchw(const char* name, int64_t uid, const array& x) {
121
+ auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
122
+ tensor_nchw(attrs, uid, x);
123
+ return attrs;
124
+ }
125
+
126
+ // Create a 4D cuDNN tensor from 1D array, with |axis| being contiguous dim.
127
+ auto tensor_4d(const char* name, int64_t uid, const array& x, int axis) {
128
+ assert(x.ndim() == 1);
129
+ auto attrs = Graph::tensor(fe::graph::Tensor_attributes().set_name(name));
130
+ std::vector<int64_t> shape(4, 1);
131
+ std::vector<int64_t> strides(4, 1);
132
+ shape.at(axis) = x.size();
133
+ if (axis > 0) {
134
+ strides.at(axis - 1) = x.size();
135
+ }
136
+ set_tensor_attrs(attrs, uid, x, shape, strides);
137
+ return attrs;
138
+ }
139
+
140
+ // Create a cuDNN tensor for scalar.
141
+ auto scalar(const char* name, int64_t uid, Dtype dtype) {
142
+ return Graph::tensor(
143
+ fe::graph::Tensor_attributes()
144
+ .set_name(name)
145
+ .set_uid(uid)
146
+ .set_dim({1, 1, 1, 1})
147
+ .set_stride({1, 1, 1, 1})
148
+ .set_is_pass_by_value(true)
149
+ .set_data_type(dtype_to_cudnn_type(dtype)));
150
+ }
151
+
152
+ // Call this before setting notes.
153
+ fe::error_t prepare();
154
+ // Call this after setting notes.
155
+ fe::error_t build();
156
+
157
+ // Add cuDNN graph to CUDA graph, using native CUDA graph API.
158
+ fe::error_t encode_graph(
159
+ cu::CommandEncoder& encoder,
160
+ std::unordered_map<int64_t, void*> variant_pack);
161
+ // Add cuDNN graph to CUDA graph, using stream capture.
162
+ fe::error_t encode_capturing(
163
+ cu::CommandEncoder& encoder,
164
+ std::unordered_map<int64_t, void*> variant_pack);
165
+
166
+ private:
167
+ void* prepare_workspace(cu::CommandEncoder& encoder);
168
+
169
+ void set_tensor_attrs(
170
+ std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
171
+ int64_t uid,
172
+ const array& x,
173
+ const std::vector<int64_t>& shape,
174
+ const std::vector<int64_t>& strides);
175
+ void set_tensor_attrs(
176
+ std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
177
+ int64_t uid,
178
+ const array& x);
179
+ void set_tensor_attrs_nchw(
180
+ std::shared_ptr<fe::graph::Tensor_attributes>& tensor,
181
+ int64_t uid,
182
+ const array& x);
183
+
184
+ cudnnHandle_t handle_;
185
+ };
186
+
187
+ } // namespace mlx::core
@@ -0,0 +1,379 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #include <iostream>
4
+
5
+ #include "mlx/backend/common/compiled.h"
6
+ #include "mlx/backend/cuda/jit_module.h"
7
+ #include "mlx/backend/cuda/utils.h"
8
+ #include "mlx/backend/gpu/copy.h"
9
+ #include "mlx/fast.h"
10
+ #include "mlx/fast_primitives.h"
11
+
12
+ #include <fmt/format.h>
13
+ #include <nvtx3/nvtx3.hpp>
14
+
15
+ namespace mlx::core::fast {
16
+
17
+ namespace {
18
+
19
+ constexpr const char* default_header = R"(
20
+ #include "mlx/backend/cuda/device/utils.cuh"
21
+
22
+ #include <cooperative_groups.h>
23
+
24
+ #define inf cuda::std::numeric_limits<float>::infinity()
25
+
26
+ )";
27
+
28
+ std::string template_arguments_hash(
29
+ const std::vector<std::pair<std::string, TemplateArg>>& template_args) {
30
+ if (template_args.empty()) {
31
+ return "";
32
+ }
33
+
34
+ std::string hash;
35
+ hash.reserve(512);
36
+
37
+ for (const auto& [name, arg] : template_args) {
38
+ if (std::holds_alternative<int>(arg)) {
39
+ hash += fmt::format("_{}", std::get<int>(arg));
40
+ } else if (std::holds_alternative<bool>(arg)) {
41
+ hash += (std::get<bool>(arg)) ? "_t" : "_f";
42
+ } else if (std::holds_alternative<Dtype>(arg)) {
43
+ hash += "_";
44
+ hash += get_type_string(std::get<Dtype>(arg));
45
+ }
46
+ }
47
+
48
+ return hash;
49
+ }
50
+
51
+ std::string build_kernel(
52
+ const std::string& func_name,
53
+ const std::string& header,
54
+ const std::string& source,
55
+ const std::vector<std::string>& input_names,
56
+ const std::vector<array>& inputs,
57
+ const std::vector<std::string>& output_names,
58
+ const std::vector<Dtype>& output_dtypes,
59
+ const std::vector<std::pair<std::string, TemplateArg>>& template_args,
60
+ const std::vector<std::tuple<bool, bool, bool>>& shape_infos) {
61
+ std::string kernel_source;
62
+ kernel_source.reserve(header.size() + source.size() + 8192);
63
+ kernel_source += default_header;
64
+ kernel_source += header;
65
+ kernel_source +=
66
+ "namespace mlx::core::cu {\n\n"
67
+ "namespace cg = cooperative_groups;\n\n";
68
+
69
+ kernel_source += "__global__ void ";
70
+ kernel_source += func_name;
71
+ kernel_source += "(\n";
72
+
73
+ // Add inputs
74
+ for (int i = 0; i < inputs.size(); ++i) {
75
+ const auto& name = input_names[i];
76
+ const auto& arr = inputs[i];
77
+ kernel_source += " const ";
78
+ kernel_source += dtype_to_cuda_type(arr.dtype());
79
+ kernel_source += "* ";
80
+ kernel_source += name;
81
+ kernel_source += ",\n";
82
+ // Add input shape, strides and ndim if present in the source
83
+ if (arr.ndim() > 0) {
84
+ if (std::get<0>(shape_infos[i])) {
85
+ kernel_source += " const __grid_constant__ Shape ";
86
+ kernel_source += name;
87
+ kernel_source += "_shape,\n";
88
+ }
89
+ if (std::get<1>(shape_infos[i])) {
90
+ kernel_source += " const __grid_constant__ Strides ";
91
+ kernel_source += name;
92
+ kernel_source += "_strides,\n";
93
+ }
94
+ if (std::get<2>(shape_infos[i])) {
95
+ kernel_source += " const __grid_constant__ int ";
96
+ kernel_source += name;
97
+ kernel_source += "_ndim,\n";
98
+ }
99
+ }
100
+ }
101
+
102
+ // Add outputs
103
+ for (int i = 0; i < output_names.size(); ++i) {
104
+ const auto& name = output_names[i];
105
+ const auto& dtype = output_dtypes[i];
106
+ kernel_source += " ";
107
+ kernel_source += dtype_to_cuda_type(dtype);
108
+ kernel_source += "* ";
109
+ kernel_source += name;
110
+ if (i < output_names.size() - 1) {
111
+ kernel_source += ",\n";
112
+ } else {
113
+ kernel_source += ") {\n";
114
+ }
115
+ }
116
+
117
+ // Set compile time constants
118
+ if (!template_args.empty()) {
119
+ for (const auto& [name, arg] : template_args) {
120
+ if (std::holds_alternative<int>(arg)) {
121
+ kernel_source +=
122
+ fmt::format(" constexpr int {} = {};\n", name, std::get<int>(arg));
123
+ } else if (std::holds_alternative<bool>(arg)) {
124
+ kernel_source += fmt::format(
125
+ " constexpr bool {} = {};\n", name, std::get<bool>(arg));
126
+ } else {
127
+ kernel_source += fmt::format(
128
+ " using {} = {};\n",
129
+ name,
130
+ dtype_to_cuda_type(std::get<Dtype>(arg)));
131
+ }
132
+ }
133
+ kernel_source += "\n";
134
+ }
135
+
136
+ kernel_source += source;
137
+ kernel_source += "\n}\n\n} // namespace mlx::core::cu\n";
138
+
139
+ return kernel_source;
140
+ }
141
+
142
+ } // namespace
143
+
144
+ CustomKernelFunction cuda_kernel(
145
+ const std::string& name,
146
+ const std::vector<std::string>& input_names,
147
+ const std::vector<std::string>& output_names,
148
+ const std::string& source,
149
+ const std::string& header,
150
+ bool ensure_row_contiguous,
151
+ int shared_memory) {
152
+ if (output_names.empty()) {
153
+ throw std::invalid_argument(
154
+ "[custom_kernel] Must specify at least one output.");
155
+ }
156
+
157
+ std::vector<std::tuple<bool, bool, bool>> shape_infos;
158
+ for (auto& n : input_names) {
159
+ std::tuple<bool, bool, bool> shape_info;
160
+ std::get<0>(shape_info) = source.find(n + "_shape") != std::string::npos;
161
+ std::get<1>(shape_info) = source.find(n + "_strides") != std::string::npos;
162
+ std::get<2>(shape_info) = source.find(n + "_ndim") != std::string::npos;
163
+ shape_infos.push_back(shape_info);
164
+ }
165
+
166
+ return [=, shape_infos = std::move(shape_infos)](
167
+ const std::vector<array>& inputs,
168
+ const std::vector<Shape>& output_shapes,
169
+ const std::vector<Dtype>& output_dtypes,
170
+ std::tuple<int, int, int> grid,
171
+ std::tuple<int, int, int> threadgroup,
172
+ const std::vector<std::pair<std::string, TemplateArg>>&
173
+ template_args = {},
174
+ std::optional<float> init_value = std::nullopt,
175
+ bool verbose = false,
176
+ StreamOrDevice s_ = {}) {
177
+ if (inputs.size() != input_names.size()) {
178
+ std::ostringstream msg;
179
+ msg << "[custom_kernel] Expected `inputs` to have size "
180
+ << input_names.size() << " but got size " << inputs.size() << "."
181
+ << std::endl;
182
+ throw std::invalid_argument(msg.str());
183
+ }
184
+ if (output_shapes.size() != output_names.size()) {
185
+ std::ostringstream msg;
186
+ msg << "[custom_kernel] Expected `output_shapes` to have size "
187
+ << output_names.size() << " but got size " << output_shapes.size()
188
+ << "." << std::endl;
189
+ throw std::invalid_argument(msg.str());
190
+ }
191
+ if (output_dtypes.size() != output_names.size()) {
192
+ std::ostringstream msg;
193
+ msg << "[custom_kernel] Expected `output_dtypes` to have size "
194
+ << output_names.size() << " but got size " << output_dtypes.size()
195
+ << "." << std::endl;
196
+ throw std::invalid_argument(msg.str());
197
+ }
198
+
199
+ auto s = to_stream(s_);
200
+ if (s.device != Device::gpu) {
201
+ throw std::invalid_argument("[custom_kernel] Only supports the GPU.");
202
+ }
203
+
204
+ std::string kernel_name =
205
+ "custom_kernel_" + name + template_arguments_hash(template_args);
206
+ std::string kernel_source = build_kernel(
207
+ kernel_name,
208
+ header,
209
+ source,
210
+ input_names,
211
+ inputs,
212
+ output_names,
213
+ output_dtypes,
214
+ template_args,
215
+ shape_infos);
216
+
217
+ if (verbose) {
218
+ std::cout << "Generated source code for `" << kernel_name
219
+ << "`:" << std::endl
220
+ << "```" << std::endl
221
+ << kernel_source << std::endl
222
+ << "```" << std::endl;
223
+ }
224
+
225
+ return array::make_arrays(
226
+ std::move(output_shapes),
227
+ std::move(output_dtypes),
228
+ std::make_shared<CustomKernel>(
229
+ s,
230
+ std::move(kernel_name),
231
+ std::move(kernel_source),
232
+ grid,
233
+ threadgroup,
234
+ shape_infos,
235
+ ensure_row_contiguous,
236
+ init_value,
237
+ std::vector<ScalarArg>{},
238
+ false,
239
+ shared_memory),
240
+ std::move(inputs));
241
+ };
242
+ }
243
+
244
+ std::vector<array> precompiled_cuda_kernel(
245
+ const std::string& name,
246
+ const std::string& compiled_source,
247
+ const std::vector<array>& inputs,
248
+ const std::vector<Shape>& output_shapes,
249
+ const std::vector<Dtype>& output_dtypes,
250
+ const std::vector<ScalarArg>& scalars,
251
+ std::tuple<int, int, int> grid,
252
+ std::tuple<int, int, int> threadgroup,
253
+ int shared_memory,
254
+ std::optional<float> init_value,
255
+ bool ensure_row_contiguous,
256
+ StreamOrDevice s) {
257
+ std::vector<std::tuple<bool, bool, bool>> shape_infos(
258
+ inputs.size(), {false, false, false});
259
+ return array::make_arrays(
260
+ output_shapes,
261
+ output_dtypes,
262
+ std::make_shared<CustomKernel>(
263
+ to_stream(s),
264
+ name,
265
+ compiled_source,
266
+ grid,
267
+ threadgroup,
268
+ shape_infos,
269
+ ensure_row_contiguous,
270
+ init_value,
271
+ scalars,
272
+ true,
273
+ shared_memory),
274
+ inputs);
275
+ }
276
+
277
+ void CustomKernel::eval_gpu(
278
+ const std::vector<array>& inputs,
279
+ std::vector<array>& outputs) {
280
+ nvtx3::scoped_range r("CustomKernel::eval_gpu");
281
+ auto& s = stream();
282
+ auto& encoder = cu::get_command_encoder(s);
283
+
284
+ std::vector<array> copies;
285
+
286
+ // Allocate and initialize the output arrays
287
+ for (auto& out : outputs) {
288
+ if (init_value_) {
289
+ copies.emplace_back(init_value_.value(), out.dtype());
290
+ fill_gpu(copies.back(), out, s);
291
+ } else {
292
+ out.set_data(cu::malloc_async(out.nbytes(), encoder));
293
+ }
294
+ }
295
+
296
+ // Create the input arrays and copy if needed
297
+ auto check_input = [&copies, &s, this](const array& x) -> const array {
298
+ bool no_copy = x.flags().row_contiguous;
299
+ if (!ensure_row_contiguous_ || no_copy) {
300
+ return x;
301
+ } else {
302
+ copies.push_back(array(x.shape(), x.dtype(), nullptr, {}));
303
+ copy_gpu(x, copies.back(), CopyType::General, s);
304
+ return copies.back();
305
+ }
306
+ };
307
+ std::vector<array> checked_inputs;
308
+ for (const array& in : inputs) {
309
+ checked_inputs.push_back(check_input(in));
310
+ }
311
+
312
+ // Compile the custom kernel
313
+ std::string kernel_name =
314
+ (is_precompiled_) ? name_ : "mlx::core::cu::" + name_;
315
+ cu::JitModule& mod = cu::get_jit_module(
316
+ s.device,
317
+ name_,
318
+ [&]() {
319
+ return std::make_tuple(
320
+ is_precompiled_, source_, std::vector{kernel_name});
321
+ },
322
+ false);
323
+
324
+ // Make the arguments
325
+ cu::KernelArgs args;
326
+ for (int i = 0; i < checked_inputs.size(); i++) {
327
+ const array& in = checked_inputs[i];
328
+ auto& shape_info = shape_infos_[i];
329
+ args.append(in);
330
+ if (std::get<0>(shape_info)) {
331
+ args.append_ndim(in.shape());
332
+ }
333
+ if (std::get<1>(shape_info)) {
334
+ args.append_ndim(in.strides());
335
+ }
336
+ if (std::get<2>(shape_info)) {
337
+ args.append<int32_t>(in.ndim());
338
+ }
339
+ }
340
+ for (auto& out : outputs) {
341
+ args.append(out);
342
+ }
343
+ for (auto& s : scalar_arguments_) {
344
+ if (std::holds_alternative<bool>(s)) {
345
+ args.append(std::get<bool>(s));
346
+ } else if (std::holds_alternative<int>(s)) {
347
+ args.append(std::get<int>(s));
348
+ } else if (std::holds_alternative<float>(s)) {
349
+ args.append(std::get<float>(s));
350
+ }
351
+ }
352
+
353
+ // Make the grid
354
+ const auto [tx, ty, tz] = threadgroup_;
355
+ const auto [gx, gy, gz] = grid_;
356
+ dim3 block(std::min(tx, gx), std::min(ty, gy), std::min(tz, gz));
357
+ dim3 grid((gx + tx - 1) / tx, (gy + ty - 1) / ty, (gz + tz - 1) / tz);
358
+
359
+ // Call the kernel
360
+ for (const auto& in : checked_inputs) {
361
+ encoder.set_input_array(in);
362
+ }
363
+ for (const auto& out : outputs) {
364
+ encoder.set_output_array(out);
365
+ }
366
+ for (const auto& t : copies) {
367
+ encoder.add_temporary(t);
368
+ }
369
+ auto kernel =
370
+ mod.get_kernel(kernel_name, [smem = shared_memory_](CUfunction kernel) {
371
+ if (smem > 0 && smem > 48000) {
372
+ cuFuncSetAttribute(
373
+ kernel, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, smem);
374
+ }
375
+ });
376
+ encoder.add_kernel_node(kernel, grid, block, shared_memory_, args.args());
377
+ }
378
+
379
+ } // namespace mlx::core::fast
@@ -0,0 +1,46 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/dtype.h"
6
+
7
+ #include <cutlass/bfloat16.h>
8
+ #include <cutlass/half.h>
9
+ #include <fmt/format.h>
10
+
11
+ namespace mlx::core {
12
+
13
+ // Throw exception if the cutlass API does not succeed.
14
+ inline void check_cutlass_error(const char* name, cutlass::Status status) {
15
+ if (status != cutlass::Status::kSuccess) {
16
+ throw std::runtime_error(
17
+ fmt::format(
18
+ "{} failed with code: {}.",
19
+ name,
20
+ cutlass::cutlassGetStatusString(status)));
21
+ }
22
+ }
23
+
24
+ // The macro version that prints the command that failed.
25
+ #define CHECK_CUTLASS_ERROR(cmd) check_cutlass_error(#cmd, (cmd))
26
+
27
+ // Maps CPU types to CUTLASS types.
28
+ template <typename T>
29
+ struct CTypeToCutlassType {
30
+ using type = T;
31
+ };
32
+
33
+ template <>
34
+ struct CTypeToCutlassType<float16_t> {
35
+ using type = cutlass::half_t;
36
+ };
37
+
38
+ template <>
39
+ struct CTypeToCutlassType<bfloat16_t> {
40
+ using type = cutlass::bfloat16_t;
41
+ };
42
+
43
+ template <typename T>
44
+ using cutlass_type_t = typename CTypeToCutlassType<T>::type;
45
+
46
+ } // namespace mlx::core
@@ -0,0 +1,80 @@
1
+ // Copyright © 2026 Apple Inc.
2
+
3
+ #include "mlx/backend/common/utils.h"
4
+
5
+ // clang-format off
6
+ #include <windows.h> // must be included first
7
+ #include <delayimp.h>
8
+ // clang-format on
9
+
10
+ namespace mlx::core {
11
+
12
+ namespace fs = std::filesystem;
13
+
14
+ inline fs::path relative_to_current_binary(const char* relative) {
15
+ return fs::absolute(current_binary_dir() / relative);
16
+ }
17
+
18
+ inline fs::path cublas_bin_dir() {
19
+ #if defined(MLX_CUDA_BIN_DIR)
20
+ return MLX_CUDA_BIN_DIR;
21
+ #else
22
+ return relative_to_current_binary("../nvidia/cublas/bin");
23
+ #endif
24
+ }
25
+
26
+ fs::path load_nvrtc() {
27
+ #if defined(MLX_CUDA_BIN_DIR)
28
+ fs::path nvrtc_bin_dir = MLX_CUDA_BIN_DIR;
29
+ #else
30
+ fs::path nvrtc_bin_dir =
31
+ relative_to_current_binary("../nvidia/cuda_nvrtc/bin");
32
+ #endif
33
+ // Internally nvrtc loads some libs dynamically, add to search dirs.
34
+ ::AddDllDirectory(nvrtc_bin_dir.c_str());
35
+ return nvrtc_bin_dir;
36
+ }
37
+
38
+ fs::path load_cudnn() {
39
+ #if defined(MLX_CUDNN_BIN_DIR)
40
+ fs::path cudnn_bin_dir = MLX_CUDNN_BIN_DIR;
41
+ #else
42
+ fs::path cudnn_bin_dir = relative_to_current_binary("../nvidia/cudnn/bin");
43
+ #endif
44
+ // Must load cudnn_graph64_9.dll before locating symbols, otherwise We would
45
+ // get errors like "Invalid handle. Cannot load symbol cudnnCreate".
46
+ for (const auto& dll : fs::directory_iterator(cudnn_bin_dir)) {
47
+ if (dll.path().filename().string().starts_with("cudnn_graph") &&
48
+ dll.path().extension() == ".dll") {
49
+ ::LoadLibraryW(dll.path().c_str());
50
+ break;
51
+ }
52
+ }
53
+ // Internally cuDNN loads some libs dynamically, add to search dirs.
54
+ load_nvrtc();
55
+ ::AddDllDirectory(cudnn_bin_dir.c_str());
56
+ ::AddDllDirectory(cublas_bin_dir().c_str());
57
+ return cudnn_bin_dir;
58
+ }
59
+
60
+ // Called by system when failed to locate a lazy-loaded DLL.
61
+ FARPROC WINAPI delayload_helper(unsigned dliNotify, PDelayLoadInfo pdli) {
62
+ HMODULE mod = NULL;
63
+ if (dliNotify == dliNotePreLoadLibrary) {
64
+ std::string dll = pdli->szDll;
65
+ if (dll.starts_with("cudnn")) {
66
+ static auto cudnn_bin_dir = load_cudnn();
67
+ mod = ::LoadLibraryW((cudnn_bin_dir / dll).c_str());
68
+ } else if (dll.starts_with("cublas")) {
69
+ mod = ::LoadLibraryW((cublas_bin_dir() / dll).c_str());
70
+ } else if (dll.starts_with("nvrtc")) {
71
+ static auto nvrtc_bin_dir = load_nvrtc();
72
+ mod = ::LoadLibraryW((nvrtc_bin_dir / dll).c_str());
73
+ }
74
+ }
75
+ return reinterpret_cast<FARPROC>(mod);
76
+ }
77
+
78
+ } // namespace mlx::core
79
+
80
+ extern "C" const PfnDliHook __pfnDliNotifyHook2 = mlx::core::delayload_helper;