whispercpp 1.3.4 → 1.3.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (630) hide show
  1. checksums.yaml +4 -4
  2. data/README.md +60 -43
  3. data/ext/extconf.rb +2 -2
  4. data/ext/ruby_whisper.c +14 -2
  5. data/ext/ruby_whisper.h +39 -0
  6. data/ext/ruby_whisper_context.c +22 -22
  7. data/ext/ruby_whisper_model.c +12 -12
  8. data/ext/ruby_whisper_params.c +47 -23
  9. data/ext/ruby_whisper_segment.c +84 -19
  10. data/ext/ruby_whisper_token.c +351 -0
  11. data/ext/ruby_whisper_transcribe.cpp +1 -1
  12. data/ext/ruby_whisper_vad_context.c +75 -0
  13. data/ext/ruby_whisper_vad_context_detect.cpp +50 -0
  14. data/ext/ruby_whisper_vad_segment.c +139 -0
  15. data/ext/ruby_whisper_vad_segments.c +106 -0
  16. data/ext/sources/CMakeLists.txt +4 -1
  17. data/ext/sources/bindings/javascript/package.json +1 -1
  18. data/ext/sources/cmake/arm64-apple-clang.cmake +16 -0
  19. data/ext/sources/cmake/arm64-windows-llvm.cmake +16 -0
  20. data/ext/sources/cmake/riscv64-spacemit-linux-gnu-gcc.cmake +29 -0
  21. data/ext/sources/cmake/x64-windows-llvm.cmake +5 -0
  22. data/ext/sources/examples/addon.node/vad-example.js +2 -2
  23. data/ext/sources/examples/cli/cli.cpp +121 -112
  24. data/ext/sources/examples/lsp/CMakeLists.txt +2 -1
  25. data/ext/sources/examples/quantize/CMakeLists.txt +2 -1
  26. data/ext/sources/examples/server/server.cpp +10 -11
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +5 -1
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +12 -3
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +7 -1
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +2046 -1974
  31. data/ext/sources/examples/talk-llama/llama-arch.h +67 -2
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +75 -33
  33. data/ext/sources/examples/talk-llama/llama-batch.h +17 -4
  34. data/ext/sources/examples/talk-llama/llama-chat.cpp +79 -3
  35. data/ext/sources/examples/talk-llama/llama-chat.h +4 -0
  36. data/ext/sources/examples/talk-llama/llama-context.cpp +775 -78
  37. data/ext/sources/examples/talk-llama/llama-context.h +57 -9
  38. data/ext/sources/examples/talk-llama/llama-cparams.h +1 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +288 -53
  40. data/ext/sources/examples/talk-llama/llama-grammar.h +22 -1
  41. data/ext/sources/examples/talk-llama/llama-graph.cpp +381 -64
  42. data/ext/sources/examples/talk-llama/llama-graph.h +103 -13
  43. data/ext/sources/examples/talk-llama/llama-hparams.cpp +26 -2
  44. data/ext/sources/examples/talk-llama/llama-hparams.h +41 -10
  45. data/ext/sources/examples/talk-llama/llama-impl.cpp +7 -3
  46. data/ext/sources/examples/talk-llama/llama-impl.h +1 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +5 -3
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +145 -65
  49. data/ext/sources/examples/talk-llama/llama-kv-cache.h +22 -7
  50. data/ext/sources/examples/talk-llama/llama-kv-cells.h +44 -2
  51. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +12 -10
  52. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +32 -19
  53. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +2 -2
  54. data/ext/sources/examples/talk-llama/llama-mmap.cpp +172 -37
  55. data/ext/sources/examples/talk-llama/llama-mmap.h +8 -3
  56. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +91 -9
  57. data/ext/sources/examples/talk-llama/llama-model-loader.h +6 -0
  58. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +3 -0
  59. data/ext/sources/examples/talk-llama/llama-model.cpp +1529 -13134
  60. data/ext/sources/examples/talk-llama/llama-model.h +44 -3
  61. data/ext/sources/examples/talk-llama/llama-quant.cpp +8 -23
  62. data/ext/sources/examples/talk-llama/llama-sampling.cpp +1294 -198
  63. data/ext/sources/examples/talk-llama/llama-sampling.h +19 -7
  64. data/ext/sources/examples/talk-llama/llama-vocab.cpp +133 -37
  65. data/ext/sources/examples/talk-llama/llama-vocab.h +45 -40
  66. data/ext/sources/examples/talk-llama/llama.cpp +729 -2
  67. data/ext/sources/examples/talk-llama/llama.h +152 -14
  68. data/ext/sources/examples/talk-llama/models/afmoe.cpp +191 -0
  69. data/ext/sources/examples/talk-llama/models/apertus.cpp +125 -0
  70. data/ext/sources/examples/talk-llama/models/arcee.cpp +135 -0
  71. data/ext/sources/examples/talk-llama/models/arctic.cpp +138 -0
  72. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +86 -0
  73. data/ext/sources/examples/talk-llama/models/baichuan.cpp +122 -0
  74. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +144 -0
  75. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +135 -0
  76. data/ext/sources/examples/talk-llama/models/bert.cpp +178 -0
  77. data/ext/sources/examples/talk-llama/models/bitnet.cpp +160 -0
  78. data/ext/sources/examples/talk-llama/models/bloom.cpp +101 -0
  79. data/ext/sources/examples/talk-llama/models/chameleon.cpp +178 -0
  80. data/ext/sources/examples/talk-llama/models/chatglm.cpp +132 -0
  81. data/ext/sources/examples/talk-llama/models/codeshell.cpp +111 -0
  82. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +102 -0
  83. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +134 -0
  84. data/ext/sources/examples/talk-llama/models/command-r.cpp +122 -0
  85. data/ext/sources/examples/talk-llama/models/dbrx.cpp +123 -0
  86. data/ext/sources/examples/talk-llama/models/deci.cpp +135 -0
  87. data/ext/sources/examples/talk-llama/models/deepseek.cpp +144 -0
  88. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +259 -0
  89. data/ext/sources/examples/talk-llama/models/dots1.cpp +134 -0
  90. data/ext/sources/examples/talk-llama/models/dream.cpp +105 -0
  91. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +150 -0
  92. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +110 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +114 -0
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +123 -0
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +113 -0
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +120 -0
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +116 -0
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +112 -0
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +128 -0
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +155 -0
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +384 -0
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +170 -0
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +150 -0
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +105 -0
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +144 -0
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +196 -0
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +211 -0
  108. data/ext/sources/examples/talk-llama/models/graph-context-mamba.cpp +283 -0
  109. data/ext/sources/examples/talk-llama/models/grok.cpp +159 -0
  110. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +141 -0
  111. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +132 -0
  112. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +154 -0
  113. data/ext/sources/examples/talk-llama/models/internlm2.cpp +120 -0
  114. data/ext/sources/examples/talk-llama/models/jais.cpp +86 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +106 -0
  116. data/ext/sources/examples/talk-llama/models/lfm2.cpp +175 -0
  117. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +122 -0
  118. data/ext/sources/examples/talk-llama/models/llada.cpp +99 -0
  119. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +178 -0
  120. data/ext/sources/examples/talk-llama/models/llama.cpp +168 -0
  121. data/ext/sources/examples/talk-llama/models/maincoder.cpp +117 -0
  122. data/ext/sources/examples/talk-llama/models/mamba.cpp +55 -0
  123. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +123 -0
  124. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +199 -0
  125. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +124 -0
  126. data/ext/sources/examples/talk-llama/models/mistral3.cpp +160 -0
  127. data/ext/sources/examples/talk-llama/models/models.h +569 -0
  128. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +116 -0
  129. data/ext/sources/examples/talk-llama/models/mpt.cpp +126 -0
  130. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +150 -0
  131. data/ext/sources/examples/talk-llama/models/nemotron.cpp +122 -0
  132. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +104 -0
  133. data/ext/sources/examples/talk-llama/models/olmo.cpp +121 -0
  134. data/ext/sources/examples/talk-llama/models/olmo2.cpp +150 -0
  135. data/ext/sources/examples/talk-llama/models/olmoe.cpp +124 -0
  136. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +127 -0
  137. data/ext/sources/examples/talk-llama/models/openelm.cpp +124 -0
  138. data/ext/sources/examples/talk-llama/models/orion.cpp +123 -0
  139. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +121 -0
  140. data/ext/sources/examples/talk-llama/models/phi2.cpp +121 -0
  141. data/ext/sources/examples/talk-llama/models/phi3.cpp +152 -0
  142. data/ext/sources/examples/talk-llama/models/plamo.cpp +110 -0
  143. data/ext/sources/examples/talk-llama/models/plamo2.cpp +316 -0
  144. data/ext/sources/examples/talk-llama/models/plamo3.cpp +128 -0
  145. data/ext/sources/examples/talk-llama/models/plm.cpp +168 -0
  146. data/ext/sources/examples/talk-llama/models/qwen.cpp +108 -0
  147. data/ext/sources/examples/talk-llama/models/qwen2.cpp +126 -0
  148. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +151 -0
  149. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +117 -0
  150. data/ext/sources/examples/talk-llama/models/qwen3.cpp +117 -0
  151. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +124 -0
  152. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +873 -0
  153. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +149 -0
  154. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +141 -0
  155. data/ext/sources/examples/talk-llama/models/refact.cpp +94 -0
  156. data/ext/sources/examples/talk-llama/models/rnd1.cpp +126 -0
  157. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +162 -0
  158. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +94 -0
  159. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +86 -0
  160. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +135 -0
  161. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +90 -0
  162. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +124 -0
  163. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +126 -0
  164. data/ext/sources/examples/talk-llama/models/smollm3.cpp +128 -0
  165. data/ext/sources/examples/talk-llama/models/stablelm.cpp +146 -0
  166. data/ext/sources/examples/talk-llama/models/starcoder.cpp +100 -0
  167. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +121 -0
  168. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +166 -0
  169. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +96 -0
  170. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +149 -0
  171. data/ext/sources/examples/talk-llama/models/xverse.cpp +108 -0
  172. data/ext/sources/examples/talk-llama/unicode.cpp +102 -16
  173. data/ext/sources/examples/vad-speech-segments/CMakeLists.txt +1 -1
  174. data/ext/sources/examples/whisper.wasm/index-tmpl.html +1 -1
  175. data/ext/sources/ggml/CMakeLists.txt +82 -54
  176. data/ext/sources/ggml/include/ggml-alloc.h +9 -0
  177. data/ext/sources/ggml/include/ggml-backend.h +4 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +1 -0
  179. data/ext/sources/ggml/include/ggml-hexagon.h +19 -0
  180. data/ext/sources/ggml/include/ggml-rpc.h +8 -11
  181. data/ext/sources/ggml/include/ggml-zendnn.h +22 -0
  182. data/ext/sources/ggml/include/ggml.h +190 -12
  183. data/ext/sources/ggml/src/CMakeLists.txt +82 -11
  184. data/ext/sources/ggml/src/ggml-alloc.c +124 -41
  185. data/ext/sources/ggml/src/ggml-backend-impl.h +1 -4
  186. data/ext/sources/ggml/src/ggml-backend-reg.cpp +27 -3
  187. data/ext/sources/ggml/src/ggml-backend.cpp +71 -21
  188. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +17 -3
  189. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +5 -9
  190. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +57 -45
  191. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +138 -47
  192. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +2179 -1696
  193. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +238 -317
  194. data/ext/sources/ggml/src/ggml-cann/common.h +283 -208
  195. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +626 -776
  196. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +156 -86
  197. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +1 -0
  198. data/ext/sources/ggml/src/ggml-cpu/arch/arm/cpu-feats.cpp +4 -0
  199. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +428 -26
  200. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +1004 -0
  201. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +4 -5
  202. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/cpu-feats.cpp +38 -0
  203. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +108 -49
  204. data/ext/sources/ggml/src/ggml-cpu/arch/s390/cpu-feats.cpp +50 -0
  205. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +6 -6
  206. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +50 -2
  207. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +5 -3
  208. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +195 -71
  209. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +4 -0
  210. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +573 -106
  211. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.h +33 -44
  212. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +298 -112
  213. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +333 -0
  214. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +819 -125
  215. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.h +6 -0
  216. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +708 -431
  217. data/ext/sources/ggml/src/ggml-cpu/ops.h +5 -4
  218. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +671 -31
  219. data/ext/sources/ggml/src/ggml-cpu/repack.h +14 -0
  220. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +41 -43
  221. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +3 -2
  222. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +151 -0
  223. data/ext/sources/ggml/src/ggml-cpu/unary-ops.h +7 -0
  224. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +124 -1
  225. data/ext/sources/ggml/src/ggml-cpu/vec.h +261 -146
  226. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +72 -1
  227. data/ext/sources/ggml/src/ggml-cuda/argmax.cu +2 -2
  228. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +123 -6
  229. data/ext/sources/ggml/src/ggml-cuda/argsort.cuh +16 -0
  230. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +1 -1
  231. data/ext/sources/ggml/src/ggml-cuda/common.cuh +353 -80
  232. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  233. data/ext/sources/ggml/src/ggml-cuda/cpy-utils.cuh +1 -1
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +339 -246
  235. data/ext/sources/ggml/src/ggml-cuda/cpy.cuh +1 -5
  236. data/ext/sources/ggml/src/ggml-cuda/cumsum.cu +307 -0
  237. data/ext/sources/ggml/src/ggml-cuda/cumsum.cuh +5 -0
  238. data/ext/sources/ggml/src/ggml-cuda/diag.cu +77 -0
  239. data/ext/sources/ggml/src/ggml-cuda/diag.cuh +5 -0
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +31 -21
  241. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +663 -596
  242. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +35 -741
  243. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +1241 -0
  244. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +30 -37
  245. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +14 -13
  246. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +48 -0
  247. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +83 -37
  248. data/ext/sources/ggml/src/ggml-cuda/fill.cu +37 -0
  249. data/ext/sources/ggml/src/ggml-cuda/fill.cuh +3 -0
  250. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1155 -164
  251. data/ext/sources/ggml/src/ggml-cuda/mean.cu +5 -4
  252. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +741 -48
  253. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +60 -12
  254. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +381 -42
  255. data/ext/sources/ggml/src/ggml-cuda/mmid.cu +164 -0
  256. data/ext/sources/ggml/src/ggml-cuda/mmid.cuh +5 -0
  257. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +69 -176
  258. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +498 -171
  259. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +375 -79
  260. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +3 -2
  261. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +241 -95
  262. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/pad.cu +64 -33
  264. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +151 -0
  265. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +14 -0
  266. data/ext/sources/ggml/src/ggml-cuda/rope.cu +192 -77
  267. data/ext/sources/ggml/src/ggml-cuda/rope.cuh +2 -0
  268. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +101 -47
  269. data/ext/sources/ggml/src/ggml-cuda/set.cu +39 -0
  270. data/ext/sources/ggml/src/ggml-cuda/set.cuh +7 -0
  271. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +203 -6
  272. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +275 -0
  273. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cuh +3 -0
  274. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +14 -20
  275. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +49 -84
  276. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu +5 -0
  277. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu +5 -0
  278. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu +5 -0
  279. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu +5 -0
  280. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu +5 -0
  281. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu +5 -0
  282. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq72-dv72.cu +5 -0
  283. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu +5 -0
  284. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu +5 -0
  285. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +19 -1
  286. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +96 -0
  287. data/ext/sources/ggml/src/ggml-cuda/top-k.cuh +3 -0
  288. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +168 -76
  289. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +11 -4
  290. data/ext/sources/ggml/src/ggml-cuda/tri.cu +136 -0
  291. data/ext/sources/ggml/src/ggml-cuda/tri.cuh +5 -0
  292. data/ext/sources/ggml/src/ggml-cuda/unary.cu +105 -11
  293. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +36 -0
  294. data/ext/sources/ggml/src/ggml-cuda/upscale.cu +163 -7
  295. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +4 -0
  296. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +12 -1
  297. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +6 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +80 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +3151 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +44 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +682 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +360 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +157 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +566 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +112 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +35 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.c +63 -0
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-dma.h +157 -0
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +165 -0
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +92 -0
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +16 -0
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +94 -0
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +72 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +49 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +1020 -0
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +1353 -0
  317. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +1001 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2503 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +149 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +487 -0
  321. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +168 -0
  322. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +402 -0
  323. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +287 -0
  324. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +297 -0
  325. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.h +57 -0
  326. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +454 -0
  327. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +221 -0
  328. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +153 -0
  329. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +8 -13
  330. data/ext/sources/ggml/src/ggml-impl.h +67 -6
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +2 -2
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +29 -20
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +652 -285
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +103 -56
  335. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +496 -118
  336. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +231 -9
  337. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +1227 -224
  338. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +12 -0
  339. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +14 -8
  340. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1972 -704
  341. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +3 -1
  342. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +11 -0
  343. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +1430 -120
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +63 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +82 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/fill.cl +17 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/flash_attn_f32.cl +4 -3
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32.cl +162 -0
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32.cl +156 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +36 -12
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +39 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_kq_kqv.cl +273 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f16_f32_l4_lm.cl +24 -10
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_f32_f32_l4_lm.cl +24 -10
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_l4_lm.cl +154 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/pad.cl +29 -20
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/rms_norm.cl +25 -10
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/rope.cl +50 -24
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/set_rows.cl +35 -16
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -0
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/sqr.cl +53 -0
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/sqrt.cl +53 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/ssm_conv.cl +77 -0
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +13 -0
  365. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +438 -156
  366. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +48 -3
  367. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +77 -0
  368. data/ext/sources/ggml/src/ggml-sycl/add-id.hpp +8 -0
  369. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +6 -0
  370. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +0 -9
  371. data/ext/sources/ggml/src/ggml-sycl/binbcast.hpp +0 -6
  372. data/ext/sources/ggml/src/ggml-sycl/common.hpp +117 -15
  373. data/ext/sources/ggml/src/ggml-sycl/concat.cpp +55 -44
  374. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +34 -0
  375. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +79 -0
  376. data/ext/sources/ggml/src/ggml-sycl/count-equal.hpp +9 -0
  377. data/ext/sources/ggml/src/ggml-sycl/cpy.cpp +0 -3
  378. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +18 -0
  379. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +76 -3
  380. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +333 -300
  381. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +10 -2
  382. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +335 -110
  383. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +22 -0
  384. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +156 -0
  385. data/ext/sources/ggml/src/ggml-sycl/norm.hpp +2 -0
  386. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +97 -0
  387. data/ext/sources/ggml/src/ggml-sycl/pad.hpp +24 -0
  388. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.cpp +100 -0
  389. data/ext/sources/ggml/src/ggml-sycl/pad_reflect_1d.hpp +10 -0
  390. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +2 -0
  391. data/ext/sources/ggml/src/ggml-sycl/repeat_back.cpp +76 -0
  392. data/ext/sources/ggml/src/ggml-sycl/repeat_back.hpp +8 -0
  393. data/ext/sources/ggml/src/ggml-sycl/roll.cpp +122 -0
  394. data/ext/sources/ggml/src/ggml-sycl/roll.hpp +20 -0
  395. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +30 -17
  396. data/ext/sources/ggml/src/ggml-sycl/set.cpp +73 -0
  397. data/ext/sources/ggml/src/ggml-sycl/set.hpp +5 -0
  398. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +327 -162
  399. data/ext/sources/ggml/src/ggml-sycl/softmax.hpp +4 -0
  400. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +127 -0
  401. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.hpp +5 -0
  402. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +58 -0
  403. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +38 -18
  404. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +5013 -2859
  405. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/abs.comp +21 -0
  406. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +2 -2
  407. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add.comp +2 -2
  408. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add1.comp +28 -0
  409. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp +1 -1
  410. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/arange.comp +20 -0
  411. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp +2 -2
  412. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp +33 -26
  413. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/argsort_large.comp +114 -0
  414. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ceil.comp +22 -0
  415. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +2 -2
  416. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp +2 -2
  417. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +2 -2
  418. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp +1 -1
  419. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +47 -49
  420. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp +1 -1
  421. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +2 -2
  422. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +3 -3
  423. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +4 -4
  424. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_transpose.comp +67 -0
  425. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +2 -2
  426. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp +2 -2
  427. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/count_experts.comp +51 -0
  428. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +83 -0
  429. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp +60 -0
  430. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp +66 -0
  431. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp +1 -1
  432. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs.comp → dequant_funcs.glsl} +9 -21
  433. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_funcs_cm2.comp → dequant_funcs_cm2.glsl} +18 -4
  434. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{dequant_head.comp → dequant_head.glsl} +1 -1
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp +1 -1
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp +1 -1
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +1 -1
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +1 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +1 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +1 -1
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +1 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp +1 -1
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp +3 -3
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp +3 -3
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp +1 -1
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp +1 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp +1 -1
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp +3 -3
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp +1 -1
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp +1 -1
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp +3 -3
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp +1 -1
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp +1 -1
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +29 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp +1 -1
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/div.comp +2 -2
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +3 -3
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fill.comp +19 -0
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +39 -17
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{flash_attn_base.comp → flash_attn_base.glsl} +19 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +45 -7
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +50 -12
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +1 -1
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/floor.comp +22 -0
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp +2 -2
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp +2 -2
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp +2 -2
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp +2 -2
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp +2 -2
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp +2 -2
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_binary_head.comp → generic_binary_head.glsl} +17 -2
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_head.comp → generic_head.glsl} +2 -0
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{generic_unary_head.comp → generic_unary_head.glsl} +7 -0
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp +4 -4
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +3 -3
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_head.comp → glu_head.glsl} +1 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp +2 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp +2 -2
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp +2 -2
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +19 -7
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +2 -3
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +2 -2
  484. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +2 -2
  485. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +18 -0
  486. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp +2 -2
  487. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +2 -2
  488. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mat_vec_base.comp → mul_mat_vec_base.glsl} +70 -25
  489. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +35 -0
  490. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp +71 -21
  491. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp +41 -25
  492. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp +2 -2
  493. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp +44 -26
  494. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp +2 -2
  495. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp +2 -2
  496. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp +2 -2
  497. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp +9 -7
  498. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp +9 -7
  499. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +4 -6
  500. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp +2 -2
  501. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +4 -6
  502. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +4 -6
  503. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp +2 -2
  504. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +39 -36
  505. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +494 -0
  506. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +78 -103
  507. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +34 -23
  508. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{mul_mm_funcs.comp → mul_mm_funcs.glsl} +69 -59
  509. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +72 -0
  510. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +88 -228
  511. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +454 -0
  512. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +78 -0
  513. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +97 -13
  514. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/neg.comp +20 -0
  515. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +2 -2
  516. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp +2 -2
  517. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp +1 -1
  518. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp +21 -6
  519. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp +1 -1
  520. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +10 -10
  521. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp +2 -2
  522. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp +2 -2
  523. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp +2 -2
  524. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp +2 -2
  525. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +50 -4
  526. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp +2 -2
  527. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp +2 -2
  528. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp +2 -2
  529. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +234 -0
  530. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +20 -0
  531. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +6 -50
  532. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +6 -33
  533. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +6 -33
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +28 -0
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +6 -39
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/round.comp +29 -0
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp +2 -2
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp +2 -2
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp +2 -2
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp +2 -2
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +2 -2
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp +1 -1
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp +2 -2
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large1.comp +62 -0
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large2.comp +79 -0
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large3.comp +65 -0
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_large_common.glsl +53 -0
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/softplus.comp +23 -0
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/solve_tri.comp +81 -0
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +2 -2
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +2 -2
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +44 -0
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +124 -0
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/step.comp +22 -0
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp +2 -2
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp +2 -25
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.glsl +25 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp +2 -2
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp +2 -2
  560. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp +2 -2
  561. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp +1 -1
  562. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_argsort.comp +118 -0
  563. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +213 -0
  564. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/topk_nary_search.comp +246 -0
  565. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +43 -0
  566. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/trunc.comp +22 -0
  567. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{types.comp → types.glsl} +345 -26
  568. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp +90 -12
  569. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +335 -151
  570. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/xielu.comp +35 -0
  571. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +28 -2
  572. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +169 -0
  573. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1964 -435
  574. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +778 -0
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +188 -0
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +101 -0
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +33 -10
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +591 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +1 -1
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +323 -0
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +6 -6
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +97 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +247 -0
  584. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +302 -0
  585. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +267 -0
  586. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +83 -17
  587. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl +295 -0
  588. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +90 -0
  589. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +112 -0
  590. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/soft_max.tmpl.wgsl +345 -0
  591. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +483 -0
  592. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +92 -0
  593. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +466 -0
  594. data/ext/sources/ggml/src/ggml.c +425 -33
  595. data/ext/sources/include/whisper.h +1 -0
  596. data/ext/sources/src/CMakeLists.txt +3 -1
  597. data/ext/sources/src/whisper.cpp +101 -35
  598. data/ext/sources/tests/CMakeLists.txt +2 -2
  599. data/ext/sources/tests/test-vad-full.cpp +4 -2
  600. data/ext/sources/tests/test-vad.cpp +1 -1
  601. data/extsources.rb +1 -0
  602. data/lib/whisper/model/uri.rb +17 -18
  603. data/sig/whisper.rbs +119 -2
  604. data/test/test_params.rb +16 -8
  605. data/test/test_segment.rb +0 -1
  606. data/test/test_token.rb +70 -0
  607. data/test/test_vad.rb +1 -1
  608. data/test/test_vad_context.rb +50 -0
  609. data/test/test_vad_segment.rb +19 -0
  610. data/test/test_vad_segments.rb +16 -0
  611. data/test/test_whisper.rb +7 -0
  612. data/whispercpp.gemspec +1 -1
  613. metadata +287 -34
  614. data/ext/sources/build-xcframework.sh +0 -571
  615. data/ext/sources/ggml/src/ggml-cann/Doxyfile +0 -2579
  616. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.comp +0 -105
  617. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.comp +0 -55
  618. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +0 -44
  619. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +0 -41
  620. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +0 -60
  621. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +0 -44
  622. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +0 -41
  623. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +0 -48
  624. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_bfloat16_support.comp → feature-tests/bfloat16.comp} +0 -0
  625. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat_support.comp → feature-tests/coopmat.comp} +0 -0
  626. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_coopmat2_support.comp → feature-tests/coopmat2.comp} +0 -0
  627. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{test_integer_dot_support.comp → feature-tests/integer_dot.comp} +0 -0
  628. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{glu_main.comp → glu_main.glsl} +0 -0
  629. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{rte.comp → rte.glsl} +0 -0
  630. /data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/{utils.comp → utils.glsl} +0 -0
