@fugood/llama.node 0.3.16 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (281) hide show
  1. package/CMakeLists.txt +6 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +44 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +374 -19
  24. package/src/LlamaCompletionWorker.h +31 -10
  25. package/src/LlamaContext.cpp +216 -7
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
  29. package/src/llama.cpp/.github/workflows/build.yml +89 -767
  30. package/src/llama.cpp/.github/workflows/docker.yml +9 -6
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +19 -23
  33. package/src/llama.cpp/CMakeLists.txt +11 -1
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +35 -4
  37. package/src/llama.cpp/common/arg.cpp +844 -121
  38. package/src/llama.cpp/common/arg.h +9 -0
  39. package/src/llama.cpp/common/chat.cpp +129 -107
  40. package/src/llama.cpp/common/chat.h +2 -0
  41. package/src/llama.cpp/common/common.cpp +64 -518
  42. package/src/llama.cpp/common/common.h +35 -45
  43. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  44. package/src/llama.cpp/common/llguidance.cpp +31 -47
  45. package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
  46. package/src/llama.cpp/common/minja/minja.hpp +186 -127
  47. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  48. package/src/llama.cpp/common/regex-partial.h +56 -0
  49. package/src/llama.cpp/common/sampling.cpp +60 -50
  50. package/src/llama.cpp/docs/build.md +122 -7
  51. package/src/llama.cpp/examples/CMakeLists.txt +2 -32
  52. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
  54. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  55. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  56. package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
  57. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  58. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  59. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  60. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  61. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  62. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  64. package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
  65. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  66. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  67. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  68. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  69. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  70. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  71. package/src/llama.cpp/ggml/include/ggml.h +76 -106
  72. package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
  73. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  74. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  75. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  76. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  77. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  78. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  79. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  80. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  81. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  82. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  83. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
  84. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  85. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  86. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  87. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  88. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
  89. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  90. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
  91. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
  93. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
  94. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
  95. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
  96. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  101. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  102. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
  103. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  104. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  105. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  106. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  107. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  108. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  109. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
  110. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  111. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
  112. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  113. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
  115. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
  116. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
  117. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  119. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  120. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
  121. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  122. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  123. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  124. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  136. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  137. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  138. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  140. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  141. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
  143. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
  144. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
  145. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
  146. package/src/llama.cpp/ggml/src/ggml.c +170 -265
  147. package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
  148. package/src/llama.cpp/include/llama.h +82 -22
  149. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  150. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  151. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  152. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  153. package/src/llama.cpp/requirements/requirements-all.txt +5 -3
  154. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  155. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  156. package/src/llama.cpp/src/CMakeLists.txt +4 -2
  157. package/src/llama.cpp/src/llama-adapter.cpp +43 -1
  158. package/src/llama.cpp/src/llama-arch.cpp +163 -17
  159. package/src/llama.cpp/src/llama-arch.h +16 -0
  160. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  161. package/src/llama.cpp/src/llama-batch.h +2 -1
  162. package/src/llama.cpp/src/llama-chat.cpp +91 -16
  163. package/src/llama.cpp/src/llama-chat.h +7 -2
  164. package/src/llama.cpp/src/llama-context.cpp +479 -575
  165. package/src/llama.cpp/src/llama-context.h +44 -33
  166. package/src/llama.cpp/src/llama-cparams.h +1 -0
  167. package/src/llama.cpp/src/llama-graph.cpp +209 -157
  168. package/src/llama.cpp/src/llama-graph.h +38 -14
  169. package/src/llama.cpp/src/llama-hparams.h +13 -0
  170. package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
  171. package/src/llama.cpp/src/llama-kv-cache.h +283 -171
  172. package/src/llama.cpp/src/llama-memory.h +12 -2
  173. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  174. package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
  175. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  176. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  177. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  178. package/src/llama.cpp/src/llama-model.cpp +1803 -330
  179. package/src/llama.cpp/src/llama-model.h +21 -2
  180. package/src/llama.cpp/src/llama-quant.cpp +33 -10
  181. package/src/llama.cpp/src/llama-sampling.cpp +25 -7
  182. package/src/llama.cpp/src/llama-vocab.cpp +86 -10
  183. package/src/llama.cpp/src/llama-vocab.h +6 -0
  184. package/src/llama.cpp/src/llama.cpp +15 -1
  185. package/src/llama.cpp/tests/CMakeLists.txt +52 -31
  186. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  187. package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
  188. package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
  189. package/src/llama.cpp/tests/test-chat.cpp +15 -3
  190. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  191. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  192. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  193. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  194. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  195. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  196. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  197. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  198. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  199. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  200. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  201. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  202. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  203. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  204. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
  205. package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
  206. package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
  207. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  208. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
  209. package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
  210. package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
  211. package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
  212. package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
  213. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  214. package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
  215. package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
  216. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  217. package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
  218. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  219. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
  220. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
  221. package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
  222. package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
  223. package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
  224. package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
  225. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  226. package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
  227. package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
  228. package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
  229. package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
  230. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  231. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  232. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  233. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  234. package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
  235. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  236. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  237. package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
  238. package/src/llama.cpp/examples/llava/clip.h +0 -118
  239. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  240. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  241. package/src/llama.cpp/examples/llava/llava.cpp +0 -574
  242. package/src/llama.cpp/examples/llava/llava.h +0 -49
  243. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  244. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
  245. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  246. package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
  247. package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
  248. package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
  249. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  250. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  251. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  252. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  253. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  254. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  255. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  256. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  257. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  258. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
  259. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  260. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  261. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  262. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  263. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  264. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  265. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  266. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  267. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  268. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  269. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  270. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  271. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  272. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  273. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  274. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  275. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  276. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  277. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  278. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  279. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  280. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  281. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
@@ -24,13 +24,54 @@
24
24
  #include <future>
25
25
  #include <thread>
26
26
 
27
+ #if defined(_MSC_VER)
28
+ # define NOMINMAX 1
29
+ # include <windows.h>
30
+ # define YIELD() YieldProcessor()
31
+ #elif defined(__clang__) || defined(__GNUC__)
32
+ # if defined(__x86_64__) ||defined(__i386__)
33
+ # include <immintrin.h>
34
+ # define YIELD() _mm_pause()
35
+ # elif defined(__arm__) || defined(__aarch64__)
36
+ # if defined(__clang__)
37
+ # include <arm_acle.h>
38
+ # define YIELD() __yield()
39
+ # else
40
+ # define YIELD() asm volatile("yield")
41
+ # endif
42
+ # endif
43
+ #endif
44
+
45
+ #if !defined(YIELD)
46
+ #define YIELD()
47
+ #endif
48
+
27
49
  #include "ggml-impl.h"
28
50
  #include "ggml-backend-impl.h"
29
51
 
30
52
  #include "ggml-vulkan-shaders.hpp"
31
53
 
54
+ // remove this once it's more widely available in the SDK
55
+ #if !defined(VK_KHR_shader_bfloat16)
56
+
57
+ #define VK_KHR_shader_bfloat16 1
58
+ #define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
59
+ #define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
60
+ #define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
61
+ #define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
62
+
63
+ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
64
+ VkStructureType sType;
65
+ void* pNext;
66
+ VkBool32 shaderBFloat16Type;
67
+ VkBool32 shaderBFloat16DotProduct;
68
+ VkBool32 shaderBFloat16CooperativeMatrix;
69
+ } VkPhysicalDeviceShaderBfloat16FeaturesKHR;
70
+ #endif
71
+
32
72
  #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
33
73
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
74
+ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
34
75
 
35
76
  #define VK_VENDOR_ID_AMD 0x1002
36
77
  #define VK_VENDOR_ID_APPLE 0x106b
