whispercpp 1.3.5 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1017) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/LICENSE +1 -1
  5. data/README.md +133 -3
  6. data/Rakefile +18 -3
  7. data/ext/dependencies.rb +10 -4
  8. data/ext/dependencies_for_windows.rb +17 -0
  9. data/ext/extconf.rb +20 -7
  10. data/ext/options.rb +54 -14
  11. data/ext/options_for_windows.rb +51 -0
  12. data/ext/ruby_whisper.c +56 -46
  13. data/ext/ruby_whisper.h +165 -2
  14. data/ext/ruby_whisper_context.c +297 -126
  15. data/ext/ruby_whisper_context_params.c +163 -0
  16. data/ext/ruby_whisper_log_queue.c +180 -0
  17. data/ext/ruby_whisper_log_settable.h +47 -0
  18. data/ext/ruby_whisper_model.c +0 -1
  19. data/ext/ruby_whisper_parakeet.c +49 -0
  20. data/ext/ruby_whisper_parakeet_context.c +304 -0
  21. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  22. data/ext/ruby_whisper_parakeet_model.c +84 -0
  23. data/ext/ruby_whisper_parakeet_params.c +548 -0
  24. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  25. data/ext/ruby_whisper_parakeet_token.c +188 -0
  26. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  27. data/ext/ruby_whisper_params.c +256 -66
  28. data/ext/ruby_whisper_segment.c +6 -7
  29. data/ext/ruby_whisper_token.c +29 -9
  30. data/ext/ruby_whisper_transcribe.cpp +46 -16
  31. data/ext/ruby_whisper_vad_context.c +48 -1
  32. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  33. data/ext/ruby_whisper_vad_params.c +0 -1
  34. data/ext/ruby_whisper_vad_segment.c +0 -1
  35. data/ext/ruby_whisper_vad_segments.c +0 -1
  36. data/ext/sources/CMakeLists.txt +41 -3
  37. data/ext/sources/CMakePresets.json +95 -0
  38. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  39. data/ext/sources/cmake/parakeet.pc.in +10 -0
  40. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  41. data/ext/sources/cmake/whisper.pc.in +1 -1
  42. data/ext/sources/examples/CMakeLists.txt +4 -2
  43. data/ext/sources/examples/bench/bench.cpp +24 -19
  44. data/ext/sources/examples/cli/cli.cpp +51 -9
  45. data/ext/sources/examples/common-ggml.cpp +4 -0
  46. data/ext/sources/examples/common-whisper.cpp +139 -67
  47. data/ext/sources/examples/common-whisper.h +11 -0
  48. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  49. data/ext/sources/examples/miniaudio.h +4507 -2131
  50. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  51. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  52. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  53. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  54. data/ext/sources/examples/server/server.cpp +213 -163
  55. data/ext/sources/ggml/CMakeLists.txt +29 -15
  56. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  57. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  58. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  59. data/ext/sources/ggml/include/ggml-backend.h +73 -11
  60. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  61. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  62. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  63. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  64. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  65. data/ext/sources/ggml/include/ggml-rpc.h +8 -3
  66. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  67. data/ext/sources/ggml/include/ggml.h +155 -16
  68. data/ext/sources/ggml/include/gguf.h +10 -2
  69. data/ext/sources/ggml/src/CMakeLists.txt +25 -5
  70. data/ext/sources/ggml/src/ggml-alloc.c +9 -10
  71. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  72. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  73. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  74. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  75. data/ext/sources/ggml/src/ggml-backend-reg.cpp +40 -86
  76. data/ext/sources/ggml/src/ggml-backend.cpp +114 -10
  77. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  78. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +10 -2
  79. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  80. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  81. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +1016 -442
  82. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +111 -85
  83. data/ext/sources/ggml/src/ggml-cann/common.h +23 -14
  84. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +255 -92
  85. data/ext/sources/ggml/src/ggml-common.h +22 -0
  86. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +68 -34
  87. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +44 -19
  88. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  89. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +101 -101
  90. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +194 -1
  91. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2874 -613
  92. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  93. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  94. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +5480 -840
  95. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1361 -0
  96. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -11
  97. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  98. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +186 -36
  99. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +119 -19
  100. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +112 -26
  101. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  102. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  103. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  104. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +13 -0
  105. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +153 -16
  106. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +17 -0
  107. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  108. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +976 -251
  109. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +671 -266
  110. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +1277 -263
  111. data/ext/sources/ggml/src/ggml-cpu/ops.h +4 -0
  112. data/ext/sources/ggml/src/ggml-cpu/quants.c +95 -0
  113. data/ext/sources/ggml/src/ggml-cpu/quants.h +6 -0
  114. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2893 -679
  115. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  116. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +226 -0
  117. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +114 -19
  118. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  119. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  120. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  121. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  122. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  123. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  124. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  125. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  126. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  127. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  128. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  129. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  130. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  131. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  132. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  133. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  134. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +54 -53
  135. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  136. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +18 -8
  137. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  138. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  139. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +73 -28
  140. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +69 -41
  141. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  142. data/ext/sources/ggml/src/ggml-cuda/common.cuh +359 -29
  143. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  144. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  145. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  146. data/ext/sources/ggml/src/ggml-cuda/convert.cu +94 -27
  147. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  148. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +20 -9
  149. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  150. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +333 -85
  151. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +632 -190
  152. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  153. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +162 -49
  154. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +43 -18
  155. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +44 -14
  156. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  157. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +241 -23
  158. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  159. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  160. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  161. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +312 -0
  162. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  163. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  164. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1454 -599
  165. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  166. data/ext/sources/ggml/src/ggml-cuda/mean.cu +13 -10
  167. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +397 -183
  168. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  169. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +161 -88
  170. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  171. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +522 -431
  172. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +139 -72
  173. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  174. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +608 -88
  175. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -0
  176. data/ext/sources/ggml/src/ggml-cuda/norm.cu +47 -79
  177. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  178. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  179. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +134 -27
  180. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  181. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +7 -17
  182. data/ext/sources/ggml/src/ggml-cuda/rope.cu +244 -137
  183. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  184. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  185. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  186. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  187. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  188. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  189. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  190. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +96 -40
  191. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  192. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  193. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  194. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  195. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +6 -0
  196. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  197. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +2 -0
  198. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  199. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +6 -0
  200. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +2 -0
  201. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  202. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  203. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +2 -0
  204. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  205. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +2 -0
  206. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  207. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  208. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  209. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  210. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  211. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  212. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  213. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  214. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  215. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  216. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  217. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  218. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  219. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  220. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  221. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  222. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  223. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  224. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  225. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -5
  226. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +202 -135
  227. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  228. data/ext/sources/ggml/src/ggml-cuda/unary.cu +86 -2
  229. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +4 -0
  230. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +111 -17
  231. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  232. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +30 -2
  233. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +84 -46
  235. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1612 -753
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +51 -11
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +361 -261
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +294 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +753 -241
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +295 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +471 -296
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +159 -53
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +3 -3
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +372 -0
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +86 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +137 -0
  254. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  255. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  256. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  257. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  258. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  259. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  260. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  261. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  262. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +97 -14
  263. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +163 -67
  264. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  265. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  266. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +308 -0
  267. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +262 -0
  268. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +291 -0
  269. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  270. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +216 -0
  271. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  272. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  273. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  274. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  275. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  276. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  278. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +142 -0
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +18 -1348
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +547 -635
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +3556 -1101
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +475 -269
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +94 -72
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +222 -217
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +432 -0
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +886 -117
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  297. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +40 -0
  302. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +28 -9
  303. data/ext/sources/ggml/src/ggml-impl.h +68 -1
  304. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  305. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  306. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  307. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  308. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +409 -83
  309. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +54 -5
  310. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +254 -52
  311. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +254 -23
  312. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +756 -285
  313. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +7 -4
  314. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +359 -133
  315. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1867 -1123
  316. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  317. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +71 -4
  318. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +14127 -5314
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +97 -88
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +104 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  322. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1978 -67
  323. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  324. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  325. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  326. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  327. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  328. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  329. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  330. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  331. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  332. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  333. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  334. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  335. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  336. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q8_0_f32.cl +129 -0
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +195 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  367. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  368. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  369. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  370. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  371. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  372. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  373. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  374. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  375. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  376. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  377. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  378. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  379. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  380. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  381. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  382. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  383. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  384. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  385. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  386. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  387. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  388. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  389. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +178 -0
  390. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  391. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  392. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  393. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  394. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  395. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  396. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  397. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  398. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  399. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  400. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  401. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +985 -0
  402. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  403. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +380 -0
  404. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  405. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1132 -0
  406. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +956 -0
  407. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  412. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  413. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  414. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  415. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  416. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  417. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  418. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  419. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  420. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  421. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  422. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  423. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  424. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +149 -0
  425. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  426. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  427. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  428. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  429. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  430. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  431. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  432. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +47 -0
  433. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +40 -0
  434. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  435. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  436. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  437. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  438. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  439. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  440. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +317 -0
  441. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  442. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +257 -0
  443. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +86 -0
  444. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +880 -0
  445. data/ext/sources/ggml/src/ggml-openvino/utils.h +143 -0
  446. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  447. data/ext/sources/ggml/src/ggml-quants.c +385 -119
  448. data/ext/sources/ggml/src/ggml-quants.h +6 -0
  449. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  450. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  451. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  452. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  453. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +64 -91
  454. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  455. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +4 -1
  456. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  457. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  458. data/ext/sources/ggml/src/ggml-sycl/common.hpp +356 -11
  459. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +184 -14
  460. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +31 -1
  461. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  462. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  463. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  464. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  465. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  466. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  467. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  468. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  469. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +77 -156
  470. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -2
  471. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  472. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  473. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1181 -0
  474. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +59 -0
  475. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1246 -0
  476. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +674 -0
  477. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +227 -0
  478. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  479. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  480. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  481. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +347 -0
  482. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +9 -0
  483. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  484. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  485. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +1134 -236
  486. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  487. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  488. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  489. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  490. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  491. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  492. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  493. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  494. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +72 -1
  495. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  496. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  497. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  498. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  499. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  500. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  501. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  502. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  503. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  504. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  505. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  506. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  507. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  508. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  509. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  510. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  511. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  512. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  513. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  514. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  515. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  516. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +8 -0
  517. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +8 -0
  518. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +8 -0
  519. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +8 -0
  520. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +8 -0
  521. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +8 -0
  522. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +8 -0
  523. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +8 -0
  524. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +8 -0
  525. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +8 -0
  526. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +8 -0
  527. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +8 -0
  528. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +8 -0
  529. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +8 -0
  530. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +8 -0
  531. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +8 -0
  532. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +8 -0
  533. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +8 -0
  534. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +8 -0
  535. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +8 -0
  536. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +8 -0
  537. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +8 -0
  538. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +8 -0
  539. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +8 -0
  540. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +8 -0
  541. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +8 -0
  542. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +8 -0
  543. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +8 -0
  544. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +8 -0
  545. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +8 -0
  546. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +8 -0
  547. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +8 -0
  548. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +8 -0
  549. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +8 -0
  550. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +8 -0
  551. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +8 -0
  552. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  553. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  554. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  555. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +228 -53
  556. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  557. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  558. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  559. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  560. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  561. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  562. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  563. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  564. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  565. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  566. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  567. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  568. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  569. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  570. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  571. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  572. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  573. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  574. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  575. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  576. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  577. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  578. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +123 -0
  579. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +160 -0
  580. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  581. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +71 -0
  582. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  583. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  584. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  585. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  586. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  587. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  588. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  589. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  590. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  591. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  592. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +99 -0
  593. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  594. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  595. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  596. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +545 -0
  597. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +115 -0
  598. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +12 -1
  599. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +3250 -940
  600. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  601. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  602. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  603. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  604. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  605. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  606. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  607. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  608. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  609. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  610. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  611. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  612. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  613. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  614. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  615. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +533 -180
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +113 -68
  618. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +412 -222
  619. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +222 -83
  620. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  621. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  622. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  623. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  624. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  625. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +189 -0
  626. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  627. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  628. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  629. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  630. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  631. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  632. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  633. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  634. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  635. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  636. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  637. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  638. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  639. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  640. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  641. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +22 -20
  642. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +51 -14
  643. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  644. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  645. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  646. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  647. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  648. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  649. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  650. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +39 -63
  651. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  652. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  653. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  654. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  655. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +13 -7
  656. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  657. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  658. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  659. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +27 -11
  660. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  661. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  662. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +193 -149
  663. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  664. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +3221 -97
  665. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +3493 -1997
  666. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  667. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  668. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  669. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  670. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  671. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +142 -0
  672. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +115 -141
  673. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +93 -0
  674. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  675. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -44
  676. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  677. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +198 -230
  678. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  679. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  680. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  681. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  682. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  683. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  684. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +234 -335
  685. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  686. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  687. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +871 -42
  688. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  689. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  690. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  691. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +149 -0
  692. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +36 -138
  693. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +151 -0
  694. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  695. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  696. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  697. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  698. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  699. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  700. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  701. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  702. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  703. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +15 -40
  704. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  705. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +39 -12
  706. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  707. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  708. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  709. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  710. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  711. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  712. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +213 -0
  713. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  714. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +24 -15
  715. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  716. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +253 -16
  717. data/ext/sources/ggml/src/ggml.c +268 -52
  718. data/ext/sources/ggml/src/gguf.cpp +377 -47
  719. data/ext/sources/include/parakeet.h +342 -0
  720. data/ext/sources/include/whisper.h +10 -0
  721. data/ext/sources/media/matmul.png +0 -0
  722. data/ext/sources/src/CMakeLists.txt +23 -0
  723. data/ext/sources/src/parakeet-arch.h +188 -0
  724. data/ext/sources/src/parakeet.cpp +3838 -0
  725. data/ext/sources/src/whisper.cpp +62 -40
  726. data/extsources.rb +26 -10
  727. data/lib/whisper/log_settable.rb +36 -0
  728. data/lib/whisper/model/uri.rb +13 -1
  729. data/lib/whisper/output.rb +74 -0
  730. data/sig/whisper.rbs +445 -55
  731. data/test/helper.rb +2 -0
  732. data/test/jfk_reader/jfk_reader.c +50 -7
  733. data/test/test_callback.rb +1 -0
  734. data/test/test_context_params.rb +82 -0
  735. data/test/test_package.rb +6 -5
  736. data/test/test_parakeet.rb +28 -0
  737. data/test/test_parakeet_callback.rb +107 -0
  738. data/test/test_parakeet_context.rb +116 -0
  739. data/test/test_parakeet_context_params.rb +24 -0
  740. data/test/test_parakeet_model.rb +21 -0
  741. data/test/test_parakeet_params.rb +78 -0
  742. data/test/test_parakeet_segment.rb +42 -0
  743. data/test/test_parakeet_token.rb +73 -0
  744. data/test/test_params.rb +2 -0
  745. data/test/test_token.rb +11 -0
  746. data/test/test_vad_context.rb +58 -8
  747. data/test/test_vad_segment.rb +1 -1
  748. data/test/test_whisper.rb +44 -6
  749. data/whispercpp.gemspec +2 -2
  750. metadata +426 -280
  751. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  752. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  753. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  754. data/ext/sources/bindings/javascript/package.json +0 -26
  755. data/ext/sources/bindings/javascript/whisper.js +0 -19
  756. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  757. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  758. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  759. data/ext/sources/examples/addon.node/index.js +0 -59
  760. data/ext/sources/examples/addon.node/package.json +0 -16
  761. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  762. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  763. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  764. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  765. data/ext/sources/examples/coi-serviceworker.js +0 -146
  766. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  767. data/ext/sources/examples/command/command.cpp +0 -802
  768. data/ext/sources/examples/command/commands.txt +0 -9
  769. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  770. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  771. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  772. data/ext/sources/examples/generate-karaoke.sh +0 -57
  773. data/ext/sources/examples/helpers.js +0 -191
  774. data/ext/sources/examples/livestream.sh +0 -112
  775. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  776. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  777. data/ext/sources/examples/lsp/whisper.vim +0 -362
  778. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  779. data/ext/sources/examples/python/whisper_processor.py +0 -54
  780. data/ext/sources/examples/server/bench.js +0 -29
  781. data/ext/sources/examples/server.py +0 -120
  782. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  783. data/ext/sources/examples/stream/stream.cpp +0 -437
  784. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  785. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  786. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  787. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  788. data/ext/sources/examples/sycl/build.sh +0 -22
  789. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  790. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  791. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -47
  792. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  793. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -494
  794. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -88
  795. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2559
  796. data/ext/sources/examples/talk-llama/llama-arch.h +0 -586
  797. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -917
  798. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  799. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -876
  800. data/ext/sources/examples/talk-llama/llama-chat.h +0 -70
  801. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3645
  802. data/ext/sources/examples/talk-llama/llama-context.h +0 -360
  803. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  804. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -42
  805. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  806. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  807. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2282
  808. data/ext/sources/examples/talk-llama/llama-graph.h +0 -910
  809. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -241
  810. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -284
  811. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  812. data/ext/sources/examples/talk-llama/llama-impl.h +0 -63
  813. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  814. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  815. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -328
  816. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  817. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2100
  818. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -390
  819. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  820. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  821. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  822. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1167
  823. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  824. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  825. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  826. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -735
  827. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  828. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1247
  829. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -176
  830. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -285
  831. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -37
  832. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -8338
  833. data/ext/sources/examples/talk-llama/llama-model.h +0 -544
  834. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1072
  835. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  836. data/ext/sources/examples/talk-llama/llama-sampling.cpp +0 -3771
  837. data/ext/sources/examples/talk-llama/llama-sampling.h +0 -44
  838. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3900
  839. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -182
  840. data/ext/sources/examples/talk-llama/llama.cpp +0 -1140
  841. data/ext/sources/examples/talk-llama/llama.h +0 -1540
  842. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -191
  843. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  844. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  845. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -138
  846. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  847. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -122
  848. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -144
  849. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -135
  850. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -178
  851. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -160
  852. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  853. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  854. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  855. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  856. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  857. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  858. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  859. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -123
  860. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  861. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -144
  862. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -259
  863. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -134
  864. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  865. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -150
  866. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  867. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  868. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  869. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -113
  870. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  871. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  872. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  873. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  874. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  875. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  876. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  877. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -150
  878. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  879. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  880. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -196
  881. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -211
  882. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +0 -283
  883. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  884. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -141
  885. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  886. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -154
  887. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  888. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  889. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  890. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -175
  891. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  892. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  893. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  894. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -168
  895. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  896. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -55
  897. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -123
  898. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -199
  899. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -124
  900. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  901. data/ext/sources/examples/talk-llama/models/models.h +0 -569
  902. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -116
  903. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  904. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -150
  905. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  906. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  907. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  908. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  909. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  910. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  911. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  912. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  913. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  914. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  915. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  916. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  917. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -316
  918. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  919. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -168
  920. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  921. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  922. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  923. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  924. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -117
  925. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -124
  926. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -873
  927. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -149
  928. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -141
  929. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  930. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  931. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -162
  932. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  933. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  934. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -135
  935. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  936. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  937. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  938. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  939. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  940. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  941. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  942. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  943. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  944. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  945. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  946. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  947. data/ext/sources/examples/talk-llama/speak +0 -40
  948. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  949. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  950. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  951. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  952. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  953. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1147
  954. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  955. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  956. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  957. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  958. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  959. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  960. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  961. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  962. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  963. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  964. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  965. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  966. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  967. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  968. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  969. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  970. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +0 -157
  971. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -165
  972. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  973. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  974. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  975. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  976. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  977. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  978. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  979. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  980. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  981. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  982. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  983. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  984. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -147
  985. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  986. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +0 -907
  987. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +0 -247
  988. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  989. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  990. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  991. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
  992. data/ext/sources/tests/CMakeLists.txt +0 -112
  993. data/ext/sources/tests/earnings21/eval.mk +0 -58
  994. data/ext/sources/tests/earnings21/eval.py +0 -68
  995. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  996. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  997. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  998. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  999. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  1000. data/ext/sources/tests/en-0-ref.txt +0 -1
  1001. data/ext/sources/tests/en-1-ref.txt +0 -1
  1002. data/ext/sources/tests/en-2-ref.txt +0 -1
  1003. data/ext/sources/tests/es-0-ref.txt +0 -1
  1004. data/ext/sources/tests/librispeech/eval.mk +0 -39
  1005. data/ext/sources/tests/librispeech/eval.py +0 -47
  1006. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  1007. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  1008. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  1009. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  1010. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  1011. data/ext/sources/tests/run-tests.sh +0 -130
  1012. data/ext/sources/tests/test-c.c +0 -3
  1013. data/ext/sources/tests/test-vad-full.cpp +0 -56
  1014. data/ext/sources/tests/test-vad.cpp +0 -83
  1015. data/ext/sources/tests/test-whisper.js +0 -58
  1016. data/lib/whisper/context.rb +0 -15
  1017. data/lib/whisper/segment.rb +0 -58