@@ -5,284 +5,211 @@
5
5
 
6
6
  using namespace ggml_cuda_mma;
7
7
 
8
- typedef tile<16, 8, half2> tile_A;
9
- typedef tile< 8, 8, half2> tile_B;
10
- typedef tile<16, 8, half2> tile_B_16;
11
- typedef tile<16, 8, float> tile_C_KQ;
12
- typedef tile<16, 16, float> tile_C_KQ_16;
13
- typedef tile<16, 4, half2> tile_C_VKQ;
14
- typedef tile<16, 8, half2> tile_C_VKQ_16;
15
-
16
- // Config options for specific head sizes.
8
+ // Config options for the MMA kernel.
17
9
  // Should not affect results, only speed/register pressure/shared memory use.
18
- //
19
- // nbatch_fa: number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
20
- // nwarps_max: maximum number of warps per CUDA block, up to 8 warps in total can run per SM (given enough shared memory).
21
- // Q_in_reg: whether the Q values should be kept permanently in registers.
22
- // nstages_target: targeted number of pipeline stages for cp_async (if available), 0 means synchronous data loading.
23
- // nbatch_K2: number of K half2 values in direction of DKQ to load in parallel.
24
- // nbatch_V2: number of V half2 values in direction of DV to load in parallel.
25
- // nbatch_combine: number of VKQ half2 values in direction of DV to combine in parallel.
26
-
27
- template <int DKQ, int DV>
28
- struct fattn_mma_f16_config;
29
-
30
- template <>
31
- struct fattn_mma_f16_config< 64, 64> {
32
- static constexpr int nbatch_fa = 64;
33
- static constexpr int nwarps_max = 4;
34
- static constexpr bool Q_in_reg = true;
35
- static constexpr int nstages_target = 2;
36
-
37
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
38
- return 32;
39
- }
40
-
41
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
42
- return 32;
43
- }
44
-
45
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
46
- return 32;
47
- }
48
-
49
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
50
- return 32;
51
- }
52
-
53
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
54
- return 32;
55
- }
56
-
57
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
58
- return 32;
59
- }
10
+ struct fattn_mma_config {
11
+ int nthreads; // Number of threads per CUDA block.
12
+ int occupancy; // Targeted occupancy for the MMA kernel.
13
+ int nbatch_fa; // Number of KV rows per softmax rescaling of KQ rowsums and VKQ accumulators.
14
+ int nbatch_K2; // Number of K half2 values in direction of DKQ to load in parallel.
15
+ int nbatch_V2; // Number of V half2 values in direction of DV to load in parallel.
16
+ int nbatch_combine; // Number of VKQ half2 values in direction of DV to combine in parallel.
17
+ int nstages_target; // Number of pipeline stages to use ideally, 1 == always load data synchronously, 2 == preload data if there is hardware support.
18
+ bool Q_in_reg; // Whether the Q values should be kept permanently in registers.
19
+
20
+ constexpr __host__ __device__ fattn_mma_config(
21
+ int nthreads, int occupancy, int nbatch_fa, int nbatch_K2, int nbatch_V2, int nbatch_combine, int nstages_target, bool Q_in_reg) :
22
+ nthreads(nthreads), occupancy(occupancy), nbatch_fa(nbatch_fa), nbatch_K2(nbatch_K2), nbatch_V2(nbatch_V2), nbatch_combine(nbatch_combine),
23
+ nstages_target(nstages_target), Q_in_reg(Q_in_reg) {}
60
24
  };
