whispercpp 1.3.5 → 1.3.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (610) hide show
  1. checksums.yaml +4 -4
  2. data/LICENSE +1 -1
  3. data/README.md +99 -2
  4. data/ext/extconf.rb +1 -0
  5. data/ext/ruby_whisper.c +20 -4
  6. data/ext/ruby_whisper.h +30 -2
  7. data/ext/ruby_whisper_context.c +216 -124
  8. data/ext/ruby_whisper_context_params.c +163 -0
  9. data/ext/ruby_whisper_model.c +0 -1
  10. data/ext/ruby_whisper_params.c +0 -1
  11. data/ext/ruby_whisper_segment.c +0 -1
  12. data/ext/ruby_whisper_token.c +29 -9
  13. data/ext/ruby_whisper_transcribe.cpp +4 -1
  14. data/ext/ruby_whisper_vad_context.c +48 -1
  15. data/ext/ruby_whisper_vad_context_detect.cpp +6 -5
  16. data/ext/ruby_whisper_vad_params.c +0 -1
  17. data/ext/ruby_whisper_vad_segment.c +0 -1
  18. data/ext/ruby_whisper_vad_segments.c +0 -1
  19. data/ext/sources/CMakeLists.txt +1 -1
  20. data/ext/sources/bindings/javascript/package.json +1 -1
  21. data/ext/sources/cmake/whisper-config.cmake.in +5 -40
  22. data/ext/sources/examples/bench/bench.cpp +23 -18
  23. data/ext/sources/examples/cli/cli.cpp +8 -0
  24. data/ext/sources/examples/common-ggml.cpp +2 -0
  25. data/ext/sources/examples/miniaudio.h +4507 -2131
  26. data/ext/sources/examples/server/server.cpp +18 -4
  27. data/ext/sources/examples/talk-llama/CMakeLists.txt +3 -2
  28. data/ext/sources/examples/talk-llama/llama-adapter.cpp +7 -13
  29. data/ext/sources/examples/talk-llama/llama-adapter.h +4 -3
  30. data/ext/sources/examples/talk-llama/llama-arch.cpp +335 -17
  31. data/ext/sources/examples/talk-llama/llama-arch.h +42 -0
  32. data/ext/sources/examples/talk-llama/llama-batch.cpp +3 -1
  33. data/ext/sources/examples/talk-llama/llama-chat.cpp +21 -1
  34. data/ext/sources/examples/talk-llama/llama-chat.h +1 -0
  35. data/ext/sources/examples/talk-llama/llama-context.cpp +508 -520
  36. data/ext/sources/examples/talk-llama/llama-context.h +27 -28
  37. data/ext/sources/examples/talk-llama/llama-cparams.h +5 -0
  38. data/ext/sources/examples/talk-llama/llama-ext.h +12 -0
  39. data/ext/sources/examples/talk-llama/llama-grammar.cpp +8 -8
  40. data/ext/sources/examples/talk-llama/llama-graph.cpp +583 -130
  41. data/ext/sources/examples/talk-llama/llama-graph.h +131 -10
  42. data/ext/sources/examples/talk-llama/llama-hparams.cpp +57 -40
  43. data/ext/sources/examples/talk-llama/llama-hparams.h +79 -10
  44. data/ext/sources/examples/talk-llama/llama-impl.cpp +4 -4
  45. data/ext/sources/examples/talk-llama/llama-impl.h +13 -1
  46. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +3 -1
  47. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +274 -89
  48. data/ext/sources/examples/talk-llama/llama-kv-cache.h +2 -3
  49. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +275 -0
  50. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +140 -0
  51. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +11 -13
  52. data/ext/sources/examples/talk-llama/llama-mmap.cpp +28 -11
  53. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +527 -119
  54. data/ext/sources/examples/talk-llama/llama-model-loader.h +35 -5
  55. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +60 -46
  56. data/ext/sources/examples/talk-llama/llama-model-saver.h +5 -2
  57. data/ext/sources/examples/talk-llama/llama-model.cpp +1365 -647
  58. data/ext/sources/examples/talk-llama/llama-model.h +72 -19
  59. data/ext/sources/examples/talk-llama/llama-quant.cpp +578 -346
  60. data/ext/sources/examples/talk-llama/{llama-sampling.cpp → llama-sampler.cpp} +190 -76
  61. data/ext/sources/examples/talk-llama/{llama-sampling.h → llama-sampler.h} +0 -2
  62. data/ext/sources/examples/talk-llama/llama-vocab.cpp +118 -48
  63. data/ext/sources/examples/talk-llama/llama-vocab.h +5 -0
  64. data/ext/sources/examples/talk-llama/llama.cpp +76 -22
  65. data/ext/sources/examples/talk-llama/llama.h +63 -30
  66. data/ext/sources/examples/talk-llama/models/afmoe.cpp +2 -3
  67. data/ext/sources/examples/talk-llama/models/apertus.cpp +3 -3
  68. data/ext/sources/examples/talk-llama/models/arcee.cpp +3 -3
  69. data/ext/sources/examples/talk-llama/models/arctic.cpp +4 -5
  70. data/ext/sources/examples/talk-llama/models/baichuan.cpp +4 -3
  71. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +1 -2
  72. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +3 -5
  73. data/ext/sources/examples/talk-llama/models/bert.cpp +13 -7
  74. data/ext/sources/examples/talk-llama/models/bitnet.cpp +9 -24
  75. data/ext/sources/examples/talk-llama/models/bloom.cpp +2 -2
  76. data/ext/sources/examples/talk-llama/models/chameleon.cpp +3 -3
  77. data/ext/sources/examples/talk-llama/models/chatglm.cpp +2 -2
  78. data/ext/sources/examples/talk-llama/models/codeshell.cpp +3 -3
  79. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +3 -3
  80. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +2 -2
  81. data/ext/sources/examples/talk-llama/models/command-r.cpp +2 -2
  82. data/ext/sources/examples/talk-llama/models/dbrx.cpp +4 -5
  83. data/ext/sources/examples/talk-llama/models/deci.cpp +3 -3
  84. data/ext/sources/examples/talk-llama/models/deepseek.cpp +4 -6
  85. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +24 -21
  86. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +445 -0
  87. data/ext/sources/examples/talk-llama/models/dots1.cpp +4 -6
  88. data/ext/sources/examples/talk-llama/models/dream.cpp +3 -3
  89. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +4 -6
  90. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +3 -3
  91. data/ext/sources/examples/talk-llama/models/eurobert.cpp +97 -0
  92. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +145 -0
  93. data/ext/sources/examples/talk-llama/models/exaone.cpp +3 -3
  94. data/ext/sources/examples/talk-llama/models/exaone4.cpp +3 -3
  95. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +2 -4
  96. data/ext/sources/examples/talk-llama/models/falcon.cpp +3 -3
  97. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +1 -1
  98. data/ext/sources/examples/talk-llama/models/gemma.cpp +1 -1
  99. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +1 -1
  100. data/ext/sources/examples/talk-llama/models/gemma3.cpp +1 -1
  101. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +7 -7
  102. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +3 -3
  103. data/ext/sources/examples/talk-llama/models/glm4.cpp +14 -7
  104. data/ext/sources/examples/talk-llama/models/gpt2.cpp +2 -2
  105. data/ext/sources/examples/talk-llama/models/gptneox.cpp +2 -2
  106. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +4 -5
  107. data/ext/sources/examples/talk-llama/models/granite.cpp +4 -5
  108. data/ext/sources/examples/talk-llama/models/grok.cpp +4 -4
  109. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +5 -7
  110. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +3 -3
  111. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +4 -5
  112. data/ext/sources/examples/talk-llama/models/internlm2.cpp +3 -3
  113. data/ext/sources/examples/talk-llama/models/jais.cpp +2 -2
  114. data/ext/sources/examples/talk-llama/models/jais2.cpp +123 -0
  115. data/ext/sources/examples/talk-llama/models/jamba.cpp +3 -3
  116. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +381 -0
  117. data/ext/sources/examples/talk-llama/models/lfm2.cpp +145 -124
  118. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +4 -4
  119. data/ext/sources/examples/talk-llama/models/llada.cpp +3 -3
  120. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +4 -4
  121. data/ext/sources/examples/talk-llama/models/llama.cpp +18 -11
  122. data/ext/sources/examples/talk-llama/models/maincoder.cpp +3 -3
  123. data/ext/sources/examples/talk-llama/models/{graph-context-mamba.cpp → mamba-base.cpp} +9 -3
  124. data/ext/sources/examples/talk-llama/models/mamba.cpp +1 -2
  125. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +11 -5
  126. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +14 -13
  127. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +4 -5
  128. data/ext/sources/examples/talk-llama/models/mistral3.cpp +4 -4
  129. data/ext/sources/examples/talk-llama/models/models.h +181 -46
  130. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +2 -9
  131. data/ext/sources/examples/talk-llama/models/mpt.cpp +2 -2
  132. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +26 -14
  133. data/ext/sources/examples/talk-llama/models/nemotron.cpp +3 -3
  134. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +2 -2
  135. data/ext/sources/examples/talk-llama/models/olmo.cpp +3 -3
  136. data/ext/sources/examples/talk-llama/models/olmo2.cpp +3 -3
  137. data/ext/sources/examples/talk-llama/models/olmoe.cpp +4 -4
  138. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +1 -1
  139. data/ext/sources/examples/talk-llama/models/openelm.cpp +3 -3
  140. data/ext/sources/examples/talk-llama/models/orion.cpp +3 -3
  141. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +122 -0
  142. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +3 -3
  143. data/ext/sources/examples/talk-llama/models/phi2.cpp +2 -2
  144. data/ext/sources/examples/talk-llama/models/phi3.cpp +3 -3
  145. data/ext/sources/examples/talk-llama/models/plamo.cpp +3 -3
  146. data/ext/sources/examples/talk-llama/models/plamo2.cpp +9 -5
  147. data/ext/sources/examples/talk-llama/models/plamo3.cpp +2 -2
  148. data/ext/sources/examples/talk-llama/models/plm.cpp +15 -14
  149. data/ext/sources/examples/talk-llama/models/qwen.cpp +2 -2
  150. data/ext/sources/examples/talk-llama/models/qwen2.cpp +3 -3
  151. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +4 -4
  152. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +3 -3
  153. data/ext/sources/examples/talk-llama/models/qwen3.cpp +12 -9
  154. data/ext/sources/examples/talk-llama/models/qwen35.cpp +381 -0
  155. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +422 -0
  156. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +15 -8
  157. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +84 -432
  158. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +9 -18
  159. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +8 -17
  160. data/ext/sources/examples/talk-llama/models/refact.cpp +2 -2
  161. data/ext/sources/examples/talk-llama/models/rnd1.cpp +4 -4
  162. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +2 -0
  163. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +2 -0
  164. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +3 -3
  165. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +4 -4
  166. data/ext/sources/examples/talk-llama/models/smollm3.cpp +3 -3
  167. data/ext/sources/examples/talk-llama/models/stablelm.cpp +2 -2
  168. data/ext/sources/examples/talk-llama/models/starcoder.cpp +2 -2
  169. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +3 -3
  170. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +165 -0
  171. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +2 -2
  172. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +2 -2
  173. data/ext/sources/examples/talk-llama/models/xverse.cpp +3 -3
  174. data/ext/sources/examples/talk-llama/unicode.cpp +21 -65
  175. data/ext/sources/ggml/CMakeLists.txt +9 -3
  176. data/ext/sources/ggml/include/ggml-backend.h +1 -1
  177. data/ext/sources/ggml/include/ggml-cann.h +1 -1
  178. data/ext/sources/ggml/include/ggml-cpu.h +5 -0
  179. data/ext/sources/ggml/include/ggml-openvino.h +37 -0
  180. data/ext/sources/ggml/include/ggml-opt.h +1 -1
  181. data/ext/sources/ggml/include/ggml-rpc.h +6 -1
  182. data/ext/sources/ggml/include/ggml-virtgpu.h +14 -0
  183. data/ext/sources/ggml/include/ggml.h +56 -9
  184. data/ext/sources/ggml/src/CMakeLists.txt +3 -0
  185. data/ext/sources/ggml/src/ggml-alloc.c +4 -9
  186. data/ext/sources/ggml/src/ggml-backend-dl.cpp +48 -0
  187. data/ext/sources/ggml/src/ggml-backend-dl.h +45 -0
  188. data/ext/sources/ggml/src/ggml-backend-reg.cpp +28 -86
  189. data/ext/sources/ggml/src/ggml-backend.cpp +5 -2
  190. data/ext/sources/ggml/src/ggml-blas/CMakeLists.txt +1 -1
  191. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +6 -2
  192. data/ext/sources/ggml/src/ggml-cann/acl_tensor.cpp +1 -1
  193. data/ext/sources/ggml/src/ggml-cann/acl_tensor.h +1 -1
  194. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +348 -189
  195. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +40 -85
  196. data/ext/sources/ggml/src/ggml-cann/common.h +3 -4
  197. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +44 -62
  198. data/ext/sources/ggml/src/ggml-common.h +11 -0
  199. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +16 -11
  200. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +42 -19
  201. data/ext/sources/ggml/src/ggml-cpu/amx/common.h +34 -10
  202. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +85 -85
  203. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +85 -1
  204. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +2744 -548
  205. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +1653 -0
  206. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +1391 -0
  207. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +8 -10
  208. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +9 -9
  209. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +118 -18
  210. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +107 -26
  211. data/ext/sources/ggml/src/ggml-cpu/binary-ops.cpp +2 -6
  212. data/ext/sources/ggml/src/ggml-cpu/common.h +8 -0
  213. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +3 -0
  214. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +59 -12
  215. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +15 -0
  216. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kernels.cpp +21 -20
  217. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +965 -252
  218. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +584 -197
  219. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +903 -188
  220. data/ext/sources/ggml/src/ggml-cpu/ops.h +1 -0
  221. data/ext/sources/ggml/src/ggml-cpu/quants.c +40 -0
  222. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  223. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +2890 -679
  224. data/ext/sources/ggml/src/ggml-cpu/repack.h +119 -8
  225. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +136 -0
  226. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +111 -3
  227. data/ext/sources/ggml/src/ggml-cpu/unary-ops.cpp +1 -1
  228. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +17 -0
  229. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +1 -1
  230. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +19 -10
  231. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +32 -30
  232. data/ext/sources/ggml/src/ggml-cuda/common.cuh +134 -18
  233. data/ext/sources/ggml/src/ggml-cuda/convert.cu +41 -27
  234. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +6 -3
  235. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +78 -64
  236. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +384 -143
  237. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +36 -22
  238. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +3 -3
  239. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +26 -5
  240. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +1 -1
  241. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +127 -12
  242. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +263 -0
  243. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cuh +4 -0
  244. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +595 -200
  245. data/ext/sources/ggml/src/ggml-cuda/mean.cu +9 -8
  246. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +173 -6
  247. data/ext/sources/ggml/src/ggml-cuda/mmf.cu +30 -10
  248. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +158 -85
  249. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +34 -22
  250. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +127 -67
  251. data/ext/sources/ggml/src/ggml-cuda/mmvf.cuh +2 -0
  252. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +157 -65
  253. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +1 -0
  254. data/ext/sources/ggml/src/ggml-cuda/norm.cu +18 -76
  255. data/ext/sources/ggml/src/ggml-cuda/pad.cu +13 -10
  256. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +1 -1
  257. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +2 -16
  258. data/ext/sources/ggml/src/ggml-cuda/rope.cu +233 -133
  259. data/ext/sources/ggml/src/ggml-cuda/softmax.cu +8 -83
  260. data/ext/sources/ggml/src/ggml-cuda/solve_tri.cu +1 -1
  261. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +56 -32
  262. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  263. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +5 -0
  264. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  265. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +5 -0
  266. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  267. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  268. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  269. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +3 -3
  270. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +0 -1
  271. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +199 -135
  272. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cuh +20 -14
  273. data/ext/sources/ggml/src/ggml-cuda/unary.cu +55 -0
  274. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  275. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +31 -17
  276. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +10 -0
  277. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +82 -45
  278. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +334 -160
  279. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +7 -5
  280. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +328 -197
  281. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +281 -0
  282. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +765 -234
  283. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +252 -0
  284. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +412 -265
  285. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +23 -23
  286. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.c → hex-dma.c} +1 -1
  287. data/ext/sources/ggml/src/ggml-hexagon/htp/{htp-dma.h → hex-dma.h} +28 -3
  288. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +77 -0
  289. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-fastdiv.h +37 -0
  290. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +51 -0
  291. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +1 -1
  292. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +27 -37
  293. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +6 -35
  294. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-arith.h +443 -0
  295. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +240 -0
  296. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +245 -0
  297. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +251 -0
  298. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-dump.h +129 -0
  299. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +215 -0
  300. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-floor.h +100 -0
  301. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.h +210 -0
  302. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-reduce.h +296 -0
  303. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-scale.h +133 -0
  304. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +141 -0
  305. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +126 -0
  306. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-types.h +36 -0
  307. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +20 -1347
  308. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +211 -13
  309. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +1119 -952
  310. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +254 -244
  311. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +36 -36
  312. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +155 -138
  313. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +339 -0
  314. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +128 -0
  315. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +209 -114
  316. data/ext/sources/ggml/src/ggml-hexagon/htp/worker-pool.c +1 -5
  317. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.cpp +418 -0
  318. data/ext/sources/ggml/src/ggml-hexagon/htp-drv.h +121 -0
  319. data/ext/sources/ggml/src/ggml-hexagon/libdl.h +79 -0
  320. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +38 -0
  321. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +6 -0
  322. data/ext/sources/ggml/src/ggml-impl.h +62 -0
  323. data/ext/sources/ggml/src/ggml-metal/CMakeLists.txt +10 -10
  324. data/ext/sources/ggml/src/ggml-metal/ggml-metal-common.cpp +13 -2
  325. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.h +8 -0
  326. data/ext/sources/ggml/src/ggml-metal/ggml-metal-context.m +147 -17
  327. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +274 -73
  328. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +22 -4
  329. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +102 -36
  330. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +174 -23
  331. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +580 -280
  332. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +5 -4
  333. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +320 -107
  334. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +1068 -825
  335. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +19 -1
  336. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +3108 -636
  337. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +41 -99
  338. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +45 -0
  339. data/ext/sources/ggml/src/ggml-opencl/kernels/cumsum.cl +139 -0
  340. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +204 -0
  341. data/ext/sources/ggml/src/ggml-opencl/kernels/diag.cl +27 -0
  342. data/ext/sources/ggml/src/ggml-opencl/kernels/exp.cl +125 -0
  343. data/ext/sources/ggml/src/ggml-opencl/kernels/expm1.cl +87 -56
  344. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl +132 -0
  345. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +195 -0
  346. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl +283 -0
  347. data/ext/sources/ggml/src/ggml-opencl/kernels/l2_norm.cl +71 -0
  348. data/ext/sources/ggml/src/ggml-opencl/kernels/mean.cl +114 -13
  349. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl +163 -0
  350. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl +165 -0
  351. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl +158 -0
  352. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q8_0_f32_8x4.cl +129 -0
  353. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl +219 -0
  354. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl +229 -0
  355. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl +180 -0
  356. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mv_q6_k.cl → mul_mv_q6_k_f32.cl} +4 -0
  357. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +194 -0
  358. data/ext/sources/ggml/src/ggml-opencl/kernels/neg.cl +125 -0
  359. data/ext/sources/ggml/src/ggml-opencl/kernels/repeat.cl +31 -32
  360. data/ext/sources/ggml/src/ggml-opencl/kernels/scale.cl +14 -4
  361. data/ext/sources/ggml/src/ggml-opencl/kernels/softplus.cl +88 -60
  362. data/ext/sources/ggml/src/ggml-opencl/kernels/solve_tri.cl +51 -0
  363. data/ext/sources/ggml/src/ggml-opencl/kernels/sum_rows.cl +114 -13
  364. data/ext/sources/ggml/src/ggml-opencl/kernels/tanh.cl +94 -48
  365. data/ext/sources/ggml/src/ggml-opencl/kernels/transpose.cl +26 -0
  366. data/ext/sources/ggml/src/ggml-opencl/kernels/tri.cl +32 -0
  367. data/ext/sources/ggml/src/ggml-openvino/.clang-format +154 -0
  368. data/ext/sources/ggml/src/ggml-openvino/CMakeLists.txt +22 -0
  369. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +975 -0
  370. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.h +294 -0
  371. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +373 -0
  372. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.h +182 -0
  373. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +1110 -0
  374. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +884 -0
  375. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.h +153 -0
  376. data/ext/sources/ggml/src/ggml-openvino/openvino/decoder.h +74 -0
  377. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.cpp +27 -0
  378. data/ext/sources/ggml/src/ggml-openvino/openvino/frontend.h +23 -0
  379. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.cpp +17 -0
  380. data/ext/sources/ggml/src/ggml-openvino/openvino/input_model.h +29 -0
  381. data/ext/sources/ggml/src/ggml-openvino/openvino/node_context.h +112 -0
  382. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cont.cpp +48 -0
  383. data/ext/sources/ggml/src/ggml-openvino/openvino/op/cpy.cpp +21 -0
  384. data/ext/sources/ggml/src/ggml-openvino/openvino/op/flash_attn_ext.cpp +90 -0
  385. data/ext/sources/ggml/src/ggml-openvino/openvino/op/get_rows.cpp +69 -0
  386. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_geglu.cpp +61 -0
  387. data/ext/sources/ggml/src/ggml-openvino/openvino/op/glu_swiglu.cpp +62 -0
  388. data/ext/sources/ggml/src/ggml-openvino/openvino/op/mulmat.cpp +90 -0
  389. data/ext/sources/ggml/src/ggml-openvino/openvino/op/permute.cpp +102 -0
  390. data/ext/sources/ggml/src/ggml-openvino/openvino/op/reshape.cpp +83 -0
  391. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rms_norm.cpp +46 -0
  392. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +123 -0
  393. data/ext/sources/ggml/src/ggml-openvino/openvino/op/scale.cpp +41 -0
  394. data/ext/sources/ggml/src/ggml-openvino/openvino/op/set_rows.cpp +76 -0
  395. data/ext/sources/ggml/src/ggml-openvino/openvino/op/softmax.cpp +89 -0
  396. data/ext/sources/ggml/src/ggml-openvino/openvino/op/transpose.cpp +23 -0
  397. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_silu.cpp +27 -0
  398. data/ext/sources/ggml/src/ggml-openvino/openvino/op/view.cpp +53 -0
  399. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +46 -0
  400. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +39 -0
  401. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +123 -0
  402. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +17 -0
  403. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.cpp +60 -0
  404. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/fuse_to_sdpa.h +17 -0
  405. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/mark_decompression_convert_constant_folding.h +29 -0
  406. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.cpp +58 -0
  407. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/squeeze_matmul.h +17 -0
  408. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +293 -0
  409. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.h +28 -0
  410. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +226 -0
  411. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +85 -0
  412. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +823 -0
  413. data/ext/sources/ggml/src/ggml-openvino/utils.h +123 -0
  414. data/ext/sources/ggml/src/ggml-quants.c +96 -5
  415. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  416. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +15 -88
  417. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +5 -1
  418. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/binbcast.cpp +21 -20
  420. data/ext/sources/ggml/src/ggml-sycl/common.hpp +315 -10
  421. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +69 -1
  422. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +22 -1
  423. data/ext/sources/ggml/src/ggml-sycl/count-equal.cpp +1 -1
  424. data/ext/sources/ggml/src/ggml-sycl/dpct/helper.hpp +791 -47
  425. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +78 -68
  426. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +2 -0
  427. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +1179 -0
  428. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +55 -0
  429. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +1338 -0
  430. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +667 -0
  431. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +225 -0
  432. data/ext/sources/ggml/src/ggml-sycl/fattn.hpp +22 -0
  433. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +309 -0
  434. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +8 -0
  435. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +316 -51
  436. data/ext/sources/ggml/src/ggml-sycl/norm.cpp +65 -66
  437. data/ext/sources/ggml/src/ggml-sycl/outprod.cpp +3 -3
  438. data/ext/sources/ggml/src/ggml-sycl/presets.hpp +3 -0
  439. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +1 -1
  440. data/ext/sources/ggml/src/ggml-sycl/rope.cpp +450 -287
  441. data/ext/sources/ggml/src/ggml-sycl/rope.hpp +6 -0
  442. data/ext/sources/ggml/src/ggml-sycl/softmax.cpp +6 -6
  443. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq112-dv112.cpp +5 -0
  444. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq128-dv128.cpp +5 -0
  445. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq256-dv256.cpp +5 -0
  446. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq40-dv40.cpp +5 -0
  447. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq576-dv512.cpp +5 -0
  448. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq64-dv64.cpp +5 -0
  449. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq72-dv72.cpp +5 -0
  450. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq80-dv80.cpp +5 -0
  451. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq96-dv96.cpp +5 -0
  452. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +7 -0
  453. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +7 -0
  454. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +7 -0
  455. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +7 -0
  456. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +7 -0
  457. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +7 -0
  458. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +7 -0
  459. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +7 -0
  460. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +7 -0
  461. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +7 -0
  462. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +7 -0
  463. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +7 -0
  464. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +7 -0
  465. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +7 -0
  466. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +7 -0
  467. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +7 -0
  468. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +7 -0
  469. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +7 -0
  470. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +7 -0
  471. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +7 -0
  472. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +7 -0
  473. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +7 -0
  474. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +7 -0
  475. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +7 -0
  476. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +7 -0
  477. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +7 -0
  478. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +7 -0
  479. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +7 -0
  480. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +7 -0
  481. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +7 -0
  482. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +7 -0
  483. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +7 -0
  484. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +7 -0
  485. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +7 -0
  486. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +7 -0
  487. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +7 -0
  488. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +13 -0
  489. data/ext/sources/ggml/src/ggml-sycl/wkv.cpp +1 -1
  490. data/ext/sources/ggml/src/ggml-virtgpu/CMakeLists.txt +70 -0
  491. data/ext/sources/ggml/src/ggml-virtgpu/apir_cs_ggml-rpc-front.cpp +87 -0
  492. data/ext/sources/ggml/src/ggml-virtgpu/backend/CMakeLists.txt +21 -0
  493. data/ext/sources/ggml/src/ggml-virtgpu/backend/apir_cs_ggml-rpc-back.cpp +115 -0
  494. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-convert.h +13 -0
  495. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +102 -0
  496. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +105 -0
  497. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +179 -0
  498. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched-device.cpp +148 -0
  499. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +51 -0
  500. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +73 -0
  501. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +27 -0
  502. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +32 -0
  503. data/ext/sources/ggml/src/ggml-virtgpu/backend/backend.cpp +144 -0
  504. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +95 -0
  505. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +94 -0
  506. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +50 -0
  507. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +378 -0
  508. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +232 -0
  509. data/ext/sources/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +58 -0
  510. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +81 -0
  511. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +119 -0
  512. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +158 -0
  513. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +213 -0
  514. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +69 -0
  515. data/ext/sources/ggml/src/ggml-virtgpu/ggml-remoting.h +71 -0
  516. data/ext/sources/ggml/src/ggml-virtgpu/ggmlremoting_functions.yaml +166 -0
  517. data/ext/sources/ggml/src/ggml-virtgpu/include/apir_hw.h +9 -0
  518. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +333 -0
  519. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-apir.h +15 -0
  520. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +58 -0
  521. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +110 -0
  522. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +173 -0
  523. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +192 -0
  524. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +36 -0
  525. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +53 -0
  526. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +98 -0
  527. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.h +23 -0
  528. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.cpp +179 -0
  529. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-utils.h +86 -0
  530. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +544 -0
  531. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +117 -0
  532. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +1 -1
  533. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1250 -465
  534. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +16 -8
  535. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/elu.comp +27 -0
  536. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +374 -170
  537. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +66 -22
  538. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +389 -201
  539. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +106 -58
  540. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mask_opt.comp +162 -0
  541. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +9 -8
  542. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +128 -0
  543. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +12 -9
  544. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +20 -17
  545. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -3
  546. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +8 -4
  547. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +3 -1
  548. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +5 -3
  549. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +3 -3
  550. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +2 -3
  551. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +36 -63
  552. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +7 -4
  553. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp +7 -4
  554. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp +7 -4
  555. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +10 -5
  556. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp +7 -4
  557. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/sgn.comp +21 -0
  558. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +16 -10
  559. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +55 -35
  560. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +1314 -109
  561. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +1660 -1371
  562. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argmax.wgsl +72 -0
  563. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort.wgsl +106 -0
  564. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/argsort_merge.wgsl +134 -0
  565. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +141 -0
  566. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +65 -72
  567. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +75 -0
  568. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl +6 -0
  569. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/cumsum.wgsl +66 -0
  570. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +40 -5
  571. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +105 -60
  572. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{get_rows.tmpl.wgsl → get_rows.wgsl} +53 -259
  573. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat.tmpl.wgsl → mul_mat.wgsl} +68 -257
  574. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +692 -23
  575. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_reg_tile.tmpl.wgsl → mul_mat_reg_tile.wgsl} +28 -128
  576. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{mul_mat_subgroup_matrix.tmpl.wgsl → mul_mat_subgroup_matrix.wgsl} +31 -137
  577. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +480 -0
  578. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/pad.wgsl +86 -0
  579. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/repeat.wgsl +67 -0
  580. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{scale.tmpl.wgsl → scale.wgsl} +9 -36
  581. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +40 -12
  582. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/sum_rows.wgsl +55 -0
  583. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +193 -0
  584. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +6 -1
  585. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +31 -32
  586. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +9 -6
  587. data/ext/sources/ggml/src/ggml.c +167 -33
  588. data/ext/sources/ggml/src/gguf.cpp +229 -44
  589. data/ext/sources/src/whisper.cpp +6 -28
  590. data/sig/whisper.rbs +43 -2
  591. data/test/test_context_params.rb +82 -0
  592. data/test/test_token.rb +11 -0
  593. data/test/test_vad_context.rb +58 -8
  594. data/test/test_whisper.rb +20 -0
  595. data/whispercpp.gemspec +1 -1
  596. metadata +240 -28
  597. data/ext/sources/ggml/cmake/BuildTypes.cmake +0 -54
  598. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +0 -333
  599. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.c +0 -94
  600. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-inverse.c +0 -72
  601. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.c +0 -49
  602. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.c +0 -1020
  603. data/ext/sources/ggml/src/ggml-hexagon/htp/ops-utils.h +0 -149
  604. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.c +0 -454
  605. data/ext/sources/ggml/src/ggml-hexagon/htp-utils.h +0 -221
  606. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl +0 -188
  607. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary_head.tmpl +0 -45
  608. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +0 -267
  609. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.tmpl.wgsl +0 -112
  610. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +0 -483