@@ -61,11 +61,24 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
61
61
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
62
62
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
63
63
 
64
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 4, 64, 96, 64, 64, 2, true);
65
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 4, 32, 96, 64, 64, 2, true);
66
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 32, 96, 64, 64, 2, true);
67
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 32, 96, 64, 64, 2, true);
68
+
64
69
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
65
70
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
66
71
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
67
72
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
68
73
 
74
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
75
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
76
+
77
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 128, 1, false);
78
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 128, 1, false);
79
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
80
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
81
+
69
82
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
70
83
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
71
84
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
@@ -80,6 +93,14 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
80
93
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
81
94
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
82
95
 
96
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 128, 128, 128, 1, false);
97
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 32, 128, 128, 128, 1, false);
98
+
99
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
100
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
101
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, false);
102
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 128, 1, false);
103
+
83
104
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
84
105
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
85
106
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
@@ -89,6 +110,11 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
89
110
  }
90
111
 
91
112
  static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
113
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 64, 4, 32, 256, 256, 64, 1, false);
114
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 64, 4, 32, 256, 256, 64, 1, false);
115
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 64, 1, false);
116
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 32, 128, 128, 64, 1, false);
117
+
92
118
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
93
119
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
94
120
  GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
@@ -98,6 +124,110 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
98
124
  return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
