mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,9 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #define MLX_UNROLL _Pragma("unroll")
6
+
7
+ #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
8
+ #define MLX_CUDA_SM_80_ENABLED
9
+ #endif
@@ -0,0 +1,101 @@
1
+
2
+ #include "mlx/backend/cuda/steel/mma.cuh"
3
+ #include "mlx/backend/cuda/steel/tiles.cuh"
4
+
5
+ namespace mlx::core::cu {
6
+
7
+ /**
8
+ * An example gemm written with the utils.
9
+ *
10
+ * Computes A @ B.T when A and B are all aligned with the block sizes.
11
+ */
12
+ template <typename T, int BM, int BN, int BK>
13
+ __global__ void ab_t_aligned(const T* a, const T* b, T* y, int N, int K) {
14
+ constexpr int WARPS_M = 2;
15
+ constexpr int WARPS_N = 2;
16
+ constexpr int NUM_WARPS = WARPS_M * WARPS_N;
17
+ constexpr int WARP_STEP_M = BM / WARPS_M;
18
+ constexpr int WARP_STEP_N = BN / WARPS_N;
19
+
20
+ // Precompute some offsets for each thread
21
+ const int warpid = threadIdx.x / 32;
22
+ const int laneid = threadIdx.x % 32;
23
+ const int wm = warpid / WARPS_N;
24
+ const int wn = warpid % WARPS_N;
25
+ const int offset_m = wm * WARP_STEP_M;
26
+ const int offset_n = wn * WARP_STEP_N;
27
+
28
+ // Allocate shared memory
29
+ extern __shared__ char shmem[];
30
+ SharedTile<T, BM, BK>(&as)[2] = *(SharedTile<T, BM, BK>(*)[2])(&shmem[0]);
31
+ SharedTile<T, BN, BK>(&bs)[2] =
32
+ *(SharedTile<T, BN, BK>(*)[2])(&shmem[sizeof(T) * 2 * BM * BK]);
33
+
34
+ // Allocate registers for the MMA
35
+ RegisterTile<float, BM / WARPS_M, BN / WARPS_N> C;
36
+ RegisterTile<T, BM / WARPS_M, 16> A;
37
+ RegisterTile<T, BN / WARPS_N, 16> B;
38
+
39
+ // Move the global pointers to the tile
40
+ a += blockIdx.y * BM * K;
41
+ b += blockIdx.x * BN * K;
42
+ y += blockIdx.y * BM * N + blockIdx.x * BN;
43
+
44
+ // Zero the accumulators
45
+ C.fill(0);
46
+
47
+ // Start the SM pipeline
48
+ load_async<NUM_WARPS>(as[0], as[0].base_addr(), a, K);
49
+ load_async<NUM_WARPS>(bs[0], bs[0].base_addr(), b, K);
50
+ cp_async_commit();
51
+
52
+ int tic = 0;
53
+ for (int k_block = BK; k_block < K; k_block += BK) {
54
+ load_async<NUM_WARPS>(as[tic ^ 1], as[tic ^ 1].base_addr(), a + k_block, K);
55
+ load_async<NUM_WARPS>(bs[tic ^ 1], bs[tic ^ 1].base_addr(), b + k_block, K);
56
+ cp_async_commit();
57
+ cp_async_wait<1>();
58
+ __syncthreads();
59
+
60
+ MLX_UNROLL
61
+ for (int k = 0; k < BK / 16; k++) {
62
+ A.load(
63
+ as[tic],
64
+ as[tic].base_addr(),
65
+ offset_m + laneid % 16,
66
+ k * 16 + laneid / 16 * 8);
67
+ B.load(
68
+ bs[tic],
69
+ bs[tic].base_addr(),
70
+ offset_n + laneid % 16,
71
+ k * 16 + laneid / 16 * 8);
72
+
73
+ mma_t(C, A, B);
74
+ }
75
+
76
+ tic ^= 1;
77
+ }
78
+
79
+ // Empty the pipeline
80
+ cp_async_wait_all();
81
+ __syncthreads();
82
+ MLX_UNROLL
83
+ for (int k = 0; k < BK / 16; k++) {
84
+ A.load(
85
+ as[tic],
86
+ as[tic].base_addr(),
87
+ offset_m + laneid % 16,
88
+ k * 16 + laneid / 16 * 8);
89
+ B.load(
90
+ bs[tic],
91
+ bs[tic].base_addr(),
92
+ offset_n + laneid % 16,
93
+ k * 16 + laneid / 16 * 8);
94
+
95
+ mma_t(C, A, B);
96
+ }
97
+
98
+ C.store_global(y, N, offset_m, offset_n);
99
+ }
100
+
101
+ } // namespace mlx::core::cu
@@ -0,0 +1,117 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/cuda/steel/defines.cuh"
6
+ #include "mlx/backend/cuda/steel/tiles.cuh"
7
+
8
+ namespace mlx::core::cu {
9
+
10
+ /**
11
+ * Fallback mma.
12
+ *
13
+ * We should probably a) implement a fallback or complain about it to the
14
+ * compiler.
15
+ */
16
+ template <typename U, typename T>
17
+ __device__ inline void
18
+ mma_t(Tile16x16<U>& C, Tile16x16<T>& A, Tile16x16<T>& B) {}
19
+
20
+ /**
21
+ * Multiply the 16x16 bfloat16 tiles and accumulate the result in one 16x16
22
+ * float tile.
23
+ *
24
+ * We actually perform C += A @ B.T
25
+ */
26
+ __device__ __forceinline__ void mma_t(
27
+ Tile16x16<float>& C,
28
+ Tile16x16<__nv_bfloat16>& A,
29
+ Tile16x16<__nv_bfloat16>& B) {
30
+ #if defined(MLX_CUDA_SM_80_ENABLED)
31
+ asm volatile(
32
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
33
+ "{%0, %1, %2, %3}, "
34
+ "{%4, %5, %6, %7}, "
35
+ "{%8, %9}, "
36
+ "{%10, %11, %12, %13};"
37
+
38
+ // D matrix
39
+ : "+f"(C.values[0].x),
40
+ "+f"(C.values[0].y),
41
+ "+f"(C.values[1].x),
42
+ "+f"(C.values[1].y)
43
+
44
+ // A matrix
45
+ : "r"(*(uint32_t*)(&A.values[0])),
46
+ "r"(*(uint32_t*)(&A.values[1])),
47
+ "r"(*(uint32_t*)(&A.values[2])),
48
+ "r"(*(uint32_t*)(&A.values[3])),
49
+
50
+ // B matrix
51
+ "r"(*(uint32_t*)(&B.values[0])),
52
+ "r"(*(uint32_t*)(&B.values[2])),
53
+
54
+ // C matrix
55
+ "f"(C.values[0].x),
56
+ "f"(C.values[0].y),
57
+ "f"(C.values[1].x),
58
+ "f"(C.values[1].y));
59
+ asm volatile(
60
+ "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
61
+ "{%0, %1, %2, %3}, "
62
+ "{%4, %5, %6, %7}, "
63
+ "{%8, %9}, "
64
+ "{%10, %11, %12, %13};"
65
+
66
+ // D matrix
67
+ : "+f"(C.values[2].x),
68
+ "+f"(C.values[2].y),
69
+ "+f"(C.values[3].x),
70
+ "+f"(C.values[3].y)
71
+
72
+ // A matrix
73
+ : "r"(*(uint32_t*)(&A.values[0])),
74
+ "r"(*(uint32_t*)(&A.values[1])),
75
+ "r"(*(uint32_t*)(&A.values[2])),
76
+ "r"(*(uint32_t*)(&A.values[3])),
77
+
78
+ // B matrix
79
+ "r"(*(uint32_t*)(&B.values[1])),
80
+ "r"(*(uint32_t*)(&B.values[3])),
81
+
82
+ // C matrix
83
+ "f"(C.values[2].x),
84
+ "f"(C.values[2].y),
85
+ "f"(C.values[3].x),
86
+ "f"(C.values[3].y));
87
+ #endif
88
+ }
89
+
90
+ /**
91
+ * Multiply larger register tiles by delegating to mma_t.
92
+ */
93
+ template <typename U, typename T, int M, int N, int K>
94
+ __device__ __forceinline__ void mma_t(
95
+ RegisterTile<U, M, N>& C,
96
+ RegisterTile<T, M, K>& A,
97
+ RegisterTile<T, N, K>& B) {
98
+ constexpr int TILES_M = RegisterTile<T, M, K>::TILES_Y;
99
+ constexpr int TILES_K = RegisterTile<T, M, K>::TILES_X;
100
+ constexpr int TILES_N = RegisterTile<T, N, K>::TILES_Y;
101
+
102
+ MLX_UNROLL
103
+ for (int k = 0; k < TILES_K; k++) {
104
+ MLX_UNROLL
105
+ for (int m = 0; m < TILES_M; m++) {
106
+ MLX_UNROLL
107
+ for (int n = 0; n < TILES_N; n++) {
108
+ mma_t(
109
+ C.data[m * TILES_N + n],
110
+ A.data[m * TILES_K + k],
111
+ B.data[n * TILES_K + k]);
112
+ }
113
+ }
114
+ }
115
+ }
116
+
117
+ } // namespace mlx::core::cu
@@ -0,0 +1,450 @@
1
+ // Copyright © 2025 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include "mlx/backend/cuda/steel/utils.cuh"
6
+ #include "mlx/backend/cuda/vector_types.cuh"
7
+
8
+ namespace mlx::core::cu {
9
+
10
+ /**
11
+ * The basic building block for Ampere mmas. A 16x16 tile distributed across
12
+ * the warp.
13
+ *
14
+ * Each thread holds 8 values. They are distributed according to
15
+ * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
16
+ *
17
+ * For use instructions see the individual methods eg load().
18
+ */
19
+ template <typename T>
20
+ struct Tile16x16 {
21
+ using T2 = Vector2_t<T>;
22
+
23
+ T2 values[4];
24
+
25
+ __device__ inline void fill(T v) {
26
+ T2 v2 = {v, v};
27
+ for (int i = 0; i < 4; i++) {
28
+ values[i] = v2;
29
+ }
30
+ }
31
+
32
+ /**
33
+ * Load a 16x16 tile from shared memory.
34
+ *
35
+ * The instruction is a bit weird in the sense that the address provided by
36
+ * each thread and the elements loaded are not the same.
37
+ *
38
+ * We load 4 8x8 tiles. The tile rows are stored contiguously in memory. As a
39
+ * result the warp provides 4*8 = 32 addresses one per row.
40
+ *
41
+ * Threads 0-7 provide the addresses for the first tile, 8-15 for the second
42
+ * and so on. For instance to load a non swizzled tile we would do
43
+ *
44
+ * base_addr + (laneid % 16) * BK + (laneid / 2) * 8
45
+ *
46
+ * See
47
+ * https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-instructions-ldmatrix
48
+ */
49
+ __device__ __forceinline__ void load(uint32_t row_address) {
50
+ if constexpr (
51
+ std::is_same_v<T2, __nv_bfloat162> || std::is_same_v<T2, __half2>) {
52
+ asm volatile(
53
+ "ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n"
54
+ : "=r"(*(uint32_t*)&(values[0])),
55
+ "=r"(*(uint32_t*)&(values[1])),
56
+ "=r"(*(uint32_t*)&(values[2])),
57
+ "=r"(*(uint32_t*)&(values[3]))
58
+ : "r"(row_address));
59
+ }
60
+ }
61
+
62
+ /**
63
+ * Store the tile to the address pointed to by `x`.
64
+ *
65
+ * The provided pointer is a generic pointer but this is meant to be used to
66
+ * store to global memory. For storing to shared memory we should use
67
+ * `stmatrix`.
68
+ *
69
+ * This also showcases the format of the tile quite nicely. Each register is
70
+ * holding to adjacent values. The indices are
71
+ *
72
+ * row + 0, col + 0
73
+ * row + 8, col + 0
74
+ * row + 0, col + 8
75
+ * row + 8, col + 8
76
+ *
77
+ * Given that we are dealing with Vector2_t<U> the column offsets are 4
78
+ * instead of 8.
79
+ */
80
+ template <typename U>
81
+ __device__ inline void store_global(U* x, int N) {
82
+ using U2 = Vector2_t<U>;
83
+ U2* x2 = reinterpret_cast<U2*>(x);
84
+ const int laneid = threadIdx.x % 32;
85
+ const int row = laneid / 4;
86
+ const int col = laneid % 4;
87
+ if constexpr (std::is_same_v<U2, T2>) {
88
+ x2[(row + 0) * (N / 2) + col + 0] = values[0];
89
+ x2[(row + 0) * (N / 2) + col + 4] = values[2];
90
+ x2[(row + 8) * (N / 2) + col + 0] = values[1];
91
+ x2[(row + 8) * (N / 2) + col + 4] = values[3];
92
+ } else if constexpr (
93
+ std::is_same_v<T2, float2> && std::is_same_v<U, __nv_bfloat16>) {
94
+ x2[(row + 0) * (N / 2) + col + 0] =
95
+ __floats2bfloat162_rn(values[0].x, values[0].y);
96
+ x2[(row + 0) * (N / 2) + col + 4] =
97
+ __floats2bfloat162_rn(values[2].x, values[2].y);
98
+ x2[(row + 8) * (N / 2) + col + 0] =
99
+ __floats2bfloat162_rn(values[1].x, values[1].y);
100
+ x2[(row + 8) * (N / 2) + col + 4] =
101
+ __floats2bfloat162_rn(values[3].x, values[3].y);
102
+ }
103
+ }
104
+
105
+ template <typename U>
106
+ __device__ inline void store_global_safe(U* x, int N, int max_rows) {
107
+ const int laneid = threadIdx.x % 32;
108
+ const int row = laneid / 4;
109
+ const int col = laneid % 4;
110
+ if (row < max_rows) {
111
+ x[(row + 0) * N + 2 * col + 0] = static_cast<U>(values[0].x);
112
+ x[(row + 0) * N + 2 * col + 1] = static_cast<U>(values[0].y);
113
+ x[(row + 0) * N + 2 * col + 8] = static_cast<U>(values[2].x);
114
+ x[(row + 0) * N + 2 * col + 9] = static_cast<U>(values[2].y);
115
+ }
116
+ if (row + 8 < max_rows) {
117
+ x[(row + 8) * N + 2 * col + 0] = static_cast<U>(values[1].x);
118
+ x[(row + 8) * N + 2 * col + 1] = static_cast<U>(values[1].y);
119
+ x[(row + 8) * N + 2 * col + 8] = static_cast<U>(values[3].x);
120
+ x[(row + 8) * N + 2 * col + 9] = static_cast<U>(values[3].y);
121
+ }
122
+ }
123
+ };
124
+
125
+ /**
126
+ * A simple container of multiple Tile16x16.
127
+ *
128
+ * Provides utility functions for loading and manipulating collections of basic
129
+ * tiles.
130
+ */
131
+ template <typename T, int ROWS_, int COLS_>
132
+ struct RegisterTile {
133
+ static constexpr int ROWS = ROWS_;
134
+ static constexpr int COLS = COLS_;
135
+ static constexpr int TILES_X = COLS / 16;
136
+ static constexpr int TILES_Y = ROWS / 16;
137
+
138
+ Tile16x16<T> data[TILES_X * TILES_Y];
139
+
140
+ __device__ inline void fill(T v) {
141
+ MLX_UNROLL
142
+ for (int i = 0; i < TILES_Y; i++) {
143
+ MLX_UNROLL
144
+ for (int j = 0; j < TILES_X; j++) {
145
+ data[i * TILES_X + j].fill(v);
146
+ }
147
+ }
148
+ }
149
+
150
+ template <typename Tile>
151
+ __device__ __forceinline__ void
152
+ load(Tile& tile, uint32_t base_address, int row, int col) {
153
+ MLX_UNROLL
154
+ for (int i = 0; i < TILES_Y; i++) {
155
+ MLX_UNROLL
156
+ for (int j = 0; j < TILES_X; j++) {
157
+ data[i * TILES_X + j].load(
158
+ tile.loc(base_address, row + i * 16, col + j * 16));
159
+ }
160
+ }
161
+ }
162
+
163
+ template <typename Tile, typename F>
164
+ __device__ __forceinline__ void
165
+ load(Tile& tile, F f, uint32_t base_address, int row, int col) {
166
+ MLX_UNROLL
167
+ for (int i = 0; i < TILES_Y; i++) {
168
+ MLX_UNROLL
169
+ for (int j = 0; j < TILES_X; j++) {
170
+ f(data[i * TILES_X + j],
171
+ tile,
172
+ base_address,
173
+ row + i * 16,
174
+ col + j * 16);
175
+ }
176
+ }
177
+ }
178
+
179
+ template <typename U>
180
+ __device__ inline void store_global(U* x, int N, int row, int col) {
181
+ MLX_UNROLL
182
+ for (int i = 0; i < TILES_Y; i++) {
183
+ MLX_UNROLL
184
+ for (int j = 0; j < TILES_X; j++) {
185
+ data[i * TILES_X + j].store_global(
186
+ x + (row + i * 16) * N + col + j * 16, N);
187
+ }
188
+ }
189
+ }
190
+
191
+ template <typename U>
192
+ __device__ inline void
193
+ store_global_safe(U* x, int N, int row, int col, int max_rows) {
194
+ MLX_UNROLL
195
+ for (int i = 0; i < TILES_Y; i++) {
196
+ MLX_UNROLL
197
+ for (int j = 0; j < TILES_X; j++) {
198
+ data[i * TILES_X + j].store_global_safe(
199
+ x + (row + i * 16) * N + col + j * 16, N, max_rows - row - i * 16);
200
+ }
201
+ }
202
+ }
203
+ };
204
+
205
+ /**
206
+ * A simple container of multiple Tile16x16.
207
+ *
208
+ * Provides utility functions for loading and manipulating collections of basic
209
+ * tiles.
210
+ */
211
+ template <typename T, int ROWS_, int COLS_>
212
+ struct RegisterTile {
213
+ static constexpr int ROWS = ROWS_;
214
+ static constexpr int COLS = COLS_;
215
+ static constexpr int TILES_X = COLS / 16;
216
+ static constexpr int TILES_Y = ROWS / 16;
217
+
218
+ Tile16x16<T> data[TILES_X * TILES_Y];
219
+
220
+ __device__ inline void fill(T v) {
221
+ MLX_UNROLL
222
+ for (int i = 0; i < TILES_Y; i++) {
223
+ MLX_UNROLL
224
+ for (int j = 0; j < TILES_X; j++) {
225
+ data[i * TILES_X + j].fill(v);
226
+ }
227
+ }
228
+ }
229
+
230
+ template <typename Tile>
231
+ __device__ inline void
232
+ load(Tile& tile, uint32_t base_address, int row, int col) {
233
+ MLX_UNROLL
234
+ for (int i = 0; i < TILES_Y; i++) {
235
+ MLX_UNROLL
236
+ for (int j = 0; j < TILES_X; j++) {
237
+ data[i * TILES_X + j].load(
238
+ tile.loc(base_address, row + i * 16, col + j * 16));
239
+ }
240
+ }
241
+ }
242
+
243
+ template <typename U>
244
+ __device__ inline void store_global(U* x, int N, int row, int col) {
245
+ MLX_UNROLL
246
+ for (int i = 0; i < TILES_Y; i++) {
247
+ MLX_UNROLL
248
+ for (int j = 0; j < TILES_X; j++) {
249
+ data[i * TILES_X + j].store_global(
250
+ x + (row + i * 16) * N + col + j * 16, N);
251
+ }
252
+ }
253
+ }
254
+ };
255
+
256
+ template <typename T, int ROWS_, int COLS_>
257
+ struct SharedTile {
258
+ static constexpr int ROWS = ROWS_;
259
+ static constexpr int COLS = COLS_;
260
+ static constexpr int TILES_X = COLS / 16;
261
+ static constexpr int TILES_Y = ROWS / 16;
262
+ static constexpr int NUMEL = ROWS * COLS;
263
+
264
+ // Swizzle taken from ThunderKittens. Should be changed when we switch to
265
+ // cute Layouts.
266
+ //
267
+ // See inludes/types/shared/st.cuh
268
+ //
269
+ // I do feel that it is too math heavy and can be improved. Also the math is
270
+ // done every time although the addresses don't change from load to load. I
271
+ // guess we are expecting the compiler to figure that out.
272
+ static constexpr int swizzle_bytes =
273
+ (sizeof(T) == 2 ? (TILES_X % 4 == 0 ? 128 : (TILES_X % 2 == 0 ? 64 : 32))
274
+ : (sizeof(T) == 4 ? (TILES_X % 2 == 0 ? 128 : 64) : 0));
275
+
276
+ T data[ROWS * COLS];
277
+
278
+ __device__ inline uint32_t base_addr() const {
279
+ return __cvta_generic_to_shared(&data[0]);
280
+ }
281
+
282
+ // Return a pointer to the element at (row, col) using the swizzle.
283
+ __device__ static inline T* ptr(T* ptr, int row, int col) {
284
+ if constexpr (swizzle_bytes > 0) {
285
+ static constexpr int swizzle_repeat = swizzle_bytes * 8;
286
+ static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
287
+ const int outer_idx = col / subtile_cols;
288
+ const uint64_t addr =
289
+ (uint64_t)(&ptr
290
+ [outer_idx * ROWS * subtile_cols + row * subtile_cols +
291
+ col % subtile_cols]);
292
+ const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
293
+ return (T*)(addr ^ swizzle);
294
+ } else {
295
+ return ptr + row * COLS + col;
296
+ }
297
+ }
298
+
299
+ // Return the location of the element at (row, col) using the swizzle.
300
+ __device__ static inline uint32_t loc(uint32_t ptr, int row, int col) {
301
+ if constexpr (swizzle_bytes > 0) {
302
+ static constexpr int swizzle_repeat = swizzle_bytes * 8;
303
+ static constexpr int subtile_cols = swizzle_bytes / sizeof(T);
304
+ const int outer_idx = col / subtile_cols;
305
+ const uint32_t addr = ptr +
306
+ sizeof(T) *
307
+ (outer_idx * ROWS * subtile_cols + row * subtile_cols +
308
+ col % subtile_cols);
309
+ const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
310
+ return (addr ^ swizzle);
311
+ } else {
312
+ return ptr + sizeof(T) * (row * COLS + col);
313
+ }
314
+ }
315
+
316
+ // Convenience functions to edit elements going through the swizzle.
317
+ __device__ inline T& operator()(int row, int col) {
318
+ return *ptr(data, row, col);
319
+ }
320
+ __device__ inline void store(float4& v, int row, int col) {
321
+ *(reinterpret_cast<float4*>(ptr(data, row, col))) = v;
322
+ }
323
+ __device__ inline void store(float2& v, int row, int col) {
324
+ *(reinterpret_cast<float2*>(ptr(data, row, col))) = v;
325
+ }
326
+ __device__ inline void store(float& v, int row, int col) {
327
+ *(reinterpret_cast<float*>(ptr(data, row, col))) = v;
328
+ }
329
+ template <int N>
330
+ __device__ inline void store(T (&v)[N], int row, int col) {
331
+ if constexpr (sizeof(T) * N == 4) {
332
+ store(*(reinterpret_cast<float*>(&v[0])), row, col);
333
+ } else if constexpr (sizeof(T) * N == 8) {
334
+ store(*(reinterpret_cast<float2*>(&v[0])), row, col);
335
+ } else if constexpr (sizeof(T) * N == 16) {
336
+ store(*(reinterpret_cast<float4*>(&v[0])), row, col);
337
+ } else {
338
+ MLX_UNROLL
339
+ for (int i = 0; i < N; i++) {
340
+ *ptr(data, row, col + i) = v[i];
341
+ }
342
+ }
343
+ }
344
+ };
345
+
346
+ /**
347
+ * Load the tile from global memory by loading 16 bytes at a time and storing
348
+ * them immediately.
349
+ *
350
+ * Can also be used as a fallback for architectures before sm_80.
351
+ */
352
+ template <int NUM_WARPS, typename T, typename Tile>
353
+ __device__ inline void load(Tile& tile, const T* x, int N) {
354
+ constexpr int NUM_THREADS = NUM_WARPS * 32;
355
+ constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
356
+ constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
357
+ constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
358
+ constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
359
+ constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
360
+
361
+ const int row = threadIdx.x / NUM_LOADS_PER_ROW;
362
+ const int col = threadIdx.x % NUM_LOADS_PER_ROW;
363
+
364
+ x += row * N + col * ELEMENTS_PER_LOAD;
365
+
366
+ MLX_UNROLL
367
+ for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
368
+ float4 tmp;
369
+ tmp = *(reinterpret_cast<const float4*>(&x[i * STEP_ROWS * N]));
370
+ tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
371
+ }
372
+ }
373
+
374
+ /**
375
+ * The asynchronous equivalent of load.
376
+ *
377
+ * Loads the tile from global memory by submitting a bunch of async copy
378
+ * instructions. The copy won't start until commit is called and we don't have
379
+ * a guarantee it will finish until wait is called.
380
+ *
381
+ * It should be used as follows
382
+ *
383
+ * load(...)
384
+ * load(...)
385
+ * cp_async_commit()
386
+ * do_other_stuff()
387
+ * cp_async_wait_all()
388
+ * do_stuff_with_shmem()
389
+ */
390
+ template <int NUM_WARPS, typename T, typename Tile>
391
+ __device__ inline void
392
+ load_async(Tile& tile, uint32_t base_address, const T* x, int N) {
393
+ constexpr int NUM_THREADS = NUM_WARPS * 32;
394
+ constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
395
+ constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
396
+ constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
397
+ constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
398
+ constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
399
+
400
+ const int row = threadIdx.x / NUM_LOADS_PER_ROW;
401
+ const int col = threadIdx.x % NUM_LOADS_PER_ROW;
402
+
403
+ x += row * N + col * ELEMENTS_PER_LOAD;
404
+
405
+ MLX_UNROLL
406
+ for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
407
+ cp_async<16>(
408
+ tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
409
+ x + i * STEP_ROWS * N);
410
+ }
411
+ }
412
+
413
+ /**
414
+ * Same as load_async but checks if we can load the row.
415
+ *
416
+ * NOTE: It should be changed to use a predicated cp async instead.
417
+ */
418
+ template <int NUM_WARPS, typename T, typename Tile>
419
+ __device__ inline void load_async_safe(
420
+ Tile& tile,
421
+ uint32_t base_address,
422
+ const T* x,
423
+ int N,
424
+ int max_rows) {
425
+ constexpr int NUM_THREADS = NUM_WARPS * 32;
426
+ constexpr int ELEMENTS_PER_LOAD = sizeof(float4) / sizeof(T);
427
+ constexpr int NUM_LOADS = Tile::NUMEL / ELEMENTS_PER_LOAD;
428
+ constexpr int NUM_LOADS_PER_THREAD = NUM_LOADS / NUM_THREADS;
429
+ constexpr int NUM_LOADS_PER_ROW = Tile::COLS / ELEMENTS_PER_LOAD;
430
+ constexpr int STEP_ROWS = NUM_THREADS / NUM_LOADS_PER_ROW;
431
+
432
+ const int row = threadIdx.x / NUM_LOADS_PER_ROW;
433
+ const int col = threadIdx.x % NUM_LOADS_PER_ROW;
434
+
435
+ x += row * N + col * ELEMENTS_PER_LOAD;
436
+
437
+ MLX_UNROLL
438
+ for (int i = 0; i < NUM_LOADS_PER_THREAD; i++) {
439
+ if (row + i * STEP_ROWS < max_rows) {
440
+ cp_async<16>(
441
+ tile.loc(base_address, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD),
442
+ x + i * STEP_ROWS * N);
443
+ } else {
444
+ float4 tmp = {0, 0, 0, 0};
445
+ tile.store(tmp, row + i * STEP_ROWS, col * ELEMENTS_PER_LOAD);
446
+ }
447
+ }
448
+ }
449
+
450
+ } // namespace mlx::core::cu