whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -98,6 +98,57 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co
98
98
  return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
99
99
  }
100
100
 
101
+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_rdna(const int DKQ, const int DV, const int ncols) {
102
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
103
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
104
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
105
+
106
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
107
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
108
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
109
+
110
+ // TODO tune specifically for RDNA
111
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
112
+ }
113
+
114
+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) {
115
+ // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async).
116
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true);
117
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true);
118
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true);
119
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true);
120
+
121
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true);
122
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true);
123
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true);
124
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true);
125
+
126
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true);
127
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true);
128
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true);
129
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true);
130
+
131
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true);
132
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true);
133
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true);
134
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true);
135
+
136
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true);
137
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true);
138
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true);
139
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true);
140
+
141
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true);
142
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true);
143
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true);
144
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true);
145
+
146
+ // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy
147
+ // compile-time static_asserts even though the kernel guard prevents runtime execution.
148
+ // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility.
149
+ return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false);
150
+ }
151
+
101
152
  static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
102
153
  if (ampere_mma_available(cc)) {
103
154
  return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
@@ -105,6 +156,12 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c
105
156
  if (turing_mma_available(cc)) {
106
157
  return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
107
158
  }
159
+ if (amd_mfma_available(cc)) {
160
+ return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
161
+ }
162
+ if (amd_wmma_available(cc)) {
163
+ return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
164
+ }
108
165
  GGML_ASSERT(volta_mma_available(cc));
109
166
  return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
110
167
  }
@@ -114,8 +171,12 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons
114
171
  return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
115
172
  #elif defined(TURING_MMA_AVAILABLE)
116
173
  return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
174
+ #elif defined(AMD_MFMA_AVAILABLE)
175
+ return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols);
117
176
  #elif defined(VOLTA_MMA_AVAILABLE)
118
177
  return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
178
+ #elif defined(AMD_WMMA_AVAILABLE)
179
+ return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols);
119
180
  #else
120
181
  GGML_UNUSED_VARS(DKQ, DV, ncols);
121
182
  return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
@@ -186,6 +247,23 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ,
186
247
  return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
187
248
  }
188
249
 
250
+ static constexpr __device__ int get_cols_per_thread() {
251
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
252
+ return 1; // AMD has a single column per thread.
253
+ #else
254
+ return 2; // This is specifically KQ columns, Volta only has a single VKQ column.
255
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
256
+ }
257
+
258
+ static __host__ int get_cols_per_warp(const int cc) {
259
+ if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) {
260
+ return 16;
261
+ } else {
262
+ // Volta
263
+ return 32;
264
+ }
265
+ }
266
+
189
267
  // ------------------------------------------------------------------------------------------------------------------
190
268
 
191
269
  static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
@@ -206,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c
206
284
  template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
207
285
  static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
208
286
  const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
287
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
209
288
  // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
210
289
  // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
211
290
  if constexpr (use_cp_async) {
@@ -217,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
217
296
  const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV);
218
297
 
219
298
  auto load = [&] __device__ (auto n) {
220
- const int stride_k = WARP_SIZE >> n;
221
- const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
299
+ const int stride_k = warp_size >> n;
300
+ const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k);
222
301
  const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k);
223
- const int stride_i = WARP_SIZE / stride_k;
302
+ const int stride_i = warp_size / stride_k;
224
303
 
225
304
  if (k0_start == k0_stop) {
226
305
  return;
@@ -228,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
228
307
 
229
308
  #pragma unroll
230
309
  for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
231
- const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
310
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
232
311
 
233
312
  if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
234
313
  break;
@@ -236,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
236
315
 
237
316
  #pragma unroll
238
317
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
239
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
318
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
240
319
 
241
320
  cp_async_cg_16<preload>(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk);
242
321
  }
@@ -252,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
252
331
  } else {
253
332
  // TODO use ggml_cuda_memcpy_1
254
333
  auto load = [&] __device__ (const int n) {
255
- const int stride_k = WARP_SIZE >> n;
256
- const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
334
+ const int stride_k = warp_size >> n;
335
+ const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k);
257
336
  const int k0_stop = D2 - D2 % (1*stride_k);
258
- const int stride_i = WARP_SIZE / stride_k;
337
+ const int stride_i = warp_size / stride_k;
259
338
 
260
339
  if (k0_start == k0_stop) {
261
340
  return;
@@ -263,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
263
342
 
264
343
  #pragma unroll
265
344
  for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) {
266
- const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
345
+ const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
267
346
 
268
347
  if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) {
269
348
  break;
@@ -271,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
271
350
 
272
351
  #pragma unroll
273
352
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
274
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
353
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
275
354
 
276
355
  tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
277
356
  }
@@ -289,18 +368,19 @@ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_chec
289
368
  static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
290
369
  const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
291
370
  const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
371
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
292
372
  if constexpr (use_cp_async) {
293
- static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
373
+ static_assert(nbatch_fa <= 8*warp_size && nbatch_fa % 8 == 0, "bad nbatch_fa");
294
374
  static_assert(!oob_check, "OOB check incompatible with cp_async");
295
375
  constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
296
- constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
376
+ constexpr int cols_per_warp = 8*warp_size/nbatch_fa;
297
377
  constexpr int stride_j = nwarps * cols_per_warp;
298
378
 
299
379
  const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
300
380
 
301
381
  #pragma unroll
302
382
  for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
303
- const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
383
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
304
384
  const int j_vram = fastmodulo(j0 + j_sram, ne01);
305
385
 
306
386
  if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
@@ -322,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
322
402
  }
323
403
 
324
404
  #pragma unroll
325
- for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
405
+ for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) {
326
406
  const int i = i0 + threadIdx.x;
327
407
 
328
408
  tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
329
409
  }