99
125
  }
100
126
 
127
+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
128
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 64, 32, 32, 32, 1, true);
129
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
130
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
131
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 1, true);
132
+
133
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 64, 2, 32, 40, 40, 40, 1, true);
134
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 64, 2, 32, 40, 40, 40, 1, true);
135
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
136
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 1, true);
137
+
138
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 64, 2, 32, 48, 48, 48, 1, true);
139
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 64, 2, 32, 48, 48, 48, 1, true);
140
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
141
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 1, true);
142
+
143
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 64, 2, 32, 56, 56, 56, 1, true);
144
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 64, 2, 32, 56, 56, 56, 1, true);
145
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
146
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 1, true);
147
+
148
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 64, 2, 32, 64, 64, 64, 1, true);
149
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 64, 2, 32, 64, 64, 64, 1, true);
150
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
151
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 1, true);
152
+
153
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 64, 2, 32, 96, 64, 64, 1, true);
154
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 64, 2, 32, 96, 64, 64, 1, true);
155
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 128, 2, 64, 96, 64, 64, 1, true);
156
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 128, 2, 64, 96, 64, 64, 1, true);
157
+
158
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 2, 32, 128, 128, 128, 1, true);
159
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 2, 32, 128, 128, 128, 1, true);
160
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 1, true);
161
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 1, true);
162
+
163
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 128, 2, 32, 160, 128, 128, 1, true);
164
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 128, 2, 32, 160, 128, 128, 1, true);
165
+
166
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 128, 3, 64, 96, 64, 128, 1, true);
167
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 128, 3, 64, 96, 64, 128, 1, true);
168
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 128, 2, 32, 128, 128, 128, 1, true);
169
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 128, 2, 32, 128, 128, 128, 1, true);
170
+
171
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 128, 3, 64, 96, 64, 128, 1, true);
172
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 128, 3, 64, 96, 64, 128, 1, true);
173
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, true);
174
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 128, 2, 32, 160, 128, 128, 1, true);
175
+
176
+ return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
177
+ }
178
+
179
+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
180
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 1, 64, 32, 32, 32, 1, true);
181
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 32, 32, 32, 1, true);
182
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 32, 32, 32, 1, true);
183
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 4, 64, 32, 32, 32, 1, true);
184
+
185
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40, 40, 40, 1, true);
186
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40, 40, 40, 1, true);
187
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40, 40, 40, 1, true);
188
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
189
+
190
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48, 48, 48, 1, true);
191
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48, 48, 48, 1, true);
192
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48, 48, 48, 1, true);
193
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
194
+
195
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56, 56, 56, 1, true);
196
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56, 56, 56, 1, true);
197
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56, 56, 56, 1, true);
198
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
199
+
200
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64, 64, 64, 1, true);
201
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64, 64, 64, 1, true);
202
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64, 64, 64, 1, true);
203
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
204
+
205
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 8, 256, 1, 64, 64, 64, 64, 1, true);
206
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 16, 256, 1, 64, 64, 64, 64, 1, true);
207
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 32, 256, 1, 64, 64, 64, 64, 1, true);
208
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(192, 128, 64, 512, 1, 64, 64, 64, 64, 1, true);
209
+
210
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 256, 1, 64, 128, 128, 128, 1, true);
211
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 256, 1, 64, 128, 128, 128, 1, true);
212
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 256, 1, 64, 128, 128, 128, 1, true);
213
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 512, 1, 64, 128, 128, 64, 1, true);
214
+
215
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 32, 256, 1, 64, 160, 128, 128, 1, true);
216
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(320, 256, 64, 256, 1, 64, 160, 128, 128, 1, true);
217
+
218
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 8, 256, 1, 64, 128, 128, 128, 1, true);
219
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 16, 256, 1, 64, 128, 128, 128, 1, true);
220
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 32, 256, 1, 64, 128, 128, 128, 1, true);
221
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(512, 512, 64, 256, 1, 64, 128, 128, 128, 1, true);
222
+
223
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 256, 1, 64, 128, 128, 128, 1, true);
224
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 256, 1, 64, 128, 128, 128, 1, true);
225
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 256, 1, 64, 160, 128, 128, 1, true);
226
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 64, 160, 128, 128, 1, true);
227
+
228
+ return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
229
+ }
230
+
101
231
  static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
102
232
  if (ampere_mma_available(cc)) {
103
233
  return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
@@ -105,6 +235,12 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
105
235
  if (turing_mma_available(cc)) {
106
236
  return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
107
237
  }
238
+ if (amd_mfma_available(cc)) {
239
+ return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
240
+ }
241
+ if (amd_wmma_available(cc)) {
242
+ return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
243
+ }
108
244
  GGML_ASSERT(volta_mma_available(cc));
109
245
  return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
110
246
  }
@@ -114,8 +250,12 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
114
250
  return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
115
251
  #elif defined(TURING_MMA_AVAILABLE)
116
252
  return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
253
+ #elif defined(AMD_MFMA_AVAILABLE)
254
+ return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
117
255
  #elif defined(VOLTA_MMA_AVAILABLE)
118
256
  return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
257
+ #elif defined(AMD_WMMA_AVAILABLE)
258
+ return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
119
259
  #else
120
260
  GGML_UNUSED_VARS(DKQ, DV, ncols);
121
261
  return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
@@ -186,6 +326,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
186
326
  return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
187
327
  }
188
328
 
329
+ static constexpr __device__ int get_cols_per_thread() {
330
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
331
+ return 1; // AMD has a single column per thread.
332
+ #else
333
+ return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
334
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
335
+ }
336
+
337
+ static __host__ int get_cols_per_warp(const int cc) {
338
+ if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
339
+ return 16;
340
+ } else {
341
+ // Volta
342
+ return 32;
343
+ }
344
+ }
345
+
189
346
  // ------------------------------------------------------------------------------------------------------------------
190
347
 
191
348
  static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
@@ -206,21 +363,23 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
206
363
  template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
207
364
  static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
208
365
  const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
366
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
209
367
  // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