61
25
 
62
- template <>
63
- struct fattn_mma_f16_config< 80, 80> {
64
- static constexpr int nbatch_fa = 64;
65
- static constexpr int nwarps_max = 4;
66
- static constexpr bool Q_in_reg = true;
67
- static constexpr int nstages_target = 2;
68
-
69
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
70
- return 40;
71
- }
72
-
73
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
74
- return 40;
75
- }
76
-
77
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
78
- return 40;
79
- }
80
-
81
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
82
- return 40;
83
- }
84
-
85
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
86
- return 40;
87
- }
88
-
89
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
90
- return 40;
91
- }
92
- };
93
-
94
- template <>
95
- struct fattn_mma_f16_config< 96, 96> {
96
- static constexpr int nbatch_fa = 64;
97
- static constexpr int nwarps_max = 4;
98
- static constexpr bool Q_in_reg = true;
99
- static constexpr int nstages_target = 2;
100
-
101
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
102
- return 48;
103
- }
104
-
105
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
106
- return 48;
107
- }
108
-
109
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
110
- return 48;
111
- }
112
-
113
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
114
- return 48;
115
- }
116
-
117
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
118
- return 48;
119
- }
26
+ #define GGML_CUDA_FATTN_MMA_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads_, occupancy_, nbatch_fa_, nbatch_K2_, nbatch_V2_, nbatch_combine_, nstages_target_, Q_in_reg_) \
27
+ if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \
28
+ static_assert((nthreads_) % 32 == 0 && (nthreads_) <= 512, "bad nthreads"); \
29
+ static_assert( (occupancy_) <= 8, "bad occupancy"); \
30
+ static_assert((nbatch_fa_) % 32 == 0 && (nbatch_fa_) <= 256, "bad nbatch_fa"); \
31
+ static_assert((nbatch_K2_) % 4 == 0 && (nbatch_K2_) <= 512, "bad nbatch_K2"); \
32
+ static_assert((nbatch_V2_) % 4 == 0 && (nbatch_V2_) <= 256, "bad nbatch_V2"); \
33
+ static_assert((nbatch_combine_) % 4 == 0 && (nbatch_combine_) <= 128, "bad nbatch_combine"); \
34
+ static_assert((nstages_target_) >= 1 && (nstages_target_) <= 2, "bad nstages_target"); \
35
+ return fattn_mma_config{(nthreads_), (occupancy_), (nbatch_fa_), (nbatch_K2_), (nbatch_V2_), (nbatch_combine_), (nstages_target_), (Q_in_reg_)}; \
36
+ } \
37
+
38
+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_ampere(const int DKQ, const int DV, const int ncols) {
39
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 2, true);
40
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 2, true);
41
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 2, true);
42
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 128, 2, 64, 32, 32, 32, 2, true);
43
+
44
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 2, true);
45
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 2, true);
46
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 2, true);
47
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 128, 2, 64, 40, 40, 40, 2, true);
48
+
49
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 2, true);
50
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 2, true);
51
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 2, true);
52
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 128, 2, 64, 48, 48, 48, 2, true);
53
+
54
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 2, true);
55
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 2, true);
56
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 2, true);
57
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 128, 2, 64, 56, 56, 56, 2, true);
58
+
59
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 2, true);
60
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 2, true);
61
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 2, true);
62
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 128, 2, 64, 64, 64, 64, 2, true);
63
+
64
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 2, true);
65
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 2, true);
66
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 2, true);
67
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 32, 128, 128, 128, 2, true);
68
+
69
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 128, 1, false);
70
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 128, 1, false);
71
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
72
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
73
+
74
+ return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
75
+ }
120
76
 
121
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
122
- return 48;
123
- }
124
- };
77
+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_turing(const int DKQ, const int DV, const int ncols) {
78
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 128, 2, 64, 128, 128, 128, 2, true);
79
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 128, 2, 64, 128, 128, 128, 2, true);
80
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 64, 128, 128, 64, 2, true);
81
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 128, 2, 64, 128, 128, 64, 2, true);
125
82
 
126
- template <>
127
- struct fattn_mma_f16_config<112, 112> {
128
- static constexpr int nbatch_fa = 64;
129
- static constexpr int nwarps_max = 4;
130
- static constexpr bool Q_in_reg = true;
131
- static constexpr int nstages_target = 2;
83
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 96, 64, 128, 1, false);
84
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 96, 64, 128, 1, false);
85
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 128, 1, false);
86
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 128, 1, false);
132
87
 
133
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
134
- return 56;
135
- }
88
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
89
+ }
136
90
 
137
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
138
- return 56;
139
- }
91
+ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_volta(const int DKQ, const int DV, const int ncols) {
92
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 8, 64, 4, 32, 288, 256, 64, 1, false);
93
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 16, 64, 4, 32, 288, 256, 64, 1, false);
94
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 32, 128, 2, 32, 160, 128, 64, 1, false);
95
+ GGML_CUDA_FATTN_MMA_CONFIG_CASE(576, 512, 64, 256, 1, 32, 160, 128, 64, 1, false);
140
96
 
141
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
142
- return 56;
143
- }
97
+ // TODO tune specifically for Volta
98
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
99
+ }
144
100
 
145
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
146
- return 56;
101
+ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) {
102
+ if (ampere_mma_available(cc)) {
103
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
147
104
  }
148
-
149
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
150
- return 56;
105
+ if (turing_mma_available(cc)) {
106
+ return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
151
107
  }
108
+ GGML_ASSERT(volta_mma_available(cc));
109
+ return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
110
+ }
152
111
 
153
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
154
- return 56;
155
- }
156
- };
112
+ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols) {
113
+ #if defined(AMPERE_MMA_AVAILABLE)
114
+ return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols);
115
+ #elif defined(TURING_MMA_AVAILABLE)
116
+ return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols);
117
+ #elif defined(VOLTA_MMA_AVAILABLE)
118
+ return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols);
119
+ #else
120
+ GGML_UNUSED_VARS(DKQ, DV, ncols);
121
+ return fattn_mma_config(32, 1, 0, 0, 0, 0, 0, false);
122
+ #endif // defined(AMPERE_MMA_AVAILABLE)
123
+ }
157
124
 
158
- template <>
159
- struct fattn_mma_f16_config<128, 128> {
160
- static constexpr int nbatch_fa = 64;
161
- static constexpr int nwarps_max = 4;
162
- static constexpr bool Q_in_reg = true;
163
- static constexpr int nstages_target = 2;
125
+ static __host__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) {
126
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nthreads;
127
+ }
164
128
 
165
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
166
- return 64;
167
- }
129
+ static constexpr __device__ int ggml_cuda_fattn_mma_get_nthreads(const int DKQ, const int DV, const int ncols) {
130
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nthreads;
131
+ }
168
132
 
169
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
170
- return 64;
171
- }
133
+ static __host__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) {
134
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).occupancy;
135
+ }
172
136
 
173
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
174
- return 64;
175
- }
137
+ static constexpr __device__ int ggml_cuda_fattn_mma_get_occupancy(const int DKQ, const int DV, const int ncols) {
138
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).occupancy;
139
+ }
176
140
 
177
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
178
- return 64;
179
- }
141
+ static __host__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) {
142
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_fa;
143
+ }
180
144
 
181
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
182
- return 64;
183
- }
145
+ static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_fa(const int DKQ, const int DV, const int ncols) {
146
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_fa;
147
+ }
184
148
 
185
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
186
- return 64;
187
- }
188
- };
149
+ static __host__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols, const int cc) {
150
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_K2;
151
+ }
189
152
 
190
- template <>
191
- struct fattn_mma_f16_config<256, 256> {
192
- static constexpr int nbatch_fa = 32;
193
- static constexpr int nwarps_max = 4;
194
- static constexpr bool Q_in_reg = true;
195
- static constexpr int nstages_target = 2;
153
+ static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_K2(const int DKQ, const int DV, const int ncols) {
154
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_K2;
155
+ }
196
156
 
197
- static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
198
- return 128;
199
- }
157
+ static __host__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols, const int cc) {
158
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_V2;
159
+ }
200
160
 
201
- static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
202
- return 128;
203
- }
161
+ static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_V2(const int DKQ, const int DV, const int ncols) {
162
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_V2;
163
+ }
204
164
 
205
- static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
206
- return 128;
207
- }
165
+ static __host__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols, const int cc) {
166
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nbatch_combine;
167
+ }
208
168
 
209
- static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
210
- return 128;
211
- }
169
+ static constexpr __device__ int ggml_cuda_fattn_mma_get_nbatch_combine(const int DKQ, const int DV, const int ncols) {
170
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nbatch_combine;
171
+ }
212
172
 
213
- static int get_nbatch_combine_host(const int cc, const int ncols) {
214
- if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
215
- return ncols <= 16 ? 128 : 64;
216
- }
217
- return 64;
218
- }
173
+ static __host__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols, const int cc) {
174
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).nstages_target;
175
+ }
219
176
 
220
- static constexpr __device__ int get_nbatch_combine_device(int ncols) {
221
- #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
222
- return ncols <= 16 ? 128 : 64;
223
- #else
224
- GGML_UNUSED(ncols);
225
- return 128;
226
- #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
227
- }
228
- };
177
+ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages_target(const int DKQ, const int DV, const int ncols) {
178
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).nstages_target;
179
+ }
229
180
 
230
- template <>
231
- struct fattn_mma_f16_config<576, 512> {
232
- static constexpr int nbatch_fa = 32;
233
- static constexpr int nwarps_max = 8;
234
- static constexpr bool Q_in_reg = false;
235
- static constexpr int nstages_target = 1;
181
+ static __host__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols, const int cc) {
182
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols, cc).Q_in_reg;
183
+ }
236
184
 
237
- static int get_nbatch_K2_host(const int cc, const int ncols) {
238
- if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
239
- return ncols <= 16 ? 96 : 160;
240
- }
241
- return ncols <= 16 ? 288 : 160;
242
- }
185
+ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, const int DV, const int ncols) {
186
+ return ggml_cuda_fattn_mma_get_config(DKQ, DV, ncols).Q_in_reg;
187
+ }
243
188
 
244
- static constexpr __device__ int get_nbatch_K2_device(int ncols) {
245
- #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
246
- return ncols <= 16 ? 96 : 160;
247
- #else
248
- return ncols <= 16 ? 288 : 160;
249
- #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
250
- }
189
+ // ------------------------------------------------------------------------------------------------------------------
251
190
 
252
- static int get_nbatch_V2_host(const int cc, const int ncols) {
253
- if (ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_TURING) {
254
- return ncols <= 16 ? 64 : 128;
255
- }
256
- return ncols <= 16 ? 256 : 128;
257
- }
191
+ static __host__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2, const int cc) {
192
+ return cp_async_available(cc) && ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2, cc) : 0;
193
+ }
258
194
 