@@ -27,6 +27,7 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
27
27
  #include <iostream>
28
28
  #include <tuple>
29
29
  #include <vector>
30
+ #include <deque>
30
31
  #include <sstream>
31
32
  #include <utility>
32
33
  #include <memory>
@@ -92,6 +93,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
92
93
  #define VK_VENDOR_ID_APPLE 0x106b
93
94
  #define VK_VENDOR_ID_INTEL 0x8086
94
95
  #define VK_VENDOR_ID_NVIDIA 0x10de
96
+ #define VK_VENDOR_ID_QUALCOMM 0x5143
95
97
 
96
98
  #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
97
99
 
@@ -187,6 +189,11 @@ struct ggml_backend_vk_buffer_type_context {
187
189
 
188
190
  struct vk_queue;
189
191
 
192
+ struct vk_command_buffer {
193
+ vk::CommandBuffer buf;
194
+ bool in_use = false;
195
+ };
196
+
190
197
  // Stores command pool/buffers. There's an instance of this
191
198
  // for each (context,queue) pair and for each (device,queue) pair.
192
199
  struct vk_command_pool {
@@ -194,10 +201,16 @@ struct vk_command_pool {
194
201
  void destroy(vk::Device& device);
195
202
 
196
203
  vk::CommandPool pool;
197
- uint32_t cmd_buffer_idx;
198
- std::vector<vk::CommandBuffer> cmd_buffers;
204
+ // Using deque so the pointers to command buffers
205
+ // remain valid even if we add more
206
+ std::deque<vk_command_buffer> cmd_buffers;
199
207
 
200
208
  vk_queue *q;
209
+
210
+ size_t buffers_in_use() const {
211
+ return std::count_if(cmd_buffers.begin(), cmd_buffers.end(),
212
+ [](const auto& cb) { return cb.in_use; });
213
+ }
201
214
  };
202
215
 
203
216
  // Prevent simultaneous submissions to the same queue.
@@ -254,6 +267,7 @@ enum vk_device_architecture {
254
267
  AMD_RDNA3,
255
268
  INTEL_XE2,
256
269
  NVIDIA_PRE_TURING,
270
+ NVIDIA_TURING,
257
271
  };
258
272
 
259
273
  static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
@@ -336,18 +350,34 @@ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice&
336
350
  const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
337
351
 
338
352
  bool cooperative_matrix = false;
353
+ bool sm_builtins = false;
339
354
 
340
355
  // Detect "pre-turing" based on lack of coopmat support.
341
356
  for (const auto& properties : ext_props) {
342
357
  if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) {
343
358
  cooperative_matrix = true;
344
- break;
359
+ } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
360
+ sm_builtins = true;
345
361
  }
346
362
  }
347
363
 
348
364
  if (!cooperative_matrix) {
349
365
  return vk_device_architecture::NVIDIA_PRE_TURING;
350
366
  }
367
+
368
+ if (sm_builtins) {
369
+ vk::PhysicalDeviceProperties2 props2;
370
+ vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
371
+
372
+ props2.pNext = &sm_props;
373
+
374
+ device.getProperties2(&props2);
375
+
376
+ // Turing has 32, following architectures have 48
377
+ if (sm_props.shaderWarpsPerSM == 32) {
378
+ return vk_device_architecture::NVIDIA_TURING;
379
+ }
380
+ }
351
381
  }
352
382
  return vk_device_architecture::OTHER;
353
383
  }
@@ -385,18 +415,20 @@ enum FaCodePath {
385
415
  };
386
416
 
387
417
  struct vk_fa_pipeline_state {
388
- vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc)
389
- : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc) {}
390
-
391
418
  uint32_t HSK, HSV;
392
- bool small_rows, small_cache;
419
+ uint32_t Br, Bc;
420
+ uint32_t D_split, row_split;
421
+ bool shmem_staging;
393
422
  FaCodePath path;
423
+ uint32_t workgroup_size, subgroup_size;
394
424
  bool aligned;
395
425
  bool f32acc;
426
+ uint32_t flags;
427
+ uint32_t limit_occupancy_shmem;
396
428
 
397
429
  bool operator<(const vk_fa_pipeline_state &b) const {
398
- return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc) <
399
- std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc);
430
+ return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
431
+ std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
400
432
  }
401
433
  };
402
434
 