210
- // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
368
+ // The minimum granularity is 16 bytes.
369
+ constexpr int h2_per_chunk = 16/sizeof(half2);
370
+ const int chunks_per_row = D2 / h2_per_chunk;
211
371
  if constexpr (use_cp_async) {
372
+ static_assert(warp_size == 32, "bad warp_size");
212
373
  static_assert(!oob_check, "OOB check not compatible with cp_async");
213
374
  constexpr int preload = 64;
214
- constexpr int h2_per_chunk = 16/sizeof(half2);
215
- const int chunks_per_row = D2 / h2_per_chunk;
216
375
 
217
376
  const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
218
377
 
219
378
  auto load = [&] __device__ (auto n) {
220
- const int stride_k = WARP_SIZE >> n;
221
- const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
379
+ const int stride_k = warp_size >> n;
380
+ const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
222
381
  const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
223
- const int stride_i = WARP_SIZE / stride_k;
382
+ const int stride_i = warp_size / stride_k;
224
383
 
225
384
  if (k0_start == k0_stop) {
226
385
  return;
@@ -228,7 +387,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
228
387
 
229
388
  #pragma unroll
230
389
  for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
231
- const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
390
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
232
391
 
233
392
  if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
234
393
  break;
@@ -236,7 +395,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
236
395
 
237
396
  #pragma unroll
238
397
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
239
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
398
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
240
399
 
241
400
  cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
242
401
  }
@@ -250,12 +409,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
250
409
  // 6: max 1*16= 16 bytes, 8 half
251
410
  ggml_cuda_unroll<6>{}(load);
252
411
  } else {
253
- // TODO use ggml_cuda_memcpy_1
412
+ const half2 zero[4] = {{0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}, {0.0f, 0.0f}};
254
413
  auto load = [&] __device__ (const int n) {
255
- const int stride_k = WARP_SIZE >> n;
256
- const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
257
- const int k0_stop = D2 - D2 % (1*stride_k);
258
- const int stride_i = WARP_SIZE / stride_k;
414
+ const int stride_k = 32 >> n;
415
+ const int k0_start = stride_k == 32 ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
416
+ const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
417
+ const int stride_i = warp_size / stride_k;
259
418
 
260
419
  if (k0_start == k0_stop) {
261
420
  return;
@@ -263,7 +422,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
263
422
 
264
423
  #pragma unroll
265
424
  for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
266
- const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
425
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
267
426
 
268
427
  if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
269
428
  break;
@@ -271,17 +430,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
271
430
 
272
431
  #pragma unroll
273
432
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
274
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
433
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
275
434
 
276
- tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
435
+ ggml_cuda_memcpy_1<16>(tile_KV + i*stride_tile + k*4,
436
+ !oob_check || i < i_sup ? KV + i*stride_KV + k*h2_per_chunk : zero);
277
437
  }
278
438
  }
279
439
  };
280
- // 1: max 32* 4=128 bytes, 64 half
281
- // 2: max 16* 4= 64 bytes, 32 half
282
- // 3: max 8* 4= 32 bytes, 16 half
283
- // 4: max 4* 4= 16 bytes, 8 half
284
- ggml_cuda_unroll<4>{}(load);
440
+ // 1: max 32*16=512 bytes, 256 half
441
+ // 2: max 16*16=256 bytes, 128 half
442
+ // 3: max 8*16=128 bytes, 64 half
443
+ // 4: max 4*16= 64 bytes, 32 half
444
+ // 5: max 2*16= 32 bytes, 16 half
445
+ // 6: max 1*16= 16 bytes, 8 half
446
+ ggml_cuda_unroll<6>{}(load);
285
447
  }
286
448
  }
287
449
 
@@ -289,18 +451,19 @@ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_chec
289
451
  static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
290
452
  const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
291
453
  const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
454
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
292
455
  if constexpr (use_cp_async) {
293
- static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
456
+ static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
294
457
  static_assert(!oob_check, "OOB check incompatible with cp_async");
295
458
  constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
296
- constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
459
+ constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
297
460
  constexpr int stride_j = nwarps * cols_per_warp;
298
461
 
299
462
  const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
300
463
 
301
464
  #pragma unroll
302
465
  for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
303
- const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
466
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
304
467
  const int j_vram = fastmodulo(j0 + j_sram, ne01);
305
468
 
306
469
  if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
@@ -309,7 +472,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
309
472
 
310
473
  const int i = 8 * (threadIdx.x % (nbatch_fa/8));
311
474
 
312
- cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
475
+ cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + int64_t(j_vram)*stride_mask + i);
313
476
  }
314
477
  } else if constexpr (oob_check) {
315
478
  #pragma unroll
@@ -322,27 +485,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
322
485
  }
323
486
 
324
487
  #pragma unroll
325
- for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
488
+ for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
326
489
  const int i = i0 + threadIdx.x;
327
490
 
328
- tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
491
+ tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[int64_t(j_vram)*stride_mask + i] : half(0.0f);
329
492
  }
330
493
  }
331
- } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
332
- constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
494
+ } else if constexpr (nbatch_fa < 2*warp_size) {
495
+ constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
333
496
  constexpr int stride_j = nwarps * cols_per_warp;
334
497
  #pragma unroll
335
498
  for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
336
- const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
499
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
337
500
  const int j_vram = fastmodulo(j0 + j_sram, ne01);
338
501
 
339
502
  if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
340
503
  break;
341
504
  }
342
505
 
343
- const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
506
+ const int i = threadIdx.x % (warp_size/cols_per_warp);
344
507
 
345
- ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
508
+ ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + int64_t(j_vram)*stride_mask + 2*i);
346
509
  }
347
510
  } else {
348
511
  #pragma unroll
@@ -355,17 +518,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
355
518
  }
356
519
 
357
520
  #pragma unroll
358
- for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
521
+ for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
359
522
  const int i = i0 + 2*threadIdx.x;
360
523
 
361
- ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
524
+ ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + int64_t(j_vram)*stride_mask + i);
362
525
  }
363
526
  }
364
527
  }
365
528
  }
366
529
 
367
530
  template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
368
- bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
531
+ bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
369
532
  typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
370
533
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
371
534
  const float2 * const __restrict__ Q_f2,
@@ -393,33 +556,34 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
393
556
  const int jt,
394
557
  const int kb0,
395
558
  const int k_VKQ_sup) {
396
- #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
559
+ #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
560
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
397
561
  constexpr int ncols = ncols1 * ncols2;
398
562
  constexpr int cols_per_warp = T_B_KQ::I;
399
- constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
400
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
563
+ constexpr int cols_per_thread = get_cols_per_thread();
564
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
401
565
  constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
402
566
  constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
403
567
  constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
404
568
  constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
405
569
  constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
406
570
 
407
- constexpr int stride_tile_Q = DKQ/2 + 4;
408
571
  constexpr int stride_tile_K = nbatch_K2 + 4;
409
572
 
410
- static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
411
- constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
573
+ constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
412
574
 
413
575
  const int k_VKQ_0 = kb0 * nbatch_fa;
414
576
  #if defined(TURING_MMA_AVAILABLE)
415
577
  T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
578
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
579
+ T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
416
580
  #else // Volta
417
581
  T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
418
582
  #endif // defined(TURING_MMA_AVAILABLE)
419
583
 
420
584
  if constexpr (nstages > 1) {
421
585
  static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
422
- static_assert(!mla, "multi-stage loading not implemented for MLA");
586
+ static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
423
587
  static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
424
588
  constexpr bool use_cp_async = true;
425
589
  cp_async_wait_all();
@@ -434,12 +598,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
434
598
  }
435
599
  }
436
600
 
601
+ // For MLA K and V have the same data.
602
+ // Therefore, iterate over K in reverse and later re-use the data if possible.
437
603
  #pragma unroll
438
- for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
604
+ for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
439
605
  const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
440
- const int k0_diff = k0_stop - k0_start;
441
606
 
442
607
  if constexpr (nstages <= 1) {
608
+ const int k0_diff = k0_stop - k0_start;
443
609
  constexpr bool use_cp_async = nstages == 1;
444
610
  flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
445
611
  (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
@@ -461,13 +627,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
461
627
  if constexpr (cols_per_warp == 8) {
462
628
  mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
463
629
  } else {
464
- // Wide version of KQ_C is column-major => swap A and B.
630
+ // Wide version of KQ_C is column-major
631
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
632
+ // AMD matrix C is column-major.
633
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
634
+ #else
635
+ // swap A and B for CUDA.
465
636
  mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
637
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
466
638
  }
467
639
  }
468
640
  }
469
641
  } else {
470
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
642
+ constexpr int stride_tile_Q = DKQ/2 + 4;
471
643
  #pragma unroll
472
644
  for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
473
645
  load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -479,8 +651,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
479
651
  T_A_KQ K_A;
480
652
  load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
481
653
 
482
- // Wide version of KQ_C is column-major => swap A and B.
483
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
654
+ if constexpr (cols_per_warp == 8) {
655
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
656
+ } else {
657
+ // Wide version of KQ_C is column-major
658
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
659
+ // AMD matrix C is column-major.
660
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
661
+ #else
662
+ // swap A and B for CUDA.
663
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
664
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
665
+ }
484
666
  }
485
667
  }
486
668
  }
@@ -532,7 +714,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
532
714
  #pragma unroll
533
715
  for (int l = 0; l < T_C_KQ::ne; ++l) {
534
716
  if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
535
- KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
717
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
718
+ constexpr int KQ_idx = 0;
719
+ #else
720
+ // Turing + Volta:
721
+ const int KQ_idx = l % 2;
722
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
723
+ KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
536
724
  }