259
- static constexpr __device__ int get_nbatch_V2_device(int ncols) {
260
- #if __CUDA_ARCH__ == GGML_CUDA_CC_TURING
261
- return ncols <= 16 ? 64 : 128;
195
+ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, const int DV, const int ncols1, const int ncols2) {
196
+ #ifdef CP_ASYNC_AVAILABLE
197
+ return ncols2 >= 2 ? ggml_cuda_fattn_mma_get_nstages_target(DKQ, DV, ncols1*ncols2) : 0;
262
198
  #else
263
- return ncols <= 16 ? 256 : 128;
264
- #endif // __CUDA_ARCH__ == GGML_CUDA_CC_TURING
265
- }
266
-
267
- static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
268
- return 128;
269
- }
270
-
271
- static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
272
- return 128;
273
- }
274
- };
199
+ GGML_UNUSED_VARS(DKQ, DV, ncols1, ncols2);
200
+ return 0;
201
+ #endif // CP_ASYNC_AVAILABLE
202
+ }
275
203
 
276
204
  // ------------------------------------------------------------------------------------------------------------------
277
205
 
278
- template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async>
206
+ template<int stride_tile, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
279
207
  static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
280
- const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV) {
281
-
208
+ const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) {
282
209
  // K/V data is loaded with decreasing granularity for D for better memory bandwidth.
283
210
  // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes.
284
-
285
- if (use_cp_async) {
211
+ if constexpr (use_cp_async) {
212
+ static_assert(!oob_check, "OOB check not compatible with cp_async");
286
213
  constexpr int preload = 64;
287
214
  constexpr int h2_per_chunk = 16/sizeof(half2);
288
215
  const int chunks_per_row = D2 / h2_per_chunk;
@@ -315,9 +242,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
315
242
  }
316
243
  }
317
244
  };
318
- ggml_cuda_unroll<5>{}(load);
245
+ // 1: max 32*16=512 bytes, 256 half
246
+ // 2: max 16*16=256 bytes, 128 half
247
+ // 3: max 8*16=128 bytes, 64 half
248
+ // 4: max 4*16= 64 bytes, 32 half
249
+ // 5: max 2*16= 32 bytes, 16 half
250
+ // 6: max 1*16= 16 bytes, 8 half
251
+ ggml_cuda_unroll<6>{}(load);
319
252
  } else {
320
- static_assert(nbatch_fa % (4*nwarps) == 0, "out of bounds");
253
+ // TODO use ggml_cuda_memcpy_1
321
254
  auto load = [&] __device__ (const int n) {
322
255
  const int stride_k = WARP_SIZE >> n;
323
256
  const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k);
@@ -340,20 +273,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile(
340
273
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
341
274
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
342
275
 
343
- tile_KV[i*stride_tile + k] = KV[i*stride_KV + k];
276
+ tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f);
344
277
  }
345
278
  }
346
279
  };
347
- ggml_cuda_unroll<3>{}(load);
280
+ // 1: max 32* 4=128 bytes, 64 half
281
+ // 2: max 16* 4= 64 bytes, 32 half
282
+ // 3: max 8* 4= 32 bytes, 16 half
283
+ // 4: max 4* 4= 16 bytes, 8 half
284
+ ggml_cuda_unroll<4>{}(load);
348
285
  }
349
286
  }
350
287
 
351
- template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async>
288
+ template<int ncols1, int nwarps, int nbatch_fa, bool use_cp_async, bool oob_check>
352
289
  static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
353
- const half2 * const __restrict__ mask_h2, half2 * const __restrict__ tile_mask, const int stride_mask) {
354
- static_assert(nbatch_fa == 2*WARP_SIZE || WARP_SIZE % nbatch_fa == 0, "bad KQ_per_iter");
355
-
356
- if (use_cp_async) {
290
+ const half * const __restrict__ mask_h, half * const __restrict__ tile_mask,
291
+ const int stride_mask, const int i_sup, const int j0, const uint3 ne01) {
292
+ if constexpr (use_cp_async) {
293
+ static_assert(nbatch_fa <= 8*WARP_SIZE && nbatch_fa % 8 == 0, "bad nbatch_fa");
294
+ static_assert(!oob_check, "OOB check incompatible with cp_async");
357
295
  constexpr int preload = nbatch_fa >= 32 ? nbatch_fa * sizeof(half) : 64;
358
296
  constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa;
359
297
  constexpr int stride_j = nwarps * cols_per_warp;
@@ -361,50 +299,85 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask(
361
299
  const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask);
362
300
 
363
301
  #pragma unroll
364
- for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
365
- const int j = j0 + threadIdx.y*cols_per_warp +
366
- (nbatch_fa == 2*WARP_SIZE ? threadIdx.x / (WARP_SIZE/4) : threadIdx.x / (WARP_SIZE/cols_per_warp));
302
+ for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
303
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
304
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
367
305
 
368
- if (j0 + stride_j > ncols1 && j >= ncols1) {
306
+ if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
369
307
  break;
370
308
  }
371
309
 
372
- const int i = 4 * (threadIdx.x % (nbatch_fa/8));
310
+ const int i = 8 * (threadIdx.x % (nbatch_fa/8));
373
311
 
374
- cp_async_cg_16<preload>(tile_mask_32 + j*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half2), mask_h2 + j*stride_mask + i);
312
+ cp_async_cg_16<preload>(tile_mask_32 + j_sram*(nbatch_fa*sizeof(half) + 16) + i*sizeof(half), mask_h + j_vram*stride_mask + i);
375
313
  }
376
- return;
377
- }
314
+ } else if constexpr (oob_check) {
315
+ #pragma unroll
316
+ for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
317
+ const int j_sram = j1 + threadIdx.y;
318
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
319
+
320
+ if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
321
+ break;
322
+ }
378
323
 
379
- constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
380
- constexpr int stride_j = nwarps * cols_per_warp;
381
324
  #pragma unroll
382
- for (int j0 = 0; j0 < ncols1; j0 += stride_j) {
383
- const int j = j0 + threadIdx.y*cols_per_warp + (nbatch_fa == 2*WARP_SIZE ? 0 : threadIdx.x / (WARP_SIZE/cols_per_warp));
325
+ for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) {
326
+ const int i = i0 + threadIdx.x;
384
327
 
385
- if (j0 + stride_j > ncols1 && j >= ncols1) {
386
- break;
328
+ tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f);
329
+ }
387
330
  }
331
+ } else if constexpr (nbatch_fa < 2*WARP_SIZE) {
332
+ constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa;
333
+ constexpr int stride_j = nwarps * cols_per_warp;
334
+ #pragma unroll
335
+ for (int j1 = 0; j1 < ncols1; j1 += stride_j) {
336
+ const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp);
337
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
388
338
 
389
- const int i = nbatch_fa == 2*WARP_SIZE ? threadIdx.x : threadIdx.x % (WARP_SIZE/cols_per_warp);
339
+ if (j1 + stride_j > ncols1 && j_sram >= ncols1) {
340
+ break;
341
+ }
390
342
 
391
- tile_mask[j*(nbatch_fa/2 + 4) + i] = mask_h2[j*stride_mask + i];
343
+ const int i = threadIdx.x % (WARP_SIZE/cols_per_warp);
344
+
345
+ ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i);
346
+ }
347
+ } else {
348
+ #pragma unroll
349
+ for (int j1 = 0; j1 < ncols1; j1 += nwarps) {
350
+ const int j_sram = j1 + threadIdx.y;
351
+ const int j_vram = fastmodulo(j0 + j_sram, ne01);
352
+
353
+ if (j1 + nwarps > ncols1 && j_sram >= ncols1) {
354
+ break;
355
+ }
356
+
357
+ #pragma unroll
358
+ for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) {
359
+ const int i = i0 + 2*threadIdx.x;
360
+
361
+ ggml_cuda_memcpy_1<sizeof(half2)>(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i);
362
+ }
363
+ }
392
364
  }
393
365
  }
394
366
 
395
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles,
396
- bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter>
367
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps,
368
+ bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter, bool oob_check,
369
+ typename T_A_KQ, typename T_B_KQ, typename T_C_KQ, typename T_A_VKQ, typename T_B_VKQ, typename T_C_VKQ>
397
370
  static __device__ __forceinline__ void flash_attn_ext_f16_iter(
398
371
  const float2 * const __restrict__ Q_f2,
399
372
  const half2 * const __restrict__ K_h2,
400
373
  const half2 * const __restrict__ V_h2,
401
- const half2 * const __restrict__ mask_h2,
374
+ const half * const __restrict__ mask_h,
402
375
  float2 * const __restrict__ dstk,
403
376
  float2 * const __restrict__ dstk_fixup,
404
377
  const float scale,
405
378
  const float slope,
406
379
  const float logit_softcap,
407
- const int ne01,
380
+ const uint3 ne01,
408
381
  const int ne02,
409
382
  const int stride_K,
410
383
  const int stride_V,
@@ -412,27 +385,24 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
412
385
  half2 * const __restrict__ tile_Q,
413
386
  half2 * const __restrict__ tile_K,
414
387
  half2 * const __restrict__ tile_V,
415
- half2 * const __restrict__ tile_mask,
416
- const tile_B * const __restrict__ Q_B,
417
- tile_C_VKQ * const __restrict__ VKQ_C,
388
+ half * const __restrict__ tile_mask,
389
+ T_B_KQ * const __restrict__ Q_B,
390
+ T_C_VKQ * const __restrict__ VKQ_C,
418
391
  float * const __restrict__ KQ_max,
419
392
  float * const __restrict__ KQ_rowsum,
420
- const int kb0) {
421
- #ifdef TURING_MMA_AVAILABLE
422
- typedef fattn_mma_f16_config<DKQ, DV> c;
423
-
424
- #ifdef CP_ASYNC_AVAILABLE
425
- constexpr int nstages = c::nstages_target;
426
- #else
427
- constexpr int nstages = 0;
428
- #endif // CP_ASYNC_AVAILABLE
429
-
430
- constexpr int cols_per_warp = ntiles * tile_B::I;
431
- constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
432
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
433
- constexpr int ncols = ncols1 * ncols2;
434
- constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
435
- constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
393
+ const int jt,
394
+ const int kb0,
395
+ const int k_VKQ_sup) {
396
+ #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
397
+ constexpr int ncols = ncols1 * ncols2;
398
+ constexpr int cols_per_warp = T_B_KQ::I;
399
+ constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
400
+ constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
401
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
402
+ constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2(DKQ, DV, ncols);
403
+ constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2(DKQ, DV, ncols);
404
+ constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
405
+ constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
436
406
 
437
407
  constexpr int stride_tile_Q = DKQ/2 + 4;
438
408
  constexpr int stride_tile_K = nbatch_K2 + 4;
@@ -440,26 +410,27 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
440
410
  static_assert(!mla || nbatch_K2 >= nbatch_V2, "bad nbatch_K2, nbatch_V2 for MLA");
441
411
  constexpr int stride_tile_V = mla ? stride_tile_K : nbatch_V2 + 4;
442
412
 
443
- const int k_VKQ_0 = kb0 * c::nbatch_fa;
444
- tile_C_KQ KQ_C[c::nbatch_fa/(np*tile_C_KQ::I) * ntiles];
445
-
446
- // Use wide variants of tiles if ntiles >= 2.
447
- tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
448
- tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
449
- tile_C_KQ_16 * KQ_C_16 = (tile_C_KQ_16 *) KQ_C;
413
+ const int k_VKQ_0 = kb0 * nbatch_fa;
414
+ #if defined(TURING_MMA_AVAILABLE)
415
+ T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))];
416
+ #else // Volta
417
+ T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)];
418
+ #endif // defined(TURING_MMA_AVAILABLE)
450
419
 