@@ -223,6 +264,7 @@ struct vk_device_struct {
223
264
  bool pipeline_robustness;
224
265
  vk::Device device;
225
266
  uint32_t vendor_id;
267
+ vk::DriverId driver_id;
226
268
  vk_device_architecture architecture;
227
269
  vk_queue compute_queue;
228
270
  vk_queue transfer_queue;
@@ -233,6 +275,9 @@ struct vk_device_struct {
233
275
  bool prefer_host_memory;
234
276
  bool float_controls_rte_fp16;
235
277
  bool subgroup_add;
278
+ bool subgroup_shuffle;
279
+
280
+ bool integer_dot_product;
236
281
 
237
282
  bool subgroup_size_control;
238
283
  uint32_t subgroup_min_size;
@@ -240,11 +285,21 @@ struct vk_device_struct {
240
285
  bool subgroup_require_full_support;
241
286
 
242
287
  bool coopmat_support;
243
- bool coopmat_acc_f32_support;
244
- bool coopmat_acc_f16_support;
288
+ bool coopmat_acc_f32_support {};
289
+ bool coopmat_acc_f16_support {};
290
+ bool coopmat_bf16_support {};
291
+ bool coopmat_support_16x16x16_f16acc {};
292
+ bool coopmat_support_16x16x16_f32acc {};
293
+ bool coopmat1_fa_support {};
245
294
  uint32_t coopmat_m;
246
295
  uint32_t coopmat_n;
247
296
  uint32_t coopmat_k;
297
+
298
+ bool coopmat_int_support;
299
+ uint32_t coopmat_int_m;
300
+ uint32_t coopmat_int_n;
301
+ uint32_t coopmat_int_k;
302
+
248
303
  bool coopmat2;
249
304
 
250
305
  size_t idx;
@@ -261,19 +316,24 @@ struct vk_device_struct {
261
316
 
262
317
  vk_matmul_pipeline pipeline_matmul_f32 {};
263
318
  vk_matmul_pipeline pipeline_matmul_f32_f16 {};
319
+ vk_matmul_pipeline pipeline_matmul_bf16 {};
264
320
  vk_matmul_pipeline2 pipeline_matmul_f16;
265
321
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
266
- vk_pipeline pipeline_matmul_split_k_reduce;
267
322
 
268
- vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
269
323
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
324
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
325
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
270
326
 
271
327
  vk_matmul_pipeline pipeline_matmul_id_f32 {};
328
+ vk_matmul_pipeline pipeline_matmul_id_bf16 {};
272
329
  vk_matmul_pipeline2 pipeline_matmul_id_f16;
273
330
  vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
274
331
 
275
332
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
276
333
 
334
+ vk_pipeline pipeline_matmul_split_k_reduce;
335
+ vk_pipeline pipeline_quantize_q8_1;
336
+
277
337
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
278
338
  vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
279
339
  vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@@ -284,11 +344,17 @@ struct vk_device_struct {
284
344
  vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
285
345
  vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
286
346
  vk_pipeline pipeline_acc_f32;
287
- vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
288
- vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
289
- vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
290
- vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
291
- vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
347
+
348
+ // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
349
+ vk_pipeline pipeline_add[2][2][2];
350
+ vk_pipeline pipeline_add_norepeat[2][2][2];
351
+ vk_pipeline pipeline_sub[2][2][2];
352
+ vk_pipeline pipeline_sub_norepeat[2][2][2];
353
+ vk_pipeline pipeline_mul[2][2][2];
354
+ vk_pipeline pipeline_mul_norepeat[2][2][2];
355
+ vk_pipeline pipeline_div[2][2][2];
356
+ vk_pipeline pipeline_div_norepeat[2][2][2];
357
+
292
358
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
293
359
  vk_pipeline pipeline_upscale_f32;
294
360
  vk_pipeline pipeline_scale_f32;
@@ -298,8 +364,8 @@ struct vk_device_struct {
298
364
  vk_pipeline pipeline_clamp_f32;
299
365
  vk_pipeline pipeline_pad_f32;
300
366
  vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
301
- vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
302
- vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
367
+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
368
+ vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
303
369
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
304
370
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
305
371
  vk_pipeline pipeline_norm_f32;
@@ -307,14 +373,17 @@ struct vk_device_struct {
307
373
  vk_pipeline pipeline_rms_norm_f32;
308
374
  vk_pipeline pipeline_rms_norm_back_f32;
309
375
  vk_pipeline pipeline_l2_norm_f32;
310
- vk_pipeline pipeline_gelu_f32;
311
- vk_pipeline pipeline_gelu_quick_f32;
312
- vk_pipeline pipeline_silu_f32;
313
- vk_pipeline pipeline_silu_back_f32;
314
- vk_pipeline pipeline_relu_f32;
376
+
377
+ // [src/dst 0=fp32,1=fp16]
378
+ vk_pipeline pipeline_gelu[2];
379
+ vk_pipeline pipeline_gelu_quick[2];
380
+ vk_pipeline pipeline_silu[2];
381
+ vk_pipeline pipeline_relu[2];
382
+ vk_pipeline pipeline_tanh[2];
383
+ vk_pipeline pipeline_sigmoid[2];
384
+
315
385
  vk_pipeline pipeline_leaky_relu_f32;
316
- vk_pipeline pipeline_tanh_f32;
317
- vk_pipeline pipeline_sigmoid_f32;
386
+ vk_pipeline pipeline_silu_back_f32;
318
387
  vk_pipeline pipeline_diag_mask_inf_f32;
319
388
  vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
320
389
  vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@@ -333,8 +402,24 @@ struct vk_device_struct {
333
402
  vk_pipeline pipeline_rwkv_wkv6_f32;
334
403
  vk_pipeline pipeline_rwkv_wkv7_f32;
335
404
  vk_pipeline pipeline_opt_step_adamw_f32;
405
+ vk_pipeline pipeline_conv2d_dw_whcn_f32;
406
+ vk_pipeline pipeline_conv2d_dw_cwhn_f32;
336
407
 
337
408
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
409
+ vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
410
+ vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
411
+ vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
412
+ vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
413
+ vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
414
+ vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
415
+
416
+ vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
417
+ vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
418
+ vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
419
+ vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
420
+ vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
421
+ vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
422
+
338
423
  vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
339
424
  vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
340
425
  vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
@@ -342,6 +427,8 @@ struct vk_device_struct {
342
427
  vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
343
428
  vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
344
429
 
430
+ vk_pipeline pipeline_flash_attn_split_k_reduce;
431
+
345
432
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
346
433
  std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
347
434
 
@@ -490,6 +577,10 @@ struct vk_flash_attn_push_constants {
490
577
  uint32_t n_head_log2;
491
578
  float m0;
492
579
  float m1;
580
+
581
+ uint32_t gqa_ratio;
582
+ uint32_t split_kv;
583
+ uint32_t k_num;
493
584
  };
494
585
 
495
586
  struct vk_op_push_constants {
@@ -640,13 +731,22 @@ struct vk_op_rwkv_wkv7_push_constants {
640
731
  uint32_t H;
641
732
  };
642
733
 
643
- // Allow pre-recording command buffers
644
- struct vk_staging_memcpy {
645
- vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
646
-
647
- void * dst;
648
- const void * src;
649
- size_t n;
734
+ struct vk_op_conv2d_dw_push_constants {
735
+ uint32_t ne;
736
+ uint32_t batches;
737
+ uint32_t channels;
738
+ uint32_t dst_w;
739
+ uint32_t dst_h;
740
+ uint32_t src_w;
741
+ uint32_t src_h;
742
+ uint32_t knl_w;
743
+ uint32_t knl_h;
744
+ int32_t stride_x;
745
+ int32_t stride_y;
746
+ int32_t pad_x;
747
+ int32_t pad_y;
748
+ int32_t dilation_x;
749
+ int32_t dilation_y;
650
750
  };
651
751
 
652
752
  struct vk_op_upscale_push_constants {
@@ -656,6 +756,15 @@ struct vk_op_upscale_push_constants {
656
756
  float sf0; float sf1; float sf2; float sf3;
657
757
  };
658
758
 
759
+ // Allow pre-recording command buffers
760
+ struct vk_staging_memcpy {
761
+ vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
762
+
763
+ void * dst;
764
+ const void * src;
765
+ size_t n;
766
+ };
767
+
659
768
  struct vk_context_struct {
660
769
  vk_submission * s;
661
770
  std::vector<vk_sequence> seqs;
@@ -770,7 +879,8 @@ struct ggml_backend_vk_context {
770
879
  ggml_vk_garbage_collector gc;
771
880
  size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
772
881
  vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
773
- vk::Fence fence;
882
+ vk::Fence fence, almost_ready_fence;
883
+ bool almost_ready_fence_pending {};
774
884
 
775
885
  vk_buffer buffer_pool[MAX_VK_BUFFERS];
776
886
 
@@ -861,6 +971,39 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
861
971
 
862
972
  static void ggml_backend_vk_free(ggml_backend_t backend);
863
973
 
974
+ // Wait for ctx->fence to be signaled.
975
+ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
976
+ // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
977
+ // during this wait.
978
+ if (ctx->almost_ready_fence_pending) {
979
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence");
980
+ ctx->device->device.resetFences({ ctx->almost_ready_fence });
981
+ ctx->almost_ready_fence_pending = false;
982
+ }
983
+
984
+ // Spin (w/pause) waiting for the graph to finish executing.
985
+ vk::Result result;
986
+ while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) {
987
+ if (result != vk::Result::eNotReady) {
988
+ fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__);
989
+ exit(1);
990
+ }
991
+ for (uint32_t i = 0; i < 100; ++i) {
992
+ YIELD();
993
+ YIELD();
994
+ YIELD();
995
+ YIELD();
996
+ YIELD();
997
+ YIELD();
998
+ YIELD();
999
+ YIELD();
1000
+ YIELD();
1001
+ YIELD();
1002
+ }
1003
+ }
1004
+ ctx->device->device.resetFences({ ctx->fence });
1005
+ }
1006
+
864
1007
  // variables to track number of compiles in progress
865
1008
  static uint32_t compile_count = 0;
866
1009
  static std::mutex compile_count_mutex;
@@ -1455,15 +1598,56 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
1455
1598
  );
1456
1599
  }
1457
1600
 
1601
+ enum FaCodePath {
1602
+ FA_SCALAR,
1603
+ FA_COOPMAT1,
1604
+ FA_COOPMAT2,
1605
+ };
1606
+
1458
1607
  // number of rows/cols for flash attention shader
1459
1608
  static constexpr uint32_t flash_attention_num_small_rows = 32;
1460
- static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1609
+ static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1610
+ static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1611
+
1612
+ // The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
1613
+ // 128 threads split into four subgroups, each subgroup does 1/4
1614
+ // of the Bc dimension.
1615
+ static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
1616
+ static constexpr uint32_t scalar_flash_attention_Bc = 64;
1617
+ static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
1618
+
1619
+ static uint32_t get_fa_num_small_rows(FaCodePath path) {
1620
+ if (path == FA_COOPMAT2) {
1621
+ return flash_attention_num_small_rows;
1622
+ } else {
1623
+ return scalar_flash_attention_num_small_rows;
1624
+ }
1625
+ }
1626
+
1627
+ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1461
1628
  GGML_UNUSED(clamp);
1462
1629
 
1630
+ if (path == FA_SCALAR) {
1631
+ if (small_rows) {
1632
+ return {scalar_flash_attention_num_small_rows, 64};
1633
+ } else {
1634
+ return {scalar_flash_attention_num_large_rows, 32};
1635
+ }
1636
+ }
1637
+
1638
+ if (path == FA_COOPMAT1) {
1639
+ if (small_rows) {
1640
+ return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
1641
+ } else {
1642
+ return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
1643
+ }
1644
+ }
1645
+
1463
1646
  // small rows, large cols
1464
1647
  if (small_rows) {
1465
- return {flash_attention_num_small_rows, 128};
1648
+ return {get_fa_num_small_rows(FA_COOPMAT2), 32};
1466
1649
  }
1650
+
1467
1651
  // small cols to reduce register count
1468
1652
  if (ggml_is_quantized(type) || D == 256) {
1469
1653
  return {64, 32};
@@ -1508,7 +1692,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1508
1692
  const uint32_t warps = warptile[0] / warptile[10];
1509
1693
 
1510
1694
  const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1511
- const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
1695
+ const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
1512
1696
  const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1513
1697
 
1514
1698
  const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
@@ -1598,6 +1782,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1598
1782
  // mulmat
1599
1783
  std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
1600
1784
  l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
1785
+ l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
1601
1786
  l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
1602
1787
  l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
1603
1788
  std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
@@ -1662,6 +1847,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
1662
1847
  m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
1663
1848
  s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
1664
1849
 
1850
+ l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
1851
+ m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
1852
+ s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
1853
+
1854
+ // chip specific tuning
1855
+ if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
1856
+ m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
1857
+ }
1858
+
1665
1859
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1666
1860
  m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1667
1861
  s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
@@ -1707,6 +1901,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
1707
1901
  if (!device->pipeline_matmul_id_f32) {
1708
1902
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1709
1903
  }
1904
+ if (!device->pipeline_matmul_bf16) {
1905
+ device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
1906
+ }
1907
+ if (!device->pipeline_matmul_id_bf16) {
1908
+ device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
1909
+ }
1710
1910
 
1711
1911
  std::vector<std::future<void>> compiles;
1712
1912
  auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
@@ -1742,63 +1942,75 @@ static void ggml_vk_load_shaders(vk_device& device) {
1742
1942
  parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
1743
1943
  };
1744
1944
 
1745
- #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1746
- if (device->coopmat2) {
1747
-
1748
- auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1749
- return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
1750
- };
1945
+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1946
+ return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
1947
+ };
1751
1948
 
1752
- auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
1753
- // For large number of rows, 128 invocations seems to work best.
1754
- // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
1755
- // can't use 256 for D==80.
1756
- uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
1757
- auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
1758
- return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
1759
- };
1949
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
1950
+ // For large number of rows, 128 invocations seems to work best.
1951
+ // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
1952
+ // can't use 256 for D==80.
1953
+ // For scalar, use 128 (arbitrary)
1954
+ uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
1955
+ ? scalar_flash_attention_workgroup_size
1956
+ : ((small_rows && (D % 32) == 0) ? 256 : 128);
1957
+ auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
1958
+
1959
+ // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
1960
+ // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
1961
+ const uint32_t D_lsb = D ^ (D & (D-1));
1962
+ uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
1963
+
1964
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
1965
+ GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
1966
+ return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
1967
+ };
1760
1968
 
1761
- #define CREATE_FA2(TYPE, NAMELC, D) \
1762
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1763
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1764
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1765
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1766
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1767
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1768
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1769
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1770
-
1771
- #define CREATE_FA(TYPE, NAMELC) \
1772
- CREATE_FA2(TYPE, NAMELC, 64) \
1773
- CREATE_FA2(TYPE, NAMELC, 80) \
1774
- CREATE_FA2(TYPE, NAMELC, 96) \
1775
- CREATE_FA2(TYPE, NAMELC, 112) \
1776
- CREATE_FA2(TYPE, NAMELC, 128) \
1777
- CREATE_FA2(TYPE, NAMELC, 256)
1778
-
1779
- CREATE_FA(GGML_TYPE_F16, f16)
1780
- CREATE_FA(GGML_TYPE_Q4_0, q4_0)
1781
- CREATE_FA(GGML_TYPE_Q4_1, q4_1)
1782
- CREATE_FA(GGML_TYPE_Q5_0, q5_0)
1783
- CREATE_FA(GGML_TYPE_Q5_1, q5_1)
1784
- CREATE_FA(GGML_TYPE_Q8_0, q8_0)
1785
- // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
1786
- //CREATE_FA(GGML_TYPE_Q2_K, q2_k)
1787
- //CREATE_FA(GGML_TYPE_Q3_K, q3_k)
1788
- //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
1789
- //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
1790
- //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
1791
- //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
1792
- //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
1793
- //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
1794
- //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
1795
- //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
1796
- //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
1797
- //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
1798
- //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs)
1799
- CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
1969
+ #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
1970
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1971
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1972
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1973
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1974
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1975
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1976
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1977
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1978
+
1979
+ #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
1980
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
1981
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
1982
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
1983
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
1984
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
1985
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
1986
+
1987
+ CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
1988
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
1989
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
1990
+ #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
1991
+ if (device->coopmat1_fa_support) {
1992
+ CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
1993
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
1994
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
1995
+ }
1996
+ #endif
1997
+ #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1998
+ if (device->coopmat2) {
1999
+ CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
2000
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
2001
+ CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
2002
+ CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
2003
+ CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
2004
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
2005
+ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
2006
+ }
2007
+ #endif
2008
+ #undef CREATE_FA2
1800
2009
  #undef CREATE_FA
1801
2010
 
2011
+ #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2012
+ if (device->coopmat2) {
2013
+
1802
2014
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1803
2015
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1804
2016
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
@@ -1814,6 +2026,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1814
2026
  CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1815
2027
 
1816
2028
  CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
2029
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2030
+ if (device->coopmat_bf16_support) {
2031
+ CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
2032
+ }
2033
+ #endif
1817
2034
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1818
2035
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1819
2036
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1835,6 +2052,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1835
2052
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1836
2053
 
1837
2054
  CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2055
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2056
+ if (device->coopmat_bf16_support) {
2057
+ CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2058
+ }
2059
+ #endif
1838
2060
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1839
2061
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1840
2062
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -1863,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
1863
2085
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1864
2086
  #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1865
2087
  if (device->mul_mat ## ID ## _l[TYPE]) \
1866
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
2088
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
1867
2089
  if (device->mul_mat ## ID ## _m[TYPE]) \
1868
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
2090
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
1869
2091
  if (device->mul_mat ## ID ## _s[TYPE]) \
1870
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
2092
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
1871
2093
  if (device->mul_mat ## ID ## _l[TYPE]) \
1872
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
2094
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
1873
2095
  if (device->mul_mat ## ID ## _m[TYPE]) \
1874
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
2096
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
1875
2097
  if (device->mul_mat ## ID ## _s[TYPE]) \
1876
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
2098
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
1877
2099
 
1878
2100
  // Create 2 variants, {f16,f32} accumulator
1879
2101
  #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -1888,6 +2110,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1888
2110
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1889
2111
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1890
2112
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2113
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2114
+ if (device->coopmat_bf16_support) {
2115
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
2116
+ }
2117
+ #endif
1891
2118
 
1892
2119
  if (device->coopmat_acc_f16_support) {
1893
2120
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -1936,6 +2163,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1936
2163
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1937
2164
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1938
2165
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2166
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2167
+ if (device->coopmat_bf16_support) {
2168
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2169
+ }
2170
+ #endif
1939
2171
 
1940
2172
  if (device->coopmat_acc_f16_support) {
1941
2173
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2000,6 +2232,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2000
2232
  if (device->mul_mat ## ID ## _s[TYPE]) \
2001
2233
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2002
2234
 
2235
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2236
+ if (device->mul_mat ## ID ## _l[TYPE]) \
2237
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2238
+ if (device->mul_mat ## ID ## _m[TYPE]) \
2239
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2240
+ if (device->mul_mat ## ID ## _s[TYPE]) \
2241
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2242
+
2003
2243
  // Create 2 variants, {f16,f32} accumulator
2004
2244
  #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2005
2245
  CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -2010,6 +2250,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2010
2250
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2011
2251
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2012
2252
 
2253
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2254
+
2013
2255
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2014
2256
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2015
2257
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2031,10 +2273,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
2031
2273
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2032
2274
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2033
2275
 
2276
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2277
+ if (device->integer_dot_product) {
2278
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2279
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2280
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2281
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2282
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2283
+ }
2284
+ #endif
2285
+
2034
2286
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2035
2287
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2036
2288
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2037
2289
 
2290
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2291
+
2038
2292
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2039
2293
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2040
2294
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2056,6 +2310,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2056
2310
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2057
2311
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2058
2312
  #undef CREATE_MM2
2313
+ #undef CREATE_MMQ
2059
2314
  #undef CREATE_MM
2060
2315
  } else {
2061
2316
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -2073,11 +2328,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
2073
2328
  if (device->mul_mat ## ID ## _s[TYPE]) \
2074
2329
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2075
2330
 
2331
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2332
+ if (device->mul_mat ## ID ## _l[TYPE]) \
2333
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2334
+ if (device->mul_mat ## ID ## _m[TYPE]) \
2335
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2336
+ if (device->mul_mat ## ID ## _s[TYPE]) \
2337
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2338
+
2076
2339
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2077
2340
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2078
2341
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2079
2342
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2080
2343
 
2344
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2345
+
2081
2346
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2082
2347
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2083
2348
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2099,10 +2364,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
2099
2364
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2100
2365
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2101
2366
 
2367
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2368
+ if (device->integer_dot_product) {
2369
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2370
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2371
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2372
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2373
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2374
+ }
2375
+ #endif
2376
+
2102
2377
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2103
2378
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2104
2379
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2105
2380
 
2381
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2382
+
2106
2383
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2107
2384
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2108
2385
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2123,8 +2400,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
2123
2400
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2124
2401
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2125
2402
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2126
- #undef CREATE_MM
2127
2403
  }
2404
+ // reusing CREATE_MM from the fp32 path
2405
+ if ((device->coopmat2 || device->coopmat_support)
2406
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2407
+ && !device->coopmat_bf16_support
2408
+ #endif
2409
+ ) {
2410
+ // use scalar tile sizes
2411
+ l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2412
+ m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
2413
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
2414
+
2415
+ l_wg_denoms = {128, 128, 1 };
2416
+ m_wg_denoms = { 64, 64, 1 };
2417
+ s_wg_denoms = { 32, 32, 1 };
2418
+
2419
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2420
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2421
+ }
2422
+ #undef CREATE_MM
2128
2423
 
2129
2424
  // mul mat vec
2130
2425
 
@@ -2132,7 +2427,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2132
2427
  uint32_t rm_stdq = 1;
2133
2428
  uint32_t rm_kq = 2;
2134
2429
  if (device->vendor_id == VK_VENDOR_ID_AMD) {
2135
- if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN
2430
+ if (device->architecture == AMD_GCN) {
2136
2431
  rm_stdq = 2;
2137
2432
  rm_kq = 4;
2138
2433
  }
@@ -2143,6 +2438,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2143
2438
  for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
2144
2439
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2145
2440
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2441
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2146
2442
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2147
2443
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2148
2444
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
@@ -2165,6 +2461,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2165
2461
 
2166
2462
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2167
2463
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2464
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2168
2465
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2169
2466
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2170
2467
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
@@ -2188,6 +2485,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2188
2485
 
2189
2486
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2190
2487
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2488
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2191
2489
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
2192
2490
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
2193
2491
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
@@ -2233,6 +2531,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2233
2531
  // get_rows
2234
2532
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2235
2533
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2534
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2236
2535
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2237
2536
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2238
2537
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2250,6 +2549,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2250
2549
 
2251
2550
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2252
2551
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2552
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2253
2553
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2254
2554
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2255
2555
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2266,6 +2566,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2266
2566
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2267
2567
 
2268
2568
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2569
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2570
+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2269
2571
 
2270
2572
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2271
2573
  if (device->subgroup_add && device->subgroup_require_full_support) {
@@ -2274,21 +2576,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
2274
2576
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
2275
2577
  }
2276
2578
  }
2277
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2579
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2278
2580
 
2279
2581
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2280
2582
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2281
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2583
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2282
2584
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2283
2585
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2284
2586
 
2285
2587
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2286
2588
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2287
2589
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2590
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2591
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2288
2592
 
2289
2593
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2290
2594
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2291
2595
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2596
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2597
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2598
+
2292
2599
  if (device->float_controls_rte_fp16) {
2293
2600
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2294
2601
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
@@ -2312,19 +2619,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
2312
2619
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2313
2620
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2314
2621
 
2315
- ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2316
- ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2317
- ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2318
- ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2622
+ auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
2623
+ std::string s;
2624
+ s += std::string(src0_f16 ? "_f16" : "_f32");
2625
+ s += std::string(src1_f16 ? "_f16" : "_f32");
2626
+ s += std::string(dst_f16 ? "_f16" : "_f32");
2627
+ return s;
2628
+ };
2319
2629
 
2320
- ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2630
+ #define CREATE_BINARY(name, namemod, spec) \
2631
+ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
2632
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2633
+ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
2634
+ "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
2635
+
2636
+ CREATE_BINARY(add, , {0})
2637
+ CREATE_BINARY(add, _norepeat, {1})
2638
+ CREATE_BINARY(sub, , {0})
2639
+ CREATE_BINARY(sub, _norepeat, {1})
2640
+ CREATE_BINARY(mul, , {0})
2641
+ CREATE_BINARY(mul, _norepeat, {1})
2642
+ CREATE_BINARY(div, , {0})
2643
+ CREATE_BINARY(div, _norepeat, {1})
2644
+ #undef CREATE_BINARY
2321
2645
 
2322
- ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2323
- ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2324
- ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2325
- ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2326
- ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2327
- ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2646
+ ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2328
2647
 
2329
2648
  ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2330
2649
  ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -2345,14 +2664,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
2345
2664
  ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2346
2665
  ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2347
2666
 
2348
- ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2349
- ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2350
- ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2351
- ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2352
- ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2667
+ #define CREATE_UNARY(name) \
2668
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
2669
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2670
+
2671
+ CREATE_UNARY(gelu)
2672
+ CREATE_UNARY(gelu_quick)
2673
+ CREATE_UNARY(silu)
2674
+ CREATE_UNARY(relu)
2675
+ CREATE_UNARY(tanh)
2676
+ CREATE_UNARY(sigmoid)
2677
+ #undef CREATE_UNARY
2678
+
2353
2679
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2354
- ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2355
- ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2680
+ ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2356
2681
 
2357
2682
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2358
2683
 
@@ -2404,6 +2729,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2404
2729
 
2405
2730
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2406
2731
 
2732
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2733
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2734
+
2407
2735
  for (auto &c : compiles) {
2408
2736
  c.wait();
2409
2737
  }
@@ -2452,6 +2780,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2452
2780
  bool pipeline_robustness = false;
2453
2781
  bool coopmat2_support = false;
2454
2782
  device->coopmat_support = false;
2783
+ device->integer_dot_product = false;
2784
+ bool bfloat16_support = false;
2455
2785
 
2456
2786
  for (const auto& properties : ext_props) {
2457
2787
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -2477,6 +2807,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
2477
2807
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2478
2808
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2479
2809
  coopmat2_support = true;
2810
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2811
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
2812
+ !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
2813
+ device->integer_dot_product = true;
2814
+ #endif
2815
+ } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
2816
+ !getenv("GGML_VK_DISABLE_BFLOAT16")) {
2817
+ bfloat16_support = true;
2480
2818
  }
2481
2819
  }
2482
2820
 
@@ -2490,6 +2828,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2490
2828
  vk::PhysicalDeviceVulkan11Properties vk11_props;
2491
2829
  vk::PhysicalDeviceVulkan12Properties vk12_props;
2492
2830
  vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2831
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
2493
2832
 
2494
2833
  props2.pNext = &props3;
2495
2834
  props3.pNext = &subgroup_props;
@@ -2524,9 +2863,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
2524
2863
  }
2525
2864
  #endif
2526
2865
 
2866
+ if (device->integer_dot_product) {
2867
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
2868
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
2869
+ }
2870
+
2527
2871
  device->physical_device.getProperties2(&props2);
2528
2872
  device->properties = props2.properties;
2529
2873
  device->vendor_id = device->properties.vendorID;
2874
+ device->driver_id = driver_props.driverID;
2530
2875
 
2531
2876
  const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
2532
2877
 
@@ -2562,6 +2907,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2562
2907
  device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2563
2908
  (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2564
2909
 
2910
+ device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2911
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
2912
+
2565
2913
  const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2566
2914
 
2567
2915
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -2570,6 +2918,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2570
2918
  device->coopmat_support = false;
2571
2919
  }
2572
2920
 
2921
+ device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
2922
+
2573
2923
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
2574
2924
 
2575
2925
  // Try to find a non-graphics compute queue and transfer-focused queues
@@ -2654,6 +3004,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
2654
3004
  }
2655
3005
  #endif
2656
3006
 
3007
+ #if defined(VK_KHR_shader_bfloat16)
3008
+ VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
3009
+ bfloat16_features.pNext = nullptr;
3010
+ bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
3011
+ if (bfloat16_support) {
3012
+ last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
3013
+ last_struct = (VkBaseOutStructure *)&bfloat16_features;
3014
+ device_extensions.push_back("VK_KHR_shader_bfloat16");
3015
+ }
3016
+ #endif
3017
+
2657
3018
  VkPhysicalDeviceMaintenance4Features maint4_features {};
2658
3019
  maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2659
3020
  if (maintenance4_support) {
@@ -2662,6 +3023,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
2662
3023
  device_extensions.push_back("VK_KHR_maintenance4");
2663
3024
  }
2664
3025
 
3026
+ VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
3027
+ shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
3028
+ if (device->integer_dot_product) {
3029
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3030
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3031
+ device_extensions.push_back("VK_KHR_shader_integer_dot_product");
3032
+ }
3033
+
2665
3034
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
2666
3035
 
2667
3036
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -2684,6 +3053,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
2684
3053
 
2685
3054
  #if defined(VK_KHR_cooperative_matrix)
2686
3055
  device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
3056
+
3057
+ // coopmat1 fa shader currently assumes 32 invocations per subgroup
3058
+ device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
3059
+ device->subgroup_size_control && device->subgroup_min_size <= 32 &&
3060
+ device->subgroup_max_size >= 32;
2687
3061
  #endif
2688
3062
 
2689
3063
  if (coopmat2_support) {
@@ -2818,6 +3192,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2818
3192
  // Only enable if shape is identical
2819
3193
  device->coopmat_acc_f32_support = true;
2820
3194
  }
3195
+ if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
3196
+ device->coopmat_support_16x16x16_f32acc = true;
3197
+ }
2821
3198
  } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
2822
3199
  (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
2823
3200
  // coopmat sizes not set yet
@@ -2830,8 +3207,41 @@ static vk_device ggml_vk_get_device(size_t idx) {
2830
3207
  // Only enable if shape is identical
2831
3208
  device->coopmat_acc_f16_support = true;
2832
3209
  }
3210
+ if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
3211
+ device->coopmat_support_16x16x16_f16acc = true;
3212
+ }
3213
+ }
3214
+ } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
3215
+ (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
3216
+ (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
3217
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
3218
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
3219
+ device->coopmat_int_m == 0
3220
+ ) {
3221
+ device->coopmat_int_support = true;
3222
+ device->coopmat_int_m = prop.MSize;
3223
+ device->coopmat_int_n = prop.NSize;
3224
+ device->coopmat_int_k = prop.KSize;
3225
+ }
3226
+ #if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3227
+ if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
3228
+ prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
3229
+ prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
3230
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
3231
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
3232
+ ) {
3233
+ // coopmat sizes not set yet
3234
+ if (device->coopmat_m == 0) {
3235
+ device->coopmat_bf16_support = true;
3236
+ device->coopmat_m = prop.MSize;
3237
+ device->coopmat_n = prop.NSize;
3238
+ device->coopmat_k = prop.KSize;
3239
+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
3240
+ // Only enable if shape is identical
3241
+ device->coopmat_bf16_support = true;
2833
3242
  }
2834
3243
  }
3244
+ #endif
2835
3245
  }
2836
3246
 
2837
3247
  if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
@@ -2839,11 +3249,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
2839
3249
  GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
2840
3250
  device->coopmat_support = false;
2841
3251
  }
3252
+ if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
3253
+ device->coopmat_bf16_support = false;
3254
+ }
2842
3255
  }
2843
3256
 
2844
3257
  if (device->coopmat_support) {
2845
3258
  device_extensions.push_back("VK_KHR_cooperative_matrix");
2846
3259
  }
3260
+ #if defined(VK_KHR_shader_bfloat16)
3261
+ if (device->coopmat_bf16_support) {
3262
+ device_extensions.push_back("VK_KHR_shader_bfloat16");
3263
+ }
3264
+ #endif
2847
3265
  #endif
2848
3266
  device->name = GGML_VK_NAME + std::to_string(idx);
2849
3267
 
@@ -2935,25 +3353,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2935
3353
  vk::PhysicalDevice physical_device = devices[dev_num];
2936
3354
  std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
2937
3355
 
2938
- vk::PhysicalDeviceProperties2 props2;
2939
- vk::PhysicalDeviceMaintenance3Properties props3;
2940
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
2941
- vk::PhysicalDeviceDriverProperties driver_props;
2942
- props2.pNext = &props3;
2943
- props3.pNext = &subgroup_props;
2944
- subgroup_props.pNext = &driver_props;
2945
- physical_device.getProperties2(&props2);
2946
-
2947
- vk_device_architecture arch = get_device_architecture(physical_device);
2948
- uint32_t default_subgroup_size = get_subgroup_size("", arch);
2949
- const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
2950
-
2951
- const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
2952
-
2953
3356
  bool fp16_storage = false;
2954
3357
  bool fp16_compute = false;
2955
3358
  bool coopmat_support = false;
2956
3359
  bool coopmat2_support = false;
3360
+ bool integer_dot_product = false;
2957
3361
 
2958
3362
  for (auto properties : ext_props) {
2959
3363
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -2969,27 +3373,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2969
3373
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2970
3374
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2971
3375
  coopmat2_support = true;
3376
+ #endif
3377
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3378
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
3379
+ !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
3380
+ integer_dot_product = true;
2972
3381
  #endif
2973
3382
  }
2974
3383
  }
2975
3384
 
2976
3385
  const vk_device_architecture device_architecture = get_device_architecture(physical_device);
2977
3386
 
2978
- if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
2979
- coopmat_support = false;
2980
- }
2981
-
2982
3387
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
2983
3388
  bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
2984
3389
 
2985
3390
  bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2986
3391
 
2987
- vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
3392
+ vk::PhysicalDeviceProperties2 props2;
3393
+ vk::PhysicalDeviceMaintenance3Properties props3;
3394
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
3395
+ vk::PhysicalDeviceDriverProperties driver_props;
3396
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
3397
+ props2.pNext = &props3;
3398
+ props3.pNext = &subgroup_props;
3399
+ subgroup_props.pNext = &driver_props;
3400
+
3401
+ // Pointer to the last chain element
3402
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
3403
+
3404
+ if (integer_dot_product) {
3405
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
3406
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
3407
+ }
3408
+
3409
+ physical_device.getProperties2(&props2);
2988
3410
 
2989
3411
  VkPhysicalDeviceFeatures2 device_features2;
2990
3412
  device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
2991
3413
  device_features2.pNext = nullptr;
2992
- device_features2.features = (VkPhysicalDeviceFeatures)device_features;
2993
3414
 
2994
3415
  VkPhysicalDeviceVulkan11Features vk11_features;
2995
3416
  vk11_features.pNext = nullptr;
@@ -3002,7 +3423,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3002
3423
  vk11_features.pNext = &vk12_features;
3003
3424
 
3004
3425
  // Pointer to the last chain element
3005
- VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
3426
+ last_struct = (VkBaseOutStructure *)&vk12_features;
3006
3427
 
3007
3428
  #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3008
3429
  VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
@@ -3014,20 +3435,39 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3014
3435
  last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
3015
3436
  last_struct = (VkBaseOutStructure *)&coopmat_features;
3016
3437
  }
3438
+ #endif
3439
+
3440
+ VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
3441
+ shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
3442
+ if (integer_dot_product) {
3443
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3444
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3445
+ }
3017
3446
 
3018
3447
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
3019
3448
 
3020
3449
  fp16 = fp16 && vk12_features.shaderFloat16;
3021
3450
 
3022
- coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
3451
+ uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
3452
+ const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
3453
+ const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
3454
+
3455
+ integer_dot_product = integer_dot_product
3456
+ && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
3457
+ && shader_integer_dot_product_features.shaderIntegerDotProduct;
3458
+
3459
+ coopmat_support = coopmat_support
3460
+ #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3461
+ && coopmat_features.cooperativeMatrix
3023
3462
  #endif
3463
+ && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
3024
3464
 
3025
3465
  std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
3026
3466
 
3027
3467
  std::string device_name = props2.properties.deviceName.data();
3028
- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
3468
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
3029
3469
  idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
3030
- props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
3470
+ props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
3031
3471
 
3032
3472
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
3033
3473
  GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -3229,6 +3669,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
3229
3669
  ctx->prealloc_size_split_k = 0;
3230
3670
 
3231
3671
  ctx->fence = ctx->device->device.createFence({});
3672
+ ctx->almost_ready_fence = ctx->device->device.createFence({});
3232
3673
 
3233
3674
  #ifdef GGML_VULKAN_CHECK_RESULTS
3234
3675
  const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
@@ -3277,6 +3718,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3277
3718
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
3278
3719
  return ctx->device->pipeline_matmul_f32_f16;
3279
3720
  }
3721
+ if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
3722
+ return ctx->device->pipeline_matmul_bf16;
3723
+ }
3280
3724
  if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
3281
3725
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3282
3726
  return ctx->device->pipeline_matmul_f16_f32.f16acc;
@@ -3293,6 +3737,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3293
3737
  }
3294
3738
  }
3295
3739
 
3740
+ // MMQ
3741
+ if (src1_type == GGML_TYPE_Q8_1) {
3742
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
3743
+
3744
+ if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
3745
+ return nullptr;
3746
+ }
3747
+
3748
+ return pipelines;
3749
+ }
3750
+
3296
3751
  if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
3297
3752
  return nullptr;
3298
3753
  }
@@ -3337,6 +3792,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
3337
3792
  switch (a_type) {
3338
3793
  case GGML_TYPE_F32:
3339
3794
  case GGML_TYPE_F16:
3795
+ case GGML_TYPE_BF16:
3340
3796
  case GGML_TYPE_Q4_0:
3341
3797
  case GGML_TYPE_Q4_1:
3342
3798
  case GGML_TYPE_Q5_0:
@@ -3369,6 +3825,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
3369
3825
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3370
3826
  return ctx->device->pipeline_matmul_id_f32;
3371
3827
  }
3828
+ if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
3829
+ return ctx->device->pipeline_matmul_id_bf16;
3830
+ }
3372
3831
  if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