330
410
  }
331
- } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
332
- constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
411
+ } else if constexpr (nbatch_fa < 2*warp_size) {
412
+ constexpr int cols_per_warp = 2*warp_size/nbatch_fa;
333
413
  constexpr int stride_j = nwarps * cols_per_warp;
334
414
  #pragma unroll
335
415
  for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
336
- const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
416
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp);
337
417
  const int j_vram = fastmodulo(j0 + j_sram, ne01);
338
418
 
339
419
  if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
340
420
  break;
341
421
  }
342
422
 
343
- const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
423
+ const int i = threadIdx.x % (warp_size/cols_per_warp);
344
424
 
345
425
  ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
346
426
  }
@@ -355,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
355
435
  }
356
436
 
357
437
  #pragma unroll
358
- for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
438
+ for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) {
359
439
  const int i = i0 + 2*threadIdx.x;
360
440
 
361
441
  ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
@@ -365,7 +445,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
365
445
  }
366
446
 
367
447
  template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
368
- bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
448
+ bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
369
449
  typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
370
450
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
371
451
  const float2 * const __restrict__ Q_f2,
@@ -393,11 +473,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
393
473
  const int jt,
394
474
  const int kb0,
395
475
  const int k_VKQ_sup) {
396
- #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
476
+ #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
477
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
397
478
  constexpr int ncols = ncols1 * ncols2;
398
479
  constexpr int cols_per_warp = T_B_KQ::I;
399
- constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
400
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
480
+ constexpr int cols_per_thread = get_cols_per_thread();
481
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
401
482
  constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
402
483
  constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
403
484
  constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
@@ -407,19 +488,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
407
488
  constexpr int stride_tile_Q = DKQ/2 + 4;
408
489
  constexpr int stride_tile_K = nbatch_K2 + 4;
409
490
 
410
- static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
411
- constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
491
+ constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
412
492
 
413
493
  const int k_VKQ_0 = kb0 * nbatch_fa;
414
494
  #if defined(TURING_MMA_AVAILABLE)
415
495
  T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
496
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
497
+ T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
416
498
  #else // Volta
417
499
  T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
418
500
  #endif // defined(TURING_MMA_AVAILABLE)
419
501
 
420
502
  if constexpr (nstages > 1) {
421
503
  static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
422
- static_assert(!mla, "multi-stage loading not implemented for MLA");
504
+ static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
423
505
  static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
424
506
  constexpr bool use_cp_async = true;
425
507
  cp_async_wait_all();
@@ -434,8 +516,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
434
516
  }
435
517
  }
436
518
 
519
+ // For MLA K and V have the same data.
520
+ // Therefore, iterate over K in reverse and later re-use the data if possible.
437
521
  #pragma unroll
438
- for (int k0_start = 0; k0_start < DKQ/2; k0_start += nbatch_K2) {
522
+ for (int k0_start = (DKQ/2-1) - (DKQ/2-1) % nbatch_K2; k0_start >= 0; k0_start -= nbatch_K2) {
439
523
  const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
440
524
  const int k0_diff = k0_stop - k0_start;
441
525
 
@@ -461,13 +545,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
461
545
  if constexpr (cols_per_warp == 8) {
462
546
  mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
463
547
  } else {
464
- // Wide version of KQ_C is column-major => swap A and B.
548
+ // Wide version of KQ_C is column-major
549
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
550
+ // AMD matrix C is column-major.
551
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
552
+ #else
553
+ // swap A and B for CUDA.
465
554
  mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
555
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
466
556
  }
467
557
  }
468
558
  }
469
559
  } else {
470
- static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
471
560
  #pragma unroll
472
561
  for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
473
562
  load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
@@ -479,8 +568,18 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
479
568
  T_A_KQ K_A;
480
569
  load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
481
570
 
482
- // Wide version of KQ_C is column-major => swap A and B.
483
- mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
571
+ if constexpr (cols_per_warp == 8) {
572
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
573
+ } else {
574
+ // Wide version of KQ_C is column-major
575
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
576
+ // AMD matrix C is column-major.
577
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]);
578
+ #else
579
+ // swap A and B for CUDA.
580
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
581
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
582
+ }
484
583
  }
485
584
  }
486
585
  }
@@ -532,7 +631,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
532
631
  #pragma unroll
