whispercpp 1.3.6 → 1.3.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (828) hide show
  1. checksums.yaml +4 -4
  2. data/.document +3 -0
  3. data/.rdoc_options +2 -0
  4. data/README.md +38 -5
  5. data/Rakefile +18 -3
  6. data/ext/dependencies.rb +10 -4
  7. data/ext/dependencies_for_windows.rb +17 -0
  8. data/ext/extconf.rb +20 -8
  9. data/ext/options.rb +54 -14
  10. data/ext/options_for_windows.rb +51 -0
  11. data/ext/ruby_whisper.c +36 -42
  12. data/ext/ruby_whisper.h +135 -0
  13. data/ext/ruby_whisper_context.c +107 -28
  14. data/ext/ruby_whisper_log_queue.c +180 -0
  15. data/ext/ruby_whisper_log_settable.h +47 -0
  16. data/ext/ruby_whisper_parakeet.c +49 -0
  17. data/ext/ruby_whisper_parakeet_context.c +304 -0
  18. data/ext/ruby_whisper_parakeet_context_params.c +117 -0
  19. data/ext/ruby_whisper_parakeet_model.c +84 -0
  20. data/ext/ruby_whisper_parakeet_params.c +548 -0
  21. data/ext/ruby_whisper_parakeet_segment.c +157 -0
  22. data/ext/ruby_whisper_parakeet_token.c +188 -0
  23. data/ext/ruby_whisper_parakeet_transcribe.cpp +58 -0
  24. data/ext/ruby_whisper_params.c +256 -65
  25. data/ext/ruby_whisper_segment.c +6 -6
  26. data/ext/ruby_whisper_transcribe.cpp +42 -15
  27. data/ext/sources/CMakeLists.txt +41 -3
  28. data/ext/sources/CMakePresets.json +95 -0
  29. data/ext/sources/cmake/parakeet-config.cmake.in +30 -0
  30. data/ext/sources/cmake/parakeet.pc.in +10 -0
  31. data/ext/sources/cmake/whisper.pc.in +1 -1
  32. data/ext/sources/examples/CMakeLists.txt +4 -2
  33. data/ext/sources/examples/bench/bench.cpp +1 -1
  34. data/ext/sources/examples/cli/cli.cpp +43 -9
  35. data/ext/sources/examples/common-ggml.cpp +2 -0
  36. data/ext/sources/examples/common-whisper.cpp +139 -67
  37. data/ext/sources/examples/common-whisper.h +11 -0
  38. data/ext/sources/examples/ffmpeg-transcode.cpp +211 -341
  39. data/ext/sources/examples/parakeet-cli/CMakeLists.txt +8 -0
  40. data/ext/sources/examples/parakeet-cli/parakeet-cli.cpp +243 -0
  41. data/ext/sources/examples/parakeet-quantize/CMakeLists.txt +7 -0
  42. data/ext/sources/examples/parakeet-quantize/parakeet-quantize.cpp +230 -0
  43. data/ext/sources/examples/server/server.cpp +199 -163
  44. data/ext/sources/ggml/CMakeLists.txt +21 -13
  45. data/ext/sources/ggml/cmake/FindNCCL.cmake +36 -0
  46. data/ext/sources/ggml/cmake/ggml-config.cmake.in +12 -2
  47. data/ext/sources/ggml/include/ggml-alloc.h +1 -0
  48. data/ext/sources/ggml/include/ggml-backend.h +72 -10
  49. data/ext/sources/ggml/include/ggml-cuda.h +3 -0
  50. data/ext/sources/ggml/include/ggml-rpc.h +3 -3
  51. data/ext/sources/ggml/include/ggml.h +101 -9
  52. data/ext/sources/ggml/include/gguf.h +10 -2
  53. data/ext/sources/ggml/src/CMakeLists.txt +22 -5
  54. data/ext/sources/ggml/src/ggml-alloc.c +5 -1
  55. data/ext/sources/ggml/src/ggml-backend-impl.h +22 -2
  56. data/ext/sources/ggml/src/ggml-backend-meta.cpp +2263 -0
  57. data/ext/sources/ggml/src/ggml-backend-reg.cpp +12 -0
  58. data/ext/sources/ggml/src/ggml-backend.cpp +110 -9
  59. data/ext/sources/ggml/src/ggml-blas/ggml-blas.cpp +4 -0
  60. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.cpp +672 -257
  61. data/ext/sources/ggml/src/ggml-cann/aclnn_ops.h +71 -0
  62. data/ext/sources/ggml/src/ggml-cann/common.h +20 -10
  63. data/ext/sources/ggml/src/ggml-cann/ggml-cann.cpp +211 -30
  64. data/ext/sources/ggml/src/ggml-common.h +11 -0
  65. data/ext/sources/ggml/src/ggml-cpu/CMakeLists.txt +58 -29
  66. data/ext/sources/ggml/src/ggml-cpu/amx/amx.cpp +2 -0
  67. data/ext/sources/ggml/src/ggml-cpu/amx/mmq.cpp +16 -16
  68. data/ext/sources/ggml/src/ggml-cpu/arch/arm/quants.c +116 -7
  69. data/ext/sources/ggml/src/ggml-cpu/arch/arm/repack.cpp +65 -0
  70. data/ext/sources/ggml/src/ggml-cpu/arch/loongarch/quants.c +151 -1
  71. data/ext/sources/ggml/src/ggml-cpu/arch/powerpc/quants.c +0 -1
  72. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/quants.c +4279 -1292
  73. data/ext/sources/ggml/src/ggml-cpu/arch/riscv/repack.cpp +5 -35
  74. data/ext/sources/ggml/src/ggml-cpu/arch/s390/quants.c +0 -1
  75. data/ext/sources/ggml/src/ggml-cpu/arch/wasm/quants.c +72 -1
  76. data/ext/sources/ggml/src/ggml-cpu/arch/x86/quants.c +177 -27
  77. data/ext/sources/ggml/src/ggml-cpu/arch/x86/repack.cpp +1 -1
  78. data/ext/sources/ggml/src/ggml-cpu/arch-fallback.h +5 -0
  79. data/ext/sources/ggml/src/ggml-cpu/cmake/FindSMTIME.cmake +32 -0
  80. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu-impl.h +10 -0
  81. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.c +95 -5
  82. data/ext/sources/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  83. data/ext/sources/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +146 -134
  84. data/ext/sources/ggml/src/ggml-cpu/llamafile/sgemm.cpp +88 -70
  85. data/ext/sources/ggml/src/ggml-cpu/ops.cpp +372 -73
  86. data/ext/sources/ggml/src/ggml-cpu/ops.h +3 -0
  87. data/ext/sources/ggml/src/ggml-cpu/quants.c +55 -0
  88. data/ext/sources/ggml/src/ggml-cpu/quants.h +3 -0
  89. data/ext/sources/ggml/src/ggml-cpu/repack.cpp +3 -0
  90. data/ext/sources/ggml/src/ggml-cpu/simd-gemm.h +90 -0
  91. data/ext/sources/ggml/src/ggml-cpu/simd-mappings.h +3 -16
  92. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.cpp +1402 -687
  93. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime.h +8 -0
  94. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp +597 -2766
  95. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime2_kernels.cpp +5768 -0
  96. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.cpp +320 -0
  97. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_env.h +55 -0
  98. data/ext/sources/ggml/src/ggml-cpu/spacemit/ime_kernels.h +182 -19
  99. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.cpp +1795 -0
  100. data/ext/sources/ggml/src/ggml-cpu/spacemit/repack.h +14 -0
  101. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.cpp +3178 -0
  102. data/ext/sources/ggml/src/ggml-cpu/spacemit/rvv_kernels.h +95 -0
  103. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_barrier.h +34 -0
  104. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.cpp +760 -0
  105. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_mem_pool.h +32 -0
  106. data/ext/sources/ggml/src/ggml-cpu/spacemit/spine_tcm.h +409 -0
  107. data/ext/sources/ggml/src/ggml-cpu/vec.cpp +37 -53
  108. data/ext/sources/ggml/src/ggml-cpu/vec.h +225 -240
  109. data/ext/sources/ggml/src/ggml-cuda/CMakeLists.txt +17 -7
  110. data/ext/sources/ggml/src/ggml-cuda/allreduce.cu +971 -0
  111. data/ext/sources/ggml/src/ggml-cuda/allreduce.cuh +29 -0
  112. data/ext/sources/ggml/src/ggml-cuda/argsort.cu +62 -26
  113. data/ext/sources/ggml/src/ggml-cuda/binbcast.cu +44 -18
  114. data/ext/sources/ggml/src/ggml-cuda/binbcast.cuh +1 -0
  115. data/ext/sources/ggml/src/ggml-cuda/common.cuh +242 -28
  116. data/ext/sources/ggml/src/ggml-cuda/concat.cu +120 -114
  117. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cu +45 -21
  118. data/ext/sources/ggml/src/ggml-cuda/conv2d-transpose.cuh +1 -0
  119. data/ext/sources/ggml/src/ggml-cuda/convert.cu +53 -0
  120. data/ext/sources/ggml/src/ggml-cuda/convert.cuh +10 -0
  121. data/ext/sources/ggml/src/ggml-cuda/cpy.cu +14 -6
  122. data/ext/sources/ggml/src/ggml-cuda/dequantize.cuh +22 -0
  123. data/ext/sources/ggml/src/ggml-cuda/fattn-common.cuh +278 -44
  124. data/ext/sources/ggml/src/ggml-cuda/fattn-mma-f16.cuh +331 -130
  125. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cu +12 -0
  126. data/ext/sources/ggml/src/ggml-cuda/fattn-tile.cuh +126 -27
  127. data/ext/sources/ggml/src/ggml-cuda/fattn-vec.cuh +40 -15
  128. data/ext/sources/ggml/src/ggml-cuda/fattn-wmma-f16.cu +18 -9
  129. data/ext/sources/ggml/src/ggml-cuda/fattn.cu +152 -49
  130. data/ext/sources/ggml/src/ggml-cuda/fattn.cuh +2 -0
  131. data/ext/sources/ggml/src/ggml-cuda/fwht.cu +101 -0
  132. data/ext/sources/ggml/src/ggml-cuda/fwht.cuh +4 -0
  133. data/ext/sources/ggml/src/ggml-cuda/gated_delta_net.cu +84 -35
  134. data/ext/sources/ggml/src/ggml-cuda/getrows.cu +34 -12
  135. data/ext/sources/ggml/src/ggml-cuda/ggml-cuda.cu +1069 -609
  136. data/ext/sources/ggml/src/ggml-cuda/im2col.cu +32 -29
  137. data/ext/sources/ggml/src/ggml-cuda/mean.cu +4 -2
  138. data/ext/sources/ggml/src/ggml-cuda/mma.cuh +242 -195
  139. data/ext/sources/ggml/src/ggml-cuda/mmf.cuh +3 -3
  140. data/ext/sources/ggml/src/ggml-cuda/mmq.cu +18 -12
  141. data/ext/sources/ggml/src/ggml-cuda/mmq.cuh +502 -423
  142. data/ext/sources/ggml/src/ggml-cuda/mmvf.cu +19 -12
  143. data/ext/sources/ggml/src/ggml-cuda/mmvq.cu +485 -57
  144. data/ext/sources/ggml/src/ggml-cuda/mmvq.cuh +6 -1
  145. data/ext/sources/ggml/src/ggml-cuda/norm.cu +36 -10
  146. data/ext/sources/ggml/src/ggml-cuda/out-prod.cu +23 -7
  147. data/ext/sources/ggml/src/ggml-cuda/quantize.cu +133 -26
  148. data/ext/sources/ggml/src/ggml-cuda/quantize.cuh +1 -1
  149. data/ext/sources/ggml/src/ggml-cuda/reduce_rows.cuh +5 -1
  150. data/ext/sources/ggml/src/ggml-cuda/rope.cu +11 -4
  151. data/ext/sources/ggml/src/ggml-cuda/scale.cu +4 -1
  152. data/ext/sources/ggml/src/ggml-cuda/set-rows.cu +14 -6
  153. data/ext/sources/ggml/src/ggml-cuda/snake.cu +72 -0
  154. data/ext/sources/ggml/src/ggml-cuda/snake.cuh +8 -0
  155. data/ext/sources/ggml/src/ggml-cuda/softcap.cu +4 -1
  156. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cu +45 -13
  157. data/ext/sources/ggml/src/ggml-cuda/ssm-conv.cuh +1 -1
  158. data/ext/sources/ggml/src/ggml-cuda/ssm-scan.cu +40 -18
  159. data/ext/sources/ggml/src/ggml-cuda/sumrows.cu +8 -4
  160. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_16.cu +1 -0
  161. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_32.cu +1 -0
  162. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_1-ncols2_8.cu +2 -0
  163. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_16-ncols2_4.cu +1 -0
  164. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_16.cu +1 -0
  165. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_32.cu +1 -0
  166. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_4.cu +1 -0
  167. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_2-ncols2_8.cu +2 -0
  168. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_16.cu +1 -0
  169. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_4.cu +1 -0
  170. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_4-ncols2_8.cu +2 -0
  171. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_4.cu +1 -0
  172. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-mma-f16-instance-ncols1_8-ncols2_8.cu +2 -0
  173. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq192-dv128.cu +5 -0
  174. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq320-dv256.cu +5 -0
  175. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq512-dv512.cu +5 -0
  176. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu +7 -0
  177. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-f16.cu +7 -0
  178. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_0.cu +7 -0
  179. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q4_1.cu +7 -0
  180. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_0.cu +7 -0
  181. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q5_1.cu +7 -0
  182. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-bf16-q8_0.cu +7 -0
  183. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-f16-bf16.cu +7 -0
  184. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_0-bf16.cu +7 -0
  185. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q4_1-bf16.cu +7 -0
  186. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_0-bf16.cu +7 -0
  187. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q5_1-bf16.cu +7 -0
  188. data/ext/sources/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-bf16.cu +7 -0
  189. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu +5 -0
  190. data/ext/sources/ggml/src/ggml-cuda/template-instances/mmq-instance-q1_0.cu +5 -0
  191. data/ext/sources/ggml/src/ggml-cuda/top-k.cu +5 -4
  192. data/ext/sources/ggml/src/ggml-cuda/topk-moe.cu +26 -23
  193. data/ext/sources/ggml/src/ggml-cuda/unary.cu +31 -2
  194. data/ext/sources/ggml/src/ggml-cuda/unary.cuh +2 -0
  195. data/ext/sources/ggml/src/ggml-cuda/vecdotq.cuh +80 -0
  196. data/ext/sources/ggml/src/ggml-cuda/vendors/cuda.h +7 -2
  197. data/ext/sources/ggml/src/ggml-cuda/vendors/hip.h +22 -4
  198. data/ext/sources/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  199. data/ext/sources/ggml/src/ggml-hexagon/CMakeLists.txt +2 -1
  200. data/ext/sources/ggml/src/ggml-hexagon/ggml-hexagon.cpp +1428 -743
  201. data/ext/sources/ggml/src/ggml-hexagon/htp/CMakeLists.txt +45 -7
  202. data/ext/sources/ggml/src/ggml-hexagon/htp/act-ops.c +53 -84
  203. data/ext/sources/ggml/src/ggml-hexagon/htp/argsort-ops.c +25 -12
  204. data/ext/sources/ggml/src/ggml-hexagon/htp/binary-ops.c +165 -184
  205. data/ext/sources/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +5 -5
  206. data/ext/sources/ggml/src/ggml-hexagon/htp/concat-ops.c +277 -0
  207. data/ext/sources/ggml/src/ggml-hexagon/htp/cpy-ops.c +170 -127
  208. data/ext/sources/ggml/src/ggml-hexagon/htp/cumsum-ops.c +270 -0
  209. data/ext/sources/ggml/src/ggml-hexagon/htp/diag-ops.c +216 -0
  210. data/ext/sources/ggml/src/ggml-hexagon/htp/fill-ops.c +123 -0
  211. data/ext/sources/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +125 -97
  212. data/ext/sources/ggml/src/ggml-hexagon/htp/gated-delta-net-ops.c +1148 -0
  213. data/ext/sources/ggml/src/ggml-hexagon/htp/get-rows-ops.c +148 -42
  214. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.c +2 -2
  215. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dma.h +252 -62
  216. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-dump.h +9 -0
  217. data/ext/sources/ggml/src/ggml-hexagon/htp/hex-utils.h +87 -1
  218. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +1878 -0
  219. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +2066 -0
  220. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.c +6 -0
  221. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-ops.h +88 -0
  222. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-profile.h +34 -0
  223. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.c +158 -0
  224. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-queue.h +134 -0
  225. data/ext/sources/ggml/src/ggml-hexagon/htp/hmx-utils.h +200 -0
  226. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ctx.h +96 -13
  227. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-ops.h +182 -57
  228. data/ext/sources/ggml/src/ggml-hexagon/htp/htp_iface.idl +9 -3
  229. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-base.h +71 -3
  230. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-copy.h +27 -10
  231. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-div.h +63 -23
  232. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-exp.h +9 -8
  233. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-flash-attn.h +47 -0
  234. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-log.h +65 -0
  235. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-pow.h +42 -0
  236. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-repl.h +74 -0
  237. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +1 -0
  238. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-sin-cos.h +90 -0
  239. data/ext/sources/ggml/src/ggml-hexagon/htp/hvx-utils.h +5 -8
  240. data/ext/sources/ggml/src/ggml-hexagon/htp/main.c +529 -815
  241. data/ext/sources/ggml/src/ggml-hexagon/htp/matmul-ops.c +2522 -234
  242. data/ext/sources/ggml/src/ggml-hexagon/htp/pad-ops.c +547 -0
  243. data/ext/sources/ggml/src/ggml-hexagon/htp/repeat-ops.c +148 -0
  244. data/ext/sources/ggml/src/ggml-hexagon/htp/rope-ops.c +291 -95
  245. data/ext/sources/ggml/src/ggml-hexagon/htp/set-rows-ops.c +59 -37
  246. data/ext/sources/ggml/src/ggml-hexagon/htp/softmax-ops.c +121 -133
  247. data/ext/sources/ggml/src/ggml-hexagon/htp/solve-tri-ops.c +267 -0
  248. data/ext/sources/ggml/src/ggml-hexagon/htp/ssm-conv.c +244 -151
  249. data/ext/sources/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +6 -6
  250. data/ext/sources/ggml/src/ggml-hexagon/htp/unary-ops.c +719 -45
  251. data/ext/sources/ggml/src/ggml-hexagon/htp/vtcm-utils.h +16 -0
  252. data/ext/sources/ggml/src/ggml-hexagon/htp-opnode.h +272 -0
  253. data/ext/sources/ggml/src/ggml-hexagon/libggml-htp.inf +3 -1
  254. data/ext/sources/ggml/src/ggml-hip/CMakeLists.txt +22 -9
  255. data/ext/sources/ggml/src/ggml-impl.h +6 -1
  256. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.cpp +138 -13
  257. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.h +32 -1
  258. data/ext/sources/ggml/src/ggml-metal/ggml-metal-device.m +164 -28
  259. data/ext/sources/ggml/src/ggml-metal/ggml-metal-impl.h +80 -0
  260. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.cpp +190 -19
  261. data/ext/sources/ggml/src/ggml-metal/ggml-metal-ops.h +2 -0
  262. data/ext/sources/ggml/src/ggml-metal/ggml-metal.cpp +39 -26
  263. data/ext/sources/ggml/src/ggml-metal/ggml-metal.metal +823 -322
  264. data/ext/sources/ggml/src/ggml-musa/CMakeLists.txt +5 -6
  265. data/ext/sources/ggml/src/ggml-opencl/CMakeLists.txt +54 -5
  266. data/ext/sources/ggml/src/ggml-opencl/ggml-opencl.cpp +12248 -5907
  267. data/ext/sources/ggml/src/ggml-opencl/kernels/concat.cl +67 -0
  268. data/ext/sources/ggml/src/ggml-opencl/kernels/cpy.cl +59 -0
  269. data/ext/sources/ggml/src/ggml-opencl/kernels/cvt.cl +1819 -112
  270. data/ext/sources/ggml/src/ggml-opencl/kernels/gated_delta_net.cl +249 -0
  271. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_mxfp4_f32_ns.cl +306 -0
  272. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_0_f32_ns.cl +256 -0
  273. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_1_f32_ns.cl +258 -0
  274. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q4_k_f32_ns.cl +283 -0
  275. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_0_f32_ns.cl +260 -0
  276. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_1_f32_ns.cl +262 -0
  277. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q5_k_f32_ns.cl +288 -0
  278. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_moe_q6_k_f32_ns.cl +267 -0
  279. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_iq4_nl_f32.cl +150 -0
  280. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mat_Ab_Bi_8x4.cl → gemm_noshuffle_q4_0_f32.cl} +1 -1
  281. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_k_f32.cl +172 -0
  282. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_0_f32.cl +131 -0
  283. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_1_f32.cl +134 -0
  284. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q5_k_f32.cl +176 -0
  285. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q6_k_f32.cl +140 -0
  286. data/ext/sources/ggml/src/ggml-opencl/kernels/{mul_mm_q8_0_f32_8x4.cl → gemm_noshuffle_q8_0_f32.cl} +1 -1
  287. data/ext/sources/ggml/src/ggml-opencl/kernels/gemm_xmem_f16_f32_os8.cl +233 -0
  288. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_mxfp4_f32_ns.cl +165 -0
  289. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_0_f32_ns.cl +120 -0
  290. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_1_f32_ns.cl +123 -0
  291. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q4_k_f32_ns.cl +155 -0
  292. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_0_f32_ns.cl +123 -0
  293. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_1_f32_ns.cl +125 -0
  294. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q5_k_f32_ns.cl +160 -0
  295. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_moe_q6_k_f32_ns.cl +141 -0
  296. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_iq4_nl_f32.cl +302 -0
  297. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general.cl → gemv_noshuffle_q4_0_f32.cl} +5 -5
  298. data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle.cl → gemv_noshuffle_q4_0_f32_spec.cl} +5 -5
  299. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_k_f32.cl +318 -0
  300. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_0_f32.cl +291 -0
  301. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_1_f32.cl +294 -0
  302. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q5_k_f32.cl +326 -0
  303. data/ext/sources/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q6_k_f32.cl +293 -0
  304. data/ext/sources/ggml/src/ggml-opencl/kernels/get_rows.cl +15 -9
  305. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_reorder_b.cl +30 -0
  306. data/ext/sources/ggml/src/ggml-opencl/kernels/moe_sort_by_expert.cl +82 -0
  307. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_iq4_nl_f32_l4_lm.cl +171 -0
  308. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q4_k_f32_l4_lm.cl +179 -0
  309. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_0_f32_l4_lm.cl +173 -0
  310. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_1_f32_l4_lm.cl +175 -0
  311. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mm_q5_k_f32_l4_lm.cl +192 -0
  312. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32.cl +164 -0
  313. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_iq4_nl_f32_flat.cl +202 -0
  314. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32_flat.cl +196 -0
  315. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32.cl +241 -0
  316. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_0_f32_flat.cl +243 -0
  317. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32.cl +243 -0
  318. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_1_f32_flat.cl +247 -0
  319. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32.cl +187 -0
  320. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q5_k_f32_flat.cl +203 -0
  321. data/ext/sources/ggml/src/ggml-opencl/kernels/mul_mv_q6_k_f32_flat.cl +48 -64
  322. data/ext/sources/ggml/src/ggml-openvino/ggml-decoder.cpp +15 -5
  323. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino-extra.cpp +18 -11
  324. data/ext/sources/ggml/src/ggml-openvino/ggml-openvino.cpp +35 -13
  325. data/ext/sources/ggml/src/ggml-openvino/ggml-quants.cpp +264 -192
  326. data/ext/sources/ggml/src/ggml-openvino/openvino/op/rope.cpp +33 -7
  327. data/ext/sources/ggml/src/ggml-openvino/openvino/op/unary_gelu.cpp +25 -0
  328. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.cpp +1 -0
  329. data/ext/sources/ggml/src/ggml-openvino/openvino/op_table.h +1 -0
  330. data/ext/sources/ggml/src/ggml-openvino/openvino/rt_info/weightless_caching_attributes.hpp +41 -0
  331. data/ext/sources/ggml/src/ggml-openvino/openvino/translate_session.cpp +27 -3
  332. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.cpp +67 -36
  333. data/ext/sources/ggml/src/ggml-openvino/openvino/utils.h +1 -0
  334. data/ext/sources/ggml/src/ggml-openvino/utils.cpp +101 -44
  335. data/ext/sources/ggml/src/ggml-openvino/utils.h +23 -3
  336. data/ext/sources/ggml/src/ggml-opt.cpp +1 -0
  337. data/ext/sources/ggml/src/ggml-quants.c +289 -114
  338. data/ext/sources/ggml/src/ggml-quants.h +3 -0
  339. data/ext/sources/ggml/src/ggml-rpc/CMakeLists.txt +24 -0
  340. data/ext/sources/ggml/src/ggml-rpc/ggml-rpc.cpp +167 -311
  341. data/ext/sources/ggml/src/ggml-rpc/transport.cpp +683 -0
  342. data/ext/sources/ggml/src/ggml-rpc/transport.h +34 -0
  343. data/ext/sources/ggml/src/ggml-sycl/CMakeLists.txt +50 -4
  344. data/ext/sources/ggml/src/ggml-sycl/add-id.cpp +1 -1
  345. data/ext/sources/ggml/src/ggml-sycl/backend.hpp +3 -1
  346. data/ext/sources/ggml/src/ggml-sycl/common.cpp +74 -2
  347. data/ext/sources/ggml/src/ggml-sycl/common.hpp +41 -1
  348. data/ext/sources/ggml/src/ggml-sycl/convert.cpp +115 -13
  349. data/ext/sources/ggml/src/ggml-sycl/convert.hpp +9 -0
  350. data/ext/sources/ggml/src/ggml-sycl/cumsum.cpp +148 -0
  351. data/ext/sources/ggml/src/ggml-sycl/cumsum.hpp +5 -0
  352. data/ext/sources/ggml/src/ggml-sycl/dequantize.hpp +663 -0
  353. data/ext/sources/ggml/src/ggml-sycl/diag.cpp +67 -0
  354. data/ext/sources/ggml/src/ggml-sycl/diag.hpp +5 -0
  355. data/ext/sources/ggml/src/ggml-sycl/dmmv.cpp +586 -6
  356. data/ext/sources/ggml/src/ggml-sycl/element_wise.cpp +1 -90
  357. data/ext/sources/ggml/src/ggml-sycl/element_wise.hpp +0 -2
  358. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.cpp +56 -0
  359. data/ext/sources/ggml/src/ggml-sycl/fattn-buffers.hpp +63 -0
  360. data/ext/sources/ggml/src/ggml-sycl/fattn-common.hpp +7 -5
  361. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.cpp +4 -0
  362. data/ext/sources/ggml/src/ggml-sycl/fattn-tile.hpp +76 -168
  363. data/ext/sources/ggml/src/ggml-sycl/fattn-vec.hpp +7 -0
  364. data/ext/sources/ggml/src/ggml-sycl/fattn.cpp +3 -1
  365. data/ext/sources/ggml/src/ggml-sycl/fill.cpp +55 -0
  366. data/ext/sources/ggml/src/ggml-sycl/fill.hpp +5 -0
  367. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.cpp +69 -31
  368. data/ext/sources/ggml/src/ggml-sycl/gated_delta_net.hpp +1 -0
  369. data/ext/sources/ggml/src/ggml-sycl/gemm.hpp +3 -0
  370. data/ext/sources/ggml/src/ggml-sycl/getrows.cpp +79 -3
  371. data/ext/sources/ggml/src/ggml-sycl/ggml-sycl.cpp +823 -190
  372. data/ext/sources/ggml/src/ggml-sycl/im2col.cpp +353 -89
  373. data/ext/sources/ggml/src/ggml-sycl/im2col.hpp +5 -3
  374. data/ext/sources/ggml/src/ggml-sycl/mmvq.cpp +1344 -26
  375. data/ext/sources/ggml/src/ggml-sycl/mmvq.hpp +16 -0
  376. data/ext/sources/ggml/src/ggml-sycl/pad.cpp +27 -27
  377. data/ext/sources/ggml/src/ggml-sycl/quants.hpp +71 -0
  378. data/ext/sources/ggml/src/ggml-sycl/set_rows.cpp +7 -1
  379. data/ext/sources/ggml/src/ggml-sycl/solve_tri.cpp +172 -0
  380. data/ext/sources/ggml/src/ggml-sycl/solve_tri.hpp +8 -0
  381. data/ext/sources/ggml/src/ggml-sycl/ssm_conv.cpp +6 -1
  382. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.cpp +156 -0
  383. data/ext/sources/ggml/src/ggml-sycl/ssm_scan.hpp +5 -0
  384. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.cpp +62 -10
  385. data/ext/sources/ggml/src/ggml-sycl/sycl_hw.hpp +18 -6
  386. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-tile-instance-dkq512-dv512.cpp +6 -0
  387. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-f16.cpp +1 -0
  388. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_0.cpp +1 -0
  389. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q4_1.cpp +1 -0
  390. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_0.cpp +1 -0
  391. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q5_1.cpp +1 -0
  392. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-f16-q8_0.cpp +1 -0
  393. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-f16.cpp +1 -0
  394. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_0.cpp +1 -0
  395. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q4_1.cpp +1 -0
  396. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_0.cpp +1 -0
  397. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q5_1.cpp +1 -0
  398. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_0-q8_0.cpp +1 -0
  399. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-f16.cpp +1 -0
  400. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_0.cpp +1 -0
  401. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q4_1.cpp +1 -0
  402. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_0.cpp +1 -0
  403. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q5_1.cpp +1 -0
  404. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q4_1-q8_0.cpp +1 -0
  405. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-f16.cpp +1 -0
  406. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_0.cpp +1 -0
  407. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q4_1.cpp +1 -0
  408. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_0.cpp +1 -0
  409. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q5_1.cpp +1 -0
  410. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_0-q8_0.cpp +1 -0
  411. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-f16.cpp +1 -0
  412. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_0.cpp +1 -0
  413. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q4_1.cpp +1 -0
  414. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_0.cpp +1 -0
  415. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q5_1.cpp +1 -0
  416. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q5_1-q8_0.cpp +1 -0
  417. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-f16.cpp +1 -0
  418. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_0.cpp +1 -0
  419. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q4_1.cpp +1 -0
  420. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_0.cpp +1 -0
  421. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q5_1.cpp +1 -0
  422. data/ext/sources/ggml/src/ggml-sycl/template-instances/fattn-vec-instance-q8_0-q8_0.cpp +1 -0
  423. data/ext/sources/ggml/src/ggml-sycl/type.hpp +112 -0
  424. data/ext/sources/ggml/src/ggml-sycl/upscale.cpp +410 -0
  425. data/ext/sources/ggml/src/ggml-sycl/upscale.hpp +9 -0
  426. data/ext/sources/ggml/src/ggml-sycl/vecdotq.hpp +215 -53
  427. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-buffer.cpp +4 -0
  428. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +2 -0
  429. data/ext/sources/ggml/src/ggml-virtgpu/ggml-backend.cpp +2 -0
  430. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu-shm.cpp +1 -0
  431. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.cpp +1 -0
  432. data/ext/sources/ggml/src/ggml-virtgpu/virtgpu.h +0 -2
  433. data/ext/sources/ggml/src/ggml-vulkan/CMakeLists.txt +11 -0
  434. data/ext/sources/ggml/src/ggml-vulkan/ggml-vulkan.cpp +2060 -535
  435. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -0
  436. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp +6 -2
  437. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp +146 -13
  438. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp +3 -1
  439. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +1 -1
  440. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +25 -1
  441. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +88 -0
  442. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl +643 -1
  443. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_nvfp4.comp +32 -0
  444. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q1_0.comp +29 -0
  445. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/diag.comp +0 -1
  446. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/dot_product_funcs.glsl +27 -0
  447. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp +0 -1
  448. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/feature-tests/coopmat2_decode_vector.comp +7 -0
  449. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +197 -48
  450. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +60 -59
  451. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +115 -113
  452. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +122 -31
  453. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_dequant.glsl +131 -0
  454. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_mmq_funcs.glsl +203 -0
  455. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/fwht.comp +115 -0
  456. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp +125 -64
  457. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl +0 -1
  458. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl +10 -1
  459. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl +16 -6
  460. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +76 -54
  461. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp +0 -1
  462. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/log.comp +0 -1
  463. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +122 -27
  464. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iface.glsl +6 -6
  465. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp +1 -1
  466. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp +1 -1
  467. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp +1 -1
  468. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +1 -0
  469. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +88 -55
  470. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +11 -17
  471. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +43 -10
  472. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +159 -125
  473. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl +8 -8
  474. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_shmem_types.glsl +24 -9
  475. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp +0 -1
  476. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_funcs.glsl +5 -2
  477. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl +0 -1
  478. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rope_params.glsl +3 -2
  479. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/snake.comp +49 -0
  480. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp +11 -1
  481. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/tri.comp +0 -1
  482. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +79 -2
  483. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +171 -147
  484. data/ext/sources/ggml/src/ggml-webgpu/CMakeLists.txt +5 -2
  485. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +2202 -283
  486. data/ext/sources/ggml/src/ggml-webgpu/ggml-webgpu.cpp +2610 -1403
  487. data/ext/sources/ggml/src/ggml-webgpu/pre_wgsl.hpp +37 -7
  488. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/add_id.wgsl +64 -0
  489. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +8 -7
  490. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +76 -95
  491. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl +19 -1
  492. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/conv2d.wgsl +165 -0
  493. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{cpy.tmpl.wgsl → cpy.wgsl} +25 -50
  494. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn.wgsl +107 -184
  495. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_quant_staging.tmpl +124 -0
  496. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_tile.wgsl +397 -0
  497. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_blk.wgsl +101 -0
  498. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_reduce.wgsl +84 -0
  499. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/flash_attn_vec_split.wgsl +619 -0
  500. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/gated_delta_net.wgsl +149 -0
  501. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl +183 -78
  502. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.wgsl +155 -0
  503. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/im2col.wgsl +101 -0
  504. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +655 -495
  505. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id.wgsl +195 -0
  506. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_gather.wgsl +52 -0
  507. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +154 -0
  508. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +8 -6
  509. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +5 -1
  510. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +80 -409
  511. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +1432 -0
  512. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +303 -0
  513. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quant_inner_loops.tmpl +21 -0
  514. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +173 -0
  515. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_mul.wgsl +152 -0
  516. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{rope.tmpl.wgsl → rope.wgsl} +71 -142
  517. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/row_norm.wgsl +153 -0
  518. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl +6 -4
  519. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set.wgsl +109 -0
  520. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows.wgsl +2 -3
  521. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/set_rows_quant.wgsl +224 -0
  522. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/{soft_max.tmpl.wgsl → soft_max.wgsl} +106 -206
  523. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/solve_tri.wgsl +121 -0
  524. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_conv.wgsl +65 -0
  525. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/ssm_scan.wgsl +193 -0
  526. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/unary.wgsl +68 -48
  527. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/upscale.wgsl +240 -0
  528. data/ext/sources/ggml/src/ggml-zdnn/ggml-zdnn.cpp +18 -14
  529. data/ext/sources/ggml/src/ggml-zendnn/CMakeLists.txt +1 -1
  530. data/ext/sources/ggml/src/ggml-zendnn/ggml-zendnn.cpp +244 -10
  531. data/ext/sources/ggml/src/ggml.c +110 -28
  532. data/ext/sources/ggml/src/gguf.cpp +173 -28
  533. data/ext/sources/include/parakeet.h +342 -0
  534. data/ext/sources/include/whisper.h +10 -0
  535. data/ext/sources/media/matmul.png +0 -0
  536. data/ext/sources/src/CMakeLists.txt +23 -0
  537. data/ext/sources/src/parakeet-arch.h +188 -0
  538. data/ext/sources/src/parakeet.cpp +3838 -0
  539. data/ext/sources/src/whisper.cpp +56 -12
  540. data/extsources.rb +26 -10
  541. data/lib/whisper/log_settable.rb +36 -0
  542. data/lib/whisper/model/uri.rb +13 -1
  543. data/lib/whisper/output.rb +74 -0
  544. data/sig/whisper.rbs +411 -62
  545. data/test/helper.rb +2 -0
  546. data/test/jfk_reader/jfk_reader.c +50 -7
  547. data/test/test_callback.rb +1 -0
  548. data/test/test_package.rb +6 -5
  549. data/test/test_parakeet.rb +28 -0
  550. data/test/test_parakeet_callback.rb +107 -0
  551. data/test/test_parakeet_context.rb +116 -0
  552. data/test/test_parakeet_context_params.rb +24 -0
  553. data/test/test_parakeet_model.rb +21 -0
  554. data/test/test_parakeet_params.rb +78 -0
  555. data/test/test_parakeet_segment.rb +42 -0
  556. data/test/test_parakeet_token.rb +73 -0
  557. data/test/test_params.rb +2 -0
  558. data/test/test_vad_segment.rb +1 -1
  559. data/test/test_whisper.rb +24 -6
  560. data/whispercpp.gemspec +2 -2
  561. metadata +215 -281
  562. data/ext/sources/bindings/javascript/CMakeLists.txt +0 -41
  563. data/ext/sources/bindings/javascript/emscripten.cpp +0 -93
  564. data/ext/sources/bindings/javascript/libwhisper.worker.js +0 -1
  565. data/ext/sources/bindings/javascript/package.json +0 -26
  566. data/ext/sources/bindings/javascript/whisper.js +0 -19
  567. data/ext/sources/examples/addon.node/CMakeLists.txt +0 -31
  568. data/ext/sources/examples/addon.node/__test__/whisper.spec.js +0 -133
  569. data/ext/sources/examples/addon.node/addon.cpp +0 -557
  570. data/ext/sources/examples/addon.node/index.js +0 -59
  571. data/ext/sources/examples/addon.node/package.json +0 -16
  572. data/ext/sources/examples/addon.node/vad-example.js +0 -132
  573. data/ext/sources/examples/bench.wasm/CMakeLists.txt +0 -49
  574. data/ext/sources/examples/bench.wasm/emscripten.cpp +0 -87
  575. data/ext/sources/examples/bench.wasm/index-tmpl.html +0 -285
  576. data/ext/sources/examples/coi-serviceworker.js +0 -146
  577. data/ext/sources/examples/command/CMakeLists.txt +0 -10
  578. data/ext/sources/examples/command/command.cpp +0 -802
  579. data/ext/sources/examples/command/commands.txt +0 -9
  580. data/ext/sources/examples/command.wasm/CMakeLists.txt +0 -50
  581. data/ext/sources/examples/command.wasm/emscripten.cpp +0 -327
  582. data/ext/sources/examples/command.wasm/index-tmpl.html +0 -415
  583. data/ext/sources/examples/generate-karaoke.sh +0 -57
  584. data/ext/sources/examples/helpers.js +0 -191
  585. data/ext/sources/examples/livestream.sh +0 -112
  586. data/ext/sources/examples/lsp/CMakeLists.txt +0 -10
  587. data/ext/sources/examples/lsp/lsp.cpp +0 -471
  588. data/ext/sources/examples/lsp/whisper.vim +0 -362
  589. data/ext/sources/examples/python/test_whisper_processor.py +0 -7
  590. data/ext/sources/examples/python/whisper_processor.py +0 -54
  591. data/ext/sources/examples/server/bench.js +0 -29
  592. data/ext/sources/examples/server.py +0 -120
  593. data/ext/sources/examples/stream/CMakeLists.txt +0 -10
  594. data/ext/sources/examples/stream/stream.cpp +0 -437
  595. data/ext/sources/examples/stream.wasm/CMakeLists.txt +0 -49
  596. data/ext/sources/examples/stream.wasm/emscripten.cpp +0 -216
  597. data/ext/sources/examples/stream.wasm/index-tmpl.html +0 -491
  598. data/ext/sources/examples/sycl/CMakeLists.txt +0 -9
  599. data/ext/sources/examples/sycl/build.sh +0 -22
  600. data/ext/sources/examples/sycl/ls-sycl-device.cpp +0 -11
  601. data/ext/sources/examples/sycl/run-whisper.sh +0 -17
  602. data/ext/sources/examples/talk-llama/CMakeLists.txt +0 -48
  603. data/ext/sources/examples/talk-llama/eleven-labs.py +0 -80
  604. data/ext/sources/examples/talk-llama/llama-adapter.cpp +0 -488
  605. data/ext/sources/examples/talk-llama/llama-adapter.h +0 -89
  606. data/ext/sources/examples/talk-llama/llama-arch.cpp +0 -2877
  607. data/ext/sources/examples/talk-llama/llama-arch.h +0 -628
  608. data/ext/sources/examples/talk-llama/llama-batch.cpp +0 -919
  609. data/ext/sources/examples/talk-llama/llama-batch.h +0 -173
  610. data/ext/sources/examples/talk-llama/llama-chat.cpp +0 -896
  611. data/ext/sources/examples/talk-llama/llama-chat.h +0 -71
  612. data/ext/sources/examples/talk-llama/llama-context.cpp +0 -3633
  613. data/ext/sources/examples/talk-llama/llama-context.h +0 -359
  614. data/ext/sources/examples/talk-llama/llama-cparams.cpp +0 -5
  615. data/ext/sources/examples/talk-llama/llama-cparams.h +0 -47
  616. data/ext/sources/examples/talk-llama/llama-ext.h +0 -12
  617. data/ext/sources/examples/talk-llama/llama-grammar.cpp +0 -1464
  618. data/ext/sources/examples/talk-llama/llama-grammar.h +0 -194
  619. data/ext/sources/examples/talk-llama/llama-graph.cpp +0 -2735
  620. data/ext/sources/examples/talk-llama/llama-graph.h +0 -1031
  621. data/ext/sources/examples/talk-llama/llama-hparams.cpp +0 -258
  622. data/ext/sources/examples/talk-llama/llama-hparams.h +0 -353
  623. data/ext/sources/examples/talk-llama/llama-impl.cpp +0 -171
  624. data/ext/sources/examples/talk-llama/llama-impl.h +0 -75
  625. data/ext/sources/examples/talk-llama/llama-io.cpp +0 -15
  626. data/ext/sources/examples/talk-llama/llama-io.h +0 -35
  627. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.cpp +0 -330
  628. data/ext/sources/examples/talk-llama/llama-kv-cache-iswa.h +0 -137
  629. data/ext/sources/examples/talk-llama/llama-kv-cache.cpp +0 -2285
  630. data/ext/sources/examples/talk-llama/llama-kv-cache.h +0 -389
  631. data/ext/sources/examples/talk-llama/llama-kv-cells.h +0 -533
  632. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.cpp +0 -275
  633. data/ext/sources/examples/talk-llama/llama-memory-hybrid-iswa.h +0 -140
  634. data/ext/sources/examples/talk-llama/llama-memory-hybrid.cpp +0 -268
  635. data/ext/sources/examples/talk-llama/llama-memory-hybrid.h +0 -139
  636. data/ext/sources/examples/talk-llama/llama-memory-recurrent.cpp +0 -1165
  637. data/ext/sources/examples/talk-llama/llama-memory-recurrent.h +0 -182
  638. data/ext/sources/examples/talk-llama/llama-memory.cpp +0 -59
  639. data/ext/sources/examples/talk-llama/llama-memory.h +0 -122
  640. data/ext/sources/examples/talk-llama/llama-mmap.cpp +0 -752
  641. data/ext/sources/examples/talk-llama/llama-mmap.h +0 -73
  642. data/ext/sources/examples/talk-llama/llama-model-loader.cpp +0 -1655
  643. data/ext/sources/examples/talk-llama/llama-model-loader.h +0 -206
  644. data/ext/sources/examples/talk-llama/llama-model-saver.cpp +0 -299
  645. data/ext/sources/examples/talk-llama/llama-model-saver.h +0 -40
  646. data/ext/sources/examples/talk-llama/llama-model.cpp +0 -9056
  647. data/ext/sources/examples/talk-llama/llama-model.h +0 -597
  648. data/ext/sources/examples/talk-llama/llama-quant.cpp +0 -1304
  649. data/ext/sources/examples/talk-llama/llama-quant.h +0 -1
  650. data/ext/sources/examples/talk-llama/llama-sampler.cpp +0 -3885
  651. data/ext/sources/examples/talk-llama/llama-sampler.h +0 -42
  652. data/ext/sources/examples/talk-llama/llama-vocab.cpp +0 -3970
  653. data/ext/sources/examples/talk-llama/llama-vocab.h +0 -187
  654. data/ext/sources/examples/talk-llama/llama.cpp +0 -1194
  655. data/ext/sources/examples/talk-llama/llama.h +0 -1573
  656. data/ext/sources/examples/talk-llama/models/afmoe.cpp +0 -190
  657. data/ext/sources/examples/talk-llama/models/apertus.cpp +0 -125
  658. data/ext/sources/examples/talk-llama/models/arcee.cpp +0 -135
  659. data/ext/sources/examples/talk-llama/models/arctic.cpp +0 -137
  660. data/ext/sources/examples/talk-llama/models/arwkv7.cpp +0 -86
  661. data/ext/sources/examples/talk-llama/models/baichuan.cpp +0 -123
  662. data/ext/sources/examples/talk-llama/models/bailingmoe.cpp +0 -143
  663. data/ext/sources/examples/talk-llama/models/bailingmoe2.cpp +0 -133
  664. data/ext/sources/examples/talk-llama/models/bert.cpp +0 -184
  665. data/ext/sources/examples/talk-llama/models/bitnet.cpp +0 -145
  666. data/ext/sources/examples/talk-llama/models/bloom.cpp +0 -101
  667. data/ext/sources/examples/talk-llama/models/chameleon.cpp +0 -178
  668. data/ext/sources/examples/talk-llama/models/chatglm.cpp +0 -132
  669. data/ext/sources/examples/talk-llama/models/codeshell.cpp +0 -111
  670. data/ext/sources/examples/talk-llama/models/cogvlm.cpp +0 -102
  671. data/ext/sources/examples/talk-llama/models/cohere2-iswa.cpp +0 -134
  672. data/ext/sources/examples/talk-llama/models/command-r.cpp +0 -122
  673. data/ext/sources/examples/talk-llama/models/dbrx.cpp +0 -122
  674. data/ext/sources/examples/talk-llama/models/deci.cpp +0 -135
  675. data/ext/sources/examples/talk-llama/models/deepseek.cpp +0 -142
  676. data/ext/sources/examples/talk-llama/models/deepseek2.cpp +0 -262
  677. data/ext/sources/examples/talk-llama/models/delta-net-base.cpp +0 -445
  678. data/ext/sources/examples/talk-llama/models/dots1.cpp +0 -132
  679. data/ext/sources/examples/talk-llama/models/dream.cpp +0 -105
  680. data/ext/sources/examples/talk-llama/models/ernie4-5-moe.cpp +0 -148
  681. data/ext/sources/examples/talk-llama/models/ernie4-5.cpp +0 -110
  682. data/ext/sources/examples/talk-llama/models/eurobert.cpp +0 -97
  683. data/ext/sources/examples/talk-llama/models/exaone-moe.cpp +0 -145
  684. data/ext/sources/examples/talk-llama/models/exaone.cpp +0 -114
  685. data/ext/sources/examples/talk-llama/models/exaone4.cpp +0 -123
  686. data/ext/sources/examples/talk-llama/models/falcon-h1.cpp +0 -111
  687. data/ext/sources/examples/talk-llama/models/falcon.cpp +0 -120
  688. data/ext/sources/examples/talk-llama/models/gemma-embedding.cpp +0 -116
  689. data/ext/sources/examples/talk-llama/models/gemma.cpp +0 -112
  690. data/ext/sources/examples/talk-llama/models/gemma2-iswa.cpp +0 -128
  691. data/ext/sources/examples/talk-llama/models/gemma3.cpp +0 -155
  692. data/ext/sources/examples/talk-llama/models/gemma3n-iswa.cpp +0 -384
  693. data/ext/sources/examples/talk-llama/models/glm4-moe.cpp +0 -170
  694. data/ext/sources/examples/talk-llama/models/glm4.cpp +0 -157
  695. data/ext/sources/examples/talk-llama/models/gpt2.cpp +0 -105
  696. data/ext/sources/examples/talk-llama/models/gptneox.cpp +0 -144
  697. data/ext/sources/examples/talk-llama/models/granite-hybrid.cpp +0 -195
  698. data/ext/sources/examples/talk-llama/models/granite.cpp +0 -210
  699. data/ext/sources/examples/talk-llama/models/grok.cpp +0 -159
  700. data/ext/sources/examples/talk-llama/models/grovemoe.cpp +0 -139
  701. data/ext/sources/examples/talk-llama/models/hunyuan-dense.cpp +0 -132
  702. data/ext/sources/examples/talk-llama/models/hunyuan-moe.cpp +0 -153
  703. data/ext/sources/examples/talk-llama/models/internlm2.cpp +0 -120
  704. data/ext/sources/examples/talk-llama/models/jais.cpp +0 -86
  705. data/ext/sources/examples/talk-llama/models/jais2.cpp +0 -123
  706. data/ext/sources/examples/talk-llama/models/jamba.cpp +0 -106
  707. data/ext/sources/examples/talk-llama/models/kimi-linear.cpp +0 -381
  708. data/ext/sources/examples/talk-llama/models/lfm2.cpp +0 -196
  709. data/ext/sources/examples/talk-llama/models/llada-moe.cpp +0 -122
  710. data/ext/sources/examples/talk-llama/models/llada.cpp +0 -99
  711. data/ext/sources/examples/talk-llama/models/llama-iswa.cpp +0 -178
  712. data/ext/sources/examples/talk-llama/models/llama.cpp +0 -175
  713. data/ext/sources/examples/talk-llama/models/maincoder.cpp +0 -117
  714. data/ext/sources/examples/talk-llama/models/mamba-base.cpp +0 -289
  715. data/ext/sources/examples/talk-llama/models/mamba.cpp +0 -54
  716. data/ext/sources/examples/talk-llama/models/mimo2-iswa.cpp +0 -129
  717. data/ext/sources/examples/talk-llama/models/minicpm3.cpp +0 -200
  718. data/ext/sources/examples/talk-llama/models/minimax-m2.cpp +0 -123
  719. data/ext/sources/examples/talk-llama/models/mistral3.cpp +0 -160
  720. data/ext/sources/examples/talk-llama/models/models.h +0 -704
  721. data/ext/sources/examples/talk-llama/models/modern-bert.cpp +0 -109
  722. data/ext/sources/examples/talk-llama/models/mpt.cpp +0 -126
  723. data/ext/sources/examples/talk-llama/models/nemotron-h.cpp +0 -162
  724. data/ext/sources/examples/talk-llama/models/nemotron.cpp +0 -122
  725. data/ext/sources/examples/talk-llama/models/neo-bert.cpp +0 -104
  726. data/ext/sources/examples/talk-llama/models/olmo.cpp +0 -121
  727. data/ext/sources/examples/talk-llama/models/olmo2.cpp +0 -150
  728. data/ext/sources/examples/talk-llama/models/olmoe.cpp +0 -124
  729. data/ext/sources/examples/talk-llama/models/openai-moe-iswa.cpp +0 -127
  730. data/ext/sources/examples/talk-llama/models/openelm.cpp +0 -124
  731. data/ext/sources/examples/talk-llama/models/orion.cpp +0 -123
  732. data/ext/sources/examples/talk-llama/models/paddleocr.cpp +0 -122
  733. data/ext/sources/examples/talk-llama/models/pangu-embedded.cpp +0 -121
  734. data/ext/sources/examples/talk-llama/models/phi2.cpp +0 -121
  735. data/ext/sources/examples/talk-llama/models/phi3.cpp +0 -152
  736. data/ext/sources/examples/talk-llama/models/plamo.cpp +0 -110
  737. data/ext/sources/examples/talk-llama/models/plamo2.cpp +0 -320
  738. data/ext/sources/examples/talk-llama/models/plamo3.cpp +0 -128
  739. data/ext/sources/examples/talk-llama/models/plm.cpp +0 -169
  740. data/ext/sources/examples/talk-llama/models/qwen.cpp +0 -108
  741. data/ext/sources/examples/talk-llama/models/qwen2.cpp +0 -126
  742. data/ext/sources/examples/talk-llama/models/qwen2moe.cpp +0 -151
  743. data/ext/sources/examples/talk-llama/models/qwen2vl.cpp +0 -117
  744. data/ext/sources/examples/talk-llama/models/qwen3.cpp +0 -120
  745. data/ext/sources/examples/talk-llama/models/qwen35.cpp +0 -381
  746. data/ext/sources/examples/talk-llama/models/qwen35moe.cpp +0 -422
  747. data/ext/sources/examples/talk-llama/models/qwen3moe.cpp +0 -131
  748. data/ext/sources/examples/talk-llama/models/qwen3next.cpp +0 -525
  749. data/ext/sources/examples/talk-llama/models/qwen3vl-moe.cpp +0 -140
  750. data/ext/sources/examples/talk-llama/models/qwen3vl.cpp +0 -132
  751. data/ext/sources/examples/talk-llama/models/refact.cpp +0 -94
  752. data/ext/sources/examples/talk-llama/models/rnd1.cpp +0 -126
  753. data/ext/sources/examples/talk-llama/models/rwkv6-base.cpp +0 -164
  754. data/ext/sources/examples/talk-llama/models/rwkv6.cpp +0 -94
  755. data/ext/sources/examples/talk-llama/models/rwkv6qwen2.cpp +0 -86
  756. data/ext/sources/examples/talk-llama/models/rwkv7-base.cpp +0 -137
  757. data/ext/sources/examples/talk-llama/models/rwkv7.cpp +0 -90
  758. data/ext/sources/examples/talk-llama/models/seed-oss.cpp +0 -124
  759. data/ext/sources/examples/talk-llama/models/smallthinker.cpp +0 -126
  760. data/ext/sources/examples/talk-llama/models/smollm3.cpp +0 -128
  761. data/ext/sources/examples/talk-llama/models/stablelm.cpp +0 -146
  762. data/ext/sources/examples/talk-llama/models/starcoder.cpp +0 -100
  763. data/ext/sources/examples/talk-llama/models/starcoder2.cpp +0 -121
  764. data/ext/sources/examples/talk-llama/models/step35-iswa.cpp +0 -165
  765. data/ext/sources/examples/talk-llama/models/t5-dec.cpp +0 -166
  766. data/ext/sources/examples/talk-llama/models/t5-enc.cpp +0 -96
  767. data/ext/sources/examples/talk-llama/models/wavtokenizer-dec.cpp +0 -149
  768. data/ext/sources/examples/talk-llama/models/xverse.cpp +0 -108
  769. data/ext/sources/examples/talk-llama/prompts/talk-alpaca.txt +0 -23
  770. data/ext/sources/examples/talk-llama/speak +0 -40
  771. data/ext/sources/examples/talk-llama/speak.bat +0 -1
  772. data/ext/sources/examples/talk-llama/speak.ps1 +0 -14
  773. data/ext/sources/examples/talk-llama/talk-llama.cpp +0 -813
  774. data/ext/sources/examples/talk-llama/unicode-data.cpp +0 -7034
  775. data/ext/sources/examples/talk-llama/unicode-data.h +0 -20
  776. data/ext/sources/examples/talk-llama/unicode.cpp +0 -1103
  777. data/ext/sources/examples/talk-llama/unicode.h +0 -111
  778. data/ext/sources/examples/wchess/CMakeLists.txt +0 -10
  779. data/ext/sources/examples/wchess/libwchess/CMakeLists.txt +0 -19
  780. data/ext/sources/examples/wchess/libwchess/Chessboard.cpp +0 -803
  781. data/ext/sources/examples/wchess/libwchess/Chessboard.h +0 -33
  782. data/ext/sources/examples/wchess/libwchess/WChess.cpp +0 -193
  783. data/ext/sources/examples/wchess/libwchess/WChess.h +0 -63
  784. data/ext/sources/examples/wchess/libwchess/test-chessboard.cpp +0 -117
  785. data/ext/sources/examples/wchess/wchess.cmd/CMakeLists.txt +0 -8
  786. data/ext/sources/examples/wchess/wchess.cmd/wchess.cmd.cpp +0 -253
  787. data/ext/sources/examples/whisper.wasm/CMakeLists.txt +0 -50
  788. data/ext/sources/examples/whisper.wasm/emscripten.cpp +0 -118
  789. data/ext/sources/examples/whisper.wasm/index-tmpl.html +0 -659
  790. data/ext/sources/ggml/src/ggml-cuda/template-instances/generate_cu_files.py +0 -99
  791. data/ext/sources/ggml/src/ggml-hexagon/htp/htp-msg.h +0 -155
  792. data/ext/sources/ggml/src/ggml-hexagon/op-desc.h +0 -153
  793. data/ext/sources/ggml/src/ggml-opencl/kernels/embed_kernel.py +0 -26
  794. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.cpp +0 -123
  795. data/ext/sources/ggml/src/ggml-openvino/openvino/pass/eliminate_zp.h +0 -17
  796. data/ext/sources/ggml/src/ggml-virtgpu/regenerate_remoting.py +0 -333
  797. data/ext/sources/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl +0 -5
  798. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +0 -182
  799. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl +0 -323
  800. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +0 -718
  801. data/ext/sources/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +0 -123
  802. data/ext/sources/tests/CMakeLists.txt +0 -112
  803. data/ext/sources/tests/earnings21/eval.mk +0 -58
  804. data/ext/sources/tests/earnings21/eval.py +0 -68
  805. data/ext/sources/tests/earnings21/normalizers/__init__.py +0 -2
  806. data/ext/sources/tests/earnings21/normalizers/basic.py +0 -80
  807. data/ext/sources/tests/earnings21/normalizers/english.json +0 -1741
  808. data/ext/sources/tests/earnings21/normalizers/english.py +0 -550
  809. data/ext/sources/tests/earnings21/requirements.txt +0 -6
  810. data/ext/sources/tests/en-0-ref.txt +0 -1
  811. data/ext/sources/tests/en-1-ref.txt +0 -1
  812. data/ext/sources/tests/en-2-ref.txt +0 -1
  813. data/ext/sources/tests/es-0-ref.txt +0 -1
  814. data/ext/sources/tests/librispeech/eval.mk +0 -39
  815. data/ext/sources/tests/librispeech/eval.py +0 -47
  816. data/ext/sources/tests/librispeech/normalizers/__init__.py +0 -2
  817. data/ext/sources/tests/librispeech/normalizers/basic.py +0 -80
  818. data/ext/sources/tests/librispeech/normalizers/english.json +0 -1741
  819. data/ext/sources/tests/librispeech/normalizers/english.py +0 -550
  820. data/ext/sources/tests/librispeech/requirements.txt +0 -6
  821. data/ext/sources/tests/run-tests.sh +0 -130
  822. data/ext/sources/tests/test-c.c +0 -3
  823. data/ext/sources/tests/test-vad-full.cpp +0 -56
  824. data/ext/sources/tests/test-vad.cpp +0 -83
  825. data/ext/sources/tests/test-whisper.js +0 -58
  826. data/lib/whisper/context.rb +0 -15
  827. data/lib/whisper/segment.rb +0 -58
  828. /data/ext/sources/ggml/src/ggml-opencl/kernels/{gemv_noshuffle_general_q8_0_f32.cl → gemv_noshuffle_q8_0_f32.cl} +0 -0