3373
3832
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3374
3833
  return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
@@ -3422,6 +3881,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
3422
3881
  switch (a_type) {
3423
3882
  case GGML_TYPE_F32:
3424
3883
  case GGML_TYPE_F16:
3884
+ case GGML_TYPE_BF16:
3425
3885
  case GGML_TYPE_Q4_0:
3426
3886
  case GGML_TYPE_Q4_1:
3427
3887
  case GGML_TYPE_Q5_0:
@@ -3585,8 +4045,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
3585
4045
  return s;
3586
4046
  }
3587
4047
 
3588
-
3589
-
3590
4048
  static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
3591
4049
  const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
3592
4050
  const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
@@ -4010,14 +4468,20 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
4010
4468
  if (split_k == 3) {
4011
4469
  split_k = 2;
4012
4470
  }
4471
+ if (ctx->device->coopmat2) {
4472
+ // coopmat2 shader expects splits to be aligned to 256
4473
+ while (split_k > 1 && ((k / split_k) % 256) != 0) {
4474
+ split_k /= 2;
4475
+ }
4476
+ }
4013
4477
  }
4014
4478
  }
4015
4479
 
4016
4480
  return split_k;
4017
4481
  }
4018
4482
 
4019
- static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
4020
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4483
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
4484
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
4021
4485
 