533
632
  for (int l = 0; l < T_C_KQ::ne; ++l) {
534
633
  if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
535
- KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
634
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
635
+ constexpr int KQ_idx = 0;
636
+ #else
637
+ // Turing + Volta:
638
+ const int KQ_idx = l % 2;
639
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
640
+ KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
536
641
  }
537
642
  }
538
643
  }
@@ -542,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
542
647
  for (int col = 0; col < cols_per_thread; ++col) {
543
648
  #pragma unroll
544
649
  for (int offset = 16; offset >= 4; offset >>= 1) {
545
- KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
650
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
546
651
  }
547
652
  }
548
653
 
@@ -552,8 +657,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
552
657
  #pragma unroll
553
658
  for (int l = 0; l < T_C_KQ::ne; ++l) {
554
659
  if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
555
- KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
556
- KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
660
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
661
+ constexpr int KQ_idx = 0;
662
+ #else
663
+ // Turing + Volta:
664
+ const int KQ_idx = l % 2;
665
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
666
+ KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]);
667
+ KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
557
668
  } else {
558
669
  KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
559
670
  }
@@ -584,8 +695,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
584
695
  #pragma unroll
585
696
  for (int l = 0; l < T_C_KQ::ne; ++l) {
586
697
  if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
698
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
699
+ constexpr int KQ_idx = 0;
700
+ #else
587
701
  // Turing + Volta:
588
- KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
702
+ const int KQ_idx = (l/2) % 2;
703
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
704
+ KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
589
705
  }
590
706
  }
591
707
  }
@@ -596,14 +712,22 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
596
712
  // Values per KQ column are spread across 4 threads:
597
713
  constexpr int offset_first = 2;
598
714
  constexpr int offset_last = 1;
599
- #else
715
+ #elif defined(AMD_MFMA_AVAILABLE)
716
+ // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16).
717
+ constexpr int offset_first = 32;
718
+ constexpr int offset_last = 16;
719
+ #elif defined(AMD_WMMA_AVAILABLE)
720
+ // Values per KQ column are spread across 2 threads:
721
+ constexpr int offset_first = 16;
722
+ constexpr int offset_last = 16;
723
+ #else // Volta
600
724
  // Values per KQ column are spread across 2 threads:
601
725
  constexpr int offset_first = 2;
602
726
  constexpr int offset_last = 2;
603
727
  #endif // defined(TURING_MMA_AVAILABLE)
604
728
  #pragma unroll
605
729
  for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
606
- KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
730
+ KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size));
607
731
  }
608
732
  }
609
733
 
@@ -612,10 +736,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
612
736
  for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
613
737
  #pragma unroll
614
738
  for (int l = 0; l < T_C_KQ::ne; ++l) {
615
- // Turing + Volta:
616
739
  if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
617
- KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
618
- KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
740
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
741
+ constexpr int KQ_idx = 0;
742
+ #else
743
+ // Turing + Volta:
744
+ const int KQ_idx = (l/2) % 2;
745
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
746
+ KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]);
747
+ KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
619
748
  } else {
620
749
  KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
621
750
  }
@@ -639,7 +768,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
639
768
 
640
769
  #if defined(TURING_MMA_AVAILABLE)
641
770
  if constexpr (cols_per_warp == 8) {
642
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
771
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
643
772
  #pragma unroll
644
773
  for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
645
774
  #pragma unroll
@@ -660,6 +789,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
660
789
  }
661
790
  }
662
791
  }
792
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
793
+ const half2 KQ_max_scale_h2 = make_half2(
794
+ KQ_max_scale[0], KQ_max_scale[0]);
795
+ #pragma unroll
796
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
797
+ #pragma unroll
798
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
799
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
800
+ }
801
+ }
663
802
  #else // Volta
664
803
  const half2 KQ_max_scale_h2 = make_half2(
665
804
  KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
@@ -688,6 +827,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
688
827
  }
689
828
 
690
829
  if constexpr (nstages > 1) {
830
+ static_assert(!V_is_K_view, "K data reuse not implemented multi-stage loading");
691
831
  // Preload K tile for next iteration:
692
832
  constexpr bool use_cp_async = true;
693
833
  cp_async_wait_all();
@@ -703,19 +843,20 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
703
843
  }
704
844
 
705
845
 
706
- // For MLA K and V have the same data.
707
- // Therefore, iterate over V in reverse and re-use the data if possible.
708
- static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
709
- constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
846
+ #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
847
+ T_A_VKQ A_identity;
848
+ make_identity_mat(A_identity);
849
+ #endif // defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
710
850
 
711
851
  // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
712
852
  #pragma unroll