@@ -21,6 +21,33 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
21
21
 
22
22
  #include <vulkan/vulkan.hpp>
23
23
 
24
+ // Fallback definitions for VK_NV_cooperative_matrix_decode_vector in case the
25
+ // installed Vulkan headers predate the extension.
26
+ #ifndef VK_NV_cooperative_matrix_decode_vector
27
+ #define VK_NV_cooperative_matrix_decode_vector 1
28
+ #define VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME "VK_NV_cooperative_matrix_decode_vector"
29
+ #define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV ((VkStructureType)1000689000)
30
+ typedef struct VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV {
31
+ VkStructureType sType;
32
+ void* pNext;
33
+ VkBool32 cooperativeMatrixDecodeVector;
34
+ } VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV;
35
+ #endif
36
+
37
+ // SPIR-V Headers: different SDK installations expose different include paths.
38
+ // LunarG Vulkan SDK on Windows typically provides <spirv-headers/spirv.hpp>.
39
+ // Linux packages, MSYS2 and MinGW often use the Khronos layout <spirv/unified1/spirv.hpp>.
40
+ #if __has_include(<spirv/unified1/spirv.hpp>)
41
+ # include <spirv/unified1/spirv.hpp>
42
+ #elif __has_include(<spirv-headers/spirv.hpp>)
43
+ # include <spirv-headers/spirv.hpp>
44
+ #elif __has_include(<spirv.hpp>)
45
+ # include <spirv.hpp>
46
+ #else
47
+ // Fallback to let the compiler throw a standard "file not found" error
48
+ # include <spirv/unified1/spirv.hpp>
49
+ #endif
50
+
24
51
  #include <algorithm>
25
52
  #include <cmath>
26
53
  #include <iomanip>
@@ -35,9 +62,10 @@ DispatchLoaderDynamic & ggml_vk_default_dispatcher();
35
62
  #include <map>
36
63
  #include <set>
37
64
  #include <unordered_map>
38
- #include <memory>
65
+ #include <shared_mutex>
39
66
  #include <mutex>
40
67
  #include <future>
68
+ #include <condition_variable>
41
69
  #include <thread>
42
70
 
43
71
  #if defined(_MSC_VER)
@@ -85,6 +113,21 @@ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
85
113
  } VkPhysicalDeviceShaderBfloat16FeaturesKHR;
86
114
  #endif
87
115
 
116
+ #if !defined(VK_VALVE_shader_mixed_float_dot_product)
117
+ #define VK_VALVE_shader_mixed_float_dot_product 1
118
+ #define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_SPEC_VERSION 1
119
+ #define VK_VALVE_SHADER_MIXED_FLOAT_DOT_PRODUCT_EXTENSION_NAME "VK_VALVE_shader_mixed_float_dot_product"
120
+ #define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE ((VkStructureType)1000673000)
121
+ typedef struct VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE {
122
+ VkStructureType sType;
123
+ void* pNext;
124
+ VkBool32 shaderMixedFloatDotProductFloat16AccFloat32;
125
+ VkBool32 shaderMixedFloatDotProductFloat16AccFloat16;
126
+ VkBool32 shaderMixedFloatDotProductBFloat16Acc;
127
+ VkBool32 shaderMixedFloatDotProductFloat8AccFloat32;
128
+ } VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE;
129
+ #endif
130
+
88
131
  #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
89
132
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
90
133
  static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
@@ -97,8 +140,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
97
140
 
98
141
  #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256
99
142
 
100
- #define GGML_VK_MAX_NODES 8192
101
-
102
143
  #define VK_CHECK(err, msg) \
103
144
  do { \
104
145
  vk::Result err_ = (err); \
@@ -134,8 +175,9 @@ struct vk_pipeline_struct {
134
175
  uint32_t align;
135
176
  // true if fields have been set by ggml_vk_create_pipeline
136
177
  bool initialized {};
137
- // set to true to request the pipeline is compiled
138
- std::atomic<bool> needed {};
178
+ // true while a compile is in flight, used to dedupe concurrent claims.
179
+ // Protected by device->compile_mutex.
180
+ bool compile_pending {};
139
181
  // set to true when the shader has been compiled
140
182
  std::atomic<bool> compiled {};
141
183
  // number of registers used, extracted from pipeline executable properties
@@ -191,6 +233,7 @@ struct vk_queue;
191
233
 
192
234
  struct vk_command_buffer {
193
235
  vk::CommandBuffer buf;
236
+ uint64_t use_counter = 0;
194
237
  bool in_use = false;
195
238
  };
196
239
 
@@ -386,6 +429,7 @@ enum vk_conv_shapes {
386
429
  CONV_SHAPE_128x128,
387
430
  CONV_SHAPE_64x32,
388
431
  CONV_SHAPE_32x256,
432
+ CONV_SHAPE_64x128,
389
433
  CONV_SHAPE_COUNT,
390
434
  };
391
435
 
@@ -400,6 +444,7 @@ vk_conv_block_size vk_conv_block_sizes[CONV_SHAPE_COUNT] = {
400
444
  { 128, 128, 16 }, // CONV_SHAPE_128x128
401
445
  { 64, 32, 32 }, // CONV_SHAPE_64x32
402
446
  { 32, 256, 16 }, // CONV_SHAPE_32x256
447
+ { 64, 128, 16 }, // CONV_SHAPE_64x128
403
448
  };
404
449
 
405
450
  enum dmmv_wg_sizes {
@@ -425,22 +470,26 @@ struct vk_fa_pipeline_state {
425
470
  bool f32acc;
426
471
  uint32_t flags;
427
472
  uint32_t limit_occupancy_shmem;
473
+ ggml_type k_type;
474
+ ggml_type v_type;
428
475
 
429
476
  bool operator<(const vk_fa_pipeline_state &b) const {
430
- return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) <
431
- std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem);
477
+ return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem, k_type, v_type) <
478
+ std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem, b.k_type, b.v_type);
432
479
  }
433
480
  };
434
481
 
435
482
  struct vk_conv2d_pipeline_state {
436
- vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH)
437
- : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH) {}
483
+ vk_conv2d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t p0, uint32_t p1, uint32_t d0, uint32_t d1, uint32_t KW, uint32_t KH, uint32_t aligned)
484
+ : s0(s0), s1(s1), p0(p0), p1(p1), d0(d0), d1(d1), KW(KW), KH(KH), aligned(aligned) {}
438
485
 
439
486
  uint32_t s0, s1, p0, p1, d0, d1, KW, KH;
487
+ // when set, shader can skip K/CRS/NPQ bounds checks and address clamps
488
+ uint32_t aligned;
440
489
 
441
490
  bool operator<(const vk_conv2d_pipeline_state &b) const {
442
- return std::tie(s0, s1, p0, p1, d0, d1, KW, KH) <
443
- std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH);
491
+ return std::tie(s0, s1, p0, p1, d0, d1, KW, KH, aligned) <
492
+ std::tie(b.s0, b.s1, b.p0, b.p1, b.d0, b.d1, b.KW, b.KH, b.aligned);
444
493
  }
445
494
  };
446
495
 
@@ -485,6 +534,12 @@ static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGM
485
534
  GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
486
535
  GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
487
536
 
537
+ // Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder
538
+ // pass so it keeps the chain contiguous and by the dispatcher to detect the fusion.
539
+ static constexpr std::initializer_list<ggml_op> snake_pattern { GGML_OP_MUL, GGML_OP_SIN,
540
+ GGML_OP_SQR, GGML_OP_MUL,
541
+ GGML_OP_ADD };
542
+
488
543
  //node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
489
544
  //node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
490
545
  //node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
@@ -581,6 +636,14 @@ static constexpr std::initializer_list<std::array<int, 3>> rms_norm_mul_rope_vie
581
636
 