@@ -570,6 +602,7 @@ struct vk_device_struct {
570
602
  vk_queue transfer_queue;
571
603
  bool single_queue;
572
604
  bool support_async;
605
+ bool async_use_transfer_queue;
573
606
  uint32_t subgroup_size;
574
607
  uint32_t subgroup_size_log2;
575
608
  uint32_t shader_core_count;
@@ -669,6 +702,7 @@ struct vk_device_struct {
669
702
  vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
670
703
  vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
671
704
  vk_pipeline pipeline_acc_f32;
705
+ vk_pipeline pipeline_set_f32;
672
706
 
673
707
  // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
674
708
  vk_pipeline pipeline_add[2][2][2];
@@ -722,6 +756,7 @@ struct vk_device_struct {
722
756
 
723
757
  // [src/dst 0=fp32,1=fp16]
724
758
  vk_pipeline pipeline_exp[2];
759
+ vk_pipeline pipeline_elu[2];
725
760
  vk_pipeline pipeline_gelu[2];
726
761
  vk_pipeline pipeline_gelu_erf[2];
727
762
  vk_pipeline pipeline_gelu_quick[2];
@@ -740,6 +775,7 @@ struct vk_device_struct {
740
775
  vk_pipeline pipeline_ceil[2];
741
776
  vk_pipeline pipeline_floor[2];
742
777
  vk_pipeline pipeline_trunc[2];
778
+ vk_pipeline pipeline_sgn[2];
743
779
 
744
780
  vk_pipeline pipeline_add1_f16_f16;
745
781
  vk_pipeline pipeline_add1_f16_f32;
@@ -789,6 +825,8 @@ struct vk_device_struct {
789
825
  vk_pipeline pipeline_pool2d_f32;
790
826
  vk_pipeline pipeline_rwkv_wkv6_f32;
791
827
  vk_pipeline pipeline_rwkv_wkv7_f32;
828
+ // [size_idx][kda] where size_idx: 0=d32, 1=d64, 2=d128
829
+ vk_pipeline pipeline_gated_delta_net[3][2];
792
830
  vk_pipeline pipeline_ssm_scan_f32_d128;
793
831
  vk_pipeline pipeline_ssm_scan_f32_d256;
794
832
  vk_pipeline pipeline_ssm_conv_f32;
@@ -803,6 +841,8 @@ struct vk_device_struct {
803
841
 
804
842
  std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
805
843
 
844
+ std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
845
+
806
846
  vk_pipeline pipeline_flash_attn_split_k_reduce;
807
847
  vk_pipeline pipeline_count_experts;
808
848
 
@@ -852,10 +892,12 @@ struct vk_device_struct {
852
892
  };
853
893
 
854
894
  void vk_command_pool::init(vk_device& device, vk_queue *q_) {
855
- cmd_buffer_idx = 0;
895
+ cmd_buffers.clear();
856
896
  q = q_;
857
897
 
858
- vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index);
898
+ vk::CommandPoolCreateInfo command_pool_create_info(
899
+ vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT | VK_COMMAND_POOL_CREATE_RESET_COMMAND_BUFFER_BIT),
900
+ q->queue_family_index);
859
901
  pool = device->device.createCommandPool(command_pool_create_info);
860
902
  }
861
903
 
@@ -903,6 +945,7 @@ struct vk_subbuffer {
903
945
  struct vk_event {
904
946
  vk::Event event;
905
947
  vk::Fence fence;
948
+ vk_command_buffer* cmd_buffer = nullptr;
906
949
  };
907
950
 
908
951
  struct vk_semaphore {
@@ -911,7 +954,7 @@ struct vk_semaphore {
911
954
  };
912
955
 
913
956
  struct vk_submission {
914
- vk::CommandBuffer buffer;
957
+ vk_command_buffer* buffer = nullptr;
915
958
  std::vector<vk_semaphore> wait_semaphores;
916
959
  std::vector<vk_semaphore> signal_semaphores;
917
960
  };
@@ -922,6 +965,7 @@ struct vk_mat_mat_push_constants {
922
965
  uint32_t M; uint32_t N; uint32_t K;
923
966
  uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
924
967
  uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
968
+ uint32_t base_work_group_z; uint32_t num_batches;
925
969
  uint32_t k_split;
926
970
  uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
927
971
  uint32_t padded_N;
@@ -941,6 +985,7 @@ struct vk_mat_vec_push_constants {
941
985
  uint32_t batch_stride_b;
942
986
  uint32_t batch_stride_d;
943
987
  uint32_t fusion_flags;
988
+ uint32_t base_work_group_y;
944
989
  uint32_t ne02;
945
990
  uint32_t ne12;
946
991
  uint32_t broadcast2;
@@ -991,6 +1036,8 @@ struct vk_mat_vec_id_push_constants {
991
1036
  uint32_t fusion_flags;
992
1037
  uint32_t nei0;
993
1038
  uint32_t ne11;
1039
+ uint32_t expert_i1;
1040
+ uint32_t nbi1;
994
1041
  };
995
1042
 
996
1043
  struct vk_flash_attn_push_constants {
@@ -1244,25 +1291,30 @@ struct vk_op_diag_mask_push_constants {
1244
1291
 
1245
1292
  struct vk_op_rope_push_constants {
1246
1293
  uint32_t rope_mode;
1247
- uint32_t ncols;
1248
1294
  uint32_t nrows;
1249
1295
  uint32_t n_dims;
1250
1296
  float freq_scale;
1251
- uint32_t p_delta_rows;
1252
1297
  float freq_base;
1253
1298
  float ext_factor;
1254
1299
  float attn_factor;
1255
1300
  float corr_dims[2];
1256
1301
  float theta_scale;
1257
1302
  uint32_t has_ff;
1258
- uint32_t ne02;
1259
- uint32_t s1;
1260
- uint32_t s2;
1261
1303
  int32_t sections[4];
1262
1304
  uint32_t is_imrope;
1263
1305
  uint32_t is_back;
1264
1306
  uint32_t set_rows_stride;
1307
+ uint32_t ne00;
1308
+ uint32_t ne01;
1309
+ uint32_t ne02;
1310
+ uint32_t nb01;
1311
+ uint32_t nb02;
1312
+ uint32_t nb03;
1313
+ uint32_t nb11;
1314
+ uint32_t nb12;
1315
+ uint32_t nb13;
1265
1316
  };
1317
+ static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
1266
1318
 
1267
1319
  // For fused rms_norm+mul+rope(+view+set_rows)
1268
1320
  struct vk_op_rms_norm_mul_rope_push_constants {
@@ -1404,6 +1456,18 @@ struct vk_op_rwkv_wkv7_push_constants {
1404
1456
  uint32_t C;
1405
1457
  uint32_t H;
1406
1458
  };
1459
+ struct vk_op_gated_delta_net_push_constants {
1460
+ uint32_t H;
1461
+ uint32_t n_tokens;
1462
+ uint32_t n_seqs;
1463
+ uint32_t s_off;
1464
+ uint32_t sq1, sq2, sq3;
1465
+ uint32_t sv1, sv2, sv3;
1466
+ uint32_t sb1, sb2, sb3;
1467
+ uint32_t neq1, rq3;
1468
+ float scale;
1469
+ };
1470
+
1407
1471
  struct vk_op_ssm_scan_push_constants {
1408
1472
  uint32_t nb02, nb03, nb12, nb13;
1409
1473
  uint32_t nb21, nb22, nb31;
@@ -1516,6 +1580,27 @@ struct vk_quantize_q8_1_push_constants {
1516
1580
  uint32_t num_blocks;
1517
1581
  };
1518
1582
 
1583
+ struct vk_op_flash_attn_split_k_reduce_push_constants {
1584
+ uint32_t D;
1585
+ uint32_t ne1;
1586
+ uint32_t ne2;
1587
+ uint32_t ne3;
1588
+ uint32_t k_num;
1589
+ uint32_t sinks;
1590
+ };
1591
+
1592
+ struct vk_op_flash_attn_mask_opt_push_constants {
1593
+ uint32_t nem0;
1594
+ uint32_t nem1;
1595
+ uint32_t nem2;
1596
+ uint32_t nbm1;
1597
+ uint32_t nbm2;
1598
+ uint32_t nbm3;
1599
+ uint32_t nbd1;
1600
+ uint32_t nbd2;
1601
+ uint32_t nbd3;
1602
+ };
1603
+
1519
1604
  // Allow pre-recording command buffers
1520
1605
  struct vk_staging_memcpy {
1521
1606
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1604,6 +1689,7 @@ static bool vk_perf_logger_concurrent = false;
1604
1689
  static bool vk_enable_sync_logger = false;
1605
1690
  // number of calls between perf logger prints
1606
1691
  static uint32_t vk_perf_logger_frequency = 1;
1692
+ static std::string vk_pipeline_stats_filter;
1607
1693
 
1608
1694
  class vk_perf_logger {
1609
1695
  public:
@@ -1724,6 +1810,7 @@ class vk_perf_logger {
1724
1810
  " k(" << k->ne[0] << "," << k->ne[1] << "," << k->ne[2] << "," << k->ne[3] << "), " <<
1725
1811
  " v(" << v->ne[0] << "," << v->ne[1] << "," << v->ne[2] << "," << v->ne[3] << "), " <<
1726
1812
  " m(" << (m?m->ne[0]:0) << "," << (m?m->ne[1]:0) << "," << (m?m->ne[2]:0) << "," << (m?m->ne[3]:0) << ")";
1813
+ *n_flops = 2ull * q->ne[1] * q->ne[2] * (k->ne[0] + v->ne[0]) * k->ne[1] * q->ne[3];
1727
1814
  return name.str();
1728
1815
  }
1729
1816
  if (node->op == GGML_OP_TOP_K) {
@@ -1802,7 +1889,10 @@ struct ggml_backend_vk_context {
1802
1889
  bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
1803
1890
 
1804
1891
  vk_context_ref compute_ctx;
1892
+
1805
1893
  vk_context_ref transfer_ctx;
1894
+ vk_semaphore transfer_semaphore;
1895
+ uint64_t transfer_semaphore_last_submitted {};
1806
1896
 
1807
1897
  std::vector<vk_context_ref> tensor_ctxs;
1808
1898
 
@@ -2121,7 +2211,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
2121
2211
  executableInfo.pipeline = pipeline->pipeline;
2122
2212
 
2123
2213
  auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo);
2214
+
2215
+ bool print_stats = !vk_pipeline_stats_filter.empty() &&
2216
+ pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos;
2217
+ if (print_stats) {
2218
+ std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl;
2219
+ }
2220
+
2124
2221
  for (auto & s : statistics) {
2222
+ if (print_stats) {
2223
+ std::cerr << "ggml_vulkan: " << s.name.data() << ": ";
2224
+ switch (s.format) {
2225
+ case vk::PipelineExecutableStatisticFormatKHR::eBool32:
2226
+ std::cerr << (s.value.b32 ? "true" : "false");
2227
+ break;
2228
+ case vk::PipelineExecutableStatisticFormatKHR::eInt64:
2229
+ std::cerr << s.value.i64;
2230
+ break;
2231
+ case vk::PipelineExecutableStatisticFormatKHR::eUint64:
2232
+ std::cerr << s.value.u64;
2233
+ break;
2234
+ case vk::PipelineExecutableStatisticFormatKHR::eFloat64:
2235
+ std::cerr << s.value.f64;
2236
+ break;
2237
+ }
2238
+ std::cerr << std::endl;
2239
+ }
2125
2240
  // "Register Count" is reported by NVIDIA drivers.
2126
2241
  if (strcmp(s.name, "Register Count") == 0) {
2127
2242
  VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers");
@@ -2197,25 +2312,15 @@ static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx
2197
2312
  }
2198
2313
  }
2199
2314
 
2200
- static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
2315
+ static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) {
2201
2316
  VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()");
2202
-
2203
- if (p.cmd_buffers.size() > p.cmd_buffer_idx) {
2204
- // Reuse command buffer
2205
- return p.cmd_buffers[p.cmd_buffer_idx++];
2206
- }
2207
-
2208
2317
  vk::CommandBufferAllocateInfo command_buffer_alloc_info(
2209
2318
  p.pool,
2210
2319
  vk::CommandBufferLevel::ePrimary,
2211
2320
  1);
2212
2321
  const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
2213
- auto buf = cmd_buffers.front();
2214
-
2215
- p.cmd_buffers.push_back(buf);
2216
- p.cmd_buffer_idx++;
2217
-
2218
- return buf;
2322
+ p.cmd_buffers.push_back({ cmd_buffers.front(), true });
2323
+ return &p.cmd_buffers[p.cmd_buffers.size()-1];
2219
2324
  }
2220
2325
 
2221
2326
  static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
@@ -2282,7 +2387,7 @@ static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) {
2282
2387
  tl_wait_semaphores[idx].data(),
2283
2388
  stage_flags[idx].data(),
2284
2389
  1,
2285
- &submission.buffer,
2390
+ &submission.buffer->buf,
2286
2391
  (uint32_t) submission.signal_semaphores.size(),
2287
2392
  tl_signal_semaphores[idx].data(),
2288
2393
  };
@@ -2406,7 +2511,11 @@ static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p)
2406
2511
 
2407
2512
  // Requires command buffers to be done
2408
2513
  device->device.resetCommandPool(p.pool);
2409
- p.cmd_buffer_idx = 0;
2514
+ // Don't clear the command buffers and mark them as not in use.
2515
+ // This allows us to reuse them
2516
+ for (auto& cmd_buffer : p.cmd_buffers) {
2517
+ cmd_buffer.in_use = false;
2518
+ }
2410
2519
  }
2411
2520
 
2412
2521
  static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
@@ -2415,10 +2524,10 @@ static void ggml_vk_queue_command_pools_cleanup(vk_device& device) {
2415
2524
  // Arbitrary frequency to cleanup/reuse command buffers
2416
2525
  static constexpr uint32_t cleanup_frequency = 10;
2417
2526
 
2418
- if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
2527
+ if (device->compute_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
2419
2528
  ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool);
2420
2529
  }
2421
- if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) {
2530
+ if (device->transfer_queue.cmd_pool.buffers_in_use() >= cleanup_frequency) {
2422
2531
  ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool);
2423
2532
  }
2424
2533
  }
@@ -2666,7 +2775,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
2666
2775
  ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
2667
2776
  }
2668
2777
 
2669
- subctx->s->buffer.pipelineBarrier(
2778
+ subctx->s->buffer->buf.pipelineBarrier(
2670
2779
  subctx->p->q->stage_flags,
2671
2780
  subctx->p->q->stage_flags,
2672
2781
  {},
@@ -2682,7 +2791,7 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
2682
2791
  static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
2683
2792
  VK_LOG_DEBUG("ggml_vk_set_event()");
2684
2793
 
2685
- ctx->s->buffer.setEvent(
2794
+ ctx->s->buffer->buf.setEvent(
2686
2795
  event,
2687
2796
  ctx->p->q->stage_flags
2688
2797
  );
@@ -2694,7 +2803,7 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
2694
2803
  return;
2695
2804
  }
2696
2805
 
2697
- ctx->s->buffer.waitEvents(
2806
+ ctx->s->buffer->buf.waitEvents(
2698
2807
  events,
2699
2808
  ctx->p->q->stage_flags,
2700
2809
  ctx->p->q->stage_flags,
@@ -2704,78 +2813,218 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
2704
2813
  );
2705
2814
  }
2706
2815
 
2707
- // number of rows/cols for flash attention shader
2708
- static constexpr uint32_t flash_attention_num_small_rows = 32;
2709
- static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
2816
+ struct vk_fa_tuning_params {
2817
+ FaCodePath path;
2818
+ uint32_t workgroup_size;
2819
+ uint32_t subgroup_size;
2820
+ uint32_t block_rows;
2821
+ uint32_t block_cols;
2822
+ uint32_t d_split;
2823
+ uint32_t row_split;
2824
+ bool shmem_staging;
2825
+ bool disable_subgroups;
2826
+ uint32_t limit_occupancy_shmem;
2827
+
2828
+ void print() const {
2829
+ std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size <<
2830
+ " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split <<
2831
+ " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups <<
2832
+ " limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl;
2833
+ }
2834
+ };
2835
+
2836
+ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
2837
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
2710
2838
 
2711
- static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) {
2712
- if (hsv >= 192) {
2713
- return 2;
2714
- } else if ((hsv | hsk) & 8 || small_cache) {
2715
- return 4;
2839
+ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
2840
+ GGML_UNUSED(kv_type);
2841
+
2842
+ vk_fa_tuning_params result{};
2843
+ result.path = FA_SCALAR;
2844
+
2845
+ if (device->vendor_id == VK_VENDOR_ID_INTEL) {
2846
+ // Disable subgroup use due to performance issues when enforcing subgroup sizes
2847
+ result.subgroup_size = 32;
2848
+ result.disable_subgroups = true;
2849
+ } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) {
2850
+ result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size;
2716
2851
  } else {
2717
- return 8;
2852
+ result.subgroup_size = device->subgroup_size;
2718
2853
  }
2719
- }
2720
2854
 
2721
- // The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
2722
- // 128 threads split into four subgroups, each subgroup does 1/4
2723
- // of the Bc dimension.
2724
- static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
2725
- static constexpr uint32_t scalar_flash_attention_Bc = 64;
2726
- static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
2855
+ // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers
2856
+ uint32_t row_split_max_hsk = 64;
2857
+ if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) {
2858
+ row_split_max_hsk = n_rows <= 8 ? 64 : 128;
2859
+ }
2860
+ result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4;
2727
2861
 
2728
- static uint32_t get_fa_num_small_rows(FaCodePath path) {
2729
- if (path == FA_COOPMAT2) {
2730
- return flash_attention_num_small_rows;
2862
+ if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) {
2863
+ result.workgroup_size = result.subgroup_size * 2;
2731
2864
  } else {
2732
- return scalar_flash_attention_num_small_rows;
2865
+ result.workgroup_size = result.subgroup_size * 4;
2733
2866
  }
2734
- }
2735
2867
 
2736
- static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) {
2737
- GGML_UNUSED(clamp);
2868
+ const uint32_t D = hsk | hsv;
2738
2869
 
2739
- if (path == FA_SCALAR) {
2740
- if (small_rows) {
2741
- return {scalar_flash_attention_num_small_rows, 64};
2870
+ const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL;
2871
+
2872
+ if (n_rows == 1) {
2873
+ result.block_rows = 1;
2874
+ result.block_cols = 64;
2875
+ } else {
2876
+ // row_split 1 means higher register use per row, so block size has to be adjusted
2877
+ if (result.row_split == 1) {
2878
+ result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8);
2742
2879
  } else {
2743
- if ((hsv | hsk) & 8) {
2744
- // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter
2745
- // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not.
2746
- return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64};
2747
- } else {
2748
- return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32};
2749
- }
2880
+ result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16);
2750
2881
  }
2882
+
2883
+ result.block_cols = (D & 8) ? 64 : 32;
2751
2884
  }
2752
2885
 
2753
- if (path == FA_COOPMAT1) {
2754
- if (small_rows) {
2755
- return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
2756
- } else {
2757
- return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
2886
+ const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit
2887
+
2888
+ result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
2889
+
2890
+ result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
2891
+
2892
+ if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
2893
+ result.block_rows /= 2;
2894
+ }
2895
+
2896
+ // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled
2897
+ // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy.
2898
+ // This targets an occupancy of 4 subgroups per SIMD.
2899
+ if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) {
2900
+ if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) {
2901
+ // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size
2902
+ // Values are guessed, tested on RDNA2
2903
+ result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4;
2904
+ } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) {
2905
+ // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD.
2906
+ // Here low-batch FA with large head size is affected.
2907
+ // n_rows < 4 switch because workgroup size switches from 128 to 256 there.
2908
+ result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4;
2758
2909
  }
2759
2910
  }
2760
2911
 
2761
- // small rows, large cols
2912
+ return result;
2913
+ }
2914
+
2915
+ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
2916
+ GGML_UNUSED(n_rows);
2917
+ GGML_UNUSED(n_kv);
2918
+ GGML_UNUSED(kv_type);
2919
+ GGML_UNUSED(f32acc);
2920
+
2921
+ vk_fa_tuning_params result{};
2922
+ result.path = FA_COOPMAT1;
2923
+
2924
+ const uint32_t D = hsk | hsv;
2925
+
2926
+ const uint32_t coopmat_block_rows = 16;
2927
+ const uint32_t coopmat_block_cols = 16;
2928
+
2929
+ const uint32_t num_subgroups = 4;
2930
+
2931
+ result.block_rows = coopmat_block_rows;
2932
+ result.block_cols = coopmat_block_cols * num_subgroups;
2933
+ result.row_split = num_subgroups;
2934
+ result.subgroup_size = device->subgroup_size;
2935
+ result.workgroup_size = num_subgroups * result.subgroup_size;
2936
+
2937
+ const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit
2938
+ result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4);
2939
+
2940
+ result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
2941
+
2942
+ return result;
2943
+ }
2944
+
2945
+ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
2946
+ GGML_UNUSED(n_kv);
2947
+ GGML_UNUSED(f32acc);
2948
+
2949
+ vk_fa_tuning_params result{};
2950
+ result.path = FA_COOPMAT2;
2951
+
2952
+ const uint32_t D = hsk | hsv;
2953
+
2954
+ const bool small_rows = n_rows < 32;
2955
+
2762
2956
  if (small_rows) {
2763
- return {get_fa_num_small_rows(FA_COOPMAT2), 32};
2957
+ result.block_rows = 32;
2958
+ result.block_cols = 32;
2959
+ } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
2960
+ result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
2961
+ result.block_cols = 32;
2962
+ } else {
2963
+ result.block_rows = 64;
2964
+ result.block_cols = 64;
2764
2965
  }
2765
2966
 
2766
- // small cols to reduce register count
2767
- if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) {
2768
- if (hsk >= 512 || hsv >= 512) {
2769
- return {32, 32};
2770
- } else {
2771
- return {64, 32};
2967
+ result.subgroup_size = device->subgroup_size;
2968
+ result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128;
2969
+
2970
+ return result;
2971
+ }
2972
+
2973
+ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
2974
+ FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
2975
+ device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
2976
+
2977
+ if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
2978
+ // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
2979
+ path = FA_SCALAR;
2980
+ }
2981
+
2982
+ if (path == FA_COOPMAT1) {
2983
+ bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
2984
+ (!f32acc && device->coopmat_support_16x16x16_f16acc);
2985
+ const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
2986
+ bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
2987
+
2988
+ if (!shape_ok || !shmem_ok) {
2989
+ path = FA_SCALAR;
2772
2990
  }
2773
2991
  }
2774
- return {64, 64};
2992
+
2993
+ // scalar is faster than coopmat when N==1
2994
+ if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) {
2995
+ path = FA_SCALAR;
2996
+ }
2997
+
2998
+ switch (path) {
2999
+ case FA_SCALAR:
3000
+ return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
3001
+ case FA_COOPMAT1:
3002
+ return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
3003
+ case FA_COOPMAT2:
3004
+ return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
3005
+ default:
3006
+ throw std::runtime_error("unsupported FaCodePath");
3007
+ }
2775
3008
  }