4022
4486
  if (ctx->device->coopmat2) {
4023
4487
  // Use large shader when the N dimension is greater than the medium shader's tile size
@@ -4042,9 +4506,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
4042
4506
  return aligned ? mmp->a_l : mmp->l;
4043
4507
  }
4044
4508
 
4045
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
4046
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
4047
- return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
4509
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
4510
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
4511
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
4048
4512
  }
4049
4513
 
4050
4514
  static void ggml_vk_matmul(
@@ -4054,7 +4518,7 @@ static void ggml_vk_matmul(
4054
4518
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
4055
4519
  uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
4056
4520
  uint32_t padded_n) {
4057
- VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
4521
+ VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
4058
4522
  ggml_vk_sync_buffers(subctx);
4059
4523
  if (split_k == 1) {
4060
4524
  const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
@@ -4072,7 +4536,7 @@ static void ggml_vk_matmul(
4072
4536
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
4073
4537
  }
4074
4538
 
4075
- static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
4539
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
4076
4540
  VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4077
4541
 
4078
4542
  if (ctx->device->coopmat2) {
@@ -4153,6 +4617,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
4153
4617
  return ctx->device->pipeline_cpy_f16_f16;
4154
4618
  }
4155
4619
  }
4620
+ if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
4621
+ if (contig) {
4622
+ return ctx->device->pipeline_contig_cpy_f16_f32;
4623
+ } else {
4624
+ return ctx->device->pipeline_cpy_f16_f32;
4625
+ }
4626
+ }
4627
+ if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
4628
+ if (contig) {
4629
+ return ctx->device->pipeline_contig_cpy_f32_bf16;
4630
+ } else {
4631
+ return ctx->device->pipeline_cpy_f32_bf16;
4632
+ }
4633
+ }
4156
4634
  if (src->type == GGML_TYPE_F32) {
4157
4635
  switch (to) {
4158
4636
  case GGML_TYPE_Q4_0:
@@ -4214,6 +4692,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
4214
4692
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
4215
4693
  }
4216
4694
 
4695
+ static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
4696
+ switch(type) {
4697
+ case GGML_TYPE_Q8_1:
4698
+ return ctx->device->pipeline_quantize_q8_1;
4699
+ default:
4700
+ std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
4701
+ GGML_ABORT("fatal error");
4702
+ }
4703
+ }
4704
+
4705
+ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
4706
+ VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
4707
+
4708
+ vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
4709
+
4710
+ ggml_vk_sync_buffers(subctx);
4711
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
4712
+ }
4713
+
4217
4714
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
4218
4715
  VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