713
- for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
714
- const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
715
- const int i0_diff = i0_stop - i0_start;
853
+ for (int i0_start = 0; i0_start < DV; i0_start += 2*nbatch_V2) {
854
+ static_assert(DV % (2*nbatch_V2) == 0, "bad loop size");
855
+ const int i0_stop = i0_start + 2*nbatch_V2;
856
+ const int i0_diff = i0_stop - i0_start;
716
857
 
717
858
  if constexpr (nstages <= 1) {
718
- if (i0_start < reusable_cutoff) {
859
+ if (!V_is_K_view || i0_stop > 2*nbatch_K2) {
719
860
  constexpr bool use_cp_async = nstages == 1;
720
861
  flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
721
862
  (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
@@ -725,9 +866,9 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
725
866
  __syncthreads();
726
867
  }
727
868
  }
728
- const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
869
+ const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2;
729
870
 
730
- #if defined(TURING_MMA_AVAILABLE)
871
+ #if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
731
872
  constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
732
873
  #pragma unroll
733
874
  for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
@@ -737,12 +878,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
737
878
  const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
738
879
 
739
880
  T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
881
+ #if defined(LDMATRIX_TRANS_AVAILABLE)
740
882
  load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
883
+ #elif defined(AMD_MFMA_AVAILABLE)
884
+ // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg].
885
+ // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T.
886
+ // Load with transposed addressing: 4 strided half loads.
887
+ {
888
+ const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2;
889
+ const half * xs0_h = (const half *) xs0;
890
+ const int stride_h = stride_tile_V * 2; // stride in half units
891
+ half * A_h = (half *) A.x;
892
+ #pragma unroll
893
+ for (int l = 0; l < 4; ++l) {
894
+ A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16];
895
+ }
896
+ }
897
+ #else
898
+ // TODO: Try to transpose tile_V when loading gmem to smem.
899
+ // Use mma to transpose T_A_VKQ for RDNA.
900
+ T_A_VKQ A_trans;
901
+ load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
902
+ mma(A, A_trans, A_identity);
903
+ #endif // defined(LDMATRIX_TRANS_AVAILABLE)
741
904
  if constexpr (T_B_KQ::I == 8) {
742
905
  mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
743
906
  } else {
744
- // Wide version of VKQ_C is column-major => swap A and B.
907
+ // Wide version of VKQ_C is column-major.
908
+ #if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
909
+ // AMD matrix C is column-major.
910
+ mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
911
+ #else
912
+ // swap A and B for CUDA.
745
913
  mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
914
+ #endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
746
915
  }
747
916
  }
748
917
  }
@@ -761,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
761
930
  mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
762
931
  }
763
932
  }
764
- #endif // defined(TURING_MMA_AVAILABLE)
933
+ #endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
765
934
 
766
935
  if constexpr (nstages <= 1) {
767
936
  __syncthreads(); // Only needed if tile_K == tile_V.
@@ -774,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
774
943
  tile_Q, tile_K, tile_V, tile_mask,
775
944
  Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
776
945
  NO_DEVICE_CODE;
777
- #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
946
+ #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
778
947
  }
779
948
 
780
949
  #if defined(TURING_MMA_AVAILABLE)
@@ -794,6 +963,15 @@ template<> struct mma_tile_sizes<8> {
794
963
  using T_B_VKQ = tile< 8, 8, half2>; // column-major
795
964
  using T_C_VKQ = tile<16, 4, half2>; // row-major
796
965
  };
966
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
967
+ template<int ncols> struct mma_tile_sizes {
968
+ using T_A_KQ = tile<16, 8, half2>; // row-major
969
+ using T_B_KQ = tile<16, 8, half2>; // column-major
970
+ using T_C_KQ = tile<16, 16, float>; // column-major
971
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
972
+ using T_B_VKQ = tile<16, 8, half2>; // column-major
973
+ using T_C_VKQ = tile<16, 8, half2>; // column-major
974
+ };
797
975
  #else // Volta
798
976
  template<int ncols> struct mma_tile_sizes {
799
977
  using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
@@ -805,7 +983,7 @@ template<int ncols> struct mma_tile_sizes {
805
983
  };
806
984
  #endif // defined(TURING_MMA_AVAILABLE)
807
985
 
808
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
986
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool V_is_K_view, bool needs_fixup, bool is_fixup>
809
987
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
810
988
  const float2 * const __restrict__ Q_f2,
811
989
  const half2 * const __restrict__ K_h2,
@@ -819,6 +997,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
819
997
  const float logit_softcap,
820
998
  const uint3 ne01,
821
999
  const int ne02,
1000
+ const int gqa_ratio,
822
1001
  const int ne11,
823
1002
  const int stride_Q1,
824
1003
  const int stride_Q2,
@@ -826,11 +1005,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
826
1005
  const int stride_V,
827
1006
  const int stride_mask,
828
1007
  const int jt,
1008
+ const int zt_gqa,
829
1009
  const int kb0_start,
830
1010
  const int kb0_stop) {
831
- #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1011
+ #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
832
1012
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
833
1013
 
1014
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
834
1015
  constexpr int ncols = ncols1 * ncols2;
835
1016
  using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
836
1017
  using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
@@ -840,8 +1021,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
840
1021
  using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
841
1022
 
842
1023
  constexpr int cols_per_warp = T_B_KQ::I;
843
- constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
844
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
1024
+ constexpr int cols_per_thread = get_cols_per_thread();
1025
+ constexpr int np = cols_per_warp > ncols ? nwarps : nwarps * cols_per_warp/ncols; // Number of parallel CUDA warps per Q column.
845
1026
  constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
846
1027
  constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
847
1028
  constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
@@ -859,8 +1040,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
859
1040
  constexpr int stride_tile_Q = DKQ/2 + 4;
860
1041
  constexpr int stride_tile_K = nbatch_K2 + 4;
861
1042
 
862
- static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
863
- constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
1043
+ constexpr int stride_tile_V = V_is_K_view ? stride_tile_K : nbatch_V2 + 4;
864
1044
  constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
865
1045
 
866
1046
  extern __shared__ half2 tile_Q[];
@@ -871,6 +1051,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
871
1051
  T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
872
1052
  #if defined(TURING_MMA_AVAILABLE)
873
1053
  T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
1054
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
1055
+ T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
874
1056
  #else // Volta
875
1057
  T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
876
1058
  #endif // defined(TURING_MMA_AVAILABLE)
@@ -887,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
887
1069
  // The loading is done with decreasing granularity for D for better memory bandwidth.
888
1070
  const half2 scale_h2 = make_half2(scale, scale);
889
1071
  #pragma unroll
890
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
891
- const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
1072
+ for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
1073
+ const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k);
892
1074
  const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k);