2776
3009
 
2777
- static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) {
2778
- return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1];
3010
+ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
3011
+ bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
3012
+ const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
3013
+ (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
3014
+
3015
+ uint32_t flags = (use_mask_opt ? 1 : 0) |
3016
+ (use_mask ? 2 : 0) |
3017
+ (use_logit_softcap ? 4 : 0) |
3018
+ (old_amd_windows ? 8 : 0);
3019
+
3020
+ const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
3021
+
3022
+ return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
3023
+ }
3024
+
3025
+ static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
3026
+ return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
3027
+ state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
2779
3028
  }
2780
3029
 
2781
3030
  static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@@ -3142,60 +3391,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
3142
3391
  align, disable_robustness, require_full_subgroups, required_subgroup_size);
3143
3392
  };
3144
3393
 
3145
- auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array<uint32_t, 3> {
3146
- return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1};
3147
- };
3148
-
3149
- auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::vector<uint32_t> {
3150
- // For large number of rows, 128 invocations seems to work best.
3151
- // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
3152
- // can't use 256 for D==80.
3153
- // For scalar, use 128 (arbitrary)
3154
- // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs.
3155
- const uint32_t D = (hsk|hsv);
3156
- uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
3157
- ? scalar_flash_attention_workgroup_size
3158
- : ((small_rows && (D % 32) == 0) ? 256 : 128);
3159
- auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache);
3160
-
3161
- // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
3162
- // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
3163
- const uint32_t D_lsb = D ^ (D & (D-1));
3164
- uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
3165
-
3166
- return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split};
3167
- };
3168
-
3169
3394
  #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
3170
3395
  for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
3171
- uint32_t HSK = fa.first.HSK; \
3172
- uint32_t HSV = fa.first.HSV; \
3173
- bool small_rows = fa.first.small_rows; \
3174
- bool small_cache = fa.first.small_cache; \
3175
3396
  FaCodePath path = fa.first.path; \
3397
+ uint32_t Br = fa.first.Br; \
3398
+ uint32_t Bc = fa.first.Bc; \
3176
3399
  bool aligned = fa.first.aligned; \
3177
3400
  bool f32acc = fa.first.f32acc; \
3401
+ uint32_t fa_sgs = fa.first.subgroup_size; \
3402
+ bool fa_ds = fa.first.subgroup_size == 0; \
3178
3403
  if (path == FAPATH) { \
3179
3404
  if (aligned) { \
3180
3405
  if (f32acc) { \
3181
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3406
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
3182
3407
  } else { \
3183
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3408
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
3184
3409
  } \
3185
3410
  } else { \
3186
3411
  if (f32acc) { \
3187
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3412
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
3188
3413
  } else { \
3189
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), 1, true, true, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
3414
+ ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
3190
3415
  } \
3191
3416
  } \
3192
3417
  } \
3193
3418
  }
3194
3419
 
3195
- CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
3196
- CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
3197
- CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
3198
- CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
3420
+ if (device->fp16) {
3421
+ CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
3422
+ CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
3423
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
3424
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
3425
+ } else {
3426
+ CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
3427
+ CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
3428
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
3429
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
3430
+ }
3199
3431
  #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3200
3432
  if (device->coopmat1_fa_support) {
3201
3433
  CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
@@ -3713,10 +3945,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
3713
3945
  && !device->coopmat_bf16_support
3714
3946
  #endif
3715
3947
  ) {
3948
+ const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32;
3949
+
3716
3950
  // use scalar tile sizes
3717
3951
  l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
3718
3952
  m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
3719
- s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
3953
+ s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 };
3720
3954
 
3721
3955
  l_wg_denoms = {128, 128, 1 };
3722
3956
  m_wg_denoms = { 64, 64, 1 };
@@ -3980,7 +4214,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
3980
4214
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
3981
4215
 
3982
4216
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
3983
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
4217
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
4218
+
4219
+ for (auto &it : device->pipeline_fa_mask_opt) {
4220
+ auto BrBc = it.first;
4221
+ ggml_vk_create_pipeline(device, it.second, "fa_mask_opt", fa_mask_opt_len, fa_mask_opt_data, "main", 2, sizeof(vk_op_flash_attn_mask_opt_push_constants), {1, 1, 1}, {128, 128 / device->subgroup_size, BrBc.first, BrBc.second}, 1, true, true, device->subgroup_size);
4222
+ }
3984
4223
 
3985
4224
  if (device->subgroup_clustered && device->subgroup_require_full_support) {
3986
4225
  ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true);
@@ -4012,7 +4251,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4012
4251
  }
4013
4252
 
4014
4253
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
4015
- ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
4254
+ ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
4016
4255
 
4017
4256
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4018
4257
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -4113,7 +4352,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
4113
4352
 
4114
4353
  ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
4115
4354
 
4116
- ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4355
+ ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
4356
+ ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
4117
4357
 
4118
4358
  ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4119
4359
  ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -4158,6 +4398,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4158
4398
  ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4159
4399
  ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4160
4400
 
4401
+ CREATE_UNARY(elu)
4161
4402
  CREATE_UNARY(gelu)
4162
4403
  CREATE_UNARY(gelu_erf)
4163
4404
  CREATE_UNARY(gelu_quick)
@@ -4176,6 +4417,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4176
4417
  CREATE_UNARY(ceil)
4177
4418
  CREATE_UNARY(floor)
4178
4419
  CREATE_UNARY(trunc)
4420
+ CREATE_UNARY(sgn)
4179
4421
  #undef CREATE_UNARY
4180
4422
 
4181
4423
  #define CREATE_UNARY_RTE(name) \
@@ -4340,6 +4582,23 @@ static void ggml_vk_load_shaders(vk_device& device) {
4340
4582
 
4341
4583
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
4342
4584
 
4585
+ {
4586
+ const uint32_t gdn_sizes[] = {32, 64, 128};
4587
+ const char * gdn_names[][2] = {
4588
+ {"gated_delta_net_f32_d32", "gated_delta_net_f32_d32_kda"},
4589
+ {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
4590
+ {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
4591
+ };
4592
+ for (uint32_t si = 0; si < 3; si++) {
4593
+ for (uint32_t kda = 0; kda < 2; kda++) {
4594
+ ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
4595
+ gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
4596
+ "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
4597
+ {1, 1, 1}, {gdn_sizes[si], kda}, 1);
4598
+ }
4599
+ }
4600
+ }
4601
+
4343
4602
  if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
4344
4603
  ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
4345
4604
  ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
@@ -4348,7 +4607,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4348
4607
  ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
4349
4608
  }
4350
4609
 
4351
- ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
4610
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
4352
4611
 
4353
4612
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4354
4613
 
@@ -4460,6 +4719,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4460
4719
  }
4461
4720
 
4462
4721
  static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
4722
+ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev);
4463
4723
 
4464
4724
  static vk_device ggml_vk_get_device(size_t idx) {
4465
4725
  VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
@@ -4676,6 +4936,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
4676
4936
  device->shader_core_count = sm_props.shaderSMCount;
4677
4937
  } else if (amd_shader_core_properties2) {
4678
4938
  device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
4939
+ } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
4940
+ device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device);
4679
4941
  } else {
4680
4942
  device->shader_core_count = 0;
4681
4943
  }
@@ -4719,8 +4981,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
4719
4981
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
4720
4982
 
4721
4983
  // Try to find a non-graphics compute queue and transfer-focused queues
4722
- const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1);
4723
- const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1);
4984
+ // On AMD, the graphics queue seems to be faster, so don't avoid it
4985
+ const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;
4986
+ const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1);
4987
+ const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1);
4724
4988
 
4725
4989
  const float priorities[] = { 1.0f, 1.0f };
4726
4990
  device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1;
@@ -4895,11 +5159,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
4895
5159
 
4896
5160
  #if defined(VK_KHR_cooperative_matrix)
4897
5161
  device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
4898
-
4899
- // coopmat1 fa shader currently assumes 32 invocations per subgroup
4900
- device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
4901
- device->subgroup_size_control && device->subgroup_min_size <= 32 &&
4902
- device->subgroup_max_size >= 32;
5162
+ device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support;
4903
5163
  #endif
4904
5164
 
4905
5165
  if (coopmat2_support) {
@@ -5186,10 +5446,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
5186
5446
  if (!device->single_queue) {
5187
5447
  const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
5188
5448
  ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
5449
+
5450
+ device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
5189
5451
  } else {
5190
5452
  // TODO: Use pointer or reference to avoid copy
5191
5453
  device->transfer_queue.copyFrom(device->compute_queue);
5192
5454
  device->transfer_queue.cmd_pool.init(device, &device->transfer_queue);
5455
+
5456
+ device->async_use_transfer_queue = false;
5193
5457
  }
5194
5458
 
5195
5459
  device->buffer_type = {
@@ -5467,6 +5731,10 @@ static void ggml_vk_instance_init() {
5467
5731
  vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr;
5468
5732
  vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr;
5469
5733
  vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr;
5734
+ const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS");
5735
+ if (GGML_VK_PIPELINE_STATS != nullptr) {
5736
+ vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS;
5737
+ }
5470
5738
  const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY");
5471
5739
 
5472
5740
  if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) {
@@ -5513,22 +5781,30 @@ static void ggml_vk_instance_init() {
5513
5781
 
5514
5782
  if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) {
5515
5783
  // Check if there are two physical devices corresponding to the same GPU
5784
+ // This handles the case where the same GPU appears with different drivers (e.g., RADV + AMDVLK on Linux),
5785
+ // see https://github.com/ggml-org/llama.cpp/pull/7582 for original deduplication.
5786
+ // MoltenVK on macOS may report the same UUID for distinct GPUs on multi-GPU cards,
5787
+ // see https://github.com/KhronosGroup/MoltenVK/issues/2683. Skip when both old/new
5788
+ // driver is MoltenVK
5516
5789
  auto old_device = std::find_if(
5517
5790
  vk_instance.device_indices.begin(),
5518
5791
  vk_instance.device_indices.end(),
5519
- [&devices, &new_id](const size_t k){
5792
+ [&devices, &new_id, &new_driver](const size_t k){
5520
5793
  vk::PhysicalDeviceProperties2 old_props;
5794
+ vk::PhysicalDeviceDriverProperties old_driver;
5521
5795
  vk::PhysicalDeviceIDProperties old_id;
5522
- old_props.pNext = &old_id;
5796
+ old_props.pNext = &old_driver;
5797
+ old_driver.pNext = &old_id;
5523
5798
  devices[k].getProperties2(&old_props);
5524
5799
 
5525
- bool equals = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
5526
- equals = equals || (
5800
+ bool same_uuid = std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID));
5801
+ same_uuid = same_uuid || (
5527
5802
  old_id.deviceLUIDValid && new_id.deviceLUIDValid &&
5528
5803
  std::equal(std::begin(old_id.deviceLUID), std::end(old_id.deviceLUID), std::begin(new_id.deviceLUID))
5529
5804
  );
5805
+ bool both_molten_vk = (new_driver.driverID == vk::DriverId::eMoltenvk && old_driver.driverID == vk::DriverId::eMoltenvk);
5530
5806
 
5531
- return equals;
5807
+ return same_uuid && !both_molten_vk;
5532
5808
  }
5533
5809
  );
5534
5810
  if (old_device == vk_instance.device_indices.end()) {
@@ -5565,6 +5841,10 @@ static void ggml_vk_instance_init() {
5565
5841
  driver_priorities[vk::DriverId::eMesaNvk] = 2;
5566
5842
  #endif
5567
5843
  break;
5844
+ case VK_VENDOR_ID_QUALCOMM:
5845
+ driver_priorities[vk::DriverId::eQualcommProprietary] = 1;
5846
+ driver_priorities[vk::DriverId::eMesaTurnip] = 2;
5847
+ break;
5568
5848
  }
5569
5849
  driver_priorities[vk::DriverId::eMesaDozen] = 100;
5570
5850
 
@@ -5647,7 +5927,15 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
5647
5927
  ctx->almost_ready_fence = ctx->device->device.createFence({});
5648
5928
 
5649
5929
  ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue);
5650
- ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
5930
+ if (ctx->device->async_use_transfer_queue) {
5931
+ vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
5932
+ vk::SemaphoreCreateInfo ci{};
5933
+ ci.setPNext(&tci);
5934
+ ctx->transfer_semaphore.s = ctx->device->device.createSemaphore(ci);
5935
+ ctx->transfer_semaphore.value = 0;
5936
+
5937
+ ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue);
5938
+ }
5651
5939
 
5652
5940
  if (vk_perf_logger_enabled) {
5653
5941
  ctx->perf_logger = std::unique_ptr<vk_perf_logger>(new vk_perf_logger());
@@ -6100,13 +6388,24 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
6100
6388
  return vk_subbuffer{buffer, offset, size};
6101
6389
  }
6102
6390
 
6391
+ // Get a command buffer from pool. Create a new one if no reusable buffer is available
6392
+ static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {
6393
+ for (auto& cmd_buffer : pool.cmd_buffers) {
6394
+ if (!cmd_buffer.in_use) {
6395
+ cmd_buffer.in_use = true;
6396
+ return &cmd_buffer;
6397
+ }
6398
+ }
6399
+ return ggml_vk_create_cmd_buffer(device, pool);
6400
+ }
6401
+
6103
6402
  static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) {
6104
6403
  vk_submission s;
6105
- s.buffer = ggml_vk_create_cmd_buffer(device, p);
6404
+ s.buffer = ggml_vk_get_or_create_cmd_buffer(device, p);
6106
6405
  if (one_time) {
6107
- s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
6406
+ s.buffer->buf.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit });
6108
6407
  } else {
6109
- s.buffer.begin({ vk::CommandBufferUsageFlags{} });
6408
+ s.buffer->buf.begin({ vk::CommandBufferUsageFlags{} });
6110
6409
  }
6111
6410
 
6112
6411
  return s;
@@ -6159,18 +6458,18 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
6159
6458
  vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() };
6160
6459
  ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {});
6161
6460
 
6162
- subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
6163
- subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
6164
- subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
6461
+ subctx->s->buffer->buf.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants));
6462
+ subctx->s->buffer->buf.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline);
6463
+ subctx->s->buffer->buf.bindDescriptorSets(vk::PipelineBindPoint::eCompute,
6165
6464
  pipeline->layout,
6166
6465
  0,
6167
6466
  { descriptor_set },
6168
6467
  {});
6169
- subctx->s->buffer.dispatch(wg0, wg1, wg2);
6468
+ subctx->s->buffer->buf.dispatch(wg0, wg1, wg2);
6170
6469
  }
6171
6470
 
6172
6471
  static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
6173
- s.buffer.end();
6472
+ s.buffer->buf.end();
6174
6473
 
6175
6474
  s.wait_semaphores = std::move(wait_semaphores);
6176
6475
  s.signal_semaphores = std::move(signal_semaphores);
@@ -6182,7 +6481,7 @@ static void ggml_vk_ctx_end(vk_context& ctx) {
6182
6481
  return;
6183
6482
  }
6184
6483
 
6185
- ctx->s->buffer.end();
6484
+ ctx->s->buffer->buf.end();
6186
6485
  ctx->s = nullptr;
6187
6486
  }
6188
6487
 
@@ -6196,6 +6495,47 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
6196
6495
  subctx->s = subctx->seqs[subctx->seqs.size() - 1].data();
6197
6496
  }
6198
6497
 
6498
+ static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
6499
+ if (!ctx->compute_ctx.expired()) {
6500
+ return ctx->compute_ctx.lock();
6501
+ }
6502
+
6503
+ vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
6504
+
6505
+ ctx->compute_ctx = result;
6506
+ ggml_vk_ctx_begin(ctx->device, result);
6507
+
6508
+ if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
6509
+ result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
6510
+ ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
6511
+ }
6512
+
6513
+ return result;
6514
+ }
6515
+
6516
+ // Submit any pending transfer queue work and signal the transfer semaphore.
6517
+ // The next compute context created via ggml_vk_get_compute_ctx will wait on this semaphore.
6518
+ // Returns true if work was submitted.
6519
+ static bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) {
6520
+ if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) {
6521
+ return false;
6522
+ }
6523
+
6524
+ vk_context cpy_ctx = ctx->transfer_ctx.lock();
6525
+ ggml_vk_ctx_end(cpy_ctx);
6526
+
6527
+ for (auto& cpy : cpy_ctx->in_memcpys) {
6528
+ memcpy(cpy.dst, cpy.src, cpy.n);
6529
+ }
6530
+
6531
+ ctx->transfer_semaphore.value++;
6532
+ cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore);
6533
+
6534
+ ggml_vk_submit(cpy_ctx, {});
6535
+ ctx->transfer_ctx.reset();
6536
+ return true;
6537
+ }
6538
+
6199
6539
  static size_t ggml_vk_align_size(size_t width, size_t align) {
6200
6540
  VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")");
6201
6541
  return CEIL_DIV(width, align) * align;
@@ -6295,7 +6635,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
6295
6635
  }
6296
6636
 
6297
6637
  ggml_vk_sync_buffers(ctx, subctx);
6298
- subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
6638
+ subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
6299
6639
  return;
6300
6640
  }
6301
6641
 
@@ -6310,7 +6650,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
6310
6650
  VkBufferCopy buf_copy{ 0, offset, copy_size };
6311
6651
 
6312
6652
  ggml_vk_sync_buffers(ctx, subctx);
6313
- vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
6653
+ vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
6314
6654
 
6315
6655
  for (uint64_t i3 = 0; i3 < ne3; i3++) {
6316
6656
  for (uint64_t i2 = 0; i2 < ne2; i2++) {
@@ -6359,7 +6699,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
6359
6699
  }
6360
6700
 
6361
6701
  ggml_vk_sync_buffers(nullptr, subctx);
6362
- subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices);
6702
+ subctx->s->buffer->buf.copyBuffer(buf->buffer, dst->buffer, slices);
6363
6703
  return true;
6364
6704
  }
6365
6705
  VK_LOG_DEBUG("STAGING");
@@ -6381,7 +6721,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
6381
6721
  copy_size};
6382
6722
 
6383
6723
  ggml_vk_sync_buffers(nullptr, subctx);
6384
- vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
6724
+ vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
6385
6725
 
6386
6726
  if (width == spitch) {
6387
6727
  deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
@@ -6467,7 +6807,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
6467
6807
  if (buf != nullptr) {
6468
6808
  // Memory is pinned, use as staging buffer
6469
6809
  ggml_vk_sync_buffers(nullptr, subctx);
6470
- subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices);
6810
+ subctx->s->buffer->buf.copyBuffer(src->buffer, buf->buffer, slices);
6471
6811
 
6472
6812
  return true;
6473
6813
  }
@@ -6485,7 +6825,7 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
6485
6825
  vk_buffer& staging_buffer = src->device->sync_staging;
6486
6826
 
6487
6827
  ggml_vk_sync_buffers(nullptr, subctx);
6488
- subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices);
6828
+ subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices);
6489
6829
 
6490
6830
  deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
6491
6831
  return true;
@@ -6532,7 +6872,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
6532
6872
 
6533
6873
  VkBufferCopy bc{ src_offset, dst_offset, size };
6534
6874
 
6535
- vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
6875
+ vkCmdCopyBuffer(ctx->s->buffer->buf, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
6536
6876
  }
6537
6877
 
6538
6878
  static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