4219
4716
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4261,30 +4758,43 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4261
4758
  const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
4262
4759
  !ggml_vk_dim01_contiguous(src0);
4263
4760
  const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
4761
+ (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
4264
4762
  !ggml_vk_dim01_contiguous(src1);
4265
4763
 
4764
+ // If src0 is BF16, try to use a BF16 x BF16 multiply
4765
+ ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
4766
+
4266
4767
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4267
4768
 
4268
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
4769
+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
4770
+
4771
+ // Check for mmq first
4772
+ vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
4773
+
4774
+ if (mmp == nullptr) {
4775
+ // Fall back to f16 dequant mul mat
4776
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
4777
+ quantize_y = false;
4778
+ }
4269
4779
 
4270
4780
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4271
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
4781
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
4272
4782
 
4273
4783
  if (qx_needs_dequant) {
4274
4784
  // Fall back to dequant + f16 mulmat
4275
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
4785
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
4276
4786
  }
4277
4787
 
4278
4788
  // Not implemented
4279
4789
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4280
4790
 
4281
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4282
- const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
4791
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
4792
+ const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
4283
4793
 
4284
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4794
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
4285
4795
 
4286
4796
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4287
- uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4797
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
4288
4798
  const int x_ne = ne01 * ne00;
4289
4799
  const int y_ne = padded_n * ne10;
4290
4800
  const int d_ne = ne11 * ne01;
@@ -4294,25 +4804,30 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4294
4804
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4295
4805
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4296
4806
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
4297
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
4807
+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
4298
4808
  const uint64_t d_sz = sizeof(float) * d_ne;
4299
4809
 
4300
4810
  vk_pipeline to_fp16_vk_0 = nullptr;
4301
4811
  vk_pipeline to_fp16_vk_1 = nullptr;
4812
+ vk_pipeline to_q8_1 = nullptr;
4302
4813
 
4303
4814
  if (x_non_contig) {
4304
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
4815
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
4305
4816
  } else {
4306
4817
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
4307
4818
  }
4308
4819
  if (y_non_contig) {
4309
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
4820
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
4310
4821
  } else {
4311
4822
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
4312
4823
  }
4313
4824
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
4314
4825
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
4315
4826
 
4827
+ if (quantize_y) {
4828
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
4829
+ }
4830
+
4316
4831
  if (dryrun) {
4317
4832
  const uint64_t x_sz_upd = x_sz * ne02 * ne03;
4318
4833
  const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -4326,7 +4841,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4326
4841
  if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
4327
4842
  ctx->prealloc_size_x = x_sz_upd;
4328
4843
  }
4329
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
4844
+ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
4330
4845
  ctx->prealloc_size_y = y_sz_upd;
4331
4846
  }
4332
4847
  if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
@@ -4341,6 +4856,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4341
4856
  if (qy_needs_dequant) {
4342
4857
  ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
4343
4858
  }
4859
+ if (quantize_y) {
4860
+ ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
4861
+ }
4344
4862
  if (split_k > 1) {
4345
4863
  ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
4346
4864
  }
@@ -4376,6 +4894,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4376
4894
  if (qy_needs_dequant) {
4377
4895
  d_Y = ctx->prealloc_y;
4378
4896
  GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4897
+ } else if (quantize_y) {
4898
+ d_Y = ctx->prealloc_y;
4899
+ GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
4379
4900
  } else {
4380
4901
  d_Y = d_Qy;
4381
4902
  y_buf_offset = qy_buf_offset;
@@ -4392,6 +4913,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4392
4913
  if (y_non_contig) {
4393
4914
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
4394
4915
  }
4916
+ if (quantize_y) {
4917
+ ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
4918
+ }
4395
4919
 
4396
4920
  uint32_t stride_batch_x = ne00*ne01;
4397
4921
  uint32_t stride_batch_y = ne10*ne11;
@@ -4400,7 +4924,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4400
4924
  stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
4401
4925
  }
4402
4926
 
4403
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
4927
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
4404
4928
  stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
4405
4929
  }
4406
4930
 
@@ -4710,6 +5234,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
4710
5234
  const uint64_t nb01 = src0->nb[1];
4711
5235
  const uint64_t nb02 = src0->nb[2];
4712
5236
 
5237
+ const uint64_t nb12 = src1->nb[2];
5238
+
4713
5239
  // const uint64_t ne10 = src1->ne[0];
4714
5240
  const uint64_t ne11 = src1->ne[1];
4715
5241
  const uint64_t ne12 = src1->ne[2];
@@ -4735,6 +5261,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
4735
5261
 
4736
5262
  const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
4737
5263
  const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
5264
+ const uint32_t channel_stride_y = nb12 / sizeof(float);
4738
5265
 
4739
5266
  const uint64_t qx_sz = ggml_nbytes(src0);
4740
5267
  const uint64_t qy_sz = ggml_nbytes(src1);
@@ -4765,7 +5292,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
4765
5292
  const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
4766
5293
 
4767
5294
  // compute
4768
- const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
5295
+ const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
4769
5296
  ggml_vk_sync_buffers(subctx);
4770
5297
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
4771
5298
  { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
@@ -4790,7 +5317,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
4790
5317
  // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
4791
5318
  // when ne12 and ne13 are one.
4792
5319
  } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
4793
- (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
5320
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
4794
5321
  ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
4795
5322
  } else {
4796
5323
  ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -4817,7 +5344,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4817
5344
 
4818
5345
  const uint64_t nei0 = ids->ne[0];
4819
5346
  const uint64_t nei1 = ids->ne[1];
4820
- GGML_ASSERT(nei0 * nei1 <= 3072);
5347
+ GGML_ASSERT(nei0 * nei1 <= 4096);
4821
5348
 
4822
5349
  const uint32_t nbi1 = ids->nb[1];
4823
5350
  const uint32_t nbi2 = ids->nb[2];
@@ -4858,27 +5385,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4858
5385
  const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
4859
5386
  !ggml_vk_dim01_contiguous(src0);
4860
5387
  const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
5388
+ (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
4861
5389
  !ggml_vk_dim01_contiguous(src1);
4862
5390
 
5391
+ // If src0 is BF16, try to use a BF16 x BF16 multiply
5392
+ ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
5393
+
4863
5394
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4864
5395
 
4865
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
5396
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
4866
5397
 
4867
5398
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4868
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
5399
+ const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
4869
5400
 
4870
5401
  if (qx_needs_dequant) {
4871
5402
  // Fall back to dequant + f16 mulmat
4872
- mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
5403
+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
4873
5404
  }
4874
5405
 
4875
5406
  // Not implemented
4876
5407
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4877
5408
 
4878
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
5409
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
4879
5410
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
4880
5411
 
4881
- vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
5412
+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
4882
5413
 
4883
5414
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4884
5415
  uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
@@ -4897,12 +5428,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4897
5428
  vk_pipeline to_fp16_vk_1 = nullptr;
4898
5429
 
4899
5430
  if (x_non_contig) {
4900
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
5431
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
4901
5432
  } else {
4902
5433
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
4903
5434
  }
4904
5435
  if (y_non_contig) {
4905
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
5436
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
4906
5437
  } else {
4907
5438
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
4908
5439
  }
@@ -5212,6 +5743,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
5212
5743
  }
5213
5744
  }
5214
5745
 
5746
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
5747
+ // Needs to be kept up to date on shader changes
5748
+ const uint32_t wg_size = scalar_flash_attention_workgroup_size;
5749
+ const uint32_t Br = scalar_flash_attention_num_large_rows;
5750
+ const uint32_t Bc = scalar_flash_attention_Bc;
5751
+
5752
+ const uint32_t acctype = f32acc ? 4 : 2;
5753
+ const uint32_t f16vec4 = 8;
5754
+
5755
+ const uint32_t tmpsh = wg_size * sizeof(float);
5756
+ const uint32_t tmpshv4 = wg_size * 4 * acctype;
5757
+
5758
+ const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
5759
+
5760
+ const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
5761
+ const uint32_t sfsh = Bc * sfshstride * acctype;
5762
+
5763
+ const uint32_t kshstride = D / 4 + 2;
5764
+ const uint32_t ksh = Bc * kshstride * f16vec4;
5765
+
5766
+ const uint32_t slope = Br * sizeof(float);
5767
+
5768
+ const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
5769
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
5770
+
5771
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
5772
+
5773
+ return supported;
5774
+ }
5775
+
5215
5776
  static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
5216
5777
  VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
5217
5778
  std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
@@ -5232,7 +5793,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5232
5793
  const uint32_t nbm1 = mask ? mask->nb[1] : 0;
5233
5794
 
5234
5795
  const uint32_t D = neq0;
5235
- const uint32_t N = neq1;
5796
+ uint32_t N = neq1;
5236
5797
  const uint32_t KV = nek1;
5237
5798
 
5238
5799
  GGML_ASSERT(ne0 == D);
@@ -5262,20 +5823,110 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5262
5823
  assert(q->type == GGML_TYPE_F32);
5263
5824
  assert(k->type == v->type);
5264
5825
 
5826
+ FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
5827
+ ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
5828
+
5829
+ if (path == FA_COOPMAT1) {
5830
+ const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
5831
+ (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
5832
+
5833
+ const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
5834
+
5835
+ if (!coopmat_shape_supported || !coopmat_shmem_supported) {
5836
+ path = FA_SCALAR;
5837
+ }
5838
+ }
5839
+
5840
+ uint32_t gqa_ratio = 1;
5841
+ uint32_t qk_ratio = neq2 / nek2;
5842
+ uint32_t workgroups_x = (uint32_t)neq1;
5843
+ uint32_t workgroups_y = (uint32_t)neq2;
5844
+ uint32_t workgroups_z = (uint32_t)neq3;
5845
+
5846
+ // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
5847
+ // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
5848
+ uint32_t max_gqa;
5849
+ switch (path) {
5850
+ case FA_SCALAR:
5851
+ case FA_COOPMAT1:
5852
+ // We may switch from coopmat1 to scalar, so use the scalar limit for both
5853
+ max_gqa = scalar_flash_attention_num_large_rows;
5854
+ break;
5855
+ case FA_COOPMAT2:
5856
+ max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
5857
+ break;
5858
+ default:
5859
+ GGML_ASSERT(0);
5860
+ }
5861
+
5862
+ if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
5863
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5864
+ // grouped query attention - make the N dimension equal to gqa_ratio, reduce
5865
+ // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5866
+ // and change addressing calculations to index Q's dimension 2.
5867
+ gqa_ratio = qk_ratio;
5868
+ N = gqa_ratio;
5869
+ workgroups_y /= N;
5870
+ }
5871
+
5265
5872
  vk_pipeline *pipelines;