451
420
  if constexpr (nstages > 1) {
421
+ static_assert(!oob_check, "OOB check incompatible with multi-stage pipeline");
452
422
  static_assert(!mla, "multi-stage loading not implemented for MLA");
453
423
  static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi stage loading");
454
424
  constexpr bool use_cp_async = true;
455
425
  cp_async_wait_all();
456
426
  __syncthreads();
457
- flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
458
- (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V);
427
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
428
+ (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V, k_VKQ_sup);
459
429
  } else {
460
430
  constexpr bool use_cp_async = nstages == 1;
461
- if (ncols2 > 1 || mask_h2) {
462
- flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
431
+ if (ncols2 > 1 || mask_h) {
432
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
433
+ (mask_h + k_VKQ_0, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
463
434
  }
464
435
  }
465
436
 
@@ -468,10 +439,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
468
439
  const int k0_stop = k0_start + nbatch_K2 < DKQ/2 ? k0_start + nbatch_K2 : DKQ/2;
469
440
  const int k0_diff = k0_stop - k0_start;
470
441
 
471
- if (nstages <= 1) {
442
+ if constexpr (nstages <= 1) {
472
443
  constexpr bool use_cp_async = nstages == 1;
473
- flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
474
- (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K);
444
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
445
+ (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K, k_VKQ_sup);
475
446
  if (use_cp_async) {
476
447
  cp_async_wait_all();
477
448
  }
@@ -479,55 +450,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
479
450
  }
480
451
 
481
452
  // Calculate tile of KQ:
482
- if constexpr (c::Q_in_reg) {
453
+ if constexpr (Q_in_reg) {
483
454
  #pragma unroll
484
- for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
485
- const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
455
+ for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
456
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
486
457
  #pragma unroll
487
- for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
488
- tile_A K_A;
458
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
459
+ T_A_KQ K_A;
489
460
  load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
490
- if (ntiles == 1) {
491
- mma(KQ_C[i_KQ_00/(np*tile_A::I)], K_A, Q_B[k_KQ_0/tile_A::J]);
461
+ if constexpr (cols_per_warp == 8) {
462
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]);
492
463
  } else {
493
- #pragma unroll
494
- for (int t = 0; t < ntiles/2; ++t) {
495
- // Wide version of KQ_C is column-major => swap A and B.
496
- mma(KQ_C_16[i_KQ_00/(np*tile_A::I) * ntiles/2 + t], Q_B_16[k_KQ_0/tile_A::J * ntiles/2 + t], K_A);
497
- }
464
+ // Wide version of KQ_C is column-major => swap A and B.
465
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A);
498
466
  }
499
467
  }
500
468
  }
501
469
  } else {
502
- static_assert(ntiles == 2, "ntiles != 2 not implemented");
470
+ static_assert(cols_per_warp != 8, "cols_per_warp == 8 not implemented");
503
471
  #pragma unroll
504
- for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += tile_A::J) {
505
- load_ldmatrix(Q_B_16[0], tile_Q + (threadIdx.y / np)*(tile_B_16::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
472
+ for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += T_A_KQ::J) {
473
+ load_ldmatrix(Q_B[0], tile_Q + (threadIdx.y / np)*(T_B_KQ::I*stride_tile_Q) + k_KQ_0, stride_tile_Q);
506
474
 
507
475
  #pragma unroll
508
- for (int i_KQ_00 = 0; i_KQ_00 < c::nbatch_fa; i_KQ_00 += np*tile_A::I) {
509
- const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*tile_A::I;
476
+ for (int i_KQ_00 = 0; i_KQ_00 < nbatch_fa; i_KQ_00 += np*T_A_KQ::I) {
477
+ const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*T_A_KQ::I;
510
478
 
511
- tile_A K_A;
479
+ T_A_KQ K_A;
512
480
  load_ldmatrix(K_A, tile_K + i_KQ_0*stride_tile_K + (k_KQ_0 - k0_start), stride_tile_K);
513
481
 
514
482
  // Wide version of KQ_C is column-major => swap A and B.
515
- mma(KQ_C_16[i_KQ_00/(np*tile_A::I)], Q_B_16[0], K_A);
483
+ mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A);
516
484
  }
517
485
  }
518
486
  }
519
487
 
520
- if (nstages <= 1) {
488
+ if constexpr (nstages <= 1) {
521
489
  __syncthreads(); // Only needed if tile_K == tile_V.
522
490
  }
523
491
  }
524
492
 
525
493
  if (use_logit_softcap) {
526
- static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
494
+ constexpr int stride = cols_per_warp == 8 ? np*T_C_KQ::I : np*T_C_KQ::J;
495
+ static_assert(nbatch_fa % stride == 0, "bad loop size");
527
496
  #pragma unroll
528
- for (int i = 0; i < c::nbatch_fa/(np*tile_C_KQ::I) * ntiles; ++i) {
497
+ for (int i = 0; i < nbatch_fa/stride; ++i) {
529
498
  #pragma unroll
530
- for (int l = 0; l < tile_C_KQ::ne; ++l) {
499
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
531
500
  KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
532
501
  }
533
502
  }
@@ -540,34 +509,35 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
540
509
  }
541
510
  float KQ_rowsum_add[cols_per_thread] = {0.0f};
542
511
 
543
- if (ntiles == 1) {
544
- if (ncols2 > 1 || mask_h2) {
512
+ if constexpr (cols_per_warp == 8) {
513
+ if (ncols2 > 1 || mask_h) {
545
514
  #pragma unroll
546
- for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
547
- const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
515
+ for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::I) {
516
+ const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::I;
548
517
  #pragma unroll
549
- for (int l = 0; l < tile_C_KQ::ne; ++l) {
550
- const int i = i0 + tile_C_KQ::get_i(l);
551
- const int j = ((threadIdx.y / np)*tile_C_KQ::J + tile_C_KQ::get_j(l)) / ncols2;
518
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
519
+ const int i = i0 + T_C_KQ::get_i(l);
520
+ const int j = ((threadIdx.y / np)*T_C_KQ::J + T_C_KQ::get_j(l)) / ncols2;
552
521
 
553
- KQ_C[i00/(np*tile_C_KQ::I)].x[l] += slope *
554
- __half2float(((const half *) tile_mask)[j*(c::nbatch_fa + 8) + i]);
522
+ KQ_C[i00/(np*T_C_KQ::I)].x[l] += slope * __half2float(tile_mask[j*(nbatch_fa + 8) + i]);
555
523
  }
556
524
  }
557
525
  }
558
526
 
559
527
  // Calculate softmax for each KQ column using the current max. value.
560
528
  // The divisor is stored in KQ_rowsum and will be applied at the end.
561
- static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
529
+ static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
562
530
  #pragma unroll
563
- for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
531
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
564
532
  #pragma unroll
565
- for (int l = 0; l < tile_C_KQ::ne; ++l) {
566
- KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k].x[l]);
533
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
534
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
535
+ KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
536
+ }
567
537
  }
568
538
  }
569
539
 
570
- // Values per KQ column are spread across 8 threads, does not need full warp reduce:
540
+ // Values per KQ column are spread across 8 threads:
571
541
  #pragma unroll
572
542
  for (int col = 0; col < cols_per_thread; ++col) {
573
543
  #pragma unroll
@@ -576,73 +546,78 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
576
546
  }
577
547
  }
578
548
 
579
- static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
549
+ static_assert(nbatch_fa % (np*T_C_KQ::I) == 0, "bad loop size");
580
550
  #pragma unroll
581
- for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ::I); ++k) {
551
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
582
552
  #pragma unroll
583
- for (int l = 0; l < tile_C_KQ::ne; ++l) {
584
- KQ_C[k].x[l] = expf(KQ_C[k].x[l] - KQ_max_new[l % 2]);
585
-
586
- KQ_rowsum_add[l % 2] += KQ_C[k].x[l];
553
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
554
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) {
555
+ KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[l % 2]);
556
+ KQ_rowsum_add[l % 2] += KQ_C[k0/(np*T_C_KQ::I)].x[l];
557
+ } else {
558
+ KQ_C[k0/(np*T_C_KQ::I)].x[l] = 0.0f;
559
+ }
587
560
  }
588
561
  }
589
- } else { // ntiles > 1
590
- if (ncols2 > 1 || mask_h2) {
562
+ } else { // not Turing mma or T_B_KQ::I > 8
563
+ if (ncols2 > 1 || mask_h) {
591
564
  #pragma unroll
592
- for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ_16::J) {
593
- const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ_16::J;
565
+ for (int i00 = 0; i00 < nbatch_fa; i00 += np*T_C_KQ::J) {
566
+ const int i0 = i00 + (threadIdx.y % np)*T_C_KQ::J;
594
567
  #pragma unroll
595
- for (int t = 0; t < ntiles/2; ++t) {
596
- #pragma unroll
597
- for (int l0 = 0; l0 < tile_C_KQ_16::ne; l0 += 2) {
598
- const int i = (i0 + tile_C_KQ_16::get_j(l0)) / 2;
599
- const int j = ((threadIdx.y / np)*cols_per_warp + t*tile_C_KQ_16::I + tile_C_KQ_16::get_i(l0)) / ncols2;
568
+ for (int l0 = 0; l0 < T_C_KQ::ne; l0 += 2) {
569
+ const int i = (i0 + T_C_KQ::get_j(l0)) / 2;
570
+ const int j = ((threadIdx.y / np)*cols_per_warp + T_C_KQ::get_i(l0)) / ncols2;
600
571
 
601
- const float2 tmp = __half22float2(tile_mask[j*(c::nbatch_fa/2 + 4) + i]);
602
- const int KQ_index = i00/(np*tile_C_KQ_16::J) * ntiles/2 + t;
603
- KQ_C_16[KQ_index].x[l0 + 0] += slope*tmp.x;
604
- KQ_C_16[KQ_index].x[l0 + 1] += slope*tmp.y;
605
- }
572
+ const float2 tmp = __half22float2(((const half2 *)tile_mask)[j*(nbatch_fa/2 + 4) + i]);
573
+ KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 0] += slope*tmp.x;
574
+ KQ_C[i00/(np*T_C_KQ::J)].x[l0 + 1] += slope*tmp.y;
606
575
  }
607
576
  }
608
577
  }
609
578
 
610
579
  // Calculate softmax for each KQ column using the current max. value.
611
580
  // The divisor is stored in KQ_rowsum and will be applied at the end.
612
- static_assert(c::nbatch_fa % (np*tile_C_KQ::I) == 0, "bad loop size");
581
+ static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
613
582
  #pragma unroll
614
- for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
583
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
615
584
  #pragma unroll
616
- for (int t = 0; t < ntiles/2; ++t) {
617
- #pragma unroll
618
- for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
619
- const int KQ_index = 2*t + (l/2) % 2;
620
- KQ_max_new[KQ_index] = fmaxf(KQ_max_new[KQ_index], KQ_C_16[k*ntiles/2 + t].x[l]);
585
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
586
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
587
+ // Turing + Volta:
588
+ KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
621
589
  }
622
590
  }
623
591
  }
624
592
 
625
- // Values per KQ column are spread across 4 threads, does not need full warp reduce:
626
593
  #pragma unroll
627
594
  for (int col = 0; col < cols_per_thread; ++col) {
595
+ #if defined(TURING_MMA_AVAILABLE)
596
+ // Values per KQ column are spread across 4 threads:
597
+ constexpr int offset_first = 2;
598
+ constexpr int offset_last = 1;
599
+ #else
600
+ // Values per KQ column are spread across 2 threads:
601
+ constexpr int offset_first = 2;
602
+ constexpr int offset_last = 2;
603
+ #endif // defined(TURING_MMA_AVAILABLE)
628
604
  #pragma unroll
629
- for (int offset = 2; offset >= 1; offset >>= 1) {
605
+ for (int offset = offset_first; offset >= offset_last; offset >>= 1) {
630
606
  KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE));
631
607
  }
632
608
  }
633
609
 
634
- static_assert(c::nbatch_fa % (np*tile_C_KQ_16::J) == 0, "bad loop size");
610
+ static_assert(nbatch_fa % (np*T_C_KQ::J) == 0, "bad loop size");
635
611
  #pragma unroll
636
- for (int k = 0; k < c::nbatch_fa/(np*tile_C_KQ_16::J); ++k) {
612
+ for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
637
613
  #pragma unroll
638
- for (int t = 0; t < ntiles/2; ++t) {
639
- #pragma unroll
640
- for (int l = 0; l < tile_C_KQ_16::ne; ++l) {
641
- const int KQ_index = 2*t + (l/2) % 2;
642
-
643
- KQ_C_16[k*ntiles/2 + t].x[l] = expf(KQ_C_16[k*ntiles/2 + t].x[l] - KQ_max_new[KQ_index]);
644
-
645
- KQ_rowsum_add[KQ_index] += KQ_C_16[k*ntiles/2 + t].x[l];
614
+ for (int l = 0; l < T_C_KQ::ne; ++l) {
615
+ // Turing + Volta:
616
+ if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) {
617
+ KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[(l/2) % 2]);
618
+ KQ_rowsum_add[(l/2) % 2] += KQ_C[(k0/(np*T_C_KQ::J))].x[l];
619
+ } else {
620
+ KQ_C[(k0/(np*T_C_KQ::J))].x[l] = 0.0f;
646
621
  }
647
622
  }
648
623
  }
@@ -662,12 +637,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
662
637
  KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_rowsum_add[col];
663
638
  }
664
639
 
665
- if (ntiles == 1) {
640
+ #if defined(TURING_MMA_AVAILABLE)
641
+ if constexpr (cols_per_warp == 8) {
666
642
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
667
643
  #pragma unroll
668
- for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
644
+ for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
669
645
  #pragma unroll
670
- for (int l = 0; l < tile_C_VKQ::ne; ++l) {
646
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
671
647
  VKQ_C[i].x[l] *= KQ_max_scale_h2;
672
648
  }
673
649
  }
@@ -676,46 +652,53 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
676
652
  for (int col = 0; col < cols_per_thread; ++col) {
677
653
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
678
654
  #pragma unroll
679
- for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
655
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
680
656
  #pragma unroll
681
- for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
682
- VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
657
+ for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
658
+ VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
683
659
  }
684
660
  }
685
661
  }
686
662
  }
663
+ #else // Volta
664
+ const half2 KQ_max_scale_h2 = make_half2(
665
+ KQ_max_scale[(threadIdx.x / 2) % 2], KQ_max_scale[(threadIdx.x / 2) % 2]);
666
+ #pragma unroll
667
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
668
+ #pragma unroll
669
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
670
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
671
+ }
672
+ }
673
+ #endif // defined(TURING_MMA_AVAILABLE)
687
674
  }
688
675
 
689
676
  // Convert KQ C tiles into B tiles for VKQ calculation:
690
- tile_B B[c::nbatch_fa/(np*2*tile_B::J) * ntiles];
691
- tile_B_16 * B_16 = (tile_B_16 *) B;
692
- static_assert(c::nbatch_fa % (np*2*tile_B::J) == 0, "bad loop size");
693
- if (ntiles == 1) {
677
+ T_B_VKQ B[nbatch_fa/(np*2*T_B_VKQ::J)];
678
+ static_assert(nbatch_fa % (np*2*T_B_VKQ::J) == 0, "bad loop size");
679
+ if constexpr (cols_per_warp == 8) {
694
680
  #pragma unroll
695
- for (int k = 0; k < c::nbatch_fa/(np*2*tile_B::J); ++k) {
681
+ for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
696
682
  B[k] = get_transposed(get_half2(KQ_C[k]));
697
683
  }
698
684
  } else {
699
- for (int k = 0; k < c::nbatch_fa/(np*2*tile_B_16::J); ++k) {
700
- #pragma unroll
701
- for (int t = 0; t < ntiles/2; ++t) {
702
- B_16[k*ntiles/2 + t] = get_half2(KQ_C_16[k*ntiles/2 + t]);
703
- }
685
+ for (int k = 0; k < nbatch_fa/(np*2*T_B_VKQ::J); ++k) {
686
+ B[k] = get_half2(KQ_C[k]);
704
687
  }
705
688
  }
706
689
 
707
- if (nstages > 1) {
690
+ if constexpr (nstages > 1) {
708
691
  // Preload K tile for next iteration:
709
692
  constexpr bool use_cp_async = true;
710
693
  cp_async_wait_all();
711
694
  __syncthreads();
712
695
  if (!last_iter) {
713
- if (ncols2 > 1 || mask_h2) {
714
- flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
715
- (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask);
696
+ if (ncols2 > 1 || mask_h) {
697
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
698
+ (mask_h + k_VKQ_0 + nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
716
699
  }
717
- flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
718
- (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K);
700
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
701
+ (K_h2 + int64_t(k_VKQ_0 + nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
719
702
  }
720
703
  }
721
704
 
@@ -724,72 +707,119 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
724
707
  // Therefore, iterate over V in reverse and re-use the data if possible.
725
708
  static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
726
709
  constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
710
+
711
+ // Calculate VKQ tile, need to use logical rather than physical elements for i0 due to transposition of V:
727
712
  #pragma unroll
728
713
  for (int i0_stop = DV; i0_stop > 0; i0_stop -= 2*nbatch_V2) {
729
714
  const int i0_start = i0_stop - 2*nbatch_V2 > 0 ? i0_stop - 2*nbatch_V2 : 0;
730
715
  const int i0_diff = i0_stop - i0_start;
731
716
 
732
- if (nstages <= 1 && i0_start < reusable_cutoff) {
733
- constexpr bool use_cp_async = nstages == 1;
734
- flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async>
735
- (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V);
736
- if (use_cp_async) {
737
- cp_async_wait_all();
717
+ if constexpr (nstages <= 1) {
718
+ if (i0_start < reusable_cutoff) {
719
+ constexpr bool use_cp_async = nstages == 1;
720
+ flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, nbatch_fa, use_cp_async, oob_check>
721
+ (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V, k_VKQ_sup);
722
+ if (use_cp_async) {
723
+ cp_async_wait_all();
724
+ }
725
+ __syncthreads();
738
726
  }
739
- __syncthreads();
740
727
  }
741
728
  const half2 * tile_V_i = i0_start < reusable_cutoff ? tile_V : tile_V + (i0_start - reusable_cutoff)/2;
742
729
 
743
- // Calculate VKQ tile:
730
+ #if defined(TURING_MMA_AVAILABLE)
731
+ constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J;
744
732
  #pragma unroll
745
- for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += tile_C_VKQ::I) {
746
- static_assert((c::nbatch_fa/2) % (np*tile_A::J) == 0, "bad loop size");
733
+ for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
734
+ static_assert((nbatch_fa/2) % (np*T_A_VKQ::J) == 0, "bad loop size");
747
735
  #pragma unroll
748
- for (int k00 = 0; k00 < c::nbatch_fa/2; k00 += np*tile_A::J) {
749
- const int k0 = k00 + (threadIdx.y % np)*tile_A::J;
736
+ for (int k00 = 0; k00 < nbatch_fa/2; k00 += np*T_A_VKQ::J) {
737
+ const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::J;
750
738
 
751
- tile_A A;
739
+ T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load.
752
740
  load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
753
- if (ntiles == 1) {
754
- mma(VKQ_C[i_VKQ_0/tile_C_VKQ::I], A, B[k00/(np*tile_A::J)]);
741
+ if constexpr (T_B_KQ::I == 8) {
742
+ mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]);
755
743
  } else {
756
- #pragma unroll
757
- for (int t = 0; t < ntiles/2; ++t) {
758
- // Wide version of VKQ_C is column-major => swap A and B.
759
- mma(VKQ_C_16[i_VKQ_0/tile_C_VKQ::I * ntiles/2 + t], B_16[k00/(np*tile_A::J) * ntiles/2 + t], A);
760
- }
744
+ // Wide version of VKQ_C is column-major => swap A and B.
745
+ mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A);
761
746
  }
762
747
  }
763
748
  }
749
+ #else // Volta
750
+ constexpr int i0_stride = 2*T_C_VKQ::J;
751
+ #pragma unroll
752
+ for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) {
753
+ static_assert(nbatch_fa % (np*T_A_VKQ::I) == 0, "bad loop size");
754
+ static_assert(2*T_B_VKQ::J == T_A_VKQ::I, "bad tile sizes");
755
+ #pragma unroll
756
+ for (int k00 = 0; k00 < nbatch_fa; k00 += np*T_A_VKQ::I) {
757
+ const int k0 = k00 + (threadIdx.y % np)*T_A_VKQ::I;
758
+
759
+ T_A_VKQ A; // Transposed in both SRAM and registers, load normally.
760
+ load_ldmatrix(A, tile_V_i + k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V);
761
+ mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A);
762
+ }
763
+ }
764
+ #endif // defined(TURING_MMA_AVAILABLE)
764
765
 
765
- if (nstages <= 1) {
766
+ if constexpr (nstages <= 1) {
766
767
  __syncthreads(); // Only needed if tile_K == tile_V.
767
768
  }
768
769
  }
769
770
  #else
770
- GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup,
771
+ GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup,
771
772
  scale, slope, logit_softcap, ne01, ne02,
772
773
  stride_K, stride_V, stride_mask,
773
774
  tile_Q, tile_K, tile_V, tile_mask,
774
775
  Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
775
776
  NO_DEVICE_CODE;
776
- #endif // TURING_MMA_AVAILABLE
777
+ #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
777
778
  }
778
779
 
779
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
780
+ #if defined(TURING_MMA_AVAILABLE)
781
+ template<int ncols> struct mma_tile_sizes {
782
+ using T_A_KQ = tile<16, 8, half2>; // row-major
783
+ using T_B_KQ = tile<16, 8, half2>; // column-major
784
+ using T_C_KQ = tile<16, 16, float>; // column-major
785
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
786
+ using T_B_VKQ = tile<16, 8, half2>; // column-major
787
+ using T_C_VKQ = tile<16, 8, half2>; // column-major
788
+ };
789
+ template<> struct mma_tile_sizes<8> {
790
+ using T_A_KQ = tile<16, 8, half2>; // row-major
791
+ using T_B_KQ = tile< 8, 8, half2>; // column-major
792
+ using T_C_KQ = tile<16, 8, float>; // row-major
793
+ using T_A_VKQ = tile<16, 8, half2>; // row-major
794
+ using T_B_VKQ = tile< 8, 8, half2>; // column-major
795
+ using T_C_VKQ = tile<16, 4, half2>; // row-major
796
+ };
797
+ #else // Volta
798
+ template<int ncols> struct mma_tile_sizes {
799
+ using T_A_KQ = tile< 8, 4, half2, DATA_LAYOUT_I_MAJOR_MIRRORED>; // row-major
800
+ using T_B_KQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
801
+ using T_C_KQ = tile<32, 8, float, DATA_LAYOUT_I_MAJOR>; // column-major
802
+ using T_A_VKQ = tile< 8, 4, half2, DATA_LAYOUT_J_MAJOR_MIRRORED>; // column-major
803
+ using T_B_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
804
+ using T_C_VKQ = tile<32, 4, half2, DATA_LAYOUT_I_MAJOR>; // column-major
805
+ };
806
+ #endif // defined(TURING_MMA_AVAILABLE)
807
+
808
+ template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup>
780
809
  static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
781
810
  const float2 * const __restrict__ Q_f2,
782
811
  const half2 * const __restrict__ K_h2,
783
812
  const half2 * const __restrict__ V_h2,
784
- const half2 * const __restrict__ mask_h2,
813
+ const half * const __restrict__ mask_h,
785
814
  const float * const __restrict__ sinks_f,
786
815
  float2 * const __restrict__ dstk,
787
816
  float2 * const __restrict__ dstk_fixup,
788
817
  const float scale,
789
818
  const float slope,
790
819
  const float logit_softcap,
791
- const int ne01,
820
+ const uint3 ne01,
792
821
  const int ne02,
822
+ const int ne11,
793
823
  const int stride_Q1,
794
824
  const int stride_Q2,
795
825
  const int stride_K,
@@ -798,23 +828,31 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
798
828
  const int jt,
799
829
  const int kb0_start,
800
830
  const int kb0_stop) {
801
- #ifdef TURING_MMA_AVAILABLE
831
+ #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
802
832
  //In this kernel Q, K, V are matrices while i, j, k are matrix indices.
803
833
 
804
- typedef fattn_mma_f16_config<DKQ, DV> c;
805
-
806
- #ifdef CP_ASYNC_AVAILABLE
807
- constexpr int nstages = c::nstages_target;
808
- #else
809
- constexpr int nstages = 0;
810
- #endif // CP_ASYNC_AVAILABLE
811
-
812
- constexpr int ncols = ncols1 * ncols2;
813
- constexpr int cols_per_warp = ntiles * tile_B::I;
814
- constexpr int cols_per_thread = ntiles == 1 ? 2 : ntiles;
815
- constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
816
- constexpr int nbatch_K2 = c::get_nbatch_K2_device(ncols);
817
- constexpr int nbatch_V2 = c::get_nbatch_V2_device(ncols);
834
+ constexpr int ncols = ncols1 * ncols2;
835
+ using T_A_KQ = typename mma_tile_sizes<ncols>::T_A_KQ;
836
+ using T_B_KQ = typename mma_tile_sizes<ncols>::T_B_KQ;
837
+ using T_C_KQ = typename mma_tile_sizes<ncols>::T_C_KQ;
838
+ using T_A_VKQ = typename mma_tile_sizes<ncols>::T_A_VKQ;
839
+ using T_B_VKQ = typename mma_tile_sizes<ncols>::T_B_VKQ;
840
+ using T_C_VKQ = typename mma_tile_sizes<ncols>::T_C_VKQ;
841
+
842
+ constexpr int cols_per_warp = T_B_KQ::I;
843
+ constexpr int cols_per_thread = 2; // This is specifically KQ columns, Volta only has a single VKQ column.
844
+ constexpr int np = nwarps * (cols_per_warp/ncols2) / ncols1; // Number of parallel CUDA warps per Q column.
845
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols);
846
+ constexpr int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols);
847
+ constexpr int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols);
848
+ constexpr int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols);
849
+ constexpr bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols);
850
+ constexpr int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2);
851
+
852
+ if (cols_per_warp > ncols) {
853
+ NO_DEVICE_CODE;
854
+ return;
855
+ }
818
856
 
819
857
  static_assert(nwarps * (cols_per_warp/ncols2) % ncols1 == 0, "bad nwarps");
820
858
 
@@ -826,15 +864,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
826
864
  constexpr int stride_tile_KV_max = stride_tile_K > stride_tile_V ? stride_tile_K : stride_tile_V;
827
865
 
828
866
  extern __shared__ half2 tile_Q[];
829
- half2 * tile_K = c::Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
830
- half2 * tile_V = nstages > 1 ? tile_K + c::nbatch_fa * stride_tile_K : tile_K;
831
- half2 * tile_mask = nstages > 1 ? tile_V + c::nbatch_fa * stride_tile_V : tile_V + c::nbatch_fa * stride_tile_KV_max;
832
-
833
- tile_B Q_B[(c::Q_in_reg ? DKQ/(2*tile_B::J) : 1) * ntiles];
834
- tile_C_VKQ VKQ_C[DV/tile_C_VKQ::I * ntiles];
867
+ half2 * tile_K = Q_in_reg ? tile_Q : tile_Q + ncols * stride_tile_Q;
868
+ half2 * tile_V = nstages > 1 ? tile_K + nbatch_fa * stride_tile_K : tile_K;
869
+ half * tile_mask = (half *) (nstages > 1 ? tile_V + nbatch_fa * stride_tile_V : tile_V + nbatch_fa * stride_tile_KV_max);
835
870
 
836
- tile_B_16 * Q_B_16 = (tile_B_16 *) Q_B;
837
- tile_C_VKQ_16 * VKQ_C_16 = (tile_C_VKQ_16 *) VKQ_C;
871
+ T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)];
872
+ #if defined(TURING_MMA_AVAILABLE)
873
+ T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)];
874
+ #else // Volta
875
+ T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)];
876
+ #endif // defined(TURING_MMA_AVAILABLE)
838
877
 
839
878
  float KQ_rowsum[cols_per_thread] = {0.0f};
840
879
  float KQ_max[cols_per_thread];
@@ -868,7 +907,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
868
907
  const int j = jc / ncols2;
869
908
  const int c = jc % ncols2;
870
909
 