@@ -6570,7 +6910,7 @@ static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t
6570
6910
  }
6571
6911
 
6572
6912
  // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers
6573
- ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
6913
+ ctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
6574
6914
  }
6575
6915
 
6576
6916
  static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) {
@@ -6585,7 +6925,7 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
6585
6925
  std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
6586
6926
  vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
6587
6927
  ggml_vk_ctx_begin(dst->device, subctx);
6588
- subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c);
6928
+ subctx->s->buffer->buf.fillBuffer(dst->buffer, offset, size, c);
6589
6929
  ggml_vk_ctx_end(subctx);
6590
6930
 
6591
6931
  ggml_vk_submit(subctx, dst->device->fence);
@@ -6691,8 +7031,16 @@ static void ggml_vk_matmul(
6691
7031
  uint32_t padded_n) {
6692
7032
  VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
6693
7033
  if (split_k == 1) {
6694
- const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
6695
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch });
7034
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
7035
+
7036
+ uint32_t base_work_group_z = 0;
7037
+ while (base_work_group_z < batch) {
7038
+ uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
7039
+
7040
+ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n };
7041
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z });
7042
+ base_work_group_z += groups_z;
7043
+ }
6696
7044
  return;
6697
7045
  }
6698
7046
 
@@ -6706,9 +7054,17 @@ static void ggml_vk_matmul(
6706
7054
  uint32_t k_split = CEIL_DIV(k, split_k);
6707
7055
  k_split = ROUNDUP_POW2(k_split, 256);
6708
7056
 
6709
- const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
6710
- // Make sure enough workgroups get assigned for split k to work
6711
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
7057
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2]));
7058
+
7059
+ uint32_t base_work_group_z = 0;
7060
+ while (base_work_group_z < batch) {
7061
+ uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
7062
+
7063
+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n };
7064
+ // Make sure enough workgroups get assigned for split k to work
7065
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z });
7066
+ base_work_group_z += groups_z;
7067
+ }
6712
7068
  ggml_vk_sync_buffers(ctx, subctx);
6713
7069
  const std::array<uint32_t, 2> pc2 = { (uint32_t)(m * n * batch), split_k };
6714
7070
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 });
@@ -7104,7 +7460,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
7104
7460
  }
7105
7461
 
7106
7462
  // Request descriptor sets
7107
- ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
7108
7463
  if (qx_needs_dequant) {
7109
7464
  ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1);
7110
7465
  }
@@ -7274,6 +7629,18 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
7274
7629
  return false;
7275
7630
  }
7276
7631
 
7632
+ if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
7633
+ // Intel Windows proprietary driver tuning
7634
+ switch (src0_type) {
7635
+ case GGML_TYPE_MXFP4:
7636
+ case GGML_TYPE_Q4_K:
7637
+ case GGML_TYPE_Q5_K:
7638
+ return false;
7639
+ default:
7640
+ return true;
7641
+ }
7642
+ }
7643
+
7277
7644
  switch (src0_type) {
7278
7645
  // From tests on A770 Linux, may need more tuning
7279
7646
  case GGML_TYPE_Q4_0:
@@ -7402,7 +7769,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
7402
7769
  if (quantize_y) {
7403
7770
  ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
7404
7771
  }
7405
- ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
7406
7772
  }
7407
7773
 
7408
7774
  vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -7497,22 +7863,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
7497
7863
  fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
7498
7864
  }
7499
7865
 
7500
- // compute
7501
- const vk_mat_vec_push_constants pc = {
7502
- (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
7503
- stride_batch_x, stride_batch_y, stride_batch_d,
7504
- fusion_flags,
7505
- (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
7506
- };
7507
- ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
7508
- {
7509
- d_X,
7510
- d_Y,
7511
- d_D,
7512
- d_F0,
7513
- d_F1,
7514
- },
7515
- pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z });
7866
+ ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1]));
7867
+
7868
+ uint32_t base_work_group_y = 0;
7869
+ while (base_work_group_y < ne12 * ne13) {
7870
+
7871
+ uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
7872
+ const vk_mat_vec_push_constants pc = {
7873
+ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
7874
+ stride_batch_x, stride_batch_y, stride_batch_d,
7875
+ fusion_flags, base_work_group_y,
7876
+ (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
7877
+ };
7878
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
7879
+ {
7880
+ d_X,
7881
+ d_Y,
7882
+ d_D,
7883
+ d_F0,
7884
+ d_F1,
7885
+ },
7886
+ pc, { groups_x, groups_y, groups_z });
7887
+ base_work_group_y += groups_y;
7888
+ }
7516
7889
 
7517
7890
  if (x_non_contig) {
7518
7891
  ctx->prealloc_x_need_sync = true;
@@ -7750,10 +8123,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
7750
8123
  src1->nb[2] <= src1->nb[1] &&
7751
8124
  src1->nb[1] <= src1->nb[3] &&
7752
8125
  src0->ne[3] == 1 &&
7753
- src1->ne[3] == 1) {
8126
+ src1->ne[3] == 1 &&
8127
+ src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
8128
+ src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
7754
8129
  ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx);
7755
8130
  } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
7756
- !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
8131
+ !ggml_is_permuted(src0) && !ggml_is_permuted(src1) &&
8132
+ src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] &&
8133
+ src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] &&
8134
+ src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) {
7757
8135
  ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx);
7758
8136
  // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
7759
8137
  // when ne12 and ne13 are one.
@@ -8083,8 +8461,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
8083
8461
 
8084
8462
  const uint64_t nei0 = ids->ne[0];
8085
8463
  const uint64_t nei1 = ids->ne[1];
8086
-
8087
- GGML_ASSERT(nei1 == 1);
8464
+ const uint32_t nbi1 = (uint32_t)(ids->nb[1] / sizeof(int));
8088
8465
 
8089
8466
  const uint64_t ne20 = dst->ne[0];
8090
8467
  const uint64_t ne21 = dst->ne[1];
@@ -8168,7 +8545,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
8168
8545
  if (quantize_y) {
8169
8546
  ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
8170
8547
  }
8171
- ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
8548
+ ggml_pipeline_request_descriptor_sets(ctx, dmmv, nei1);
8172
8549
  }
8173
8550
 
8174
8551
  vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
@@ -8226,7 +8603,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
8226
8603
  uint32_t stride_batch_y = ne10*ne11;
8227
8604
 
8228
8605
  if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
8229
- stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
8606
+ stride_batch_y = src1->nb[2] / ggml_type_size(src1->type);
8230
8607
  }
8231
8608
 
8232
8609
  const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
@@ -8262,23 +8639,25 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
8262
8639
  fusion_flags |= MAT_VEC_FUSION_FLAGS_SCALE1;
8263
8640
  }
8264
8641
 
8265
- // compute
8266
- const vk_mat_vec_id_push_constants pc = {
8267
- (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
8268
- (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
8269
- fusion_flags,
8270
- (uint32_t)nei0, (uint32_t)ne11,
8271
- };
8272
- ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
8273
- {
8274
- d_X,
8275
- d_Y,
8276
- d_D,
8277
- d_F0,
8278
- d_F1,
8279
- d_ids,
8280
- },
8281
- pc, { groups_x, (uint32_t)nei0, groups_z });
8642
+ // Loop over the batch dimension
8643
+ for (uint32_t expert_i1 = 0; expert_i1 < nei1; ++expert_i1) {
8644
+ const vk_mat_vec_id_push_constants pc = {
8645
+ (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
8646
+ (uint32_t)(ne00 * ne01), stride_batch_y, (uint32_t)(ne20 * ne21),
8647
+ fusion_flags,
8648
+ (uint32_t)nei0, (uint32_t)ne11, expert_i1, nbi1
8649
+ };
8650
+ ggml_vk_dispatch_pipeline(ctx, subctx, dmmv,
8651
+ {
8652
+ d_X,
8653
+ d_Y,
8654
+ d_D,
8655
+ d_F0,
8656
+ d_F1,
8657
+ d_ids,
8658
+ },
8659
+ pc, { groups_x, (uint32_t)nei0, groups_z });
8660
+ }
8282
8661
 
8283
8662
  if (x_non_contig) {
8284
8663
  ctx->prealloc_x_need_sync = true;
@@ -8292,7 +8671,7 @@ static bool ggml_vk_use_mul_mat_vec_id(const struct ggml_cgraph * cgraph, int no
8292
8671
  ggml_tensor * dst = cgraph->nodes[node_idx];
8293
8672
  ggml_tensor * src0 = dst->src[0];
8294
8673
  ggml_tensor * src2 = dst->src[2];
8295
- return src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
8674
+ return (src2->ne[1] <= 8) && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type));
8296
8675
  }
8297
8676
 
8298
8677
  static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -8308,55 +8687,70 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
8308
8687
  }
8309
8688
  }
8310
8689
 
8311
- static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) {
8690
+ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
8691
+ GGML_UNUSED(f32acc);
8312
8692
  // Needs to be kept up to date on shader changes
8313
- GGML_UNUSED(hsv);
8314
- const uint32_t wg_size = scalar_flash_attention_workgroup_size;
8315
- const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache);
8316
- const uint32_t Bc = scalar_flash_attention_Bc;
8693
+ const uint32_t wg_size = params.workgroup_size;
8694
+ const uint32_t Br = params.block_rows;
8695
+ const uint32_t Bc = params.block_cols;
8317
8696
 
8697
+ const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
8698
+
8699
+ // tmpsh is overestimated slightly
8318
8700
  const uint32_t tmpsh = wg_size * sizeof(float);
8319
- const uint32_t tmpshv4 = wg_size * 4 * sizeof(float);
8701
+ const uint32_t tmpshv4 = wg_size * 4 * float_type_size;
8702
+
8703
+ const uint32_t masksh = Bc * (Br + 1) * float_type_size;
8320
8704
 
8321
- const uint32_t masksh = Bc * Br * sizeof(float);
8705
+ const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
8322
8706
 
8323
- const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float);
8707
+ const uint32_t D = std::max(hsk, hsv);
8708
+ const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
8324
8709
 
8325
- const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf;
8710
+ const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
8326
8711
  const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
8327
8712
 
8328
- VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
8713
+ VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
8329
8714
 
8330
8715
  return supported;
8331
8716
  }
8332
8717
 
8333
- static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) {
8718
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
8334
8719
  // Needs to be kept up to date on shader changes
8335
- GGML_UNUSED(hsv);
8336
- const uint32_t wg_size = scalar_flash_attention_workgroup_size;
8337
- const uint32_t Br = coopmat1_flash_attention_num_large_rows;
8338
- const uint32_t Bc = scalar_flash_attention_Bc;
8720
+ const uint32_t Br = params.block_rows;
8721
+ const uint32_t Bc = params.block_cols;
8722
+
8723
+ const uint32_t MatBr = 16, MatBc = 16;
8724
+
8725
+ const uint32_t row_split = Bc / MatBc;
8339
8726
 
8340
8727
  const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16);
8728
+ const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16);
8341
8729
 
8342
8730
  const uint32_t acctype = f32acc ? 4 : 2;
8343
8731
  const uint32_t f16vec4 = 8;
8344
8732
 
8345
- const uint32_t tmpsh = wg_size * sizeof(float);
8346
- const uint32_t tmpshv4 = wg_size * 4 * acctype;
8733
+ const uint32_t tmpsh = (Bc / MatBc) * sizeof(float);
8347
8734
 
8348
8735
  const uint32_t qstride = hsk_pad / 4 + 2;
8349
8736
  const uint32_t Qf = Br * qstride * f16vec4;
8350
8737
 
8738
+ const uint32_t psh_stride = Br / 4 + 2;
8739
+ const uint32_t Psh = Bc * psh_stride * f16vec4;
8740
+
8351
8741
  const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br;
8352
8742
  const uint32_t sfsh = Bc * sfshstride * acctype;
8353
8743
 
8354
- const uint32_t kshstride = hsk_pad / 4 + 2;
8355
- const uint32_t ksh = Bc * kshstride * f16vec4;
8744
+ const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2;
8745
+ const uint32_t vsh_stride = MatBc / 4 * row_split;
8746
+ const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
8356
8747
 
8357
- const uint32_t slope = Br * sizeof(float);
8748
+ const uint32_t osh_stride = params.row_split * MatBr / 4;
8749
+ const uint32_t pvsh = MatBc * osh_stride * f16vec4;
8358
8750
 
8359
- const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
8751
+ const uint32_t slope = Br * acctype;
8752
+
8753
+ const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope;
8360
8754
  const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
8361
8755
 
8362
8756
  VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
@@ -8383,6 +8777,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8383
8777
  GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
8384
8778
  GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
8385
8779
 
8780
+ const uint32_t nem0 = mask ? mask->ne[0] : 0;
8386
8781
  const uint32_t nem1 = mask ? mask->ne[1] : 0;
8387
8782
  const uint32_t nem2 = mask ? mask->ne[2] : 0;
8388
8783
  const uint32_t nem3 = mask ? mask->ne[3] : 0;
@@ -8416,72 +8811,30 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8416
8811
  assert(q->type == GGML_TYPE_F32);
8417
8812
  assert(k->type == v->type);
8418
8813
 
8419
- FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
8420
- ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
8421
-
8422
- if (path == FA_COOPMAT1) {
8423
- const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
8424
- (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
8425
-
8426
- const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32);
8427
-
8428
- if (!coopmat_shape_supported || !coopmat_shmem_supported) {
8429
- path = FA_SCALAR;
8430
- }
8431
- }
8432
-
8433
8814
  uint32_t gqa_ratio = 1;
8434
8815
  uint32_t qk_ratio = neq2 / nek2;
8435
8816
  uint32_t workgroups_x = (uint32_t)neq1;
8436
8817
  uint32_t workgroups_y = (uint32_t)neq2;
8437
8818
  uint32_t workgroups_z = (uint32_t)neq3;
8438
8819
 
8439
- const bool small_cache = nek1 < 1024;
8820
+ const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
8440
8821
 
8441
8822
  // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
8442
8823
  // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
8443
- uint32_t max_gqa;
8444
- switch (path) {
8445
- case FA_SCALAR:
8446
- case FA_COOPMAT1:
8447
- // We may switch from coopmat1 to scalar, so use the scalar limit for both
8448
- max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache);
8449
- break;
8450
- case FA_COOPMAT2:
8451
- max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
8452
- break;
8453
- default:
8454
- GGML_ASSERT(0);
8455
- }
8824
+ vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
8825
+ const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
8456
8826
 
8457
- if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
8827
+ if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
8458
8828
  qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) {
8459
8829
  // grouped query attention - make the N dimension equal to gqa_ratio, reduce
8460
8830
  // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
8461
8831
  // and change addressing calculations to index Q's dimension 2.
8462
8832
  gqa_ratio = qk_ratio;
8463
8833
  N = gqa_ratio;
8464
- workgroups_y /= N;
8465
- }
8466
-
8467
- bool small_rows = N <= get_fa_num_small_rows(path);
8468
-
8469
- // coopmat1 does not actually support "small rows" (it needs 16 rows).
8470
- // So use scalar instead.
8471
- if (small_rows && path == FA_COOPMAT1) {
8472
- path = FA_SCALAR;
8834
+ workgroups_y /= gqa_ratio;
8473
8835
  }
8474
8836
 
8475
- // scalar is faster than coopmat2 when N==1
8476
- if (N == 1 && path == FA_COOPMAT2) {
8477
- path = FA_SCALAR;
8478
- }
8479
-
8480
- // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory
8481
- if (path == FA_SCALAR &&
8482
- !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) {
8483
- small_rows = true;
8484
- }
8837
+ tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
8485
8838
 
8486
8839
  const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
8487
8840
  uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
@@ -8495,19 +8848,32 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8495
8848
  v_stride /= 4;
8496
8849
  }
8497
8850
 
8498
- uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache);
8851
+ const uint32_t alignment = tuning_params.block_cols;
8499
8852
  bool aligned = (KV % alignment) == 0 &&
8500
8853
  // the "aligned" shader variant will forcibly align strides, for performance
8501
8854
  (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
8502
8855
 
8503
8856
  // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned.
8504
- if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) {
8857
+ if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) {
8505
8858
  aligned = false;
8506
8859
  }
8507
8860
 
8508
- bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
8861
+ float scale = 1.0f;
8862
+ float max_bias = 0.0f;
8863
+ float logit_softcap = 0.0f;
8509
8864
 
8510
- vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc);
8865
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8866
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8867
+ memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
8868
+
8869
+ if (logit_softcap != 0) {
8870
+ scale /= logit_softcap;
8871
+ }
8872
+
8873
+ // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
8874
+ bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
8875
+ vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
8876
+ mask != nullptr, use_mask_opt, logit_softcap != 0);
8511
8877
 
8512
8878
  vk_pipeline pipeline = nullptr;
8513
8879
 
@@ -8523,29 +8889,46 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8523
8889
  }
8524
8890
 
8525
8891
  assert(pipeline);
8892
+ // Compile early to initialize wg_denoms.
8893
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
8526
8894
 
8527
8895
  uint32_t split_kv = KV;
8528
8896
  uint32_t split_k = 1;
8529
8897
 
8898
+ // Intel Alchemist prefers more workgroups
8899
+ const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1;
8900
+
8530
8901
  // Use a placeholder core count if one isn't available. split_k is a big help for perf.
8531
- const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
8902
+ const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16;
8532
8903
 
8533
- // Try to use split_k when KV is large enough to be worth the overhead
8534
- if (workgroups_x == 1 && shader_core_count > 0) {
8535
- // Try to run two workgroups per SM.
8536
- split_k = shader_core_count * 2 / (workgroups_y * workgroups_z);
8537
- if (split_k > 1) {
8538
- // Try to evenly split KV into split_k chunks, but it needs to be a multiple
8539
- // of "align", so recompute split_k based on that.
8540
- split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
8541
- split_k = CEIL_DIV(KV, split_kv);
8542
- workgroups_x = split_k;
8904
+ const uint32_t Br = fa_pipeline_state.Br;
8905
+ const uint32_t Bc = fa_pipeline_state.Bc;
8906
+
8907
+ GGML_ASSERT(Br == pipeline->wg_denoms[0]);
8908
+ const uint32_t Tr = CEIL_DIV(N, Br);
8909
+
8910
+ // Try to use split_k when KV is large enough to be worth the overhead.
8911
+ if (gqa_ratio > 1 && workgroups_x <= Br) {
8912
+ split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z);
8913
+ } else if (gqa_ratio <= 1) {
8914
+ uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z;
8915
+ if (total_wgs_no_split < shader_core_count * 2) {
8916
+ split_k = shader_core_count * 2 / total_wgs_no_split;
8543
8917
  }
8544
8918
  }
8545
8919
 
8920
+ if (split_k > 1) {
8921
+ // Try to evenly split KV into split_k chunks, but it needs to be a multiple
8922
+ // of "align", so recompute split_k based on that.
8923
+ split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment);
8924
+ split_k = CEIL_DIV(KV, split_kv);
8925
+ }
8926
+
8546
8927
  // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1)
8547
8928
  // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
8548
- const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
8929
+ // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3].
8930
+ // For L/M, the order is (inner to outer) [ne1, k, ne2, ne3].
8931
+ const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne2 * ne3 : 0;
8549
8932
  if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) {
8550
8933
  GGML_ABORT("Requested preallocation size is too large");
8551
8934
  }
@@ -8554,24 +8937,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8554
8937
  ggml_vk_preallocate_buffers(ctx, subctx);