5266
- // XXX TODO other backends may be changing accumulator precision to default to f32 soon
5267
- bool f32acc = dst->op_params[3] == GGML_PREC_F32;
5268
- bool small_rows = N <= flash_attention_num_small_rows;
5269
- switch (D) {
5270
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
5271
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
5272
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
5273
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
5274
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
5275
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
5873
+ bool small_rows = N <= get_fa_num_small_rows(path);
5874
+
5875
+ // coopmat1 does not actually support "small rows" (it needs 16 rows).
5876
+ // So use scalar instead.
5877
+ if (small_rows && path == FA_COOPMAT1) {
5878
+ path = FA_SCALAR;
5879
+ }
5880
+
5881
+ // scalar is faster than coopmat2 when N==1
5882
+ if (N == 1 && path == FA_COOPMAT2) {
5883
+ path = FA_SCALAR;
5884
+ }
5885
+
5886
+ bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
5887
+
5888
+ switch (path) {
5889
+ case FA_SCALAR:
5890
+ switch (D) {
5891
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
5892
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
5893
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
5894
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
5895
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
5896
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
5897
+ default:
5898
+ GGML_ASSERT(!"unsupported D value");
5899
+ return;
5900
+ }
5901
+ break;
5902
+ case FA_COOPMAT1:
5903
+ switch (D) {
5904
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
5905
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
5906
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
5907
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
5908
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
5909
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
5910
+ default:
5911
+ GGML_ASSERT(!"unsupported D value");
5912
+ return;
5913
+ }
5914
+ break;
5915
+ case FA_COOPMAT2:
5916
+ switch (D) {
5917
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
5918
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
5919
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
5920
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
5921
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
5922
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
5923
+ default:
5924
+ GGML_ASSERT(!"unsupported D value");
5925
+ return;
5926
+ }
5927
+ break;
5276
5928
  default:
5277
- assert(!"unsupported D value");
5278
- return;
5929
+ GGML_ASSERT(0);
5279
5930
  }
5280
5931
  assert(pipelines);
5281
5932
 
@@ -5287,12 +5938,47 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5287
5938
  // the "aligned" shader variant will forcibly align strides, for performance
5288
5939
  (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
5289
5940
 
5941
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
5942
+ GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
5943
+
5290
5944
  vk_pipeline pipeline = pipelines[aligned];
5291
5945
  assert(pipeline);
5292
5946
 
5947
+ uint32_t split_kv = KV;
5948
+ uint32_t split_k = 1;
5949
+
5950
+ // Use a placeholder core count if one isn't available. split_k is a big help for perf.
5951
+ const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
5952
+
5953
+ // Try to use split_k when KV is large enough to be worth the overhead
5954
+ if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
5955
+ // Try to run two workgroups per SM.
5956
+ split_k = ctx->device->shader_core_count * 2 / workgroups_y;
5957
+ if (split_k > 1) {
5958
+ // Try to evenly split KV into split_k chunks, but it needs to be a multiple
5959
+ // of "align", so recompute split_k based on that.
5960
+ split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
5961
+ split_k = CEIL_DIV(KV, split_kv);
5962
+ workgroups_x = split_k;
5963
+ }
5964
+ }
5965
+
5966
+ // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
5967
+ // and the per-row m and L values (ne1 rows).
5968
+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
5969
+ if (split_k_size > ctx->device->max_memory_allocation_size) {
5970
+ GGML_ABORT("Requested preallocation size is too large");
5971
+ }
5972
+ if (ctx->prealloc_size_split_k < split_k_size) {
5973
+ ctx->prealloc_size_split_k = split_k_size;
5974
+ }
5975
+
5293
5976
  if (dryrun) {
5294
5977
  // Request descriptor sets
5295
5978
  ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5979
+ if (split_k > 1) {
5980
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
5981
+ }
5296
5982
  return;
5297
5983
  }
5298
5984
 
@@ -5313,8 +5999,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5313
5999
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5314
6000
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5315
6001
 
5316
- ggml_vk_sync_buffers(subctx);
5317
-
5318
6002
  vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
5319
6003
  size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
5320
6004
 
@@ -5379,16 +6063,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5379
6063
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
5380
6064
  nbm1,
5381
6065
  scale, max_bias, logit_softcap,
5382
- mask != nullptr, n_head_log2, m0, m1 };
5383
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5384
- {
5385
- vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5386
- vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5387
- vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5388
- vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5389
- vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5390
- },
5391
- sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
6066
+ mask != nullptr, n_head_log2, m0, m1,
6067
+ gqa_ratio, split_kv, split_k };
6068
+
6069
+ ggml_vk_sync_buffers(subctx);
6070
+
6071
+ if (split_k > 1) {
6072
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6073
+ {
6074
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
6075
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
6076
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
6077
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
6078
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6079
+ },
6080
+ // We only use split_k when group query attention is enabled, which means
6081
+ // there's no more than one tile of rows (i.e. workgroups_x would have been
6082
+ // one). We reuse workgroups_x to mean the number of splits, so we need to
6083
+ // cancel out the divide by wg_denoms[0].
6084
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
6085
+
6086
+ ggml_vk_sync_buffers(subctx);
6087
+ const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
6088
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
6089
+ {
6090
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
6091
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6092
+ },
6093
+ pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
6094
+ } else {
6095
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
6096
+ {
6097
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
6098
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
6099
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
6100
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
6101
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
6102
+ },
6103
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
6104
+ }
5392
6105
  }
5393
6106
 
5394
6107
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
@@ -5408,26 +6121,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5408
6121
  }
5409
6122
  return nullptr;
5410
6123
  case GGML_OP_ADD:
5411
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5412
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
6124
+ case GGML_OP_SUB:
6125
+ case GGML_OP_MUL:
6126
+ case GGML_OP_DIV:
6127
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6128
+ (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
6129
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
6130
+ return nullptr;
5413
6131
  }
5414
- if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
5415
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
6132
+ switch (op) {
6133
+ case GGML_OP_ADD:
6134
+ {
6135
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6136
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5416
6137
  }
5417
- return nullptr;
5418
- case GGML_OP_SUB:
5419
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5420
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
6138
+ case GGML_OP_SUB:
6139
+ {
6140
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
6141
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5421
6142
  }
5422
- return nullptr;
5423
- case GGML_OP_MUL:
5424
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5425
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
6143
+ case GGML_OP_MUL:
6144
+ {
6145
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
6146
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5426
6147
  }
5427
- return nullptr;
5428
- case GGML_OP_DIV:
5429
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5430
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
6148
+ case GGML_OP_DIV:
6149
+ {
6150
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
6151
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
6152
+ }
6153
+ default:
6154
+ break;
5431
6155
  }
5432
6156
  return nullptr;
5433
6157
  case GGML_OP_CONCAT:
@@ -5442,7 +6166,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5442
6166
  }
5443
6167
  return nullptr;
5444
6168
  case GGML_OP_UPSCALE:
5445
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6169
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
5446
6170
  return ctx->device->pipeline_upscale_f32;
5447
6171
  }
5448
6172
  return nullptr;
@@ -5521,37 +6245,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5521
6245
  }
5522
6246
  return nullptr;
5523
6247
  case GGML_OP_UNARY:
6248
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6249
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6250
+ (src0->type != dst->type)) {
6251
+ return nullptr;
6252
+ }
6253
+
5524
6254
  switch (ggml_get_unary_op(dst)) {
5525
6255
  case GGML_UNARY_OP_SILU:
5526
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5527
- return ctx->device->pipeline_silu_f32;
5528
- }
5529
- break;
6256
+ return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
5530
6257
  case GGML_UNARY_OP_GELU:
5531
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5532
- return ctx->device->pipeline_gelu_f32;
5533
- }
5534
- break;
6258
+ return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
5535
6259
  case GGML_UNARY_OP_GELU_QUICK:
5536
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5537
- return ctx->device->pipeline_gelu_quick_f32;
5538
- }
5539
- break;
6260
+ return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
5540
6261
  case GGML_UNARY_OP_RELU:
5541
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5542
- return ctx->device->pipeline_relu_f32;
5543
- }
5544
- break;
6262
+ return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
5545
6263
  case GGML_UNARY_OP_TANH:
5546
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5547
- return ctx->device->pipeline_tanh_f32;
5548
- }
5549
- break;
6264
+ return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
5550
6265
  case GGML_UNARY_OP_SIGMOID:
5551
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5552
- return ctx->device->pipeline_sigmoid_f32;
5553
- }
5554
- break;
6266
+ return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
5555
6267
  default:
5556
6268
  break;
5557
6269
  }
@@ -5674,6 +6386,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5674
6386
  return ctx->device->pipeline_leaky_relu_f32;
5675
6387
  }
5676
6388
  return nullptr;
6389
+ case GGML_OP_CONV_2D_DW:
6390
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6391
+ if (ggml_is_contiguous(src1)) {
6392
+ return ctx->device->pipeline_conv2d_dw_whcn_f32;
6393
+ } else if (ggml_is_contiguous_channels(src1)) {
6394
+ return ctx->device->pipeline_conv2d_dw_cwhn_f32;
6395
+ }
6396
+ }
6397
+ return nullptr;
5677
6398
  default:
5678
6399
  return nullptr;
5679
6400
  }
@@ -5699,6 +6420,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5699
6420
  case GGML_OP_REPEAT:
5700
6421
  case GGML_OP_REPEAT_BACK:
5701
6422
  case GGML_OP_ROPE:
6423
+ case GGML_OP_RMS_NORM:
6424
+ case GGML_OP_CONV_2D_DW:
5702
6425
  return true;
5703
6426
  default:
5704
6427
  return false;
@@ -5909,7 +6632,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5909
6632
 
5910
6633
  switch (op) {
5911
6634
  case GGML_OP_NORM:
5912
- case GGML_OP_RMS_NORM:
5913
6635
  case GGML_OP_RMS_NORM_BACK:
5914
6636
  case GGML_OP_L2_NORM:
5915
6637
  case GGML_OP_SOFT_MAX:
@@ -5926,6 +6648,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5926
6648
  elements = { nr, 1, 1 };
5927
6649
  }
5928
6650
  } break;
6651
+ case GGML_OP_RMS_NORM:
6652
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
6653
+ break;
6654
+
5929
6655
  case GGML_OP_SUM:
5930
6656
  // We use GGML_OP_SUM_ROWS with 1 row.
5931
6657
  elements = { 1, 1, 1 };
@@ -5992,6 +6718,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5992
6718
  case GGML_OP_CONCAT:
5993
6719
  case GGML_OP_UPSCALE:
5994
6720
  case GGML_OP_UNARY:
6721
+ case GGML_OP_CONV_2D_DW:
5995
6722
  {
5996
6723
  const uint32_t ne = ggml_nelements(dst);
5997
6724
  if (ne > 262144) {
@@ -6576,7 +7303,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
6576
7303
 
6577
7304
  static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6578
7305
  float * op_params = (float *)dst->op_params;
6579
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
7306
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
7307
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
7308
+
7309
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
7310
+ (uint32_t)ggml_nelements(src0),
7311
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
7312
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
7313
+ 0,
7314
+ op_params[0], 0.0f,
7315
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
7316
+ }, dryrun);
6580
7317
  }
6581
7318
 
6582
7319
  static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -6768,6 +7505,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
6768
7505
  }, dryrun);
6769
7506
  }
6770
7507
 
7508
+ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7509
+ vk_op_conv2d_dw_push_constants p{};
7510
+ p.ne = ggml_nelements(dst);
7511
+ p.channels = dst->ne[2];
7512
+ p.batches = dst->ne[3];
7513
+ p.dst_w = dst->ne[0];
7514
+ p.dst_h = dst->ne[1];
7515
+ p.src_w = src1->ne[0];
7516
+ p.src_h = src1->ne[1];
7517
+ p.knl_w = src0->ne[0];
7518
+ p.knl_h = src0->ne[1];
7519
+ p.stride_x = dst->op_params[0];
7520
+ p.stride_y = dst->op_params[1];
7521
+ p.pad_x = dst->op_params[2];
7522
+ p.pad_y = dst->op_params[3];
7523
+ p.dilation_x = dst->op_params[4];
7524
+ p.dilation_y = dst->op_params[5];
7525
+
7526
+ GGML_ASSERT(src0->ne[3] == p.channels);
7527
+ GGML_ASSERT(src1->ne[3] == p.batches);
7528
+
7529
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
7530
+ }
7531
+
6771
7532
  static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6772
7533
  const float * op_params = (const float *)dst->op_params;
6773
7534
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
@@ -6929,6 +7690,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
6929
7690
  }
6930
7691
  }
6931
7692
 
7693
+ if (ctx->device->need_compiles) {
7694
+ ggml_vk_load_shaders(ctx->device);
7695
+ }
7696
+
6932
7697
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
6933
7698
 
6934
7699
  vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -7177,6 +7942,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
7177
7942
 
7178
7943
  ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
7179
7944
 
7945
+ if (ctx->device->need_compiles) {
7946
+ ggml_vk_load_shaders(ctx->device);
7947
+ }
7948
+
7180
7949
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7181
7950
 
7182
7951
  ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
@@ -7236,66 +8005,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
7236
8005
  free(x_chk);
7237
8006
  }
7238
8007
 
