mlx 0.30.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (599) hide show
  1. checksums.yaml +7 -0
  2. data/ext/mlx/extconf.rb +94 -0
  3. data/ext/mlx/native.cpp +8027 -0
  4. data/lib/mlx/core.rb +1678 -0
  5. data/lib/mlx/distributed_utils/common.rb +116 -0
  6. data/lib/mlx/distributed_utils/config.rb +600 -0
  7. data/lib/mlx/distributed_utils/launch.rb +490 -0
  8. data/lib/mlx/extension.rb +24 -0
  9. data/lib/mlx/nn/base.rb +388 -0
  10. data/lib/mlx/nn/init.rb +140 -0
  11. data/lib/mlx/nn/layers/activations.rb +336 -0
  12. data/lib/mlx/nn/layers/base.rb +6 -0
  13. data/lib/mlx/nn/layers/containers.rb +20 -0
  14. data/lib/mlx/nn/layers/convolution.rb +120 -0
  15. data/lib/mlx/nn/layers/convolution_transpose.rb +114 -0
  16. data/lib/mlx/nn/layers/distributed.rb +309 -0
  17. data/lib/mlx/nn/layers/dropout.rb +75 -0
  18. data/lib/mlx/nn/layers/embedding.rb +28 -0
  19. data/lib/mlx/nn/layers/linear.rb +79 -0
  20. data/lib/mlx/nn/layers/normalization.rb +216 -0
  21. data/lib/mlx/nn/layers/pooling.rb +167 -0
  22. data/lib/mlx/nn/layers/positional_encoding.rb +126 -0
  23. data/lib/mlx/nn/layers/quantized.rb +215 -0
  24. data/lib/mlx/nn/layers/recurrent.rb +135 -0
  25. data/lib/mlx/nn/layers/transformer.rb +330 -0
  26. data/lib/mlx/nn/layers/upsample.rb +97 -0
  27. data/lib/mlx/nn/layers.rb +18 -0
  28. data/lib/mlx/nn/losses.rb +251 -0
  29. data/lib/mlx/nn/utils.rb +167 -0
  30. data/lib/mlx/nn.rb +12 -0
  31. data/lib/mlx/optimizers/optimizers.rb +808 -0
  32. data/lib/mlx/optimizers/schedulers.rb +62 -0
  33. data/lib/mlx/optimizers.rb +9 -0
  34. data/lib/mlx/utils.rb +171 -0
  35. data/lib/mlx/version.rb +5 -0
  36. data/lib/mlx.rb +64 -0
  37. data/mlx/CMakeLists.txt +449 -0
  38. data/mlx/cmake/FindCUDNN.cmake +177 -0
  39. data/mlx/cmake/FindNCCL.cmake +54 -0
  40. data/mlx/cmake/Findnvpl.cmake +3 -0
  41. data/mlx/cmake/extension.cmake +50 -0
  42. data/mlx/mlx/3rdparty/.clang-format +2 -0
  43. data/mlx/mlx/3rdparty/pocketfft.h +3581 -0
  44. data/mlx/mlx/CMakeLists.txt +107 -0
  45. data/mlx/mlx/allocator.h +75 -0
  46. data/mlx/mlx/api.h +29 -0
  47. data/mlx/mlx/array.cpp +354 -0
  48. data/mlx/mlx/array.h +647 -0
  49. data/mlx/mlx/backend/common/CMakeLists.txt +9 -0
  50. data/mlx/mlx/backend/common/binary.h +97 -0
  51. data/mlx/mlx/backend/common/broadcasting.cpp +24 -0
  52. data/mlx/mlx/backend/common/broadcasting.h +11 -0
  53. data/mlx/mlx/backend/common/buffer_cache.h +158 -0
  54. data/mlx/mlx/backend/common/common.cpp +305 -0
  55. data/mlx/mlx/backend/common/compiled.cpp +243 -0
  56. data/mlx/mlx/backend/common/compiled.h +77 -0
  57. data/mlx/mlx/backend/common/copy.h +50 -0
  58. data/mlx/mlx/backend/common/hadamard.h +109 -0
  59. data/mlx/mlx/backend/common/load.cpp +57 -0
  60. data/mlx/mlx/backend/common/matmul.h +67 -0
  61. data/mlx/mlx/backend/common/reduce.cpp +154 -0
  62. data/mlx/mlx/backend/common/reduce.h +59 -0
  63. data/mlx/mlx/backend/common/slicing.cpp +71 -0
  64. data/mlx/mlx/backend/common/slicing.h +20 -0
  65. data/mlx/mlx/backend/common/ternary.h +85 -0
  66. data/mlx/mlx/backend/common/unary.h +29 -0
  67. data/mlx/mlx/backend/common/utils.cpp +231 -0
  68. data/mlx/mlx/backend/common/utils.h +205 -0
  69. data/mlx/mlx/backend/cpu/CMakeLists.txt +88 -0
  70. data/mlx/mlx/backend/cpu/arange.h +28 -0
  71. data/mlx/mlx/backend/cpu/arg_reduce.cpp +124 -0
  72. data/mlx/mlx/backend/cpu/binary.cpp +269 -0
  73. data/mlx/mlx/backend/cpu/binary.h +517 -0
  74. data/mlx/mlx/backend/cpu/binary_ops.h +98 -0
  75. data/mlx/mlx/backend/cpu/binary_two.h +166 -0
  76. data/mlx/mlx/backend/cpu/cholesky.cpp +85 -0
  77. data/mlx/mlx/backend/cpu/compiled.cpp +357 -0
  78. data/mlx/mlx/backend/cpu/compiled_preamble.h +12 -0
  79. data/mlx/mlx/backend/cpu/conv.cpp +1351 -0
  80. data/mlx/mlx/backend/cpu/copy.cpp +386 -0
  81. data/mlx/mlx/backend/cpu/copy.h +36 -0
  82. data/mlx/mlx/backend/cpu/device_info.cpp +113 -0
  83. data/mlx/mlx/backend/cpu/device_info.h +28 -0
  84. data/mlx/mlx/backend/cpu/distributed.cpp +103 -0
  85. data/mlx/mlx/backend/cpu/eig.cpp +281 -0
  86. data/mlx/mlx/backend/cpu/eigh.cpp +241 -0
  87. data/mlx/mlx/backend/cpu/encoder.cpp +16 -0
  88. data/mlx/mlx/backend/cpu/encoder.h +67 -0
  89. data/mlx/mlx/backend/cpu/eval.cpp +40 -0
  90. data/mlx/mlx/backend/cpu/eval.h +12 -0
  91. data/mlx/mlx/backend/cpu/fft.cpp +120 -0
  92. data/mlx/mlx/backend/cpu/gemm.h +26 -0
  93. data/mlx/mlx/backend/cpu/gemms/bnns.cpp +214 -0
  94. data/mlx/mlx/backend/cpu/gemms/cblas.cpp +134 -0
  95. data/mlx/mlx/backend/cpu/gemms/simd_bf16.cpp +45 -0
  96. data/mlx/mlx/backend/cpu/gemms/simd_fp16.cpp +45 -0
  97. data/mlx/mlx/backend/cpu/gemms/simd_gemm.h +139 -0
  98. data/mlx/mlx/backend/cpu/hadamard.cpp +121 -0
  99. data/mlx/mlx/backend/cpu/indexing.cpp +854 -0
  100. data/mlx/mlx/backend/cpu/inverse.cpp +160 -0
  101. data/mlx/mlx/backend/cpu/jit_compiler.cpp +166 -0
  102. data/mlx/mlx/backend/cpu/jit_compiler.h +20 -0
  103. data/mlx/mlx/backend/cpu/lapack.h +80 -0
  104. data/mlx/mlx/backend/cpu/logsumexp.cpp +139 -0
  105. data/mlx/mlx/backend/cpu/luf.cpp +120 -0
  106. data/mlx/mlx/backend/cpu/make_compiled_preamble.ps1 +38 -0
  107. data/mlx/mlx/backend/cpu/make_compiled_preamble.sh +41 -0
  108. data/mlx/mlx/backend/cpu/masked_mm.cpp +608 -0
  109. data/mlx/mlx/backend/cpu/matmul.cpp +166 -0
  110. data/mlx/mlx/backend/cpu/primitives.cpp +478 -0
  111. data/mlx/mlx/backend/cpu/qrf.cpp +147 -0
  112. data/mlx/mlx/backend/cpu/quantized.cpp +1370 -0
  113. data/mlx/mlx/backend/cpu/reduce.cpp +587 -0
  114. data/mlx/mlx/backend/cpu/scan.cpp +338 -0
  115. data/mlx/mlx/backend/cpu/select.cpp +95 -0
  116. data/mlx/mlx/backend/cpu/simd/accelerate_fp16_simd.h +56 -0
  117. data/mlx/mlx/backend/cpu/simd/accelerate_simd.h +329 -0
  118. data/mlx/mlx/backend/cpu/simd/base_simd.h +319 -0
  119. data/mlx/mlx/backend/cpu/simd/math.h +193 -0
  120. data/mlx/mlx/backend/cpu/simd/neon_fp16_simd.h +212 -0
  121. data/mlx/mlx/backend/cpu/simd/simd.h +4 -0
  122. data/mlx/mlx/backend/cpu/simd/type.h +11 -0
  123. data/mlx/mlx/backend/cpu/slicing.h +21 -0
  124. data/mlx/mlx/backend/cpu/softmax.cpp +170 -0
  125. data/mlx/mlx/backend/cpu/sort.cpp +481 -0
  126. data/mlx/mlx/backend/cpu/svd.cpp +289 -0
  127. data/mlx/mlx/backend/cpu/ternary.h +154 -0
  128. data/mlx/mlx/backend/cpu/threefry.cpp +31 -0
  129. data/mlx/mlx/backend/cpu/threefry.h +21 -0
  130. data/mlx/mlx/backend/cpu/unary.cpp +238 -0
  131. data/mlx/mlx/backend/cpu/unary.h +281 -0
  132. data/mlx/mlx/backend/cpu/unary_ops.h +175 -0
  133. data/mlx/mlx/backend/cuda/CMakeLists.txt +265 -0
  134. data/mlx/mlx/backend/cuda/allocator.cpp +451 -0
  135. data/mlx/mlx/backend/cuda/allocator.h +94 -0
  136. data/mlx/mlx/backend/cuda/arange.cu +68 -0
  137. data/mlx/mlx/backend/cuda/arg_reduce.cu +189 -0
  138. data/mlx/mlx/backend/cuda/bin2h.cmake +150 -0
  139. data/mlx/mlx/backend/cuda/binary/CMakeLists.txt +21 -0
  140. data/mlx/mlx/backend/cuda/binary/add.cu +7 -0
  141. data/mlx/mlx/backend/cuda/binary/arctan2.cu +7 -0
  142. data/mlx/mlx/backend/cuda/binary/binary.cuh +383 -0
  143. data/mlx/mlx/backend/cuda/binary/bitwise_binary.cu +27 -0
  144. data/mlx/mlx/backend/cuda/binary/divide.cu +7 -0
  145. data/mlx/mlx/backend/cuda/binary/equal.cu +15 -0
  146. data/mlx/mlx/backend/cuda/binary/greater.cu +7 -0
  147. data/mlx/mlx/backend/cuda/binary/greater_equal.cu +7 -0
  148. data/mlx/mlx/backend/cuda/binary/less.cu +7 -0
  149. data/mlx/mlx/backend/cuda/binary/less_equal.cu +7 -0
  150. data/mlx/mlx/backend/cuda/binary/log_add_exp.cu +7 -0
  151. data/mlx/mlx/backend/cuda/binary/logical_and.cu +7 -0
  152. data/mlx/mlx/backend/cuda/binary/logical_or.cu +7 -0
  153. data/mlx/mlx/backend/cuda/binary/maximum.cu +7 -0
  154. data/mlx/mlx/backend/cuda/binary/minimum.cu +7 -0
  155. data/mlx/mlx/backend/cuda/binary/multiply.cu +7 -0
  156. data/mlx/mlx/backend/cuda/binary/not_equal.cu +7 -0
  157. data/mlx/mlx/backend/cuda/binary/power.cu +7 -0
  158. data/mlx/mlx/backend/cuda/binary/remainder.cu +7 -0
  159. data/mlx/mlx/backend/cuda/binary/subtract.cu +7 -0
  160. data/mlx/mlx/backend/cuda/binary_two.cu +412 -0
  161. data/mlx/mlx/backend/cuda/compiled.cpp +357 -0
  162. data/mlx/mlx/backend/cuda/conv/conv.h +126 -0
  163. data/mlx/mlx/backend/cuda/conv/gemm_conv.cu +217 -0
  164. data/mlx/mlx/backend/cuda/conv/gemm_grouped_conv.cu +231 -0
  165. data/mlx/mlx/backend/cuda/conv.cpp +403 -0
  166. data/mlx/mlx/backend/cuda/copy/copy.cuh +55 -0
  167. data/mlx/mlx/backend/cuda/copy/copy_contiguous.cu +88 -0
  168. data/mlx/mlx/backend/cuda/copy/copy_general.cu +171 -0
  169. data/mlx/mlx/backend/cuda/copy/copy_general_dynamic.cu +118 -0
  170. data/mlx/mlx/backend/cuda/copy/copy_general_input.cu +229 -0
  171. data/mlx/mlx/backend/cuda/copy.cu +132 -0
  172. data/mlx/mlx/backend/cuda/cublas_utils.cpp +222 -0
  173. data/mlx/mlx/backend/cuda/cublas_utils.h +95 -0
  174. data/mlx/mlx/backend/cuda/cuda.h +21 -0
  175. data/mlx/mlx/backend/cuda/cuda_utils.h +90 -0
  176. data/mlx/mlx/backend/cuda/cudnn_utils.cpp +133 -0
  177. data/mlx/mlx/backend/cuda/cudnn_utils.h +187 -0
  178. data/mlx/mlx/backend/cuda/custom_kernel.cpp +379 -0
  179. data/mlx/mlx/backend/cuda/cutlass_utils.cuh +46 -0
  180. data/mlx/mlx/backend/cuda/delayload.cpp +80 -0
  181. data/mlx/mlx/backend/cuda/device/atomic_ops.cuh +63 -0
  182. data/mlx/mlx/backend/cuda/device/binary_ops.cuh +300 -0
  183. data/mlx/mlx/backend/cuda/device/cast_op.cuh +118 -0
  184. data/mlx/mlx/backend/cuda/device/complex.cuh +60 -0
  185. data/mlx/mlx/backend/cuda/device/config.h +12 -0
  186. data/mlx/mlx/backend/cuda/device/fp16_math.cuh +96 -0
  187. data/mlx/mlx/backend/cuda/device/gather.cuh +53 -0
  188. data/mlx/mlx/backend/cuda/device/gather_axis.cuh +65 -0
  189. data/mlx/mlx/backend/cuda/device/indexing.cuh +30 -0
  190. data/mlx/mlx/backend/cuda/device/scatter.cuh +68 -0
  191. data/mlx/mlx/backend/cuda/device/scatter_axis.cuh +67 -0
  192. data/mlx/mlx/backend/cuda/device/scatter_ops.cuh +44 -0
  193. data/mlx/mlx/backend/cuda/device/ternary_ops.cuh +13 -0
  194. data/mlx/mlx/backend/cuda/device/unary_ops.cuh +350 -0
  195. data/mlx/mlx/backend/cuda/device/utils.cuh +464 -0
  196. data/mlx/mlx/backend/cuda/device.cpp +522 -0
  197. data/mlx/mlx/backend/cuda/device.h +195 -0
  198. data/mlx/mlx/backend/cuda/device_info.cpp +232 -0
  199. data/mlx/mlx/backend/cuda/distributed.cu +121 -0
  200. data/mlx/mlx/backend/cuda/eval.cpp +66 -0
  201. data/mlx/mlx/backend/cuda/event.cu +415 -0
  202. data/mlx/mlx/backend/cuda/event.h +79 -0
  203. data/mlx/mlx/backend/cuda/fence.cpp +42 -0
  204. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.cpp +233 -0
  205. data/mlx/mlx/backend/cuda/gemms/cublas_gemm.h +114 -0
  206. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_0.cpp +77 -0
  207. data/mlx/mlx/backend/cuda/gemms/cublas_gemm_batched_12_9.cu +329 -0
  208. data/mlx/mlx/backend/cuda/gemms/gemv.cu +327 -0
  209. data/mlx/mlx/backend/cuda/gemms/gemv.h +34 -0
  210. data/mlx/mlx/backend/cuda/gemms/grouped_gemm.h +25 -0
  211. data/mlx/mlx/backend/cuda/gemms/grouped_gemm_unaligned.cu +358 -0
  212. data/mlx/mlx/backend/cuda/indexing.cpp +434 -0
  213. data/mlx/mlx/backend/cuda/jit_module.cpp +443 -0
  214. data/mlx/mlx/backend/cuda/jit_module.h +120 -0
  215. data/mlx/mlx/backend/cuda/kernel_utils.cu +52 -0
  216. data/mlx/mlx/backend/cuda/kernel_utils.cuh +148 -0
  217. data/mlx/mlx/backend/cuda/layer_norm.cu +417 -0
  218. data/mlx/mlx/backend/cuda/load.cpp +60 -0
  219. data/mlx/mlx/backend/cuda/logsumexp.cu +161 -0
  220. data/mlx/mlx/backend/cuda/lru_cache.h +190 -0
  221. data/mlx/mlx/backend/cuda/matmul.cpp +373 -0
  222. data/mlx/mlx/backend/cuda/no_cuda.cpp +47 -0
  223. data/mlx/mlx/backend/cuda/primitives.cpp +46 -0
  224. data/mlx/mlx/backend/cuda/quantized/affine_quantize.cu +329 -0
  225. data/mlx/mlx/backend/cuda/quantized/convert_fp8.cu +19 -0
  226. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.cpp +206 -0
  227. data/mlx/mlx/backend/cuda/quantized/cublas_qqmm.h +88 -0
  228. data/mlx/mlx/backend/cuda/quantized/cuda_fp4.h +100 -0
  229. data/mlx/mlx/backend/cuda/quantized/fp_quantize.cu +496 -0
  230. data/mlx/mlx/backend/cuda/quantized/mxfp8_quantize.cuh +32 -0
  231. data/mlx/mlx/backend/cuda/quantized/no_qqmm_impl.cpp +26 -0
  232. data/mlx/mlx/backend/cuda/quantized/nvfp4_quantize.cuh +334 -0
  233. data/mlx/mlx/backend/cuda/quantized/qmv.cu +304 -0
  234. data/mlx/mlx/backend/cuda/quantized/qmv.h +21 -0
  235. data/mlx/mlx/backend/cuda/quantized/qqmm.cpp +158 -0
  236. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.cpp +50 -0
  237. data/mlx/mlx/backend/cuda/quantized/qqmm_impl.h +26 -0
  238. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.cu +227 -0
  239. data/mlx/mlx/backend/cuda/quantized/qqmm_utils.h +30 -0
  240. data/mlx/mlx/backend/cuda/quantized/quantized.cpp +85 -0
  241. data/mlx/mlx/backend/cuda/quantized/quantized.h +53 -0
  242. data/mlx/mlx/backend/cuda/quantized/quantized_utils.cuh +88 -0
  243. data/mlx/mlx/backend/cuda/quantized/quantized_utils.h +50 -0
  244. data/mlx/mlx/backend/cuda/random.cu +202 -0
  245. data/mlx/mlx/backend/cuda/reduce/all_reduce.cu +159 -0
  246. data/mlx/mlx/backend/cuda/reduce/col_reduce.cu +510 -0
  247. data/mlx/mlx/backend/cuda/reduce/init_reduce.cu +50 -0
  248. data/mlx/mlx/backend/cuda/reduce/reduce.cuh +71 -0
  249. data/mlx/mlx/backend/cuda/reduce/reduce_ops.cuh +211 -0
  250. data/mlx/mlx/backend/cuda/reduce/reduce_utils.cuh +145 -0
  251. data/mlx/mlx/backend/cuda/reduce/row_reduce.cu +361 -0
  252. data/mlx/mlx/backend/cuda/reduce.cu +73 -0
  253. data/mlx/mlx/backend/cuda/rms_norm.cu +536 -0
  254. data/mlx/mlx/backend/cuda/rope.cu +429 -0
  255. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cpp +681 -0
  256. data/mlx/mlx/backend/cuda/scaled_dot_product_attention.cu +796 -0
  257. data/mlx/mlx/backend/cuda/scan.cu +468 -0
  258. data/mlx/mlx/backend/cuda/slicing.cpp +111 -0
  259. data/mlx/mlx/backend/cuda/softmax.cu +162 -0
  260. data/mlx/mlx/backend/cuda/sort.cu +1076 -0
  261. data/mlx/mlx/backend/cuda/steel/defines.cuh +9 -0
  262. data/mlx/mlx/backend/cuda/steel/gemm.cuh +101 -0
  263. data/mlx/mlx/backend/cuda/steel/mma.cuh +117 -0
  264. data/mlx/mlx/backend/cuda/steel/tiles.cuh +450 -0
  265. data/mlx/mlx/backend/cuda/steel/utils.cuh +89 -0
  266. data/mlx/mlx/backend/cuda/ternary.cu +271 -0
  267. data/mlx/mlx/backend/cuda/unary/CMakeLists.txt +34 -0
  268. data/mlx/mlx/backend/cuda/unary/abs.cu +7 -0
  269. data/mlx/mlx/backend/cuda/unary/arccos.cu +7 -0
  270. data/mlx/mlx/backend/cuda/unary/arccosh.cu +7 -0
  271. data/mlx/mlx/backend/cuda/unary/arcsin.cu +7 -0
  272. data/mlx/mlx/backend/cuda/unary/arcsinh.cu +7 -0
  273. data/mlx/mlx/backend/cuda/unary/arctan.cu +7 -0
  274. data/mlx/mlx/backend/cuda/unary/arctanh.cu +7 -0
  275. data/mlx/mlx/backend/cuda/unary/bitwise_invert.cu +7 -0
  276. data/mlx/mlx/backend/cuda/unary/ceil.cu +7 -0
  277. data/mlx/mlx/backend/cuda/unary/conjugate.cu +7 -0
  278. data/mlx/mlx/backend/cuda/unary/cos.cu +7 -0
  279. data/mlx/mlx/backend/cuda/unary/cosh.cu +7 -0
  280. data/mlx/mlx/backend/cuda/unary/erf.cu +7 -0
  281. data/mlx/mlx/backend/cuda/unary/erf_inv.cu +7 -0
  282. data/mlx/mlx/backend/cuda/unary/exp.cu +7 -0
  283. data/mlx/mlx/backend/cuda/unary/expm1.cu +7 -0
  284. data/mlx/mlx/backend/cuda/unary/floor.cu +7 -0
  285. data/mlx/mlx/backend/cuda/unary/imag.cu +7 -0
  286. data/mlx/mlx/backend/cuda/unary/log.cu +21 -0
  287. data/mlx/mlx/backend/cuda/unary/log1p.cu +7 -0
  288. data/mlx/mlx/backend/cuda/unary/logical_not.cu +7 -0
  289. data/mlx/mlx/backend/cuda/unary/negative.cu +7 -0
  290. data/mlx/mlx/backend/cuda/unary/real.cu +7 -0
  291. data/mlx/mlx/backend/cuda/unary/round.cu +18 -0
  292. data/mlx/mlx/backend/cuda/unary/sigmoid.cu +7 -0
  293. data/mlx/mlx/backend/cuda/unary/sign.cu +7 -0
  294. data/mlx/mlx/backend/cuda/unary/sin.cu +7 -0
  295. data/mlx/mlx/backend/cuda/unary/sinh.cu +7 -0
  296. data/mlx/mlx/backend/cuda/unary/sqrt.cu +15 -0
  297. data/mlx/mlx/backend/cuda/unary/square.cu +7 -0
  298. data/mlx/mlx/backend/cuda/unary/tan.cu +7 -0
  299. data/mlx/mlx/backend/cuda/unary/tanh.cu +7 -0
  300. data/mlx/mlx/backend/cuda/unary/unary.cuh +224 -0
  301. data/mlx/mlx/backend/cuda/utils.cpp +116 -0
  302. data/mlx/mlx/backend/cuda/utils.h +49 -0
  303. data/mlx/mlx/backend/cuda/vector_types.cuh +48 -0
  304. data/mlx/mlx/backend/cuda/worker.cpp +79 -0
  305. data/mlx/mlx/backend/cuda/worker.h +55 -0
  306. data/mlx/mlx/backend/gpu/CMakeLists.txt +5 -0
  307. data/mlx/mlx/backend/gpu/copy.cpp +89 -0
  308. data/mlx/mlx/backend/gpu/copy.h +57 -0
  309. data/mlx/mlx/backend/gpu/device_info.h +36 -0
  310. data/mlx/mlx/backend/gpu/eval.h +18 -0
  311. data/mlx/mlx/backend/gpu/primitives.cpp +307 -0
  312. data/mlx/mlx/backend/gpu/slicing.cpp +44 -0
  313. data/mlx/mlx/backend/gpu/slicing.h +36 -0
  314. data/mlx/mlx/backend/metal/CMakeLists.txt +144 -0
  315. data/mlx/mlx/backend/metal/allocator.cpp +279 -0
  316. data/mlx/mlx/backend/metal/allocator.h +79 -0
  317. data/mlx/mlx/backend/metal/binary.cpp +257 -0
  318. data/mlx/mlx/backend/metal/binary.h +33 -0
  319. data/mlx/mlx/backend/metal/compiled.cpp +471 -0
  320. data/mlx/mlx/backend/metal/conv.cpp +1118 -0
  321. data/mlx/mlx/backend/metal/copy.cpp +235 -0
  322. data/mlx/mlx/backend/metal/custom_kernel.cpp +430 -0
  323. data/mlx/mlx/backend/metal/device.cpp +816 -0
  324. data/mlx/mlx/backend/metal/device.h +289 -0
  325. data/mlx/mlx/backend/metal/device_info.cpp +58 -0
  326. data/mlx/mlx/backend/metal/distributed.cpp +38 -0
  327. data/mlx/mlx/backend/metal/eval.cpp +97 -0
  328. data/mlx/mlx/backend/metal/event.cpp +62 -0
  329. data/mlx/mlx/backend/metal/fence.cpp +162 -0
  330. data/mlx/mlx/backend/metal/fft.cpp +807 -0
  331. data/mlx/mlx/backend/metal/hadamard.cpp +198 -0
  332. data/mlx/mlx/backend/metal/indexing.cpp +727 -0
  333. data/mlx/mlx/backend/metal/jit/includes.h +58 -0
  334. data/mlx/mlx/backend/metal/jit/indexing.h +76 -0
  335. data/mlx/mlx/backend/metal/jit_kernels.cpp +1118 -0
  336. data/mlx/mlx/backend/metal/kernels/CMakeLists.txt +193 -0
  337. data/mlx/mlx/backend/metal/kernels/arange.h +9 -0
  338. data/mlx/mlx/backend/metal/kernels/arange.metal +20 -0
  339. data/mlx/mlx/backend/metal/kernels/arg_reduce.metal +182 -0
  340. data/mlx/mlx/backend/metal/kernels/atomic.h +345 -0
  341. data/mlx/mlx/backend/metal/kernels/bf16.h +16 -0
  342. data/mlx/mlx/backend/metal/kernels/bf16_math.h +380 -0
  343. data/mlx/mlx/backend/metal/kernels/binary.h +199 -0
  344. data/mlx/mlx/backend/metal/kernels/binary.metal +109 -0
  345. data/mlx/mlx/backend/metal/kernels/binary_ops.h +330 -0
  346. data/mlx/mlx/backend/metal/kernels/binary_two.h +244 -0
  347. data/mlx/mlx/backend/metal/kernels/binary_two.metal +54 -0
  348. data/mlx/mlx/backend/metal/kernels/cexpf.h +134 -0
  349. data/mlx/mlx/backend/metal/kernels/complex.h +173 -0
  350. data/mlx/mlx/backend/metal/kernels/conv.metal +701 -0
  351. data/mlx/mlx/backend/metal/kernels/copy.h +276 -0
  352. data/mlx/mlx/backend/metal/kernels/copy.metal +75 -0
  353. data/mlx/mlx/backend/metal/kernels/defines.h +24 -0
  354. data/mlx/mlx/backend/metal/kernels/erf.h +69 -0
  355. data/mlx/mlx/backend/metal/kernels/expm1f.h +90 -0
  356. data/mlx/mlx/backend/metal/kernels/fence.metal +52 -0
  357. data/mlx/mlx/backend/metal/kernels/fft/radix.h +328 -0
  358. data/mlx/mlx/backend/metal/kernels/fft/readwrite.h +624 -0
  359. data/mlx/mlx/backend/metal/kernels/fft.h +486 -0
  360. data/mlx/mlx/backend/metal/kernels/fft.metal +67 -0
  361. data/mlx/mlx/backend/metal/kernels/fp4.h +48 -0
  362. data/mlx/mlx/backend/metal/kernels/fp8.h +80 -0
  363. data/mlx/mlx/backend/metal/kernels/fp_quantized.h +1850 -0
  364. data/mlx/mlx/backend/metal/kernels/fp_quantized.metal +153 -0
  365. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.h +1044 -0
  366. data/mlx/mlx/backend/metal/kernels/fp_quantized_nax.metal +79 -0
  367. data/mlx/mlx/backend/metal/kernels/gemv.metal +868 -0
  368. data/mlx/mlx/backend/metal/kernels/gemv_masked.h +827 -0
  369. data/mlx/mlx/backend/metal/kernels/gemv_masked.metal +76 -0
  370. data/mlx/mlx/backend/metal/kernels/hadamard.h +182 -0
  371. data/mlx/mlx/backend/metal/kernels/indexing/gather.h +51 -0
  372. data/mlx/mlx/backend/metal/kernels/indexing/gather_axis.h +44 -0
  373. data/mlx/mlx/backend/metal/kernels/indexing/gather_front.h +24 -0
  374. data/mlx/mlx/backend/metal/kernels/indexing/indexing.h +23 -0
  375. data/mlx/mlx/backend/metal/kernels/indexing/masked_scatter.h +41 -0
  376. data/mlx/mlx/backend/metal/kernels/indexing/scatter.h +59 -0
  377. data/mlx/mlx/backend/metal/kernels/indexing/scatter_axis.h +52 -0
  378. data/mlx/mlx/backend/metal/kernels/layer_norm.metal +433 -0
  379. data/mlx/mlx/backend/metal/kernels/logging.h +26 -0
  380. data/mlx/mlx/backend/metal/kernels/logsumexp.h +140 -0
  381. data/mlx/mlx/backend/metal/kernels/logsumexp.metal +18 -0
  382. data/mlx/mlx/backend/metal/kernels/quantized.h +2508 -0
  383. data/mlx/mlx/backend/metal/kernels/quantized.metal +144 -0
  384. data/mlx/mlx/backend/metal/kernels/quantized_nax.h +1705 -0
  385. data/mlx/mlx/backend/metal/kernels/quantized_nax.metal +106 -0
  386. data/mlx/mlx/backend/metal/kernels/quantized_utils.h +90 -0
  387. data/mlx/mlx/backend/metal/kernels/random.metal +103 -0
  388. data/mlx/mlx/backend/metal/kernels/reduce.h +5 -0
  389. data/mlx/mlx/backend/metal/kernels/reduce.metal +169 -0
  390. data/mlx/mlx/backend/metal/kernels/reduce_utils.h +6 -0
  391. data/mlx/mlx/backend/metal/kernels/reduction/ops.h +275 -0
  392. data/mlx/mlx/backend/metal/kernels/reduction/reduce_all.h +66 -0
  393. data/mlx/mlx/backend/metal/kernels/reduction/reduce_col.h +398 -0
  394. data/mlx/mlx/backend/metal/kernels/reduction/reduce_init.h +8 -0
  395. data/mlx/mlx/backend/metal/kernels/reduction/reduce_row.h +369 -0
  396. data/mlx/mlx/backend/metal/kernels/rms_norm.metal +391 -0
  397. data/mlx/mlx/backend/metal/kernels/rope.metal +229 -0
  398. data/mlx/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +44 -0
  399. data/mlx/mlx/backend/metal/kernels/scan.h +514 -0
  400. data/mlx/mlx/backend/metal/kernels/scan.metal +109 -0
  401. data/mlx/mlx/backend/metal/kernels/sdpa_vector.h +394 -0
  402. data/mlx/mlx/backend/metal/kernels/softmax.h +190 -0
  403. data/mlx/mlx/backend/metal/kernels/softmax.metal +24 -0
  404. data/mlx/mlx/backend/metal/kernels/sort.h +719 -0
  405. data/mlx/mlx/backend/metal/kernels/sort.metal +80 -0
  406. data/mlx/mlx/backend/metal/kernels/steel/attn/attn.h +296 -0
  407. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.h +471 -0
  408. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention.metal +27 -0
  409. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.h +481 -0
  410. data/mlx/mlx/backend/metal/kernels/steel/attn/kernels/steel_attention_nax.metal +28 -0
  411. data/mlx/mlx/backend/metal/kernels/steel/attn/loader.h +264 -0
  412. data/mlx/mlx/backend/metal/kernels/steel/attn/mma.h +750 -0
  413. data/mlx/mlx/backend/metal/kernels/steel/attn/nax.h +1076 -0
  414. data/mlx/mlx/backend/metal/kernels/steel/attn/params.h +44 -0
  415. data/mlx/mlx/backend/metal/kernels/steel/attn/transforms.h +71 -0
  416. data/mlx/mlx/backend/metal/kernels/steel/conv/conv.h +13 -0
  417. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.h +176 -0
  418. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv.metal +56 -0
  419. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h +225 -0
  420. data/mlx/mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.metal +47 -0
  421. data/mlx/mlx/backend/metal/kernels/steel/conv/loader.h +6 -0
  422. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_l.h +451 -0
  423. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_channel_n.h +319 -0
  424. data/mlx/mlx/backend/metal/kernels/steel/conv/loaders/loader_general.h +381 -0
  425. data/mlx/mlx/backend/metal/kernels/steel/conv/params.h +62 -0
  426. data/mlx/mlx/backend/metal/kernels/steel/defines.h +7 -0
  427. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm.h +295 -0
  428. data/mlx/mlx/backend/metal/kernels/steel/gemm/gemm_nax.h +157 -0
  429. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.h +346 -0
  430. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused.metal +34 -0
  431. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.h +219 -0
  432. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_fused_nax.metal +30 -0
  433. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.h +459 -0
  434. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather.metal +59 -0
  435. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.h +143 -0
  436. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_gather_nax.metal +37 -0
  437. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.h +719 -0
  438. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_masked.metal +76 -0
  439. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.h +266 -0
  440. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_segmented.metal +43 -0
  441. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.h +227 -0
  442. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk.metal +76 -0
  443. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.h +152 -0
  444. data/mlx/mlx/backend/metal/kernels/steel/gemm/kernels/steel_gemm_splitk_nax.metal +30 -0
  445. data/mlx/mlx/backend/metal/kernels/steel/gemm/loader.h +137 -0
  446. data/mlx/mlx/backend/metal/kernels/steel/gemm/mma.h +1146 -0
  447. data/mlx/mlx/backend/metal/kernels/steel/gemm/nax.h +1084 -0
  448. data/mlx/mlx/backend/metal/kernels/steel/gemm/params.h +65 -0
  449. data/mlx/mlx/backend/metal/kernels/steel/gemm/transforms.h +72 -0
  450. data/mlx/mlx/backend/metal/kernels/steel/utils/integral_constant.h +134 -0
  451. data/mlx/mlx/backend/metal/kernels/steel/utils/type_traits.h +55 -0
  452. data/mlx/mlx/backend/metal/kernels/steel/utils.h +42 -0
  453. data/mlx/mlx/backend/metal/kernels/ternary.h +145 -0
  454. data/mlx/mlx/backend/metal/kernels/ternary.metal +48 -0
  455. data/mlx/mlx/backend/metal/kernels/ternary_ops.h +10 -0
  456. data/mlx/mlx/backend/metal/kernels/unary.h +63 -0
  457. data/mlx/mlx/backend/metal/kernels/unary.metal +115 -0
  458. data/mlx/mlx/backend/metal/kernels/unary_ops.h +454 -0
  459. data/mlx/mlx/backend/metal/kernels/utils.h +445 -0
  460. data/mlx/mlx/backend/metal/kernels.h +375 -0
  461. data/mlx/mlx/backend/metal/logsumexp.cpp +95 -0
  462. data/mlx/mlx/backend/metal/make_compiled_preamble.sh +120 -0
  463. data/mlx/mlx/backend/metal/matmul.cpp +2572 -0
  464. data/mlx/mlx/backend/metal/matmul.h +144 -0
  465. data/mlx/mlx/backend/metal/metal.cpp +50 -0
  466. data/mlx/mlx/backend/metal/metal.h +25 -0
  467. data/mlx/mlx/backend/metal/no_metal.cpp +42 -0
  468. data/mlx/mlx/backend/metal/nojit_kernels.cpp +414 -0
  469. data/mlx/mlx/backend/metal/normalization.cpp +433 -0
  470. data/mlx/mlx/backend/metal/primitives.cpp +242 -0
  471. data/mlx/mlx/backend/metal/quantized.cpp +1651 -0
  472. data/mlx/mlx/backend/metal/reduce.cpp +1038 -0
  473. data/mlx/mlx/backend/metal/reduce.h +41 -0
  474. data/mlx/mlx/backend/metal/resident.cpp +100 -0
  475. data/mlx/mlx/backend/metal/resident.h +32 -0
  476. data/mlx/mlx/backend/metal/rope.cpp +165 -0
  477. data/mlx/mlx/backend/metal/scaled_dot_product_attention.cpp +798 -0
  478. data/mlx/mlx/backend/metal/scan.cpp +145 -0
  479. data/mlx/mlx/backend/metal/scan.h +17 -0
  480. data/mlx/mlx/backend/metal/slicing.cpp +99 -0
  481. data/mlx/mlx/backend/metal/softmax.cpp +87 -0
  482. data/mlx/mlx/backend/metal/sort.cpp +368 -0
  483. data/mlx/mlx/backend/metal/ternary.cpp +160 -0
  484. data/mlx/mlx/backend/metal/ternary.h +21 -0
  485. data/mlx/mlx/backend/metal/unary.cpp +161 -0
  486. data/mlx/mlx/backend/metal/unary.h +21 -0
  487. data/mlx/mlx/backend/metal/utils.cpp +77 -0
  488. data/mlx/mlx/backend/metal/utils.h +99 -0
  489. data/mlx/mlx/backend/no_cpu/CMakeLists.txt +7 -0
  490. data/mlx/mlx/backend/no_cpu/compiled.cpp +24 -0
  491. data/mlx/mlx/backend/no_cpu/device_info.cpp +22 -0
  492. data/mlx/mlx/backend/no_cpu/primitives.cpp +146 -0
  493. data/mlx/mlx/backend/no_gpu/CMakeLists.txt +8 -0
  494. data/mlx/mlx/backend/no_gpu/allocator.cpp +134 -0
  495. data/mlx/mlx/backend/no_gpu/apple_memory.h +16 -0
  496. data/mlx/mlx/backend/no_gpu/device_info.cpp +22 -0
  497. data/mlx/mlx/backend/no_gpu/eval.cpp +24 -0
  498. data/mlx/mlx/backend/no_gpu/event.cpp +53 -0
  499. data/mlx/mlx/backend/no_gpu/fence.cpp +54 -0
  500. data/mlx/mlx/backend/no_gpu/linux_memory.h +22 -0
  501. data/mlx/mlx/backend/no_gpu/primitives.cpp +185 -0
  502. data/mlx/mlx/compile.cpp +1243 -0
  503. data/mlx/mlx/compile.h +45 -0
  504. data/mlx/mlx/compile_impl.h +70 -0
  505. data/mlx/mlx/device.cpp +72 -0
  506. data/mlx/mlx/device.h +56 -0
  507. data/mlx/mlx/distributed/CMakeLists.txt +14 -0
  508. data/mlx/mlx/distributed/distributed.cpp +197 -0
  509. data/mlx/mlx/distributed/distributed.h +61 -0
  510. data/mlx/mlx/distributed/distributed_impl.h +59 -0
  511. data/mlx/mlx/distributed/jaccl/CMakeLists.txt +12 -0
  512. data/mlx/mlx/distributed/jaccl/jaccl.cpp +178 -0
  513. data/mlx/mlx/distributed/jaccl/jaccl.h +12 -0
  514. data/mlx/mlx/distributed/jaccl/mesh.cpp +451 -0
  515. data/mlx/mlx/distributed/jaccl/mesh.h +122 -0
  516. data/mlx/mlx/distributed/jaccl/no_jaccl.cpp +20 -0
  517. data/mlx/mlx/distributed/jaccl/ring.cpp +692 -0
  518. data/mlx/mlx/distributed/jaccl/ring.h +178 -0
  519. data/mlx/mlx/distributed/jaccl/utils.cpp +329 -0
  520. data/mlx/mlx/distributed/jaccl/utils.h +342 -0
  521. data/mlx/mlx/distributed/mpi/CMakeLists.txt +5 -0
  522. data/mlx/mlx/distributed/mpi/mpi.cpp +501 -0
  523. data/mlx/mlx/distributed/mpi/mpi.h +12 -0
  524. data/mlx/mlx/distributed/mpi/mpi_declarations.h +28 -0
  525. data/mlx/mlx/distributed/mpi/no_mpi.cpp +20 -0
  526. data/mlx/mlx/distributed/nccl/CMakeLists.txt +26 -0
  527. data/mlx/mlx/distributed/nccl/nccl.cpp +443 -0
  528. data/mlx/mlx/distributed/nccl/nccl.h +12 -0
  529. data/mlx/mlx/distributed/nccl/nccl_stub/CMakeLists.txt +1 -0
  530. data/mlx/mlx/distributed/nccl/nccl_stub/nccl_stubs.cpp +54 -0
  531. data/mlx/mlx/distributed/nccl/no_nccl.cpp +20 -0
  532. data/mlx/mlx/distributed/ops.cpp +186 -0
  533. data/mlx/mlx/distributed/ops.h +57 -0
  534. data/mlx/mlx/distributed/primitives.cpp +95 -0
  535. data/mlx/mlx/distributed/primitives.h +156 -0
  536. data/mlx/mlx/distributed/reduction_ops.h +38 -0
  537. data/mlx/mlx/distributed/ring/CMakeLists.txt +5 -0
  538. data/mlx/mlx/distributed/ring/no_ring.cpp +20 -0
  539. data/mlx/mlx/distributed/ring/ring.cpp +870 -0
  540. data/mlx/mlx/distributed/ring/ring.h +12 -0
  541. data/mlx/mlx/distributed/utils.cpp +206 -0
  542. data/mlx/mlx/distributed/utils.h +67 -0
  543. data/mlx/mlx/dtype.cpp +197 -0
  544. data/mlx/mlx/dtype.h +116 -0
  545. data/mlx/mlx/dtype_utils.cpp +42 -0
  546. data/mlx/mlx/dtype_utils.h +119 -0
  547. data/mlx/mlx/einsum.cpp +941 -0
  548. data/mlx/mlx/einsum.h +23 -0
  549. data/mlx/mlx/event.h +58 -0
  550. data/mlx/mlx/export.cpp +1130 -0
  551. data/mlx/mlx/export.h +137 -0
  552. data/mlx/mlx/export_impl.h +99 -0
  553. data/mlx/mlx/fast.cpp +941 -0
  554. data/mlx/mlx/fast.h +103 -0
  555. data/mlx/mlx/fast_primitives.h +427 -0
  556. data/mlx/mlx/fence.h +39 -0
  557. data/mlx/mlx/fft.cpp +262 -0
  558. data/mlx/mlx/fft.h +159 -0
  559. data/mlx/mlx/graph_utils.cpp +175 -0
  560. data/mlx/mlx/graph_utils.h +67 -0
  561. data/mlx/mlx/io/CMakeLists.txt +25 -0
  562. data/mlx/mlx/io/gguf.cpp +470 -0
  563. data/mlx/mlx/io/gguf.h +20 -0
  564. data/mlx/mlx/io/gguf_quants.cpp +164 -0
  565. data/mlx/mlx/io/load.cpp +397 -0
  566. data/mlx/mlx/io/load.h +175 -0
  567. data/mlx/mlx/io/no_gguf.cpp +20 -0
  568. data/mlx/mlx/io/no_safetensors.cpp +37 -0
  569. data/mlx/mlx/io/safetensors.cpp +234 -0
  570. data/mlx/mlx/io.h +61 -0
  571. data/mlx/mlx/linalg.cpp +708 -0
  572. data/mlx/mlx/linalg.h +115 -0
  573. data/mlx/mlx/memory.h +80 -0
  574. data/mlx/mlx/mlx.h +25 -0
  575. data/mlx/mlx/ops.cpp +6094 -0
  576. data/mlx/mlx/ops.h +1610 -0
  577. data/mlx/mlx/primitives.cpp +5850 -0
  578. data/mlx/mlx/primitives.h +2525 -0
  579. data/mlx/mlx/random.cpp +492 -0
  580. data/mlx/mlx/random.h +283 -0
  581. data/mlx/mlx/scheduler.cpp +73 -0
  582. data/mlx/mlx/scheduler.h +189 -0
  583. data/mlx/mlx/small_vector.h +540 -0
  584. data/mlx/mlx/stream.h +42 -0
  585. data/mlx/mlx/threadpool.h +133 -0
  586. data/mlx/mlx/transforms.cpp +1065 -0
  587. data/mlx/mlx/transforms.h +231 -0
  588. data/mlx/mlx/transforms_impl.h +88 -0
  589. data/mlx/mlx/types/bf16.h +187 -0
  590. data/mlx/mlx/types/complex.h +113 -0
  591. data/mlx/mlx/types/fp16.h +234 -0
  592. data/mlx/mlx/types/half_types.h +58 -0
  593. data/mlx/mlx/types/limits.h +70 -0
  594. data/mlx/mlx/utils.cpp +302 -0
  595. data/mlx/mlx/utils.h +174 -0
  596. data/mlx/mlx/version.cpp +11 -0
  597. data/mlx/mlx/version.h +22 -0
  598. data/mlx/mlx.pc.in +52 -0
  599. metadata +643 -0