8555
8938
  }
8556
8939
 
8557
- {
8558
- // Request descriptor sets
8559
- ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
8560
- if (split_k > 1) {
8561
- ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
8562
- }
8563
- }
8564
-
8565
- float scale = 1.0f;
8566
- float max_bias = 0.0f;
8567
- float logit_softcap = 0.0f;
8940
+ const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc);
8941
+ const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3;
8568
8942
 
8569
- memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
8570
- memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
8571
- memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
8943
+ vk_pipeline pipeline_fa_mask_opt = nullptr;
8944
+ if (use_mask_opt) {
8945
+ std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
8946
+ auto &pipelines = ctx->device->pipeline_fa_mask_opt;
8947
+ auto it = pipelines.find({Br, Bc});
8948
+ if (it != pipelines.end()) {
8949
+ pipeline_fa_mask_opt = it->second;
8950
+ } else {
8951
+ pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
8952
+ }
8953
+ assert(pipeline_fa_mask_opt);
8954
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
8572
8955
 
8573
- if (logit_softcap != 0) {
8574
- scale /= logit_softcap;
8956
+ if (ctx->prealloc_size_y < mask_opt_size) {
8957
+ ctx->prealloc_size_y = mask_opt_size;
8958
+ ggml_vk_preallocate_buffers(ctx, subctx);
8959
+ }
8960
+ if (ctx->prealloc_y_need_sync) {
8961
+ ggml_vk_sync_buffers(ctx, subctx);
8962
+ }
8575
8963
  }
8576
8964
 
8577
8965
  const uint32_t n_head_kv = neq2;
@@ -8585,8 +8973,29 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8585
8973
  vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
8586
8974
  vk_subbuffer mask_buf = mask ? ggml_vk_tensor_subbuffer(ctx, mask) : q_buf;
8587
8975
  vk_subbuffer sinks_buf = sinks ? ggml_vk_tensor_subbuffer(ctx, sinks) : q_buf;
8976
+ vk_subbuffer mask_opt_buf = use_mask_opt ? ggml_vk_subbuffer(ctx, ctx->prealloc_y, 0) : q_buf;
8977
+
8978
+ uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | n_head_log2;
8979
+
8980
+ if (use_mask_opt)
8981
+ {
8982
+ const vk_op_flash_attn_mask_opt_push_constants opt_pc = {
8983
+ nem0,
8984
+ nem1,
8985
+ nem2,
8986
+ (uint32_t)(mask->nb[1] / sizeof(ggml_fp16_t)),
8987
+ (uint32_t)(mask->nb[2] / sizeof(ggml_fp16_t)),
8988
+ (uint32_t)(mask->nb[3] / sizeof(ggml_fp16_t)),
8989
+ mask_opt_num_dwords,
8990
+ mask_opt_num_dwords * CEIL_DIV(nem1, Br),
8991
+ mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2,
8992
+ };
8588
8993
 
8589
- uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2;
8994
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline_fa_mask_opt,
8995
+ { mask_buf, mask_opt_buf }, opt_pc,
8996
+ { mask_opt_num_dwords, CEIL_DIV(nem1, Br), nem2 * nem3 });
8997
+ ggml_vk_sync_buffers(ctx, subctx);
8998
+ }
8590
8999
 
8591
9000
  const vk_flash_attn_push_constants pc = { N, KV,
8592
9001
  (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
@@ -8602,28 +9011,40 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8602
9011
  gqa_ratio, split_kv, split_k };
8603
9012
 
8604
9013
  if (split_k > 1) {
9014
+ ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
9015
+
8605
9016
  if (ctx->prealloc_split_k_need_sync) {
8606
9017
  ggml_vk_sync_buffers(ctx, subctx);
8607
9018
  }
8608
9019
 
9020
+ // We reuse workgroups_x to mean the number of splits, so we need to
9021
+ // cancel out the divide by wg_denoms[0].
9022
+ uint32_t dispatch_x;
9023
+ if (gqa_ratio > 1) {
9024
+ workgroups_x *= pipeline->wg_denoms[0];
9025
+ dispatch_x = split_k * workgroups_x;
9026
+ } else {
9027
+ dispatch_x = Tr * split_k * pipeline->wg_denoms[0];
9028
+ }
9029
+
8609
9030
  vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
8610
9031
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8611
- {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf},
8612
- // We only use split_k when group query attention is enabled, which means
8613
- // there's no more than one tile of rows (i.e. workgroups_x would have been
8614
- // one). We reuse workgroups_x to mean the number of splits, so we need to
8615
- // cancel out the divide by wg_denoms[0].
8616
- pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
9032
+ {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf},
9033
+ pc, { dispatch_x, workgroups_y, workgroups_z });
8617
9034
 
8618
9035
  ggml_vk_sync_buffers(ctx, subctx);
8619
- const std::array<uint32_t, 5> pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) };
9036
+ const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) };
8620
9037
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
8621
9038
  {split_k_buf, sinks_buf, dst_buf},
8622
- pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 });
9039
+ pc2, { (uint32_t)ne1, HSV, (uint32_t)(ne2 * ne3) });
8623
9040
  ctx->prealloc_split_k_need_sync = true;
8624
9041
  } else {
9042
+ if (gqa_ratio > 1) {
9043
+ // When using gqa, we want one actual workgroup per batch, so cancel out wg_denoms
9044
+ workgroups_x *= pipeline->wg_denoms[0];
9045
+ }
8625
9046
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8626
- {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf},
9047
+ {q_buf, k_buf, v_buf, mask_buf, sinks_buf, dst_buf, mask_opt_buf},
8627
9048
  pc, { workgroups_x, workgroups_y, workgroups_z });
8628
9049
  }
8629
9050
  }
@@ -8668,6 +9089,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
8668
9089
  return ctx->device->pipeline_acc_f32;
8669
9090
  }
8670
9091
  return nullptr;
9092
+ case GGML_OP_SET:
9093
+ if (src0->type == src1->type && src0->type == dst->type &&
9094
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
9095
+ return ctx->device->pipeline_set_f32;
9096
+ }
9097
+ return nullptr;
8671
9098
  case GGML_OP_ADD:
8672
9099
  case GGML_OP_SUB:
8673
9100
  case GGML_OP_MUL:
@@ -8869,6 +9296,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
8869
9296
  switch (ggml_get_unary_op(dst)) {
8870
9297
  case GGML_UNARY_OP_EXP:
8871
9298
  return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16];
9299
+ case GGML_UNARY_OP_ELU:
9300
+ return ctx->device->pipeline_elu[dst->type == GGML_TYPE_F16];
8872
9301
  case GGML_UNARY_OP_SILU:
8873
9302
  return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
8874
9303
  case GGML_UNARY_OP_GELU:
@@ -8905,6 +9334,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
8905
9334
  return ctx->device->pipeline_floor[dst->type == GGML_TYPE_F16];
8906
9335
  case GGML_UNARY_OP_TRUNC:
8907
9336
  return ctx->device->pipeline_trunc[dst->type == GGML_TYPE_F16];
9337
+ case GGML_UNARY_OP_SGN:
9338
+ return ctx->device->pipeline_sgn[dst->type == GGML_TYPE_F16];
8908
9339
  default:
8909
9340
  break;
8910
9341
  }
@@ -9098,6 +9529,20 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
9098
9529
  return ctx->device->pipeline_rwkv_wkv7_f32;
9099
9530
  }
9100
9531
  return nullptr;
9532
+ case GGML_OP_GATED_DELTA_NET:
9533
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9534
+ const uint32_t S_v = dst->src[2]->ne[0];
9535
+ const uint32_t kda = (dst->src[3]->ne[0] == (int64_t)S_v) ? 1 : 0;
9536
+ uint32_t si;
9537
+ switch (S_v) {
9538
+ case 32: si = 0; break;
9539
+ case 64: si = 1; break;
9540
+ case 128: si = 2; break;
9541
+ default: return nullptr;
9542
+ }
9543
+ return ctx->device->pipeline_gated_delta_net[si][kda];
9544
+ }
9545
+ return nullptr;
9101
9546
  case GGML_OP_SSM_SCAN:
9102
9547
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9103
9548
  const uint32_t d_state = src0->ne[0];
@@ -9654,16 +10099,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
9654
10099
  const uint32_t src1_type_size = ggml_type_size(src1->type);
9655
10100
  const uint32_t dst_type_size = ggml_type_size(dst->type);
9656
10101
 
9657
- int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
9658
- int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
9659
- // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
9660
- int offset = dst->op_params[3] / 4; // offset in bytes
10102
+ int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32
10103
+ int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32
10104
+ int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
10105
+ int offset = dst->op_params[3] / src0_type_size; // offset in bytes
9661
10106
 
9662
- ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
10107
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
9663
10108
  (uint32_t)ggml_nelements(src0),
9664
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
10109
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
9665
10110
  (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
9666
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
10111
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
9667
10112
  0,
9668
10113
  0.0f, 0.0f, offset,
9669
10114
  });
@@ -9928,6 +10373,59 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
9928
10373
  );
9929
10374
  }
9930
10375
 
10376
+ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10377
+ const ggml_tensor * src_q = dst->src[0];
10378
+ const ggml_tensor * src_v = dst->src[2];
10379
+ const ggml_tensor * src_beta = dst->src[4];
10380
+
10381
+ GGML_ASSERT(dst->buffer != nullptr);
10382
+
10383
+ const uint32_t S_v = (uint32_t)src_v->ne[0];
10384
+ const uint32_t H = (uint32_t)src_v->ne[1];
10385
+ const uint32_t n_tokens = (uint32_t)src_v->ne[2];
10386
+ const uint32_t n_seqs = (uint32_t)src_v->ne[3];
10387
+
10388
+ const uint32_t s_off = S_v * H * n_tokens * n_seqs;
10389
+
10390
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
10391
+ GGML_ASSERT(pipeline != nullptr);
10392
+
10393
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
10394
+
10395
+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
10396
+ vk_subbuffer src_buf[6] = {};
10397
+ for (int i = 0; i < 6; i++) {
10398
+ src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
10399
+ }
10400
+
10401
+ const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));
10402
+ const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));
10403
+ const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));
10404
+ const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));
10405
+ const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));
10406
+ const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));
10407
+ const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float));
10408
+ const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float));
10409
+ const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float));
10410
+
10411
+ const uint32_t neq1 = (uint32_t)src_q->ne[1];
10412
+ const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);
10413
+
10414
+ const float scale = 1.0f / sqrtf((float)S_v);
10415
+ const vk_op_gated_delta_net_push_constants pc = {
10416
+ H, n_tokens, n_seqs, s_off,
10417
+ sq1, sq2, sq3,
10418
+ sv1, sv2, sv3,
10419
+ sb1, sb2, sb3,
10420
+ neq1, rq3,
10421
+ scale
10422
+ };
10423
+
10424
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10425
+ {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
10426
+ pc, { H, n_seqs, 1u });
10427
+ }
10428
+
9931
10429
  static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
9932
10430
  const ggml_tensor * src0 = dst->src[0];
9933
10431
  const ggml_tensor * src1 = dst->src[1];
@@ -10335,12 +10833,22 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
10335
10833
 
10336
10834
  uint32_t nb01 = src0->nb[1] / ggml_type_size(src0->type);
10337
10835
  uint32_t nb02 = src0->nb[2] / ggml_type_size(src0->type);
10836
+ uint32_t nb03 = src0->nb[3] / ggml_type_size(src0->type);
10837
+
10838
+ uint32_t nb11 = dst->nb[1] / ggml_type_size(dst->type);
10839
+ uint32_t nb12 = dst->nb[2] / ggml_type_size(dst->type);
10840
+ uint32_t nb13 = dst->nb[3] / ggml_type_size(dst->type);
10338
10841
 
10339
10842
  vk_op_rope_push_constants rope {
10340
- (uint32_t)mode, (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
10341
- freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
10342
- has_ff, (uint32_t)src0->ne[2], nb01, nb02,
10843
+ (uint32_t)mode, (uint32_t)ggml_nrows(src0), (uint32_t)n_dims, freq_scale,
10844
+ freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, has_ff,
10343
10845
  { sections[0], sections[1], sections[2], sections[3] }, is_imrope, backprop, set_rows_stride,
10846
+
10847
+ (uint32_t)src0->ne[0],
10848
+ (uint32_t)src0->ne[1],
10849
+ (uint32_t)src0->ne[2],
10850
+ nb01, nb02, nb03,
10851
+ nb11, nb12, nb13,
10344
10852
  };
10345
10853
 
10346
10854
  return rope;
@@ -10467,8 +10975,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
10467
10975
  }
10468
10976
 
10469
10977
  static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
10470
- float * op_params = (float *)dst->op_params;
10471
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f });
10978
+ const float * op_params = (const float *)dst->op_params;
10979
+ vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
10980
+ p.param1 = op_params[0];
10981
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p));
10472
10982
  }
10473
10983
 
10474
10984
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@@ -11386,7 +11896,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
11386
11896
  }
11387
11897
  }
11388
11898
 
11389
- ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
11390
11899
  if (split_k > 1) {
11391
11900
  ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
11392
11901
 
@@ -11560,7 +12069,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
11560
12069
  free(d_chk);
11561
12070
 
11562
12071
  ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
11563
- ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
11564
12072
 
11565
12073
  ggml_vk_destroy_buffer(d_X);
11566
12074
  ggml_vk_destroy_buffer(d_Y);
@@ -11896,7 +12404,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
11896
12404
  // y[i] = i % k;
11897
12405
  }
11898
12406
 
11899
- ggml_pipeline_request_descriptor_sets(ctx, p, num_it);
11900
12407
  if (split_k > 1) {
11901
12408
  ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it);
11902
12409
 
@@ -11909,7 +12416,8 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
11909
12416
  }
11910
12417
  }
11911
12418
  if (mmq) {
11912
- ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it);
12419
+ vk_pipeline pipeline_quantize_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
12420
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline_quantize_q8_1, num_it);
11913
12421
  }
11914
12422
 
11915
12423
  ggml_pipeline_allocate_descriptor_sets(ctx);
@@ -12145,7 +12653,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
12145
12653
  ggml_vk_submit(subctx, {});
12146
12654
  ctx->submit_pending = true;
12147
12655
  ggml_vk_synchronize(ctx);
12656
+ GGML_ASSERT(ctx->compute_ctx.expired());
12148
12657
  ggml_vk_ctx_begin(ctx->device, subctx);
12658
+ ctx->compute_ctx = subctx;
12149
12659
  }
12150
12660
 
12151
12661
  if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) {
@@ -12163,6 +12673,7 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
12163
12673
  ggml_vk_destroy_buffer(ctx->prealloc_y);
12164
12674
  }
12165
12675
  ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
12676
+ ctx->prealloc_y_last_tensor_used = nullptr;
12166
12677
  }
12167
12678
  if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
12168
12679
  VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
@@ -12191,6 +12702,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12191
12702
  if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) {
12192
12703
  return false;
12193
12704
  }
12705
+ if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
12706
+ return false;
12707
+ }
12194
12708
 
12195
12709
  VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")");
12196
12710
  ctx->semaphore_idx = 0;
@@ -12215,15 +12729,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12215
12729
  }
12216
12730
  }
12217
12731
 
12218
- vk_context compute_ctx;
12219
-
12220
- if (ctx->compute_ctx.expired()) {
12221
- compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
12222
- ctx->compute_ctx = compute_ctx;
12223
- ggml_vk_ctx_begin(ctx->device, compute_ctx);
12224
- } else {
12225
- compute_ctx = ctx->compute_ctx.lock();
12226
- }
12732
+ vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
12227
12733
 
12228
12734
  {
12229
12735
  // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers
@@ -12294,7 +12800,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12294
12800
 
12295
12801
  if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
12296
12802
  ctx->query_node_idx[ctx->query_idx] = node_idx;
12297
- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
12803
+ compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
12298
12804
  }
12299
12805
  }
12300
12806
  // Add all fused nodes to the unsynchronized lists.
@@ -12337,6 +12843,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12337
12843
 
12338
12844
  break;
12339
12845
  case GGML_OP_ACC:
12846
+ case GGML_OP_SET:
12340
12847
  ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
12341
12848
 
12342
12849
  break;
@@ -12471,6 +12978,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12471
12978
  }
12472
12979
 
12473
12980
  switch (ggml_get_unary_op(node)) {
12981
+ case GGML_UNARY_OP_ELU:
12474
12982
  case GGML_UNARY_OP_EXP:
12475
12983
  case GGML_UNARY_OP_SILU:
12476
12984
  case GGML_UNARY_OP_GELU:
@@ -12489,6 +12997,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12489
12997
  case GGML_UNARY_OP_CEIL:
12490
12998
  case GGML_UNARY_OP_FLOOR:
12491
12999
  case GGML_UNARY_OP_TRUNC:
13000
+ case GGML_UNARY_OP_SGN:
12492
13001
  ggml_vk_unary(ctx, compute_ctx, src0, node);
12493
13002
  break;
12494
13003
  case GGML_UNARY_OP_XIELU:
@@ -12633,6 +13142,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12633
13142
 
12634
13143
  break;
12635
13144
 
13145
+ case GGML_OP_GATED_DELTA_NET:
13146
+ ggml_vk_gated_delta_net(ctx, compute_ctx, node);
13147
+
13148
+ break;
13149
+
12636
13150
  case GGML_OP_SSM_SCAN:
12637
13151
  ggml_vk_ssm_scan(ctx, compute_ctx, node);
12638
13152
 
@@ -12740,7 +13254,9 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
12740
13254
  ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false;
12741
13255
 
12742
13256
  ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool);
12743
- ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
13257
+ if (ctx->device->async_use_transfer_queue) {
13258
+ ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool);
13259
+ }
12744
13260
 
12745
13261
  for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) {
12746
13262
  ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s });
@@ -12769,7 +13285,7 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
12769
13285
  static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
12770
13286
  VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")");
12771
13287
  // discard any unsubmitted command buffers
12772
- ctx->transfer_ctx.reset();
13288
+ ctx->compute_ctx.reset();
12773
13289
  // wait for any pending command buffers to finish
12774
13290
  ggml_vk_synchronize(ctx);
12775
13291
 
@@ -12802,7 +13318,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
12802
13318
  ctx->descriptor_sets.clear();
12803
13319
 
12804
13320
  ctx->compute_cmd_pool.destroy(ctx->device->device);
12805
- ctx->transfer_cmd_pool.destroy(ctx->device->device);
13321
+ if (ctx->device->async_use_transfer_queue) {
13322
+ ctx->device->device.destroySemaphore(ctx->transfer_semaphore.s);
13323
+
13324
+ ctx->transfer_cmd_pool.destroy(ctx->device->device);
13325
+ }
12806
13326
  if (vk_perf_logger_enabled) {
12807
13327
  ctx->perf_logger->print_timings(true);
12808
13328
  }
@@ -12861,6 +13381,10 @@ static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, g
12861
13381
  ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
12862
13382
  vk_buffer buf = buf_ctx->dev_buffer;
12863
13383
 
13384
+ if (size == 0) {
13385
+ return;
13386
+ }
13387
+
12864
13388
  uint32_t val32 = (uint32_t)value * 0x01010101;
12865
13389
  ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size);
12866
13390
  }
@@ -12870,6 +13394,10 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
12870
13394
  ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
12871
13395
  vk_buffer buf = buf_ctx->dev_buffer;
12872
13396
 