7239
- static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
8008
+ // This does not work without ggml q8_1 quantization support
8009
+ //
8010
+ // typedef uint16_t ggml_half;
8011
+ // typedef uint32_t ggml_half2;
8012
+ //
8013
+ // #define QK8_1 32
8014
+ // typedef struct {
8015
+ // union {
8016
+ // struct {
8017
+ // ggml_half d; // delta
8018
+ // ggml_half s; // d * sum(qs[i])
8019
+ // } GGML_COMMON_AGGR_S;
8020
+ // ggml_half2 ds;
8021
+ // } GGML_COMMON_AGGR_U;
8022
+ // int8_t qs[QK8_1]; // quants
8023
+ // } block_q8_1;
8024
+ //
8025
+ // static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
8026
+ // VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
8027
+ // GGML_ASSERT(quant == GGML_TYPE_Q8_1);
8028
+ //
8029
+ // const size_t x_sz = sizeof(float) * ne;
8030
+ // const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
8031
+ // float * x = (float *) malloc(x_sz);
8032
+ // block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
8033
+ // block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
8034
+ // vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
8035
+ // vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
8036
+ //
8037
+ // for (size_t i = 0; i < ne; i++) {
8038
+ // x[i] = rand() / (float)RAND_MAX;
8039
+ // }
8040
+ //
8041
+ // vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
8042
+ //
8043
+ // ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
8044
+ //
8045
+ // if (ctx->device->need_compiles) {
8046
+ // ggml_vk_load_shaders(ctx->device);
8047
+ // }
8048
+ //
8049
+ // ggml_pipeline_allocate_descriptor_sets(ctx->device);
8050
+ //
8051
+ // ggml_vk_buffer_write(x_buf, 0, x, x_sz);
8052
+ //
8053
+ // vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
8054
+ // ggml_vk_ctx_begin(ctx->device, subctx);
8055
+ // ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
8056
+ // ggml_vk_ctx_end(subctx);
8057
+ //
8058
+ // auto begin = std::chrono::high_resolution_clock::now();
8059
+ //
8060
+ // ggml_vk_submit(subctx, ctx->fence);
8061
+ // VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
8062
+ // ctx->device->device.resetFences({ ctx->fence });
8063
+ //
8064
+ // auto end = std::chrono::high_resolution_clock::now();
8065
+ //
8066
+ // double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
8067
+ // ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
8068
+ //
8069
+ // ggml_vk_quantize_data(x, qx_res, ne, quant);
8070
+ //
8071
+ // int first_err = -1;
8072
+ //
8073
+ // for (size_t i = 0; i < ne / 32; i++) {
8074
+ // double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
8075
+ //
8076
+ // if (first_err < 0 && error > 0.1) {
8077
+ // first_err = i;
8078
+ // }
8079
+ //
8080
+ // error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
8081
+ //
8082
+ // if (first_err < 0 && error > 0.1) {
8083
+ // first_err = i;
8084
+ // }
8085
+ //
8086
+ // for (size_t j = 0; j < 32; j++) {
8087
+ // uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
8088
+ //
8089
+ // if (first_err < 0 && error > 1) {
8090
+ // first_err = i;
8091
+ // }
8092
+ // }
8093
+ // }
8094
+ //
8095
+ // std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
8096
+ //
8097
+ // if (first_err != -1) {
8098
+ // std::cerr << "first_error = " << first_err << std::endl;
8099
+ // std::cerr << "Actual result: " << std::endl << std::endl;
8100
+ // std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
8101
+ // for (size_t j = 0; j < 32; j++) {
8102
+ // std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
8103
+ // }
8104
+ // std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
8105
+ // std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
8106
+ // for (size_t j = 0; j < 32; j++) {
8107
+ // std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
8108
+ // }
8109
+ // std::cerr << std::endl;
8110
+ // }
8111
+ //
8112
+ // ggml_vk_destroy_buffer(x_buf);
8113
+ // ggml_vk_destroy_buffer(qx_buf);
8114
+ //
8115
+ // free(x);
8116
+ // free(qx);
8117
+ // free(qx_res);
8118
+ // }
8119
+
8120
+ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
7240
8121
  VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
7241
8122
  const size_t x_ne = m * k * batch;
7242
8123
  const size_t y_ne = k * n * batch;
7243
8124
  const size_t d_ne = m * n * batch;
7244
8125
 
8126
+ vk_matmul_pipeline2 * pipelines;
8127
+
8128
+ if (mmq) {
8129
+ pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
8130
+ } else {
8131
+ pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
8132
+ }
8133
+
8134
+ const bool fp16acc = ctx->device->fp16;
8135
+
7245
8136
  vk_pipeline p;
7246
8137
  std::string shname;
7247
8138
  if (shader_size == 0) {
7248
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
8139
+ p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
7249
8140
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
7250
8141
  } else if (shader_size == 1) {
7251
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
8142
+ p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
7252
8143
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
7253
8144
  } else if (shader_size == 2) {
7254
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
8145
+ p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
7255
8146
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
7256
8147
  } else {
7257
8148
  GGML_ASSERT(0);
7258
8149
  }
7259
8150
 
7260
- const size_t kpad = ggml_vk_align_size(k, p->align);
8151
+ const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
7261
8152
 
7262
- if (k != kpad) {
8153
+ if (mmq || k != kpad) {
7263
8154
  if (shader_size == 0) {
7264
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
8155
+ p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
7265
8156
  shname = std::string(ggml_type_name(quant)) + "_S";
7266
8157
  } else if (shader_size == 1) {
7267
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
8158
+ p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
7268
8159
  shname = std::string(ggml_type_name(quant)) + "_M";
7269
8160
  } else if (shader_size == 2) {
7270
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
8161
+ p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
7271
8162
  shname = std::string(ggml_type_name(quant)) + "_L";
7272
8163
  } else {
7273
8164
  GGML_ASSERT(0);
7274
8165
  }
7275
8166
  }
7276
8167
 
8168
+ if (p == nullptr) {
8169
+ std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
8170
+ return;
8171
+ }
8172
+
7277
8173
  const size_t x_sz = sizeof(float) * x_ne;
7278
8174
  const size_t y_sz = sizeof(float) * y_ne;
7279
8175
  const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
8176
+ const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
7280
8177
  const size_t d_sz = sizeof(float) * d_ne;
7281
8178
  float * x = (float *) malloc(x_sz);
7282
8179
  float * y = (float *) malloc(y_sz);
7283
8180
  void * qx = malloc(qx_sz);
7284
8181
  vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7285
8182
  vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
8183
+ vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7286
8184
  vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7287
8185
  float * d = (float *) malloc(d_sz);
7288
8186
  float * d_chk = (float *) malloc(d_sz);
7289
8187
 
7290
8188
  for (size_t i = 0; i < x_ne; i++) {
7291
8189
  x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
8190
+ // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
8191
+ // x[i] = i % k;
7292
8192
  }
7293
8193
 
7294
8194
  ggml_vk_quantize_data(x, qx, x_ne, quant);
7295
8195
 
7296
8196
  for (size_t i = 0; i < y_ne; i++) {
7297
- // y[i] = rand() / (float)RAND_MAX;
7298
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
8197
+ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
8198
+ // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
8199
+ // y[i] = i % k;
7299
8200
  }
7300
8201
 
7301
8202
  ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
@@ -7310,6 +8211,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7310
8211
  ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
7311
8212
  }
7312
8213
  }
8214
+ if (mmq) {
8215
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
8216
+ }
8217
+
8218
+ if (ctx->device->need_compiles) {
8219
+ ggml_vk_load_shaders(ctx->device);
8220
+ }
7313
8221
 
7314
8222
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7315
8223
 
@@ -7318,13 +8226,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7318
8226
 
7319
8227
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
7320
8228
  ggml_vk_ctx_begin(ctx->device, subctx);
7321
- for (size_t i = 0; i < num_it; i++) {
7322
- ggml_vk_matmul(
7323
- ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
7324
- m, n, k,
7325
- k, k, m, k*m, k*n, m*n,
7326
- split_k, batch, batch, batch, 1, 1, n
7327
- );
8229
+ if (mmq) {
8230
+ for (size_t i = 0; i < num_it; i++) {
8231
+ ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
8232
+ ggml_vk_matmul(
8233
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
8234
+ m, n, k,
8235
+ k, k, m, k*m, k*n, m*n,
8236
+ split_k, batch, batch, batch, 1, 1, n
8237
+ );
8238
+ }
8239
+ } else {
8240
+ for (size_t i = 0; i < num_it; i++) {
8241
+ ggml_vk_matmul(
8242
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
8243
+ m, n, k,
8244
+ k, k, m, k*m, k*n, m*n,
8245
+ split_k, batch, batch, batch, 1, 1, n
8246
+ );
8247
+ }
7328
8248
  }
7329
8249
  ggml_vk_ctx_end(subctx);
7330
8250
 
@@ -7382,7 +8302,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7382
8302
 
7383
8303
  double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
7384
8304
 
7385
- std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
8305
+ std::cerr << "TEST dequant matmul " << shname;
8306
+ if (mmq) {
8307
+ std::cerr << " mmq";
8308
+ }
8309
+ std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
7386
8310
 
7387
8311
  if (avg_err > 0.01 || std::isnan(avg_err)) {
7388
8312
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
@@ -7392,6 +8316,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7392
8316
  std::cerr << "Expected result: " << std::endl << std::endl;
7393
8317
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
7394
8318
 
8319
+ std::cerr << "src0: " << std::endl << std::endl;
8320
+ ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
8321
+ std::cerr << std::endl;
8322
+ std::cerr << "src1: " << std::endl << std::endl;
8323
+ ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
8324
+
7395
8325
  if (split_k > 1) {
7396
8326
  float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
7397
8327
  ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
@@ -7414,6 +8344,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7414
8344
 
7415
8345
  ggml_vk_destroy_buffer(qx_buf);
7416
8346
  ggml_vk_destroy_buffer(y_buf);
8347
+ ggml_vk_destroy_buffer(qy_buf);
7417
8348
  ggml_vk_destroy_buffer(d_buf);
7418
8349
 
7419
8350
  free(x);
@@ -7448,6 +8379,24 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
7448
8379
  };
7449
8380
  const size_t num_it = 100;
7450
8381
 
8382
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
8383
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
8384
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
8385
+
8386
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
8387
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
8388
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
8389
+
8390
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
8391
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
8392
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
8393
+
8394
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
8395
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
8396
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
8397
+
8398
+ abort();
8399
+
7451
8400
  for (size_t i = 0; i < vals.size(); i += 3) {
7452
8401
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
7453
8402
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
@@ -7522,11 +8471,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
7522
8471
  }
7523
8472
  }
7524
8473
 
7525
- static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence);
8474
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
7526
8475
 
7527
8476
  // Returns true if node has enqueued work into the queue, false otherwise
7528
8477
  // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
7529
- static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){
8478
+ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
7530
8479
  if (ggml_is_empty(node) || !node->buffer) {
7531
8480
  return false;
7532
8481
  }
@@ -7600,6 +8549,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7600
8549
  case GGML_OP_IM2COL:
7601
8550
  case GGML_OP_TIMESTEP_EMBEDDING:
7602
8551
  case GGML_OP_POOL_2D:
8552
+ case GGML_OP_CONV_2D_DW:
7603
8553
  case GGML_OP_RWKV_WKV6:
7604
8554
  case GGML_OP_RWKV_WKV7:
7605
8555
  case GGML_OP_LEAKY_RELU:
@@ -7663,6 +8613,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7663
8613
  case GGML_OP_IM2COL:
7664
8614
  case GGML_OP_TIMESTEP_EMBEDDING:
7665
8615
  case GGML_OP_POOL_2D:
8616
+ case GGML_OP_CONV_2D_DW:
7666
8617
  case GGML_OP_LEAKY_RELU:
7667
8618
  {
7668
8619
  // These operations all go through ggml_vk_op_f32, so short-circuit and
@@ -7836,6 +8787,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7836
8787
  case GGML_OP_POOL_2D:
7837
8788
  ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
7838
8789
 
8790
+ break;
8791
+ case GGML_OP_CONV_2D_DW:
8792
+ ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
8793
+
7839
8794
  break;
7840
8795
  case GGML_OP_LEAKY_RELU:
7841
8796
  ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
@@ -7898,7 +8853,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7898
8853
 
7899
8854
  ctx->compute_ctx.reset();
7900
8855
 
7901
- bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false);
8856
+ bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
7902
8857
  if (!ok) {
7903
8858
  if (node->op == GGML_OP_UNARY) {
7904
8859
  std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@@ -7912,7 +8867,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7912
8867
  return true;
7913
8868
  }
7914
8869
 
7915
- static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){
8870
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
7916
8871
  ggml_backend_buffer * buf = nullptr;
7917
8872
 
7918
8873
  switch (tensor->op) {
@@ -7957,6 +8912,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7957
8912
  case GGML_OP_IM2COL:
7958
8913
  case GGML_OP_TIMESTEP_EMBEDDING:
7959
8914
  case GGML_OP_POOL_2D:
8915
+ case GGML_OP_CONV_2D_DW:
7960
8916
  case GGML_OP_RWKV_WKV6:
7961
8917
  case GGML_OP_RWKV_WKV7:
7962
8918
  case GGML_OP_LEAKY_RELU:
@@ -8015,12 +8971,15 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
8015
8971
  memcpy(cpy.dst, cpy.src, cpy.n);
8016
8972
  }
8017
8973
 
8018
- ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{});
8974
+ if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) {
8975
+ ggml_vk_submit(subctx, ctx->almost_ready_fence);
8976
+ ctx->almost_ready_fence_pending = true;
8977
+ } else {
8978
+ ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{});
8979
+ }
8019
8980
 