893
- const int stride_jc = WARP_SIZE / stride_k;
1075
+ const int stride_jc = warp_size / stride_k;
894
1076
 
895
1077
  if (k0_start == k0_stop) {
896
1078
  continue;
@@ -898,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
898
1080
 
899
1081
  #pragma unroll
900
1082
  for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) {
901
- const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
1083
+ const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
902
1084
 
903
1085
  if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) {
904
1086
  break;
@@ -907,10 +1089,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
907
1089
  const int j = jc / ncols2;
908
1090
  const int c = jc % ncols2;
909
1091
 
910
- if (jt*ncols1 + j < int(ne01.z)) {
1092
+ if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) {
911
1093
  #pragma unroll
912
1094
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
913
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1095
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
914
1096
 
915
1097
  const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k];
916
1098
  tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y);
@@ -918,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
918
1100
  } else {
919
1101
  #pragma unroll
920
1102
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
921
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1103
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
922
1104
 
923
1105
  tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f);
924
1106
  }
@@ -962,7 +1144,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
962
1144
  constexpr bool last_iter = false;
963
1145
  constexpr int k_VKQ_sup = nbatch_fa;
964
1146
  flash_attn_ext_f16_iter
965
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
1147
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
966
1148
  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
967
1149
  (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
968
1150
  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -971,7 +1153,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
971
1153
  constexpr bool last_iter = true;
972
1154
  const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
973
1155
  flash_attn_ext_f16_iter
974
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
1156
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
975
1157
  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
976
1158
  (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
977
1159
  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -982,7 +1164,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
982
1164
  constexpr bool last_iter = false;
983
1165
  constexpr int k_VKQ_sup = nbatch_fa;
984
1166
  flash_attn_ext_f16_iter
985
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
1167
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
986
1168
  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
987
1169
  (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
988
1170
  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -991,7 +1173,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
991
1173
  constexpr bool last_iter = true;
992
1174
  constexpr int k_VKQ_sup = nbatch_fa;
993
1175
  flash_attn_ext_f16_iter
994
- <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
1176
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup, last_iter, oob_check,
995
1177
  T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
996
1178
  (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
997
1179
  ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
@@ -1010,6 +1192,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1010
1192
  // The partial sums are spread across 8/4 threads.
1011
1193
  constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
1012
1194
  constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
1195
+ #elif defined(AMD_MFMA_AVAILABLE)
1196
+ // The partial sums are spread across 4 threads (wavefront64, 16 cols).
1197
+ constexpr int offset_first = 32;
1198
+ constexpr int offset_last = 16;
1199
+ #elif defined(AMD_WMMA_AVAILABLE)
1200
+ // The partial sums are spread across 2 threads.
1201
+ constexpr int offset_first = 16;
1202
+ constexpr int offset_last = 16;
1013
1203
  #else // Volta
1014
1204
  // The partial sums are spread across 2 threads.
1015
1205
  constexpr int offset_first = 2;
@@ -1019,13 +1209,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1019
1209
  for (int col = 0; col < cols_per_thread; ++col) {
1020
1210
  #pragma unroll
1021
1211
  for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
1022
- KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE);
1212
+ KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size);
1023
1213
  }
1024
1214
  }
1025
1215
  }
1026
1216
 
1027
1217
  // If attention sinks are used, potentially re-scale if KQ_max is small.
1028
- // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
1218
+ // Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum
1029
1219
  // so it's being done unconditionally for every thread.
1030
1220
  if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
1031
1221
  float KQ_max_scale[cols_per_thread];
@@ -1047,7 +1237,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1047
1237
 
1048
1238
  #if defined(TURING_MMA_AVAILABLE)