871
- if (jt*ncols1 + j < ne01) {
910
+ if (jt*ncols1 + j < int(ne01.z)) {
872
911
  #pragma unroll
873
912
  for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
874
913
  const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
@@ -889,63 +928,93 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
889
928
 
890
929
  __syncthreads();
891
930
 
892
- if (c::Q_in_reg) {
931
+ if (Q_in_reg) {
893
932
  const int j0 = (threadIdx.y / np) * cols_per_warp;
894
933
 
895
934
  #pragma unroll
896
- for (int k0 = 0; k0 < DKQ/2; k0 += tile_B::J) {
897
- if (ntiles == 1) {
898
- load_ldmatrix(Q_B[k0/tile_B::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
899
- } else {
900
- #pragma unroll
901
- for (int t = 0; t < ntiles/2; ++t) {
902
- load_ldmatrix(Q_B_16[k0/tile_B_16::J * ntiles/2 + t],
903
- tile_Q + (j0 + t*tile_B_16::I)*stride_tile_Q + k0, stride_tile_Q);
904
- }
905
- }
935
+ for (int k0 = 0; k0 < DKQ/2; k0 += T_B_KQ::J) {
936
+ load_ldmatrix(Q_B[k0/T_B_KQ::J], tile_Q + j0*stride_tile_Q + k0, stride_tile_Q);
906
937
  }
907
938
  }
908
939
 
909
940
  __syncthreads();
910
941
 
942
+ int kb0 = kb0_start;
943
+
911
944
  // Preload mask and K data for first iteration when using cp_async with multiple stages:
912
945
  if constexpr (nstages > 1) {
913
946
  static_assert(nbatch_K2 == DKQ/2, "batching not implemented for multi-stage pipeline");
914
947
  constexpr bool use_cp_async = true;
915
- if (ncols2 > 1 || mask_h2) {
916
- flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>
917
- (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask);
948
+ constexpr bool oob_check = false;
949
+ constexpr int k_VKQ_sup = nbatch_fa;
950
+ if (ncols2 > 1 || mask_h) {
951
+ flash_attn_ext_f16_load_mask<ncols1, nwarps, nbatch_fa, use_cp_async, oob_check>
952
+ (mask_h + kb0*nbatch_fa, tile_mask, stride_mask, k_VKQ_sup, jt*ncols1, ne01);
953
+ }
954
+ flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, nbatch_fa, use_cp_async, oob_check>
955
+ (K_h2 + int64_t(kb0)*nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K, k_VKQ_sup);
956
+ }
957
+
958
+ // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
959
+ if constexpr (ncols2 == 1) {
960
+ constexpr bool oob_check = true;
961
+ for (; kb0 < kb0_stop-1; ++kb0) {
962
+ constexpr bool last_iter = false;
963
+ constexpr int k_VKQ_sup = nbatch_fa;
964
+ flash_attn_ext_f16_iter
965
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
966
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
967
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
968
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
969
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
970
+ }
971
+ constexpr bool last_iter = true;
972
+ const int k_VKQ_sup = ne11 - kb0*nbatch_fa;
973
+ flash_attn_ext_f16_iter
974
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
975
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
976
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
977
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
978
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
979
+ } else {
980
+ constexpr bool oob_check = false;
981
+ for (; kb0 < kb0_stop-1; ++kb0) {
982
+ constexpr bool last_iter = false;
983
+ constexpr int k_VKQ_sup = nbatch_fa;
984
+ flash_attn_ext_f16_iter
985
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
986
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
987
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
988
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
989
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
918
990
  }
919
- flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async>
920
- (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K);
921
- }
922
-
923
- // Iterate over ne11 == previous tokens:
924
- int kb0 = kb0_start;
925
- for (; kb0 < kb0_stop-1; ++kb0) {
926
- constexpr bool last_iter = false;
927
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
928
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
929
- ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
930
- }
931
- { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally.
932
991
  constexpr bool last_iter = true;
933
- flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter>
934
- (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap,
935
- ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0);
992
+ constexpr int k_VKQ_sup = nbatch_fa;
993
+ flash_attn_ext_f16_iter
994
+ <DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter, oob_check,
995
+ T_A_KQ, T_B_KQ, T_C_KQ, T_A_VKQ, T_B_VKQ, T_C_VKQ>
996
+ (Q_f2, K_h2, V_h2, mask_h, dstk, dstk_fixup, scale, slope, logit_softcap,
997
+ ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C,
998
+ KQ_max, KQ_rowsum, jt, kb0, k_VKQ_sup);
936
999
  }
937
1000
 
938
1001
  // With multi-stage loading there is no __syncthreads at the end of the iter,
939
1002
  // there can be a race condition on shared memory access for combining/writing back results.
940
- if (nstages > 1 && nwarps*cols_per_warp > c::nbatch_fa) {
1003
+ if constexpr (nstages > 1 && nwarps*cols_per_warp > nbatch_fa) {
941
1004
  __syncthreads();
942
1005
  }
943
1006
 
944
1007
  // Finally, sum up partial KQ rowsums.
945
- // The partial sums are spread across 8/4 threads each, does not need full reduce.
946
1008
  {
947
- constexpr int offset_first = ntiles == 1 ? 16 : 2;
948
- constexpr int offset_last = ntiles == 1 ? 4 : 1;
1009
+ #if defined(TURING_MMA_AVAILABLE)
1010
+ // The partial sums are spread across 8/4 threads.
1011
+ constexpr int offset_first = cols_per_warp == 8 ? 16 : 2;
1012
+ constexpr int offset_last = cols_per_warp == 8 ? 4 : 1;
1013
+ #else // Volta
1014
+ // The partial sums are spread across 2 threads.
1015
+ constexpr int offset_first = 2;
1016
+ constexpr int offset_last = 2;
1017
+ #endif // defined(TURING_MMA_AVAILABLE)
949
1018
  #pragma unroll
950
1019
  for (int col = 0; col < cols_per_thread; ++col) {
951
1020
  #pragma unroll
@@ -962,8 +1031,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
962
1031
  float KQ_max_scale[cols_per_thread];
963
1032
  #pragma unroll
964
1033
  for (int col = 0; col < cols_per_thread; ++col) {
965
- static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
966
- const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
1034
+ const int jc = cols_per_warp == 8 ? T_C_KQ::get_j(col) : T_C_KQ::get_i(2*col);
967
1035
  const float sink = sinks_f[jc % ncols2];
968
1036
 
969
1037
  const float KQ_max_new = fmaxf(KQ_max[col], sink);
@@ -977,12 +1045,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
977
1045
  KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
978
1046
  }
979
1047
 
980
- if (ntiles == 1) {
1048
+ #if defined(TURING_MMA_AVAILABLE)
1049
+ if constexpr (cols_per_warp == 8) {
981
1050
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
982
1051
  #pragma unroll
983
- for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
1052
+ for (int i = 0; i < DV/T_C_VKQ::I; ++i) {
984
1053
  #pragma unroll
985
- for (int l = 0; l < tile_C_VKQ::ne; ++l) {
1054
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
986
1055
  VKQ_C[i].x[l] *= KQ_max_scale_h2;
987
1056
  }
988
1057
  }
@@ -991,30 +1060,40 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
991
1060
  for (int col = 0; col < cols_per_thread; ++col) {
992
1061
  const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
993
1062
  #pragma unroll
994
- for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
1063
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
995
1064
  #pragma unroll
996
- for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
997
- VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
1065
+ for (int l0 = 0; l0 < T_C_VKQ::ne; l0 += 2) {
1066
+ VKQ_C[i].x[l0 + col] *= KQ_max_scale_h2;
998
1067
  }
999
1068
  }
1000
1069
  }
1001
1070
  }
1071
+ #else // Volta
1072
+ const int col = (threadIdx.x / 2) % 2;
1073
+ const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
1074
+ #pragma unroll
1075
+ for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) {
1076
+ #pragma unroll
1077
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
1078
+ VKQ_C[i].x[l] *= KQ_max_scale_h2;
1079
+ }
1080
+ }
1081
+ #endif // defined(TURING_MMA_AVAILABLE)
1002
1082
  }
1003
1083
 
1004
1084
  // Combine VKQ accumulator values if np > 1.
1005
1085
  // It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
1006
1086
  // So also write VKQ accumulators to shared memory in column-major format if np == 1.
1007
1087
 
1008
- constexpr int nbatch_combine = c::get_nbatch_combine_device(ncols);
1009
- constexpr int tile_stride = nbatch_combine + 4;
1088
+ constexpr int tile_stride = nbatch_combine + 4;
1010
1089
  static_assert((DV/2) % nbatch_combine == 0, "bad nbatch_combine");
1011
1090
 
1012
- if constexpr (ntiles == 1) {
1013
- const int jc_cwmo = (threadIdx.x % (2*tile_C_VKQ::J)) / tile_C_VKQ::J; // jc combine write meta offset
1014
- const int jc_cwm = threadIdx.y*(2*tile_C_VKQ::J) + 2*tile_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
1091
+ if constexpr (cols_per_warp == 8) {
1092
+ const int jc_cwmo = (threadIdx.x % (2*T_C_VKQ::J)) / T_C_VKQ::J; // jc combine write meta offset
1093
+ const int jc_cwm = threadIdx.y*(2*T_C_VKQ::J) + 2*T_C_VKQ::get_j(-1) + jc_cwmo; // jc combine write meta
1015
1094
  const float2 KQ_cmr = make_float2(KQ_max[jc_cwmo], KQ_rowsum[jc_cwmo]); // KQ combine max rowsum
1016
1095
 
1017
- if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*tile_C_VKQ::J) {
1096
+ if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*T_C_VKQ::J) {
1018
1097
  // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
1019
1098
  ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
1020
1099
  }
@@ -1023,24 +1102,30 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1023
1102
 
1024
1103
  if (np == 1) {
1025
1104
  // No combination is needed, the meta data can be directly written from registers to VRAM.
1026
- if (needs_fixup && threadIdx.x < tile_B::I) {
1105
+ if (needs_fixup && threadIdx.x < T_B_KQ::I) {
1027
1106
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1028
1107
  dstk_fixup_meta[jc_cwm] = KQ_cmr;
1029
1108
  }
1030
- if (is_fixup && threadIdx.x < tile_B::I) {
1109
+ if (is_fixup && threadIdx.x < T_B_KQ::I) {
1031
1110
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1032
1111
  dstk_fixup_meta[jc_cwm] = KQ_cmr;
1033
1112
  }
1034
1113
  }
1035
1114
  } else {
1036
- static_assert(ntiles == 2 || ntiles == 4, "bad ntiles");
1037
- const int jc_cwm = threadIdx.y*cols_per_warp // jc combine write meta
1038
- + (ntiles == 4 ? ((threadIdx.x % 4) / 2) * tile_C_VKQ_16::I : 0)
1039
- + tile_C_VKQ_16::get_i(threadIdx.x % 4);
1040
- const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); // KQ combine max rowsum
1041
-
1042
- if (((!needs_fixup && !is_fixup) || np > 1) && (ntiles == 4 || threadIdx.x % 4 < cols_per_thread)) {
1043
- // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
1115
+ // jc_cwm = jc combine write meta
1116
+ // KQ_cmr = KQ combine max rowsum
1117
+ // Use the 16 bytes of padding in each Q column to store the meta data: KQ max, KQ rowsum, KQ max scale.
1118
+ #if defined(TURING_MMA_AVAILABLE)
1119
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4);
1120
+ const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]);
1121
+ const bool thread_should_write = threadIdx.x % 4 < cols_per_thread;
1122
+ #else // Volta
1123
+ const int jc_cwm = threadIdx.y*cols_per_warp + T_C_KQ::get_i(threadIdx.x & 2);
1124
+ const float2 KQ_cmr = make_float2(KQ_max[(threadIdx.x & 2) / 2], KQ_rowsum[(threadIdx.x & 2) / 2]);
1125
+ const bool thread_should_write = T_C_KQ::J == 8 || T_C_KQ::get_j(threadIdx.x & 2) < 8;
1126
+ #endif // defined(TURING_MMA_AVAILABLE)
1127
+
1128
+ if (((!needs_fixup && !is_fixup) || np > 1) && thread_should_write) {
1044
1129
  ((float2 *) tile_Q)[jc_cwm*(tile_stride/2) + nbatch_combine/2] = KQ_cmr;
1045
1130
  }
1046
1131
 
@@ -1048,18 +1133,17 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1048
1133
 
1049
1134
  if (np == 1) {
1050
1135
  // No combination is needed, the meta data can be directly written from registers to VRAM.
1051
- if (needs_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
1136
+ if (needs_fixup && thread_should_write) {
1052
1137
  float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
1053
1138
  dstk_fixup_meta[jc_cwm] = KQ_cmr;
1054
1139
  }
1055
- if (is_fixup && (ntiles == 4 || threadIdx.x % 4 < ntiles)) {
1140
+ if (is_fixup && thread_should_write) {
1056
1141
  float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
1057
1142
  dstk_fixup_meta[jc_cwm] = KQ_cmr;
1058
1143
  }
1059
1144
  }
1060
1145
  }
1061
1146
 
1062
- static_assert(np == 1 || ntiles == 1 || ntiles == 2, "bad ntiles");
1063
1147
  if (np > 1 && threadIdx.y % np == 0) {
1064
1148
  // Combine the meta data for parallel warps via shared memory.
1065
1149
  // Warps with threadIdx.y % np != 0 must NOT return early.
@@ -1135,32 +1219,29 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1135
1219
 
1136
1220
  #pragma unroll
1137
1221
  for (int k00 = 0; k00 < DV/2; k00 += nbatch_combine) {
1138
- if (ntiles == 1) {
1139
- const int jc_cwd = threadIdx.y*tile_B::I + tile_B::get_i(-1); // jc combine write data
1222
+ if constexpr (cols_per_warp == 8) {
1223
+ const int jc_cwd = threadIdx.y*T_B_KQ::I + T_B_KQ::get_i(-1); // jc combine write data
1140
1224
  #pragma unroll
1141
- for (int k0 = 0; k0 < nbatch_combine; k0 += tile_B::J) {
1142
- const tile_B B = get_transposed(VKQ_C[(k00 + k0)/tile_B::J]); // Conversion of C to B matrix puts it in column-major format.
1225
+ for (int k1 = 0; k1 < nbatch_combine; k1 += T_B_KQ::J) {
1226
+ const T_B_KQ B = get_transposed(VKQ_C[(k00 + k1)/T_B_KQ::J]); // Conversion of C to B matrix puts it in column-major format.
1143
1227
 
1144
1228
  #pragma unroll
1145
- for (int l = 0; l < tile_B::ne; ++l) {
1146
- const int k = k0 + tile_B::get_j(l);
1229
+ for (int l = 0; l < T_B_KQ::ne; ++l) {
1230
+ const int k = k1 + T_B_KQ::get_j(l);
1147
1231
 
1148
1232
  tile_Q[jc_cwd*tile_stride + k] = B.x[l];
1149
1233
  }
1150
1234
  }
1151
1235
  } else {
1236
+ const int j0 = threadIdx.y*cols_per_warp;
1152
1237
  #pragma unroll
1153
- for (int t = 0; t < ntiles/2; ++t) {
1154
- const int j0 = threadIdx.y*cols_per_warp + t*tile_C_VKQ_16::I;
1238
+ for (int k1 = 0; k1 < nbatch_combine; k1 += T_C_VKQ::J) {
1155
1239
  #pragma unroll
1156
- for (int k0 = 0; k0 < nbatch_combine; k0 += tile_C_VKQ_16::J) {
1157
- #pragma unroll
1158
- for (int l = 0; l < tile_C_VKQ_16::ne; ++l) {
1159
- const int j = j0 + tile_C_VKQ_16::get_i(l);
1160
- const int k = k0 + tile_C_VKQ_16::get_j(l);
1240
+ for (int l = 0; l < T_C_VKQ::ne; ++l) {
1241
+ const int j = j0 + T_C_VKQ::get_i(l);
1242
+ const int k = k1 + T_C_VKQ::get_j(l);
1161
1243
 
1162
- tile_Q[j*tile_stride + k] = VKQ_C_16[(k00 + k0)/tile_C_VKQ_16::J * ntiles/2 + t].x[l];
1163
- }
1244
+ tile_Q[j*tile_stride + k] = VKQ_C[(k00 + k1)/T_C_VKQ::J].x[l];
1164
1245
  }
1165
1246
  }
1166
1247
  }
@@ -1195,7 +1276,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1195
1276
  const int j_dst = jc_dst / ncols2;
1196
1277
  const int c_dst = jc_dst % ncols2;
1197
1278
 
1198
- if (!is_fixup && jt*ncols1 + j_dst >= ne01) {
1279
+ if (!is_fixup && jt*ncols1 + j_dst >= int(ne01.z)) {
1199
1280
  continue;
1200
1281
  }
1201
1282
 
@@ -1233,16 +1314,16 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
1233
1314
  }
1234
1315
  }
1235
1316
  #else
1236
- GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dstk_fixup,
1317
+ GGML_UNUSED_VARS(Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dstk_fixup,
1237
1318
  scale, slope, logit_softcap, ne01, ne02,
1238
1319
  stride_Q1, stride_Q2, stride_K, stride_V, stride_mask,
1239
1320
  jt, kb0_start, kb0_stop);
1240
1321
  NO_DEVICE_CODE;
1241
- #endif // TURING_MMA_AVAILABLE
1322
+ #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)
1242
1323
  }
1243
1324
 
1244
- template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla>
1245
- __launch_bounds__(nwarps*WARP_SIZE, 1)
1325
+ template<int DKQ, int DV, int ncols1, int ncols2, bool use_logit_softcap, bool mla>
1326
+ __launch_bounds__(ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_mma_get_occupancy(DKQ, DV, ncols1*ncols2))
1246
1327
  static __global__ void flash_attn_ext_f16(
1247
1328
  const char * __restrict__ Q,
1248
1329
  const char * __restrict__ K,
@@ -1258,14 +1339,14 @@ static __global__ void flash_attn_ext_f16(
1258
1339
  const float m1,
1259
1340
  const uint32_t n_head_log2,
1260
1341
  const float logit_softcap,
1261
- const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
1342
+ const int32_t ne00, const uint3 ne01, const int32_t ne02, const int32_t ne03,
1262
1343
  const int32_t nb01, const int32_t nb02, const int32_t nb03,
1263
1344
  const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
1264
1345
  const int32_t nb11, const int32_t nb12, const int64_t nb13,
1265
1346
  const int32_t nb21, const int32_t nb22, const int64_t nb23,
1266
1347
  const int32_t ne31, const int32_t ne32, const int32_t ne33,
1267
1348
  const int32_t nb31, const int32_t nb32, const int64_t nb33) {
1268
- #if defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
1349
+ #if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1269
1350
 
1270
1351
  // Skip unused kernel variants for faster compilation:
1271
1352
  if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) {
@@ -1281,27 +1362,26 @@ static __global__ void flash_attn_ext_f16(
1281
1362
 
1282
1363
  static_assert(!mla || DKQ >= DV, "MLA needs DKQ >= DV");
1283
1364
 
1284
- typedef fattn_mma_f16_config<DKQ, DV> c;
1285
-
1286
- static_assert(FATTN_KQ_STRIDE % fattn_mma_f16_config<DKQ, DV>::nbatch_fa == 0, "bad nbatch_fa");
1365
+ constexpr int ncols = ncols1 * ncols2;
1366
+ constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols);
1367
+ constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols);
1368
+ constexpr int nwarps = nthreads / WARP_SIZE;
1287
1369
 
1288
1370
  const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
1289
1371
 
1290
1372
  const int stride_Q1 = nb01 / sizeof(float2);
1291
1373
  const int stride_Q2 = nb02 / sizeof(float2);
1292
1374
  const int stride_K = nb11 / sizeof(half2);
1293
- const int stride_mask = nb31 / sizeof(half2);
1375
+ const int stride_mask = nb31 / sizeof(half);
1294
1376
 
1295
1377
  const int stride_V = mla ? stride_K : nb21 / sizeof(half2);
1296
1378
 
1297
- const int iter_k = ne11 / FATTN_KQ_STRIDE;
1298
- const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
1299
-
1300
- constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
1379
+ const int iter_k = (ne11 + (nbatch_fa - 1)) / nbatch_fa;
1380
+ const int iter_j = (ne01.z + (ncols1 - 1)) / ncols1;
1301
1381
 
1302
1382
  // kbc == k block continuous, current index in continuous ijk space.
1303
- int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1304
- const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1383
+ int kbc = int64_t(blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1384
+ const int kbc_stop = int64_t(blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
1305
1385
 
1306
1386
  // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
1307
1387
  // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
@@ -1318,35 +1398,31 @@ static __global__ void flash_attn_ext_f16(
1318
1398
 
1319
1399
  const int head0 = zt * ncols2;
1320
1400
 
1321
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1322
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1323
- const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1324
- (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1325
- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
1401
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1402
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1403
+ const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1404
+ (const half *) (mask + nb33*(sequence % ne33));
1405
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
1326
1406
 
1327
1407
  const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1328
1408
  const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1329
1409
 
1330
1410
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1331
1411
 
1332
- const int kb0_start_kernel = kb0_start * kb_niter;
1333
- int kb0_stop_kernel = kb0_stop * kb_niter;
1334
-
1335
1412
  if (KV_max) {
1336
- kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
1413
+ kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
1337
1414
  }
1338
-
1339
1415
  constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
1340
1416
  if (kb0_start == 0) {
1341
1417
  constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
1342
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1343
- (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1344
- ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1418
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
1419
+ (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1420
+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1345
1421
  } else {
1346
- constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
1347
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1348
- (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1349
- ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1422
+ constexpr bool needs_fixup = true; // CUDA block is missing the beginning of a tile.
1423
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
1424
+ (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1425
+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1350
1426
  }
1351
1427
 
1352
1428
  kbc += iter_k;
@@ -1366,29 +1442,26 @@ static __global__ void flash_attn_ext_f16(
1366
1442
 
1367
1443
  const int head0 = zt * ncols2;
1368
1444
 
1369
- const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1370
- const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1371
- const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
1372
- (const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
1373
- float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head0) * (DV/2);
1445
+ const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02* head0);
1446
+ const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio));
1447
+ const half * mask_h = ncols2 == 1 && !mask ? nullptr :
1448
+ (const half *) (mask + nb33*(sequence % ne33));
1449
+ float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);
1374
1450
 
1375
1451
  const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
1376
1452
  const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;
1377
1453
 
1378
1454
  const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
1379
1455
 
1380
- const int kb0_start_kernel = kb0_start * kb_niter;
1381
- int kb0_stop_kernel = kb0_stop * kb_niter;
1382
-
1383
1456
  if (KV_max) {
1384
- kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa);
1457
+ kb0_stop = min(kb0_stop, KV_max[sequence*iter_j + jt] / nbatch_fa);
1385
1458
  }
1386
1459
 
1387
1460
  constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
1388
1461
  constexpr bool needs_fixup = false;
1389
- flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
1390
- (Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1391
- ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
1462
+ flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, use_logit_softcap, mla, needs_fixup, is_fixup>
1463
+ (Q_f2, K_h2, V_h2, mask_h, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
1464
+ ne01, ne02, ne11, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop);
1392
1465
  #else
1393
1466
  GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
1394
1467
  max_bias, m0, m1, n_head_log2, logit_softcap,
@@ -1400,7 +1473,7 @@ static __global__ void flash_attn_ext_f16(
1400
1473
  ne31, ne32, ne33,
1401
1474
  nb31, nb32, nb33);
1402
1475
  NO_DEVICE_CODE;
1403
- #endif // defined(FLASH_ATTN_AVAILABLE) && defined(TURING_MMA_AVAILABLE)
1476
+ #endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE))
1404
1477
  }
1405
1478
 
1406
1479
  template <int DKQ, int DV, int ncols1, int ncols2>
@@ -1409,36 +1482,30 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1409
1482
  const int id = ggml_cuda_get_device();
1410
1483
  const int cc = ggml_cuda_info().devices[id].cc;
1411
1484
 
1412
- typedef fattn_mma_f16_config<DKQ, DV> c;
1485
+ constexpr int ncols = ncols1 * ncols2;
1413
1486
 
1414
- const int nstages = cp_async_available(cc) ? c::nstages_target : 0;
1487
+ const int nthreads = ggml_cuda_fattn_mma_get_nthreads (DKQ, DV, ncols, cc);
1488
+ const int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa (DKQ, DV, ncols, cc);
1489
+ const int nbatch_K2 = ggml_cuda_fattn_mma_get_nbatch_K2 (DKQ, DV, ncols, cc);
1490
+ const int nbatch_V2 = ggml_cuda_fattn_mma_get_nbatch_V2 (DKQ, DV, ncols, cc);
1491
+ const int nbatch_combine = ggml_cuda_fattn_mma_get_nbatch_combine(DKQ, DV, ncols, cc);
1492
+ const bool Q_in_reg = ggml_cuda_fattn_mma_get_Q_in_reg (DKQ, DV, ncols, cc);
1493
+ const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc);
1415
1494
 
1416
- constexpr int ncols = ncols1 * ncols2;
1417
- constexpr int ntiles = ncols <= 8 ? 1 : 2; // Number of tiles per warp.
1418
- constexpr int cols_per_warp = ntiles * tile_B::I;
1419
- constexpr int nwarps_max_x = ncols / cols_per_warp;
1420
- constexpr int nwarps_max_y = c::nbatch_fa / tile_A::I;
1421
- constexpr int nwarps = nwarps_max_x*nwarps_max_y <= c::nwarps_max ? nwarps_max_x*nwarps_max_y : c::nwarps_max;
1495
+ const int cols_per_warp = std::min(ncols, turing_mma_available(cc) ? 16 : 32);
1496
+ const int nwarps = nthreads / WARP_SIZE;
1422
1497
 
1423
1498
  constexpr bool mla = DKQ == 576;
1424
1499
 
1425
- const int nbatch_K2 = c::get_nbatch_K2_host (cc, ncols);
1426
- const int nbatch_V2 = c::get_nbatch_K2_host (cc, ncols);
1427
- const int nbatch_combine = c::get_nbatch_combine_host(cc, ncols);
1428
-
1429
- static_assert(DKQ % tile_B::J == 0, "bad DKQ");
1430
- static_assert(DV % tile_A::J == 0, "bad DV");
1431
- static_assert(ncols % cols_per_warp == 0, "bad ncols");
1432
-
1433
- const size_t nbytes_shared_KV_1stage = c::nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
1434
- const size_t nbytes_shared_KV_2stage = c::nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
1500
+ const size_t nbytes_shared_KV_1stage = nbatch_fa * std::max(nbatch_K2 + 4, nbatch_V2 + 4) * sizeof(half2);
1501
+ const size_t nbytes_shared_KV_2stage = nbatch_fa * (nbatch_K2 + 4 + nbatch_V2 + 4) * sizeof(half2);
1435
1502
  const size_t nbytes_shared_Q = ncols * (DKQ/2 + 4) * sizeof(half2);
1436
- const size_t nbytes_shared_mask = ncols1 * (c::nbatch_fa/2 + 4) * sizeof(half2);
1503
+ const size_t nbytes_shared_mask = ncols1 * (nbatch_fa/2 + 4) * sizeof(half2);
1437
1504
  const size_t nbytes_shared_combine = nwarps*cols_per_warp * (nbatch_combine + 4) * sizeof(half2);
1438
1505
 
1439
1506
  const size_t nbytes_shared_KV = nstages <= 1 ? nbytes_shared_KV_1stage : nbytes_shared_KV_2stage;
1440
1507
 
1441
- const size_t nbytes_shared_total = std::max(nbytes_shared_combine, c::Q_in_reg ?
1508
+ const size_t nbytes_shared_total = std::max(nbytes_shared_combine, Q_in_reg ?
1442
1509
  std::max(nbytes_shared_Q, nbytes_shared_KV + nbytes_shared_mask) :
1443
1510
  nbytes_shared_Q + nbytes_shared_KV + nbytes_shared_mask);
1444
1511
 
@@ -1448,7 +1515,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1448
1515
  fattn_kernel_t fattn_kernel;
1449
1516
  if (logit_softcap == 0.0f) {
1450
1517
  constexpr bool use_logit_softcap = false;
1451
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
1518
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
1452
1519
 
1453
1520
  #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1454
1521
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1459,7 +1526,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1459
1526
  #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1460
1527
  } else {
1461
1528
  constexpr bool use_logit_softcap = true;
1462
- fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>;
1529
+ fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, use_logit_softcap, mla>;
1463
1530
 
1464
1531
  #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
1465
1532
  static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
@@ -1471,7 +1538,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
1471
1538
  }
1472
1539
 
1473
1540
  launch_fattn<DV, ncols1, ncols2>
1474
- (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
1541
+ (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true);
1475
1542
  }
1476
1543
 
1477
1544