537
725
  }
538
726
  }
@@ -542,7 +730,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
542
730
  for (int col = 0; col < cols_per_thread; ++col) {
543
731
  #pragma unroll
544
732
  for (int offset = 16; offset >= 4; offset >>= 1) {
545
- KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
733
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
546
734
  }
547
735
  }
548
736
 
@@ -552,8 +740,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
552
740
  #pragma unroll
553
741
  for (int l = 0; l < T_C_KQ::ne; ++l) {
554
742
  if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
555
- KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
556
- KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
743
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
744
+ constexpr int KQ_idx = 0;
745
+ #else
746
+ // Turing + Volta:
747
+ const int KQ_idx = l % 2;
748
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
749
+ KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
750
+ KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
557
751
  } else {
558
752
  KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
559
753
  }
@@ -564,6 +758,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
564
758
  #pragma unroll
565
759
  for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
566
760
  const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
761
+
762
+ // The mask is stored as 16 bit half values, loading them as 32 bit half2 values is preferred in terms of speed.
763
+ // However, this is not possible for RDNA3 where 2 consecutive l indices are not consecutive in the mask memory layout.
764
+ #ifdef RDNA3
765
+ #pragma unroll
766
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
767
+ const int i = i0 + T_C_KQ::get_j(l);
768
+ const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l)) / ncols2;
769
+
770
+ KQ_C[i00/(np*T_C_KQ::J)].x[l] += __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
771
+ }
772
+ #else
567
773
  #pragma unroll
568
774
  for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
569
775
  const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
@@ -573,6 +779,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
573
779
  KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
574
780
  KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
575
781
  }
782
+ #endif // RDNA3
576
783
  }
577
784
  }
578
785
 
@@ -584,8 +791,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
584
791
  #pragma unroll
585
792
  for (int l = 0; l < T_C_KQ::ne; ++l) {
586
793
  if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
794
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
795
+ constexpr int KQ_idx = 0;
796
+ #else
587
797
  // Turing + Volta:
588
- KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
798
+ const int KQ_idx = (l/2) % 2;
799
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
800
+ KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
589
801
  }
590
802
  }
591
803
  }
@@ -596,14 +808,22 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
596
808
  // Values per KQ column are spread across 4 threads:
597
809
  constexpr int offset_first = 2;
598
810
  constexpr int offset_last = 1;
599
- #else
811
+ #elif defined(AMD_MFMA_AVAILABLE)
812
+ // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
813
+ constexpr int offset_first = 32;
814
+ constexpr int offset_last = 16;
815
+ #elif defined(AMD_WMMA_AVAILABLE)
816
+ // Values per KQ column are spread across 2 threads:
817
+ constexpr int offset_first = 16;
818
+ constexpr int offset_last = 16;
819
+ #else // Volta
600
820
  // Values per KQ column are spread across 2 threads:
601
821
  constexpr int offset_first = 2;
602
822
  constexpr int offset_last = 2;
603
823
  #endif // defined(TURING_MMA_AVAILABLE)
604
824
  #pragma unroll
605
825
  for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
606
- KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
826
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
607
827
  }
608
828
  }
609
829
 
@@ -612,10 +832,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
612
832
  for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
613
833
  #pragma unroll
614
834
  for (int l = 0; l < T_C_KQ::ne; ++l) {
615
- // Turing + Volta:
616
835
  if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
617
- KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
618
- KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
836
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
837
+ constexpr int KQ_idx = 0;
838
+ #else
839
+ // Turing + Volta:
840
+ const int KQ_idx = (l/2) % 2;
841
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
842
+ KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
843
+ KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
619
844
  } else {
620
845
  KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
621
846
  }
@@ -639,7 +864,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
639
864
 
640
865
  #if defined(TURING_MMA_AVAILABLE)
641
866
  if constexpr (cols_per_warp == 8) {
642
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
867
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
643
868
  #pragma unroll
644
869
  for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
645
870
  #pragma unroll
@@ -660,6 +885,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
660
885
  }
661
886
  }
662
887
  }
888
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
889
+ if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) {
890
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
891
+ #pragma unroll
892
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
893
+ #pragma unroll
894
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
895
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
896
+ }
897
+ }
898
+ } else {
899
+ static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type");
900
+ #pragma unroll
901
+ for (int i = 0; i < DV/T_C_VKQ::J; ++i) {
902
+ #pragma unroll
903
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
904
+ VKQ_C[i].x[l] *= KQ_max_scale[0];
905
+ }
906
+ }
907
+ }
663
908
  #else // Volta
664
909
  const half2 KQ_max_scale_h2 = make_half2(
665
910
  KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
@@ -688,6 +933,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
688
933
  }
689
934
 
690
935
  if constexpr (nstages > 1) {
936
+ static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
691
937
  // Preload K tile for next iteration:
692
938
  constexpr bool use_cp_async = true;
693
939
  cp_async_wait_all();
@@ -703,19 +949,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
703
949
  }
704
950
 
705
951
 
706
- // For MLA K and V have the same data.
707
- // Therefore, iterate over V in reverse and re-use the data if possible.
708
- static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
709
- constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
710
-
711
952
  // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
712
953
  #pragma unroll
713
- for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
714
- const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
715
- const int i0_diff = i0_stop - i0_start;
954
+ for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
955
+ static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
956
+ const int i0_stop = i0_start + 2*nbatch_V2;
716
957
 
717
958
  if constexpr (nstages <= 1) {
718
- if (i0_start < reusable_cutoff) {
959
+ const int i0_diff = i0_stop - i0_start;
960
+ if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
719
961
  constexpr bool use_cp_async = nstages == 1;
720
962
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
721
963
  (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
@@ -725,12 +967,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
725
967
  __syncthreads();
726
968
  }
727
969
  }
728
- const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
970
+ const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
729
971
 
730
- #if defined(TURING_MMA_AVAILABLE)
731
- constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
972
+ #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
732
973
  #pragma unroll
733
- for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
974
+ for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += T_A_VKQ::I) {
734
975
  static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
735
976
  #pragma unroll
736
977
  for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
@@ -739,10 +980,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
739
980
  T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
740
981
  load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
741
982
  if constexpr (T_B_KQ::I == 8) {
742
- mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
983
+ mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]);
743
984
  } else {
744
- // Wide version of VKQ_C is column-major => swap A and B.
745
- mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
985
+ // Wide version of VKQ_C is column-major.
986
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
987
+ // AMD matrix C is column-major.
988
+ mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], A, B[k00/(np*T_A_VKQ::J)]);
989
+ #else
990
+ // swap A and B for CUDA.
991
+ mma(VKQ_C[i_VKQ_0/T_A_VKQ::I], B[k00/(np*T_A_VKQ::J)], A);
992
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
746
993
  }
747
994
  }
748
995
  }
@@ -761,7 +1008,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
761
1008
  mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
762
1009
  }
763
1010
  }
764
- #endif // defined(TURING_MMA_AVAILABLE)
1011
+ #endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
765
1012
 
766
1013
  if constexpr (nstages <= 1) {
767
1014
  __syncthreads(); // Only needed if tile_K == tile_V.
@@ -774,11 +1021,11 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
774
1021
  tile_Q, tile_K, tile_V, tile_mask,
775
1022
  Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
776
1023
  NO_DEVICE_CODE;
777
- #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1024
+ #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
778
1025
  }
779
1026
 
780
1027
  #if defined(TURING_MMA_AVAILABLE)