1049
1239
  if constexpr (cols_per_warp == 8) {
1050
- const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
1240
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[cols_per_thread - 1]);
1051
1241
  #pragma unroll
1052
1242
  for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
1053
1243
  #pragma unroll
@@ -1068,6 +1258,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1068
1258
  }
1069
1259
  }
1070
1260
  }
1261
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
1262
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]);
1263
+ #pragma unroll
1264
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
1265
+ #pragma unroll
1266
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
1267
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
1268
+ }
1269
+ }
1071
1270
  #else // Volta
1072
1271
  const int col = (threadIdx.x / 2) % 2;
1073
1272
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
@@ -1119,6 +1318,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1119
1318
  const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
1120
1319
  const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
1121
1320
  const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
1321
+ #elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
1322
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0);
1323
+ const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]);
1324
+ const bool thread_should_write = threadIdx.x / 16 < cols_per_thread;
1122
1325
  #else // Volta
1123
1326
  const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
1124
1327
  const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
@@ -1149,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1149
1352
  // Warps with threadIdx.y % np != 0 must NOT return early.
1150
1353
  // All threads must return simultaneously to avoid race conditions with work on the next tile.
1151
1354
 
1152
- constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1;
1355
+ constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1;
1153
1356
 
1154
- const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
1357
+ const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x);
1155
1358
  float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2;
1156
1359
  float2 meta[nmeta];
1157
1360
  #pragma unroll
1158
1361
  for (int imeta = 0; imeta < nmeta; ++imeta) {
1159
- meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2];
1362
+ meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2];
1160
1363
  }
1161
1364
 
1162
1365
  float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps.
@@ -1166,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1166
1369
  }
1167
1370
  #pragma unroll
1168
1371
  for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1169
- if (offset < WARP_SIZE) {
1170
- KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
1372
+ if (offset < warp_size) {
1373
+ KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size));
1171
1374
  }
1172
1375
  }
1173
1376
 
@@ -1184,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1184
1387
  }
1185
1388
  #pragma unroll
1186
1389
  for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) {
1187
- if (offset < WARP_SIZE) {
1188
- KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
1390
+ if (offset < warp_size) {
1391
+ KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size);
1189
1392
  }
1190
1393
  }
1191
1394
 
@@ -1194,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1194
1397
  // Write back combined meta data:
1195
1398
  #pragma unroll
1196
1399
  for (int imeta = 0; imeta < nmeta; ++imeta) {
1197
- if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) {
1400
+ if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) {
1198
1401
  // Combined KQ max scale + rowsum.
1199
- meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
1402
+ meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs);
1200
1403
  }
1201
1404
  }
1202
1405
 
1203
1406
  // Combined KQ max + rowsum.
1204
- static_assert(cols_per_warp <= WARP_SIZE);
1205
- if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1407
+ static_assert(cols_per_warp <= warp_size);
1408
+ if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
1206
1409
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1207
1410
  dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1208
1411
  }
1209
- if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) {
1412
+ if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) {
1210
1413
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1211
1414
  dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
1212
1415
  }
@@ -1254,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1254
1457
  float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2));
1255
1458
 
1256
1459
  #pragma unroll
1257
- for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
1258
- const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
1460
+ for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) {
1461
+ const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k);
1259
1462
  const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k);
1260
- const int stride_jc = WARP_SIZE / stride_k;
1463
+ const int stride_jc = warp_size / stride_k;
1261
1464
 
1262
1465
  if (k0_start == k0_stop) {
1263
1466
  continue;
@@ -1265,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1265
1468
 
1266
1469
  #pragma unroll
1267
1470
  for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) {
1268
- const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
1471
+ const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k);
1269
1472
 
1270
1473
  if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) {
1271
1474
  break;
@@ -1276,14 +1479,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1276
1479
  const int j_dst = jc_dst / ncols2;
1277
1480
  const int c_dst = jc_dst % ncols2;
1278
1481
 
1279
- if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
1482
+ if (!is_fixup && ((ncols1 > 1 && jt*ncols1 + j_dst >= int(ne01.z)) || (ncols2 > 1 && zt_gqa*ncols2 + c_dst >= gqa_ratio))) {
1280
1483
  continue;
1281
1484
  }
1282
1485
 
1283
1486
  const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine;
1284
1487
  #pragma unroll
1285
1488
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
1286
- const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
1489
+ const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k);
1287
1490
 
1288
1491
  float2 dstk_val = make_float2(0.0f, 0.0f);
1289
1492
  #pragma unroll
@@ -1315,14 +1518,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1315
1518
  }
1316
1519
  #else
1317
1520
  GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
1318
- scale, slope, logit_softcap, ne01, ne02,
1521
+ scale, slope, logit_softcap, ne01, ne02, gqa_ratio,
1319
1522
  stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
1320
1523
  jt, kb0_start, kb0_stop);
1321
1524
  NO_DEVICE_CODE;
1322
- #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1525
+ #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)
1323
1526
  }
1324
1527
 