582
637
  struct vk_device_struct {
583
638
  std::recursive_mutex mutex;
639
+ mutable std::shared_mutex pinned_memory_mutex;
640
+
641
+ // Guards compile_pending, all_pipelines, and the dynamic pipeline maps
642
+ // (flash_attn, fa_mask_opt, solve_tri, conv2d, etc). The actual compile
643
+ // runs with no lock held, so different pipelines can compile in parallel.
644
+ // Lock order is device->mutex -> compile_mutex, never the reverse.
645
+ std::mutex compile_mutex;
646
+ std::condition_variable compile_cv;
584
647
 
585
648
  vk::PhysicalDevice physical_device;
586
649
  vk::PhysicalDeviceProperties properties;
@@ -654,6 +717,10 @@ struct vk_device_struct {
654
717
  uint32_t coopmat_int_k;
655
718
 
656
719
  bool coopmat2;
720
+ bool coopmat2_bf16_support {};
721
+ bool coopmat2_decode_vector;
722
+
723
+ bool dot2_f16 {};
657
724
 
658
725
  bool pipeline_executable_properties_support {};
659
726
 
@@ -666,6 +733,15 @@ struct vk_device_struct {
666
733
  bool mul_mat_id_m[GGML_TYPE_COUNT];
667
734
  bool mul_mat_id_s[GGML_TYPE_COUNT];
668
735
 
736
+ // Separate flags for the q8_1 (integer dot) mmq path, whose shader uses
737
+ // a different shared-memory layout than the float matmul shaders.
738
+ bool mul_mat_l_int[GGML_TYPE_COUNT];
739
+ bool mul_mat_m_int[GGML_TYPE_COUNT];
740
+ bool mul_mat_s_int[GGML_TYPE_COUNT];
741
+ bool mul_mat_id_l_int[GGML_TYPE_COUNT];
742
+ bool mul_mat_id_m_int[GGML_TYPE_COUNT];
743
+ bool mul_mat_id_s_int[GGML_TYPE_COUNT];
744
+
669
745
  vk::DescriptorSetLayout dsl;
670
746
 
671
747
  vk_matmul_pipeline pipeline_matmul_f32 {};
@@ -735,9 +811,10 @@ struct vk_device_struct {
735
811
  vk_pipeline pipeline_clamp_f32;
736
812
  vk_pipeline pipeline_pad_f32;
737
813
  vk_pipeline pipeline_roll_f32;
738
- vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
739
- vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
740
- vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
814
+ vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32;
815
+ vk_pipeline pipeline_repeat_i16;
816
+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_bf16_f32, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
817
+ vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_bf16_f32, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
741
818
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
742
819
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
743
820
  vk_pipeline pipeline_cpy_transpose_16, pipeline_cpy_transpose_32;
@@ -784,6 +861,7 @@ struct vk_device_struct {
784
861
  vk_pipeline pipeline_arange_f32;
785
862
 
786
863
  vk_pipeline pipeline_fill_f32;
864
+ vk_pipeline pipeline_fill_f16;
787
865
 
788
866
  vk_pipeline pipeline_geglu[2];
789
867
  vk_pipeline pipeline_reglu[2];
@@ -811,6 +889,7 @@ struct vk_device_struct {
811
889
  vk_pipeline pipeline_argsort_large_f32[num_argsort_pipelines];
812
890
  vk_pipeline pipeline_topk_f32[num_topk_pipelines];
813
891
  vk_pipeline pipeline_sum_rows_f32;
892
+ vk_pipeline pipeline_fwht_f32[4];
814
893
  vk_pipeline pipeline_cumsum_f32;
815
894
  vk_pipeline pipeline_cumsum_small_f32;
816
895
  vk_pipeline pipeline_cumsum_multipass1_f32;
@@ -822,6 +901,9 @@ struct vk_device_struct {
822
901
  vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
823
902
  vk_pipeline pipeline_timestep_embedding_f32;
824
903
  vk_pipeline pipeline_conv_transpose_1d_f32;
904
+ vk_pipeline pipeline_snake_f32;
905
+ vk_pipeline pipeline_snake_f16;
906
+ vk_pipeline pipeline_snake_bf16;
825
907
  vk_pipeline pipeline_pool2d_f32;
826
908
  vk_pipeline pipeline_rwkv_wkv6_f32;
827
909
  vk_pipeline pipeline_rwkv_wkv7_f32;
@@ -830,6 +912,8 @@ struct vk_device_struct {
830
912
  vk_pipeline pipeline_ssm_scan_f32_d128;
831
913
  vk_pipeline pipeline_ssm_scan_f32_d256;
832
914
  vk_pipeline pipeline_ssm_conv_f32;
915
+ vk_pipeline pipeline_ssm_conv_silu_f32;
916
+ vk_pipeline pipeline_ssm_conv_bias_silu_f32;
833
917
  vk_pipeline pipeline_opt_step_adamw_f32;
834
918
  vk_pipeline pipeline_opt_step_sgd_f32;
835
919
  std::map<vk_conv2d_pipeline_state, vk_pipeline> pipeline_conv2d_f32[CONV_SHAPE_COUNT];
@@ -839,7 +923,7 @@ struct vk_device_struct {
839
923
  vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
840
924
  vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
841
925
 
842
- std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT];
926
+ std::map<vk_fa_pipeline_state, vk_pipeline> pipeline_flash_attn_f32_f16;
843
927
 
844
928
  std::map<std::pair<uint32_t, uint32_t>, vk_pipeline> pipeline_fa_mask_opt;
845
929
 
@@ -938,19 +1022,24 @@ struct vk_subbuffer {
938
1022
  }
939
1023
  };
940
1024
 
941
- // vk_event is used for the event-related backend interfaces. It uses 'event' for
942
- // event_wait and 'fence' for event_synchronize. Polling on an event for
1025
+ struct vk_semaphore {
1026
+ vk::Semaphore s;
1027
+ uint64_t value;
1028
+ };
1029
+
1030
+ // vk_event is used for the event-related backend interfaces. It uses vk::Events for
1031
+ // event_wait and a timeline semaphore for event_synchronize. Polling on an event for
943
1032
  // event_synchronize wouldn't be sufficient to wait for command buffers to complete,
944
1033
  // and would lead to validation errors.
945
1034
  struct vk_event {
1035
+ std::vector<vk::Event> events_free; // Events available for reuse
1036
+ std::vector<vk::Event> events_submitted; // Events that are fully submitted and can be reused on next synchronize
946
1037
  vk::Event event;
947
- vk::Fence fence;
948
- vk_command_buffer* cmd_buffer = nullptr;
949
- };
1038
+ bool has_event;
950
1039
 
951
- struct vk_semaphore {
952
- vk::Semaphore s;
953
- uint64_t value;
1040
+ vk_semaphore tl_semaphore;
1041
+ vk_command_buffer* cmd_buffer = nullptr;
1042
+ uint64_t cmd_buffer_use_counter = 0;
954
1043
  };
955
1044
 
956
1045
  struct vk_submission {
@@ -1091,6 +1180,13 @@ struct vk_op_push_constants {
1091
1180
  float param4;
1092
1181
  };
1093
1182
 
1183
+ struct vk_op_fwht_push_constants {
1184
+ uint32_t n_rows;
1185
+ uint32_t src_offset;
1186
+ uint32_t dst_offset;
1187
+ float scale;
1188
+ };
1189
+
1094
1190
  struct vk_op_count_experts_push_constants {
1095
1191
  uint32_t ne00;
1096
1192
  uint32_t ne01;
@@ -1106,6 +1202,16 @@ struct vk_op_glu_push_constants {
1106
1202
  uint32_t mode; // 0: default, 1: swapped, 2: split
1107
1203
  float alpha; // for swiglu_oai
1108
1204
  float limit;
1205
+ uint32_t nb01;
1206
+ uint32_t nb02;
1207
+ uint32_t nb03;
1208
+ uint32_t ne01;
1209
+ uint32_t ne02;
1210
+ uint32_t nb11;
1211
+ uint32_t nb12;
1212
+ uint32_t nb13;
1213
+ uint32_t ne11;
1214
+ uint32_t ne12;
1109
1215
  };
1110
1216
 
1111
1217
  struct vk_op_unary_push_constants {
@@ -1313,6 +1419,8 @@ struct vk_op_rope_push_constants {
1313
1419
  uint32_t nb11;
1314
1420
  uint32_t nb12;
1315
1421
  uint32_t nb13;
1422
+ uint32_t a_offset;
1423
+ uint32_t d_offset;
1316
1424
  };
1317
1425
  static_assert(sizeof(vk_op_rope_push_constants) <= 128, "sizeof(vk_op_rope_push_constants) must be <= 128");
1318
1426
 
@@ -1371,7 +1479,7 @@ struct vk_op_im2col_push_constants {
1371
1479
  uint32_t IW; uint32_t IH;
1372
1480
  uint32_t OW; uint32_t OH;
1373
1481
  uint32_t KW; uint32_t KH;
1374
- uint32_t pelements;
1482
+ uint32_t OH_batch;
1375
1483
  uint32_t CHW;
1376
1484
  int32_t s0; int32_t s1;
1377
1485
  int32_t p0; int32_t p1;
@@ -1432,6 +1540,11 @@ struct vk_op_conv_transpose_1d_push_constants {
1432
1540
  int32_t s0;
1433
1541
  };
1434
1542
 
1543
+ struct vk_op_snake_push_constants {
1544
+ uint32_t ne0;
1545
+ uint32_t ne1;
1546
+ };
1547
+
1435
1548
  struct vk_op_pool2d_push_constants {
1436
1549
  uint32_t IW; uint32_t IH;
1437
1550
  uint32_t OW; uint32_t OH;
@@ -1466,6 +1579,7 @@ struct vk_op_gated_delta_net_push_constants {
1466
1579
  uint32_t sb1, sb2, sb3;
1467
1580
  uint32_t neq1, rq3;
1468
1581
  float scale;
1582
+ uint32_t K;
1469
1583
  };
1470
1584
 
1471
1585
  struct vk_op_ssm_scan_push_constants {
@@ -1641,7 +1755,7 @@ struct ggml_vk_garbage_collector {
1641
1755
  };
1642
1756
 
1643
1757
  static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_context subctx);
1644
- static void ggml_vk_load_shaders(vk_device& device);
1758
+ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested = nullptr);
1645
1759
  static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx);
1646
1760
 
1647
1761
  static bool vk_memory_logger_enabled = false;
@@ -1879,6 +1993,9 @@ struct ggml_backend_vk_context {
1879
1993
  // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert.
1880
1994
  vk_pipeline_struct * prealloc_y_last_pipeline_used {};
1881
1995
  const ggml_tensor * prealloc_y_last_tensor_used {};
1996
+ // True when prealloc_y holds the padded fp16 layout used by the coopmat2 B decode-vector callback.
1997
+ // If false, then it's contiguous.
1998
+ bool prealloc_y_last_decode_vector_staging {};
1882
1999
 
1883
2000
  // Track which nodes have been used since the last sync, and whether they were written to
1884
2001
  std::vector<const ggml_tensor *> unsynced_nodes_written;
@@ -1978,6 +2095,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
1978
2095
  GGML_UNUSED(src3);
1979
2096
  }
1980
2097
 
2098
+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_fwht_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
2099
+ p.src_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
2100
+ p.dst_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
2101
+
2102
+ GGML_UNUSED(src1);
2103
+ GGML_UNUSED(src2);
2104
+ GGML_UNUSED(src3);
2105
+ }
2106
+
1981
2107
  struct ggml_backend_vk_buffer_context {
1982
2108
  vk_device_ref device;
1983
2109
  vk_buffer dev_buffer;
@@ -2018,9 +2144,9 @@ void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) {
2018
2144
  const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal);
2019
2145
  std::string type = device ? "device" : "host";
2020
2146
  auto it = allocations.find(buf->buffer);
2021
- total_device -= device ? it->second : 0;
2022
- total_host -= device ? 0 : it->second;
2023
2147
  if (it != allocations.end()) {
2148
+ total_device -= device ? it->second : 0;
2149
+ total_host -= device ? 0 : it->second;
2024
2150
  VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host));
2025
2151
  allocations.erase(it);
2026
2152
  } else {
@@ -2099,10 +2225,135 @@ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
2099
2225
  ctx->device->device.resetFences({ ctx->fence });
2100
2226
  }
2101
2227
 
2102
- // variables to track number of compiles in progress
2103
- static uint32_t compile_count = 0;
2104
- static std::mutex compile_count_mutex;
2105
- static std::condition_variable compile_count_cond;
2228
+ static constexpr uint32_t kSpvOpCooperativeMatrixLoadTensorNV = 5367;
2229
+ static constexpr uint32_t kSpvCapabilityCooperativeMatrixDecodeVectorNV = 5447;
2230
+ static constexpr uint32_t kSpvTensorAddressingDecodeVectorFuncBit = 0x4;
2231
+
2232
+ // Remove SPV_NV_cooperative_matrix_decode_vector usage from a SPIR-V module so it
2233
+ // can be loaded on drivers that only support SPV_NV_cooperative_matrix2. Drops the
2234
+ // OpExtension declaration, the CooperativeMatrixDecodeVectorNV OpCapability, and the
2235
+ // DecodeVectorFunc operand from any OpCooperativeMatrixLoadTensorNV instruction.
2236
+ // Returns true when the input used the extension (and `out` was populated with a
2237
+ // stripped copy); returns false otherwise without touching `out`.
2238
+ static bool ggml_vk_strip_decode_vector(const uint32_t * code, size_t word_count, std::vector<uint32_t> & out) {
2239
+ static const char kDecodeVectorExt[] = "SPV_NV_cooperative_matrix_decode_vector";
2240
+
2241
+ if (word_count < 5) {
2242
+ return false;
2243
+ }
2244
+
2245
+ bool uses_decode_vector = false;
2246
+ for (size_t pos = 5; pos < word_count; ) {
2247
+ uint32_t word = code[pos];
2248
+ uint32_t wc = word >> spv::WordCountShift;
2249
+ uint32_t op = word & spv::OpCodeMask;
2250
+ GGML_ASSERT(wc > 0 && pos + wc <= word_count);
2251
+ if (op == spv::OpExtension && wc >= 2) {
2252
+ const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
2253
+ if (strcmp(s, kDecodeVectorExt) == 0) {
2254
+ uses_decode_vector = true;
2255
+ break;
2256
+ }
2257
+ }
2258
+ pos += wc;
2259
+ }
2260
+
2261
+ if (!uses_decode_vector) {
2262
+ return false;
2263
+ }
2264
+
2265
+ VK_LOG_DEBUG("ggml_vk_strip_decode_vector: stripping SPV_NV_cooperative_matrix_decode_vector");
2266
+
2267
+ // Bulk-copy unchanged runs and only break the run when an instruction needs to
2268
+ // be dropped or patched. Use reserve + insert/push_back so the destination buffer
2269
+ // is touched exactly once (no zero-initialization pass from resize()).
2270
+ out.clear();
2271
+ out.reserve(word_count);
2272
+
2273
+ size_t run_start = 0;
2274
+ auto flush_run = [&](size_t up_to) {
2275
+ if (up_to > run_start) {
2276
+ out.insert(out.end(), code + run_start, code + up_to);
2277
+ }
2278
+ };
2279
+
2280
+ for (size_t pos = 5; pos < word_count; ) {
2281
+ uint32_t word = code[pos];
2282
+ uint32_t wc = word >> spv::WordCountShift;
2283
+ uint32_t op = word & spv::OpCodeMask;
2284
+ GGML_ASSERT(wc > 0 && pos + wc <= word_count);
2285
+
2286
+ if (op == spv::OpExtension && wc >= 2) {
2287
+ const char * s = reinterpret_cast<const char *>(&code[pos + 1]);
2288
+ if (strcmp(s, kDecodeVectorExt) == 0) {
2289
+ flush_run(pos);
2290
+ pos += wc;
2291
+ run_start = pos;
2292
+ continue;
2293
+ }
2294
+ }
2295
+
2296
+ if (op == spv::OpCapability && wc == 2 && code[pos + 1] == kSpvCapabilityCooperativeMatrixDecodeVectorNV) {
2297
+ flush_run(pos);
2298
+ pos += wc;
2299
+ run_start = pos;
2300
+ continue;
2301
+ }
2302
+
2303
+ if (op == kSpvOpCooperativeMatrixLoadTensorNV) {
2304
+ // [opcode/wc][ResultType][Result][Pointer][Object][TensorLayout][MemOperand mask][mem extras...][TA mask][ta extras...]
2305
+ GGML_ASSERT(wc >= 8);
2306
+
2307
+ uint32_t mem_mask = code[pos + 6];
2308
+ size_t cur = pos + 7;
2309
+ // Each of these MemoryAccess bits (when set) carries one trailing operand.
2310
+ cur += (mem_mask & 0x2) ? 1 : 0; // Aligned
2311
+ cur += (mem_mask & 0x8) ? 1 : 0; // MakePointerAvailable
2312
+ cur += (mem_mask & 0x10) ? 1 : 0; // MakePointerVisible
2313
+ cur += (mem_mask & 0x10000) ? 1 : 0; // AliasScopeINTELMask
2314
+ cur += (mem_mask & 0x20000) ? 1 : 0; // NoAliasINTELMask
2315
+ GGML_ASSERT(cur < pos + wc);
2316
+
2317
+ uint32_t ta_mask = code[cur];
2318
+ if ((ta_mask & kSpvTensorAddressingDecodeVectorFuncBit) == 0) {
2319
+ pos += wc;
2320
+ continue; // leave instruction inside the current unchanged run
2321
+ }
2322
+
2323
+ flush_run(pos);
2324
+
2325
+ // Append unchanged prefix of the instruction (header through the mem-extras).
2326
+ size_t inst_start = out.size();
2327
+ size_t pre_n = cur - pos;
2328
+ out.insert(out.end(), code + pos, code + pos + pre_n);
2329
+
2330
+ // Emit TA mask with the DecodeVectorFunc bit cleared.
2331
+ out.push_back(ta_mask & ~kSpvTensorAddressingDecodeVectorFuncBit);
2332
+
2333
+ // TA extras: TensorView (0x1) and DecodeFunc (0x2) are kept verbatim;
2334
+ // DecodeVectorFunc (0x4) is dropped along with its trailing id operand.
2335
+ size_t keep_ta_extras = ((ta_mask & 0x1) ? 1 : 0) + ((ta_mask & 0x2) ? 1 : 0);
2336
+ if (keep_ta_extras) {
2337
+ out.insert(out.end(), code + cur + 1, code + cur + 1 + keep_ta_extras);
2338
+ }
2339
+
2340
+ GGML_ASSERT(wc == pre_n + 1 + keep_ta_extras + 1);
2341
+
2342
+ // Patch the instruction header with the new (one-shorter) word count.
2343
+ uint32_t new_wc = wc - 1;
2344
+ out[inst_start] = (new_wc << spv::WordCountShift) | op;
2345
+
2346
+ pos += wc;
2347
+ run_start = pos;
2348
+ continue;
2349
+ }
2350
+
2351
+ pos += wc;
2352
+ }
2353
+
2354
+ flush_run(word_count);
2355
+ return true;
2356
+ }
2106
2357
 
2107
2358
  static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint,
2108
2359
  uint32_t parameter_count, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
@@ -2115,6 +2366,78 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
2115
2366
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
2116
2367
 
2117
2368
  vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast<const uint32_t *>(spv_data));
2369
+
2370
+ // Patch SPIR-V to enable RTE rounding for FP16, avoiding the need for
2371
+ // separate shader variants compiled with -DRTE16.
2372
+ std::vector<uint32_t> spirv;
2373
+ if (device->float_controls_rte_fp16) {
2374
+ const uint32_t* spv_words = reinterpret_cast<const uint32_t *>(spv_data);
2375
+ size_t word_count = spv_size / sizeof(uint32_t);
2376
+ spirv.assign(spv_words, spv_words + word_count);
2377
+
2378
+ // Find insertion points respecting SPIR-V layout order:
2379
+ // Header(5) -> OpCapability -> OpExtension -> ... -> OpEntryPoint -> OpExecutionMode -> ...
2380
+ size_t pos = 5; // skip header
2381
+ size_t cap_insert_pos = pos;
2382
+ size_t ext_insert_pos = pos;
2383
+ size_t exec_insert_pos = pos;
2384
+ uint32_t entry_point_id = 0;
2385
+
2386
+ while (pos < spirv.size()) {
2387
+ uint32_t opcode = spirv[pos] & spv::OpCodeMask;
2388
+ uint32_t len = spirv[pos] >> spv::WordCountShift;
2389
+ if (len == 0) break;
2390
+
2391
+ if (opcode == spv::OpCapability) {
2392
+ cap_insert_pos = pos + len;
2393
+ ext_insert_pos = pos + len;
2394
+ } else if (opcode == spv::OpExtension) {
2395
+ ext_insert_pos = pos + len;
2396
+ } else if (opcode == spv::OpEntryPoint) {
2397
+ entry_point_id = spirv[pos + 2];
2398
+ exec_insert_pos = pos + len;
2399
+ } else if (opcode == spv::OpExecutionMode || opcode == spv::OpExecutionModeId) {
2400
+ exec_insert_pos = pos + len;
2401
+ } else if (entry_point_id != 0) {
2402
+ break;
2403
+ }
2404
+
2405
+ pos += len;
2406
+ }
2407
+
2408
+ // Insert from latest position first so earlier indices stay valid.
2409
+
2410
+ // OpExecutionMode %entrypoint RoundingModeRTE 16
2411
+ uint32_t exec_mode[] = { (4u << spv::WordCountShift) | spv::OpExecutionMode, entry_point_id, spv::ExecutionModeRoundingModeRTE, 16 };
2412
+ spirv.insert(spirv.begin() + exec_insert_pos, std::begin(exec_mode), std::end(exec_mode));
2413
+
2414
+ // OpExtension "SPV_KHR_float_controls"
2415
+ const char ext_str[] = "SPV_KHR_float_controls";
2416
+ size_t ext_str_words = CEIL_DIV(sizeof(ext_str), sizeof(uint32_t));
2417
+ std::vector<uint32_t> extension(1 + ext_str_words, 0);
2418
+ extension[0] = (uint32_t)((1 + ext_str_words) << spv::WordCountShift) | spv::OpExtension;
2419
+ memcpy(&extension[1], ext_str, sizeof(ext_str));
2420
+ spirv.insert(spirv.begin() + ext_insert_pos, extension.begin(), extension.end());
2421
+
2422
+ // OpCapability RoundingModeRTE
2423
+ uint32_t capability[] = { (2u << spv::WordCountShift) | spv::OpCapability, spv::CapabilityRoundingModeRTE };
2424
+ spirv.insert(spirv.begin() + cap_insert_pos, std::begin(capability), std::end(capability));
2425
+
2426
+ shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
2427
+ }
2428
+
2429
+ #if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
2430
+ if (device->coopmat2 && !device->coopmat2_decode_vector) {
2431
+ const uint32_t * src = spirv.empty() ? reinterpret_cast<const uint32_t *>(spv_data) : spirv.data();
2432
+ size_t src_n = spirv.empty() ? spv_size / sizeof(uint32_t) : spirv.size();
2433
+ std::vector<uint32_t> stripped;
2434
+ if (ggml_vk_strip_decode_vector(src, src_n, stripped)) {
2435
+ spirv = std::move(stripped);
2436
+ shader_module_create_info = vk::ShaderModuleCreateInfo({}, spirv.size() * sizeof(uint32_t), spirv.data());
2437
+ }
2438
+ }
2439
+ #endif
2440
+
2118
2441
  pipeline->shader_module = device->device.createShaderModule(shader_module_create_info);
2119
2442
 
2120
2443
  vk::PushConstantRange pcr(
@@ -2196,7 +2519,6 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
2196
2519
  std::cerr << "ggml_vulkan: " << e.what() << std::endl;
2197
2520
  throw e;
2198
2521
  }
2199
- pipeline->compiled = true;
2200
2522
 
2201
2523
  if (vk_instance.debug_utils_support) {
2202
2524
  vk::DebugUtilsObjectNameInfoEXT duoni;
@@ -2245,14 +2567,13 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
2245
2567
  }
2246
2568
  }
2247
2569
 
2248
- device->all_pipelines.push_back(pipeline);
2249
-
2250
2570
  {
2251
- std::lock_guard<std::mutex> guard(compile_count_mutex);
2252
- assert(compile_count > 0);
2253
- compile_count--;
2571
+ std::lock_guard<std::mutex> guard(device->compile_mutex);
2572
+ device->all_pipelines.push_back(pipeline);
2573
+ pipeline->compiled = true;
2574
+ pipeline->compile_pending = false;
2254
2575
  }
2255
- compile_count_cond.notify_all();
2576
+ device->compile_cv.notify_all();
2256
2577
  }
2257
2578
 
2258
2579
  static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) {
@@ -2268,8 +2589,7 @@ static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx,
2268
2589
  VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")");
2269
2590
  ctx->pipeline_descriptor_set_requirements += n;
2270
2591
  if (!pipeline->compiled) {
2271
- pipeline->needed = true;
2272
- ggml_vk_load_shaders(ctx->device);
2592
+ ggml_vk_load_shaders(ctx->device, pipeline);
2273
2593
  }
2274
2594
  ggml_pipeline_allocate_descriptor_sets(ctx);
2275
2595
  }
@@ -2319,7 +2639,7 @@ static vk_command_buffer* ggml_vk_create_cmd_buffer(vk_device& device, vk_comman
2319
2639
  vk::CommandBufferLevel::ePrimary,
2320
2640
  1);
2321
2641
  const std::vector<vk::CommandBuffer> cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info);
2322
- p.cmd_buffers.push_back({ cmd_buffers.front(), true });
2642
+ p.cmd_buffers.push_back({ cmd_buffers.front(), 0, true });
2323
2643
  return &p.cmd_buffers[p.cmd_buffers.size()-1];
2324
2644
  }
2325
2645
 
@@ -2788,6 +3108,15 @@ static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subct
2788
3108
  );
2789
3109
  }
2790
3110
 
3111
+ static void ggml_vk_reset_event(vk_context& ctx, vk::Event& event) {
3112
+ VK_LOG_DEBUG("ggml_vk_set_event()");
3113
+
3114
+ ctx->s->buffer->buf.resetEvent(
3115
+ event,
3116
+ ctx->p->q->stage_flags
3117
+ );
3118
+ }
3119
+
2791
3120
  static void ggml_vk_set_event(vk_context& ctx, vk::Event& event) {
2792
3121
  VK_LOG_DEBUG("ggml_vk_set_event()");
2793
3122
 
@@ -2833,11 +3162,10 @@ struct vk_fa_tuning_params {
2833
3162
  }
2834
3163
  };
2835
3164
 
2836
- static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
2837
- static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc);
3165
+ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type);
3166
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type = GGML_TYPE_F16);
2838
3167
 
2839
- static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
2840
- GGML_UNUSED(kv_type);
3168
+ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
2841
3169
 
2842
3170
  vk_fa_tuning_params result{};
2843
3171
  result.path = FA_SCALAR;
@@ -2889,7 +3217,7 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
2889
3217
 
2890
3218
  result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0;
2891
3219
 
2892
- if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) {
3220
+ if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc, k_type, v_type)) {
2893
3221
  result.block_rows /= 2;
2894
3222
  }
2895
3223
 
@@ -2912,10 +3240,11 @@ static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device,
2912
3240
  return result;
2913
3241
  }
2914
3242
 
2915
- static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
3243
+ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
2916
3244
  GGML_UNUSED(n_rows);
2917
3245
  GGML_UNUSED(n_kv);
2918
- GGML_UNUSED(kv_type);
3246
+ GGML_UNUSED(k_type);
3247
+ GGML_UNUSED(v_type);
2919
3248
  GGML_UNUSED(f32acc);
2920
3249
 
2921
3250
  vk_fa_tuning_params result{};
@@ -2942,7 +3271,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device
2942
3271
  return result;
2943
3272
  }
2944
3273
 
2945
- static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
3274
+ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
2946
3275
  GGML_UNUSED(n_kv);
2947
3276
  GGML_UNUSED(f32acc);
2948
3277
 
@@ -2956,7 +3285,7 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
2956
3285
  if (small_rows) {
2957
3286
  result.block_rows = 32;
2958
3287
  result.block_cols = 32;
2959
- } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) {
3288
+ } else if (ggml_is_quantized(k_type) || ggml_is_quantized(v_type) || hsk >= 256 || hsv >= 256) {
2960
3289
  result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64;
2961
3290
  result.block_cols = 32;
2962
3291
  } else {
@@ -2970,10 +3299,17 @@ static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device
2970
3299
  return result;
2971
3300
  }
2972
3301
 
2973
- static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) {
3302
+ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type k_type, ggml_type v_type, bool f32acc) {
2974
3303
  FaCodePath path = device->coopmat2 ? FA_COOPMAT2 :
2975
3304
  device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
2976
3305
 
3306
+ if (path == FA_COOPMAT2 && k_type == GGML_TYPE_BF16 && !device->coopmat2_bf16_support) {
3307
+ path = FA_COOPMAT1;
3308
+ }
3309
+ if (path == FA_COOPMAT1 && k_type == GGML_TYPE_BF16 && !device->coopmat_bf16_support) {
3310
+ path = FA_SCALAR;
3311
+ }
3312
+
2977
3313
  if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) {
2978
3314
  // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090
2979
3315
  path = FA_SCALAR;
@@ -2982,8 +3318,8 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
2982
3318
  if (path == FA_COOPMAT1) {
2983
3319
  bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) ||
2984
3320
  (!f32acc && device->coopmat_support_16x16x16_f16acc);
2985
- const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
2986
- bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc);
3321
+ const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
3322
+ bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc, k_type);
2987
3323
 
2988
3324
  if (!shape_ok || !shmem_ok) {
2989
3325
  path = FA_SCALAR;
@@ -2995,20 +3331,25 @@ static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_
2995
3331
  path = FA_SCALAR;
2996
3332
  }
2997
3333
 
3334
+ // Q1_0 K/V is only implemented on coopmat2 (flash_attn_cm2); there is no scalar FA shader for it.
3335
+ if ((k_type == GGML_TYPE_Q1_0 || v_type == GGML_TYPE_Q1_0) && device->coopmat2) {
3336
+ path = FA_COOPMAT2;
3337
+ }
3338
+
2998
3339
  switch (path) {
2999
3340
  case FA_SCALAR:
3000
- return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
3341
+ return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
3001
3342
  case FA_COOPMAT1:
3002
- return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
3343
+ return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
3003
3344
  case FA_COOPMAT2:
3004
- return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc);
3345
+ return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, k_type, v_type, f32acc);
3005
3346
  default:
3006
3347
  throw std::runtime_error("unsupported FaCodePath");
3007
3348
  }
3008
3349
  }
3009
3350
 
3010
3351
  static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc,
3011
- bool use_mask, bool use_mask_opt, bool use_logit_softcap) {
3352
+ bool use_mask, bool use_mask_opt, bool use_logit_softcap, ggml_type k_type, ggml_type v_type) {
3012
3353
  const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary &&
3013
3354
  (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2);
3014
3355
 
@@ -3019,12 +3360,32 @@ static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const
3019
3360
 
3020
3361
  const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size;
3021
3362
 
3022
- return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem};
3363
+ return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem, k_type, v_type};
3023
3364
  }
3024
3365
 
3025
3366
  static std::vector<uint32_t> get_fa_spec_constants(const vk_fa_pipeline_state& state) {
3026
- return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split,
3027
- state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem};
3367
+ const auto fa_block_bytes = [](ggml_type t) -> uint32_t {
3368
+ if (t == GGML_TYPE_F32) return 16u;
3369
+ return (uint32_t) ggml_type_size(t);
3370
+ };
3371
+ return {
3372
+ /* 0 WorkGroupSize */ state.workgroup_size,
3373
+ /* 1 Br */ state.Br,
3374
+ /* 2 Bc */ state.Bc,
3375
+ /* 3 HSK */ state.HSK,
3376
+ /* 4 HSV */ state.HSV,
3377
+ /* 5 Clamp */ static_cast<uint32_t>(!state.aligned),
3378
+ /* 6 D_split */ state.D_split,
3379
+ /* 7 row_split */ state.row_split,
3380
+ /* 8 SubGroupSize */ state.subgroup_size,
3381
+ /* 9 SHMEM_STAGING */ state.shmem_staging ? 1u : 0u,
3382
+ /*10 Flags */ state.flags,
3383
+ /*11 LIMIT_OCCUPANCY_SHMEM */ state.limit_occupancy_shmem,
3384
+ /*12 FaTypeK */ static_cast<uint32_t>(state.k_type),
3385
+ /*13 FaTypeV */ static_cast<uint32_t>(state.v_type),
3386
+ /*14 FaBlockBytesK */ fa_block_bytes(state.k_type),
3387
+ /*15 FaBlockBytesV */ fa_block_bytes(state.v_type),
3388
+ };
3028
3389
  }
3029
3390
 
3030
3391
  static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
@@ -3033,7 +3394,9 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
3033
3394
  switch (src0_type) {
3034
3395
  case GGML_TYPE_IQ1_S:
3035
3396
  case GGML_TYPE_IQ1_M:
3036
- lut_size = 2*2048 + 4*2048;
3397
+ // Regular matmul uses the compact uint16_t IQ1 grid; the expanded
3398
+ // uint32_t grid is only enabled for the q8_1/int-dot vector path.
3399
+ lut_size = 2*2048;
3037
3400
  break;
3038
3401
  case GGML_TYPE_IQ2_XXS:
3039
3402
  lut_size = 8*256;
@@ -3055,6 +3418,10 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
3055
3418
  case GGML_TYPE_MXFP4:
3056
3419
  lut_size = 4*16;
3057
3420
  break;
3421
+ case GGML_TYPE_NVFP4:
3422
+ // Same kvalues budget as MXFP4 plus ue4m3_fp32_lut[128] (types.glsl, DATA_A_NVFP4).
3423
+ lut_size = 4*16 + 128u * (uint32_t)sizeof(float);
3424
+ break;
3058
3425
  default:
3059
3426
  break;
3060
3427
  }
@@ -3078,6 +3445,70 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
3078
3445
  return supported;
3079
3446
  }
3080
3447
 