781
- template<int ncols> struct mma_tile_sizes {
1028
+ template<int DV, int ncols> struct mma_tile_sizes {
782
1029
  using T_A_KQ = tile<16, 8, half2>; // row-major
783
1030
  using T_B_KQ = tile<16, 8, half2>; // column-major
784
1031
  using T_C_KQ = tile<16, 16, float>; // column-major
@@ -786,7 +1033,7 @@ template<int ncols> struct mma_tile_sizes {
786
1033
  using T_B_VKQ = tile<16, 8, half2>; // column-major
787
1034
  using T_C_VKQ = tile<16, 8, half2>; // column-major
788
1035
  };
789
- template<> struct mma_tile_sizes<8> {
1036
+ template<int DV> struct mma_tile_sizes<DV, 8> {
790
1037
  using T_A_KQ = tile<16, 8, half2>; // row-major
791
1038
  using T_B_KQ = tile< 8, 8, half2>; // column-major
792
1039
  using T_C_KQ = tile<16, 8, float>; // row-major
@@ -794,8 +1041,69 @@ template<> struct mma_tile_sizes<8> {
794
1041
  using T_B_VKQ = tile< 8, 8, half2>; // column-major
795
1042
  using T_C_VKQ = tile<16, 4, half2>; // row-major
796
1043
  };
1044
+ #elif defined(AMD_WMMA_AVAILABLE)
1045
+ #ifdef RDNA3
1046
+ template<int DV, int ncols> struct mma_tile_sizes {
1047
+ using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
1048
+ using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
1049
+ using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
1050
+ using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
1051
+ using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
1052
+ using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR>; // column-major
1053
+ };
1054
+ template<int ncols> struct mma_tile_sizes<80, ncols> {
1055
+ using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
1056
+ using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
1057
+ using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
1058
+ using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
1059
+ using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
1060
+ using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
1061
+ };
1062
+ template<int ncols> struct mma_tile_sizes<112, ncols> {
1063
+ using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
1064
+ using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
1065
+ using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
1066
+ using T_A_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
1067
+ using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // column-major
1068
+ using T_C_VKQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
1069
+ };
1070
+ #else
1071
+ template<int DV, int ncols> struct mma_tile_sizes {
1072
+ using T_A_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major
1073
+ using T_B_KQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major
1074
+ using T_C_KQ = tile<16, 16, float, DATA_LAYOUT_I_MAJOR>; // column-major
1075
+ using T_A_VKQ = tile<32, 8, half2, DATA_LAYOUT_I_MAJOR>; // row-major
1076
+ using T_B_VKQ = tile<16, 8, half2, DATA_LAYOUT_I_MAJOR>; // column-major
1077
+ using T_C_VKQ = tile<16, 16, half2, DATA_LAYOUT_I_MAJOR_SCRAMBLED>; // column-major
1078
+ };
1079
+ template<int ncols> struct mma_tile_sizes<80, ncols> {
1080
+ using T_A_KQ = tile<16, 8, half2>; // row-major
1081
+ using T_B_KQ = tile<16, 8, half2>; // column-major
1082
+ using T_C_KQ = tile<16, 16, float>; // column-major
1083
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
1084
+ using T_B_VKQ = tile<16, 8, half2>; // column-major
1085
+ using T_C_VKQ = tile<16, 8, half2>; // column-major
1086
+ };
1087
+ template<int ncols> struct mma_tile_sizes<112, ncols> {
1088
+ using T_A_KQ = tile<16, 8, half2>; // row-major
1089
+ using T_B_KQ = tile<16, 8, half2>; // column-major
1090
+ using T_C_KQ = tile<16, 16, float>; // column-major
1091
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
1092
+ using T_B_VKQ = tile<16, 8, half2>; // column-major
1093
+ using T_C_VKQ = tile<16, 8, half2>; // column-major
1094
+ };
1095
+ #endif // RDNA3
1096
+ #elif defined(AMD_MFMA_AVAILABLE)
1097
+ template<int DV, int ncols> struct mma_tile_sizes {
1098
+ using T_A_KQ = tile<16, 8, half2>; // row-major
1099
+ using T_B_KQ = tile<16, 8, half2>; // column-major
1100
+ using T_C_KQ = tile<16, 16, float>; // column-major
1101
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
1102
+ using T_B_VKQ = tile<16, 8, half2>; // column-major
1103
+ using T_C_VKQ = tile<16, 8, half2>; // column-major
1104
+ };
797
1105
  #else // Volta
798
- template<int ncols> struct mma_tile_sizes {
1106
+ template<int DV, int ncols> struct mma_tile_sizes {
799
1107
  using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
800
1108
  using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
801
1109
  using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
@@ -805,7 +1113,7 @@ template<int ncols> struct mma_tile_sizes {
805
1113
  };
806
1114
  #endif // defined(TURING_MMA_AVAILABLE)
807
1115
 
808
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
1116
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
809
1117
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
810
1118
  const float2 * const __restrict__ Q_f2,
811
1119
  const half2 * const __restrict__ K_h2,
@@ -819,6 +1127,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
819
1127
  const float logit_softcap,
820
1128
  const uint3 ne01,
821
1129
  const int ne02,
1130
+ const int gqa_ratio,
822
1131
  const int ne11,
823
1132
  const int stride_Q1,
824
1133
  const int stride_Q2,
@@ -826,22 +1135,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
826
1135
  const int stride_V,
827
1136
  const int stride_mask,
828
1137
  const int jt,
1138
+ const int zt_gqa,
829
1139
  const int kb0_start,
830
1140
  const int kb0_stop) {
831
- #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1141
+ #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
832
1142
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
833
1143
 
1144
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
834
1145
  constexpr int ncols = ncols1 * ncols2;
835
- using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
836
- using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
837
- using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ;
838
- using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ;
839
- using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ;
840
- using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
1146
+ using T_A_KQ = typename mma_tile_sizes<DV, ncols>::T_A_KQ;
1147
+ using T_B_KQ = typename mma_tile_sizes<DV, ncols>::T_B_KQ;
1148
+ using T_C_KQ = typename mma_tile_sizes<DV, ncols>::T_C_KQ;
1149
+ using T_A_VKQ = typename mma_tile_sizes<DV, ncols>::T_A_VKQ;
1150
+ using T_B_VKQ = typename mma_tile_sizes<DV, ncols>::T_B_VKQ;
1151
+ using T_C_VKQ = typename mma_tile_sizes<DV, ncols>::T_C_VKQ;
841
1152
 
842
1153
  constexpr int cols_per_warp = T_B_KQ::I;
843
- constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
844
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
1154
+ constexpr int cols_per_thread = get_cols_per_thread();
1155
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
845
1156
  constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
846
1157
  constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
847
1158
  constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -859,8 +1170,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
859
1170
  constexpr int stride_tile_Q = DKQ/2 + 4;
860
1171
  constexpr int stride_tile_K = nbatch_K2 + 4;
861
1172
 
862
- static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
863
- constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
1173
+ constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
864
1174
  constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
865
1175
 
866
1176
  extern __shared__ half2 tile_Q[];
@@ -871,6 +1181,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
871
1181
  T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
872
1182
  #if defined(TURING_MMA_AVAILABLE)
873
1183
  T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
1184
+ #elif defined(AMD_WMMA_AVAILABLE) && defined(RDNA3)
1185
+ T_C_VKQ VKQ_C[DV % 32 != 0 ? DV/T_C_VKQ::J : DV/(2*T_C_VKQ::J)];
1186
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
1187
+ T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
874
1188
  #else // Volta
875
1189
  T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
876
1190
  #endif // defined(TURING_MMA_AVAILABLE)
@@ -887,10 +1201,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
887
1201
  // The loading is done with decreasing granularity for D for better memory bandwidth.
888
1202
  const half2 scale_h2 = make_half2(scale, scale);
889
1203
  #pragma unroll
890
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
891
- const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
1204
+ for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
1205
+ const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
892
1206
  const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
893
- const int stride_jc = WARP_SIZE / stride_k;
1207
+ const int stride_jc = warp_size / stride_k;
894
1208
 
895
1209
  if (k0_start == k0_stop) {
896
1210
  continue;
@@ -898,7 +1212,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
898
1212
 
899
1213
  #pragma unroll
900
1214
  for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
901
- const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
1215
+ const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
902
1216
 
903
1217
  if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
904
1218
  break;
@@ -907,10 +1221,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
907
1221
  const int j = jc / ncols2;
908
1222
  const int c = jc % ncols2;
909
1223
 
910
- if (jt*ncols1 + j < int(ne01.z)) {
1224
+ if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
911
1225
  #pragma unroll
912
1226
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
913
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1227
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
914
1228
 
915
1229
  const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
916
1230
  tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
@@ -918,7 +1232,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
918
1232
  } else {
919
1233
  #pragma unroll
920
1234
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
921
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1235
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
922
1236
 
923
1237
  tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
924
1238
  }
@@ -962,7 +1276,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
962
1276
  constexpr bool last_iter = false;
963
1277
  constexpr int k_VKQ_sup = nbatch_fa;
964
1278
  flash_attn_ext_f16_iter
965
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
1279
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
966
1280
  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
967
1281
  (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
968
1282
  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -971,7 +1285,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
971
1285
  constexpr bool last_iter = true;
972
1286
  const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
973
1287
  flash_attn_ext_f16_iter
974
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
1288
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
975
1289
  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
976
1290
  (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
977
1291
  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -982,7 +1296,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
982
1296
  constexpr bool last_iter = false;
983
1297
  constexpr int k_VKQ_sup = nbatch_fa;
984
1298
  flash_attn_ext_f16_iter
985
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
1299
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
986
1300
  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
987
1301
  (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
988
1302
  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -991,7 +1305,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
991
1305
  constexpr bool last_iter = true;
992
1306
  constexpr int k_VKQ_sup = nbatch_fa;
993
1307
  flash_attn_ext_f16_iter
994
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
1308
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
995
1309
  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
996
1310
  (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
997
1311
  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -1010,6 +1324,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1010
1324
  // The partial sums are spread across 8/4 threads.
1011
1325
  constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
1012
1326
  constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
1327
+ #elif defined(AMD_MFMA_AVAILABLE)
1328
+ // The partial sums are spread across 4 threads (wavefront64, 16 cols).
1329
+ constexpr int offset_first = 32;
1330
+ constexpr int offset_last = 16;
1331
+ #elif defined(AMD_WMMA_AVAILABLE)
1332
+ // The partial sums are spread across 2 threads.
1333
+ constexpr int offset_first = 16;
1334
+ constexpr int offset_last = 16;
1013
1335
  #else // Volta
1014
1336
  // The partial sums are spread across 2 threads.
1015
1337
  constexpr int offset_first = 2;
@@ -1019,19 +1341,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1019
1341
  for (int col = 0; col < cols_per_thread; ++col) {
1020
1342
  #pragma unroll
1021
1343
  for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
1022
- KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
1344
+ KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
1023
1345
  }
1024
1346
  }
1025
1347
  }
1026
1348
 
1027
1349
  // If attention sinks are used, potentially re-scale if KQ_max is small.
1028
- // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
1350
+ // Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum
1029
1351
  // so it's being done unconditionally for every thread.
1030
1352
  if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
1031
1353
  float KQ_max_scale[cols_per_thread];
1032
1354
  #pragma unroll
1033
1355
  for (int col = 0; col < cols_per_thread; ++col) {
1034
- const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
1356
+ const int jc = (threadIdx.y/np)*cols_per_warp + (cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col));
1035
1357
  const float sink = sinks_f[jc % ncols2];
1036
1358
 
1037
1359
  const float KQ_max_new = fmaxf(KQ_max[col], sink);
@@ -1047,7 +1369,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1047
1369
 
1048
1370
  #if defined(TURING_MMA_AVAILABLE)
1049
1371
  if constexpr (cols_per_warp == 8) {
1050
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
1372
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
1051
1373
  #pragma unroll
1052
1374
  for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
1053
1375
  #pragma unroll
@@ -1068,6 +1390,26 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1068
1390
  }
1069
1391
  }
1070
1392
  }
1393
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
1394
+ if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) {
1395
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
1396
+ #pragma unroll
1397
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
1398
+ #pragma unroll
1399
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
1400
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
1401
+ }
1402
+ }
1403
+ } else {
1404
+ static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type");
1405
+ #pragma unroll
1406
+ for (int i = 0; i < DV/T_C_VKQ::J; ++i) {
1407
+ #pragma unroll
1408
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
1409
+ VKQ_C[i].x[l] *= KQ_max_scale[0];
1410
+ }
1411
+ }
1412
+ }
1071
1413
  #else // Volta
1072
1414
  const int col = (threadIdx.x / 2) % 2;
1073
1415
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
@@ -1119,6 +1461,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1119
1461
  const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
1120
1462
  const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
1121
1463
  const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
1464
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
1465
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
1466
+ const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
1467
+ const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
1122
1468
  #else // Volta
1123
1469
  const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
1124
1470
  const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
@@ -1149,14 +1495,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1149
1495
  // Warps with threadIdx.y % np != 0 must NOT return early.
1150
1496
  // All threads must return simultaneously to avoid race conditions with work on the next tile.
1151
1497
 
1152
- constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
1498
+ constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
1153
1499
 
1154
- const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
1500
+ const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
1155
1501
  float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
1156
1502
  float2 meta[nmeta];
1157
1503
  #pragma unroll
1158
1504
  for (int imeta = 0; imeta < nmeta; ++imeta) {
1159
- meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
1505
+ meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
1160
1506
  }
1161
1507
 
1162
1508
  float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
@@ -1166,8 +1512,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1166
1512
  }
1167
1513
  #pragma unroll
1168
1514
  for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1169
- if (offset < WARP_SIZE) {
1170
- KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
1515
+ if (offset < warp_size) {
1516
+ KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
1171
1517
  }
1172
1518
  }
1173
1519
 
@@ -1184,8 +1530,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1184
1530
  }
1185
1531
  #pragma unroll
1186
1532
  for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1187
- if (offset < WARP_SIZE) {
1188
- KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
1533
+ if (offset < warp_size) {
1534
+ KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
1189
1535
  }
1190
1536
  }
1191
1537
 
@@ -1194,19 +1540,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1194
1540
  // Write back combined meta data:
1195
1541
  #pragma unroll
1196
1542
  for (int imeta = 0; imeta < nmeta; ++imeta) {
1197
- if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
1543
+ if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
1198
1544
  // Combined KQ max scale + rowsum.
1199
- meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
1545
+ meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
1200
1546
  }
1201
1547
  }