1325
- template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
1528
+ template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool V_is_K_view>
1326
1529
  __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
1327
1530
  static __global__ void flash_attn_ext_f16(
1328
1531
  const char * __restrict__ Q,
@@ -1346,13 +1549,20 @@ static __global__ void flash_attn_ext_f16(
1346
1549
  const int32_t nb21, const int32_t nb22, const int64_t nb23,
1347
1550
  const int32_t ne31, const int32_t ne32, const int32_t ne33,
1348
1551
  const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1349
- #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1552
+ #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
1350
1553
 
1351
1554
  // Skip unused kernel variants for faster compilation:
1352
1555
  if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
1353
1556
  NO_DEVICE_CODE;
1354
1557
  return;
1355
1558
  }
1559
+ #ifdef VOLTA_MMA_AVAILABLE
1560
+ if (ncols1*ncols2 < 32) {
1561
+ NO_DEVICE_CODE;
1562
+ return;
1563
+ }
1564
+ #endif // VOLTA_MMA_AVAILABLE
1565
+
1356
1566
  #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1357
1567
  if (ncols1*ncols2 > 32) {
1358
1568
  NO_DEVICE_CODE;
@@ -1360,12 +1570,25 @@ static __global__ void flash_attn_ext_f16(
1360
1570
  }
1361
1571
  #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
1362
1572
 
1363
- static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
1573
+ #if defined(AMD_WMMA_AVAILABLE)
1574
+ if (ncols1*ncols2 > 32 || ncols1*ncols2 < 16 || DKQ > 128 || ncols2 == 1) {
1575
+ NO_DEVICE_CODE;
1576
+ return;
1577
+ }
1578
+ #endif // defined(AMD_WMMA_AVAILABLE)
1579
+
1580
+ #if defined(AMD_MFMA_AVAILABLE)
1581
+ if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) {
1582
+ NO_DEVICE_CODE;
1583
+ return;
1584
+ }
1585
+ #endif // defined(AMD_MFMA_AVAILABLE)
1364
1586
 
1587
+ constexpr int warp_size = ggml_cuda_get_physical_warp_size();
1365
1588
  constexpr int ncols = ncols1 * ncols2;
1366
1589
  constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
1367
1590
  constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
1368
- constexpr int nwarps = nthreads / WARP_SIZE;
1591
+ constexpr int nwarps = nthreads / warp_size;
1369
1592
 
1370
1593
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
1371
1594
 
@@ -1374,14 +1597,15 @@ static __global__ void flash_attn_ext_f16(
1374
1597
  const int stride_K = nb11 / sizeof(half2);
1375
1598
  const int stride_mask = nb31 / sizeof(half);
1376
1599
 
1377
- const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
1600
+ const int stride_V = V_is_K_view ? stride_K : nb21 / sizeof(half2);
1378
1601
 
1379
- const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1380
- const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
1602
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1603
+ const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
1604
+ const int iter_z_gqa = (gqa_ratio + (ncols2 - 1)) / ncols2;
1381
1605
 
1382
1606
  // kbc == k block continuous, current index in continuous ijk space.
1383
- int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1384
- const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1607
+ int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
1608
+ const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*iter_z_gqa*ne12*ne03) / gridDim.x;
1385
1609
 
1386
1610
  // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1387
1611
  // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1392,22 +1616,24 @@ static __global__ void flash_attn_ext_f16(
1392
1616
  int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
1393
1617
 
1394
1618
  while (kbc < kbc_stop && kb0_stop == iter_k) {
1395
- const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1396
- const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1397
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1619
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index
1620
+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1621
+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1622
+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1623
+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
1398
1624
 
1399
- const int head0 = zt * ncols2;
1625
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
1400
1626
 
1401
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1402
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1627
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1628
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
1403
1629
  const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1404
1630
  (const half *) (mask + nb33*(sequence % ne33));
1405
- float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
1631
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
1406
1632
 
1407
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1408
- const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1633
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1634
+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
1409
1635
 
1410
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1636
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
1411
1637
 
1412
1638
  if (KV_max) {
1413
1639
  kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1415,14 +1641,14 @@ static __global__ void flash_attn_ext_f16(
1415
1641
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1416
1642
  if (kb0_start == 0) {
1417
1643
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1418
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
1644
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1419
1645
  (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1420
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1646
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1421
1647
  } else {
1422
1648
  constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
1423
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
1649
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1424
1650
  (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1425
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1651
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1426
1652
  }
1427
1653
 
1428
1654
  kbc += iter_k;
@@ -1436,22 +1662,24 @@ static __global__ void flash_attn_ext_f16(
1436
1662
  return;
1437
1663
  }
1438
1664
 
1439
- const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
1440
- const int zt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j); // head in units of ncols2
1441
- const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*zt) / iter_k; // j index of current tile.
1665
+ // z_KV == K/V head index, zt_gqa = Q head start index per K/V head, jt = token position start index.
1666
+ const int sequence = kbc /(iter_k*iter_j*iter_z_gqa*ne12);
1667
+ const int z_KV = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence)/(iter_k*iter_j*iter_z_gqa);
1668
+ const int zt_gqa = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV)/(iter_k*iter_j);
1669
+ const int jt = (kbc - iter_k*iter_j*iter_z_gqa*ne12 * sequence - iter_k*iter_j*iter_z_gqa * z_KV - iter_k*iter_j * zt_gqa) / iter_k;
1442
1670
 
1443
- const int head0 = zt * ncols2;
1671
+ const int zt_Q = z_KV*gqa_ratio + zt_gqa*ncols2; // Global Q head start index.
1444
1672
 
1445
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1446
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1673
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*zt_Q);
1674
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*z_KV);
1447
1675
  const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1448