3448
+ // Shmem usage for the q8_1 mmq shader (mul_mmq.comp), which uses
3449
+ // block_a_cache / block_b_cache layouts (see mul_mmq_shmem_types.glsl) rather
3450
+ // than the float load buffers checked by ggml_vk_matmul_shmem_support.
3451
+ // Sizes follow std430 rules. Returns false for types without a q8_1 pipeline.
3452
+ static bool ggml_vk_matmul_int_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id, ggml_type src0_type) {
3453
+
3454
+ // FLOAT_TYPE in the shader is float16_t with fp16 support, otherwise float.
3455
+ const uint32_t fp_size = device->fp16 ? 2u : 4u;
3456
+ const uint32_t fp_align = fp_size;
3457
+ const uint32_t fp2_size = 2u * fp_size;
3458
+ const uint32_t fp2_align = device->fp16 ? 4u : 8u;
3459
+
3460
+ struct member { uint32_t size, align; };
3461
+ auto std430_size = [](std::initializer_list<member> members) {
3462
+ uint32_t off = 0, struct_align = 1;
3463
+ for (const auto &m : members) {
3464
+ off = (off + m.align - 1) & ~(m.align - 1);
3465
+ off += m.size;
3466
+ struct_align = std::max(struct_align, m.align);
3467
+ }
3468
+ return (off + struct_align - 1) & ~(struct_align - 1);
3469
+ };
3470
+
3471
+ uint32_t block_a_size = 0;
3472
+ switch (src0_type) {
3473
+ case GGML_TYPE_Q4_0: block_a_size = std430_size({{16, 4}, {fp_size, fp_align}}); break; // qs[16/4] + dm
3474
+ case GGML_TYPE_Q4_1: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + dm(vec2)
3475
+ case GGML_TYPE_Q5_0: block_a_size = std430_size({{16, 4}, {4, 4}, {fp_size, fp_align}}); break; // qs[16/4] + qh + dm
3476
+ case GGML_TYPE_Q5_1: block_a_size = std430_size({{16, 4}, {4, 4}, {fp2_size, fp2_align}}); break; // qs[16/4] + qh + dm(vec2)
3477
+ case GGML_TYPE_Q8_0: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + dm
3478
+ case GGML_TYPE_MXFP4: block_a_size = std430_size({{32, 4}, {fp_size, fp_align}}); break; // qs[8] + d
3479
+ case GGML_TYPE_Q2_K: block_a_size = std430_size({{ 8, 4}, {2, 2}, {fp2_size, fp2_align}}); break; // qs[2] + scales(u8vec2) + dm(vec2)
3480
+ case GGML_TYPE_Q3_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + d_scales(vec2)
3481
+ case GGML_TYPE_Q4_K: block_a_size = std430_size({{16, 4}, {fp2_size, fp2_align}}); break; // qs[4] + dm(vec2)
3482
+ case GGML_TYPE_Q5_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + dm(vec2)
3483
+ case GGML_TYPE_Q6_K: block_a_size = std430_size({{32, 4}, {fp2_size, fp2_align}}); break; // qs[8] + d_scales(vec2)
3484
+ default:
3485
+ return false;
3486
+ }
3487
+
3488
+ // block_b_cache: { int32_t qs[8]; FLOAT_TYPEV2 ds; }
3489
+ const uint32_t block_b_size = std430_size({{32, 4}, {fp2_size, fp2_align}});
3490
+
3491
+ const uint32_t BM = warptile[1];
3492
+ const uint32_t BN = warptile[2];
3493
+ // mul_mmq.comp: BK_STEP=1 for MUL_MAT_ID, 4 otherwise.
3494
+ const uint32_t BK_STEP = mul_mat_id ? 1u : 4u;
3495
+
3496
+ const uint32_t buf_a_size = BM * BK_STEP * block_a_size;
3497
+ const uint32_t buf_b_size = BN * BK_STEP * block_b_size;
3498
+ const uint32_t mmid_row_ids = mul_mat_id ? (BN * 2u * (uint32_t)sizeof(uint16_t)) : 0u;
3499
+
3500
+ const uint32_t warps = warptile[0] / warptile[10];
3501
+ const uint32_t ballots_sh = mul_mat_id ? (warps * 4u * (uint32_t)sizeof(uint32_t)) : 0u;
3502
+
3503
+ const uint32_t total_size = buf_a_size + buf_b_size + mmid_row_ids + ballots_sh;
3504
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
3505
+
3506
+ VK_LOG_DEBUG("ggml_vk_matmul_int_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
3507
+ "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", total=" << total_size << ", supported=" << supported);
3508
+
3509
+ return supported;
3510
+ }
3511
+
3081
3512
  struct GpuPipelineConfig {
3082
3513
  // GPU architecture identifier.
3083
3514
  // Example: vk_device_architecture::AMD_GCN
@@ -3145,10 +3576,40 @@ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_dev
3145
3576
  return 0; // If no matching configuration is found
3146
3577
  }
3147
3578
 
3148
- static void ggml_vk_load_shaders(vk_device& device) {
3579
+ // Whether scalar flash attention will use the MMQ path for the given k_type.
3580
+ static bool ggml_vk_fa_scalar_uses_mmq(const vk_device& device, ggml_type k_type) {
3581
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3582
+ return device->integer_dot_product && device->subgroup_clustered &&
3583
+ (k_type == GGML_TYPE_Q4_0 || k_type == GGML_TYPE_Q4_1 ||
3584
+ k_type == GGML_TYPE_Q5_0 || k_type == GGML_TYPE_Q5_1 ||
3585
+ k_type == GGML_TYPE_Q8_0);
3586
+ #else
3587
+ GGML_UNUSED(device);
3588
+ GGML_UNUSED(k_type);
3589
+ return false;
3590
+ #endif
3591
+ }
3592
+
3593
+ // load_shaders walks the pipeline list under compile_mutex and either claims
3594
+ // the requested pipeline for compilation or, if another thread is already
3595
+ // compiling it, drops the lock and waits on compile_cv. Compiles themselves
3596
+ // run unlocked.
3597
+ struct CompileTask {
3598
+ vk_pipeline pipeline;
3599
+ size_t spv_size;
3600
+ const void * spv_data;
3601
+ std::string entrypoint;
3602
+ uint32_t parameter_count;
3603
+ std::array<uint32_t, 3> wg_denoms;
3604
+ std::vector<uint32_t> specialization_constants;
3605
+ bool disable_robustness;
3606
+ bool require_full_subgroups;
3607
+ uint32_t required_subgroup_size;
3608
+ };
3609
+
3610
+ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) {
3149
3611
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
3150
3612
 
3151
- std::lock_guard<std::recursive_mutex> guard(device->mutex);
3152
3613
  // some shaders have a minimum subgroup size
3153
3614
  const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u);
3154
3615
  const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
@@ -3178,6 +3639,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
3178
3639
  l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
3179
3640
 
3180
3641
  uint32_t l_align, m_align, s_align;
3642
+
3643
+ vk_pipeline wait_pipeline;
3644
+ CompileTask claimed_task {};
3645
+ bool has_claimed_task = false;
3646
+
3647
+ // The rest of the walk reads and writes shared device state, so hold the
3648
+ // lock until we're done deciding what to compile.
3649
+ std::unique_lock<std::mutex> compile_lock(device->compile_mutex);
3650
+
3181
3651
  if (device->coopmat2) {
3182
3652
  // spec constants and tile sizes for non-quant matmul/matmul_id
3183
3653
  l_warptile = { 256, 128, 256, 64, 1 };
@@ -3204,9 +3674,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
3204
3674
  s_mmq_wg_denoms_k = { 32, 64, 1 };
3205
3675
 
3206
3676
  // spec constants and tile sizes for quant matmul_id
3207
- l_warptile_mmqid = { 256, 128, 128, 32, 1, device->subgroup_size };
3208
- m_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
3209
- s_warptile_mmqid = { 256, 128, 64, 32, 0, device->subgroup_size };
3677
+ const uint32_t mmqid_bk = device->coopmat2_decode_vector ? 64u : 32u;
3678
+ l_warptile_mmqid = { 256, 128, 128, mmqid_bk, 1, device->subgroup_size };
3679
+ m_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size };
3680
+ s_warptile_mmqid = { 256, 128, 64, mmqid_bk, 0, device->subgroup_size };
3210
3681
  l_mmqid_wg_denoms = { 128, 128, 1 };
3211
3682
  m_mmqid_wg_denoms = { 128, 64, 1 };
3212
3683
  s_mmqid_wg_denoms = { 128, 64, 1 };
@@ -3310,6 +3781,40 @@ static void ggml_vk_load_shaders(vk_device& device) {
3310
3781
  } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) {
3311
3782
  device->mul_mat_id_l[i] = false;
3312
3783
  }
3784
+
3785
+ // The q8_1 mmq path has its own (larger) shmem layout, check it separately.
3786
+ // K-quants use the _int_k warptiles, others use _int.
3787
+ const bool is_k_quant = (t == GGML_TYPE_Q2_K || t == GGML_TYPE_Q3_K ||
3788
+ t == GGML_TYPE_Q4_K || t == GGML_TYPE_Q5_K ||
3789
+ t == GGML_TYPE_Q6_K);
3790
+ const auto & s_int = is_k_quant ? s_warptile_mmq_int_k : s_warptile_mmq_int;
3791
+ const auto & m_int = is_k_quant ? m_warptile_mmq_int_k : m_warptile_mmq_int;
3792
+ const auto & l_int = is_k_quant ? l_warptile_mmq_int_k : l_warptile_mmq_int;
3793
+ const auto & s_intid = is_k_quant ? s_warptile_mmqid_int_k : s_warptile_mmqid_int;
3794
+ const auto & m_intid = is_k_quant ? m_warptile_mmqid_int_k : m_warptile_mmqid_int;
3795
+ const auto & l_intid = is_k_quant ? l_warptile_mmqid_int_k : l_warptile_mmqid_int;
3796
+
3797
+ if (!ggml_vk_matmul_int_shmem_support(device, s_int, false, t)) {
3798
+ device->mul_mat_s_int[i] = false;
3799
+ device->mul_mat_m_int[i] = false;
3800
+ device->mul_mat_l_int[i] = false;
3801
+ } else if (!ggml_vk_matmul_int_shmem_support(device, m_int, false, t)) {
3802
+ device->mul_mat_m_int[i] = false;
3803
+ device->mul_mat_l_int[i] = false;
3804
+ } else if (!ggml_vk_matmul_int_shmem_support(device, l_int, false, t)) {
3805
+ device->mul_mat_l_int[i] = false;
3806
+ }
3807
+
3808
+ if (!ggml_vk_matmul_int_shmem_support(device, s_intid, true, t)) {
3809
+ device->mul_mat_id_s_int[i] = false;
3810
+ device->mul_mat_id_m_int[i] = false;
3811
+ device->mul_mat_id_l_int[i] = false;
3812
+ } else if (!ggml_vk_matmul_int_shmem_support(device, m_intid, true, t)) {
3813
+ device->mul_mat_id_m_int[i] = false;
3814
+ device->mul_mat_id_l_int[i] = false;
3815
+ } else if (!ggml_vk_matmul_int_shmem_support(device, l_intid, true, t)) {
3816
+ device->mul_mat_id_l_int[i] = false;
3817
+ }
3313
3818
  }
3314
3819
  }
3315
3820
 
@@ -3329,7 +3834,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
3329
3834
  device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
3330
3835
  }
3331
3836
 
3332
- std::vector<std::future<void>> compiles;
3333
3837
  auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& base_pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint,
3334
3838
  uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
3335
3839
  uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
@@ -3363,23 +3867,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
3363
3867
  #endif
3364
3868
  }
3365
3869
 
3366
- if (!pipeline->needed || pipeline->compiled) {
3870
+ // We only care about the pipeline this call asked for; the rest
3871
+ // (including the 64-bit indexing variant) are handled by their
3872
+ // own request_descriptor_sets / load_shaders calls.
3873
+ if (pipeline.get() != requested.get()) {
3367
3874
  continue;
3368
3875
  }
3369
- // TODO: We're no longer benefitting from the async compiles (shaders are
3370
- // compiled individually, as needed) and this complexity can be removed.
3371
- {
3372
- // wait until fewer than N compiles are in progress
3373
- uint32_t N = std::max(1u, std::thread::hardware_concurrency());
3374
- std::unique_lock<std::mutex> guard(compile_count_mutex);
3375
- while (compile_count >= N) {
3376
- compile_count_cond.wait(guard);
3377
- }
3378
- compile_count++;
3876
+
3877
+ if (pipeline->compiled) {
3878
+ continue;
3379
3879
  }
3380
3880
 
3381
- compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint,
3382
- parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
3881
+ wait_pipeline = pipeline;
3882
+
3883
+ if (!pipeline->compile_pending) {
3884
+ pipeline->compile_pending = true;
3885
+ claimed_task.pipeline = pipeline;
3886
+ claimed_task.spv_size = spv_size;
3887
+ claimed_task.spv_data = spv_data;
3888
+ claimed_task.entrypoint = entrypoint;
3889
+ claimed_task.parameter_count = parameter_count;
3890
+ claimed_task.wg_denoms = wg_denoms;
3891
+ claimed_task.specialization_constants = specialization_constants;
3892
+ claimed_task.disable_robustness = disable_robustness;
3893
+ claimed_task.require_full_subgroups = require_full_subgroups;
3894
+ claimed_task.required_subgroup_size = required_subgroup_size;
3895
+ has_claimed_task = true;
3896
+ }
3383
3897
  }
3384
3898
  };
3385
3899
 
@@ -3391,64 +3905,132 @@ static void ggml_vk_load_shaders(vk_device& device) {
3391
3905
  align, disable_robustness, require_full_subgroups, required_subgroup_size);
3392
3906
  };
3393
3907
 
3394
- #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
3395
- for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \
3396
- FaCodePath path = fa.first.path; \
3397
- uint32_t Br = fa.first.Br; \
3398
- uint32_t Bc = fa.first.Bc; \
3399
- bool aligned = fa.first.aligned; \
3400
- bool f32acc = fa.first.f32acc; \
3401
- uint32_t fa_sgs = fa.first.subgroup_size; \
3402
- bool fa_ds = fa.first.subgroup_size == 0; \
3403
- if (path == FAPATH) { \
3404
- if (aligned) { \
3405
- if (f32acc) { \
3406
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
3407
- } else { \
3408
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
3409
- } \
3410
- } else { \
3411
- if (f32acc) { \
3412
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
3413
- } else { \
3414
- ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \
3415
- } \
3416
- } \
3417
- } \
3908
+ // FA scalar has two SPIR-V modules (MMQ vs non-MMQ); FA cm1 has one. K/V
3909
+ // quant type is selected at runtime via the FaTypeK / FaTypeV spec constants.
3910
+
3911
+ for (auto &fa : device->pipeline_flash_attn_f32_f16) {
3912
+ if (fa.first.path != FA_SCALAR) continue;
3913
+ const uint32_t Br = fa.first.Br;
3914
+ const uint32_t Bc = fa.first.Bc;
3915
+ const bool aligned = fa.first.aligned;
3916
+ const bool f32acc = fa.first.f32acc;
3917
+ const uint32_t fa_sgs = fa.first.subgroup_size;
3918
+ const bool fa_ds = fa.first.subgroup_size == 0;
3919
+
3920
+ const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
3921
+ const bool use_mmq = ggml_vk_fa_scalar_uses_mmq(device, fa.first.k_type);
3922
+ const void * spv_data = nullptr;
3923
+ size_t spv_size = 0;
3924
+ const char *name = nullptr;
3925
+ if (bf16_kv) {
3926
+ spv_data = flash_attn_f32_f16_fp32_data;
3927
+ spv_size = flash_attn_f32_f16_fp32_len;
3928
+ name = aligned ? "flash_attn_f32_bf16_aligned" : "flash_attn_f32_bf16";
3929
+ } else if (use_mmq) {
3930
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3931
+ if (device->fp16) {
3932
+ if (f32acc) { spv_data = flash_attn_f32_f16_int8_data; spv_size = flash_attn_f32_f16_int8_len; }
3933
+ else { spv_data = flash_attn_f32_f16_f16acc_int8_data; spv_size = flash_attn_f32_f16_f16acc_int8_len; }
3934
+ } else {
3935
+ spv_data = flash_attn_f32_f16_fp32_int8_data;
3936
+ spv_size = flash_attn_f32_f16_fp32_int8_len;
3937
+ }
3938
+ #endif
3939
+ name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
3940
+ } else {
3941
+ if (device->fp16) {
3942
+ if (device->dot2_f16) {
3943
+ if (f32acc) { spv_data = flash_attn_f32_f16_dot2_data; spv_size = flash_attn_f32_f16_dot2_len; }
3944
+ else { spv_data = flash_attn_f32_f16_dot2_f16acc_data; spv_size = flash_attn_f32_f16_dot2_f16acc_len; }
3945
+ } else {
3946
+ if (f32acc) { spv_data = flash_attn_f32_f16_data; spv_size = flash_attn_f32_f16_len; }
3947
+ else { spv_data = flash_attn_f32_f16_f16acc_data; spv_size = flash_attn_f32_f16_f16acc_len; }
3948
+ }
3949
+ } else {
3950
+ spv_data = flash_attn_f32_f16_fp32_data;
3951
+ spv_size = flash_attn_f32_f16_fp32_len;
3952
+ }
3953
+ name = aligned ? "flash_attn_f32_f16_aligned" : "flash_attn_f32_f16";
3418
3954
  }
3419
-
3420
- if (device->fp16) {
3421
- CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
3422
- CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
3423
- CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
3424
- CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
3425
- } else {
3426
- CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32)
3427
- CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32)
3428
- CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32)
3429
- CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32)
3955
+ ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
3956
+ sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
3957
+ get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
3958
+ !fa_ds, !fa_ds ? fa_sgs : 0);
3430
3959
  }
3960
+
3431
3961
  #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3432
3962
  if (device->coopmat1_fa_support) {
3433
- CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
3434
- CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
3435
- CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
3436
- CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
3963
+ for (auto &fa : device->pipeline_flash_attn_f32_f16) {
3964
+ if (fa.first.path != FA_COOPMAT1) continue;
3965
+ const uint32_t Br = fa.first.Br;
3966
+ const uint32_t Bc = fa.first.Bc;
3967
+ const bool aligned = fa.first.aligned;
3968
+ const bool f32acc = fa.first.f32acc;
3969
+ const uint32_t fa_sgs = fa.first.subgroup_size;
3970
+ const bool fa_ds = fa.first.subgroup_size == 0;
3971
+
3972
+ const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
3973
+
3974
+ const void * spv_data;
3975
+ size_t spv_size;
3976
+ const char *name;
3977
+ if (bf16_kv) {
3978
+ #if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3979
+ if (!device->coopmat_bf16_support) continue;
3980
+ spv_data = flash_attn_f32_f16_bf16_cm1_data;
3981
+ spv_size = flash_attn_f32_f16_bf16_cm1_len;
3982
+ name = aligned ? "flash_attn_f32_bf16_aligned_cm1" : "flash_attn_f32_bf16_cm1";
3983
+ #else
3984
+ continue;
3985
+ #endif
3986
+ } else {
3987
+ if (f32acc) { spv_data = flash_attn_f32_f16_cm1_data; spv_size = flash_attn_f32_f16_cm1_len; }
3988
+ else { spv_data = flash_attn_f32_f16_f16acc_cm1_data; spv_size = flash_attn_f32_f16_f16acc_cm1_len; }
3989
+ name = aligned ? "flash_attn_f32_f16_aligned_cm1" : "flash_attn_f32_f16_cm1";
3990
+ }
3991
+ ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
3992
+ sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
3993
+ get_fa_spec_constants(fa.first), aligned ? Bc : 1, true,
3994
+ !fa_ds, !fa_ds ? fa_sgs : 0);
3995
+ }
3437
3996
  }
3438
3997
  #endif
3998
+
3439
3999
  #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
3440
4000
  if (device->coopmat2) {
3441
- CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
3442
- CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
3443
- CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
3444
- CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
3445
- CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
3446
- CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
3447
- CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
3448
- CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
4001
+ for (auto &fa : device->pipeline_flash_attn_f32_f16) {
4002
+ if (fa.first.path != FA_COOPMAT2) continue;
4003
+ const uint32_t Br = fa.first.Br;
4004
+ const uint32_t Bc = fa.first.Bc;
4005
+ const bool aligned = fa.first.aligned;
4006
+ const bool f32acc = fa.first.f32acc;
4007
+
4008
+ const bool bf16_kv = fa.first.k_type == GGML_TYPE_BF16;
4009
+ const void * spv_data;
4010
+ size_t spv_size;
4011
+ const char * name;
4012
+ if (bf16_kv) {
4013
+ #if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
4014
+ if (!device->coopmat2_bf16_support) continue;
4015
+ spv_data = flash_attn_f32_f16_bf16_cm2_data;
4016
+ spv_size = flash_attn_f32_f16_bf16_cm2_len;
4017
+ name = aligned ? "flash_attn_f32_bf16_aligned_cm2" : "flash_attn_f32_bf16_cm2";
4018
+ #else
4019
+ continue;
4020
+ #endif
4021
+ } else if (aligned) {
4022
+ if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_aligned_f32acc_cm2"; }
4023
+ else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_aligned_f16acc_cm2"; }
4024
+ } else {
4025
+ if (f32acc) { spv_data = flash_attn_f32_f16_cm2_data; spv_size = flash_attn_f32_f16_cm2_len; name = "flash_attn_f32_f16_f32acc_cm2"; }
4026
+ else { spv_data = flash_attn_f32_f16_f16acc_cm2_data; spv_size = flash_attn_f32_f16_f16acc_cm2_len; name = "flash_attn_f32_f16_f16acc_cm2"; }
4027
+ }
4028
+ ggml_vk_create_pipeline(device, fa.second, name, spv_size, spv_data, "main", 7,
4029
+ sizeof(vk_flash_attn_push_constants), {Br, 1, 1},
4030
+ get_fa_spec_constants(fa.first), aligned ? Bc : 1, true, false, 0);
4031
+ }
3449
4032
  }
3450
4033
  #endif
3451
- #undef CREATE_FA
3452
4034
 
3453
4035
  const int mul_mat_id_param_count = 5;
3454
4036
 
@@ -3475,6 +4057,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3475
4057
  CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
3476
4058
  }
3477
4059
  #endif
4060
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q1_0], matmul_q1_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3478
4061
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3479
4062
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3480
4063
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -3495,6 +4078,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3495
4078
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3496
4079
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3497
4080
  CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
4081
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_NVFP4], matmul_nvfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
3498
4082
 
3499
4083
  GGML_ASSERT(device->subgroup_ballot);
3500
4084
 
@@ -3504,6 +4088,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3504
4088
  CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 5)
3505
4089
  }
3506
4090
  #endif
4091
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3507
4092
  CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3508
4093
  CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3509
4094
  CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
@@ -3524,6 +4109,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3524
4109
  CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3525
4110
  CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3526
4111
  CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
4112
+ CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 5)
3527
4113
  #undef CREATE_MM
3528
4114
  #undef CREATE_MM2
3529
4115
  } else
@@ -3565,6 +4151,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3565
4151
  #endif
3566
4152
 
3567
4153
  if (device->coopmat_acc_f16_support) {
4154
+ CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3568
4155
  CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3569
4156
  CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3570
4157
  CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -3586,7 +4173,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
3586
4173
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3587
4174
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3588
4175
  CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
4176
+ CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3589
4177
  } else {
4178
+ CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3590
4179
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3591
4180
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3592
4181
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -3608,6 +4197,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3608
4197
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3609
4198
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3610
4199
  CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
4200
+ CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
3611
4201
  }
3612
4202
 
3613
4203
  GGML_ASSERT(device->subgroup_ballot);
@@ -3621,6 +4211,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3621
4211
  }
3622
4212
  #endif
3623
4213
 
4214
+ CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3624
4215
  CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3625
4216
  CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3626
4217
  CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
@@ -3641,13 +4232,30 @@ static void ggml_vk_load_shaders(vk_device& device) {
3641
4232
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3642
4233
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3643
4234
  CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
4235
+ CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id);
3644
4236
  #undef CREATE_MM2
3645
4237
  #undef CREATE_MM
3646
4238
  } else
3647
4239
  #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3648
4240
  if (device->fp16) {
3649
4241
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
4242
+ // Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true
3650
4243
  #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
4244
+ if (device->mul_mat ## ID ## _l[TYPE]) \
4245
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
4246
+ if (device->mul_mat ## ID ## _m[TYPE]) \
4247
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
4248
+ if (device->mul_mat ## ID ## _s[TYPE]) \
4249
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
4250
+ if (device->mul_mat ## ID ## _l[TYPE]) \
4251
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
4252
+ if (device->mul_mat ## ID ## _m[TYPE]) \
4253
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
4254
+ if (device->mul_mat ## ID ## _s[TYPE]) \
4255
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
4256
+
4257
+ // bf16 scalar path promotes to f32, no dot2 variant
4258
+ #define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3651
4259
  if (device->mul_mat ## ID ## _l[TYPE]) \
3652
4260
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3653
4261
  if (device->mul_mat ## ID ## _m[TYPE]) \
@@ -3662,13 +4270,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
3662
4270
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3663
4271
 
3664
4272
  #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
3665
- if (device->mul_mat ## ID ## _l[TYPE]) { \
4273
+ if (device->mul_mat ## ID ## _l_int[TYPE]) { \
3666
4274
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3667
4275
  } \
3668
- if (device->mul_mat ## ID ## _m[TYPE]) { \
4276
+ if (device->mul_mat ## ID ## _m_int[TYPE]) { \
3669
4277
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3670
4278
  } \
3671
- if (device->mul_mat ## ID ## _s[TYPE]) { \
4279
+ if (device->mul_mat ## ID ## _s_int[TYPE]) { \
3672
4280
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3673
4281
  } \
3674
4282
 
@@ -3682,14 +4290,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
3682
4290
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3683
4291
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3684
4292
 
3685
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
4293
+ CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3686
4294
 
4295
+ CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0], matmul_q1_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3687
4296
  CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3688
4297
  CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3689
4298
  CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3690
4299
  CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3691
4300
  CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3692
-
3693
4301
  CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3694
4302
  CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3695
4303
  CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@@ -3705,6 +4313,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3705
4313
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3706
4314
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3707
4315
  CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
4316
+ CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4], matmul_nvfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3708
4317
 
3709
4318
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3710
4319
  if (device->integer_dot_product) {
@@ -3728,8 +4337,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
3728
4337
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3729
4338
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3730
4339
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3731
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3732
-
4340
+ CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
4341
+ CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_subgroup_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3733
4342
  CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3734
4343
  CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3735
4344
  CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
@@ -3750,6 +4359,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3750
4359
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3751
4360
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3752
4361
  CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
4362
+ CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_subgroup_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3753
4363
 
3754
4364
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3755
4365
  if (device->integer_dot_product) {
@@ -3772,8 +4382,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
3772
4382
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3773
4383
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3774
4384
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3775
- CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3776
-
4385
+ CREATE_MM_NODOT2(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
4386
+ CREATE_MM2(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0], matmul_id_q1_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3777
4387
  CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3778
4388
  CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3779
4389
  CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
@@ -3794,6 +4404,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3794
4404
  CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3795
4405
  CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3796
4406
  CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
4407
+ CREATE_MM2(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4], matmul_id_nvfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3797
4408
 
3798
4409
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3799
4410
  if (device->integer_dot_product) {
@@ -3816,6 +4427,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3816
4427
  #undef CREATE_MM2
3817
4428
  #undef CREATE_MMQ
3818
4429
  #undef CREATE_MM
4430
+ #undef CREATE_MM_NODOT2
3819
4431
  } else {
3820
4432
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
3821
4433
  #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \
@@ -3833,11 +4445,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
3833
4445
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \
3834
4446
 
3835
4447
  #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
3836
- if (device->mul_mat ## ID ## _l[TYPE]) \
4448
+ if (device->mul_mat ## ID ## _l_int[TYPE]) \
3837
4449
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
3838
- if (device->mul_mat ## ID ## _m[TYPE]) \
4450
+ if (device->mul_mat ## ID ## _m_int[TYPE]) \
3839
4451
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
3840
- if (device->mul_mat ## ID ## _s[TYPE]) \
4452
+ if (device->mul_mat ## ID ## _s_int[TYPE]) \
3841
4453
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
3842
4454
 
3843
4455
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
@@ -3847,6 +4459,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3847
4459
 
3848
4460
  CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3849
4461
 
4462
+ CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q1_0].f32acc, matmul_q1_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3850
4463
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3851
4464
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3852
4465
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
@@ -3868,6 +4481,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3868
4481
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3869
4482
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3870
4483
  CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
4484
+ CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_NVFP4].f32acc, matmul_nvfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0);
3871
4485
 
3872
4486
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3873
4487
  if (device->integer_dot_product) {
@@ -3891,6 +4505,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3891
4505
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3892
4506
  CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16);
3893
4507
 
4508
+ CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_subgroup_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3894
4509
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3895
4510
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3896
4511
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
@@ -3911,12 +4526,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
3911
4526
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3912
4527
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3913
4528
  CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
4529
+ CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_subgroup_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size);
3914
4530
  } else {
3915
4531
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3916
4532
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3917
4533
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3918
4534
  CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3919
4535
 
4536
+ CREATE_MM(GGML_TYPE_Q1_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q1_0].f32acc, matmul_id_q1_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3920
4537
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3921
4538
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3922
4539
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
@@ -3937,6 +4554,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
3937
4554
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3938
4555
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3939
4556
  CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
4557
+ CREATE_MM(GGML_TYPE_NVFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_NVFP4].f32acc, matmul_id_nvfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3940
4558
  }
3941
4559
  }
3942
4560
  // reusing CREATE_MM from the fp32 path
@@ -3956,11 +4574,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
3956
4574
  m_wg_denoms = { 64, 64, 1 };
3957
4575
  s_wg_denoms = { 32, 32, 1 };
3958
4576
 
3959
- if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {
3960
- // Xe2/Xe3 - bf16 warptile performance tuning
3961
- l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };
3962
- }
3963
-
3964
4577
  CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
3965
4578
  CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
3966
4579
  }
@@ -4014,6 +4627,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4014
4627
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
4015
4628
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
4016
4629
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
4630
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f32_f32", arr_dmmv_q1_0_f32_f32_len[reduc], arr_dmmv_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
4017
4631
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
4018
4632
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
4019
4633
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
@@ -4034,10 +4648,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
4034
4648
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4035
4649
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4036
4650
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4651
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f32_f32", arr_dmmv_nvfp4_f32_f32_len[reduc16], arr_dmmv_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4037
4652
 
4038
4653
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {wg_size_subgroup, 1, i+1}, 1, false, use_subgroups, force_subgroup_size);
4039
4654
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
4040
4655
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size);
4656
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q1_0][i], "mul_mat_vec_q1_0_f16_f32", arr_dmmv_q1_0_f16_f32_len[reduc], arr_dmmv_q1_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
4041
4657
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
4042
4658
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
4043
4659
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size);
@@ -4058,6 +4674,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4058
4674
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4059
4675
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4060
4676
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4677
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_NVFP4][i], "mul_mat_vec_nvfp4_f16_f32", arr_dmmv_nvfp4_f16_f32_len[reduc16], arr_dmmv_nvfp4_f16_f32_data[reduc16], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16);
4061
4678
 
4062
4679
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
4063
4680
  if (device->integer_dot_product) {
@@ -4088,6 +4705,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4088
4705
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {wg_size_subgroup, 1}, 1, false, use_subgroups, force_subgroup_size);
4089
4706
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
4090
4707
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
4708
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q1_0], "mul_mat_vec_id_q1_0_f32", arr_dmmv_id_q1_0_f32_f32_len[reduc], arr_dmmv_id_q1_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
4091
4709
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
4092
4710
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
4093
4711
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
@@ -4108,6 +4726,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4108
4726
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
4109
4727
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
4110
4728
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
4729
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_NVFP4], "mul_mat_vec_id_nvfp4_f32", arr_dmmv_id_nvfp4_f32_f32_len[reduc16], arr_dmmv_id_nvfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
4111
4730
 
4112
4731
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
4113
4732
  if (device->integer_dot_product) {
@@ -4142,6 +4761,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4142
4761
 
4143
4762
  // dequant shaders
4144
4763
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
4764
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q1_0], "dequant_q1_0", dequant_q1_0_len, dequant_q1_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 8, 1, 1}, {}, 1);
4145
4765
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
4146
4766
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
4147
4767
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -4162,11 +4782,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
4162
4782
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
4163
4783
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
4164
4784
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
4785
+ ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_NVFP4], "dequant_nvfp4", dequant_nvfp4_len, dequant_nvfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
4165
4786
 
4166
4787
  // get_rows
4167
4788
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
4168
4789
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
4169
4790
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
4791
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q1_0], "get_rows_q1_0", get_rows_q1_0_len, get_rows_q1_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4170
4792
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4171
4793
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4172
4794
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -4187,11 +4809,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
4187
4809
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4188
4810
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4189
4811
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4812
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_NVFP4], "get_rows_nvfp4", get_rows_nvfp4_len, get_rows_nvfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4190
4813
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_I32], "get_rows_i32", get_rows_i32_len, get_rows_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4191
4814
 
4192
4815
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
4193
4816
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
4194
4817
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
4818
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q1_0], "get_rows_q1_0_f32", get_rows_q1_0_f32_len, get_rows_q1_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4195
4819
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4196
4820
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4197
4821
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -4212,6 +4836,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4212
4836
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4213
4837
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4214
4838
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4839
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
4215
4840
 
4216
4841
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
4217
4842
  ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true);
@@ -4244,10 +4869,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
4244
4869
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true);
4245
4870
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true);
4246
4871
 
4247
- if (device->float_controls_rte_fp16 &&
4248
- sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
4872
+ if (sizeof(vk_op_rms_norm_mul_rope_push_constants) <= device->properties.limits.maxPushConstantsSize) {
4249
4873
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f32, "rms_norm_mul_rope_f32_f32", rms_norm_mul_rope_f32_f32_len, rms_norm_mul_rope_f32_f32_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
4250
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_rte_len, rms_norm_mul_rope_f32_f16_rte_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
4874
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_rope_f32_f16, "rms_norm_mul_rope_f32_f16", rms_norm_mul_rope_f32_f16_len, rms_norm_mul_rope_f32_f16_data, "main", 7, sizeof(vk_op_rms_norm_mul_rope_push_constants), {1, 1, 1}, {0, 1}, 1, true);
4251
4875
  }
4252
4876
 
4253
4877
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -4258,6 +4882,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
4258
4882
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4259
4883
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4260
4884
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4885
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_bf16_f32,"cpy_bf16_f32",cpy_bf16_f32_len,cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4261
4886
  ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4262
4887
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4263
4888
 
@@ -4266,49 +4891,39 @@ static void ggml_vk_load_shaders(vk_device& device) {
4266
4891
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4267
4892
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4268
4893
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4894
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_bf16_f32,"contig_cpy_bf16_f32",contig_cpy_bf16_f32_len,contig_cpy_bf16_f32_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4269
4895
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4270
4896
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4271
4897
 
4272
4898
  ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_32, "cpy_transpose_32", cpy_transpose_32_len, cpy_transpose_32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
4273
4899
  ggml_vk_create_pipeline(device, device->pipeline_cpy_transpose_16, "cpy_transpose_16", cpy_transpose_16_len, cpy_transpose_16_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
4274
4900
 
4275
- if (device->float_controls_rte_fp16) {
4276
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4277
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4278
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4279
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4280
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4281
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4282
- } else {
4283
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4284
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4285
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4286
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4287
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4288
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4289
- }
4290
-
4291
- #define SET_ROWS(itype, rte) \
4292
- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4293
- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4294
- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4295
- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4296
- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4297
- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4298
- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4299
- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4300
- ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
4301
-
4302
- if (device->float_controls_rte_fp16) {
4303
- SET_ROWS(_i32, _rte)
4304
- SET_ROWS(_i64, _rte)
4305
- } else {
4306
- SET_ROWS(_i32, )
4307
- SET_ROWS(_i64, )
4308
- }
4901
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q1_0], "cpy_f32_q1_0", cpy_f32_q1_0_len, cpy_f32_q1_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4902
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4903
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4904
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4905
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4906
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4907
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
4908
+
4909
+ #define SET_ROWS(itype) \
4910
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## _len, set_rows_f32 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4911
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## _len, set_rows_f16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4912
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## _len, set_rows_bf16 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4913
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q1_0], "set_rows_q1_0" #itype, set_rows_q1_0 ## itype ## _len, set_rows_q1_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4914
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## _len, set_rows_q4_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4915
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## _len, set_rows_q4_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4916
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## _len, set_rows_q5_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4917
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## _len, set_rows_q5_1 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4918
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## _len, set_rows_q8_0 ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
4919
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## _len, set_rows_iq4_nl ## itype ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
4920
+
4921
+ SET_ROWS(_i32)
4922
+ SET_ROWS(_i64)
4309
4923
  #undef SET_ROWS
4310
4924
 
4311
4925
 