1202
1548
 
1203
1549
  // Combined KQ max + rowsum.
1204
- static_assert(cols_per_warp <= WARP_SIZE);
1205
- if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1550
+ static_assert(cols_per_warp <= warp_size);
1551
+ if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
1206
1552
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1207
1553
  dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1208
1554
  }
1209
- if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1555
+ if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
1210
1556
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1211
1557
  dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1212
1558
  }
@@ -1220,6 +1566,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1220
1566
  #pragma unroll
1221
1567
  for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
1222
1568
  if constexpr (cols_per_warp == 8) {
1569
+ static_assert(std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>, "bad VKQ type");
1223
1570
  const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
1224
1571
  #pragma unroll
1225
1572
  for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
@@ -1234,14 +1581,45 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1234
1581
  }
1235
1582
  } else {
1236
1583
  const int j0 = threadIdx.y*cols_per_warp;
1584
+ if constexpr (std::is_same_v<decltype(T_C_VKQ::x), half2[T_C_VKQ::ne]>) {
1585
+ if constexpr (T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR) {
1237
1586
  #pragma unroll
1238
- for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
1587
+ for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
1239
1588
  #pragma unroll
1240
- for (int l = 0; l < T_C_VKQ::ne; ++l) {
1241
- const int j = j0 + T_C_VKQ::get_i(l);
1242
- const int k = k1 + T_C_VKQ::get_j(l);
1589
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
1590
+ const int j = j0 + T_C_VKQ::get_i(l);
1591
+ const int k = k1 + T_C_VKQ::get_j(l);
1592
+
1593
+ tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
1594
+ }
1595
+ }
1596
+ } else {
1597
+ static_assert(T_C_VKQ::dl == DATA_LAYOUT_I_MAJOR_SCRAMBLED, "bad T_C_VKQ data layout");
1598
+ using T_C_VKQ_us = tile<T_C_VKQ::I, T_C_VKQ::J, half2, DATA_LAYOUT_I_MAJOR>; // us == unscrambled
1599
+ #pragma unroll
1600
+ for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
1601
+ const T_C_VKQ_us VKQ_C_us = unscramble(VKQ_C[(k00 + k1)/T_C_VKQ::J]);
1602
+ #pragma unroll
1603
+ for (int l = 0; l < T_C_VKQ_us::ne; ++l) {
1604
+ const int j = j0 + T_C_VKQ_us::get_i(l);
1605
+ const int k = k1 + T_C_VKQ_us::get_j(l);
1243
1606
 
1244
- tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
1607
+ tile_Q[j*tile_stride + k] = VKQ_C_us.x[l];
1608
+ }
1609
+ }
1610
+ }
1611
+ } else {
1612
+ static_assert(std::is_same_v<decltype(T_C_VKQ::x), float[T_C_VKQ::ne]>, "bad VKQ type");
1613
+ half * tile_Q_h = (half *) tile_Q;
1614
+ #pragma unroll
1615
+ for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J/2) {
1616
+ #pragma unroll
1617
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
1618
+ const int j = j0 + T_C_VKQ::get_i(l);
1619
+ const int k = 2*k1 + T_C_VKQ::get_j(l);
1620
+
1621
+ tile_Q_h[j*(2*tile_stride) + k] = VKQ_C[(k00 + k1)/(T_C_VKQ::J/2)].x[l];
1622
+ }
1245
1623
  }
1246
1624
  }
1247
1625
  }
@@ -1254,10 +1632,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1254
1632
  float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
1255
1633
 
1256
1634
  #pragma unroll
1257
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
1258
- const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
1635
+ for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
1636
+ const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
1259
1637
  const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
1260
- const int stride_jc = WARP_SIZE / stride_k;
1638
+ const int stride_jc = warp_size / stride_k;
1261
1639
 
1262
1640
  if (k0_start == k0_stop) {
1263
1641
  continue;
@@ -1265,7 +1643,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1265
1643
 
1266
1644
  #pragma unroll
1267
1645
  for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
1268
- const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
1646
+ const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
1269
1647
 
1270
1648
  if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
1271
1649
  break;
@@ -1276,14 +1654,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1276
1654
  const int j_dst = jc_dst / ncols2;
1277
1655
  const int c_dst = jc_dst % ncols2;
1278
1656
 
1279
- if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
1657
+ if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
1280
1658
  continue;
1281
1659
  }
1282
1660
 
1283
1661
  const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
1284
1662
  #pragma unroll
1285
1663
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
1286
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1664
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
1287
1665
 
1288
1666
  float2 dstk_val = make_float2(0.0f, 0.0f);
1289
1667
  #pragma unroll
@@ -1315,24 +1693,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1315
1693
  }
1316
1694
  #else
1317
1695
  GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
1318
- scale, slope, logit_softcap, ne01, ne02,
1696
+ scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
1319
1697
  stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
1320
1698
  jt, kb0_start, kb0_stop);
1321
1699
  NO_DEVICE_CODE;
1322
- #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1700
+ #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
1323
1701
  }
1324
1702
 
1325
- template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
1703
+ template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
1326
1704
  __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