13397
+ if (size == 0) {
13398
+ return;
13399
+ }
13400
+
12873
13401
  ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
12874
13402
  }
12875
13403
 
@@ -12877,12 +13405,20 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
12877
13405
  VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
12878
13406
  ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
12879
13407
 
13408
+ if (size == 0) {
13409
+ return;
13410
+ }
13411
+
12880
13412
  vk_buffer buf = buf_ctx->dev_buffer;
12881
13413
 
12882
13414
  ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
12883
13415
  }
12884
13416
 
12885
13417
  static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
13418
+ if (ggml_nbytes(src) == 0) {
13419
+ return true;
13420
+ }
13421
+
12886
13422
  if (ggml_backend_buffer_is_vk(src->buffer)) {
12887
13423
  ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
12888
13424
  ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
@@ -13072,36 +13608,44 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
13072
13608
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13073
13609
  GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
13074
13610
 
13611
+ if (size == 0) {
13612
+ return;
13613
+ }
13614
+
13075
13615
  ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
13076
13616
 
13077
- vk_context transfer_ctx;
13617
+ vk_context cpy_ctx;
13078
13618
 
13079
- if (ctx->transfer_ctx.expired()) {
13080
- // Initialize new transfer context
13081
- transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13082
- ctx->transfer_ctx = transfer_ctx;
13083
- ggml_vk_ctx_begin(ctx->device, transfer_ctx);
13619
+ if (ctx->device->async_use_transfer_queue) {
13620
+ if (ctx->transfer_ctx.expired()) {
13621
+ // Initialize new transfer context
13622
+ cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
13623
+ ctx->transfer_ctx = cpy_ctx;
13624
+ ggml_vk_ctx_begin(ctx->device, cpy_ctx);
13625
+ } else {
13626
+ cpy_ctx = ctx->transfer_ctx.lock();
13627
+ }
13084
13628
  } else {
13085
- transfer_ctx = ctx->transfer_ctx.lock();
13629
+ cpy_ctx = ggml_vk_get_compute_ctx(ctx);
13086
13630
  }
13087
13631
 
13088
13632
  vk_buffer buf = buf_ctx->dev_buffer;
13089
13633
 
13090
13634
  auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
13091
13635
 
13092
- bool ret = ggml_vk_buffer_write_async(transfer_ctx, buf, dst_offset, data, size);
13636
+ bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size);
13093
13637
 
13094
13638
  if (!ret) {
13095
13639
  ggml_vk_ensure_sync_staging_buffer(ctx, size);
13096
- ggml_vk_sync_buffers(nullptr, transfer_ctx);
13640
+ ggml_vk_sync_buffers(nullptr, cpy_ctx);
13097
13641
 
13098
13642
  vk::BufferCopy buffer_cpy;
13099
13643
  buffer_cpy.srcOffset = 0;
13100
13644
  buffer_cpy.dstOffset = dst_offset;
13101
13645
  buffer_cpy.size = size;
13102
13646
 
13103
- transfer_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
13104
- deferred_memcpy(ctx->sync_staging->ptr, data, size, &transfer_ctx->in_memcpys);
13647
+ cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
13648
+ deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys);
13105
13649
  ggml_vk_synchronize(ctx);
13106
13650
  }
13107
13651
  }
@@ -13111,101 +13655,156 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
13111
13655
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13112
13656
  GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
13113
13657
 
13114
- ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
13658
+ if (size == 0) {
13659
+ return;
13660
+ }
13115
13661
 
13116
- vk_context transfer_ctx;
13662
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context;
13117
13663
 
13118
- if (ctx->transfer_ctx.expired()) {
13119
- // Initialize new transfer context
13120
- transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13121
- ctx->transfer_ctx = transfer_ctx;
13122
- ggml_vk_ctx_begin(ctx->device, transfer_ctx);
13123
- } else {
13124
- transfer_ctx = ctx->transfer_ctx.lock();
13125
- }
13664
+ vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
13126
13665
 
13127
13666
  vk_buffer buf = buf_ctx->dev_buffer;
13128
13667
 
13129
13668
  auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
13130
- bool ret = ggml_vk_buffer_read_async(transfer_ctx, buf, src_offset, data, size);
13669
+ bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
13131
13670
 
13132
13671
  // If that failed, copy synchronously through a staging buffer
13133
13672
  if (!ret) {
13134
13673
  ggml_vk_ensure_sync_staging_buffer(ctx, size);
13135
- ggml_vk_sync_buffers(nullptr, transfer_ctx);
13674
+ ggml_vk_sync_buffers(nullptr, compute_ctx);
13136
13675
 
13137
13676
  vk::BufferCopy buffer_cpy;
13138
13677
  buffer_cpy.srcOffset = src_offset;
13139
13678
  buffer_cpy.dstOffset = 0;
13140
13679
  buffer_cpy.size = size;
13141
13680
 
13142
- transfer_ctx->s->buffer.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
13143
- deferred_memcpy(data, ctx->sync_staging->ptr, size, &transfer_ctx->out_memcpys);
13681
+ compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
13682
+ deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
13144
13683
  ggml_vk_synchronize(ctx);
13145
13684
  }
13146
13685
  }
13147
13686
 
13148
- static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) {
13149
- VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()");
13150
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13151
- if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) {
13152
- ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
13153
- ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
13687
+ static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
13688
+ VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")");
13689
+ ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
13154
13690
 
13155
- vk_context transfer_ctx;
13691
+ // Skip zero-size tensors
13692
+ if (ggml_nbytes(src) == 0) {
13693
+ return true;
13694
+ }
13156
13695
 
13157
- if (ctx->transfer_ctx.expired()) {
13158
- // Initialize new transfer context
13159
- transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13160
- ctx->transfer_ctx = transfer_ctx;
13161
- ggml_vk_ctx_begin(ctx->device, transfer_ctx);
13162
- } else {
13163
- transfer_ctx = ctx->transfer_ctx.lock();
13696
+ if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) {
13697
+ return false;
13698
+ }
13699
+
13700
+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
13701
+ vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
13702
+
13703
+ if (ggml_backend_buffer_is_vk(src->buffer)) {
13704
+ ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context;
13705
+
13706
+ // Async copy only works within the same device
13707
+ if (src_buf_ctx->dev_buffer->device != dst_buf->device) {
13708
+ return false;
13164
13709
  }
13165
13710
 
13166
- vk_buffer src_buf = src_buf_ctx->dev_buffer;
13167
- vk_buffer dst_buf = dst_buf_ctx->dev_buffer;
13711
+ vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
13168
13712
 
13169
- ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src));
13713
+ ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs,
13714
+ src_buf_ctx->dev_buffer, vk_tensor_offset(src) + src->view_offs,
13715
+ ggml_nbytes(src));
13170
13716
  return true;
13171
13717
  }
13172
13718
 
13719
+ if (ggml_backend_buffer_is_host(src->buffer)) {
13720
+ vk_buffer pinned_buf = nullptr;
13721
+ size_t pinned_offset = 0;
13722
+ ggml_vk_host_get(ctx->device, src->data, pinned_buf, pinned_offset);
13723
+ if (pinned_buf == nullptr) {
13724
+ return false;
13725
+ }
13726
+
13727
+ vk_context cpy_ctx;
13728
+ if (ctx->device->async_use_transfer_queue) {
13729
+ if (ctx->transfer_ctx.expired()) {
13730
+ cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
13731
+ ctx->transfer_ctx = cpy_ctx;
13732
+ ggml_vk_ctx_begin(ctx->device, cpy_ctx);
13733
+ } else {
13734
+ cpy_ctx = ctx->transfer_ctx.lock();
13735
+ }
13736
+ } else {
13737
+ cpy_ctx = ggml_vk_get_compute_ctx(ctx);
13738
+ }
13739
+
13740
+ return ggml_vk_buffer_write_async(cpy_ctx, dst_buf,
13741
+ vk_tensor_offset(dst) + dst->view_offs,
13742
+ src->data, ggml_nbytes(src));
13743
+ }
13744
+
13745
+ GGML_UNUSED(backend_src);
13173
13746
  return false;
13174
13747
  }
13175
13748
 
13176
13749
  static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
13177
13750
  VK_LOG_DEBUG("ggml_vk_synchronize()");
13178
13751
 
13179
- bool do_transfer = !ctx->transfer_ctx.expired();
13752
+ bool do_transfer = !ctx->compute_ctx.expired();
13180
13753
 
13181
- vk_context transfer_ctx;
13754
+ if (ggml_vk_submit_transfer_ctx(ctx)) {
13755
+ ctx->submit_pending = true;
13756
+ }
13757
+
13758
+ vk_context compute_ctx;
13759
+ vk_command_buffer* cmd_buf = nullptr;
13182
13760
  if (do_transfer) {
13183
- transfer_ctx = ctx->transfer_ctx.lock();
13761
+ compute_ctx = ctx->compute_ctx.lock();
13762
+ if (compute_ctx->s) {
13763
+ cmd_buf = compute_ctx->s->buffer;
13764
+ }
13184
13765
 
13185
- ggml_vk_ctx_end(transfer_ctx);
13766
+ ggml_vk_ctx_end(compute_ctx);
13186
13767
 
13187
- for (auto& cpy : transfer_ctx->in_memcpys) {
13768
+ for (auto& cpy : compute_ctx->in_memcpys) {
13188
13769
  memcpy(cpy.dst, cpy.src, cpy.n);
13189
13770
  }
13190
13771
 
13191
- ggml_vk_submit(transfer_ctx, {});
13772
+ ggml_vk_submit(compute_ctx, {});
13192
13773
  ctx->submit_pending = true;
13193
13774
  }
13194
13775
 
13195
13776
  if (ctx->submit_pending) {
13196
- {
13777
+ if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
13778
+ vk::TimelineSemaphoreSubmitInfo tl_info{
13779
+ 1, &ctx->transfer_semaphore.value,
13780
+ 0, nullptr,
13781
+ };
13782
+ vk::PipelineStageFlags stage = ctx->device->transfer_queue.stage_flags;
13783
+ vk::SubmitInfo si{
13784
+ 1, &ctx->transfer_semaphore.s, &stage,
13785
+ 0, nullptr,
13786
+ 0, nullptr,
13787
+ };
13788
+ si.setPNext(&tl_info);
13789
+ std::lock_guard<std::mutex> guard(queue_mutex);
13790
+ ctx->device->compute_queue.queue.submit({ si }, ctx->fence);
13791
+ ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value;
13792
+ } else {
13197
13793
  std::lock_guard<std::mutex> guard(queue_mutex);
13198
13794
  ctx->device->compute_queue.queue.submit({}, ctx->fence);
13199
13795
  }
13200
13796
  ggml_vk_wait_for_fence(ctx);
13201
13797
  ctx->submit_pending = false;
13798
+ if (cmd_buf) {
13799
+ cmd_buf->in_use = false;
13800
+ }
13202
13801
  }
13203
13802
 
13204
13803
  if (do_transfer) {
13205
- for (auto& cpy : transfer_ctx->out_memcpys) {
13804
+ for (auto& cpy : compute_ctx->out_memcpys) {
13206
13805
  memcpy(cpy.dst, cpy.src, cpy.n);
13207
13806
  }
13208
- ctx->transfer_ctx.reset();
13807
+ ctx->compute_ctx.reset();
13209
13808
  }
13210
13809
  }
13211
13810
 
@@ -13505,12 +14104,11 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
13505
14104
  return true;
13506
14105
  }
13507
14106
 
13508
- // Check whether the tensors overlap in memory but are not equal.
13509
- // Fusions can potenitally overwrite src tensors in ways that are not prevented
13510
- // by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them
13511
- // to overlap if they are exactly equal.
13512
- // XXX TODO this check is probably missing from several fusion optimizations.
13513
- static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) {
14107
+ // Check whether the tensors overlap in memory.
14108
+ // Fusions can potentially overwrite src tensors in ways that are not prevented
14109
+ // by ggml-alloc. If the fusion src is being applied in a way that's elementwise
14110
+ // with the destination, then it's OK for them to overlap if they are exactly equal.
14111
+ static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) {
13514
14112
  ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context;
13515
14113
  vk_buffer a_buf = a_buf_ctx->dev_buffer;
13516
14114
  ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context;
@@ -13521,7 +14119,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g
13521
14119
  auto b_base = vk_tensor_offset(b) + b->view_offs;
13522
14120
  auto b_size = ggml_nbytes(b);
13523
14121
 
13524
- if (a_base == b_base && a_size == b_size) {
14122
+ if (elementwise && a_base == b_base && a_size == b_size) {
13525
14123
  return false;
13526
14124
  }
13527
14125
 
@@ -13559,13 +14157,6 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
13559
14157
  return false;
13560
14158
  }
13561
14159
 
13562
- // must not overwrite srcs in a way that's not elementwise
13563
- ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0];
13564
- if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) ||
13565
- ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) {
13566
- return false;
13567
- }
13568
-
13569
14160
  // conditions for pipeline creation
13570
14161
  if (!(ctx->device->float_controls_rte_fp16 &&
13571
14162
  sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
@@ -13627,6 +14218,18 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru
13627
14218
  return num_adds;
13628
14219
  }
13629
14220
 
14221
+ static int32_t find_first_set(uint32_t x) {
14222
+ int32_t ret = 0;
14223
+ if (!x) {
14224
+ return -1;
14225
+ }
14226
+ while (!(x & 1)) {
14227
+ x >>= 1;
14228
+ ret++;
14229
+ }
14230
+ return ret;
14231
+ }
14232
+
13630
14233
  static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
13631
14234
  VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
13632
14235
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
@@ -13645,7 +14248,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13645
14248
  int last_node = cgraph->n_nodes - 1;
13646
14249
 
13647
14250
  // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly
13648
- while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) {
14251
+ while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) {
13649
14252
  last_node -= 1;
13650
14253
  }
13651
14254
 
@@ -13655,6 +14258,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13655
14258
  bool first_node_in_batch = true; // true if next node will be first node in a batch
13656
14259
  int submit_node_idx = 0; // index to first node in a batch
13657
14260
 
14261
+ ggml_vk_submit_transfer_ctx(ctx);
14262
+
13658
14263
  vk_context compute_ctx;
13659
14264
  if (vk_perf_logger_enabled) {
13660
14265
  // allocate/resize the query pool
@@ -13680,11 +14285,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13680
14285
  std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0);
13681
14286
 
13682
14287
  GGML_ASSERT(ctx->compute_ctx.expired());
13683
- compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13684
- ctx->compute_ctx = compute_ctx;
13685
- ggml_vk_ctx_begin(ctx->device, compute_ctx);
14288
+ compute_ctx = ggml_vk_get_compute_ctx(ctx);
13686
14289
  ctx->query_idx = 0;
13687
- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
14290
+ compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
13688
14291
  }
13689
14292
 
13690
14293
  ctx->prealloc_y_last_pipeline_used = nullptr;
@@ -13692,13 +14295,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13692
14295
 
13693
14296
  if (ctx->prealloc_size_add_rms_partials) {
13694
14297
  ggml_vk_preallocate_buffers(ctx, nullptr);
13695
- if (ctx->compute_ctx.expired()) {
13696
- compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13697
- ctx->compute_ctx = compute_ctx;
13698
- ggml_vk_ctx_begin(ctx->device, compute_ctx);
13699
- } else {
13700
- compute_ctx = ctx->compute_ctx.lock();
13701
- }
14298
+ compute_ctx = ggml_vk_get_compute_ctx(ctx);
13702
14299
  // initialize partial sums to zero.
13703
14300
  ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials);
13704
14301
  ggml_vk_sync_buffers(ctx, compute_ctx);
@@ -13725,6 +14322,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13725
14322
  total_mul_mat_bytes += bytes;
13726
14323
  }
13727
14324
 
14325
+ // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to
14326
+ // the fused result in an elementwise-way. This affects whether the memory for
14327
+ // the src is allowed to overlap the memory for the destination.
14328
+ // The array is sized to handle the largest fusion (asserted later).
14329
+ bool op_srcs_fused_elementwise[12];
14330
+
13728
14331
  ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
13729
14332
  ctx->fused_topk_moe_scale = false;
13730
14333
  const char *fusion_string {};
@@ -13733,39 +14336,68 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13733
14336
  if (num_adds) {
13734
14337
  ctx->num_additional_fused_ops = num_adds - 1;
13735
14338
  fusion_string = "MULTI_ADD";
14339
+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true);
13736
14340
  } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) {
13737
14341
  ctx->num_additional_fused_ops = 2;
13738
14342
  fusion_string = "MUL_MAT_ADD_ADD";
14343
+ op_srcs_fused_elementwise[0] = false;
14344
+ op_srcs_fused_elementwise[1] = true;
14345
+ op_srcs_fused_elementwise[2] = true;
13739
14346
  } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) {
13740
14347
  ctx->num_additional_fused_ops = 1;
13741
14348
  fusion_string = "MUL_MAT_ADD";
14349
+ op_srcs_fused_elementwise[0] = false;
14350
+ op_srcs_fused_elementwise[1] = true;
13742
14351
  } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) {
13743
14352
  ctx->num_additional_fused_ops = 2;
13744
14353
  fusion_string = "MUL_MAT_ID_ADD_ID_MUL";
14354
+ op_srcs_fused_elementwise[0] = false;
14355
+ op_srcs_fused_elementwise[1] = true;
14356
+ op_srcs_fused_elementwise[2] = true;
13745
14357
  } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) {
13746
14358
  ctx->num_additional_fused_ops = 1;
13747
14359
  fusion_string = "MUL_MAT_ID_ADD_ID";
14360
+ op_srcs_fused_elementwise[0] = false;
14361
+ op_srcs_fused_elementwise[1] = true;
13748
14362
  } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) {
13749
14363
  ctx->num_additional_fused_ops = 1;
13750
14364
  fusion_string = "MUL_MAT_ID_MUL";
14365
+ op_srcs_fused_elementwise[0] = false;
14366
+ op_srcs_fused_elementwise[1] = true;
13751
14367
  } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) &&
13752
14368
  ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) &&
13753
14369
  ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) &&
13754
14370
  ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) {
13755
14371
  ctx->num_additional_fused_ops = 4;
13756
14372
  fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS";
14373
+ op_srcs_fused_elementwise[0] = false;
14374
+ op_srcs_fused_elementwise[1] = false;
14375
+ op_srcs_fused_elementwise[2] = false;
14376
+ op_srcs_fused_elementwise[3] = false;
14377
+ op_srcs_fused_elementwise[4] = false;
13757
14378
  } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&&
13758
14379
  ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) {
13759
14380
  ctx->num_additional_fused_ops = 2;
13760
14381
  fusion_string = "RMS_NORM_MUL_ROPE";
14382
+ // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise
14383
+ op_srcs_fused_elementwise[0] = false;
14384
+ op_srcs_fused_elementwise[1] = true;
14385
+ op_srcs_fused_elementwise[2] = true;
13761
14386
  } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
13762
14387
  ctx->num_additional_fused_ops = 1;
13763
14388
  fusion_string = "RMS_NORM_MUL";
14389
+ // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before
14390
+ // they are overwritten, and one workgroup per row. So close enough.
14391
+ op_srcs_fused_elementwise[0] = true;
14392
+ op_srcs_fused_elementwise[1] = true;
13764
14393
  } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
13765
14394
  ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
13766
14395
  ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
13767
14396
  ctx->num_additional_fused_ops = 2;