4926
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q1_0], "cpy_q1_0_f32", cpy_q1_0_f32_len, cpy_q1_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q1_0), 1, 1}, {}, 1);
4312
4927
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
4313
4928
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
4314
4929
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
@@ -4324,11 +4939,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
4324
4939
  return s;
4325
4940
  };
4326
4941
 
4327
- bool rte = device->float_controls_rte_fp16;
4328
4942
  #define CREATE_BINARY(name, namemod, spec, bindings) \
4329
4943
  for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
4330
4944
  ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
4331
- #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
4945
+ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
4332
4946
  "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
4333
4947
 
4334
4948
  CREATE_BINARY(add, , {0}, 4)
@@ -4371,13 +4985,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
4371
4985
  ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4372
4986
  ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4373
4987
 
4374
- if (device->float_controls_rte_fp16) {
4375
- ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32_rte", log_f32_rte_len, log_f32_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4376
- ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16_rte", log_f16_rte_len, log_f16_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4377
- } else {
4378
- ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4379
- ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4380
- }
4988
+ ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4989
+ ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4381
4990
 
4382
4991
  ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4383
4992
  ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -4391,9 +5000,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
4391
5000
 
4392
5001
  ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4393
5002
 
4394
- ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
5003
+ ggml_vk_create_pipeline(device, device->pipeline_repeat_i32, "repeat_i32", repeat_i32_len, repeat_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4395
5004
  ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
4396
5005
 
5006
+ ggml_vk_create_pipeline(device, device->pipeline_repeat_i16, "repeat_i16", repeat_i16_len, repeat_i16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
5007
+
4397
5008
  #define CREATE_UNARY(name) \
4398
5009
  ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4399
5010
  ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
@@ -4418,19 +5029,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
4418
5029
  CREATE_UNARY(floor)
4419
5030
  CREATE_UNARY(trunc)
4420
5031
  CREATE_UNARY(sgn)
5032
+ CREATE_UNARY(exp)
4421
5033
  #undef CREATE_UNARY
4422
5034
 
4423
- #define CREATE_UNARY_RTE(name) \
4424
- if (device->float_controls_rte_fp16) { \
4425
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4426
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4427
- } else { \
4428
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4429
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
4430
- }
4431
- CREATE_UNARY_RTE(exp)
4432
- #undef CREATE_UNARY_RTE
4433
-
4434
5035
  ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f16, "add1_f16_f16", add1_f16_f16_len, add1_f16_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4435
5036
  ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
4436
5037
  ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -4438,15 +5039,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
4438
5039
  ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4439
5040
 
4440
5041
  ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
5042
+ ggml_vk_create_pipeline(device, device->pipeline_fill_f16, "fill_f16", fill_f16_len, fill_f16_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4441
5043
 
4442
5044
  #define CREATE_GLU(name) \
4443
- if (device->float_controls_rte_fp16) { \
4444
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
4445
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
4446
- } else { \
4447
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
4448
- ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
4449
- }
5045
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
5046
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
4450
5047
 
4451
5048
  CREATE_GLU(geglu)
4452
5049
  CREATE_GLU(reglu)
@@ -4479,25 +5076,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
4479
5076
  ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4480
5077
  ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4481
5078
 
4482
- if (device->float_controls_rte_fp16) {
4483
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4484
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4485
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4486
- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4487
-
4488
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_rte_len, rope_norm_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4489
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_rte_len, rope_neox_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4490
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_rte_len, rope_multi_f32_f16_rte_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4491
- } else {
4492
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4493
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4494
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4495
- ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
5079
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
5080
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
5081
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
5082
+ ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4496
5083
 
4497
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4498
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4499
- ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4500
- }
5084
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32_f16, "rope_norm_f32_f16", rope_norm_f32_f16_len, rope_norm_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
5085
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32_f16, "rope_neox_f32_f16", rope_neox_f32_f16_len, rope_neox_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
5086
+ ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32_f16, "rope_multi_f32_f16", rope_multi_f32_f16_len, rope_multi_f32_f16_data, "main", 5, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
4501
5087
 
4502
5088
  for (uint32_t i = 0; i < num_argsort_pipelines; ++i) {
4503
5089
  uint32_t BLOCK_SIZE = 1u << std::min(i, device->max_workgroup_size_log2);
@@ -4531,6 +5117,24 @@ static void ggml_vk_load_shaders(vk_device& device) {
4531
5117
  ggml_vk_create_pipeline(device, device->pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
4532
5118
 
4533
5119
  ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
5120
+ // Intel Arc B390 was observed segfaulting with this shader.
5121
+ if (device->subgroup_basic && device->subgroup_shuffle && device->vendor_id != VK_VENDOR_ID_INTEL) {
5122
+ int idx = 0;
5123
+ for (uint32_t n : {64, 128, 256, 512}) {
5124
+ if (device->subgroup_size <= n) {
5125
+ ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_f32", fwht_f32_len, fwht_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { device->subgroup_size, n }, 1, true, true, device->subgroup_size);
5126
+ }
5127
+ ++idx;
5128
+ }
5129
+ } else if (device->driver_id != vk::DriverId::eIntelProprietaryWindows) {
5130
+ // Disabled on Intel Windows due to a driver bug: https://github.com/ggml-org/llama.cpp/pull/23964#issuecomment-4598226147
5131
+ int idx = 0;
5132
+ for (uint32_t n : {64, 128, 256, 512}) {
5133
+ const uint32_t block_size = std::min(device->subgroup_size, n);
5134
+ ggml_vk_create_pipeline(device, device->pipeline_fwht_f32[idx], "fwht_shmem_f32", fwht_shmem_f32_len, fwht_shmem_f32_data, "main", 2, sizeof(vk_op_fwht_push_constants), {1, 1, 1}, { block_size, n }, 1);
5135
+ ++idx;
5136
+ }
5137
+ }
4534
5138
 
4535
5139
  const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4;
4536
5140
  ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
@@ -4559,13 +5163,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
4559
5163
  #define IM2COL(bda) \
4560
5164
  ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
4561
5165
  ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
4562
- if (device->float_controls_rte_fp16) { \
4563
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
4564
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
4565
- } else { \
4566
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
4567
- ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
4568
- }
5166
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
5167
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
4569
5168
  if (device->shader_int64 && device->buffer_device_address) {
4570
5169
  IM2COL(_bda)
4571
5170
  } else {
@@ -4576,6 +5175,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
4576
5175
 
4577
5176
  ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);
4578
5177
 
5178
+ ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
5179
+ ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
5180
+ ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
5181
+
4579
5182
  ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
4580
5183
 
4581
5184
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
@@ -4589,12 +5192,42 @@ static void ggml_vk_load_shaders(vk_device& device) {
4589
5192
  {"gated_delta_net_f32_d64", "gated_delta_net_f32_d64_kda"},
4590
5193
  {"gated_delta_net_f32_d128", "gated_delta_net_f32_d128_kda"},
4591
5194
  };
5195
+ const bool use_subgroup_reduce = device->subgroup_arithmetic;
4592
5196
  for (uint32_t si = 0; si < 3; si++) {
5197
+ const uint32_t S_V = gdn_sizes[si];
5198
+ GGML_ASSERT(is_pow2(S_V));
5199
+
5200
+ uint32_t lanes_per_column;
5201
+ if (S_V >= 128u && device->subgroup_clustered) {
5202
+ lanes_per_column = 8u;
5203
+ } else {
5204
+ // Use largest power-of-two that divides both S_V and subgroup_size so that
5205
+ // (1) S_V % lanes_per_column == 0 and (2) S_V % (subgroup_size / lanes_per_column) == 0.
5206
+ // This means we don't need extra bounds checking logic in the shader.
5207
+ lanes_per_column = std::min(S_V, device->subgroup_size);
5208
+ }
5209
+
5210
+ const bool need_clustered_shader = lanes_per_column != 1 && (lanes_per_column < device->subgroup_size);
5211
+ size_t gdn_len;
5212
+ const void * gdn_data;
5213
+ if (use_subgroup_reduce && need_clustered_shader) {
5214
+ gdn_len = gated_delta_net_f32_len;
5215
+ gdn_data = (const void *)gated_delta_net_f32_data;
5216
+ } else if (use_subgroup_reduce) {
5217
+ gdn_len = gated_delta_net_f32_nocluster_len;
5218
+ gdn_data = (const void *)gated_delta_net_f32_nocluster_data;
5219
+ } else {
5220
+ gdn_len = gated_delta_net_f32_shmem_len;
5221
+ gdn_data = (const void *)gated_delta_net_f32_shmem_data;
5222
+ }
5223
+
5224
+ const uint32_t cols_per_wg = device->subgroup_size / lanes_per_column;
5225
+ const std::array<uint32_t, 3> wg_denoms = {1u, 1u, cols_per_wg};
5226
+
4593
5227
  for (uint32_t kda = 0; kda < 2; kda++) {
4594
5228
  ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net[si][kda],
4595
- gdn_names[si][kda], gated_delta_net_f32_len, gated_delta_net_f32_data,
4596
- "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
4597
- {1, 1, 1}, {gdn_sizes[si], kda}, 1);
5229
+ gdn_names[si][kda], gdn_len, gdn_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants),
5230
+ wg_denoms, {S_V, kda, device->subgroup_size, lanes_per_column}, 1, true, use_subgroup_reduce, device->subgroup_size);
4598
5231
  }
4599
5232
  }
4600
5233
  }
@@ -4607,7 +5240,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
4607
5240
  ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
4608
5241
  }
4609
5242
 
4610
- ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);
5243
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 0}, 1);
5244
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_silu_f32, "ssm_conv_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 0, 1}, 1);
5245
+ ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_bias_silu_f32, "ssm_conv_bias_silu_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 4, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16, 1, 1}, 1);
4611
5246
 
4612
5247
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4613
5248
 
@@ -4615,7 +5250,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
4615
5250
 
4616
5251
  // conv2d, conv_transpose_2d
4617
5252
  for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
4618
- uint32_t conv2d_WG_SIZE = 256;
5253
+ // smaller WG for the small-tile fallback gives more concurrent WGs per SM
5254
+ uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256;
4619
5255
  uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices.
4620
5256
  uint32_t conv2d_TS_K = (s == CONV_SHAPE_64x32) ? 4 : 8;
4621
5257
  uint32_t conv2d_SHMEM_PAD = 4;
@@ -4654,18 +5290,77 @@ static void ggml_vk_load_shaders(vk_device& device) {
4654
5290
  conv2d_BS.CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used.
4655
5291
  }
4656
5292
 
4657
- uint32_t conv2d_shmem_req =
4658
- (conv2d_BS.K * (conv2d_BS.CRS + conv2d_SHMEM_PAD) + conv2d_BS.CRS * (conv2d_BS.NPQ + conv2d_SHMEM_PAD)) * sizeof(float);
4659
- if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) {
5293
+ // cm1 is used only when cm2 is unavailable; capped at 64x128 (due to shared memory size).
5294
+ // Requires 16x16x16 f16-acc since that's the fragment shape hard-coded in the shader.
5295
+ // Subgroup size must be 32 or 64 (to keep WG_SIZE sane) and we need
5296
+ // subgroup_size_control to force the driver to actually use it.
5297
+ bool conv2d_use_cm1 = false;
5298
+ #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
5299
+ conv2d_use_cm1 = !device->coopmat2 &&
5300
+ device->coopmat_support && device->coopmat_support_16x16x16_f16acc &&
5301
+ device->subgroup_size_control &&
5302
+ (device->subgroup_size == 32 || device->subgroup_size == 64) &&
5303
+ s != CONV_SHAPE_128x128;
5304
+ #endif
5305
+
5306
+ const uint32_t conv2d_cm1_shmem_pad = 8;
5307
+
5308
+ auto shmem_req = [&](uint32_t pad, bool csh_store, bool fp16_shmem) {
5309
+ const uint32_t elem_size = fp16_shmem ? (uint32_t)sizeof(uint16_t) : (uint32_t)sizeof(float);
5310
+ const uint32_t csh_elems = csh_store ? conv2d_BS.K * conv2d_BS.NPQ : 0u;
5311
+ return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size;
5312
+ };
5313
+
5314
+ // coopmat1 needs to store the output through shared memory, so check up front
5315
+ // whether it'll fit and disable it before applying coopmat1 parameters.
5316
+ if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) {
5317
+ conv2d_use_cm1 = false;
5318
+ }
5319
+
5320
+ uint32_t conv2d_WM = 16, conv2d_WN = 16; // cm1 subgroup tile, ignored otherwise
5321
+ if (conv2d_use_cm1) {
5322
+ conv2d_SHMEM_PAD = conv2d_cm1_shmem_pad;
5323
+ // 16x16x16 fragments; pick WM/WN to keep WG_SIZE at 256
5324
+ // (i.e. 8 subgroups for sg=32, 4 subgroups for sg=64).
5325
+ const bool sg64 = (device->subgroup_size == 64);
5326
+ switch (s) {
5327
+ case CONV_SHAPE_64x32: conv2d_WM = sg64 ? 32 : 16; conv2d_WN = 16; break;
5328
+ case CONV_SHAPE_64x128: conv2d_WM = 32; conv2d_WN = sg64 ? 64 : 32; break;
5329
+ case CONV_SHAPE_32x256: conv2d_WM = sg64 ? 16 : 32; conv2d_WN = sg64 ? 128 : 32; break;
5330
+ default: break;
5331
+ }
5332
+ const uint32_t warps_M = conv2d_BS.K / conv2d_WM;
5333
+ const uint32_t warps_N = conv2d_BS.NPQ / conv2d_WN;
5334
+ conv2d_WG_SIZE = warps_M * warps_N * device->subgroup_size;
5335
+ }
5336
+
5337
+ // stage cm2 accumulator through shmem for coalesced global stores;
5338
+ // skipped on 128x128 where the extra Csh footprint hurts occupancy.
5339
+ // cm1 always uses the staged path.
5340
+ uint32_t conv2d_csh_store = (device->coopmat2 && s != CONV_SHAPE_128x128) ? 1u : 0u;
5341
+ if (conv2d_use_cm1) {
5342
+ conv2d_csh_store = 1;
5343
+ }
5344
+
5345
+ // shmem is fp16 on cm2/cm1 (matches Csh), fp32 on scalar
5346
+ const bool conv2d_use_fp16_shmem = device->coopmat2 || conv2d_use_cm1;
5347
+
5348
+ // shrink CRS if the non-cm1 config still doesn't fit
5349
+ if (device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_SHMEM_PAD, conv2d_csh_store, conv2d_use_fp16_shmem)) {
5350
+ GGML_ASSERT(!conv2d_use_cm1);
4660
5351
  conv2d_BS.CRS = 8;
4661
5352
  if (use_collectives) {
4662
5353
  conv2d_BS.CRS = std::min(device->subgroup_size, conv2d_BS.CRS);
4663
5354
  }
5355
+ conv2d_csh_store = 0;
4664
5356
  }
4665
5357
 
4666
5358
  std::array<uint32_t, 3> wg_denoms = { conv2d_BS.K, 1, 1 };
4667
5359
  std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
4668
5360
 
5361
+ // cm1 needs a fixed subgroup width to match the WG_SIZE we computed
5362
+ const uint32_t conv2d_required_subgroup_size = conv2d_use_cm1 ? device->subgroup_size : 0;
5363
+
4669
5364
  #define CREATE_CONV(name, type_suffix, spv_suffix) \
4670
5365
  for (auto &c : device->pipeline_##name##type_suffix[s]) { \
4671
5366
  const vk_conv2d_pipeline_state &state = c.first; \
@@ -4678,10 +5373,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
4678
5373
  spec_constants_cpy.push_back(state.d1); \
4679
5374
  spec_constants_cpy.push_back(state.KW); \
4680
5375
  spec_constants_cpy.push_back(state.KH); \
5376
+ spec_constants_cpy.push_back(state.aligned); \
5377
+ spec_constants_cpy.push_back(conv2d_csh_store); \
5378
+ spec_constants_cpy.push_back(conv2d_WM); \
5379
+ spec_constants_cpy.push_back(conv2d_WN); \
4681
5380
  ggml_vk_create_pipeline( \
4682
5381
  device, c.second, #name #type_suffix, \
4683
5382
  name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \
4684
- sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives); \
5383
+ sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants_cpy, 1, true, use_collectives || conv2d_required_subgroup_size, conv2d_required_subgroup_size); \
4685
5384
  }
4686
5385
  #define CREATE_CONVS(spv_suffix) \
4687
5386
  CREATE_CONV(conv2d, _f32, spv_suffix) \
@@ -4692,6 +5391,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
4692
5391
  if (device->coopmat2) {
4693
5392
  CREATE_CONVS(_cm2)
4694
5393
  } else
5394
+ #endif
5395
+ #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
5396
+ if (conv2d_use_cm1) {
5397
+ CREATE_CONVS(_cm1)
5398
+ } else
4695
5399
  #endif
4696
5400
  if (conv2d_UNROLL) {
4697
5401
  CREATE_CONVS(_unroll)
@@ -4713,8 +5417,25 @@ static void ggml_vk_load_shaders(vk_device& device) {
4713
5417
  }
4714
5418
  }
4715
5419
 
4716
- for (auto &c : compiles) {
4717
- c.wait();
5420
+ // Drop compile_mutex so other threads can walk while we compile.
5421
+ compile_lock.unlock();
5422
+
5423
+ // Compile what we claimed; create_pipeline_func reacquires compile_mutex
5424
+ // at the end to flip compile_pending/compiled and notify waiters.
5425
+ if (has_claimed_task) {
5426
+ auto & task = claimed_task;
5427
+ ggml_vk_create_pipeline_func(device, task.pipeline, task.spv_size, task.spv_data,
5428
+ task.entrypoint, task.parameter_count, task.wg_denoms,
5429
+ task.specialization_constants, task.disable_robustness,
5430
+ task.require_full_subgroups, task.required_subgroup_size);
5431
+ }
5432
+
5433
+ // Another thread may be compiling the pipeline we need; block on it here.
5434
+ if (wait_pipeline) {
5435
+ std::unique_lock<std::mutex> wait_lock(device->compile_mutex);
5436
+ device->compile_cv.wait(wait_lock, [&] {
5437
+ return wait_pipeline->compiled.load();
5438
+ });
4718
5439
  }
4719
5440
  }
4720
5441
 
@@ -4764,11 +5485,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
4764
5485
  bool amd_shader_core_properties2 = false;
4765
5486
  bool pipeline_robustness = false;
4766
5487
  bool coopmat2_support = false;
5488
+ bool coopmat2_decode_vector_support = false;
4767
5489
  bool pipeline_executable_properties_support = false;
4768
5490
  device->coopmat_support = false;
4769
5491
  device->integer_dot_product = false;
4770
5492
  device->shader_64b_indexing = false;
4771
5493
  bool bfloat16_support = false;
5494
+ bool dot2_f16_support = false;
4772
5495
 
4773
5496
  for (const auto& properties : ext_props) {
4774
5497
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -4798,6 +5521,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
4798
5521
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
4799
5522
  coopmat2_support = true;
4800
5523
  #endif
5524
+ } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
5525
+ !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
5526
+ coopmat2_decode_vector_support = true;
4801
5527
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
4802
5528
  } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
4803
5529
  !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
@@ -4808,6 +5534,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
4808
5534
  !getenv("GGML_VK_DISABLE_BFLOAT16")) {
4809
5535
  bfloat16_support = true;
4810
5536
  #endif
5537
+ } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
5538
+ !getenv("GGML_VK_DISABLE_DOT2")) {
5539
+ dot2_f16_support = true;
4811
5540
  } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) {
4812
5541
  pipeline_executable_properties_support = true;
4813
5542
  } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
@@ -4955,6 +5684,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
4955
5684
  #endif
4956
5685
  device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4957
5686
  (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
5687
+ #ifdef __APPLE__
5688
+ if (device->vendor_id == VK_VENDOR_ID_AMD) {
5689
+ device->subgroup_shuffle = false;
5690
+ }
5691
+ #endif
4958
5692
  device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
4959
5693
  (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered);
4960
5694
 
@@ -4981,8 +5715,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
4981
5715
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
4982
5716
 
4983
5717
  // Try to find a non-graphics compute queue and transfer-focused queues
4984
- // On AMD, the graphics queue seems to be faster, so don't avoid it
4985
- const vk::QueueFlagBits graphics_flag = device->vendor_id == VK_VENDOR_ID_AMD ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;
5718
+ // Allow overriding avoiding the graphics queue because it can increase performance on RADV
5719
+ const bool allow_graphics_queue = (getenv("GGML_VK_ALLOW_GRAPHICS_QUEUE") != nullptr);
5720
+ const vk::QueueFlagBits graphics_flag = allow_graphics_queue ? (vk::QueueFlagBits)0 : vk::QueueFlagBits::eGraphics;
4986
5721
  const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, graphics_flag, -1, 1);
4987
5722
  const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | graphics_flag, compute_queue_family_index, 1);
4988
5723
 
@@ -4998,7 +5733,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
4998
5733
  } else {
4999
5734
  device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
5000
5735
  }
5001
- vk::DeviceCreateInfo device_create_info;
5736
+ vk::DeviceCreateInfo device_create_info{};
5002
5737
  std::vector<const char *> device_extensions;
5003
5738
  vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
5004
5739
 
@@ -5074,6 +5809,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
5074
5809
  }
5075
5810
  #endif
5076
5811
 
5812
+ VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
5813
+ coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
5814
+ if (coopmat2_decode_vector_support) {
5815
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
5816
+ last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
5817
+ device_extensions.push_back(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME);
5818
+ }
5819
+
5077
5820
  #if defined(VK_KHR_shader_bfloat16)
5078
5821
  VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
5079
5822
  bfloat16_features.pNext = nullptr;
@@ -5101,6 +5844,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
5101
5844
  device_extensions.push_back("VK_KHR_shader_integer_dot_product");
5102
5845
  }
5103
5846
 
5847
+ VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
5848
+ dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
5849
+ if (dot2_f16_support) {
5850
+ last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
5851
+ last_struct = (VkBaseOutStructure *)&dot2_features;
5852
+ device_extensions.push_back("VK_VALVE_shader_mixed_float_dot_product");
5853
+ }
5854
+
5104
5855
  VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {};
5105
5856
  pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR;
5106
5857
  if (pipeline_executable_properties_support) {
@@ -5135,6 +5886,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
5135
5886
  device->bf16 = false;
5136
5887
  #endif
5137
5888
 
5889
+ device->dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
5890
+
5138
5891
  device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
5139
5892
 
5140
5893
  device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
@@ -5193,46 +5946,73 @@ static vk_device ggml_vk_get_device(size_t idx) {
5193
5946
  found_fp16_256 = false,
5194
5947
  found_fp32_128 = false,
5195
5948
  found_fp32_256 = false;
5949
+ bool found_bf16_128 = false,
5950
+ found_bf16_256 = false;
5196
5951
  // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
5197
5952
  // with 32x16x16 and 256 with 32x32x16.
5198
5953
  for (auto &prop : flexible_dimensions) {
5199
5954
  if (prop.saturatingAccumulation == VK_FALSE &&
5200
- prop.scope == VK_SCOPE_WORKGROUP_KHR &&
5201
- prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
5202
- prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
5203
-
5204
- if (prop.workgroupInvocations == 128 &&
5205
- prop.MGranularity <= 32 &&
5206
- prop.NGranularity <= 16 &&
5207
- prop.KGranularity <= 16) {
5208
- if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
5209
- prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
5210
- found_fp16_128 = true;
5955
+ prop.scope == VK_SCOPE_WORKGROUP_KHR) {
5956
+
5957
+ if (prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
5958
+ prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
5959
+
5960
+ if (prop.workgroupInvocations == 128 &&
5961
+ prop.MGranularity <= 32 &&
5962
+ prop.NGranularity <= 16 &&
5963
+ prop.KGranularity <= 16) {
5964
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
5965
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
5966
+ found_fp16_128 = true;
5967
+ }
5968
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
5969
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
5970
+ found_fp32_128 = true;
5971
+ }
5211
5972
  }
5212
- if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
5213
- prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
5214
- found_fp32_128 = true;
5973
+ if (prop.workgroupInvocations == 256 &&
5974
+ prop.MGranularity <= 32 &&
5975
+ prop.NGranularity <= 32 &&
5976
+ prop.KGranularity <= 16) {
5977
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
5978
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
5979
+ found_fp16_256 = true;
5980
+ }
5981
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
5982
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
5983
+ found_fp32_256 = true;
5984
+ }
5215
5985
  }
5216
5986
  }
5217
- if (prop.workgroupInvocations == 256 &&
5218
- prop.MGranularity <= 32 &&
5219
- prop.NGranularity <= 32 &&
5220
- prop.KGranularity <= 16) {
5221
- if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
5222
- prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
5223
- found_fp16_256 = true;
5987
+
5988
+ #if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
5989
+ if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
5990
+ prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
5991
+ prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
5992
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
5993
+
5994
+ if (prop.workgroupInvocations == 128 &&
5995
+ prop.MGranularity <= 32 &&
5996
+ prop.NGranularity <= 16 &&
5997
+ prop.KGranularity <= 16) {
5998
+ found_bf16_128 = true;
5224
5999
  }
5225
- if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
5226
- prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
5227
- found_fp32_256 = true;
6000
+ if (prop.workgroupInvocations == 256 &&
6001
+ prop.MGranularity <= 32 &&
6002
+ prop.NGranularity <= 32 &&
6003
+ prop.KGranularity <= 16) {
6004
+ found_bf16_256 = true;
5228
6005
  }
5229
6006
  }
6007
+ #endif
5230
6008
  }
5231
6009
  }
5232
6010
  if (found_fp16_128 && found_fp16_256 &&
5233
6011
  found_fp32_128 && found_fp32_256 &&
5234
6012
  coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
5235
6013
  device->coopmat2 = true;
6014
+ device->coopmat2_bf16_support = found_bf16_128 && found_bf16_256;
6015
+ device->coopmat2_decode_vector = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
5236
6016
  }
5237
6017
  }
5238
6018
  #endif
@@ -5367,12 +6147,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
5367
6147
  #endif
5368
6148
  device->name = GGML_VK_NAME + std::to_string(idx);
5369
6149
 
5370
- device_create_info = {
5371
- vk::DeviceCreateFlags(),
5372
- device_queue_create_infos,
5373
- {},
5374
- device_extensions
5375
- };
6150
+ device_create_info
6151
+ .setFlags(vk::DeviceCreateFlags())
6152
+ .setQueueCreateInfos(device_queue_create_infos)
6153
+ .setPEnabledExtensionNames(device_extensions);
5376
6154
  device_create_info.setPNext(&device_features2);
5377
6155
  device->device = device->physical_device.createDevice(device_create_info);
5378
6156
 
@@ -5392,19 +6170,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
5392
6170
  device->mul_mat_id_m[i] = true;
5393
6171
  device->mul_mat_id_s[i] = true;
5394
6172
  break;
5395
- case VK_VENDOR_ID_INTEL:
5396
- if (!device->coopmat_support || device->architecture != INTEL_XE2) {
5397
- device->mul_mat_l[i] = false;
5398
- device->mul_mat_id_l[i] = false;
5399
- } else {
5400
- device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel
5401
- device->mul_mat_id_l[i] = true;
5402
- }
6173
+ case VK_VENDOR_ID_INTEL: {
6174
+ // Current Windows driver does not expose BF16 support.
6175
+ // We only want to use l_warptile if coopmat is available and is Xe2+
6176
+ const bool xe2_with_coopmat = device->coopmat_support && device->architecture == INTEL_XE2;
6177
+ const bool use_l_warptile = (i == GGML_TYPE_BF16) ? (device->coopmat_bf16_support && xe2_with_coopmat) : xe2_with_coopmat;
6178
+ device->mul_mat_l[i] = use_l_warptile;
6179
+ device->mul_mat_id_l[i] = use_l_warptile;
5403
6180
  device->mul_mat_m[i] = true;
5404
6181
  device->mul_mat_s[i] = true;
5405
6182
  device->mul_mat_id_m[i] = true;
5406
6183
  device->mul_mat_id_s[i] = true;
5407
6184
  break;
6185
+ }
5408
6186
  case VK_VENDOR_ID_APPLE:
5409
6187
  device->mul_mat_l[i] = false;
5410
6188
  device->mul_mat_m[i] = true;
@@ -5423,6 +6201,26 @@ static vk_device ggml_vk_get_device(size_t idx) {
5423
6201
  device->mul_mat_id_s[i] = true;
5424
6202
  break;
5425
6203
  }
6204
+
6205
+ #if VK_HEADER_VERSION >= 287
6206
+ // Honeykrisp driver for Asahi Linux doesn't report VK_VENDOR_ID_APPLE.
6207
+ // Check for Honeykrisp driver and force same configuration as the VK_VENDOR_ID_APPLE case.
6208
+ if (device->driver_id == vk::DriverId::eMesaHoneykrisp) {
6209
+ device->mul_mat_l[i] = false;
6210
+ device->mul_mat_m[i] = true;
6211
+ device->mul_mat_s[i] = false;
6212
+ device->mul_mat_id_l[i] = false;
6213
+ device->mul_mat_id_m[i] = true;
6214
+ device->mul_mat_id_s[i] = false;
6215
+ }
6216
+ #endif
6217
+
6218
+ device->mul_mat_l_int[i] = device->mul_mat_l[i];
6219
+ device->mul_mat_m_int[i] = device->mul_mat_m[i];
6220
+ device->mul_mat_s_int[i] = device->mul_mat_s[i];
6221
+ device->mul_mat_id_l_int[i] = device->mul_mat_id_l[i];
6222
+ device->mul_mat_id_m_int[i] = device->mul_mat_id_m[i];
6223
+ device->mul_mat_id_s_int[i] = device->mul_mat_id_s[i];
5426
6224
  }
5427
6225
 
5428
6226
 
@@ -5443,11 +6241,18 @@ static vk_device ggml_vk_get_device(size_t idx) {
5443
6241
 
5444
6242
  ggml_vk_load_shaders(device);
5445
6243
 
6244
+ // Prefer a dedicated transfer queue on AMD dGPUs (non-GCN) when graphics queue use is disabled.
6245
+ const bool prefers_transfer_queue =
6246
+ device->vendor_id == VK_VENDOR_ID_AMD &&
6247
+ device->architecture != AMD_GCN &&
6248
+ !device->uma &&
6249
+ !allow_graphics_queue;
6250
+
5446
6251
  if (!device->single_queue) {
5447
6252
  const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
5448
6253
  ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true);
5449
6254
 
5450
- device->async_use_transfer_queue = (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
6255
+ device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr);
5451
6256
  } else {
5452
6257
  // TODO: Use pointer or reference to avoid copy
5453
6258
  device->transfer_queue.copyFrom(device->compute_queue);
@@ -5507,8 +6312,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
5507
6312
  bool fp16_compute = false;
5508
6313
  bool coopmat_support = false;
5509
6314
  bool coopmat2_support = false;
6315
+ bool coopmat2_decode_vector_support = false;
5510
6316
  bool integer_dot_product = false;
5511
6317
  bool bfloat16_support = false;
6318
+ bool dot2_f16_support = false;
5512
6319
 
5513
6320
  for (auto properties : ext_props) {
5514
6321
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -5525,6 +6332,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
5525
6332
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
5526
6333
  coopmat2_support = true;
5527
6334
  #endif
6335
+ } else if (strcmp(VK_NV_COOPERATIVE_MATRIX_DECODE_VECTOR_EXTENSION_NAME, properties.extensionName) == 0 &&
6336
+ !getenv("GGML_VK_DISABLE_COOPMAT2_DECODE_VECTOR")) {
6337
+ coopmat2_decode_vector_support = true;
5528
6338
  #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
5529
6339
  } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
5530
6340
  !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
@@ -5535,6 +6345,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
5535
6345
  !getenv("GGML_VK_DISABLE_BFLOAT16")) {
5536
6346
  bfloat16_support = true;
5537
6347
  #endif
6348
+ } else if (strcmp("VK_VALVE_shader_mixed_float_dot_product", properties.extensionName) == 0 &&
6349
+ !getenv("GGML_VK_DISABLE_DOT2")) {
6350
+ dot2_f16_support = true;
5538
6351
  }
5539
6352
  }
5540
6353
 
@@ -5609,6 +6422,29 @@ static void ggml_vk_print_gpu_info(size_t idx) {
5609
6422
  }
5610
6423
  #endif
5611
6424
 
6425
+ #if defined(VK_NV_cooperative_matrix2)
6426
+ VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
6427
+ coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
6428
+ if (coopmat2_support) {
6429
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
6430
+ last_struct = (VkBaseOutStructure *)&coopmat2_features;
6431
+ }
6432
+ #endif
6433
+
6434
+ VkPhysicalDeviceCooperativeMatrixDecodeVectorFeaturesNV coopmat2_decode_vector_features {};
6435
+ coopmat2_decode_vector_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_DECODE_VECTOR_FEATURES_NV;
6436
+ if (coopmat2_decode_vector_support) {
6437
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
6438
+ last_struct = (VkBaseOutStructure *)&coopmat2_decode_vector_features;
6439
+ }
6440
+
6441
+ VkPhysicalDeviceShaderMixedFloatDotProductFeaturesVALVE dot2_features {};
6442
+ dot2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_MIXED_FLOAT_DOT_PRODUCT_FEATURES_VALVE;
6443
+ if (dot2_f16_support) {
6444
+ last_struct->pNext = (VkBaseOutStructure *)&dot2_features;
6445
+ last_struct = (VkBaseOutStructure *)&dot2_features;
6446
+ }
6447
+
5612
6448
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
5613
6449
 
5614
6450
  fp16 = fp16 && vk12_features.shaderFloat16;