1327
1705
  static __global__ void flash_attn_ext_f16(
1328
- const char * __restrict__ Q,
1329
- const char * __restrict__ K,
1330
- const char * __restrict__ V,
1331
- const char * __restrict__ mask,
1332
- const char * __restrict__ sinks,
1333
- const int * __restrict__ KV_max,
1334
- float * __restrict__ dst,
1335
- float2 * __restrict__ dst_meta,
1706
+ const char * Q_ptr,
1707
+ const char * K_ptr,
1708
+ const char * V_ptr,
1709
+ const char * mask_ptr,
1710
+ const char * sinks_ptr,
1711
+ const int * KV_max_ptr,
1712
+ float * dst_ptr,
1713
+ float2 * dst_meta_ptr,
1336
1714
  const float scale,
1337
1715
  const float max_bias,
1338
1716
  const float m0,
@@ -1346,13 +1724,33 @@ static __global__ void flash_attn_ext_f16(
1346
1724
  const int32_t nb21, const int32_t nb22, const int64_t nb23,
1347
1725
  const int32_t ne31, const int32_t ne32, const int32_t ne33,
1348
1726
  const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1349
- #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1727
+ ggml_cuda_pdl_sync(); // TODO optimize placement
1728
+ #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE))
1729
+ const char * GGML_CUDA_RESTRICT Q = Q_ptr;
1730
+ const char * GGML_CUDA_RESTRICT K = K_ptr;
1731
+ const char * GGML_CUDA_RESTRICT V = V_ptr;
1732
+ const char * GGML_CUDA_RESTRICT mask = mask_ptr;
1733
+ const char * GGML_CUDA_RESTRICT sinks = sinks_ptr;
1734
+ const int * GGML_CUDA_RESTRICT KV_max = KV_max_ptr;
1735
+ float * GGML_CUDA_RESTRICT dst = dst_ptr;
1736
+ float2 * GGML_CUDA_RESTRICT dst_meta = dst_meta_ptr;
1350
1737
 
1351
1738
  // Skip unused kernel variants for faster compilation:
1352
- if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
1739
+ if (use_logit_softcap && !(DKQ == 128 || DKQ == 256 || DKQ == 512)) {
1353
1740
  NO_DEVICE_CODE;
1354
1741
  return;
1355
1742
  }
1743
+ if (DKQ == 192 && ncols2 != 8 && ncols2 != 16) {
1744
+ NO_DEVICE_CODE;
1745
+ return;
1746
+ }
1747
+ #ifdef VOLTA_MMA_AVAILABLE
1748
+ if (ncols1*ncols2 < 32) {
1749
+ NO_DEVICE_CODE;
1750
+ return;
1751
+ }
1752
+ #endif // VOLTA_MMA_AVAILABLE
1753
+
1356
1754
  #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1357
1755
  if (ncols1*ncols2 > 32) {
1358
1756
  NO_DEVICE_CODE;
@@ -1360,12 +1758,25 @@ static __global__ void flash_attn_ext_f16(
1360
1758
  }
1361
1759
  #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1362
1760
 
1363
- static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
1761
+ #if defined(AMD_WMMA_AVAILABLE)
1762
+ if (ncols1*ncols2 < 16 || ncols2 == 1 || DKQ > 128) {
1763
+ NO_DEVICE_CODE;
1764
+ return;
1765
+ }
1766
+ #endif // defined(AMD_WMMA_AVAILABLE)
1767
+
1768
+ #if defined(AMD_MFMA_AVAILABLE)
1769
+ if (ncols1*ncols2 < 16 || DKQ > 256) {
1770
+ NO_DEVICE_CODE;
1771
+ return;
1772
+ }
1773
+ #endif // defined(AMD_MFMA_AVAILABLE)
1364
1774
 
1775
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1365
1776
  constexpr int ncols = ncols1 * ncols2;
1366
1777
  constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
1367
1778
  constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
1368
- constexpr int nwarps = nthreads / WARP_SIZE;
1779
+ constexpr int nwarps = nthreads / warp_size;
1369
1780
 
1370
1781
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
1371
1782
 
@@ -1374,14 +1785,15 @@ static __global__ void flash_attn_ext_f16(
1374
1785
  const int stride_K = nb11 / sizeof(half2);
1375
1786
  const int stride_mask = nb31 / sizeof(half);
1376
1787
 
1377
- const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
1788
+ const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
1378
1789
 
1379
- const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1380
- const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
1790
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1791
+ const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
1792
+ const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
1381
1793
 
1382
1794
  // kbc == k block continuous, current index in continuous ijk space.
1383
- int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1384
- const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1795
+ int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
1796
+ const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
1385
1797
 
1386
1798
  // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1387
1799
  // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1392,22 +1804,24 @@ static __global__ void flash_attn_ext_f16(
1392
1804
  int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
1393
1805
 
1394
1806
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1395
- const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1396
- const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1397
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1807
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
1808
+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1809
+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1810
+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1811
+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
1398
1812
 
1399
- const int head0 = zt * ncols2;
1813
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
1400
1814
 
1401
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1402
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1815
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1816
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
1403
1817
  const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1404
1818
  (const half *) (mask + nb33*(sequence % ne33));
1405
- float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
1819
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
1406
1820
 
1407
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1408
- const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1821
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1822
+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
1409
1823
 
1410
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1824
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
1411
1825
 
1412
1826
  if (KV_max) {
1413
1827
  kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1415,14 +1829,14 @@ static __global__ void flash_attn_ext_f16(
1415
1829
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1416
1830
  if (kb0_start == 0) {
1417
1831
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1418
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
1832
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1419
1833
  (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1420
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1834
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1421
1835
  } else {
1422
1836
  constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
1423
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
1837
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1424
1838
  (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1425
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1839
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1426
1840
  }
1427
1841
 
1428
1842
  kbc += iter_k;
@@ -1436,22 +1850,24 @@ static __global__ void flash_attn_ext_f16(
1436
1850
  return;
1437
1851
  }
1438
1852
 
1439
- const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1440
- const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1441
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1853
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
1854
+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1855
+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1856
+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1857
+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
1442
1858
 
1443
- const int head0 = zt * ncols2;
1859
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
1444
1860
 
1445
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1446
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1861
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1862
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
1447
1863
  const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1448
1864
  (const half *) (mask + nb33*(sequence % ne33));
1449
- float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
1865
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
1450
1866
 
1451
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1452
- const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1867
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1868
+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
1453
1869
 
1454
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1870
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
1455
1871
 
1456
1872
  if (KV_max) {
1457
1873
  kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1459,11 +1875,11 @@ static __global__ void flash_attn_ext_f16(
1459
1875
 
1460
1876
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1461
1877
  constexpr bool needs_fixup = false;
1462
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
1878
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1463
1879
  (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1464
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1880
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1465
1881
  #else
1466
- GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1882
+ GGML_UNUSED_VARS(Q_ptr, K_ptr, V_ptr, mask_ptr, sinks_ptr, KV_max_ptr, dst_ptr, dst_meta_ptr, scale,
1467
1883
  max_bias, m0, m1, n_head_log2, logit_softcap,
1468
1884
  ne00, ne01, ne02, ne03,
1469
1885
  nb01, nb02, nb03,
@@ -1473,7 +1889,7 @@ static __global__ void flash_attn_ext_f16(
1473
1889
  ne31, ne32, ne33,
1474
1890
  nb31, nb32, nb33);
1475
1891
  NO_DEVICE_CODE;
1476
- #endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1892
+ #endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE))
1477
1893
  }
1478
1894
 
1479
1895
  template <int DKQ, int DV, int ncols1, int ncols2>
@@ -1492,10 +1908,11 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1492
1908
  const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
1493
1909
  const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
1494
1910
 
1495
- const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
1496
- const int nwarps = nthreads / WARP_SIZE;
1911
+ const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
1912
+ const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
1913
+ const int nwarps = nthreads / warp_size_host;
1497
1914
 
1498
- constexpr bool mla = DKQ == 576;
1915
+ constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
1499
1916
 
1500
1917
  const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
1501
1918
  const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
@@ -1512,33 +1929,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1512
1929
  float logit_softcap;
1513
1930
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
1514
1931
 
1932
+ #if defined(GGML_USE_HIP)
1933
+ using fattn_kernel_ptr_t = const void*;
1934
+ #else
1935
+ using fattn_kernel_ptr_t = fattn_kernel_t;
1936
+ #endif // defined(GGML_USE_HIP)
1515
1937
  fattn_kernel_t fattn_kernel;
1516
1938
  if (logit_softcap == 0.0f) {
1517
1939
  constexpr bool use_logit_softcap = false;
1518
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
1940
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
1519
1941
 
1520
- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1942
+ #if !defined(GGML_USE_MUSA)
1521
1943
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1522
1944
  if (!shared_memory_limit_raised[id]) {
1523
- CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1945
+ CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1524
1946
  shared_memory_limit_raised[id] = true;
1525
1947
  }
1526
- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1948
+ #endif // !defined(GGML_USE_MUSA)
1527
1949
  } else {
1528
1950
  constexpr bool use_logit_softcap = true;
1529
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
1951
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
1530
1952
 
1531
- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1953
+ #if !defined(GGML_USE_MUSA)
1532
1954
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1533
1955
  if (!shared_memory_limit_raised[id]) {
1534
- CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1956
+ CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1535
1957
  shared_memory_limit_raised[id] = true;
1536
1958
  }
1537
- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1959
+ #endif // !defined(GGML_USE_MUSA)
1538
1960
  }
1539
1961
 
1540
1962
  launch_fattn<DV, ncols1, ncols2>
1541
- (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
1963
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
1542
1964
  }
1543
1965
 
1544
1966
 
@@ -1581,7 +2003,27 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 112, 64)
1581
2003
  DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128, 64)
1582
2004
  DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
1583
2005
 
2006
+ extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 4);
2007
+ extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 4);
2008
+ extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 4);
2009
+ extern DECL_FATTN_MMA_F16_CASE(512, 512, 16, 4);
2010
+ extern DECL_FATTN_MMA_F16_CASE(512, 512, 1, 8);
2011
+ extern DECL_FATTN_MMA_F16_CASE(512, 512, 2, 8);
2012
+ extern DECL_FATTN_MMA_F16_CASE(512, 512, 4, 8);
2013
+ extern DECL_FATTN_MMA_F16_CASE(512, 512, 8, 8);
2014
+
1584
2015
  // The number of viable configurations for Deepseek is very limited:
1585
2016
  extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
1586
2017
  extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
1587
2018
  extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
2019
+
2020
+ // Mistral Small 4 (DKQ=320, DV=256), GQA=32-only build:
2021
+ extern DECL_FATTN_MMA_F16_CASE(320, 256, 1, 32);
2022
+ extern DECL_FATTN_MMA_F16_CASE(320, 256, 2, 32);
2023
+
2024
+ // For GLM 4.7 Flash
2025
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
2026
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
2027
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
2028
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
2029
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);