8020
8981
  if (use_fence) {
8021
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
8022
-
8023
- ctx->device->device.resetFences({ ctx->fence });
8982
+ ggml_vk_wait_for_fence(ctx);
8024
8983
  }
8025
8984
  #ifdef GGML_VULKAN_CHECK_RESULTS
8026
8985
  ggml_vk_check_results_1(tensor);
@@ -8106,6 +9065,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
8106
9065
  ctx->gc.events.clear();
8107
9066
 
8108
9067
  ctx->device->device.destroyFence(ctx->fence);
9068
+ ctx->device->device.destroyFence(ctx->almost_ready_fence);
8109
9069
  }
8110
9070
 
8111
9071
  static int ggml_vk_get_device_count() {
@@ -8452,8 +9412,7 @@ static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
8452
9412
  }
8453
9413
 
8454
9414
  ggml_vk_submit(transfer_ctx, ctx->fence);
8455
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
8456
- ctx->device->device.resetFences({ ctx->fence });
9415
+ ggml_vk_wait_for_fence(ctx);
8457
9416
 
8458
9417
  for (auto& cpy : transfer_ctx->out_memcpys) {
8459
9418
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -8472,7 +9431,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8472
9431
 
8473
9432
  uint64_t total_mat_mul_bytes = 0;
8474
9433
  for (int i = 0; i < cgraph->n_nodes; i++) {
8475
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
9434
+ ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
8476
9435
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8477
9436
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8478
9437
  }
@@ -8514,11 +9473,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8514
9473
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8515
9474
  }
8516
9475
 
9476
+ // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
9477
+ bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
8517
9478
  bool submit = (submitted_nodes >= nodes_per_submit) ||
8518
9479
  (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
8519
- (i == last_node);
9480
+ (i == last_node) ||
9481
+ (almost_ready && !ctx->almost_ready_fence_pending);
8520
9482
 
8521
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
9483
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
8522
9484
 
8523
9485
  if (enqueued) {
8524
9486
  ++submitted_nodes;
@@ -8530,7 +9492,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8530
9492
  #endif
8531
9493
  }
8532
9494
 
8533
- if (submit) {
9495
+ if (submit && enqueued) {
8534
9496
  first_node_in_batch = true;
8535
9497
  submitted_nodes = 0;
8536
9498
  mul_mat_bytes = 0;
@@ -8687,7 +9649,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8687
9649
  case GGML_UNARY_OP_RELU:
8688
9650
  case GGML_UNARY_OP_TANH:
8689
9651
  case GGML_UNARY_OP_SIGMOID:
8690
- return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
9652
+ return ggml_is_contiguous(op->src[0]) &&
9653
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
9654
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
9655
+ (op->src[0]->type == op->type);
8691
9656
  default:
8692
9657
  return false;
8693
9658
  }
@@ -8705,6 +9670,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8705
9670
  switch (src0_type) {
8706
9671
  case GGML_TYPE_F32:
8707
9672
  case GGML_TYPE_F16:
9673
+ case GGML_TYPE_BF16:
8708
9674
  case GGML_TYPE_Q4_0:
8709
9675
  case GGML_TYPE_Q4_1:
8710
9676
  case GGML_TYPE_Q5_0:
@@ -8740,19 +9706,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8740
9706
  if (a->ne[3] != b->ne[3]) {
8741
9707
  return false;
8742
9708
  }
8743
- if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
9709
+ if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
8744
9710
  !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
8745
9711
  return false;
8746
9712
  }
9713
+ if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
9714
+ // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
9715
+ // So don't support this combination for now.
9716
+ return false;
9717
+ }
8747
9718
 
8748
9719
  return true;
8749
9720
  } break;
8750
9721
  case GGML_OP_FLASH_ATTN_EXT:
8751
9722
  {
8752
9723
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
8753
- if (!ggml_vk_get_device(ctx->device)->coopmat2) {
8754
- return false;
8755
- }
9724
+ auto device = ggml_vk_get_device(ctx->device);
9725
+ bool coopmat2 = device->coopmat2;
8756
9726
  switch (op->src[0]->ne[0]) {
8757
9727
  case 64:
8758
9728
  case 80:
@@ -8764,6 +9734,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8764
9734
  default:
8765
9735
  return false;
8766
9736
  }
9737
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
9738
+ // different head sizes of K and V are not supported yet
9739
+ return false;
9740
+ }
8767
9741
  if (op->src[0]->type != GGML_TYPE_F32) {
8768
9742
  return false;
8769
9743
  }
@@ -8781,10 +9755,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8781
9755
  switch (op->src[1]->type) {
8782
9756
  case GGML_TYPE_F16:
8783
9757
  case GGML_TYPE_Q4_0:
9758
+ case GGML_TYPE_Q8_0:
9759
+ // supported in scalar and coopmat2 paths
9760
+ break;
8784
9761
  case GGML_TYPE_Q4_1:
8785
9762
  case GGML_TYPE_Q5_0:
8786
9763
  case GGML_TYPE_Q5_1:
8787
- case GGML_TYPE_Q8_0:
8788
9764
  // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
8789
9765
  //case GGML_TYPE_Q2_K:
8790
9766
  //case GGML_TYPE_Q3_K:
@@ -8800,10 +9776,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8800
9776
  //case GGML_TYPE_IQ3_S:
8801
9777
  //case GGML_TYPE_IQ4_XS:
8802
9778
  case GGML_TYPE_IQ4_NL:
9779
+ // currently supported only in coopmat2 path
9780
+ if (!coopmat2) {
9781
+ return false;
9782
+ }
8803
9783
  break;
8804
9784
  default:
8805
9785
  return false;
8806
9786
  }
9787
+ if (!coopmat2 && !device->subgroup_shuffle) {
9788
+ // scalar FA uses subgroupShuffle
9789
+ return false;
9790
+ }
8807
9791
  return true;
8808
9792
  }
8809
9793
  case GGML_OP_GET_ROWS:
@@ -8811,6 +9795,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8811
9795
  switch (op->src[0]->type) {
8812
9796
  case GGML_TYPE_F32:
8813
9797
  case GGML_TYPE_F16:
9798
+ case GGML_TYPE_BF16:
8814
9799
  case GGML_TYPE_Q4_0:
8815
9800
  case GGML_TYPE_Q4_1:
8816
9801
  case GGML_TYPE_Q5_0:
@@ -8841,6 +9826,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8841
9826
  switch (src1_type) {
8842
9827
  case GGML_TYPE_F32:
8843
9828
  case GGML_TYPE_F16:
9829
+ case GGML_TYPE_BF16:
8844
9830
  case GGML_TYPE_Q4_0:
8845
9831
  case GGML_TYPE_Q4_1:
8846
9832
  case GGML_TYPE_Q5_0:
@@ -8854,6 +9840,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8854
9840
  }
8855
9841
  if (src1_type == GGML_TYPE_F32) {
8856
9842
  switch (src0_type) {
9843
+ case GGML_TYPE_F16:
8857
9844
  case GGML_TYPE_Q4_0:
8858
9845
  case GGML_TYPE_Q4_1:
8859
9846
  case GGML_TYPE_Q5_0:
@@ -8882,16 +9869,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8882
9869
  case GGML_OP_VIEW:
8883
9870
  case GGML_OP_PERMUTE:
8884
9871
  case GGML_OP_TRANSPOSE:
9872
+ case GGML_OP_RMS_NORM:
8885
9873
  return true;
8886
9874
  case GGML_OP_NORM:
8887
9875
  case GGML_OP_GROUP_NORM:
8888
- case GGML_OP_RMS_NORM:
8889
9876
  case GGML_OP_L2_NORM:
8890
9877
  return ggml_is_contiguous(op->src[0]);
8891
9878
  case GGML_OP_ADD:
8892
9879
  case GGML_OP_SUB:
8893
9880
  case GGML_OP_MUL:
8894
9881
  case GGML_OP_DIV:
9882
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
9883
+ (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
9884
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
8895
9885
  case GGML_OP_SILU_BACK:
8896
9886
  case GGML_OP_RMS_NORM_BACK:
8897
9887
  case GGML_OP_SQR:
@@ -8899,9 +9889,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8899
9889
  case GGML_OP_COS:
8900
9890
  case GGML_OP_CLAMP:
8901
9891
  return op->src[0]->type == GGML_TYPE_F32;
9892
+ case GGML_OP_UPSCALE:
9893
+ return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
8902
9894
  case GGML_OP_ACC:
8903
9895
  case GGML_OP_CONCAT:
8904
- case GGML_OP_UPSCALE:
8905
9896
  case GGML_OP_SCALE:
8906
9897
  case GGML_OP_PAD:
8907
9898
  case GGML_OP_DIAG_MASK_INF:
@@ -8914,6 +9905,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8914
9905
  case GGML_OP_COUNT_EQUAL:
8915
9906
  case GGML_OP_IM2COL:
8916
9907
  case GGML_OP_TIMESTEP_EMBEDDING:
9908
+ case GGML_OP_CONV_2D_DW:
8917
9909
  case GGML_OP_POOL_2D:
8918
9910
  case GGML_OP_RWKV_WKV6:
8919
9911
  case GGML_OP_RWKV_WKV7:
@@ -9254,7 +10246,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9254
10246
  }
9255
10247
 
9256
10248
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
9257
- const float *params = (const float *)tensor->op_params;
10249
+ const float * params = (const float *)tensor->op_params;
9258
10250
  tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
9259
10251
  } else if (tensor->op == GGML_OP_MUL_MAT) {
9260
10252
  tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
@@ -9269,9 +10261,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9269
10261
  } else if (tensor->op == GGML_OP_CONCAT) {
9270
10262
  tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
9271
10263
  } else if (tensor->op == GGML_OP_UPSCALE) {
9272
- tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
10264
+ tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]);
9273
10265
  } else if (tensor->op == GGML_OP_SCALE) {
9274
- tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
10266
+ const float * params = (const float *)tensor->op_params;
10267
+ tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
9275
10268
  } else if (tensor->op == GGML_OP_SQR) {
9276
10269
  tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
9277
10270
  } else if (tensor->op == GGML_OP_SIN) {
@@ -9279,7 +10272,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9279
10272
  } else if (tensor->op == GGML_OP_COS) {
9280
10273
  tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
9281
10274
  } else if (tensor->op == GGML_OP_CLAMP) {
9282
- tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
10275
+ const float * params = (const float *)tensor->op_params;
10276
+ tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
9283
10277
  } else if (tensor->op == GGML_OP_PAD) {
9284
10278
  tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
9285
10279
  } else if (tensor->op == GGML_OP_REPEAT) {
@@ -9293,7 +10287,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9293
10287
  } else if (tensor->op == GGML_OP_NORM) {
9294
10288
  tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9295
10289
  } else if (tensor->op == GGML_OP_GROUP_NORM) {
9296
- tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
10290
+ const float * float_params = (const float *)tensor->op_params;
10291
+ tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
9297
10292
  } else if (tensor->op == GGML_OP_RMS_NORM) {
9298
10293
  tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9299
10294
  } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
@@ -9306,14 +10301,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9306
10301
  tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
9307
10302
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9308
10303
  if (src1 != nullptr) {
9309
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
10304
+ const float * params = (const float *)tensor->op_params;
10305
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
9310
10306
  } else {
9311
10307
  tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
9312
10308
  }
9313
10309
  } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9314
10310
  tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9315
10311
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
9316
- tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
10312
+ tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
9317
10313
  } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
9318
10314
  const int n_dims = ((int32_t *) tensor->op_params)[1];
9319
10315
  const int mode = ((int32_t *) tensor->op_params)[2];