@@ -5633,11 +6469,34 @@ static void ggml_vk_print_gpu_info(size_t idx) {
5633
6469
  #endif
5634
6470
  && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
5635
6471
 
5636
- std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
6472
+ #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
6473
+ coopmat2_support = coopmat2_support &&
6474
+ coopmat2_features.cooperativeMatrixWorkgroupScope &&
6475
+ coopmat2_features.cooperativeMatrixFlexibleDimensions &&
6476
+ coopmat2_features.cooperativeMatrixReductions &&
6477
+ coopmat2_features.cooperativeMatrixConversions &&
6478
+ coopmat2_features.cooperativeMatrixPerElementOperations &&
6479
+ coopmat2_features.cooperativeMatrixTensorAddressing &&
6480
+ coopmat2_features.cooperativeMatrixBlockLoads;
6481
+ #else
6482
+ coopmat2_support = false;
6483
+ #endif
6484
+
6485
+ coopmat2_decode_vector_support = coopmat2_decode_vector_support && coopmat2_decode_vector_features.cooperativeMatrixDecodeVector;
6486
+ #if !defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
6487
+ coopmat2_decode_vector_support = false;
6488
+ #endif
6489
+
6490
+ std::string matrix_cores = coopmat2_support ? (coopmat2_decode_vector_support ? "NV_coopmat2v" : "NV_coopmat2")
6491
+ : coopmat_support ? "KHR_coopmat"
6492
+ : "none";
6493
+
6494
+ bool dot2_f16 = dot2_f16_support && dot2_features.shaderMixedFloatDotProductFloat16AccFloat32;
6495
+ const char *fp16_str = fp16 ? (dot2_f16 ? "dot2" : "1") : "0";
5637
6496
 
5638
6497
  std::string device_name = props2.properties.deviceName.data();
5639
- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
5640
- idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size,
6498
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %s | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
6499
+ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16_str, bf16, subgroup_size,
5641
6500
  props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
5642
6501
 
5643
6502
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
@@ -5953,6 +6812,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
5953
6812
  VK_LOG_DEBUG("ggml_vk_get_to_fp16()");
5954
6813
  switch (type) {
5955
6814
  case GGML_TYPE_F32:
6815
+ case GGML_TYPE_Q1_0:
5956
6816
  case GGML_TYPE_Q4_0:
5957
6817
  case GGML_TYPE_Q4_1:
5958
6818
  case GGML_TYPE_Q5_0:
@@ -5973,6 +6833,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
5973
6833
  case GGML_TYPE_IQ4_XS:
5974
6834
  case GGML_TYPE_IQ4_NL:
5975
6835
  case GGML_TYPE_MXFP4:
6836
+ case GGML_TYPE_NVFP4:
5976
6837
  break;
5977
6838
  default:
5978
6839
  return nullptr;
@@ -6024,6 +6885,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
6024
6885
  }
6025
6886
 
6026
6887
  switch (src0_type) {
6888
+ case GGML_TYPE_Q1_0:
6027
6889
  case GGML_TYPE_Q4_0:
6028
6890
  case GGML_TYPE_Q4_1:
6029
6891
  case GGML_TYPE_Q5_0:
@@ -6044,6 +6906,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
6044
6906
  case GGML_TYPE_IQ4_XS:
6045
6907
  case GGML_TYPE_IQ4_NL:
6046
6908
  case GGML_TYPE_MXFP4:
6909
+ case GGML_TYPE_NVFP4:
6047
6910
  break;
6048
6911
  default:
6049
6912
  return nullptr;
@@ -6089,6 +6952,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
6089
6952
  case GGML_TYPE_F32:
6090
6953
  case GGML_TYPE_F16:
6091
6954
  case GGML_TYPE_BF16:
6955
+ case GGML_TYPE_Q1_0:
6092
6956
  case GGML_TYPE_Q4_0:
6093
6957
  case GGML_TYPE_Q4_1:
6094
6958
  case GGML_TYPE_Q5_0:
@@ -6109,6 +6973,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
6109
6973
  case GGML_TYPE_IQ4_XS:
6110
6974
  case GGML_TYPE_IQ4_NL:
6111
6975
  case GGML_TYPE_MXFP4:
6976
+ case GGML_TYPE_NVFP4:
6112
6977
  break;
6113
6978
  default:
6114
6979
  return nullptr;
@@ -6179,6 +7044,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
6179
7044
  GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
6180
7045
 
6181
7046
  switch (src0_type) {
7047
+ case GGML_TYPE_Q1_0:
6182
7048
  case GGML_TYPE_Q4_0:
6183
7049
  case GGML_TYPE_Q4_1:
6184
7050
  case GGML_TYPE_Q5_0:
@@ -6199,6 +7065,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
6199
7065
  case GGML_TYPE_IQ4_XS:
6200
7066
  case GGML_TYPE_IQ4_NL:
6201
7067
  case GGML_TYPE_MXFP4:
7068
+ case GGML_TYPE_NVFP4:
6202
7069
  break;
6203
7070
  default:
6204
7071
  return nullptr;
@@ -6247,6 +7114,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
6247
7114
  case GGML_TYPE_F32:
6248
7115
  case GGML_TYPE_F16:
6249
7116
  case GGML_TYPE_BF16:
7117
+ case GGML_TYPE_Q1_0:
6250
7118
  case GGML_TYPE_Q4_0:
6251
7119
  case GGML_TYPE_Q4_1:
6252
7120
  case GGML_TYPE_Q5_0:
@@ -6267,6 +7135,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
6267
7135
  case GGML_TYPE_IQ4_XS:
6268
7136
  case GGML_TYPE_IQ4_NL:
6269
7137
  case GGML_TYPE_MXFP4:
7138
+ case GGML_TYPE_NVFP4:
6270
7139
  break;
6271
7140
  default:
6272
7141
  return nullptr;
@@ -6313,7 +7182,7 @@ static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
6313
7182
  return nullptr;
6314
7183
  }
6315
7184
 
6316
- std::lock_guard<std::recursive_mutex> guard(device->mutex);
7185
+ std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex);
6317
7186
  device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf));
6318
7187
 
6319
7188
  return buf->ptr;
@@ -6324,7 +7193,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
6324
7193
  return;
6325
7194
  }
6326
7195
  VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")");
6327
- std::lock_guard<std::recursive_mutex> guard(device->mutex);
7196
+ std::lock_guard<std::shared_mutex> guard(device->pinned_memory_mutex);
6328
7197
 
6329
7198
  vk_buffer buf;
6330
7199
  size_t index;
@@ -6348,7 +7217,7 @@ static void ggml_vk_host_free(vk_device& device, void* ptr) {
6348
7217
  }
6349
7218
 
6350
7219
  static void ggml_vk_host_get(const vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) {
6351
- std::lock_guard<std::recursive_mutex> guard(device->mutex);
7220
+ std::shared_lock<std::shared_mutex> guard(device->pinned_memory_mutex);
6352
7221
  buf = nullptr;
6353
7222
  buf_offset = 0;
6354
7223
  for (size_t i = 0; i < device->pinned_memory.size(); i++) {
@@ -6392,6 +7261,7 @@ static vk_subbuffer ggml_vk_tensor_subbuffer(
6392
7261
  static vk_command_buffer* ggml_vk_get_or_create_cmd_buffer(vk_device& device, vk_command_pool& pool) {
6393
7262
  for (auto& cmd_buffer : pool.cmd_buffers) {
6394
7263
  if (!cmd_buffer.in_use) {
7264
+ cmd_buffer.use_counter++;
6395
7265
  cmd_buffer.in_use = true;
6396
7266
  return &cmd_buffer;
6397
7267
  }
@@ -6468,13 +7338,6 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context&
6468
7338
  subctx->s->buffer->buf.dispatch(wg0, wg1, wg2);
6469
7339
  }
6470
7340
 
6471
- static void ggml_vk_end_submission(vk_submission& s, std::vector<vk_semaphore> wait_semaphores, std::vector<vk_semaphore> signal_semaphores) {
6472
- s.buffer->buf.end();
6473
-
6474
- s.wait_semaphores = std::move(wait_semaphores);
6475
- s.signal_semaphores = std::move(signal_semaphores);
6476
- }
6477
-
6478
7341
  static void ggml_vk_ctx_end(vk_context& ctx) {
6479
7342
  VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")");
6480
7343
  if (ctx->s == nullptr) {
@@ -6496,14 +7359,15 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) {
6496
7359
  }
6497
7360
 
6498
7361
  static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) {
7362
+ vk_context result;
6499
7363
  if (!ctx->compute_ctx.expired()) {
6500
- return ctx->compute_ctx.lock();
6501
- }
6502
-
6503
- vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
7364
+ result = ctx->compute_ctx.lock();
7365
+ } else {
7366
+ result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool);
6504
7367
 
6505
- ctx->compute_ctx = result;
6506
- ggml_vk_ctx_begin(ctx->device, result);
7368
+ ctx->compute_ctx = result;
7369
+ ggml_vk_ctx_begin(ctx->device, result);
7370
+ }
6507
7371
 
6508
7372
  if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) {
6509
7373
  result->s->wait_semaphores.push_back(ctx->transfer_semaphore);
@@ -6626,7 +7490,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
6626
7490
  const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1;
6627
7491
  const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1;
6628
7492
  for (uint64_t i0 = 0; i0 < ne0; i0++) {
6629
- slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 });
7493
+ slices.push_back({ s_off + i0*nb0, d_off + i0*dstnb0, dstnb0 });
6630
7494
  }
6631
7495
  }
6632
7496
  }
@@ -6674,7 +7538,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
6674
7538
  }
6675
7539
  }
6676
7540
 
6677
- static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) {
7541
+ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
6678
7542
  VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")");
6679
7543
  // Check if src is pinned memory
6680
7544
  vk_buffer buf = nullptr;
@@ -6684,7 +7548,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
6684
7548
  if (buf != nullptr) {
6685
7549
  // Memory is pinned, use as staging buffer
6686
7550
  std::vector<vk::BufferCopy> slices(1);
6687
- if (width == spitch) {
7551
+ if (width == spitch && width == dpitch) {
6688
7552
  // Only do single write if stride is equal
6689
7553
  slices[0].srcOffset = buf_offset;
6690
7554
  slices[0].dstOffset = offset;
@@ -6693,7 +7557,7 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
6693
7557
  slices.resize(height);
6694
7558
  for (size_t i = 0; i < height; i++) {
6695
7559
  slices[i].srcOffset = buf_offset + i * spitch;
6696
- slices[i].dstOffset = offset + i * width;
7560
+ slices[i].dstOffset = offset + i * dpitch;
6697
7561
  slices[i].size = width;
6698
7562
  }
6699
7563
  }
@@ -6710,21 +7574,30 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
6710
7574
  }
6711
7575
 
6712
7576
  // Staging buffer required
6713
- const size_t copy_size = width*height;
6714
- ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size);
7577
+ const size_t staging_size = width * height;
7578
+ ggml_vk_ensure_sync_staging_buffer(dst->device, staging_size);
6715
7579
 
6716
7580
  vk_buffer& staging_buffer = dst->device->sync_staging;
6717
7581
 
6718
- VkBufferCopy buf_copy = {
6719
- 0,
6720
- offset,
6721
- copy_size};
7582
+ std::vector<vk::BufferCopy> slices(1);
7583
+ if (width == dpitch) {
7584
+ slices[0].srcOffset = 0;
7585
+ slices[0].dstOffset = offset;
7586
+ slices[0].size = staging_size;
7587
+ } else {
7588
+ slices.resize(height);
7589
+ for (size_t i = 0; i < height; i++) {
7590
+ slices[i].srcOffset = i * width;
7591
+ slices[i].dstOffset = offset + i * dpitch;
7592
+ slices[i].size = width;
7593
+ }
7594
+ }
6722
7595
 
6723
7596
  ggml_vk_sync_buffers(nullptr, subctx);
6724
- vkCmdCopyBuffer(subctx->s->buffer->buf, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
7597
+ subctx->s->buffer->buf.copyBuffer((VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, slices);
6725
7598
 
6726
7599
  if (width == spitch) {
6727
- deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
7600
+ deferred_memcpy((uint8_t *)staging_buffer->ptr, src, staging_size, &subctx->in_memcpys);
6728
7601
  } else {
6729
7602
  for (size_t i = 0; i < height; i++) {
6730
7603
  deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys);
@@ -6735,24 +7608,28 @@ static bool ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
6735
7608
 
6736
7609
  static bool ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) {
6737
7610
  VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")");
6738
- return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging);
7611
+ return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, size, 1, sync_staging);
6739
7612
  }
6740
7613
 
6741
- static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) {
7614
+ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t dpitch, size_t width, size_t height) {
6742
7615
  VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")");
6743
7616
  // Buffer is already mapped
6744
7617
  if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
6745
7618
  GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
6746
7619
 
6747
- for (size_t i = 0; i < height; i++) {
6748
- memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width);
7620
+ if (width == spitch && width == dpitch) {
7621
+ memcpy((uint8_t *)dst->ptr + offset, src, width * height);
7622
+ } else {
7623
+ for (size_t i = 0; i < height; i++) {
7624
+ memcpy((uint8_t *)dst->ptr + offset + i * dpitch, (const uint8_t *) src + i * spitch, width);
7625
+ }
6749
7626
  }
6750
7627
  } else {
6751
7628
  std::lock_guard<std::recursive_mutex> guard(dst->device->mutex);
6752
7629
 
6753
7630
  vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool);
6754
7631
  ggml_vk_ctx_begin(dst->device, subctx);
6755
- bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true);
7632
+ bool ret = ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, dpitch, width, height, true);
6756
7633
  GGML_ASSERT(ret);
6757
7634
  ggml_vk_ctx_end(subctx);
6758
7635
 
@@ -6773,7 +7650,7 @@ static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void *
6773
7650
 
6774
7651
  static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) {
6775
7652
  VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")");
6776
- ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1);
7653
+ ggml_vk_buffer_write_2d(dst, offset, src, size, size, size, 1);
6777
7654
  }
6778
7655
 
6779
7656
  static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) {
@@ -6819,15 +7696,35 @@ static bool ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
6819
7696
  }
6820
7697
 
6821
7698
  // Fall back to staging buffer
6822
- const size_t copy_size = dpitch * height;
6823
- ggml_vk_ensure_sync_staging_buffer(src->device, copy_size);
7699
+ const size_t staging_size = width * height;
7700
+ ggml_vk_ensure_sync_staging_buffer(src->device, staging_size);
6824
7701
 
6825
7702
  vk_buffer& staging_buffer = src->device->sync_staging;
6826
7703
 
7704
+ std::vector<vk::BufferCopy> staging_slices(1);
7705
+ if (width == spitch) {
7706
+ staging_slices[0].srcOffset = offset;
7707
+ staging_slices[0].dstOffset = 0;
7708
+ staging_slices[0].size = staging_size;
7709
+ } else {
7710
+ staging_slices.resize(height);
7711
+ for (size_t i = 0; i < height; i++) {
7712
+ staging_slices[i].srcOffset = offset + i * spitch;
7713
+ staging_slices[i].dstOffset = i * width;
7714
+ staging_slices[i].size = width;
7715
+ }
7716
+ }
7717
+
6827
7718
  ggml_vk_sync_buffers(nullptr, subctx);
6828
- subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, slices);
7719
+ subctx->s->buffer->buf.copyBuffer(src->buffer, staging_buffer->buffer, staging_slices);
6829
7720
 
6830
- deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys);
7721
+ if (width == dpitch) {
7722
+ deferred_memcpy(dst, staging_buffer->ptr, staging_size, &subctx->out_memcpys);
7723
+ } else {
7724
+ for (size_t i = 0; i < height; i++) {
7725
+ deferred_memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) staging_buffer->ptr + i * width, width, &subctx->out_memcpys);
7726
+ }
7727
+ }
6831
7728
  return true;
6832
7729
  }
6833
7730
 
@@ -6835,8 +7732,8 @@ static bool ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t
6835
7732
  return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging);
6836
7733
  }
6837
7734
 
6838
- static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
6839
- VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
7735
+ static void ggml_vk_buffer_read_2d(vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height) {
7736
+ VK_LOG_DEBUG("ggml_vk_buffer_read_2d(" << src->buffer << ", " << offset << ", " << width << ", " << height << ")");
6840
7737
 
6841
7738
  // If the device is not an UMA device the memory is host-accessible through rebar. While writing
6842
7739
  // through PCIe is sufficient fast reading back data from PCIe is slower than going through
@@ -6844,18 +7741,24 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
6844
7741
  if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) {
6845
7742
  GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent);
6846
7743
 
6847
- memcpy(dst, (uint8_t *) src->ptr + offset, size);
7744
+ if (width == spitch && width == dpitch) {
7745
+ memcpy(dst, (const uint8_t *) src->ptr + offset, width * height);
7746
+ } else {
7747
+ for (size_t i = 0; i < height; i++) {
7748
+ memcpy((uint8_t *) dst + i * dpitch, (const uint8_t *) src->ptr + offset + i * spitch, width);
7749
+ }
7750
+ }
6848
7751
  } else {
6849
7752
  std::lock_guard<std::recursive_mutex> guard(src->device->mutex);
6850
7753
 
6851
7754
  vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool);
6852
7755
  ggml_vk_ctx_begin(src->device, subctx);
6853
- bool ret = ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true);
7756
+ bool ret = ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, spitch, dpitch, width, height, true);
6854
7757
  GGML_ASSERT(ret);
6855
7758
  ggml_vk_ctx_end(subctx);
6856
7759
 
6857
7760
  ggml_vk_submit(subctx, src->device->fence);
6858
- VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences");
7761
+ VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read_2d waitForFences");
6859
7762
  src->device->device.resetFences({ src->device->fence });
6860
7763
  ggml_vk_queue_command_pools_cleanup(src->device);
6861
7764
 
@@ -6865,6 +7768,11 @@ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_
6865
7768
  }
6866
7769
  }
6867
7770
 
7771
+ static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) {
7772
+ VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")");
7773
+ ggml_vk_buffer_read_2d(src, offset, dst, size, size, size, 1);
7774
+ }
7775
+
6868
7776
  static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
6869
7777
  VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")");
6870
7778
  // Make sure both buffers are on same device
@@ -6896,7 +7804,7 @@ static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& sr
6896
7804
  // Copy to src staging buffer
6897
7805
  ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size);
6898
7806
  // Copy to dst buffer
6899
- ggml_vk_buffer_write_2d(dst, dst_offset, src->device->sync_staging->ptr, 0, size, 1);
7807
+ ggml_vk_buffer_write(dst, dst_offset, src->device->sync_staging->ptr, size);
6900
7808
  }
6901
7809
  }
6902
7810
 
@@ -6979,6 +7887,13 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m,
6979
7887
  static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
6980
7888
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
6981
7889
 
7890
+ // The q8_1 (integer dot) mmq path uses a different shader with its own
7891
+ // shared-memory layout, so use the int-specific availability flags.
7892
+ const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
7893
+ const bool mm_l = is_q8_1 ? ctx->device->mul_mat_l_int[src0_type] : ctx->device->mul_mat_l[src0_type];
7894
+ const bool mm_m = is_q8_1 ? ctx->device->mul_mat_m_int[src0_type] : ctx->device->mul_mat_m[src0_type];
7895
+ const bool mm_s = is_q8_1 ? ctx->device->mul_mat_s_int[src0_type] : ctx->device->mul_mat_s[src0_type];
7896
+
6982
7897
  if (ctx->device->coopmat2) {
6983
7898
  const uint32_t shader_core_count = ctx->device->shader_core_count;
6984
7899
  const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]);
@@ -6995,26 +7910,24 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
6995
7910
  // split_k==3 with large tiles likely better than medium tiles with no split_k.
6996
7911
  (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2);
6997
7912
 
6998
- if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
7913
+ if ((mm_l && (n > crossover_large && prefer_large)) || (!mm_m && !mm_s)) {
6999
7914
  return aligned ? mmp->a_l : mmp->l;
7000
7915
  }
7001
7916
  // Use medium shader when the N dimension is greater than the small shader's tile size
7002
7917
  uint32_t crossover_medium = mmp->s->wg_denoms[1];
7003
- if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
7918
+ if ((mm_m && (n > crossover_medium)) || !mm_s) {
7004
7919
  return aligned ? mmp->a_m : mmp->m;
7005
7920
  }
7006
7921
  return aligned ? mmp->a_s : mmp->s;
7007
7922
  }
7008
7923
 
7009
- if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) {
7924
+ if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) {
7010
7925
  return aligned ? mmp->a_s : mmp->s;
7011
7926
  }
7012
- if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) {
7927
+ if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) {
7013
7928
  return aligned ? mmp->a_m : mmp->m;
7014
7929
  }
7015
7930
  return aligned ? mmp->a_l : mmp->l;
7016
-
7017
- GGML_UNUSED(src1_type);
7018
7931
  }
7019
7932
 
7020
7933
  static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
@@ -7071,35 +7984,42 @@ static void ggml_vk_matmul(
7071
7984
  ctx->prealloc_split_k_need_sync = true;
7072
7985
  }
7073
7986
 
7074
- static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
7075
- VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
7987
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
7988
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
7989
+
7990
+ // The q8_1 (integer dot) mmq path uses a different shader with its own
7991
+ // shared-memory layout, so use the int-specific availability flags.
7992
+ const bool is_q8_1 = (src1_type == GGML_TYPE_Q8_1);
7993
+ const bool mm_l = is_q8_1 ? ctx->device->mul_mat_id_l_int[src0_type] : ctx->device->mul_mat_id_l[src0_type];
7994
+ const bool mm_m = is_q8_1 ? ctx->device->mul_mat_id_m_int[src0_type] : ctx->device->mul_mat_id_m[src0_type];
7995
+ const bool mm_s = is_q8_1 ? ctx->device->mul_mat_id_s_int[src0_type] : ctx->device->mul_mat_id_s[src0_type];
7076
7996
 
7077
7997
  if (ctx->device->coopmat2) {
7078
7998
  // Use large shader when the N dimension is greater than the medium shader's tile size
7079
7999
  uint32_t crossover_large = mmp->m->wg_denoms[1];
7080
- if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
8000
+ if ((mm_l && (n > crossover_large)) || (!mm_m && !mm_s)) {
7081
8001
  return aligned ? mmp->a_l : mmp->l;
7082
8002
  }
7083
8003
  // Use medium shader when the N dimension is greater than the small shader's tile size
7084
8004
  uint32_t crossover_medium = mmp->s->wg_denoms[1];
7085
- if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
8005
+ if ((mm_m && (n > crossover_medium)) || !mm_s) {
7086
8006
  return aligned ? mmp->a_m : mmp->m;
7087
8007
  }
7088
8008
  return aligned ? mmp->a_s : mmp->s;
7089
8009
  }
7090
8010
 
7091
- if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) {
8011
+ if ((mm_s && (m <= 32 || n <= 32)) || (!mm_m && !mm_l)) {
7092
8012
  return aligned ? mmp->a_s : mmp->s;
7093
8013
  }
7094
- if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) {
8014
+ if ((mm_m && (m <= 64 || n <= 64)) || !mm_l) {
7095
8015
  return aligned ? mmp->a_m : mmp->m;
7096
8016
  }
7097
8017
  return aligned ? mmp->a_l : mmp->l;
7098
8018
  }
7099
8019
 
7100
- static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
7101
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
7102
- return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align;
8020
+ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
8021
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
8022
+ return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
7103
8023
  }
7104
8024
 
7105
8025
  static void ggml_vk_matmul_id(
@@ -7176,6 +8096,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
7176
8096
  return ctx->device->pipeline_cpy_f32_bf16;
7177
8097
  }
7178
8098
  }
8099
+ if (src->type == GGML_TYPE_BF16 && to == GGML_TYPE_F32) {
8100
+ if (contig) {
8101
+ return ctx->device->pipeline_contig_cpy_bf16_f32;
8102
+ } else {
8103
+ return ctx->device->pipeline_cpy_bf16_f32;
8104
+ }
8105
+ }
7179
8106
  if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
7180
8107
  if (contig) {
7181
8108
  return ctx->device->pipeline_contig_cpy_f32_i32;
@@ -7192,6 +8119,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
7192
8119
  }
7193
8120
  if (src->type == GGML_TYPE_F32) {
7194
8121
  switch (to) {
8122
+ case GGML_TYPE_Q1_0:
7195
8123
  case GGML_TYPE_Q4_0:
7196
8124
  case GGML_TYPE_Q4_1:
7197
8125
  case GGML_TYPE_Q5_0:
@@ -7206,6 +8134,7 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
7206
8134
 
7207
8135
  if (to == GGML_TYPE_F32) {
7208
8136
  switch (src->type) {
8137
+ case GGML_TYPE_Q1_0:
7209
8138
  case GGML_TYPE_Q4_0:
7210
8139
  case GGML_TYPE_Q4_1:
7211
8140
  case GGML_TYPE_Q5_0:
@@ -7272,6 +8201,40 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
7272
8201
  ggml_vk_sync_buffers(ctx, subctx);
7273
8202
  }
7274
8203
 
8204
+ // Copy/convert tensor into a caller-defined dense layout. Destination strides
8205
+ // are in output elements, not bytes.
8206
+ static void ggml_vk_cpy_to_strided(
8207
+ ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor,
8208
+ const vk_subbuffer & in, const vk_subbuffer & out,
8209
+ uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13) {
8210
+ VK_LOG_DEBUG("ggml_vk_cpy_to_strided((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), ";
8211
+ std::cerr << "dst_nb=(" << nb10 << ", " << nb11 << ", " << nb12 << ", " << nb13 << "), buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")");
8212
+ const int tensor_type_size = ggml_type_size(tensor->type);
8213
+
8214
+ const uint32_t ne = ggml_nelements(tensor);
8215
+ std::array<uint32_t, 3> elements;
8216
+
8217
+ if (ne > 262144) {
8218
+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
8219
+ } else if (ne > 512) {
8220
+ elements = { 512, CEIL_DIV(ne, 512), 1 };
8221
+ } else {
8222
+ elements = { ne, 1, 1 };
8223
+ }
8224
+
8225
+ vk_op_unary_push_constants pc = {
8226
+ (uint32_t)ne,
8227
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
8228
+ (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], nb10, nb11, nb12, nb13,
8229
+ 0,
8230
+ 0.0f, 0.0f,
8231
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
8232
+ };
8233
+ init_pushconst_fastdiv(pc);
8234
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements);
8235
+ ggml_vk_sync_buffers(ctx, subctx);
8236
+ }
8237
+
7275
8238
  static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
7276
8239
  switch(type) {
7277
8240
  case GGML_TYPE_Q8_1:
@@ -7393,10 +8356,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
7393
8356
  // Not implemented
7394
8357
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
7395
8358
 
7396
- const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
8359
+ const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
8360
+
8361
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
7397
8362
  const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
7398
8363
 
7399
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
8364
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type);
7400
8365
 
7401
8366
  if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
7402
8367
  pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
@@ -7527,24 +8492,28 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
7527
8492
  }
7528
8493
  if (y_non_contig) {
7529
8494
  if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
7530
- ctx->prealloc_y_last_tensor_used != src1) {
8495
+ ctx->prealloc_y_last_tensor_used != src1 ||
8496
+ ctx->prealloc_y_last_decode_vector_staging) {
7531
8497
  if (ctx->prealloc_y_need_sync) {
7532
8498
  ggml_vk_sync_buffers(ctx, subctx);
7533
8499
  }
7534
8500
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
7535
8501
  ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
7536
8502
  ctx->prealloc_y_last_tensor_used = src1;
8503
+ ctx->prealloc_y_last_decode_vector_staging = false;
7537
8504
  }
7538
8505
  }
7539
8506
  if (quantize_y) {
7540
8507
  if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
7541
- ctx->prealloc_y_last_tensor_used != src1) {
8508
+ ctx->prealloc_y_last_tensor_used != src1 ||
8509
+ ctx->prealloc_y_last_decode_vector_staging) {
7542
8510
  if (ctx->prealloc_y_need_sync) {
7543
8511
  ggml_vk_sync_buffers(ctx, subctx);
7544
8512
  }
7545
8513
  ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
7546
8514
  ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
7547
8515
  ctx->prealloc_y_last_tensor_used = src1;
8516
+ ctx->prealloc_y_last_decode_vector_staging = false;
7548
8517
  }
7549
8518
  }
7550
8519
 
@@ -7585,8 +8554,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
7585
8554
  return false;
7586
8555
  }
7587
8556
 
7588
- // General performance issue with q3_k and q6_k due to 2-byte alignment
7589
- if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
8557
+ // q6_k only has 2-byte alignment which makes it somewhat problematic,
8558
+ // using MMVQ is only a win on Intel.
8559
+ bool mmvq_q6 = device->vendor_id == VK_VENDOR_ID_INTEL;
8560
+ if (src0_type == GGML_TYPE_Q6_K && !mmvq_q6) {
7590
8561
  return false;
7591
8562
  }
7592
8563
 
@@ -7598,7 +8569,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
7598
8569
  // Quantization overhead is not worth it for small k
7599
8570
  switch (device->vendor_id) {
7600
8571
  case VK_VENDOR_ID_NVIDIA:
7601
- if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
8572
+ if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) {
7602
8573
  return true;
7603
8574
  }
7604
8575
 
@@ -7625,20 +8596,21 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
7625
8596
  return true;
7626
8597
  }
7627
8598
  case VK_VENDOR_ID_INTEL:
7628
- if (k < 2048) {
7629
- return false;
8599
+ if (device->architecture == vk_device_architecture::INTEL_XE2) {
8600
+ if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
8601
+ return true;
8602
+ }
7630
8603
  }
7631
8604
 
7632
8605
  if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) {
7633
- // Intel Windows proprietary driver tuning
7634
- switch (src0_type) {
7635
- case GGML_TYPE_MXFP4:
7636
- case GGML_TYPE_Q4_K:
7637
- case GGML_TYPE_Q5_K:
7638
- return false;
7639
- default:
7640
- return true;
7641
- }
8606
+ // Intel Windows proprietary driver MMVQ performance for !Q2/Q3/Q6 is worse than fp16,
8607
+ // see https://github.com/ggml-org/llama.cpp/issues/17628 and
8608
+ // https://github.com/ggml-org/llama.cpp/pull/23056
8609
+ return false;
8610
+ }
8611
+
8612
+ if (k < 2048) {
8613
+ return false;
7642
8614
  }
7643
8615
 
7644
8616
  switch (src0_type) {
@@ -7799,24 +8771,28 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
7799
8771
  if (y_non_contig) {
7800
8772
  GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
7801
8773
  if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
7802
- ctx->prealloc_y_last_tensor_used != src1) {
8774
+ ctx->prealloc_y_last_tensor_used != src1 ||
8775
+ ctx->prealloc_y_last_decode_vector_staging) {
7803
8776
  if (ctx->prealloc_y_need_sync) {
7804
8777
  ggml_vk_sync_buffers(ctx, subctx);
7805
8778
  }
7806
8779
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
7807
8780
  ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
7808
8781
  ctx->prealloc_y_last_tensor_used = src1;
8782
+ ctx->prealloc_y_last_decode_vector_staging = false;
7809
8783
  }
7810
8784
  }
7811
8785
  if (quantize_y) {
7812
8786
  if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
7813
- ctx->prealloc_y_last_tensor_used != src1) {
8787
+ ctx->prealloc_y_last_tensor_used != src1 ||
8788
+ ctx->prealloc_y_last_decode_vector_staging) {
7814
8789
  if (ctx->prealloc_y_need_sync) {
7815
8790
  ggml_vk_sync_buffers(ctx, subctx);
7816
8791
  }
7817
8792
  ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
7818
8793
  ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
7819
8794
  ctx->prealloc_y_last_tensor_used = src1;
8795
+ ctx->prealloc_y_last_decode_vector_staging = false;
7820
8796
  }
7821
8797
  }
7822
8798
 
@@ -8060,25 +9036,87 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
8060
9036
  fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1;
8061
9037
  }
8062
9038
 
8063
- // compute
8064
- vk_mat_vec_nc_push_constants pc = {
8065
- (uint32_t)ne00, (uint32_t)ne01,
8066
- row_stride_x, channel_stride_x, channel_stride_y,
8067
- (uint32_t)(ne12 / ne02), (uint32_t)ne12,
8068
- 0, 0,
8069
- nb03, nb13, nb23, fusion_flags
8070
- };
9039
+ // compute
9040
+ vk_mat_vec_nc_push_constants pc = {
9041
+ (uint32_t)ne00, (uint32_t)ne01,
9042
+ row_stride_x, channel_stride_x, channel_stride_y,
9043
+ (uint32_t)(ne12 / ne02), (uint32_t)ne12,
9044
+ 0, 0,
9045
+ nb03, nb13, nb23, fusion_flags
9046
+ };
9047
+
9048
+ init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
9049
+
9050
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
9051
+ {
9052
+ d_Qx,
9053
+ d_Qy,
9054
+ d_D,
9055
+ d_F0,
9056
+ d_F1,
9057
+ }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
9058
+ }
9059
+
9060
+ static int ggml_vk_fwht_pipeline_idx(int64_t n) {
9061
+ switch (n) {
9062
+ case 64: return 0;
9063
+ case 128: return 1;
9064
+ case 256: return 2;
9065
+ case 512: return 3;
9066
+ default: return -1;
9067
+ }
9068
+ }
9069
+
9070
+ static bool ggml_vk_can_use_fwht(const ggml_backend_vk_context * ctx, const ggml_tensor * src1, const ggml_tensor * dst) {
9071
+ if (ctx->num_additional_fused_ops != 0) {
9072
+ return false;
9073
+ }
9074
+
9075
+ if (ggml_get_op_params_i32(dst, 1) != GGML_HINT_SRC0_IS_HADAMARD) {
9076
+ return false;
9077
+ }
9078
+
9079
+ const int idx = ggml_vk_fwht_pipeline_idx(src1->ne[0]);
9080
+ if (idx < 0 || ctx->device->pipeline_fwht_f32[idx] == nullptr) {
9081
+ return false;
9082
+ }
9083
+
9084
+ if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
9085
+ return false;
9086
+ }
9087
+
9088
+ if (!ggml_is_contiguous(src1)) {
9089
+ return false;
9090
+ }
9091
+ GGML_ASSERT(ggml_is_contiguous(dst));
9092
+
9093
+ return true;
9094
+ }
9095
+
9096
+ static void ggml_vk_fwht(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src, ggml_tensor * dst) {
9097
+ const int idx = ggml_vk_fwht_pipeline_idx(src->ne[0]);
9098
+ vk_pipeline pipeline = ctx->device->pipeline_fwht_f32[idx];
9099
+
9100
+ const uint32_t rows_per_workgroup = 4;
9101
+ const uint32_t n_rows = (uint32_t)ggml_nrows(src);
9102
+ const uint32_t max_workgroups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0];
9103
+
9104
+ const uint32_t total_workgroups = CEIL_DIV(n_rows, rows_per_workgroup);
9105
+ const uint32_t workgroups_x = std::min(total_workgroups, max_workgroups_x);
9106
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
9107
+
9108
+ const vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src, true);
9109
+ const vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
8071
9110
 
8072
- init_pushconst_tensor_offsets(ctx, pc, src0, src1, nullptr, nullptr, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]);
9111
+ vk_op_fwht_push_constants pc = {
9112
+ n_rows,
9113
+ 0,
9114
+ 0,
9115
+ 1.0f / std::sqrt((float)src->ne[0]),
9116
+ };
9117
+ init_pushconst_tensor_offsets(ctx, pc, src, nullptr, nullptr, nullptr, dst);
8073
9118
 