13768
14397
  fusion_string = "ROPE_VIEW_SET_ROWS";
14398
+ op_srcs_fused_elementwise[0] = false;
14399
+ op_srcs_fused_elementwise[1] = false;
14400
+ op_srcs_fused_elementwise[2] = false;
13769
14401
  } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
13770
14402
  ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
13771
14403
  ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
@@ -13774,6 +14406,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13774
14406
  ctx->fused_ops_write_mask |= 1 << 3;
13775
14407
  ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
13776
14408
  fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
14409
+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
13777
14410
  } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
13778
14411
  ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
13779
14412
  ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
@@ -13782,6 +14415,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13782
14415
  ctx->fused_ops_write_mask |= 1 << 4;
13783
14416
  ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
13784
14417
  fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
14418
+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
13785
14419
  } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
13786
14420
  ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
13787
14421
  ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
@@ -13790,6 +14424,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13790
14424
  ctx->fused_ops_write_mask |= 1 << 3;
13791
14425
  ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
13792
14426
  fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
14427
+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
13793
14428
  } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
13794
14429
  ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
13795
14430
  ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) {
@@ -13798,6 +14433,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13798
14433
  ctx->fused_ops_write_mask |= 1 << 1;
13799
14434
  ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
13800
14435
  fusion_string = "TOPK_MOE_LATE_SOFTMAX";
14436
+ std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false);
13801
14437
  }
13802
14438
  if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
13803
14439
  // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
@@ -13805,11 +14441,73 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13805
14441
  ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
13806
14442
  ctx->fused_topk_moe_scale = true;
13807
14443
  ctx->num_additional_fused_ops++;
14444
+ op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false;
13808
14445
  }
13809
14446
  }
13810
14447
  }
14448
+ GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0])));
13811
14449
  ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
13812
14450
 
14451
+ // Check whether fusion would overwrite src operands while they're still in use.
14452
+ // If so, disable fusion.
14453
+ if (ctx->num_additional_fused_ops) {
14454
+ // There are up to two output nodes - topk_moe has two.
14455
+ uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops);
14456
+ ggml_tensor *output_nodes[2] {};
14457
+ output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops];
14458
+ if (bits) {
14459
+ int output_idx = find_first_set(bits);
14460
+ GGML_ASSERT(bits == (1u << output_idx));
14461
+ output_nodes[1] = cgraph->nodes[i + output_idx];
14462
+ }
14463
+
14464
+ bool need_disable = false;
14465
+
14466
+ // topk_moe often overwrites the source, but for a given row all the src values are
14467
+ // loaded before anything is stored. If there's only one row, this is safe, so treat
14468
+ // this as a special case.
14469
+ bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT &&
14470
+ ggml_nrows(cgraph->nodes[i]->src[0]) == 1;
14471
+
14472
+ if (!is_topk_moe_single_row) {
14473
+ for (int j = 0; j < 2; ++j) {
14474
+ ggml_tensor *dst = output_nodes[j];
14475
+ if (!dst) {
14476
+ continue;
14477
+ }
14478
+ // Loop over all srcs of all nodes in the fusion. If the src overlaps
14479
+ // the destination and the src is not an intermediate node that's being
14480
+ // elided, then disable fusion.
14481
+ for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) {
14482
+ for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
14483
+ ggml_tensor *src = cgraph->nodes[i + k]->src[s];
14484
+ if (!src || src->op == GGML_OP_NONE) {
14485
+ continue;
14486
+ }
14487
+ if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) {
14488
+ bool found = false;
14489
+ for (int n = 0; n < k; ++n) {
14490
+ if (cgraph->nodes[i + n] == src) {
14491
+ found = true;
14492
+ break;
14493
+ }
14494
+ }
14495
+ if (!found) {
14496
+ need_disable = true;
14497
+ }
14498
+ }
14499
+ }
14500
+ }
14501
+ }
14502
+ }
14503
+ if (need_disable) {
14504
+ ctx->num_additional_fused_ops = 0;
14505
+ ctx->fused_ops_write_mask = 1;
14506
+ ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
14507
+ ctx->fused_topk_moe_scale = false;
14508
+ }
14509
+ }
14510
+
13813
14511
  // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
13814
14512
  bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
13815
14513
  bool submit = (submitted_nodes >= nodes_per_submit) ||
@@ -13820,18 +14518,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13820
14518
  bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit);
13821
14519
 
13822
14520
  if (vk_perf_logger_enabled && enqueued) {
13823
- if (ctx->compute_ctx.expired()) {
13824
- compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
13825
- ctx->compute_ctx = compute_ctx;
13826
- ggml_vk_ctx_begin(ctx->device, compute_ctx);
13827
- } else {
13828
- compute_ctx = ctx->compute_ctx.lock();
13829
- }
14521
+ compute_ctx = ggml_vk_get_compute_ctx(ctx);
13830
14522
  if (!vk_perf_logger_concurrent) {
13831
14523
  // track a single node/fusion for the current query
13832
14524
  ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
13833
14525
  ctx->query_fusion_names[ctx->query_idx] = fusion_string;
13834
- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
14526
+ compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
13835
14527
  } else {
13836
14528
  // track a fusion string and number of fused ops for the current node_idx
13837
14529
  ctx->query_fusion_names[i] = fusion_string;
@@ -13874,6 +14566,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
13874
14566
  ggml_vk_submit(compute_ctx, ctx->device->fence);
13875
14567
  VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences");
13876
14568
  ctx->device->device.resetFences({ ctx->device->fence });
14569
+ ctx->compute_ctx.reset();
13877
14570
 
13878
14571
  // Get the results and pass them to the logger
13879
14572
  std::vector<uint64_t> timestamps(cgraph->n_nodes + 1);
@@ -14160,29 +14853,24 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
14160
14853
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
14161
14854
  vk_event *vkev = (vk_event *)event->context;
14162
14855
 
14163
- vk_context transfer_ctx;
14856
+ ggml_vk_submit_transfer_ctx(ctx);
14164
14857
 
14165
- if (ctx->transfer_ctx.expired()) {
14166
- // Initialize new transfer context
14167
- transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
14168
- ctx->transfer_ctx = transfer_ctx;
14169
- ggml_vk_ctx_begin(ctx->device, transfer_ctx);
14170
- } else {
14171
- transfer_ctx = ctx->transfer_ctx.lock();
14172
- }
14858
+ vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
14859
+ auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset
14173
14860
 
14174
14861
  // the backend interface doesn't have an explicit reset, so reset it here
14175
14862
  // before we record the command to set it
14176
14863
  ctx->device->device.resetEvent(vkev->event);
14177
14864
  ctx->device->device.resetFences({ vkev->fence });
14178
14865
 
14179
- ggml_vk_set_event(transfer_ctx, vkev->event);
14866
+ ggml_vk_set_event(compute_ctx, vkev->event);
14180
14867
 
14181
- ggml_vk_ctx_end(transfer_ctx);
14868
+ ggml_vk_ctx_end(compute_ctx);
14182
14869
 
14183
- ggml_vk_submit(transfer_ctx, {vkev->fence});
14870
+ ggml_vk_submit(compute_ctx, {vkev->fence});
14184
14871
  ctx->submit_pending = true;
14185
- ctx->transfer_ctx.reset();
14872
+ vkev->cmd_buffer = cmd_buf;
14873
+ ctx->compute_ctx.reset();
14186
14874
  }
14187
14875
 
14188
14876
  static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
@@ -14190,20 +14878,11 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
14190
14878
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
14191
14879
  vk_event *vkev = (vk_event *)event->context;
14192
14880
 
14193
- vk_context transfer_ctx;
14194
-
14195
- if (ctx->transfer_ctx.expired()) {
14196
- // Initialize new transfer context
14197
- transfer_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
14198
- ctx->transfer_ctx = transfer_ctx;
14199
- ggml_vk_ctx_begin(ctx->device, transfer_ctx);
14200
- } else {
14201
- transfer_ctx = ctx->transfer_ctx.lock();
14202
- }
14881
+ vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
14203
14882
 
14204
- ggml_vk_wait_events(transfer_ctx, {vkev->event});
14205
- ggml_vk_ctx_end(transfer_ctx);
14206
- ctx->transfer_ctx.reset();
14883
+ ggml_vk_wait_events(compute_ctx, {vkev->event});
14884
+ ggml_vk_ctx_end(compute_ctx);
14885
+ ctx->compute_ctx.reset();
14207
14886
  }
14208
14887
 
14209
14888
  // TODO: enable async and synchronize
@@ -14212,7 +14891,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
14212
14891
  /* .free = */ ggml_backend_vk_free,
14213
14892
  /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
14214
14893
  /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
14215
- /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
14894
+ /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async,
14216
14895
  /* .synchronize = */ ggml_backend_vk_synchronize,
14217
14896
  /* .graph_plan_create = */ NULL,
14218
14897
  /* .graph_plan_free = */ NULL,
@@ -14413,13 +15092,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
14413
15092
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
14414
15093
  const vk_device& device = ggml_vk_get_device(ctx->device);
14415
15094
 
15095
+ const bool uses_bda = (op->op == GGML_OP_IM2COL || op->op == GGML_OP_IM2COL_3D) &&
15096
+ device->shader_int64 && device->buffer_device_address;
15097
+
15098
+ auto const & tensor_size_supported = [&](size_t tensor_size) {
15099
+ if (tensor_size > device->max_buffer_size) {
15100
+ return false;
15101
+ }
15102
+ // For im2col shaders using BDA, maxStorageBufferRange limit doesn't apply.
15103
+ // If shader64BitIndexing is enabled, maxStorageBufferRange limit doesn't apply.
15104
+ if (!uses_bda && !device->shader_64b_indexing) {
15105
+ if (tensor_size > device->properties.limits.maxStorageBufferRange) {
15106
+ return false;
15107
+ }
15108
+ }
15109
+ return true;
15110
+ };
14416
15111
  // reject any tensors larger than the max buffer size
14417
15112
  for (int i = 0; i < GGML_MAX_SRC; i++) {
14418
- if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) {
15113
+ if (op->src[i] && !tensor_size_supported(ggml_nbytes(op->src[i]))) {
14419
15114
  return false;
14420
15115
  }
14421
15116
  }
14422
- if (ggml_nbytes(op) > device->max_buffer_size) {
15117
+ if (!tensor_size_supported(ggml_nbytes(op))) {
14423
15118
  return false;
14424
15119
  }
14425
15120
 
@@ -14427,6 +15122,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
14427
15122
  case GGML_OP_UNARY:
14428
15123
  switch (ggml_get_unary_op(op)) {
14429
15124
  case GGML_UNARY_OP_EXP:
15125
+ case GGML_UNARY_OP_ELU:
14430
15126
  case GGML_UNARY_OP_GELU:
14431
15127
  case GGML_UNARY_OP_GELU_ERF:
14432
15128
  case GGML_UNARY_OP_GELU_QUICK:
@@ -14445,6 +15141,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
14445
15141
  case GGML_UNARY_OP_CEIL:
14446
15142
  case GGML_UNARY_OP_FLOOR:
14447
15143
  case GGML_UNARY_OP_TRUNC:
15144
+ case GGML_UNARY_OP_SGN:
14448
15145
  return ggml_is_contiguous(op->src[0]) &&
14449
15146
  (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
14450
15147
  (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
@@ -14707,6 +15404,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
14707
15404
  case GGML_OP_REPEAT_BACK:
14708
15405
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
14709
15406
  case GGML_OP_ROPE:
15407
+ return ggml_is_contiguous_rows(op) && ggml_is_contiguous_rows(op->src[0]);
14710
15408
  case GGML_OP_ROPE_BACK:
14711
15409
  case GGML_OP_NONE:
14712
15410
  case GGML_OP_RESHAPE:
@@ -14717,8 +15415,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
14717
15415
  return true;
14718
15416
  case GGML_OP_NORM:
14719
15417
  case GGML_OP_GROUP_NORM:
14720
- case GGML_OP_L2_NORM:
14721
15418
  return ggml_is_contiguous(op->src[0]);
15419
+ case GGML_OP_L2_NORM:
15420
+ return ggml_is_contiguous_rows(op->src[0]) &&
15421
+ op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
14722
15422
  case GGML_OP_ADD:
14723
15423
  case GGML_OP_SUB:
14724
15424
  case GGML_OP_MUL:
@@ -14781,7 +15481,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
14781
15481
  }
14782
15482
  return op->src[0]->type == GGML_TYPE_F32;
14783
15483
  case GGML_OP_ACC:
14784
- return op->src[0]->type == GGML_TYPE_F32;
15484
+ return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
15485
+ case GGML_OP_SET:
15486
+ return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
15487
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
14785
15488
  case GGML_OP_CONCAT:
14786
15489
  return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
14787
15490
  case GGML_OP_ADD1:
@@ -14855,6 +15558,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
14855
15558
  case GGML_OP_RWKV_WKV6:
14856
15559
  case GGML_OP_RWKV_WKV7:
14857
15560
  return true; // all inputs are contiguous, see ggml.c
15561
+ case GGML_OP_GATED_DELTA_NET:
15562
+ {
15563
+ const uint32_t S_v = op->src[2]->ne[0];
15564
+ if (S_v != 32 && S_v != 64 && S_v != 128) {
15565
+ return false;
15566
+ }
15567
+ for (int i = 0; i < 6; i++) {
15568
+ if (op->src[i] == nullptr || op->src[i]->type != GGML_TYPE_F32) {
15569
+ return false;
15570
+ }
15571
+ }
15572
+ return op->type == GGML_TYPE_F32;
15573
+ }
14858
15574
  case GGML_OP_SSM_SCAN:
14859
15575
  {
14860
15576
  for (int i = 0; i < 6; i++) {
@@ -14926,11 +15642,25 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba
14926
15642
  return buft_ctx->device->idx == ctx->device;
14927
15643
  }
14928
15644
 
15645
+ static int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op) {
15646
+ switch (op->op) {
15647
+ case GGML_OP_GET_ROWS:
15648
+ return 0;
15649
+ case GGML_OP_MUL_MAT:
15650
+ return op->ne[1];
15651
+ case GGML_OP_MUL_MAT_ID:
15652
+ case GGML_OP_ROPE:
15653
+ case GGML_OP_ROPE_BACK:
15654
+ return op->ne[2];
15655
+ default:
15656
+ return ggml_nrows(op);
15657
+ }
15658
+ }
15659
+
14929
15660
  static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
14930
15661
  ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context;
14931
15662
 
14932
- return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) ||
14933
- (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
15663
+ return ggml_vk_get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size;
14934
15664
  }
14935
15665
 
14936
15666
  static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) {
@@ -14972,6 +15702,10 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
14972
15702
  vk_event *vkev = (vk_event *)event->context;
14973
15703
 
14974
15704
  VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
15705
+ // Finished using current command buffer so we flag for reuse
15706
+ if (vkev->cmd_buffer) {
15707
+ vkev->cmd_buffer->in_use = false;
15708
+ }
14975
15709
  }
14976
15710
 
14977
15711
  static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
@@ -15190,6 +15924,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
15190
15924
  }
15191
15925
  }
15192
15926
 
15927
+ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) {
15928
+ VkPhysicalDeviceProperties2 props = vkdev.getProperties2();
15929
+
15930
+ if (props.properties.vendorID != VK_VENDOR_ID_INTEL) {
15931
+ return 0;
15932
+ }
15933
+
15934
+ const uint32_t device_id = props.properties.deviceID;
15935
+
15936
+ switch (device_id) {
15937
+ case 0x56A6: // A310
15938
+ return 6;
15939
+ case 0x5693: // A370M
15940
+ case 0x56A5: // A380
15941
+ case 0x56B1: // Pro A40/A50
15942
+ return 8;
15943
+ case 0x5697: // A530M
15944
+ return 12;
15945
+ case 0x5692: // A550M
15946
+ case 0x56B3: // Pro A60
15947
+ return 16;
15948
+ case 0x56A2: // A580
15949
+ return 24;
15950
+ case 0x5691: // A730M
15951
+ case 0x56A1: // A750
15952
+ return 28;
15953
+ case 0x56A0: // A770
15954
+ case 0x5690: // A770M
15955
+ return 32;
15956
+ case 0xE212: // Pro B50
15957
+ return 16;
15958
+ case 0xE20C: // B570
15959
+ return 18;
15960
+ case 0xE20B: // B580
15961
+ return 20;
15962
+ default:
15963
+ return 0;
15964
+ }
15965
+ }
15966
+
15193
15967
  // checks
15194
15968
 
15195
15969
  #ifdef GGML_VULKAN_CHECK_RESULTS
@@ -15403,7 +16177,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
15403
16177
  tensor_clone = ggml_arange(ggml_ctx, start, stop, step);
15404
16178
  } else if (tensor->op == GGML_OP_FILL) {
15405
16179
  const float value = ggml_get_op_params_f32(tensor, 0);
15406
- tensor_clone = ggml_fill(ggml_ctx, tensor_clone, value);
16180
+ tensor_clone = ggml_fill(ggml_ctx, src_clone[0], value);
15407
16181
  } else if (tensor->op == GGML_OP_SQR) {
15408
16182
  tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
15409
16183
  } else if (tensor->op == GGML_OP_SQRT) {
@@ -15432,6 +16206,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
15432
16206
  tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
15433
16207
  } else if (tensor->op == GGML_OP_ACC) {
15434
16208
  tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
16209
+ } else if (tensor->op == GGML_OP_SET) {
16210
+ tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
15435
16211
  } else if (tensor->op == GGML_OP_NORM) {
15436
16212
  tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
15437
16213
  } else if (tensor->op == GGML_OP_GROUP_NORM) {
@@ -15488,6 +16264,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
15488
16264
  case GGML_UNARY_OP_EXP:
15489
16265
  tensor_clone = ggml_exp(ggml_ctx, src_clone[0]);
15490
16266
  break;
16267
+ case GGML_UNARY_OP_ELU:
16268
+ tensor_clone = ggml_elu(ggml_ctx, src_clone[0]);
16269
+ break;
15491
16270
  case GGML_UNARY_OP_SILU:
15492
16271
  tensor_clone = ggml_silu(ggml_ctx, src_clone[0]);
15493
16272
  break;
@@ -15546,6 +16325,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
15546
16325
  case GGML_UNARY_OP_TRUNC:
15547
16326
  tensor_clone = ggml_trunc(ggml_ctx, src_clone[0]);
15548
16327
  break;
16328
+ case GGML_UNARY_OP_SGN:
16329
+ tensor_clone = ggml_sgn(ggml_ctx, src_clone[0]);
16330
+ break;
15549
16331
  default:
15550
16332
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
15551
16333
  GGML_ABORT("fatal error");
@@ -15666,6 +16448,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
15666
16448
  } else if (tensor->op == GGML_OP_RWKV_WKV7) {
15667
16449
  tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
15668
16450
  src_clone[4], src_clone[5], src_clone[6]);
16451
+ } else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
16452
+ tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
16453
+ src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
15669
16454
  } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
15670
16455
  src_clone[0]->flags = tensor->src[0]->flags;
15671
16456
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
@@ -15864,7 +16649,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph *
15864
16649
  ggml_vk_print_graph_origin(tensor, done);
15865
16650
  }
15866
16651
 
15867
- if (avg_err > 0.5 || std::isnan(avg_err)) {
16652
+ if (avg_err > 0.01 || std::isnan(avg_err)) {
15868
16653
  std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl;
15869
16654
  std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl;
15870
16655
  if (src0 != nullptr) {