@@ -0,0 +1,868 @@
1
+ // Copyright © 2023-2024 Apple Inc.
2
+
3
+ #include <metal_simdgroup>
4
+ #include <metal_stdlib>
5
+
6
+ #include "mlx/backend/metal/kernels/utils.h"
7
+
8
+ #include "mlx/backend/metal/kernels/steel/utils.h"
9
+
10
+ using namespace metal;
11
+
12
+ ///////////////////////////////////////////////////////////////////////////////
13
+ /// Matrix vector multiplication
14
+ ///////////////////////////////////////////////////////////////////////////////
15
+
16
+ #define MLX_MTL_CONST static constant constexpr const
17
+
18
+ template <typename U>
19
+ struct DefaultAccT {
20
+ using type = float;
21
+ };
22
+ template <>
23
+ struct DefaultAccT<complex64_t> {
24
+ using type = complex64_t;
25
+ };
26
+
27
+ template <
28
+ typename T,
29
+ const int BM, /* Threadgroup rows (in simdgroups) */
30
+ const int BN, /* Threadgroup cols (in simdgroups) */
31
+ const int SM, /* Simdgroup rows (in threads) */
32
+ const int SN, /* Simdgroup cols (in threads) */
33
+ const int TM, /* Thread rows (in elements) */
34
+ const int TN, /* Thread cols (in elements) */
35
+ const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
36
+ typename AccT = typename DefaultAccT<T>::type>
37
+ struct GEMVKernel {
38
+ using acc_type = AccT;
39
+
40
+ MLX_MTL_CONST int threadsM = BM * SM;
41
+ MLX_MTL_CONST int threadsN = BN * SN;
42
+
43
+ MLX_MTL_CONST int blockM = threadsM * TM;
44
+ MLX_MTL_CONST int blockN = threadsN * TN;
45
+
46
+ static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
47
+
48
+ static_assert(
49
+ SN == 4 || SN == 8 || SN == 16 || SN == 32,
50
+ "gemv block must have a width of 4, 8, 16, or 32");
51
+
52
+ // - The matrix of size (M = out_vec_size, K = in_vec_size) is divided up
53
+ // into blocks of (blockM, blockN) divided among threadgroups
54
+ // - Every thread works on a block of (TM, TN)
55
+ // - We assume each threadgroup has (threadsN, threadsM, 1) threads
56
+ //
57
+ // 1. A thread loads TN elements each from mat along TM rows
58
+ // and the corresponding scalar from the vector
59
+ // 2. The thread then multiplies and adds to accumulate its local result for
60
+ // the block
61
+ // 3. At the end, each thread has accumulated results over all blocks across
62
+ // the rows. These are then summed up across the threadgroup
63
+ // 4. Each threadgroup writes its accumulated blockM outputs
64
+ //
65
+ // Edge case handling:
66
+ // - The threadgroup with the largest tid has blocks that exceed the matrix
67
+ // * The blocks that start outside the matrix are never read (thread results
68
+ // remain zero)
69
+ // * The last thread that partially overlaps with the matrix is shifted
70
+ // inwards such that the thread block fits exactly in the matrix
71
+
72
+ MLX_MTL_CONST short tgp_mem_size = BN > 1 ? BN*(blockM + TM) : 0;
73
+ MLX_MTL_CONST bool needs_tgp_reduction = BN > 1;
74
+
75
+ template <typename U = T>
76
+ static METAL_FUNC void
77
+ load_unsafe(const device T* src, thread U dst[TN], const int src_offset = 0) {
78
+ MLX_MTL_PRAGMA_UNROLL
79
+ for (int tn = 0; tn < TN; tn++) {
80
+ dst[tn] = static_cast<U>(src[src_offset + tn]);
81
+ }
82
+ }
83
+
84
+ template <typename U = T>
85
+ static METAL_FUNC void load_safe(
86
+ const device T* src,
87
+ thread U dst[TN],
88
+ const int src_offset = 0,
89
+ const int src_size = TN) {
90
+ if (src_offset + TN <= src_size) {
91
+ MLX_MTL_PRAGMA_UNROLL
92
+ for (int tn = 0; tn < TN; tn++) {
93
+ dst[tn] = static_cast<U>(src[src_offset + tn]);
94
+ }
95
+ } else { // Edgecase
96
+ MLX_MTL_PRAGMA_UNROLL
97
+ for (int tn = 0; tn < TN; tn++) {
98
+ dst[tn] = src_offset + tn < src_size
99
+ ? static_cast<U>(src[src_offset + tn])
100
+ : U(0);
101
+ }
102
+ }
103
+ }
104
+
105
+ static METAL_FUNC void run(
106
+ const device T* mat [[buffer(0)]],
107
+ const device T* in_vec [[buffer(1)]],
108
+ const device T* bias [[buffer(2)]],
109
+ device T* out_vec [[buffer(3)]],
110
+ const constant int& in_vec_size [[buffer(4)]],
111
+ const constant int& out_vec_size [[buffer(5)]],
112
+ const constant int& matrix_ld [[buffer(6)]],
113
+ const constant float& alpha [[buffer(7)]],
114
+ const constant float& beta [[buffer(8)]],
115
+ const constant int& bias_stride [[buffer(14)]],
116
+ threadgroup AccT* tgp_memory [[threadgroup(0)]],
117
+ uint3 tid [[threadgroup_position_in_grid]],
118
+ uint3 lid [[thread_position_in_threadgroup]],
119
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
120
+ uint simd_lid [[thread_index_in_simdgroup]]) {
121
+ // Appease compiler
122
+ (void)lid;
123
+
124
+ // Thread local accumulation results
125
+ thread AccT result[TM] = {0};
126
+ thread T inter[TN];
127
+ thread AccT v_coeff[TN];
128
+
129
+ const int thrM = SN != 32 ? simd_lid / SN : 0;
130
+ const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
131
+
132
+ const int sgN = BN != 1 ? (simd_gid % BN) : 0;
133
+
134
+ const int simdM = BN != 1 ? SM * (simd_gid / BN) : int(SM * simd_gid);
135
+ const int simdN = BN != 1 ? SN * (simd_gid % BN) : 0;
136
+
137
+ int bm = (simdM + thrM) * TM;
138
+ int bn = (simdN + thrN) * TN;
139
+
140
+ // Block position
141
+ int out_row = tid.x * blockM + bm;
142
+
143
+ // Exit simdgroup if rows out of bound
144
+ if (out_row >= out_vec_size)
145
+ return;
146
+
147
+ // Adjust tail simdgroup to ensure in bound reads
148
+ out_row = out_row + TM <= out_vec_size ? out_row : out_vec_size - TM;
149
+
150
+ // Advance matrix
151
+ mat += out_row * matrix_ld;
152
+
153
+ constexpr const uniform<int> loop_stride = make_uniform(blockN);
154
+ const uniform<int> in_size = make_uniform(in_vec_size);
155
+ const uniform<int> n_iter = in_size / loop_stride;
156
+ const uniform<int> last_iter = loop_stride * n_iter;
157
+ const uniform<int> leftover = in_size - last_iter;
158
+
159
+ // Loop over in_vec in blocks of blockN
160
+ for (int i = 0; i < n_iter; ++i) {
161
+ load_unsafe<AccT>(in_vec, v_coeff, bn);
162
+
163
+ // Per thread work loop
164
+ int mat_offset = 0;
165
+ MLX_MTL_PRAGMA_UNROLL
166
+ for (int tm = 0; tm < TM; tm++) {
167
+ // Load for the row
168
+ load_unsafe(mat, inter, mat_offset + bn);
169
+
170
+ // Accumulate results
171
+ MLX_MTL_PRAGMA_UNROLL
172
+ for (int tn = 0; tn < TN; tn++) {
173
+ result[tm] += inter[tn] * v_coeff[tn];
174
+ }
175
+
176
+ mat_offset += matrix_ld;
177
+ }
178
+
179
+ bn += blockN;
180
+ }
181
+
182
+ if (leftover > 0) {
183
+ load_safe<AccT>(in_vec, v_coeff, bn, in_size);
184
+
185
+ // Per thread work loop
186
+ MLX_MTL_PRAGMA_UNROLL
187
+ for (int tm = 0; tm < TM; tm++) {
188
+ // Load for the row
189
+ load_safe(&mat[tm * matrix_ld], inter, bn, in_size);
190
+
191
+ // Accumulate results
192
+ MLX_MTL_PRAGMA_UNROLL
193
+ for (int tn = 0; tn < TN; tn++) {
194
+ result[tm] += inter[tn] * v_coeff[tn];
195
+ }
196
+ }
197
+ }
198
+
199
+ // Simdgroup accumulations
200
+ MLX_MTL_PRAGMA_UNROLL
201
+ for (int tm = 0; tm < TM; tm++) {
202
+ MLX_MTL_PRAGMA_UNROLL
203
+ for (ushort sn = (SN / 2); sn >= 1; sn >>= 1) {
204
+ result[tm] += simd_shuffle_down(result[tm], sn);
205
+ }
206
+ }
207
+
208
+ // Threadgroup accumulation results
209
+ if (needs_tgp_reduction) {
210
+ threadgroup AccT* tgp_results = tgp_memory + sgN * (blockM + TM) + bm;
211
+ if (thrN == 0) {
212
+ MLX_MTL_PRAGMA_UNROLL
213
+ for (int tm = 0; tm < TM; tm++) {
214
+ tgp_results[tm] = result[tm];
215
+ }
216
+
217
+ threadgroup_barrier(mem_flags::mem_none);
218
+
219
+ if (sgN == 0) {
220
+ MLX_MTL_PRAGMA_UNROLL
221
+ for (int sgn = 1; sgn < BN; sgn++) {
222
+ MLX_MTL_PRAGMA_UNROLL
223
+ for (int tm = 0; tm < TM; tm++) {
224
+ result[tm] += tgp_results[sgn * (blockM + TM) + tm];
225
+ }
226
+ }
227
+ }
228
+ }
229
+ }
230
+
231
+ // Write outputs
232
+ if (simdN == 0 && thrN == 0) {
233
+ MLX_MTL_PRAGMA_UNROLL
234
+ for (int tm = 0; tm < TM; tm++) {
235
+ if (kDoAxpby) {
236
+ out_vec[out_row + tm] =
237
+ static_cast<T>(alpha) * static_cast<T>(result[tm]) +
238
+ static_cast<T>(beta) * bias[(out_row + tm) * bias_stride];
239
+ } else {
240
+ out_vec[out_row + tm] = static_cast<T>(result[tm]);
241
+ }
242
+ }
243
+ }
244
+ }
245
+ };
246
+
247
+ ///////////////////////////////////////////////////////////////////////////////
248
+ /// Vector matrix multiplication
249
+ ///////////////////////////////////////////////////////////////////////////////
250
+
251
+ template <
252
+ typename T,
253
+ const int BM, /* Threadgroup rows (in simdgroups) */
254
+ const int BN, /* Threadgroup cols (in simdgroups) */
255
+ const int SM, /* Simdgroup rows (in threads) */
256
+ const int SN, /* Simdgroup cols (in threads) */
257
+ const int TM, /* Thread rows (in elements) */
258
+ const int TN, /* Thread cols (in elements) */
259
+ const bool kDoAxpby, /* Do out = alpha * out + beta * bias */
260
+ typename AccT = typename DefaultAccT<T>::type>
261
+ struct GEMVTKernel {
262
+ using acc_type = AccT;
263
+
264
+ MLX_MTL_CONST int threadsM = BM * SM;
265
+ MLX_MTL_CONST int threadsN = BN * SN;
266
+
267
+ MLX_MTL_CONST int blockM = threadsM * TM;
268
+ MLX_MTL_CONST int blockN = threadsN * TN;
269
+
270
+ static_assert(SM * SN == 32, "simdgroup can only have 32 threads");
271
+
272
+ // - The matrix of size (M = in_vec_size, N = out_vec_size) is divided up
273
+ // into blocks of (blockM, blockN) divided among threadgroups
274
+ // - Every thread works on a block of (TM, TN)
275
+ // - We assume each threadgroup has (threadsN, threadsM, 1) threads
276
+ //
277
+ // 1. A thread loads TN elements each from mat along TM contiguous rows
278
+ // and the corresponding scalar from the vector
279
+ // 2. The thread then accumulates its local result for the block
280
+ // 3. At the end, each thread has accumulated results over all blocks across
281
+ // the rows. These are then summed up across the threadgroup
282
+ // 4. Each threadgroup writes its accumulated BN * TN outputs
283
+ //
284
+ // Edge case handling:
285
+ // - The threadgroup with the largest tid has blocks that exceed the matrix
286
+ // * The blocks that start outside the matrix are never read (thread results
287
+ // remain zero)
288
+ // * The last thread that partially overlaps with the matrix is shifted
289
+ // inwards such that the thread block fits exactly in the matrix
290
+
291
+ MLX_MTL_CONST short tgp_mem_size = BM > 1 ? BM*(blockN + TN) : 0;
292
+ MLX_MTL_CONST bool needs_tgp_reduction = BM > 1;
293
+
294
+ static METAL_FUNC void run(
295
+ const device T* mat [[buffer(0)]],
296
+ const device T* in_vec [[buffer(1)]],
297
+ const device T* bias [[buffer(2)]],
298
+ device T* out_vec [[buffer(3)]],
299
+ const constant int& in_vec_size [[buffer(4)]],
300
+ const constant int& out_vec_size [[buffer(5)]],
301
+ const constant int& marix_ld [[buffer(6)]],
302
+ const constant float& alpha [[buffer(7)]],
303
+ const constant float& beta [[buffer(8)]],
304
+ const constant int& bias_stride [[buffer(14)]],
305
+ threadgroup AccT* tgp_memory [[threadgroup(0)]],
306
+ uint3 tid [[threadgroup_position_in_grid]],
307
+ uint3 lid [[thread_position_in_threadgroup]],
308
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
309
+ uint simd_lid [[thread_index_in_simdgroup]]) {
310
+ // Appease compiler
311
+ (void)lid;
312
+
313
+ // Thread local accumulation results
314
+ AccT result[TN] = {0};
315
+ T inter[TN];
316
+ AccT v_coeff[TM];
317
+ const int thrM = SN != 32 ? simd_lid / SN : 0;
318
+ const int thrN = SN != 32 ? simd_lid % SN : int(simd_lid);
319
+
320
+ const int sgM = BN != 1 ? (simd_gid / BN) : int(simd_gid);
321
+ const int sgN = BN != 1 ? (simd_gid % BN) : 0;
322
+
323
+ const int simdM = SM * sgM;
324
+ const int simdN = SN * sgN;
325
+
326
+ int cm = (simdM + thrM);
327
+ int cn = (simdN + thrN);
328
+
329
+ int bm = cm * TM;
330
+ int bn = cn * TN;
331
+
332
+ int out_col = tid.x * blockN + bn;
333
+
334
+ constexpr const uniform<int> loop_stride = make_uniform(blockM);
335
+ const uniform<int> in_size = make_uniform(in_vec_size);
336
+ const uniform<int> n_iter = in_size / loop_stride;
337
+ const uniform<int> last_iter = loop_stride * n_iter;
338
+ const uniform<int> leftover = in_size - last_iter;
339
+
340
+ // Edgecase handling
341
+ if (out_col < out_vec_size) {
342
+ out_col = out_col + TN < out_vec_size ? out_col : out_vec_size - TN;
343
+
344
+ // Per thread accumulation main loop
345
+ for (int i = 0; i < n_iter; ++i) {
346
+ // Adding a threadgroup_barrier improves performance slightly
347
+ // This is possibly it may help exploit cache better
348
+ threadgroup_barrier(mem_flags::mem_none);
349
+
350
+ MLX_MTL_PRAGMA_UNROLL
351
+ for (int tm = 0; tm < TM; tm++) {
352
+ v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
353
+ }
354
+
355
+ MLX_MTL_PRAGMA_UNROLL
356
+ for (int tm = 0; tm < TM; tm++) {
357
+ auto vc = static_cast<AccT>(v_coeff[tm]);
358
+ for (int tn = 0; tn < TN; tn++) {
359
+ inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
360
+ }
361
+ for (int tn = 0; tn < TN; tn++) {
362
+ result[tn] += vc * inter[tn];
363
+ }
364
+ }
365
+
366
+ bm += blockM;
367
+ }
368
+
369
+ if (leftover > 0) {
370
+ for (int tm = 0; tm < TM && bm + tm < in_vec_size; tm++) {
371
+ v_coeff[tm] = static_cast<AccT>(in_vec[bm + tm]);
372
+
373
+ MLX_MTL_PRAGMA_UNROLL
374
+ for (int tn = 0; tn < TN; tn++) {
375
+ inter[tn] = mat[(bm + tm) * marix_ld + out_col + tn];
376
+ }
377
+
378
+ MLX_MTL_PRAGMA_UNROLL
379
+ for (int tn = 0; tn < TN; tn++) {
380
+ result[tn] += v_coeff[tm] * inter[tn];
381
+ }
382
+ }
383
+ }
384
+ }
385
+
386
+ // Simdgroup accumulations
387
+ MLX_MTL_PRAGMA_UNROLL
388
+ for (int tn = 0; tn < TN; tn++) {
389
+ MLX_MTL_PRAGMA_UNROLL
390
+ for (ushort sm = (SM / 2); sm >= 1; sm >>= 1) {
391
+ result[tn] += simd_shuffle_down(result[tn], SN * sm);
392
+ }
393
+ }
394
+
395
+ // Threadgroup accumulation results
396
+ if (needs_tgp_reduction) {
397
+ threadgroup AccT* tgp_results = tgp_memory + sgM * (blockN + TN) + bn;
398
+ if (thrM == 0) {
399
+ MLX_MTL_PRAGMA_UNROLL
400
+ for (int tn = 0; tn < TN; tn++) {
401
+ tgp_results[tn] = result[tn];
402
+ }
403
+
404
+ threadgroup_barrier(mem_flags::mem_none);
405
+
406
+ if (sgM == 0) {
407
+ MLX_MTL_PRAGMA_UNROLL
408
+ for (int sgm = 1; sgm < BM; sgm++) {
409
+ MLX_MTL_PRAGMA_UNROLL
410
+ for (int tn = 0; tn < TN; tn++) {
411
+ result[tn] += tgp_results[sgm * (blockN + TN) + tn];
412
+ }
413
+ }
414
+ }
415
+ }
416
+ }
417
+
418
+ // Threadgroup accumulation and writing out results
419
+ if (cm == 0 && out_col < out_vec_size) {
420
+ MLX_MTL_PRAGMA_UNROLL
421
+ for (int j = 0; j < TN; j++) {
422
+ if (kDoAxpby) {
423
+ out_vec[out_col + j] =
424
+ static_cast<T>(alpha) * static_cast<T>(result[j]) +
425
+ static_cast<T>(beta) * bias[(out_col + j) * bias_stride];
426
+ } else {
427
+ out_vec[out_col + j] = static_cast<T>(result[j]);
428
+ }
429
+ }
430
+ }
431
+ }
432
+ };
433
+
434
+ ///////////////////////////////////////////////////////////////////////////////
435
+ /// Matrix vector multiplication
436
+ ///////////////////////////////////////////////////////////////////////////////
437
+
438
+ template <
439
+ typename T,
440
+ const int BM, /* Threadgroup rows (in simdgroups) */
441
+ const int BN, /* Threadgroup cols (in simdgroups) */
442
+ const int SM, /* Simdgroup rows (in threads) */
443
+ const int SN, /* Simdgroup cols (in threads) */
444
+ const int TM, /* Thread rows (in elements) */
445
+ const int TN, /* Thread cols (in elements) */
446
+ const bool kDoNCBatch, /* Batch ndim > 1 */
447
+ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
448
+ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv(
449
+ const device T* mat [[buffer(0)]],
450
+ const device T* in_vec [[buffer(1)]],
451
+ const device T* bias [[buffer(2)]],
452
+ device T* out_vec [[buffer(3)]],
453
+ const constant int& in_vec_size [[buffer(4)]],
454
+ const constant int& out_vec_size [[buffer(5)]],
455
+ const constant int& marix_ld [[buffer(6)]],
456
+ const constant float& alpha [[buffer(7)]],
457
+ const constant float& beta [[buffer(8)]],
458
+ const constant int& batch_ndim [[buffer(9)]],
459
+ const constant int* batch_shape [[buffer(10)]],
460
+ const constant int64_t* vector_batch_stride [[buffer(11)]],
461
+ const constant int64_t* matrix_batch_stride [[buffer(12)]],
462
+ const constant int64_t* bias_batch_stride [[buffer(13)]],
463
+ const constant int& bias_stride [[buffer(14)]],
464
+ uint3 tid [[threadgroup_position_in_grid]],
465
+ uint3 lid [[thread_position_in_threadgroup]],
466
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
467
+ uint simd_lid [[thread_index_in_simdgroup]]) {
468
+ using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
469
+ threadgroup typename gemv_kernel::acc_type tgp_memory
470
+ [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
471
+
472
+ // Update batch offsets
473
+ if (kDoNCBatch) {
474
+ in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
475
+ mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
476
+
477
+ if (kDoAxpby) {
478
+ bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
479
+ }
480
+
481
+ } else {
482
+ in_vec += tid.z * vector_batch_stride[0];
483
+ mat += tid.z * matrix_batch_stride[0];
484
+
485
+ if (kDoAxpby) {
486
+ bias += tid.z * bias_batch_stride[0];
487
+ }
488
+ }
489
+
490
+ out_vec += tid.z * out_vec_size;
491
+
492
+ gemv_kernel::run(
493
+ mat,
494
+ in_vec,
495
+ bias,
496
+ out_vec,
497
+ in_vec_size,
498
+ out_vec_size,
499
+ marix_ld,
500
+ alpha,
501
+ beta,
502
+ bias_stride,
503
+ gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
504
+ tid,
505
+ lid,
506
+ simd_gid,
507
+ simd_lid);
508
+ }
509
+
510
+ #define instantiate_gemv_helper( \
511
+ name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
512
+ instantiate_kernel( \
513
+ "gemv_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn "_tm" #tm \
514
+ "_tn" #tn "_nc" #nc "_axpby" #axpby, \
515
+ gemv, \
516
+ itype, \
517
+ bm, \
518
+ bn, \
519
+ sm, \
520
+ sn, \
521
+ tm, \
522
+ tn, \
523
+ nc, \
524
+ axpby)
525
+
526
+ // clang-format off
527
+ #define instantiate_gemv(name, itype, bm, bn, sm, sn, tm, tn) \
528
+ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
529
+ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
530
+ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
531
+ instantiate_gemv_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
532
+
533
+ // clang-format off
534
+ #define instantiate_gemv_blocks(name, itype) \
535
+ instantiate_gemv(name, itype, 1, 8, 1, 32, 4, 4) \
536
+ instantiate_gemv(name, itype, 1, 8, 1, 32, 1, 4) \
537
+ instantiate_gemv(name, itype, 1, 1, 8, 4, 4, 4) \
538
+ instantiate_gemv(name, itype, 1, 1, 8, 4, 1, 4) \
539
+ instantiate_gemv(name, itype, 4, 1, 1, 32, 1, 4) \
540
+ instantiate_gemv(name, itype, 4, 1, 1, 32, 4, 4) \
541
+ instantiate_gemv(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
542
+
543
+ instantiate_gemv_blocks(float32, float);
544
+ instantiate_gemv_blocks(float16, half);
545
+ instantiate_gemv_blocks(bfloat16, bfloat16_t);
546
+ instantiate_gemv_blocks(complex64, complex64_t);
547
+
548
+ template <
549
+ typename T,
550
+ const int BM, /* Threadgroup rows (in simdgroups) */
551
+ const int BN, /* Threadgroup cols (in simdgroups) */
552
+ const int SM, /* Simdgroup rows (in threads) */
553
+ const int SN, /* Simdgroup cols (in threads) */
554
+ const int TM, /* Thread rows (in elements) */
555
+ const int TN> /* Thread cols (in elements) */
556
+ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_gather(
557
+ const device T* mat [[buffer(0)]],
558
+ const device T* in_vec [[buffer(1)]],
559
+ const device T* bias [[buffer(2)]],
560
+ device T* out_vec [[buffer(3)]],
561
+ const constant int& in_vec_size [[buffer(4)]],
562
+ const constant int& out_vec_size [[buffer(5)]],
563
+ const constant int& marix_ld [[buffer(6)]],
564
+ const constant float& alpha [[buffer(7)]],
565
+ const constant float& beta [[buffer(8)]],
566
+ const constant int& batch_ndim [[buffer(9)]],
567
+ const constant int* batch_shape [[buffer(10)]],
568
+ const constant int64_t* index_batch_strides [[buffer(11)]],
569
+ const constant int& vector_batch_ndim [[buffer(12)]],
570
+ const constant int* vector_batch_shape [[buffer(13)]],
571
+ const constant int64_t* vector_batch_stride [[buffer(14)]],
572
+ const constant int& matrix_batch_ndim [[buffer(15)]],
573
+ const constant int* matrix_batch_shape [[buffer(16)]],
574
+ const constant int64_t* matrix_batch_stride [[buffer(17)]],
575
+ const constant uint32_t* vec_indices [[buffer(18)]],
576
+ const constant uint32_t* mat_indices [[buffer(19)]],
577
+ uint3 tid [[threadgroup_position_in_grid]],
578
+ uint3 lid [[thread_position_in_threadgroup]],
579
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
580
+ uint simd_lid [[thread_index_in_simdgroup]]) {
581
+ using gemv_kernel = GEMVKernel<T, BM, BN, SM, SN, TM, TN, false>;
582
+ threadgroup typename gemv_kernel::acc_type tgp_memory
583
+ [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
584
+
585
+ uint32_t indx_vec;
586
+ uint32_t indx_mat;
587
+
588
+ // Update batch offsets
589
+ if (batch_ndim > 1) {
590
+ const constant auto* veci_bstrides = index_batch_strides;
591
+ const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
592
+
593
+ ulong2 batch_offsets = elem_to_loc_broadcast(
594
+ tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
595
+
596
+ indx_vec = vec_indices[batch_offsets.x];
597
+ indx_mat = mat_indices[batch_offsets.y];
598
+
599
+ } else {
600
+ indx_vec = vec_indices[index_batch_strides[0] * tid.z];
601
+ indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
602
+ }
603
+
604
+ if (vector_batch_ndim > 1) {
605
+ in_vec += elem_to_loc(
606
+ indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
607
+ } else {
608
+ in_vec += indx_vec * vector_batch_stride[0];
609
+ }
610
+
611
+ if (matrix_batch_ndim > 1) {
612
+ mat += elem_to_loc(
613
+ indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
614
+ } else {
615
+ mat += indx_mat * matrix_batch_stride[0];
616
+ }
617
+
618
+ out_vec += tid.z * out_vec_size;
619
+
620
+ gemv_kernel::run(
621
+ mat,
622
+ in_vec,
623
+ bias,
624
+ out_vec,
625
+ in_vec_size,
626
+ out_vec_size,
627
+ marix_ld,
628
+ alpha,
629
+ beta,
630
+ batch_ndim, // Not used
631
+ gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
632
+ tid,
633
+ lid,
634
+ simd_gid,
635
+ simd_lid);
636
+ }
637
+
638
+ // clang-format off
639
+ #define instantiate_gemv_bs_helper(nm, itype, bm, bn, sm, sn, tm, tn) \
640
+ instantiate_kernel( \
641
+ "gemv_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
642
+ "_sn" #sn "_tm" #tm "_tn" #tn, \
643
+ gemv_gather, itype, bm, bn, sm, sn, tm, tn)
644
+
645
+ #define instantiate_gemv_bs_blocks(name, itype) \
646
+ instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 1, 4) \
647
+ instantiate_gemv_bs_helper(name, itype, 4, 1, 1, 32, 4, 4) \
648
+ instantiate_gemv_bs_helper(name, itype, 8, 1, 1, 32, 4, 4) // clang-format on
649
+
650
+ instantiate_gemv_bs_blocks(float32, float);
651
+ instantiate_gemv_bs_blocks(float16, half);
652
+ instantiate_gemv_bs_blocks(bfloat16, bfloat16_t);
653
+ instantiate_gemv_bs_blocks(complex64, complex64_t);
654
+
655
+ ///////////////////////////////////////////////////////////////////////////////
656
+ /// Vector matrix multiplication
657
+ ///////////////////////////////////////////////////////////////////////////////
658
+
659
+ template <
660
+ typename T,
661
+ const int BM, /* Threadgroup rows (in simdgroups) */
662
+ const int BN, /* Threadgroup cols (in simdgroups) */
663
+ const int SM, /* Simdgroup rows (in threads) */
664
+ const int SN, /* Simdgroup cols (in threads) */
665
+ const int TM, /* Thread rows (in elements) */
666
+ const int TN, /* Thread cols (in elements) */
667
+ const bool kDoNCBatch, /* Batch ndim > 1 */
668
+ const bool kDoAxpby> /* Do out = alpha * out + beta * bias */
669
+ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t(
670
+ const device T* mat [[buffer(0)]],
671
+ const device T* in_vec [[buffer(1)]],
672
+ const device T* bias [[buffer(2)]],
673
+ device T* out_vec [[buffer(3)]],
674
+ const constant int& in_vec_size [[buffer(4)]],
675
+ const constant int& out_vec_size [[buffer(5)]],
676
+ const constant int& marix_ld [[buffer(6)]],
677
+ const constant float& alpha [[buffer(7)]],
678
+ const constant float& beta [[buffer(8)]],
679
+ const constant int& batch_ndim [[buffer(9)]],
680
+ const constant int* batch_shape [[buffer(10)]],
681
+ const constant int64_t* vector_batch_stride [[buffer(11)]],
682
+ const constant int64_t* matrix_batch_stride [[buffer(12)]],
683
+ const constant int64_t* bias_batch_stride [[buffer(13)]],
684
+ const constant int& bias_stride [[buffer(14)]],
685
+ uint3 tid [[threadgroup_position_in_grid]],
686
+ uint3 lid [[thread_position_in_threadgroup]],
687
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
688
+ uint simd_lid [[thread_index_in_simdgroup]]) {
689
+ using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, kDoAxpby>;
690
+ threadgroup typename gemv_kernel::acc_type tgp_memory
691
+ [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
692
+
693
+ // Update batch offsets
694
+ if (kDoNCBatch) {
695
+ in_vec += elem_to_loc(tid.z, batch_shape, vector_batch_stride, batch_ndim);
696
+ mat += elem_to_loc(tid.z, batch_shape, matrix_batch_stride, batch_ndim);
697
+
698
+ if (kDoAxpby) {
699
+ bias += elem_to_loc(tid.z, batch_shape, bias_batch_stride, batch_ndim);
700
+ }
701
+
702
+ } else {
703
+ in_vec += tid.z * vector_batch_stride[0];
704
+ mat += tid.z * matrix_batch_stride[0];
705
+
706
+ if (kDoAxpby) {
707
+ bias += tid.z * bias_batch_stride[0];
708
+ }
709
+ }
710
+
711
+ out_vec += tid.z * out_vec_size;
712
+
713
+ gemv_kernel::run(
714
+ mat,
715
+ in_vec,
716
+ bias,
717
+ out_vec,
718
+ in_vec_size,
719
+ out_vec_size,
720
+ marix_ld,
721
+ alpha,
722
+ beta,
723
+ bias_stride,
724
+ gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
725
+ tid,
726
+ lid,
727
+ simd_gid,
728
+ simd_lid);
729
+ }
730
+
731
+ // clang-format off
732
+ #define instantiate_gemv_t_helper( \
733
+ name, itype, bm, bn, sm, sn, tm, tn, nc, axpby) \
734
+ instantiate_kernel( \
735
+ "gemv_t_" #name "_bm" #bm "_bn" #bn "_sm" #sm "_sn" #sn \
736
+ "_tm" #tm "_tn" #tn "_nc" #nc "_axpby" #axpby, \
737
+ gemv_t, itype, bm, bn, sm, sn, tm, tn, nc, axpby)
738
+
739
+ #define instantiate_gemv_t(name, itype, bm, bn, sm, sn, tm, tn) \
740
+ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 0) \
741
+ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 0, 1) \
742
+ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 0) \
743
+ instantiate_gemv_t_helper(name, itype, bm, bn, sm, sn, tm, tn, 1, 1) // clang-format on
744
+
745
+ // clang-format off
746
+ #define instantiate_gemv_t_blocks(name, itype) \
747
+ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 1) \
748
+ instantiate_gemv_t(name, itype, 1, 2, 8, 4, 4, 4) \
749
+ instantiate_gemv_t(name, itype, 1, 4, 8, 4, 4, 4) \
750
+ instantiate_gemv_t(name, itype, 1, 16, 8, 4, 4, 4) \
751
+ instantiate_gemv_t(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on
752
+
753
+ // clang-format off
754
+ instantiate_gemv_t_blocks(float32, float);
755
+ instantiate_gemv_t_blocks(float16, half);
756
+ instantiate_gemv_t_blocks(bfloat16, bfloat16_t);
757
+ instantiate_gemv_t_blocks(complex64, complex64_t); // clang-format on
758
+
759
+ template <
760
+ typename T,
761
+ const int BM, /* Threadgroup rows (in simdgroups) */
762
+ const int BN, /* Threadgroup cols (in simdgroups) */
763
+ const int SM, /* Simdgroup rows (in threads) */
764
+ const int SN, /* Simdgroup cols (in threads) */
765
+ const int TM, /* Thread rows (in elements) */
766
+ const int TN> /* Thread cols (in elements) */
767
+ [[kernel, max_total_threads_per_threadgroup(BM * BN * 32)]] void gemv_t_gather(
768
+ const device T* mat [[buffer(0)]],
769
+ const device T* in_vec [[buffer(1)]],
770
+ const device T* bias [[buffer(2)]],
771
+ device T* out_vec [[buffer(3)]],
772
+ const constant int& in_vec_size [[buffer(4)]],
773
+ const constant int& out_vec_size [[buffer(5)]],
774
+ const constant int& marix_ld [[buffer(6)]],
775
+ const constant float& alpha [[buffer(7)]],
776
+ const constant float& beta [[buffer(8)]],
777
+ const constant int& batch_ndim [[buffer(9)]],
778
+ const constant int* batch_shape [[buffer(10)]],
779
+ const constant int64_t* index_batch_strides [[buffer(11)]],
780
+ const constant int& vector_batch_ndim [[buffer(12)]],
781
+ const constant int* vector_batch_shape [[buffer(13)]],
782
+ const constant int64_t* vector_batch_stride [[buffer(14)]],
783
+ const constant int& matrix_batch_ndim [[buffer(15)]],
784
+ const constant int* matrix_batch_shape [[buffer(16)]],
785
+ const constant int64_t* matrix_batch_stride [[buffer(17)]],
786
+ const constant uint32_t* vec_indices [[buffer(18)]],
787
+ const constant uint32_t* mat_indices [[buffer(19)]],
788
+ uint3 tid [[threadgroup_position_in_grid]],
789
+ uint3 lid [[thread_position_in_threadgroup]],
790
+ uint simd_gid [[simdgroup_index_in_threadgroup]],
791
+ uint simd_lid [[thread_index_in_simdgroup]]) {
792
+ using gemv_kernel = GEMVTKernel<T, BM, BN, SM, SN, TM, TN, false>;
793
+ threadgroup typename gemv_kernel::acc_type tgp_memory
794
+ [gemv_kernel::tgp_mem_size == 0 ? 1 : gemv_kernel::tgp_mem_size];
795
+
796
+ uint32_t indx_vec;
797
+ uint32_t indx_mat;
798
+
799
+ // Update batch offsets
800
+ if (batch_ndim > 1) {
801
+ const constant auto* veci_bstrides = index_batch_strides;
802
+ const constant auto* mati_bstrides = index_batch_strides + batch_ndim;
803
+
804
+ ulong2 batch_offsets = elem_to_loc_broadcast(
805
+ tid.z, batch_shape, veci_bstrides, mati_bstrides, batch_ndim);
806
+
807
+ indx_vec = vec_indices[batch_offsets.x];
808
+ indx_mat = mat_indices[batch_offsets.y];
809
+
810
+ } else {
811
+ indx_vec = vec_indices[index_batch_strides[0] * tid.z];
812
+ indx_mat = mat_indices[index_batch_strides[batch_ndim] * tid.z];
813
+ }
814
+
815
+ if (vector_batch_ndim > 1) {
816
+ in_vec += elem_to_loc(
817
+ indx_vec, vector_batch_shape, vector_batch_stride, vector_batch_ndim);
818
+ } else {
819
+ in_vec += indx_vec * vector_batch_stride[0];
820
+ }
821
+
822
+ if (matrix_batch_ndim > 1) {
823
+ mat += elem_to_loc(
824
+ indx_mat, matrix_batch_shape, matrix_batch_stride, matrix_batch_ndim);
825
+ } else {
826
+ mat += indx_mat * matrix_batch_stride[0];
827
+ }
828
+
829
+ out_vec += tid.z * out_vec_size;
830
+
831
+ gemv_kernel::run(
832
+ mat,
833
+ in_vec,
834
+ bias,
835
+ out_vec,
836
+ in_vec_size,
837
+ out_vec_size,
838
+ marix_ld,
839
+ alpha,
840
+ beta,
841
+ batch_ndim, // Not used,
842
+ gemv_kernel::tgp_mem_size == 0 ? nullptr : tgp_memory,
843
+ tid,
844
+ lid,
845
+ simd_gid,
846
+ simd_lid);
847
+ }
848
+
849
+ // clang-format off
850
+ #define instantiate_gemv_t_bs_helper( \
851
+ nm, itype, bm, bn, sm, sn, tm, tn) \
852
+ instantiate_kernel( \
853
+ "gemv_t_gather_" #nm "_bm" #bm "_bn" #bn "_sm" #sm \
854
+ "_sn" #sn "_tm" #tm "_tn" #tn, \
855
+ gemv_t_gather, itype, bm, bn, sm, sn, tm, tn)
856
+
857
+ #define instantiate_gemv_t_bs_blocks(name, itype) \
858
+ instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 1) \
859
+ instantiate_gemv_t_bs_helper(name, itype, 1, 2, 8, 4, 4, 4) \
860
+ instantiate_gemv_t_bs_helper(name, itype, 1, 4, 8, 4, 4, 4) \
861
+ instantiate_gemv_t_bs_helper(name, itype, 1, 16, 8, 4, 4, 4) \
862
+ instantiate_gemv_t_bs_helper(name, itype, 1, 16, 4, 8, 4, 4) // clang-format on
863
+
864
+ // clang-format off
865
+ instantiate_gemv_t_bs_blocks(float32, float);
866
+ instantiate_gemv_t_bs_blocks(float16, half);
867
+ instantiate_gemv_t_bs_blocks(bfloat16, bfloat16_t);
868
+ instantiate_gemv_t_bs_blocks(complex64, complex64_t); // clang-format on