8074
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
8075
- {
8076
- d_Qx,
8077
- d_Qy,
8078
- d_D,
8079
- d_F0,
8080
- d_F1,
8081
- }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 });
9119
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { src_buf, dst_buf }, pc, { workgroups_x, 1, 1 });
8082
9120
  }
8083
9121
 
8084
9122
  static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@@ -8114,6 +9152,8 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
8114
9152
 
8115
9153
  m_offset += cur_M_size;
8116
9154
  }
9155
+ } else if (ggml_vk_can_use_fwht(ctx, src1, dst)) {
9156
+ ggml_vk_fwht(ctx, subctx, src1, dst);
8117
9157
  } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
8118
9158
  // detect 0213 permutation, and batch size of 1
8119
9159
  src0->nb[0] <= src0->nb[2] &&
@@ -8203,12 +9243,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
8203
9243
  // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
8204
9244
  const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
8205
9245
  !ggml_vk_dim01_contiguous(src0);
8206
- const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
9246
+ // If src0 is BF16, try to use a BF16 x BF16 multiply
9247
+ ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
9248
+ #if defined(GGML_VULKAN_COOPMAT2_DECODE_VECTOR_GLSLC_SUPPORT)
9249
+ // B must already be, or be convertible to, the matmul B type used by this path.
9250
+ const bool y_decode_vector_supported = ctx->device->coopmat2_decode_vector &&
9251
+ (f16_type != GGML_TYPE_BF16 || ctx->device->coopmat2_bf16_support) &&
9252
+ (src1->type == GGML_TYPE_F32 || src1->type == f16_type);
9253
+ // If B is copied to prealloc_y, we can choose a 4-element-aligned row stride.
9254
+ const bool y_decode_vector_uses_prealloc = !ggml_vk_dim01_contiguous(src1) || src1->type != f16_type;
9255
+ // Direct B reads are safe only if row starts and the original buffer offset are 4-element aligned.
9256
+ const bool y_decode_vector_aligned =
9257
+ (ne10 % 4 == 0) &&
9258
+ (y_decode_vector_uses_prealloc || get_misalign_bytes(ctx, src1) % (4 * ggml_type_size(src1->type)) == 0);
9259
+ // Stage B only when decode-vector is available and direct B reads would be misaligned.
9260
+ const bool y_decode_vector_staging = y_decode_vector_supported && !y_decode_vector_aligned;
9261
+ #else
9262
+ const bool y_decode_vector_staging = false;
9263
+ #endif
9264
+ const bool y_non_contig = y_decode_vector_staging ||
9265
+ (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
8207
9266
  (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
8208
9267
  !ggml_vk_dim01_contiguous(src1);
8209
9268
 
8210
- // If src0 is BF16, try to use a BF16 x BF16 multiply
8211
- ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
9269
+ const uint32_t y_staged_row_stride = y_decode_vector_staging ? (uint32_t)ggml_vk_align_size(ne10, 4) : (uint32_t)ne10;
8212
9270
 
8213
9271
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
8214
9272
 
@@ -8234,10 +9292,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
8234
9292
  // Not implemented
8235
9293
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
8236
9294
 
8237
- const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
9295
+ const ggml_type effective_src1_type = quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type);
9296
+
9297
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type, effective_src1_type));
8238
9298
  const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && nei1 > 8;
8239
9299
 
8240
- vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
9300
+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type, effective_src1_type);
8241
9301
 
8242
9302
  if (ggml_nbytes(src0) > ctx->device->properties.limits.maxStorageBufferRange) {
8243
9303
  pipeline = ggml_vk_get_64b_indexing_pipeline(ctx, pipeline);
@@ -8245,11 +9305,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
8245
9305
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
8246
9306
  uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
8247
9307
  const uint64_t x_ne = ggml_nelements(src0);
8248
- const uint64_t y_ne = padded_n * ne10 * ne12 * ne13;
9308
+ const uint64_t y_ne = (uint64_t)y_staged_row_stride * padded_n * ne12 * ne13;
8249
9309
  const uint64_t d_ne = ggml_nelements(dst);
8250
9310
 
8251
9311
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
8252
- const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
9312
+ const uint64_t qy_sz = ggml_type_size(src1->type) * ggml_nelements(src1) / ggml_blck_size(src1->type);
8253
9313
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
8254
9314
  const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
8255
9315
  const uint64_t ids_sz = nbi2;
@@ -8259,13 +9319,30 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
8259
9319
  vk_pipeline to_fp16_vk_1 = nullptr;
8260
9320
  vk_pipeline to_q8_1 = nullptr;
8261
9321
 
9322
+ auto make_y_staged_dst = [&]() {
9323
+ ggml_tensor y_staged_dst = *src1;
9324
+ y_staged_dst.type = f16_type;
9325
+ y_staged_dst.nb[0] = ggml_type_size(f16_type);
9326
+ y_staged_dst.nb[1] = y_staged_dst.nb[0] * y_staged_row_stride;
9327
+ y_staged_dst.nb[2] = y_staged_dst.nb[1] * padded_n;
9328
+ y_staged_dst.nb[3] = y_staged_dst.nb[2] * y_staged_dst.ne[2];
9329
+ return y_staged_dst;
9330
+ };
9331
+
8262
9332
  if (x_non_contig) {
8263
9333
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
8264
9334
  } else {
8265
9335
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
8266
9336
  }
8267
9337
  if (y_non_contig) {
8268
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
9338
+ ggml_tensor y_staged_dst;
9339
+ const ggml_tensor * y_staged_dst_ptr = nullptr;
9340
+ if (y_decode_vector_staging) {
9341
+ y_staged_dst = make_y_staged_dst();
9342
+ y_staged_dst_ptr = &y_staged_dst;
9343
+ }
9344
+
9345
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, y_staged_dst_ptr, f16_type);
8269
9346
  } else {
8270
9347
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
8271
9348
  }
@@ -8383,30 +9460,47 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
8383
9460
  }
8384
9461
  if (y_non_contig) {
8385
9462
  if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
8386
- ctx->prealloc_y_last_tensor_used != src1) {
9463
+ ctx->prealloc_y_last_tensor_used != src1 ||
9464
+ ctx->prealloc_y_last_decode_vector_staging != y_decode_vector_staging) {
8387
9465
  if (ctx->prealloc_y_need_sync) {
8388
9466
  ggml_vk_sync_buffers(ctx, subctx);
8389
9467
  }
8390
- ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
9468
+ if (y_decode_vector_staging) {
9469
+ const ggml_tensor y_staged_dst = make_y_staged_dst();
9470
+ const uint32_t y_staged_dst_type_size = ggml_type_size(y_staged_dst.type);
9471
+ ggml_vk_cpy_to_strided(
9472
+ ctx, subctx, to_fp16_vk_1, src1,
9473
+ ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0),
9474
+ (uint32_t)(y_staged_dst.nb[0] / y_staged_dst_type_size),
9475
+ (uint32_t)(y_staged_dst.nb[1] / y_staged_dst_type_size),
9476
+ (uint32_t)(y_staged_dst.nb[2] / y_staged_dst_type_size),
9477
+ (uint32_t)(y_staged_dst.nb[3] / y_staged_dst_type_size));
9478
+ } else {
9479
+ ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0));
9480
+ }
8391
9481
  ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
8392
9482
  ctx->prealloc_y_last_tensor_used = src1;
9483
+ ctx->prealloc_y_last_decode_vector_staging = y_decode_vector_staging;
8393
9484
  }
8394
9485
  }
8395
9486
  if (quantize_y) {
8396
9487
  if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
8397
- ctx->prealloc_y_last_tensor_used != src1) {
9488
+ ctx->prealloc_y_last_tensor_used != src1 ||
9489
+ ctx->prealloc_y_last_decode_vector_staging) {
8398
9490
  if (ctx->prealloc_y_need_sync) {
8399
9491
  ggml_vk_sync_buffers(ctx, subctx);
8400
9492
  }
8401
9493
  ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne);
8402
9494
  ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
8403
9495
  ctx->prealloc_y_last_tensor_used = src1;
9496
+ ctx->prealloc_y_last_decode_vector_staging = false;
8404
9497
  }
8405
9498
  }
8406
9499
  ggml_vk_sync_buffers(ctx, subctx);
8407
9500
 
8408
9501
  uint32_t stride_batch_x = ne00*ne01;
8409
- uint32_t stride_batch_y = ne10*ne11;
9502
+ uint32_t stride_b_y = y_decode_vector_staging ? y_staged_row_stride : ne10;
9503
+ uint32_t stride_batch_y = y_decode_vector_staging ? y_staged_row_stride * padded_n : ne10*ne11;
8410
9504
 
8411
9505
  if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
8412
9506
  stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
@@ -8421,7 +9515,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
8421
9515
  ctx, subctx, pipeline,
8422
9516
  { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz },
8423
9517
  { d_D, d_buf_offset, d_sz }, { d_ids, ids_buf_offset, ids_sz }, expert_count_buf,
8424
- ne01, ne21, ne10, ne10, ne10, ne01,
9518
+ ne01, ne21, ne10, ne10, stride_b_y, ne01,
8425
9519
  stride_batch_x, stride_batch_y, ne20*ne21,
8426
9520
  n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
8427
9521
  ); // NOLINT
@@ -8579,24 +9673,28 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
8579
9673
  if (y_non_contig) {
8580
9674
  GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne);
8581
9675
  if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() ||
8582
- ctx->prealloc_y_last_tensor_used != src1) {
9676
+ ctx->prealloc_y_last_tensor_used != src1 ||
9677
+ ctx->prealloc_y_last_decode_vector_staging) {
8583
9678
  if (ctx->prealloc_y_need_sync) {
8584
9679
  ggml_vk_sync_buffers(ctx, subctx);
8585
9680
  }
8586
9681
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, d_Qy, d_Y);
8587
9682
  ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get();
8588
9683
  ctx->prealloc_y_last_tensor_used = src1;
9684
+ ctx->prealloc_y_last_decode_vector_staging = false;
8589
9685
  }
8590
9686
  }
8591
9687
  if (quantize_y) {
8592
9688
  if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
8593
- ctx->prealloc_y_last_tensor_used != src1) {
9689
+ ctx->prealloc_y_last_tensor_used != src1 ||
9690
+ ctx->prealloc_y_last_decode_vector_staging) {
8594
9691
  if (ctx->prealloc_y_need_sync) {
8595
9692
  ggml_vk_sync_buffers(ctx, subctx);
8596
9693
  }
8597
9694
  ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
8598
9695
  ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
8599
9696
  ctx->prealloc_y_last_tensor_used = src1;
9697
+ ctx->prealloc_y_last_decode_vector_staging = false;
8600
9698
  }
8601
9699
  }
8602
9700
 
@@ -8687,14 +9785,18 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
8687
9785
  }
8688
9786
  }
8689
9787
 
8690
- static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
9788
+ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type, ggml_type v_type) {
8691
9789
  GGML_UNUSED(f32acc);
9790
+ GGML_UNUSED(v_type);
8692
9791
  // Needs to be kept up to date on shader changes
8693
9792
  const uint32_t wg_size = params.workgroup_size;
8694
9793
  const uint32_t Br = params.block_rows;
8695
9794
  const uint32_t Bc = params.block_cols;
8696
9795
 
8697
- const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
9796
+ // BF16 uses the fp32 shader (FLOAT_TYPE=float)
9797
+ const uint32_t float_type_size = (device->fp16 && k_type != GGML_TYPE_BF16) ? sizeof(ggml_fp16_t) : sizeof(float);
9798
+
9799
+ const bool mmq = ggml_vk_fa_scalar_uses_mmq(device, k_type);
8698
9800
 
8699
9801
  // tmpsh is overestimated slightly
8700
9802
  const uint32_t tmpsh = wg_size * sizeof(float);
@@ -8702,20 +9804,38 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con
8702
9804
 
8703
9805
  const uint32_t masksh = Bc * (Br + 1) * float_type_size;
8704
9806
 
8705
- const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
9807
+ uint32_t Qf, kvsh, kblocksh_size;
9808
+ if (mmq) {
9809
+ // block_b_cache: int32_t qs[8] + FLOAT_TYPEV2 ds
9810
+ const uint32_t block_b_size = 8 * sizeof(int32_t) + 2 * float_type_size;
9811
+ Qf = Br * (hsk / 32) * block_b_size;
9812
+
9813
+ // kvsh uses D = HSV (K goes through kblocksh instead)
9814
+ kvsh = params.shmem_staging ? Bc * (hsv / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
9815
+
9816
+ // The mixed MMQ shader uses a superset block_a_cache that fits every
9817
+ // FA-supported quant: int32_t qs[8] + uint32_t qh + FLOAT_TYPEV2 dm.
9818
+ // Single-scale types leave dm.y unused; non-Q5_* leave qh unused.
9819
+ const uint32_t block_a_size = 8 * sizeof(int32_t) + sizeof(uint32_t) + 2 * float_type_size;
9820
+ kblocksh_size = params.shmem_staging ? Bc * (hsk / 32) * block_a_size : block_a_size;
9821
+ } else {
9822
+ Qf = Br * (hsk / 4 + 1) * 4 * float_type_size;
9823
+
9824
+ const uint32_t D = std::max(hsk, hsv);
9825
+ kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
8706
9826
 
8707
- const uint32_t D = std::max(hsk, hsv);
8708
- const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size;
9827
+ kblocksh_size = 0;
9828
+ }
8709
9829
 
8710
- const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh;
9830
+ const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh + kblocksh_size;
8711
9831
  const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
8712
9832
 
8713
- VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported);
9833
+ VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", mmq=" << mmq << ", total_size=" << total_size << ", supported=" << supported);
8714
9834
 
8715
9835
  return supported;
8716
9836
  }
8717
9837
 
8718
- static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) {
9838
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type k_type) {
8719
9839
  // Needs to be kept up to date on shader changes
8720
9840
  const uint32_t Br = params.block_rows;
8721
9841
  const uint32_t Bc = params.block_cols;
@@ -8745,8 +9865,10 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co
8745
9865
  const uint32_t vsh_stride = MatBc / 4 * row_split;
8746
9866
  const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4;
8747
9867
 
9868
+ // BF16 PVMat accumulator is f32 (no bf16 accumulator support), so pvsh is vec4 (16 bytes)
9869
+ const uint32_t pvsh_elem_size = (k_type == GGML_TYPE_BF16) ? 16u : f16vec4;
8748
9870
  const uint32_t osh_stride = params.row_split * MatBr / 4;
8749
- const uint32_t pvsh = MatBc * osh_stride * f16vec4;
9871
+ const uint32_t pvsh = MatBc * osh_stride * pvsh_elem_size;
8750
9872
 
8751
9873
  const uint32_t slope = Br * acctype;
8752
9874
 
@@ -8809,19 +9931,17 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8809
9931
 
8810
9932
  assert(dst->type == GGML_TYPE_F32);
8811
9933
  assert(q->type == GGML_TYPE_F32);
8812
- assert(k->type == v->type);
8813
-
8814
9934
  uint32_t gqa_ratio = 1;
8815
9935
  uint32_t qk_ratio = neq2 / nek2;
8816
9936
  uint32_t workgroups_x = (uint32_t)neq1;
8817
9937
  uint32_t workgroups_y = (uint32_t)neq2;
8818
9938
  uint32_t workgroups_z = (uint32_t)neq3;
8819
9939
 
8820
- const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32;
9940
+ const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32 || k->type == GGML_TYPE_BF16;
8821
9941
 
8822
9942
  // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
8823
9943
  // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
8824
- vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc);
9944
+ vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, v->type, f32acc);
8825
9945
  const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u);
8826
9946
 
8827
9947
  if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa &&
@@ -8834,7 +9954,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8834
9954
  workgroups_y /= gqa_ratio;
8835
9955
  }
8836
9956
 
8837
- tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc);
9957
+ tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, v->type, f32acc);
8838
9958
 
8839
9959
  const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
8840
9960
  uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
@@ -8873,13 +9993,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8873
9993
  // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively.
8874
9994
  bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768 && nem0 >= tuning_params.block_cols * 16;
8875
9995
  vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc,
8876
- mask != nullptr, use_mask_opt, logit_softcap != 0);
9996
+ mask != nullptr, use_mask_opt, logit_softcap != 0, k->type, v->type);
8877
9997
 
8878
9998
  vk_pipeline pipeline = nullptr;
8879
9999
 
8880
10000
  {
8881
- std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
8882
- auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type];
10001
+ std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
10002
+ auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16;
8883
10003
  auto it = pipelines.find(fa_pipeline_state);
8884
10004
  if (it != pipelines.end()) {
8885
10005
  pipeline = it->second;
@@ -8942,13 +10062,15 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
8942
10062
 
8943
10063
  vk_pipeline pipeline_fa_mask_opt = nullptr;
8944
10064
  if (use_mask_opt) {
8945
- std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
8946
- auto &pipelines = ctx->device->pipeline_fa_mask_opt;
8947
- auto it = pipelines.find({Br, Bc});
8948
- if (it != pipelines.end()) {
8949
- pipeline_fa_mask_opt = it->second;
8950
- } else {
8951
- pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
10065
+ {
10066
+ std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
10067
+ auto &pipelines = ctx->device->pipeline_fa_mask_opt;
10068
+ auto it = pipelines.find({Br, Bc});
10069
+ if (it != pipelines.end()) {
10070
+ pipeline_fa_mask_opt = it->second;
10071
+ } else {
10072
+ pipelines[{Br, Bc}] = pipeline_fa_mask_opt = std::make_shared<vk_pipeline_struct>();
10073
+ }
8952
10074
  }
8953
10075
  assert(pipeline_fa_mask_opt);
8954
10076
  ggml_pipeline_request_descriptor_sets(ctx, pipeline_fa_mask_opt, 1);
@@ -9059,10 +10181,23 @@ static vk_conv_shapes ggml_vk_conv_select_shape(ggml_backend_vk_context * ctx, u
9059
10181
  // so small convolutions will still choose a smaller tile.
9060
10182
  const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32;
9061
10183
 
9062
- if (K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
10184
+ // 128x128 isn't used with cm1 due to shared memory size; fall through to a smaller tile.
10185
+ bool allow_128x128 = true;
10186
+ #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
10187
+ if (!ctx->device->coopmat2 && ctx->device->coopmat_support && ctx->device->coopmat_support_16x16x16_f16acc) {
10188
+ allow_128x128 = false;
10189
+ }
10190
+ #endif
10191
+
10192
+ if (allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_128x128) >= shader_core_count * 2) {
9063
10193
  return CONV_SHAPE_128x128;
9064
10194
  } else if (K <= 32 && n_tiles(CONV_SHAPE_32x256) >= shader_core_count * 2) {
9065
10195
  return CONV_SHAPE_32x256;
10196
+ } else if (K <= 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) {
10197
+ return CONV_SHAPE_64x128;
10198
+ } else if (!allow_128x128 && K > 64 && n_tiles(CONV_SHAPE_64x128) >= shader_core_count * 2) {
10199
+ // cm1 fallback for large K when 128x128 isn't available
10200
+ return CONV_SHAPE_64x128;
9066
10201
  } else {
9067
10202
  return CONV_SHAPE_64x32;
9068
10203
  }
@@ -9234,7 +10369,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
9234
10369
  return nullptr;
9235
10370
  case GGML_OP_REPEAT:
9236
10371
  if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) {
9237
- return ctx->device->pipeline_repeat_f32;
10372
+ return ctx->device->pipeline_repeat_i32;
10373
+ }
10374
+ if (ggml_type_size(src0->type) == 2 && ggml_type_size(dst->type) == 2) {
10375
+ return ctx->device->pipeline_repeat_i16;
9238
10376
  }
9239
10377
  return nullptr;
9240
10378
  case GGML_OP_REPEAT_BACK:
@@ -9466,7 +10604,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
9466
10604
  vk_pipeline pipeline = nullptr;
9467
10605
 
9468
10606
  {
9469
- std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
10607
+ std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
9470
10608
  auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
9471
10609
  if (it != ctx->device->pipeline_solve_tri_f32.end()) {
9472
10610
  pipeline = it->second;
@@ -9555,7 +10693,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
9555
10693
  return nullptr;
9556
10694
  case GGML_OP_SSM_CONV:
9557
10695
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
9558
- return ctx->device->pipeline_ssm_conv_f32;
10696
+ switch (ctx->num_additional_fused_ops) {
10697
+ case 0: return ctx->device->pipeline_ssm_conv_f32;
10698
+ case 1: return ctx->device->pipeline_ssm_conv_silu_f32;
10699
+ case 2: return ctx->device->pipeline_ssm_conv_bias_silu_f32;
10700
+ default: return nullptr;
10701
+ }
9559
10702
  }
9560
10703
  return nullptr;
9561
10704
  case GGML_OP_OPT_STEP_ADAMW:
@@ -9589,7 +10732,18 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
9589
10732
  uint32_t p1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 3) : 0;
9590
10733
  uint32_t d0 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 4) : 1;
9591
10734
  uint32_t d1 = !transpose ? (uint32_t)ggml_get_op_params_i32(dst, 5) : 1;
9592
- vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH);
10735
+
10736
+ // tile-aligned shapes let the shader skip bounds checks
10737
+ const uint32_t Cin = (uint32_t)src1->ne[2];
10738
+ const uint32_t CRS = Cin * KW * KH;
10739
+ const uint32_t BS_K = vk_conv_block_sizes[shape].K;
10740
+ const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS;
10741
+ const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ;
10742
+ const uint32_t aligned = ((K % BS_K == 0) &&
10743
+ (CRS % BS_CRS == 0) &&
10744
+ (NPQ % BS_NPQ == 0)) ? 1u : 0u;
10745
+
10746
+ vk_conv2d_pipeline_state conv2d_pipeline_state(s0, s1, p0, p1, d0, d1, KW, KH, aligned);
9593
10747
 
9594
10748
  std::map<vk_conv2d_pipeline_state, vk_pipeline> *pipelines = nullptr;
9595
10749
  if (op == GGML_OP_CONV_2D) {
@@ -9609,7 +10763,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
9609
10763
  vk_pipeline pipeline = nullptr;
9610
10764
 
9611
10765
  {
9612
- std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
10766
+ std::lock_guard<std::mutex> guard(ctx->device->compile_mutex);
9613
10767
  auto it = pipelines->find(conv2d_pipeline_state);
9614
10768
  if (it != pipelines->end()) {
9615
10769
  pipeline = it->second;
@@ -9656,6 +10810,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
9656
10810
  if (dst->type == GGML_TYPE_F32) {
9657
10811
  return ctx->device->pipeline_fill_f32;
9658
10812
  }
10813
+ if (dst->type == GGML_TYPE_F16) {
10814
+ return ctx->device->pipeline_fill_f16;
10815
+ }
9659
10816
  return nullptr;
9660
10817
  default:
9661
10818
  return nullptr;
@@ -9733,6 +10890,15 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
9733
10890
  GGML_UNUSED(src3);
9734
10891
  }
9735
10892
 
10893
+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_rope_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
10894
+ p.a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
10895
+ p.d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
10896
+
10897
+ GGML_UNUSED(src1);
10898
+ GGML_UNUSED(src2);
10899
+ GGML_UNUSED(src3);
10900
+ }
10901
+
9736
10902
  template<typename PC>
9737
10903
  static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst, ggml_op op, PC&& pc) {
9738
10904
  VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
@@ -9876,7 +11042,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
9876
11042
 
9877
11043
  const uint32_t batch = src1->ne[is_2D ? 3 : 2];
9878
11044
 
9879
- elements = { OW * KW * KH, OH, batch * IC };
11045
+ const uint32_t CHW = IC * KH * KW;
11046
+ // Cap X workgroups to limit concurrent IC channel reads.
11047
+ // The shader loops over X to cover the full CHW dimension.
11048
+ // AMD prefers a lower limit
11049
+ const uint32_t min_cap = ctx->device->vendor_id == VK_VENDOR_ID_AMD ? 512u : 4096u;
11050
+ const uint32_t x_elements = std::min(CHW, std::max(min_cap, OW * KH * KW));
11051
+ elements = { x_elements, OW, OH * batch };
9880
11052
  elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]);
9881
11053
  elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
9882
11054
  } break;
@@ -10385,6 +11557,9 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
10385
11557
  const uint32_t n_tokens = (uint32_t)src_v->ne[2];
10386
11558
  const uint32_t n_seqs = (uint32_t)src_v->ne[3];
10387
11559
 
11560
+ // K (snapshot slot count) is an op param; state holds s0 only [S_v, S_v, H, n_seqs].
11561
+ const uint32_t K = (uint32_t)ggml_get_op_params_i32(dst, 0);
11562
+
10388
11563
  const uint32_t s_off = S_v * H * n_tokens * n_seqs;
10389
11564
 
10390
11565
  vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
@@ -10418,12 +11593,13 @@ static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& s
10418
11593
  sv1, sv2, sv3,
10419
11594
  sb1, sb2, sb3,
10420
11595
  neq1, rq3,
10421
- scale
11596
+ scale,
11597
+ K
10422
11598
  };
10423
11599
 
10424
11600
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
10425
11601
  {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
10426
- pc, { H, n_seqs, 1u });
11602
+ pc, { H, n_seqs, S_v });
10427
11603
  }
10428
11604
 
10429
11605
  static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
@@ -10482,11 +11658,28 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx,
10482
11658
  pc, elements);
10483
11659
  }
10484
11660
 
10485
- static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
10486
- const ggml_tensor * src0 = dst->src[0];
10487
- const ggml_tensor * src1 = dst->src[1];
11661
+ static void ggml_vk_ssm_conv(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
11662
+ ggml_tensor * conv = cgraph->nodes[node_idx];
11663
+ const ggml_tensor * src0 = conv->src[0];
11664
+ const ggml_tensor * src1 = conv->src[1];
11665
+
11666
+ // Pick the destination tensor (last node in the fused chain) and the optional bias.
11667
+ // Fusion modes: 0 = ssm_conv, 1 = ssm_conv+silu, 2 = ssm_conv+add(bias)+silu.
11668
+ ggml_tensor * dst = conv;
11669
+ const ggml_tensor * bias = nullptr;
10488
11670
 
10489
- ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SSM_CONV, {
11671
+ if (ctx->num_additional_fused_ops == 1) {
11672
+ dst = cgraph->nodes[node_idx + 1]; // silu
11673
+ } else if (ctx->num_additional_fused_ops == 2) {
11674
+ ggml_tensor * add = cgraph->nodes[node_idx + 1];
11675
+ bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
11676
+ dst = cgraph->nodes[node_idx + 2]; // silu
11677
+ }
11678
+
11679
+ // The shader always declares 4 bindings; bind src0 as a dummy when bias isn't fused.
11680
+ const ggml_tensor * src2 = bias ? bias : src0;
11681
+
11682
+ ggml_vk_op_f32<vk_op_ssm_conv_push_constants>(ctx, subctx, src0, src1, src2, nullptr, dst, GGML_OP_SSM_CONV, {
10490
11683
  (uint32_t)src0->nb[1], (uint32_t)src0->nb[2],
10491
11684
  (uint32_t)src1->nb[1],
10492
11685
  (uint32_t)dst->nb[0], (uint32_t)dst->nb[1], (uint32_t)dst->nb[2],
@@ -10849,6 +12042,7 @@ static vk_op_rope_push_constants ggml_vk_make_rope_constants(const ggml_tensor *
10849
12042
  (uint32_t)src0->ne[2],
10850
12043
  nb01, nb02, nb03,
10851
12044
  nb11, nb12, nb13,
12045
+ 0, 0, // a_offset, d_offset filled in by init_pushconst_tensor_offsets
10852
12046
  };
10853
12047
 
10854
12048
  return rope;
@@ -10944,6 +12138,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
10944
12138
  GGML_ASSERT(buf[i] != nullptr);
10945
12139
  }
10946
12140
 
12141
+ // a_offset is unused (the fused path reads from shared memory), but the rope/set_rows dst can be misaligned.
12142
+ // Round the binding offset down to the storage buffer alignment; the in-element shift goes in pc.rope.d_offset.
12143
+ pc.rope.d_offset = get_misalign_bytes(ctx, tensors[5]) / ggml_type_size(tensors[5]->type);
12144
+ offset[5] &= ~(size_t(ctx->device->properties.limits.minStorageBufferOffsetAlignment) - 1);
12145
+
10947
12146
  std::array<uint32_t, 3> elements;
10948
12147
  elements = { (uint32_t)rms->src[0]->ne[1], (uint32_t)rms->src[0]->ne[2], (uint32_t)rms->src[0]->ne[3] };
10949
12148
 
@@ -11003,8 +12202,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
11003
12202
  const float alpha = op_params_f[2];
11004
12203
  const float limit = op_params_f[3];
11005
12204
 
11006
- GGML_ASSERT(ggml_is_contiguous(src0));
11007
-
11008
12205
  if (!split) {
11009
12206
  GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
11010
12207
  } else {
@@ -11022,7 +12219,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
11022
12219
  (uint32_t)dst->ne[0],
11023
12220
  mode,
11024
12221
  alpha,
11025
- limit
12222
+ limit,
12223
+ (uint32_t)(src0->nb[1] / src0->nb[0]),
12224
+ (uint32_t)(src0->nb[2] / src0->nb[0]),
12225
+ (uint32_t)(src0->nb[3] / src0->nb[0]),
12226
+ (uint32_t)src0->ne[1],
12227
+ (uint32_t)src0->ne[2],
12228
+ (uint32_t)(dst->nb[1] / dst->nb[0]),
12229
+ (uint32_t)(dst->nb[2] / dst->nb[0]),
12230
+ (uint32_t)(dst->nb[3] / dst->nb[0]),
12231
+ (uint32_t)dst->ne[1],
12232
+ (uint32_t)dst->ne[2]
11026
12233
  });
11027
12234
  }
11028
12235
 
@@ -11531,7 +12738,6 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
11531
12738
  const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
11532
12739
  const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
11533
12740
 
11534
- const uint32_t pelements = OW * KW * KH;
11535
12741
  const uint32_t batch = src1->ne[is_2D ? 3 : 2];
11536
12742
 
11537
12743
  const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
@@ -11543,7 +12749,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
11543
12749
  dst_addr,
11544
12750
  batch_offset, offset_delta,
11545
12751
  IC, IW, IH, OW, OH, KW, KH,
11546
- pelements,
12752
+ OH * batch,
11547
12753
  IC * KH * KW,
11548
12754
  s0, s1, p0, p1, d0, d1, batch * IC
11549
12755
  });
@@ -11656,6 +12862,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
11656
12862
  ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
11657
12863
  }
11658
12864
 
12865
+ // Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b.
12866
+ // Match the naive mul -> sin -> sqr -> mul -> add chain and run the
12867
+ // dedicated kernel directly. The pattern is validated by
12868
+ // ggml_vk_can_fuse_snake before this call.
12869
+ static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
12870
+ const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0];
12871
+ const ggml_tensor * sqr = cgraph->nodes[node_idx + 2];
12872
+ const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3];
12873
+ ggml_tensor * add = cgraph->nodes[node_idx + 4];
12874
+
12875
+ // x carries the full activation shape, a is the broadcast operand
12876
+ const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
12877
+ const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
12878
+
12879
+ // mul1 reads sqr and inv_b in either operand order
12880
+ const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
12881
+
12882
+ vk_pipeline pipeline = nullptr;
12883
+ switch (x->type) {
12884
+ case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break;
12885
+ case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break;
12886
+ case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break;
12887
+ default: GGML_ABORT("unsupported type");
12888
+ }
12889
+ ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
12890
+
12891
+ vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
12892
+ vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a);
12893
+ vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b);
12894
+ vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add);
12895
+
12896
+ vk_op_snake_push_constants pc{};
12897
+ pc.ne0 = static_cast<uint32_t>(x->ne[0]);
12898
+ pc.ne1 = static_cast<uint32_t>(x->ne[1]);
12899
+
12900
+ std::array<uint32_t, 3> elements = { pc.ne0, pc.ne1, 1 };
12901
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements);
12902
+ }
12903
+
11659
12904
  static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
11660
12905
  uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
11661
12906
  const int32_t k1 = dst->op_params[1];
@@ -12673,7 +13918,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx, vk_contex
12673
13918
  ggml_vk_destroy_buffer(ctx->prealloc_y);
12674
13919
  }
12675
13920
  ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y);
13921
+ ctx->prealloc_y_last_pipeline_used = nullptr;
12676
13922
  ctx->prealloc_y_last_tensor_used = nullptr;
13923
+ ctx->prealloc_y_last_decode_vector_staging = false;
12677
13924
  }
12678
13925
  if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) {
12679
13926
  VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")");
@@ -12801,6 +14048,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12801
14048
  if (vk_perf_logger_enabled && vk_perf_logger_concurrent) {
12802
14049
  ctx->query_node_idx[ctx->query_idx] = node_idx;
12803
14050
  compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
14051
+ ggml_vk_sync_buffers(ctx, compute_ctx);
12804
14052
  }
12805
14053
  }
12806
14054
  // Add all fused nodes to the unsynchronized lists.
@@ -12863,7 +14111,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
12863
14111
 
12864
14112
  break;
12865
14113
  case GGML_OP_MUL:
12866
- ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
14114
+ if (ctx->num_additional_fused_ops) {
14115
+ ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx);
14116
+ } else {
14117
+ ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
14118
+ }
12867
14119
 
12868
14120
  break;
12869
14121
  case GGML_OP_DIV:
@@ -13153,7 +14405,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
13153
14405
  break;
13154
14406
 
13155
14407
  case GGML_OP_SSM_CONV:
13156
- ggml_vk_ssm_conv(ctx, compute_ctx, node);
14408
+ ggml_vk_ssm_conv(ctx, compute_ctx, cgraph, node_idx);
13157
14409
 
13158
14410
  break;
13159
14411
 
@@ -13248,6 +14500,8 @@ static void ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
13248
14500
  static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
13249
14501
  VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
13250
14502
  ctx->prealloc_y_last_pipeline_used = {};
14503
+ ctx->prealloc_y_last_tensor_used = nullptr;
14504
+ ctx->prealloc_y_last_decode_vector_staging = false;
13251
14505
 
13252
14506
  ctx->unsynced_nodes_written.clear();
13253
14507
  ctx->unsynced_nodes_read.clear();
@@ -13298,6 +14552,8 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
13298
14552
  ggml_vk_destroy_buffer(ctx->sync_staging);
13299
14553
 
13300
14554
  ctx->prealloc_y_last_pipeline_used = nullptr;
14555
+ ctx->prealloc_y_last_tensor_used = nullptr;
14556
+ ctx->prealloc_y_last_decode_vector_staging = false;
13301
14557
 
13302
14558
  ctx->prealloc_size_x = 0;
13303
14559
  ctx->prealloc_size_y = 0;