1676
  (const half *) (mask + nb33*(sequence % ne33));
1449
- float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
1677
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + zt_Q) * (DV/2);
1450
1678
 
1451
- const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1452
- const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1679
+ const half2 * V_h2 = V_is_K_view ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*z_KV);
1680
+ const float * sinks_f = sinks ? (const float *) sinks + zt_Q : nullptr;
1453
1681
 
1454
- const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1682
+ const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, zt_Q, n_head_log2, m0, m1) : 1.0f;
1455
1683
 
1456
1684
  if (KV_max) {
1457
1685
  kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
@@ -1459,9 +1687,9 @@ static __global__ void flash_attn_ext_f16(
1459
1687
 
1460
1688
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1461
1689
  constexpr bool needs_fixup = false;
1462
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
1690
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, V_is_K_view, needs_fixup, is_fixup>
1463
1691
  (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1464
- ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1692
+ ne01, ne02, gqa_ratio, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, zt_gqa, kb0_start, kb0_stop);
1465
1693
  #else
1466
1694
  GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1467
1695
  max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1473,7 +1701,7 @@ static __global__ void flash_attn_ext_f16(
1473
1701
  ne31, ne32, ne33,
1474
1702
  nb31, nb32, nb33);
1475
1703
  NO_DEVICE_CODE;
1476
- #endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1704
+ #endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE))
1477
1705
  }
1478
1706
 
1479
1707
  template <int DKQ, int DV, int ncols1, int ncols2>
@@ -1492,10 +1720,11 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1492
1720
  const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
1493
1721
  const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
1494
1722
 
1495
- const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
1496
- const int nwarps = nthreads / WARP_SIZE;
1723
+ const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc));
1724
+ const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size;
1725
+ const int nwarps = nthreads / warp_size_host;
1497
1726
 
1498
- constexpr bool mla = DKQ == 576;
1727
+ constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu
1499
1728
 
1500
1729
  const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
1501
1730
  const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
@@ -1512,33 +1741,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1512
1741
  float logit_softcap;
1513
1742
  memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
1514
1743
 
1744
+ #if defined(GGML_USE_HIP)
1745
+ using fattn_kernel_ptr_t = const void*;
1746
+ #else
1747
+ using fattn_kernel_ptr_t = fattn_kernel_t;
1748
+ #endif // defined(GGML_USE_HIP)
1515
1749
  fattn_kernel_t fattn_kernel;
1516
1750
  if (logit_softcap == 0.0f) {
1517
1751
  constexpr bool use_logit_softcap = false;
1518
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
1752
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
1519
1753
 
1520
- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1754
+ #if !defined(GGML_USE_MUSA)
1521
1755
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1522
1756
  if (!shared_memory_limit_raised[id]) {
1523
- CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1757
+ CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1524
1758
  shared_memory_limit_raised[id] = true;
1525
1759
  }
1526
- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1760
+ #endif // !defined(GGML_USE_MUSA)
1527
1761
  } else {
1528
1762
  constexpr bool use_logit_softcap = true;
1529
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
1763
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, V_is_K_view>;
1530
1764
 
1531
- #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1765
+ #if !defined(GGML_USE_MUSA)
1532
1766
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
1533
1767
  if (!shared_memory_limit_raised[id]) {
1534
- CUDA_CHECK(cudaFuncSetAttribute(fattn_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1768
+ CUDA_CHECK(cudaFuncSetAttribute(reinterpret_cast<fattn_kernel_ptr_t>(fattn_kernel), cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared_total));
1535
1769
  shared_memory_limit_raised[id] = true;
1536
1770
  }
1537
- #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1771
+ #endif // !defined(GGML_USE_MUSA)
1538
1772
  }
1539
1773
 
1540
1774
  launch_fattn<DV, ncols1, ncols2>
1541
- (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
1775
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host);
1542
1776
  }
1543
1777
 
1544
1778
 
@@ -1585,3 +1819,10 @@ DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 256, 64)
1585
1819
  extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 16);
1586
1820
  extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 16);
1587
1821
  extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 16);
1822
+
1823
+ // For GLM 4.7 Flash
1824
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 4, 4);
1825
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 8, 4);
1826
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 16, 4);
1827
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 1, 32);
1828
+ extern DECL_FATTN_MMA_F16_CASE(576, 512, 2, 32);