@@ -13401,6 +14657,20 @@ static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml
13401
14657
  ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
13402
14658
  }
13403
14659
 
14660
+ static void ggml_backend_vk_buffer_set_tensor_2d(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset,
14661
+ size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
14662
+ VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " <<
14663
+ n_copies << ", " << stride_tensor << ", " << stride_data << ")");
14664
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
14665
+ vk_buffer buf = buf_ctx->dev_buffer;
14666
+
14667
+ if (size == 0) {
14668
+ return;
14669
+ }
14670
+
14671
+ ggml_vk_buffer_write_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_data, stride_tensor, size, n_copies);
14672
+ }
14673
+
13404
14674
  static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
13405
14675
  VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")");
13406
14676
  ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
@@ -13414,6 +14684,21 @@ static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, cons
13414
14684
  ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size);
13415
14685
  }
13416
14686
 
14687
+ static void ggml_backend_vk_buffer_get_tensor_2d(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset,
14688
+ size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
14689
+ VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor_2d(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ", " <<
14690
+ n_copies << ", " << stride_tensor << ", " << stride_data << ")");
14691
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context;
14692
+
14693
+ if (size == 0) {
14694
+ return;
14695
+ }
14696
+
14697
+ vk_buffer buf = buf_ctx->dev_buffer;
14698
+
14699
+ ggml_vk_buffer_read_2d(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, stride_tensor, stride_data, size, n_copies);
14700
+ }
14701
+
13417
14702
  static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
13418
14703
  if (ggml_nbytes(src) == 0) {
13419
14704
  return true;
@@ -13448,6 +14733,8 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
13448
14733
  /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor,
13449
14734
  /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor,
13450
14735
  /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor,
14736
+ /* .set_tensor_2d = */ ggml_backend_vk_buffer_set_tensor_2d,
14737
+ /* .get_tensor_2d = */ ggml_backend_vk_buffer_get_tensor_2d,
13451
14738
  /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor,
13452
14739
  /* .clear = */ ggml_backend_vk_buffer_clear,
13453
14740
  /* .reset = */ NULL,
@@ -13510,12 +14797,6 @@ static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_ty
13510
14797
  UNUSED(buft);
13511
14798
  }
13512
14799
 
13513
- static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) {
13514
- return GGML_VK_NAME "_Host";
13515
-
13516
- UNUSED(buffer);
13517
- }
13518
-
13519
14800
  static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
13520
14801
  VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()");
13521
14802
  ggml_vk_host_free(vk_instance.devices[0], buffer->context);
@@ -13603,8 +14884,9 @@ static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_b
13603
14884
  return &ctx->device->buffer_type;
13604
14885
  }
13605
14886
 
13606
- static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
13607
- VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
14887
+ static void ggml_backend_vk_set_tensor_2d_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset,
14888
+ size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
14889
+ VK_LOG_DEBUG("ggml_backend_vk_set_tensor_2d_async(" << size << ", " << n_copies << ")");
13608
14890
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13609
14891
  GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
13610
14892
 
@@ -13618,7 +14900,6 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
13618
14900
 
13619
14901
  if (ctx->device->async_use_transfer_queue) {
13620
14902
  if (ctx->transfer_ctx.expired()) {
13621
- // Initialize new transfer context
13622
14903
  cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool);
13623
14904
  ctx->transfer_ctx = cpy_ctx;
13624
14905
  ggml_vk_ctx_begin(ctx->device, cpy_ctx);
@@ -13633,25 +14914,48 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor
13633
14914
 
13634
14915
  auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
13635
14916
 
13636
- bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size);
14917
+ bool ret = ggml_vk_buffer_write_2d_async(cpy_ctx, buf, dst_offset, data, stride_data, stride_tensor, size, n_copies);
13637
14918
 
13638
14919
  if (!ret) {
13639
- ggml_vk_ensure_sync_staging_buffer(ctx, size);
14920
+ const size_t staging_size = size * n_copies;
14921
+ ggml_vk_ensure_sync_staging_buffer(ctx, staging_size);
13640
14922
  ggml_vk_sync_buffers(nullptr, cpy_ctx);
13641
14923
 
13642
- vk::BufferCopy buffer_cpy;
13643
- buffer_cpy.srcOffset = 0;
13644
- buffer_cpy.dstOffset = dst_offset;
13645
- buffer_cpy.size = size;
14924
+ std::vector<vk::BufferCopy> slices(1);
14925
+ if (size == stride_tensor) {
14926
+ slices[0].srcOffset = 0;
14927
+ slices[0].dstOffset = dst_offset;
14928
+ slices[0].size = staging_size;
14929
+ } else {
14930
+ slices.resize(n_copies);
14931
+ for (size_t i = 0; i < n_copies; i++) {
14932
+ slices[i].srcOffset = i * size;
14933
+ slices[i].dstOffset = dst_offset + i * stride_tensor;
14934
+ slices[i].size = size;
14935
+ }
14936
+ }
14937
+
14938
+ cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, slices);
13646
14939
 
13647
- cpy_ctx->s->buffer->buf.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy });
13648
- deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys);
14940
+ if (size == stride_data) {
14941
+ deferred_memcpy(ctx->sync_staging->ptr, data, staging_size, &cpy_ctx->in_memcpys);
14942
+ } else {
14943
+ for (size_t i = 0; i < n_copies; i++) {
14944
+ deferred_memcpy((uint8_t *)ctx->sync_staging->ptr + i * size, (const uint8_t *)data + i * stride_data, size, &cpy_ctx->in_memcpys);
14945
+ }
14946
+ }
13649
14947
  ggml_vk_synchronize(ctx);
13650
14948
  }
13651
14949
  }
13652
14950
 
13653
- static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
13654
- VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
14951
+ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
14952
+ VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")");
14953
+ ggml_backend_vk_set_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size);
14954
+ }
14955
+
14956
+ static void ggml_backend_vk_get_tensor_2d_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset,
14957
+ size_t size, size_t n_copies, size_t stride_tensor, size_t stride_data) {
14958
+ VK_LOG_DEBUG("ggml_backend_vk_get_tensor_2d_async(" << size << ", " << n_copies << ")");
13655
14959
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
13656
14960
  GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type");
13657
14961
 
@@ -13666,24 +14970,45 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_
13666
14970
  vk_buffer buf = buf_ctx->dev_buffer;
13667
14971
 
13668
14972
  auto src_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset;
13669
- bool ret = ggml_vk_buffer_read_async(compute_ctx, buf, src_offset, data, size);
14973
+ bool ret = ggml_vk_buffer_read_2d_async(compute_ctx, buf, src_offset, data, stride_tensor, stride_data, size, n_copies);
13670
14974
 
13671
- // If that failed, copy synchronously through a staging buffer
13672
14975
  if (!ret) {
13673
- ggml_vk_ensure_sync_staging_buffer(ctx, size);
14976
+ const size_t staging_size = size * n_copies;
14977
+ ggml_vk_ensure_sync_staging_buffer(ctx, staging_size);
13674
14978
  ggml_vk_sync_buffers(nullptr, compute_ctx);
13675
14979
 
13676
- vk::BufferCopy buffer_cpy;
13677
- buffer_cpy.srcOffset = src_offset;
13678
- buffer_cpy.dstOffset = 0;
13679
- buffer_cpy.size = size;
14980
+ std::vector<vk::BufferCopy> slices(1);
14981
+ if (size == stride_tensor) {
14982
+ slices[0].srcOffset = src_offset;
14983
+ slices[0].dstOffset = 0;
14984
+ slices[0].size = staging_size;
14985
+ } else {
14986
+ slices.resize(n_copies);
14987
+ for (size_t i = 0; i < n_copies; i++) {
14988
+ slices[i].srcOffset = src_offset + i * stride_tensor;
14989
+ slices[i].dstOffset = i * size;
14990
+ slices[i].size = size;
14991
+ }
14992
+ }
14993
+
14994
+ compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, slices);
13680
14995
 
13681
- compute_ctx->s->buffer->buf.copyBuffer(buf->buffer, ctx->sync_staging->buffer, { buffer_cpy });
13682
- deferred_memcpy(data, ctx->sync_staging->ptr, size, &compute_ctx->out_memcpys);
14996
+ if (size == stride_data) {
14997
+ deferred_memcpy(data, ctx->sync_staging->ptr, staging_size, &compute_ctx->out_memcpys);
14998
+ } else {
14999
+ for (size_t i = 0; i < n_copies; i++) {
15000
+ deferred_memcpy((uint8_t *)data + i * stride_data, (const uint8_t *)ctx->sync_staging->ptr + i * size, size, &compute_ctx->out_memcpys);
15001
+ }
15002
+ }
13683
15003
  ggml_vk_synchronize(ctx);
13684
15004
  }
13685
15005
  }
13686
15006
 
15007
+ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
15008
+ VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")");
15009
+ ggml_backend_vk_get_tensor_2d_async(backend, tensor, data, offset, size, 1, size, size);
15010
+ }
15011
+
13687
15012
  static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) {
13688
15013
  VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async(" << src << " -> " << dst << ", size=" << ggml_nbytes(src) << ")");
13689
15014
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context;
@@ -13797,6 +15122,7 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) {
13797
15122
  ctx->submit_pending = false;
13798
15123
  if (cmd_buf) {
13799
15124
  cmd_buf->in_use = false;
15125
+ cmd_buf->buf.reset();
13800
15126
  }
13801
15127
  }
13802
15128
 
@@ -13974,6 +15300,62 @@ static bool ggml_vk_can_fuse(const ggml_backend_vk_context * ctx, const struct g
13974
15300
  return true;
13975
15301
  }
13976
15302
 
15303
+ // Match SSM_CONV + UNARY(SILU) or SSM_CONV + ADD + UNARY(SILU). num_extra is 1 or 2.
15304
+ static bool ggml_vk_can_fuse_ssm_conv(const ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
15305
+ int node_idx, int num_extra) {
15306
+ const ggml_tensor * conv = cgraph->nodes[node_idx];
15307
+ if (conv->op != GGML_OP_SSM_CONV) {
15308
+ return false;
15309
+ }
15310
+
15311
+ const ggml_tensor * silu = nullptr;
15312
+ const ggml_tensor * bias = nullptr;
15313
+
15314
+ if (num_extra == 1) {
15315
+ if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_UNARY })) {
15316
+ return false;
15317
+ }
15318
+ silu = cgraph->nodes[node_idx + 1];
15319
+ } else if (num_extra == 2) {
15320
+ if (!ggml_can_fuse(cgraph, node_idx, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY })) {
15321
+ return false;
15322
+ }
15323
+ const ggml_tensor * add = cgraph->nodes[node_idx + 1];
15324
+ silu = cgraph->nodes[node_idx + 2];
15325
+ bias = (add->src[0] == conv) ? add->src[1] : add->src[0];
15326
+
15327
+ if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
15328
+ return false;
15329
+ }
15330
+ // bias must be channel-wise (one element per channel of the conv output)
15331
+ if (ggml_nelements(bias) != conv->ne[0] || bias->ne[0] != conv->ne[0]) {
15332
+ return false;
15333
+ }
15334
+ if (add->type != GGML_TYPE_F32) {
15335
+ return false;
15336
+ }
15337
+ // The shader doesn't apply per-tensor offsets, so reject misaligned bias.
15338
+ if (get_misalign_bytes(ctx, bias) != 0) {
15339
+ return false;
15340
+ }
15341
+ } else {
15342
+ return false;
15343
+ }
15344
+
15345
+ if (ggml_get_unary_op(silu) != GGML_UNARY_OP_SILU) {
15346
+ return false;
15347
+ }
15348
+ if (conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
15349
+ return false;
15350
+ }
15351
+ // The shader writes to the fused dst using its own strides, but the push constants don't
15352
+ // carry a per-tensor offset, so the binding must be naturally aligned.
15353
+ if (get_misalign_bytes(ctx, silu) != 0) {
15354
+ return false;
15355
+ }
15356
+ return true;
15357
+ }
15358
+
13977
15359
  static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph,
13978
15360
  int node_idx, topk_moe_mode mode) {
13979
15361
 
@@ -14104,6 +15486,65 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
14104
15486
  return true;
14105
15487
  }
14106
15488
 
15489
+ // Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add.
15490
+ // Verifies the chain shape, the closure x_in_add == x_in_mul0, and that
15491
+ // the broadcast operands a and inv_b share a [1, C] layout.
15492
+ static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
15493
+ GGML_UNUSED(ctx);
15494
+ if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) {
15495
+ return false;
15496
+ }
15497
+
15498
+ const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0];
15499
+ const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1];
15500
+ const ggml_tensor * sqr = cgraph->nodes[node_idx + 2];
15501
+ const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3];
15502
+ const ggml_tensor * add = cgraph->nodes[node_idx + 4];
15503
+
15504
+ const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
15505
+ const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];
15506
+
15507
+ const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
15508
+ const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];
15509
+
15510
+ if (x_in_add != x) {
15511
+ return false;
15512
+ }
15513
+ if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) {
15514
+ return false;
15515
+ }
15516
+ // Shader bindings: data_a is A_TYPE so it follows x's precision, while
15517
+ // data_b and data_c are hardcoded float, so the broadcast operands must
15518
+ // be F32 regardless of x's type.
15519
+ if (a->type != GGML_TYPE_F32) return false;
15520
+ if (inv_b->type != GGML_TYPE_F32) return false;
15521
+ // Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline).
15522
+ if (mul0->type != x->type) return false;
15523
+ if (sin_node->type != x->type) return false;
15524
+ if (sqr->type != x->type) return false;
15525
+ if (mul1->type != x->type) return false;
15526
+ if (add->type != x->type) return false;
15527
+ if (!ggml_are_same_shape(a, inv_b)) {
15528
+ return false;
15529
+ }
15530
+ if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) {
15531
+ return false;
15532
+ }
15533
+ // Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b
15534
+ // must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader.
15535
+ if (x->ne[2] != 1 || x->ne[3] != 1) return false;
15536
+ if (add->ne[2] != 1 || add->ne[3] != 1) return false;
15537
+ if (a->ne[2] != 1 || a->ne[3] != 1) return false;
15538
+ if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false;
15539
+ // Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1],
15540
+ // so every operand must be contiguous.
15541
+ if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) ||
15542
+ !ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) {
15543
+ return false;
15544
+ }
15545
+ return true;
15546
+ }
15547
+
14107
15548
  // Check whether the tensors overlap in memory.
14108
15549
  // Fusions can potentially overwrite src tensors in ways that are not prevented
14109
15550
  // by ggml-alloc. If the fusion src is being applied in a way that's elementwise
@@ -14158,8 +15599,7 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co
14158
15599
  }
14159
15600
 
14160
15601
  // conditions for pipeline creation
14161
- if (!(ctx->device->float_controls_rte_fp16 &&
14162
- sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) {
15602
+ if (sizeof(vk_op_rms_norm_mul_rope_push_constants) > ctx->device->properties.limits.maxPushConstantsSize) {
14163
15603
  return false;
14164
15604
  }
14165
15605
 
@@ -14288,10 +15728,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
14288
15728
  compute_ctx = ggml_vk_get_compute_ctx(ctx);
14289
15729
  ctx->query_idx = 0;
14290
15730
  compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
15731
+ ggml_vk_sync_buffers(ctx, compute_ctx);
14291
15732
  }
14292
15733
 
14293
15734
  ctx->prealloc_y_last_pipeline_used = nullptr;
14294
15735
  ctx->prealloc_y_last_tensor_used = nullptr;
15736
+ ctx->prealloc_y_last_decode_vector_staging = false;
14295
15737
 
14296
15738
  if (ctx->prealloc_size_add_rms_partials) {
14297
15739
  ggml_vk_preallocate_buffers(ctx, nullptr);
@@ -14390,6 +15832,19 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
14390
15832
  // they are overwritten, and one workgroup per row. So close enough.
14391
15833
  op_srcs_fused_elementwise[0] = true;
14392
15834
  op_srcs_fused_elementwise[1] = true;
15835
+ } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 2)) {
15836
+ ctx->num_additional_fused_ops = 2;
15837
+ fusion_string = "SSM_CONV_BIAS_SILU";
15838
+ // ssm_conv reads multiple input tokens per output, so it's not elementwise w.r.t. its srcs.
15839
+ // The downstream add and silu are elementwise on the conv output.
15840
+ op_srcs_fused_elementwise[0] = false;
15841
+ op_srcs_fused_elementwise[1] = true;
15842
+ op_srcs_fused_elementwise[2] = true;
15843
+ } else if (ggml_vk_can_fuse_ssm_conv(ctx, cgraph, i, 1)) {
15844
+ ctx->num_additional_fused_ops = 1;
15845
+ fusion_string = "SSM_CONV_SILU";
15846
+ op_srcs_fused_elementwise[0] = false;
15847
+ op_srcs_fused_elementwise[1] = true;
14393
15848
  } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) &&
14394
15849
  ggml_check_edges(cgraph, i, rope_view_set_rows_edges) &&
14395
15850
  ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) {
@@ -14398,6 +15853,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
14398
15853
  op_srcs_fused_elementwise[0] = false;
14399
15854
  op_srcs_fused_elementwise[1] = false;
14400
15855
  op_srcs_fused_elementwise[2] = false;
15856
+ } else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) {
15857
+ ctx->num_additional_fused_ops = 4;
15858
+ fusion_string = "SNAKE";
15859
+ // elementwise=true: snake.comp is safe under exact aliasing because each
15860
+ // thread reads data_x[idx] into a register before writing data_d[idx]
15861
+ // with a data dependency on that register. The overlap check still
15862
+ // rejects partial overlaps (different base or size).
15863
+ std::fill_n(op_srcs_fused_elementwise, 5, true);
14401
15864
  } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
14402
15865
  ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
14403
15866
  ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
@@ -14524,6 +15987,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
14524
15987
  ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i];
14525
15988
  ctx->query_fusion_names[ctx->query_idx] = fusion_string;
14526
15989
  compute_ctx->s->buffer->buf.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++);
15990
+ ggml_vk_sync_buffers(ctx, compute_ctx);
14527
15991
  } else {
14528
15992
  // track a fusion string and number of fused ops for the current node_idx
14529
15993
  ctx->query_fusion_names[i] = fusion_string;
@@ -14687,6 +16151,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
14687
16151
  if (keep_pattern(topk_moe_late_softmax)) {
14688
16152
  continue;
14689
16153
  }
16154
+ if (keep_pattern(snake_pattern)) {
16155
+ continue;
16156
+ }
14690
16157
 
14691
16158
  // First, grab the next unused node.
14692
16159
  current_set.push_back(first_unused);
@@ -14709,7 +16176,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
14709
16176
  if (match_pattern(topk_moe_early_softmax_norm, j) ||
14710
16177
  match_pattern(topk_moe_sigmoid_norm_bias, j) ||
14711
16178
  match_pattern(topk_moe_early_softmax, j) ||
14712
- match_pattern(topk_moe_late_softmax, j)) {
16179
+ match_pattern(topk_moe_late_softmax, j) ||
16180
+ match_pattern(snake_pattern, j)) {
14713
16181
  continue;
14714
16182
  }
14715
16183
  bool ok = true;
@@ -14720,7 +16188,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
14720
16188
  !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT && graph->nodes[j]->op == GGML_OP_ADD) &&
14721
16189
  !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_ADD_ID) &&
14722
16190
  !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_MUL_MAT_ID && graph->nodes[j]->op == GGML_OP_MUL) &&
14723
- !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD)) {
16191
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_ADD && graph->nodes[j]->op == GGML_OP_ADD) &&
16192
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_ADD) &&
16193
+ !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_SSM_CONV && graph->nodes[j]->op == GGML_OP_UNARY)) {
14724
16194
  ok = false;
14725
16195
  break;
14726
16196
  }
@@ -14803,6 +16273,19 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
14803
16273
  }
14804
16274
  }
14805
16275
  }
16276
+ // SSM_CONV + ADD + UNARY: pull the consuming UNARY forward
16277
+ if (j > 0 &&
16278
+ graph->nodes[j]->op == GGML_OP_ADD &&
16279
+ graph->nodes[j-1]->op == GGML_OP_SSM_CONV) {
16280
+ for (int k = j + 1; k < std::min(j + 15, graph->n_nodes); ++k) {
16281
+ if (graph->nodes[k]->op == GGML_OP_UNARY &&
16282
+ graph->nodes[k]->src[0] == graph->nodes[j]) {
16283
+ current_set.push_back(k);
16284
+ used[k] = true;
16285
+ break;
16286
+ }
16287
+ }
16288
+ }
14806
16289
  }
14807
16290
  }
14808
16291
  // Second pass grabs view nodes.
@@ -14858,18 +16341,31 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev
14858
16341
  vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
14859
16342
  auto* cmd_buf = compute_ctx->s->buffer; // retrieve pointer before it gets reset
14860
16343
 
14861
- // the backend interface doesn't have an explicit reset, so reset it here
14862
- // before we record the command to set it
14863
- ctx->device->device.resetEvent(vkev->event);
14864
- ctx->device->device.resetFences({ vkev->fence });
16344
+ if (vkev->has_event) {
16345
+ // Move existing event into submitted
16346
+ vkev->events_submitted.push_back(vkev->event);
16347
+ }
16348
+
16349
+ // Grab the next event and record it, create one if necessary
16350
+ if (vkev->events_free.empty()) {
16351
+ vkev->event = ctx->device->device.createEvent({});
16352
+ } else {
16353
+ vkev->event = vkev->events_free.back();
16354
+ vkev->events_free.pop_back();
16355
+ }
16356
+
16357
+ vkev->has_event = true;
14865
16358
 
14866
16359
  ggml_vk_set_event(compute_ctx, vkev->event);
14867
16360
 
16361
+ vkev->tl_semaphore.value++;
16362
+ compute_ctx->s->signal_semaphores.push_back(vkev->tl_semaphore);
14868
16363
  ggml_vk_ctx_end(compute_ctx);
14869
16364
 
14870
- ggml_vk_submit(compute_ctx, {vkev->fence});
16365
+ ggml_vk_submit(compute_ctx, {});
14871
16366
  ctx->submit_pending = true;
14872
16367
  vkev->cmd_buffer = cmd_buf;
16368
+ vkev->cmd_buffer_use_counter = cmd_buf->use_counter;
14873
16369
  ctx->compute_ctx.reset();
14874
16370
  }
14875
16371
 
@@ -14880,9 +16376,10 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even
14880
16376
 
14881
16377
  vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx);
14882
16378
 
14883
- ggml_vk_wait_events(compute_ctx, {vkev->event});
14884
- ggml_vk_ctx_end(compute_ctx);
14885
- ctx->compute_ctx.reset();
16379
+ if (vkev->has_event) {
16380
+ // Wait for latest event
16381
+ ggml_vk_wait_events(compute_ctx, { vkev->event });
16382
+ }
14886
16383
  }
14887
16384
 
14888
16385
  // TODO: enable async and synchronize
@@ -14891,6 +16388,8 @@ static ggml_backend_i ggml_backend_vk_interface = {
14891
16388
  /* .free = */ ggml_backend_vk_free,
14892
16389
  /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async,
14893
16390
  /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async,
16391
+ /* .set_tensor_2d_async = */ ggml_backend_vk_set_tensor_2d_async,
16392
+ /* .get_tensor_2d_async = */ ggml_backend_vk_get_tensor_2d_async,
14894
16393
  /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async,
14895
16394
  /* .synchronize = */ ggml_backend_vk_synchronize,
14896
16395
  /* .graph_plan_create = */ NULL,
@@ -15157,8 +16656,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15157
16656
  case GGML_GLU_OP_SWIGLU_OAI:
15158
16657
  case GGML_GLU_OP_GEGLU_ERF:
15159
16658
  case GGML_GLU_OP_GEGLU_QUICK:
15160
- return ggml_is_contiguous(op->src[0]) &&
15161
- (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
16659
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
15162
16660
  (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
15163
16661
  (op->src[0]->type == op->type);
15164
16662
  default:
@@ -15178,6 +16676,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15178
16676
  case GGML_TYPE_F32:
15179
16677
  case GGML_TYPE_F16:
15180
16678
  case GGML_TYPE_BF16:
16679
+ case GGML_TYPE_Q1_0:
15181
16680
  case GGML_TYPE_Q4_0:
15182
16681
  case GGML_TYPE_Q4_1:
15183
16682
  case GGML_TYPE_Q5_0:
@@ -15198,6 +16697,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15198
16697
  case GGML_TYPE_IQ4_XS:
15199
16698
  case GGML_TYPE_IQ4_NL:
15200
16699
  case GGML_TYPE_MXFP4:
16700
+ case GGML_TYPE_NVFP4:
15201
16701
  break;
15202
16702
  default:
15203
16703
  return false;
@@ -15246,42 +16746,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15246
16746
  if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
15247
16747
  return false;
15248
16748
  }
15249
- // It's straightforward to support different K/V dequant, but would
15250
- // significantly increase the number of pipelines
15251
- if (op->src[1]->type != op->src[2]->type) {
15252
- return false;
15253
- }
15254
- switch (op->src[1]->type) {
15255
- case GGML_TYPE_F16:
15256
- case GGML_TYPE_F32:
15257
- case GGML_TYPE_Q4_0:
15258
- case GGML_TYPE_Q8_0:
15259
- // supported in scalar and coopmat2 paths
15260
- break;
15261
- case GGML_TYPE_Q4_1:
15262
- case GGML_TYPE_Q5_0:
15263
- case GGML_TYPE_Q5_1:
15264
- // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
15265
- //case GGML_TYPE_Q2_K:
15266
- //case GGML_TYPE_Q3_K:
15267
- //case GGML_TYPE_Q4_K:
15268
- //case GGML_TYPE_Q5_K:
15269
- //case GGML_TYPE_Q6_K:
15270
- //case GGML_TYPE_IQ1_S:
15271
- //case GGML_TYPE_IQ1_M:
15272
- //case GGML_TYPE_IQ2_XXS:
15273
- //case GGML_TYPE_IQ2_XS:
15274
- //case GGML_TYPE_IQ2_S:
15275
- //case GGML_TYPE_IQ3_XXS:
15276
- //case GGML_TYPE_IQ3_S:
15277
- //case GGML_TYPE_IQ4_XS:
15278
- case GGML_TYPE_IQ4_NL:
15279
- // currently supported only in coopmat2 path
15280
- if (!coopmat2) {
16749
+ auto fa_kv_ok = [coopmat2](ggml_type t) {
16750
+ switch (t) {
16751
+ case GGML_TYPE_F32:
16752
+ case GGML_TYPE_F16:
16753
+ case GGML_TYPE_BF16:
16754
+ case GGML_TYPE_Q8_0:
16755
+ case GGML_TYPE_Q5_1:
16756
+ case GGML_TYPE_Q5_0:
16757
+ case GGML_TYPE_Q4_1:
16758
+ case GGML_TYPE_Q4_0:
16759
+ return true;
16760
+ case GGML_TYPE_Q1_0:
16761
+ return coopmat2;
16762
+ default:
15281
16763
  return false;
15282
16764
  }
15283
- break;
15284
- default:
16765
+ };
16766
+ if (!fa_kv_ok(op->src[1]->type) || !fa_kv_ok(op->src[2]->type)) {
16767
+ return false;
16768
+ }
16769
+ if ((op->src[1]->type == GGML_TYPE_BF16) != (op->src[2]->type == GGML_TYPE_BF16)) {
15285
16770
  return false;
15286
16771
  }
15287
16772
  if (!coopmat2 && !(device->subgroup_shuffle && device->subgroup_vote)) {
@@ -15296,6 +16781,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15296
16781
  case GGML_TYPE_F32:
15297
16782
  case GGML_TYPE_F16:
15298
16783
  case GGML_TYPE_BF16:
16784
+ case GGML_TYPE_Q1_0:
15299
16785
  case GGML_TYPE_Q4_0:
15300
16786
  case GGML_TYPE_Q4_1:
15301
16787
  case GGML_TYPE_Q5_0:
@@ -15316,6 +16802,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15316
16802
  case GGML_TYPE_IQ4_XS:
15317
16803
  case GGML_TYPE_IQ4_NL:
15318
16804
  case GGML_TYPE_MXFP4:
16805
+ case GGML_TYPE_NVFP4:
15319
16806
  case GGML_TYPE_I32:
15320
16807
  return true;
15321
16808
  default:
@@ -15328,6 +16815,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15328
16815
  case GGML_TYPE_F32:
15329
16816
  case GGML_TYPE_F16:
15330
16817
  case GGML_TYPE_BF16:
16818
+ case GGML_TYPE_Q1_0:
15331
16819
  case GGML_TYPE_Q4_0:
15332
16820
  case GGML_TYPE_Q4_1:
15333
16821
  case GGML_TYPE_Q5_0:
@@ -15351,6 +16839,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15351
16839
  case GGML_TYPE_F32:
15352
16840
  case GGML_TYPE_F16:
15353
16841
  case GGML_TYPE_BF16:
16842
+ case GGML_TYPE_Q1_0:
15354
16843
  case GGML_TYPE_Q4_0:
15355
16844
  case GGML_TYPE_Q4_1:
15356
16845
  case GGML_TYPE_Q5_0:
@@ -15365,6 +16854,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15365
16854
  if (src1_type == GGML_TYPE_F32) {
15366
16855
  switch (src0_type) {
15367
16856
  case GGML_TYPE_F16:
16857
+ case GGML_TYPE_BF16:
16858
+ case GGML_TYPE_Q1_0:
15368
16859
  case GGML_TYPE_Q4_0:
15369
16860
  case GGML_TYPE_Q4_1:
15370
16861
  case GGML_TYPE_Q5_0:
@@ -15400,7 +16891,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15400
16891
  return false;
15401
16892
  }
15402
16893
  case GGML_OP_REPEAT:
15403
- return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
16894
+ return ggml_type_size(op->type) == ggml_type_size(op->src[0]->type) &&
16895
+ (ggml_type_size(op->type) == sizeof(float) || ggml_type_size(op->type) == 2);
15404
16896
  case GGML_OP_REPEAT_BACK:
15405
16897
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
15406
16898
  case GGML_OP_ROPE:
@@ -15492,8 +16984,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
15492
16984
  || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
15493
16985
  || (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
15494
16986
  case GGML_OP_ARANGE:
15495
- case GGML_OP_FILL:
15496
16987
  return op->type == GGML_TYPE_F32;
16988
+ case GGML_OP_FILL:
16989
+ return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
15497
16990
  case GGML_OP_SCALE:
15498
16991
  return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
15499
16992
  case GGML_OP_PAD:
@@ -15672,10 +17165,13 @@ static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t
15672
17165
  return nullptr;
15673
17166
  }
15674
17167
 
15675
- // The event/fence is expected to initially be in the signaled state.
15676
- vkev->event = device->device.createEvent({});
15677
- vkev->fence = device->device.createFence({vk::FenceCreateFlagBits::eSignaled});
15678
- device->device.setEvent(vkev->event);
17168
+ // No events initially, they get created on demand
17169
+ vkev->has_event = false;
17170
+
17171
+ vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 };
17172
+ vk::SemaphoreCreateInfo ci{};
17173
+ ci.setPNext(&tci);
17174
+ vkev->tl_semaphore = { device->device.createSemaphore(ci), 0 };
15679
17175
 
15680
17176
  return new ggml_backend_event {
15681
17177
  /* .device = */ dev,
@@ -15689,8 +17185,16 @@ static void ggml_backend_vk_device_event_free(ggml_backend_dev_t dev, ggml_backe
15689
17185
 
15690
17186
  vk_event *vkev = (vk_event *)event->context;
15691
17187
 
15692
- device->device.destroyFence(vkev->fence);
15693
- device->device.destroyEvent(vkev->event);
17188
+ device->device.destroySemaphore(vkev->tl_semaphore.s);
17189
+ for (auto& event : vkev->events_free) {
17190
+ device->device.destroyEvent(event);
17191
+ }
17192
+ for (auto& event : vkev->events_submitted) {
17193
+ device->device.destroyEvent(event);
17194
+ }
17195
+ if (vkev->has_event) {
17196
+ device->device.destroyEvent(vkev->event);
17197
+ }
15694
17198
  delete vkev;
15695
17199
  delete event;
15696
17200
  }
@@ -15701,10 +17205,29 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
15701
17205
  auto device = ggml_vk_get_device(ctx->device);
15702
17206
  vk_event *vkev = (vk_event *)event->context;
15703
17207
 
15704
- VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
15705
- // Finished using current command buffer so we flag for reuse
15706
- if (vkev->cmd_buffer) {
15707
- vkev->cmd_buffer->in_use = false;
17208
+ // Only do something if the event has actually been used
17209
+ if (vkev->has_event) {
17210
+ vk::Semaphore sem = vkev->tl_semaphore.s;
17211
+ uint64_t val = vkev->tl_semaphore.value;
17212
+ vk::SemaphoreWaitInfo swi{vk::SemaphoreWaitFlags{}, sem, val};
17213
+ VK_CHECK(device->device.waitSemaphores(swi, UINT64_MAX), "event_synchronize");
17214
+
17215
+ // Reset and move submitted events
17216
+ for (auto& event : vkev->events_submitted) {
17217
+ device->device.resetEvent(event);
17218
+ }
17219
+ vkev->events_free.insert(vkev->events_free.end(), vkev->events_submitted.begin(), vkev->events_submitted.end());
17220
+ vkev->events_submitted.clear();
17221
+
17222
+ // Finished using current command buffer so we flag for reuse
17223
+ if (vkev->cmd_buffer) {
17224
+ // Only flag for reuse if it hasn't been reused already
17225
+ if (vkev->cmd_buffer_use_counter == vkev->cmd_buffer->use_counter) {
17226
+ vkev->cmd_buffer->in_use = false;
17227
+ vkev->cmd_buffer->buf.reset();
17228
+ }
17229
+ vkev->cmd_buffer = nullptr;
17230
+ }
15708
17231
  }
15709
17232
  }
15710
17233
 
@@ -15958,6 +17481,7 @@ static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev)
15958
17481
  case 0xE20C: // B570
15959
17482
  return 18;
15960
17483
  case 0xE20B: // B580
17484
+ case 0xE211: // Pro B60
15961
17485
  return 20;
15962
17486
  default:
15963
17487
  return 0;
@@ -16450,7 +17974,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
16450
17974
  src_clone[4], src_clone[5], src_clone[6]);
16451
17975
  } else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
16452
17976
  tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
16453
- src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
17977
+ src_clone[2], src_clone[3], src_clone[4], src_clone[5],
17978
+ ggml_get_op_params_i32(tensor, 0));
16454
17979
  } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
16455
17980
  src_clone[0]->flags = tensor->src[0]->flags;
16456
17981
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],