@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,7 +1,8 @@
1
1
  #include "ggml-vulkan.h"
2
2
  #include <vulkan/vulkan_core.h>
3
- #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
3
+ #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS)
4
4
  #include <chrono>
5
+ #include "ggml-cpu.h"
5
6
  #endif
6
7
 
7
8
  #include <vulkan/vulkan.hpp>
@@ -43,12 +44,6 @@
43
44
 
44
45
  #define MAX_VK_BUFFERS 256
45
46
 
46
- #ifndef K_QUANTS_PER_ITERATION
47
- #define K_QUANTS_PER_ITERATION 1
48
- #else
49
- static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
50
- #endif
51
-
52
47
  #define VK_CHECK(err, msg) \
53
48
  do { \
54
49
  vk::Result err_ = (err); \
@@ -106,6 +101,15 @@ struct vk_matmul_pipeline_struct {
106
101
 
107
102
  typedef std::shared_ptr<vk_matmul_pipeline_struct> vk_matmul_pipeline;
108
103
 
104
+ struct vk_matmul_pipeline2 {
105
+ vk_matmul_pipeline2() {
106
+ f16acc = std::make_shared<vk_matmul_pipeline_struct>();
107
+ f32acc = std::make_shared<vk_matmul_pipeline_struct>();
108
+ }
109
+ vk_matmul_pipeline f32acc;
110
+ vk_matmul_pipeline f16acc;
111
+ };
112
+
109
113
  struct vk_device_struct;
110
114
  typedef std::shared_ptr<vk_device_struct> vk_device;
111
115
  typedef std::weak_ptr<vk_device_struct> vk_device_ref;
@@ -149,29 +153,53 @@ struct vk_device_struct {
149
153
  std::string name;
150
154
  uint64_t max_memory_allocation_size;
151
155
  bool fp16;
156
+ bool pipeline_robustness;
152
157
  vk::Device device;
153
158
  uint32_t vendor_id;
154
159
  vk_queue compute_queue;
155
160
  vk_queue transfer_queue;
156
161
  bool single_queue;
157
162
  uint32_t subgroup_size;
163
+ uint32_t shader_core_count;
158
164
  bool uma;
165
+ bool float_controls_rte_fp16;
166
+
167
+ bool subgroup_size_control;
168
+ uint32_t subgroup_min_size;
169
+ uint32_t subgroup_max_size;
170
+ bool subgroup_require_full_support;
171
+
172
+ bool coopmat_support;
173
+ bool coopmat_acc_f32_support;
174
+ bool coopmat_acc_f16_support;
175
+ uint32_t coopmat_m;
176
+ uint32_t coopmat_n;
177
+ uint32_t coopmat_k;
178
+ bool coopmat2;
159
179
 
160
180
  size_t idx;
161
181
 
182
+ bool mul_mat_l;
183
+ bool mul_mat_m;
184
+ bool mul_mat_s;
185
+ bool mul_mat_id_l;
186
+ bool mul_mat_id_m;
187
+ bool mul_mat_id_s;
188
+
162
189
  vk_matmul_pipeline pipeline_matmul_f32;
163
190
  vk_matmul_pipeline pipeline_matmul_f32_f16;
164
- vk_matmul_pipeline pipeline_matmul_f16;
165
- vk_matmul_pipeline pipeline_matmul_f16_f32;
191
+ vk_matmul_pipeline2 pipeline_matmul_f16;
192
+ vk_matmul_pipeline2 pipeline_matmul_f16_f32;
166
193
  vk_pipeline pipeline_matmul_split_k_reduce;
167
194
 
168
- vk_matmul_pipeline pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
195
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
196
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
169
197
 
170
198
  vk_matmul_pipeline pipeline_matmul_id_f32;
171
- vk_matmul_pipeline pipeline_matmul_id_f16;
172
- vk_matmul_pipeline pipeline_matmul_id_f16_f32;
199
+ vk_matmul_pipeline2 pipeline_matmul_id_f16;
200
+ vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
173
201
 
174
- vk_matmul_pipeline pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
202
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
175
203
 
176
204
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
177
205
  vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
@@ -183,9 +211,10 @@ struct vk_device_struct {
183
211
  vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
184
212
  vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
185
213
  vk_pipeline pipeline_acc_f32;
186
- vk_pipeline pipeline_add_f32, pipeline_add_f16_f32_f16;
187
- vk_pipeline pipeline_mul_f32;
188
- vk_pipeline pipeline_div_f32;
214
+ vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
215
+ vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
216
+ vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
217
+ vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
189
218
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
190
219
  vk_pipeline pipeline_upscale_f32;
191
220
  vk_pipeline pipeline_scale_f32;
@@ -196,6 +225,7 @@ struct vk_device_struct {
196
225
  vk_pipeline pipeline_pad_f32;
197
226
  vk_pipeline pipeline_repeat_f32;
198
227
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
228
+ vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
199
229
  vk_pipeline pipeline_norm_f32;
200
230
  vk_pipeline pipeline_group_norm_f32;
201
231
  vk_pipeline pipeline_rms_norm_f32;
@@ -207,12 +237,23 @@ struct vk_device_struct {
207
237
  vk_pipeline pipeline_tanh_f32;
208
238
  vk_pipeline pipeline_diag_mask_inf_f32;
209
239
  vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
240
+ vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
210
241
  vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
211
242
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
212
243
  vk_pipeline pipeline_argsort_f32;
213
244
  vk_pipeline pipeline_sum_rows_f32;
214
245
  vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
215
246
  vk_pipeline pipeline_timestep_embedding_f32;
247
+ vk_pipeline pipeline_pool2d_f32;
248
+ vk_pipeline pipeline_rwkv_wkv6_f32;
249
+
250
+ // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
251
+ vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
252
+ vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
253
+ vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
254
+ vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
255
+ vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
256
+ vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
216
257
 
217
258
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
218
259
  std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
@@ -325,6 +366,40 @@ struct vk_mat_vec_id_push_constants {
325
366
  uint32_t nei0; uint32_t ne11;
326
367
  };
327
368
 
369
+ struct vk_flash_attn_push_constants {
370
+ uint32_t N;
371
+ uint32_t KV;
372
+
373
+ uint32_t ne1;
374
+ uint32_t ne2;
375
+ uint32_t ne3;
376
+
377
+ uint32_t neq2;
378
+ uint32_t neq3;
379
+ uint32_t nek2;
380
+ uint32_t nek3;
381
+ uint32_t nev2;
382
+ uint32_t nev3;
383
+ uint32_t nem1;
384
+
385
+ uint32_t nb02;
386
+ uint32_t nb03;
387
+ uint32_t nb12;
388
+ uint32_t nb13;
389
+ uint32_t nb22;
390
+ uint32_t nb23;
391
+ uint32_t nb31;
392
+
393
+ float scale;
394
+ float max_bias;
395
+ float logit_softcap;
396
+
397
+ uint32_t mask;
398
+ uint32_t n_head_log2;
399
+ float m0;
400
+ float m1;
401
+ };
402
+
328
403
  struct vk_op_push_constants {
329
404
  uint32_t KX;
330
405
  uint32_t KY;
@@ -338,7 +413,46 @@ struct vk_op_unary_push_constants {
338
413
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
339
414
  uint32_t d_offset;
340
415
  float param1; float param2;
416
+ uint32_t ne0_012mp; uint32_t ne0_012L;
417
+ uint32_t ne0_01mp; uint32_t ne0_01L;
418
+ uint32_t ne0_0mp; uint32_t ne0_0L;
419
+ uint32_t ne1_012mp; uint32_t ne1_012L;
420
+ uint32_t ne1_01mp; uint32_t ne1_01L;
421
+ uint32_t ne1_0mp; uint32_t ne1_0L;
341
422
  };
423
+ static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
424
+
425
+ // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
426
+ // Precompute mp (m' in the paper) and L such that division
427
+ // can be computed using a multiply (high 32b of 64b result)
428
+ // and a shift:
429
+ //
430
+ // n/d = (mulhi(n, mp) + n) >> L;
431
+ static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
432
+ {
433
+ // compute L = ceil(log2(d));
434
+ L = 0;
435
+ while (L < 32 && (uint32_t{1} << L) < d) {
436
+ L++;
437
+ }
438
+
439
+ mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
440
+ }
441
+
442
+ template <typename T> void init_pushconst_fastdiv(T &p) {
443
+ GGML_UNUSED(p);
444
+ static_assert(!std::is_const<T>::value, "unexpected type");
445
+ }
446
+
447
+ template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
448
+ // Compute magic values to divide by these six numbers.
449
+ init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L);
450
+ init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L);
451
+ init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L);
452
+ init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L);
453
+ init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L);
454
+ init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L);
455
+ }
342
456
 
343
457
  struct vk_op_binary_push_constants {
344
458
  uint32_t ne;
@@ -376,6 +490,7 @@ struct vk_op_soft_max_push_constants {
376
490
  float m0;
377
491
  float m1;
378
492
  uint32_t n_head_log2;
493
+ uint32_t nrows_x;
379
494
  };
380
495
 
381
496
  struct vk_op_argsort_push_constants {
@@ -403,6 +518,24 @@ struct vk_op_timestep_embedding_push_constants {
403
518
  uint32_t max_period;
404
519
  };
405
520
 
521
+ struct vk_op_pool2d_push_constants {
522
+ uint32_t IW; uint32_t IH;
523
+ uint32_t OW; uint32_t OH;
524
+ uint32_t OC;
525
+ uint32_t pelements;
526
+ uint32_t op;
527
+ int32_t k0; int32_t k1;
528
+ int32_t s0; int32_t s1;
529
+ int32_t p0; int32_t p1;
530
+ };
531
+
532
+ struct vk_op_rwkv_wkv6_push_constants {
533
+ uint32_t B;
534
+ uint32_t T;
535
+ uint32_t C;
536
+ uint32_t H;
537
+ };
538
+
406
539
  // Allow pre-recording command buffers
407
540
  struct vk_staging_memcpy {
408
541
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -629,8 +762,12 @@ static uint32_t compile_count = 0;
629
762
  static std::mutex compile_count_mutex;
630
763
  static std::condition_variable compile_count_cond;
631
764
 
632
- static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
633
- VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
765
+ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
766
+ uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
767
+ uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
768
+ VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size <<
769
+ ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align <<
770
+ ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
634
771
  GGML_ASSERT(parameter_count > 0);
635
772
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
636
773
 
@@ -689,16 +826,39 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
689
826
  specialization_constants.data()
690
827
  );
691
828
 
829
+ vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
830
+
831
+ if (device->subgroup_require_full_support && require_full_subgroups) {
832
+ pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
833
+ }
834
+
692
835
  vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
693
- vk::PipelineShaderStageCreateFlags(),
836
+ pipeline_shader_stage_create_flags,
694
837
  vk::ShaderStageFlagBits::eCompute,
695
838
  pipeline->shader_module,
696
839
  entrypoint.c_str(),
697
840
  &specialization_info);
841
+
842
+ vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
843
+ pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
844
+ if (device->subgroup_size_control && required_subgroup_size > 0) {
845
+ GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
846
+ pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
847
+ }
848
+
698
849
  vk::ComputePipelineCreateInfo compute_pipeline_create_info(
699
- vk::PipelineCreateFlags(),
850
+ vk::PipelineCreateFlags{},
700
851
  pipeline_shader_create_info,
701
852
  pipeline->layout);
853
+
854
+ vk::PipelineRobustnessCreateInfoEXT rci;
855
+
856
+ if (device->pipeline_robustness && disable_robustness) {
857
+ rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
858
+ rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
859
+ compute_pipeline_create_info.setPNext(&rci);
860
+ }
861
+
702
862
  pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
703
863
 
704
864
  {
@@ -710,6 +870,12 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
710
870
  std::lock_guard<std::mutex> guard(compile_count_mutex);
711
871
  assert(compile_count > 0);
712
872
  compile_count--;
873
+
874
+ // "Progress bar" for shader compiles
875
+ static uint32_t total_compile_count = 0;
876
+ if ((total_compile_count++ % 10) == 0) {
877
+ std::cerr << ".";
878
+ }
713
879
  }
714
880
  compile_count_cond.notify_all();
715
881
  }
@@ -1035,7 +1201,6 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
1035
1201
  return buf;
1036
1202
  }
1037
1203
 
1038
- buf->size = size;
1039
1204
  vk::BufferCreateInfo buffer_create_info{
1040
1205
  vk::BufferCreateFlags(),
1041
1206
  size,
@@ -1063,7 +1228,6 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
1063
1228
 
1064
1229
  if (memory_type_index == UINT32_MAX) {
1065
1230
  device->device.destroyBuffer(buf->buffer);
1066
- buf->size = 0;
1067
1231
  throw vk::OutOfDeviceMemoryError("No suitable memory type found");
1068
1232
  }
1069
1233
 
@@ -1080,13 +1244,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
1080
1244
  }
1081
1245
  catch (const vk::SystemError& e) {
1082
1246
  device->device.destroyBuffer(buf->buffer);
1083
- buf->size = 0;
1084
1247
  throw e;
1085
1248
  }
1086
1249
  } else {
1087
1250
  // Out of Host/Device memory, clean up buffer
1088
1251
  device->device.destroyBuffer(buf->buffer);
1089
- buf->size = 0;
1090
1252
  throw e;
1091
1253
  }
1092
1254
  }
@@ -1099,6 +1261,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
1099
1261
  device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
1100
1262
 
1101
1263
  buf->device = device;
1264
+ buf->size = size;
1102
1265
 
1103
1266
  #ifdef GGML_VULKAN_MEMORY_DEBUG
1104
1267
  device->memory_logger->log_allocation(buf, size);
@@ -1188,59 +1351,186 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
1188
1351
  );
1189
1352
  }
1190
1353
 
1354
+ // number of rows/cols for flash attention shader
1355
+ static constexpr uint32_t flash_attention_num_small_rows = 32;
1356
+ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1357
+ GGML_UNUSED(clamp);
1358
+
1359
+ // small rows, large cols
1360
+ if (small_rows) {
1361
+ return {flash_attention_num_small_rows, 128};
1362
+ }
1363
+ // small cols to reduce register count
1364
+ if (ggml_is_quantized(type) || D == 256) {
1365
+ return {64, 32};
1366
+ }
1367
+ return {64, 64};
1368
+ };
1369
+
1370
+ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
1371
+ // Needs to be kept up to date on shader changes
1372
+ const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
1373
+ const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
1374
+ const uint32_t warps = warptile[0] / warptile[10];
1375
+
1376
+ const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1377
+ const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
1378
+ const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1379
+
1380
+ return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
1381
+ }
1382
+
1191
1383
  static void ggml_vk_load_shaders(vk_device& device) {
1192
1384
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
1193
1385
 
1194
- // mulmat
1195
- std::initializer_list<uint32_t> warptile_l = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1196
- std::initializer_list<uint32_t> warptile_m = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1197
- std::initializer_list<uint32_t> warptile_s = { std::max(device->subgroup_size, 16u), 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
1198
-
1199
- std::initializer_list<uint32_t> warptile_mmq_l = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1200
- std::initializer_list<uint32_t> warptile_mmq_m = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1201
- std::initializer_list<uint32_t> warptile_mmq_s = { std::max(device->subgroup_size, 16u), 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
1386
+ std::cerr << "ggml_vulkan: Compiling shaders";
1202
1387
 
1203
- std::array<uint32_t, 3> l_wg_denoms = {128, 128, 1 };
1204
- std::array<uint32_t, 3> m_wg_denoms = { 64, 64, 1 };
1205
- std::array<uint32_t, 3> s_wg_denoms = { 32, 32, 1 };
1388
+ // some shaders have a minimum subgroup size
1389
+ const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
1390
+ const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
1206
1391
 
1207
- uint32_t l_align = 128;
1208
- uint32_t m_align = 64;
1209
- uint32_t s_align = 32;
1392
+ // mulmat
1393
+ std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
1394
+ l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
1395
+ l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
1396
+ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
1397
+ std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
1398
+ l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
1399
+ l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
1400
+ l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
1401
+
1402
+ uint32_t l_align, m_align, s_align;
1403
+ if (device->coopmat2) {
1404
+ // spec constants and tile sizes for non-quant matmul/matmul_id
1405
+ l_warptile = { 256, 128, 256, 64 };
1406
+ m_warptile = { 256, 128, 128, 64 };
1407
+ s_warptile = { 128, 32, 16, 64 };
1408
+ l_wg_denoms = {128, 256, 1 };
1409
+ m_wg_denoms = {128, 128, 1 };
1410
+ s_wg_denoms = { 32, 16, 1 };
1411
+
1412
+ // spec constants and tile sizes for quant matmul (non-Qi_K)
1413
+ l_warptile_mmq = { 256, 128, 256, 64 };
1414
+ m_warptile_mmq = { 256, 128, 128, 64 };
1415
+ s_warptile_mmq = { 256, 128, 128, 64 };
1416
+ l_mmq_wg_denoms = { 128, 256, 1 };
1417
+ m_mmq_wg_denoms = { 128, 128, 1 };
1418
+ s_mmq_wg_denoms = { 128, 128, 1 };
1419
+
1420
+ // spec constants and tile sizes for quant matmul (Qi_K)
1421
+ l_warptile_mmq_k = { 256, 128, 512, 16 };
1422
+ m_warptile_mmq_k = { 256, 128, 256, 16 };
1423
+ s_warptile_mmq_k = { 256, 32, 128, 64 };
1424
+ l_mmq_wg_denoms_k = { 128, 512, 1 };
1425
+ m_mmq_wg_denoms_k = { 128, 256, 1 };
1426
+ s_mmq_wg_denoms_k = { 32, 128, 1 };
1427
+
1428
+ // spec constants and tile sizes for quant matmul_id
1429
+ l_warptile_mmqid = { 256, 128, 128, 16 };
1430
+ m_warptile_mmqid = { 256, 128, 64, 16 };
1431
+ s_warptile_mmqid = { 256, 64, 64, 16 };
1432
+ l_mmqid_wg_denoms = { 128, 128, 1 };
1433
+ m_mmqid_wg_denoms = { 128, 64, 1 };
1434
+ s_mmqid_wg_denoms = { 64, 64, 1 };
1435
+
1436
+ l_align = 128;
1437
+ m_align = 64;
1438
+ s_align = 32;
1439
+ } else {
1440
+ // Matrix cores require different warp group sizes
1441
+ const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
1442
+ const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
1443
+ const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
1444
+ const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
1445
+ const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
1446
+ const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
1447
+ const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
1448
+ const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
1449
+ const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
1450
+
1451
+ l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1452
+ m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1453
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1454
+
1455
+ l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1456
+ m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1457
+ s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1458
+
1459
+ l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1460
+ m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1461
+ s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
1462
+ l_align = 128;
1463
+ m_align = 64;
1464
+ s_align = 32;
1465
+
1466
+ // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
1467
+ // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
1468
+ // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
1469
+ // But the numbers happen to work out for 32KB shared memory size that when using the medium
1470
+ // size there's enough room for everything, and we assert for this.
1471
+ uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1472
+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1473
+ l_warptile = m_warptile;
1474
+ l_wg_denoms = m_wg_denoms;
1475
+ shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1476
+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1477
+ }
1478
+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1479
+ // assert mul_mat_mat_id shaders will fit.
1480
+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1481
+ }
1482
+
1483
+ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1484
+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1485
+ if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
1486
+ l_warptile_mmq = m_warptile_mmq;
1487
+ l_mmq_wg_denoms = m_mmq_wg_denoms;
1488
+ } else {
1489
+ l_warptile_mmq = s_warptile_mmq;
1490
+ l_mmq_wg_denoms = s_mmq_wg_denoms;
1491
+ }
1492
+ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1493
+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1494
+ }
1495
+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1496
+ // assert mul_mat_mat_id shaders will fit.
1497
+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1498
+ }
1499
+ // Disable medium and large matrix multiplication if not enough shared memory is available
1500
+ // Check mmq warptiles as the largest configuration
1501
+ // Throw an error if not enough for any matrix multiplication is available
1502
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
1503
+ std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
1504
+ throw std::runtime_error("Shared memory size too small for matrix multiplication.");
1505
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
1506
+ device->mul_mat_m = false;
1507
+ device->mul_mat_l = false;
1508
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
1509
+ device->mul_mat_l = false;
1510
+ }
1511
+
1512
+ // Disable mul_mat_id if not enough shared memory is available
1513
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
1514
+ device->mul_mat_id_s = false;
1515
+ device->mul_mat_id_m = false;
1516
+ device->mul_mat_id_l = false;
1517
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
1518
+ device->mul_mat_id_m = false;
1519
+ device->mul_mat_id_l = false;
1520
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
1521
+ device->mul_mat_id_l = false;
1522
+ }
1523
+ }
1210
1524
 
1211
1525
  device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1212
1526
  device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1213
- device->pipeline_matmul_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1214
- device->pipeline_matmul_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1215
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
1216
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
1217
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
1218
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
1219
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
1220
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
1221
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
1222
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
1223
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
1224
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
1225
- device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
1226
1527
 
1227
1528
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1228
- device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1229
- device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1230
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
1231
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
1232
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
1233
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
1234
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
1235
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
1236
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
1237
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
1238
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
1239
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
1240
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
1241
1529
 
1242
1530
  std::vector<std::future<void>> compiles;
1243
- auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t>&& specialization_constants, uint32_t align) {
1531
+ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
1532
+ uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
1533
+ uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
1244
1534
  {
1245
1535
  // wait until fewer than N compiles are in progress
1246
1536
  uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@@ -1250,459 +1540,368 @@ static void ggml_vk_load_shaders(vk_device& device) {
1250
1540
  }
1251
1541
  compile_count++;
1252
1542
  }
1253
- compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
1543
+ compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
1544
+ parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
1254
1545
  };
1255
1546
 
1256
- if (device->fp16) {
1257
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
1258
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
1259
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_len, matmul_f32_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
1260
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
1261
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
1262
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_len, matmul_f32_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
1263
-
1264
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
1265
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
1266
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_len, matmul_f32_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
1267
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
1268
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
1269
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_len, matmul_f32_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
1270
-
1271
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
1272
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
1273
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_len, matmul_f16_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
1274
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
1275
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
1276
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_len, matmul_f16_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
1277
-
1278
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
1279
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
1280
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_len, matmul_f16_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
1281
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
1282
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
1283
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_len, matmul_f16_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
1284
-
1285
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1286
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1287
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_len, matmul_q4_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1288
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1289
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1290
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_len, matmul_q4_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1291
-
1292
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1293
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1294
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_len, matmul_q4_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1295
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1296
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1297
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_len, matmul_q4_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1298
-
1299
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1300
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1301
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_len, matmul_q5_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1302
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1303
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1304
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_len, matmul_q5_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1305
-
1306
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1307
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1308
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_len, matmul_q5_1_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1309
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1310
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1311
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_len, matmul_q5_1_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1312
-
1313
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1314
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1315
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_len, matmul_q8_0_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1316
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1317
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1318
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_len, matmul_q8_0_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1319
-
1320
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1321
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1322
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_len, matmul_q2_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1323
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1324
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1325
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_len, matmul_q2_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1326
-
1327
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1328
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1329
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_len, matmul_q3_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1330
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1331
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1332
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_len, matmul_q3_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1333
-
1334
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1335
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1336
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_len, matmul_q4_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1337
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1338
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1339
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_len, matmul_q4_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1340
-
1341
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1342
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1343
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_len, matmul_q5_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1344
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1345
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1346
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_len, matmul_q5_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1347
-
1348
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1349
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1350
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_len, matmul_q6_k_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1351
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1352
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1353
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_len, matmul_q6_k_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1354
-
1355
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1356
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1357
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_len, matmul_iq4_nl_f32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1358
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1359
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1360
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_len, matmul_iq4_nl_f32_aligned_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1361
-
1362
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
1363
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
1364
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_len, matmul_id_f32_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
1365
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
1366
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
1367
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_len, matmul_id_f32_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
1368
-
1369
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
1370
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
1371
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_len, matmul_id_f16_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
1372
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
1373
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
1374
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_len, matmul_id_f16_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
1375
-
1376
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
1377
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
1378
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_len, matmul_id_f16_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
1379
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
1380
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
1381
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_len, matmul_id_f16_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
1382
-
1383
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1384
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1385
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_len, matmul_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1386
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1387
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1388
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_len, matmul_id_q4_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1389
-
1390
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1391
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1392
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_len, matmul_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1393
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1394
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1395
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_len, matmul_id_q4_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1396
-
1397
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1398
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1399
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_len, matmul_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1400
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1401
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1402
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_len, matmul_id_q5_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1403
-
1404
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1405
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1406
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_len, matmul_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1407
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1408
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1409
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_len, matmul_id_q5_1_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1410
-
1411
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1412
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1413
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_len, matmul_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1414
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1415
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1416
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_len, matmul_id_q8_0_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1417
-
1418
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1419
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1420
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_len, matmul_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1421
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1422
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1423
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_len, matmul_id_q2_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1424
-
1425
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1426
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1427
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_len, matmul_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1428
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1429
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1430
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_len, matmul_id_q3_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1431
-
1432
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1433
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1434
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_len, matmul_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1435
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1436
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1437
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_len, matmul_id_q4_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1438
-
1439
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1440
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1441
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_len, matmul_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1442
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1443
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1444
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_len, matmul_id_q5_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1445
-
1446
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1447
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1448
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_len, matmul_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1449
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1450
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1451
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_len, matmul_id_q6_k_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1452
-
1453
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1454
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1455
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_len, matmul_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1456
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1457
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1458
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_len, matmul_id_iq4_nl_f32_aligned_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1547
+ #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1548
+ if (device->coopmat2) {
1549
+
1550
+ auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1551
+ return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
1552
+ };
1553
+
1554
+ auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
1555
+ // For large number of rows, 128 invocations seems to work best.
1556
+ // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
1557
+ // can't use 256 for D==80.
1558
+ uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
1559
+ auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
1560
+ return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
1561
+ };
1562
+
1563
+ #define CREATE_FA2(TYPE, NAMELC, D) \
1564
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1565
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1566
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1567
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1568
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1569
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1570
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1571
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1572
+
1573
+ #define CREATE_FA(TYPE, NAMELC) \
1574
+ CREATE_FA2(TYPE, NAMELC, 64) \
1575
+ CREATE_FA2(TYPE, NAMELC, 80) \
1576
+ CREATE_FA2(TYPE, NAMELC, 96) \
1577
+ CREATE_FA2(TYPE, NAMELC, 112) \
1578
+ CREATE_FA2(TYPE, NAMELC, 128) \
1579
+ CREATE_FA2(TYPE, NAMELC, 256)
1580
+
1581
+ CREATE_FA(GGML_TYPE_F16, f16)
1582
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0)
1583
+ CREATE_FA(GGML_TYPE_Q4_1, q4_1)
1584
+ CREATE_FA(GGML_TYPE_Q5_0, q5_0)
1585
+ CREATE_FA(GGML_TYPE_Q5_1, q5_1)
1586
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0)
1587
+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
1588
+ //CREATE_FA(GGML_TYPE_Q2_K, q2_k)
1589
+ //CREATE_FA(GGML_TYPE_Q3_K, q3_k)
1590
+ //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
1591
+ //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
1592
+ //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
1593
+ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
1594
+ #undef CREATE_FA
1595
+
1596
+ // Create 6 variants, {s,m,l}x{unaligned,aligned}
1597
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1598
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1599
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1600
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1601
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1602
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1603
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1604
+
1605
+ // Create 2 variants, {f16,f32} accumulator
1606
+ #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1607
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1608
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1609
+
1610
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1611
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1612
+
1613
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1614
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1615
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1616
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1617
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1618
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1619
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1620
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1621
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1622
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1623
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1624
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1625
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1626
+
1627
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1628
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1629
+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1630
+
1631
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1632
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1633
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1634
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1635
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1636
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1637
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1638
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1639
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1640
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1641
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1642
+ #undef CREATE_MM
1643
+ #undef CREATE_MM2
1644
+ } else
1645
+ #endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1646
+ if (device->coopmat_support) {
1647
+ // Create 6 variants, {s,m,l}x{unaligned,aligned}
1648
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1649
+ if (device->mul_mat ## ID ## _l) \
1650
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
1651
+ if (device->mul_mat ## ID ## _m) \
1652
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
1653
+ if (device->mul_mat ## ID ## _s) \
1654
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
1655
+ if (device->mul_mat ## ID ## _l) \
1656
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
1657
+ if (device->mul_mat ## ID ## _m) \
1658
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
1659
+ if (device->mul_mat ## ID ## _s) \
1660
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
1661
+
1662
+ // Create 2 variants, {f16,f32} accumulator
1663
+ #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1664
+ if (device->coopmat_acc_f16_support) { \
1665
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1666
+ } \
1667
+ if (device->coopmat_acc_f32_support) { \
1668
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1669
+ } \
1670
+
1671
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1672
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1673
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1674
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1675
+
1676
+ if (device->coopmat_acc_f16_support) {
1677
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1678
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1679
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1680
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1681
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1682
+
1683
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1684
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1685
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1686
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1687
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1688
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1689
+ } else {
1690
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1691
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1692
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1693
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1694
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1695
+
1696
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1697
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1698
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1699
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1700
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1701
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1702
+ }
1703
+
1704
+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1705
+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1706
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1707
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1708
+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1709
+
1710
+ if (device->coopmat_acc_f16_support) {
1711
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1712
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1713
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1714
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1715
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1716
+
1717
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1718
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1719
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1720
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1721
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1722
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1723
+ } else {
1724
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1725
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1726
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1727
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1728
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1729
+
1730
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1731
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1732
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1733
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1734
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1735
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1736
+ }
1737
+ }
1738
+ #undef CREATE_MM2
1739
+ #undef CREATE_MM
1740
+ } else if (device->fp16) {
1741
+ // Create 6 variants, {s,m,l}x{unaligned,aligned}
1742
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1743
+ if (device->mul_mat ## ID ## _l) \
1744
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1745
+ if (device->mul_mat ## ID ## _m) \
1746
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1747
+ if (device->mul_mat ## ID ## _s) \
1748
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1749
+ if (device->mul_mat ## ID ## _l) \
1750
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1751
+ if (device->mul_mat ## ID ## _m) \
1752
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1753
+ if (device->mul_mat ## ID ## _s) \
1754
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1755
+
1756
+ // Create 2 variants, {f16,f32} accumulator
1757
+ #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1758
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1759
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1760
+
1761
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1762
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1763
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1764
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1765
+
1766
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1767
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1768
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1769
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1770
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1771
+
1772
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1773
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1774
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1775
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1776
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1777
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1778
+
1779
+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1780
+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1781
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1782
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1783
+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1784
+
1785
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1786
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1787
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1788
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1789
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1790
+
1791
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1792
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1793
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1794
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1795
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1796
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1797
+ }
1798
+ #undef CREATE_MM2
1799
+ #undef CREATE_MM
1459
1800
  } else {
1460
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->l, "matmul_f32_l", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
1461
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->m, "matmul_f32_m", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
1462
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->s, "matmul_f32_s", matmul_f32_f32_fp32_len, matmul_f32_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
1463
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_l, "matmul_f32_aligned_l", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
1464
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_m, "matmul_f32_aligned_m", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
1465
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32->a_s, "matmul_f32_aligned_s", matmul_f32_f32_aligned_fp32_len, matmul_f32_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
1466
-
1467
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->l, "matmul_f32_f16_l", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
1468
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->m, "matmul_f32_f16_m", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
1469
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->s, "matmul_f32_f16_s", matmul_f32_f16_fp32_len, matmul_f32_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
1470
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_l, "matmul_f32_f16_aligned_l", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
1471
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_m, "matmul_f32_f16_aligned_m", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
1472
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f32_f16->a_s, "matmul_f32_f16_aligned_s", matmul_f32_f16_aligned_fp32_len, matmul_f32_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
1473
-
1474
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->l, "matmul_f16_l", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
1475
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->m, "matmul_f16_m", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
1476
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->s, "matmul_f16_s", matmul_f16_fp32_len, matmul_f16_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
1477
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_l, "matmul_f16_aligned_l", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
1478
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_m, "matmul_f16_aligned_m", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
1479
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16->a_s, "matmul_f16_aligned_s", matmul_f16_aligned_fp32_len, matmul_f16_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
1480
-
1481
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->l, "matmul_f16_f32_l", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, 1);
1482
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->m, "matmul_f16_f32_m", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, 1);
1483
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->s, "matmul_f16_f32_s", matmul_f16_f32_fp32_len, matmul_f16_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, 1);
1484
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_l, "matmul_f16_f32_aligned_l", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_l, l_align);
1485
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_m, "matmul_f16_f32_aligned_m", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_m, m_align);
1486
- ggml_vk_create_pipeline(device, device->pipeline_matmul_f16_f32->a_s, "matmul_f16_f32_aligned_s", matmul_f16_f32_aligned_fp32_len, matmul_f16_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_s, s_align);
1487
-
1488
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->l, "matmul_q4_0_f32_l", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1489
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->m, "matmul_q4_0_f32_m", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1490
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->s, "matmul_q4_0_f32_s", matmul_q4_0_f32_fp32_len, matmul_q4_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1491
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_l, "matmul_q4_0_f32_aligned_l", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1492
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_m, "matmul_q4_0_f32_aligned_m", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1493
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0]->a_s, "matmul_q4_0_f32_aligned_s", matmul_q4_0_f32_aligned_fp32_len, matmul_q4_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1494
-
1495
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->l, "matmul_q4_1_f32_l", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1496
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->m, "matmul_q4_1_f32_m", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1497
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->s, "matmul_q4_1_f32_s", matmul_q4_1_f32_fp32_len, matmul_q4_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1498
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_l, "matmul_q4_1_f32_aligned_l", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1499
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_m, "matmul_q4_1_f32_aligned_m", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1500
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1]->a_s, "matmul_q4_1_f32_aligned_s", matmul_q4_1_f32_aligned_fp32_len, matmul_q4_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1501
-
1502
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->l, "matmul_q5_0_f32_l", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1503
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->m, "matmul_q5_0_f32_m", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1504
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->s, "matmul_q5_0_f32_s", matmul_q5_0_f32_fp32_len, matmul_q5_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1505
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_l, "matmul_q5_0_f32_aligned_l", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1506
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_m, "matmul_q5_0_f32_aligned_m", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1507
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0]->a_s, "matmul_q5_0_f32_aligned_s", matmul_q5_0_f32_aligned_fp32_len, matmul_q5_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1508
-
1509
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->l, "matmul_q5_1_f32_l", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1510
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->m, "matmul_q5_1_f32_m", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1511
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->s, "matmul_q5_1_f32_s", matmul_q5_1_f32_fp32_len, matmul_q5_1_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1512
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_l, "matmul_q5_1_f32_aligned_l", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1513
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_m, "matmul_q5_1_f32_aligned_m", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1514
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1]->a_s, "matmul_q5_1_f32_aligned_s", matmul_q5_1_f32_aligned_fp32_len, matmul_q5_1_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1515
-
1516
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->l, "matmul_q8_0_f32_l", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1517
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->m, "matmul_q8_0_f32_m", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1518
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->s, "matmul_q8_0_f32_s", matmul_q8_0_f32_fp32_len, matmul_q8_0_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1519
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_l, "matmul_q8_0_f32_aligned_l", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1520
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_m, "matmul_q8_0_f32_aligned_m", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1521
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0]->a_s, "matmul_q8_0_f32_aligned_s", matmul_q8_0_f32_aligned_fp32_len, matmul_q8_0_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1522
-
1523
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->l, "matmul_q2_k_f32_l", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1524
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->m, "matmul_q2_k_f32_m", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1525
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->s, "matmul_q2_k_f32_s", matmul_q2_k_f32_fp32_len, matmul_q2_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1526
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_l, "matmul_q2_k_f32_aligned_l", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1527
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_m, "matmul_q2_k_f32_aligned_m", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1528
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K]->a_s, "matmul_q2_k_f32_aligned_s", matmul_q2_k_f32_aligned_fp32_len, matmul_q2_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1529
-
1530
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->l, "matmul_q3_k_f32_l", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1531
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->m, "matmul_q3_k_f32_m", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1532
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->s, "matmul_q3_k_f32_s", matmul_q3_k_f32_fp32_len, matmul_q3_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1533
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_l, "matmul_q3_k_f32_aligned_l", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1534
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_m, "matmul_q3_k_f32_aligned_m", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1535
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K]->a_s, "matmul_q3_k_f32_aligned_s", matmul_q3_k_f32_aligned_fp32_len, matmul_q3_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1536
-
1537
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->l, "matmul_q4_k_f32_l", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1538
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->m, "matmul_q4_k_f32_m", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1539
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->s, "matmul_q4_k_f32_s", matmul_q4_k_f32_fp32_len, matmul_q4_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1540
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_l, "matmul_q4_k_f32_aligned_l", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1541
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_m, "matmul_q4_k_f32_aligned_m", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1542
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K]->a_s, "matmul_q4_k_f32_aligned_s", matmul_q4_k_f32_aligned_fp32_len, matmul_q4_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1543
-
1544
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->l, "matmul_q5_k_f32_l", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1545
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->m, "matmul_q5_k_f32_m", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1546
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->s, "matmul_q5_k_f32_s", matmul_q5_k_f32_fp32_len, matmul_q5_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1547
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_l, "matmul_q5_k_f32_aligned_l", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1548
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_m, "matmul_q5_k_f32_aligned_m", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1549
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K]->a_s, "matmul_q5_k_f32_aligned_s", matmul_q5_k_f32_aligned_fp32_len, matmul_q5_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1550
-
1551
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->l, "matmul_q6_k_f32_l", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1552
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->m, "matmul_q6_k_f32_m", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1553
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->s, "matmul_q6_k_f32_s", matmul_q6_k_f32_fp32_len, matmul_q6_k_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1554
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_l, "matmul_q6_k_f32_aligned_l", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1555
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_m, "matmul_q6_k_f32_aligned_m", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1556
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K]->a_s, "matmul_q6_k_f32_aligned_s", matmul_q6_k_f32_aligned_fp32_len, matmul_q6_k_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1557
-
1558
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->l, "matmul_iq4_nl_f32_l", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1559
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->m, "matmul_iq4_nl_f32_m", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1560
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->s, "matmul_iq4_nl_f32_s", matmul_iq4_nl_f32_fp32_len, matmul_iq4_nl_f32_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1561
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_l, "matmul_iq4_nl_f32_aligned_l", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1562
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_m, "matmul_iq4_nl_f32_aligned_m", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1563
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL]->a_s, "matmul_iq4_nl_f32_aligned_s", matmul_iq4_nl_f32_aligned_fp32_len, matmul_iq4_nl_f32_aligned_fp32_data, "main", 3, sizeof(vk_mat_mat_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1564
-
1565
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->l, "matmul_id_f32_l", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
1566
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->m, "matmul_id_f32_m", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
1567
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->s, "matmul_id_f32_s", matmul_id_f32_f32_fp32_len, matmul_id_f32_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
1568
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_l, "matmul_id_f32_aligned_l", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
1569
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_m, "matmul_id_f32_aligned_m", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
1570
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f32->a_s, "matmul_id_f32_aligned_s", matmul_id_f32_f32_aligned_fp32_len, matmul_id_f32_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
1571
-
1572
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->l, "matmul_id_f16_l", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
1573
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->m, "matmul_id_f16_m", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
1574
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->s, "matmul_id_f16_s", matmul_id_f16_fp32_len, matmul_id_f16_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
1575
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_l, "matmul_id_f16_aligned_l", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
1576
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_m, "matmul_id_f16_aligned_m", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
1577
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16->a_s, "matmul_id_f16_aligned_s", matmul_id_f16_aligned_fp32_len, matmul_id_f16_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
1578
-
1579
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->l, "matmul_id_f16_f32_l", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, 1);
1580
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->m, "matmul_id_f16_f32_m", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, 1);
1581
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->s, "matmul_id_f16_f32_s", matmul_id_f16_f32_fp32_len, matmul_id_f16_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, 1);
1582
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_l, "matmul_id_f16_f32_aligned_l", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_l, l_align);
1583
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_m, "matmul_id_f16_f32_aligned_m", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_m, m_align);
1584
- ggml_vk_create_pipeline(device, device->pipeline_matmul_id_f16_f32->a_s, "matmul_id_f16_f32_aligned_s", matmul_id_f16_f32_aligned_fp32_len, matmul_id_f16_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_s, s_align);
1585
-
1586
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->l, "matmul_id_q4_0_f32_l", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1587
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->m, "matmul_id_q4_0_f32_m", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1588
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->s, "matmul_id_q4_0_f32_s", matmul_id_q4_0_f32_fp32_len, matmul_id_q4_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1589
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_l, "matmul_id_q4_0_f32_aligned_l", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1590
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_m, "matmul_id_q4_0_f32_aligned_m", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1591
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0]->a_s, "matmul_id_q4_0_f32_aligned_s", matmul_id_q4_0_f32_aligned_fp32_len, matmul_id_q4_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1592
-
1593
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->l, "matmul_id_q4_1_f32_l", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1594
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->m, "matmul_id_q4_1_f32_m", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1595
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->s, "matmul_id_q4_1_f32_s", matmul_id_q4_1_f32_fp32_len, matmul_id_q4_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1596
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_l, "matmul_id_q4_1_f32_aligned_l", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1597
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_m, "matmul_id_q4_1_f32_aligned_m", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1598
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1]->a_s, "matmul_id_q4_1_f32_aligned_s", matmul_id_q4_1_f32_aligned_fp32_len, matmul_id_q4_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1599
-
1600
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->l, "matmul_id_q5_0_f32_l", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1601
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->m, "matmul_id_q5_0_f32_m", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1602
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->s, "matmul_id_q5_0_f32_s", matmul_id_q5_0_f32_fp32_len, matmul_id_q5_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1603
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_l, "matmul_id_q5_0_f32_aligned_l", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1604
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_m, "matmul_id_q5_0_f32_aligned_m", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1605
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0]->a_s, "matmul_id_q5_0_f32_aligned_s", matmul_id_q5_0_f32_aligned_fp32_len, matmul_id_q5_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1606
-
1607
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->l, "matmul_id_q5_1_f32_l", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1608
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->m, "matmul_id_q5_1_f32_m", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1609
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->s, "matmul_id_q5_1_f32_s", matmul_id_q5_1_f32_fp32_len, matmul_id_q5_1_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1610
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_l, "matmul_id_q5_1_f32_aligned_l", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1611
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_m, "matmul_id_q5_1_f32_aligned_m", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1612
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1]->a_s, "matmul_id_q5_1_f32_aligned_s", matmul_id_q5_1_f32_aligned_fp32_len, matmul_id_q5_1_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1613
-
1614
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->l, "matmul_id_q8_0_f32_l", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1615
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->m, "matmul_id_q8_0_f32_m", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1616
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->s, "matmul_id_q8_0_f32_s", matmul_id_q8_0_f32_fp32_len, matmul_id_q8_0_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1617
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_l, "matmul_id_q8_0_f32_aligned_l", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1618
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_m, "matmul_id_q8_0_f32_aligned_m", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1619
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0]->a_s, "matmul_id_q8_0_f32_aligned_s", matmul_id_q8_0_f32_aligned_fp32_len, matmul_id_q8_0_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1620
-
1621
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->l, "matmul_id_q2_k_f32_l", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1622
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->m, "matmul_id_q2_k_f32_m", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1623
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->s, "matmul_id_q2_k_f32_s", matmul_id_q2_k_f32_fp32_len, matmul_id_q2_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1624
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_l, "matmul_id_q2_k_f32_aligned_l", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1625
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_m, "matmul_id_q2_k_f32_aligned_m", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1626
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K]->a_s, "matmul_id_q2_k_f32_aligned_s", matmul_id_q2_k_f32_aligned_fp32_len, matmul_id_q2_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1627
-
1628
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->l, "matmul_id_q3_k_f32_l", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1629
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->m, "matmul_id_q3_k_f32_m", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1630
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->s, "matmul_id_q3_k_f32_s", matmul_id_q3_k_f32_fp32_len, matmul_id_q3_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1631
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_l, "matmul_id_q3_k_f32_aligned_l", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1632
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_m, "matmul_id_q3_k_f32_aligned_m", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1633
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K]->a_s, "matmul_id_q3_k_f32_aligned_s", matmul_id_q3_k_f32_aligned_fp32_len, matmul_id_q3_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1634
-
1635
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->l, "matmul_id_q4_k_f32_l", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1636
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->m, "matmul_id_q4_k_f32_m", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1637
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->s, "matmul_id_q4_k_f32_s", matmul_id_q4_k_f32_fp32_len, matmul_id_q4_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1638
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_l, "matmul_id_q4_k_f32_aligned_l", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1639
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_m, "matmul_id_q4_k_f32_aligned_m", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1640
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K]->a_s, "matmul_id_q4_k_f32_aligned_s", matmul_id_q4_k_f32_aligned_fp32_len, matmul_id_q4_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1641
-
1642
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->l, "matmul_id_q5_k_f32_l", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1643
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->m, "matmul_id_q5_k_f32_m", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1644
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->s, "matmul_id_q5_k_f32_s", matmul_id_q5_k_f32_fp32_len, matmul_id_q5_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1645
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_l, "matmul_id_q5_k_f32_aligned_l", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1646
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_m, "matmul_id_q5_k_f32_aligned_m", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1647
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K]->a_s, "matmul_id_q5_k_f32_aligned_s", matmul_id_q5_k_f32_aligned_fp32_len, matmul_id_q5_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1648
-
1649
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->l, "matmul_id_q6_k_f32_l", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1650
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->m, "matmul_id_q6_k_f32_m", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1651
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->s, "matmul_id_q6_k_f32_s", matmul_id_q6_k_f32_fp32_len, matmul_id_q6_k_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1652
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_l, "matmul_id_q6_k_f32_aligned_l", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1653
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_m, "matmul_id_q6_k_f32_aligned_m", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1654
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K]->a_s, "matmul_id_q6_k_f32_aligned_s", matmul_id_q6_k_f32_aligned_fp32_len, matmul_id_q6_k_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1655
-
1656
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->l, "matmul_id_iq4_nl_f32_l", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1657
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->m, "matmul_id_iq4_nl_f32_m", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1658
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->s, "matmul_id_iq4_nl_f32_s", matmul_id_iq4_nl_f32_fp32_len, matmul_id_iq4_nl_f32_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1659
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_l, "matmul_id_iq4_nl_f32_aligned_l", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), l_wg_denoms, warptile_mmq_l, l_align);
1660
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_m, "matmul_id_iq4_nl_f32_aligned_m", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), m_wg_denoms, warptile_mmq_m, m_align);
1661
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL]->a_s, "matmul_id_iq4_nl_f32_aligned_s", matmul_id_iq4_nl_f32_aligned_fp32_len, matmul_id_iq4_nl_f32_aligned_fp32_data, "main", 4, sizeof(vk_mat_mat_id_push_constants), s_wg_denoms, warptile_mmq_s, s_align);
1801
+ // Create 6 variants, {s,m,l}x{unaligned,aligned}
1802
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1803
+ if (device->mul_mat ## ID ## _l) \
1804
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1805
+ if (device->mul_mat ## ID ## _m) \
1806
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1807
+ if (device->mul_mat ## ID ## _s) \
1808
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1809
+ if (device->mul_mat ## ID ## _l) \
1810
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1811
+ if (device->mul_mat ## ID ## _m) \
1812
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1813
+ if (device->mul_mat ## ID ## _s) \
1814
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1815
+
1816
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1817
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1818
+ CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1819
+ CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1820
+
1821
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1822
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1823
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1824
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1825
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1826
+
1827
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1828
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1829
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1830
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1831
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1832
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1833
+
1834
+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1835
+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1836
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1837
+ CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1838
+ CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1839
+
1840
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1841
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1842
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1843
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1844
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1845
+
1846
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1847
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1848
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1849
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1850
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1851
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1852
+ }
1853
+ #undef CREATE_MM
1662
1854
  }
1663
1855
 
1664
1856
  // mul mat vec
1665
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1666
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1667
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1668
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1669
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1670
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1671
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1672
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1673
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1674
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1675
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1676
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1677
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1678
-
1679
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1680
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1681
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1682
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1683
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1684
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1685
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1686
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1687
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1688
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1689
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1690
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1691
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1692
-
1693
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1694
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1695
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1696
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1697
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1698
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1699
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1700
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1701
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1702
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1703
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1704
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1705
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1857
+
1858
+ // AMD GCN and Intel graphics cards perform best when the number of rows per shader is doubled
1859
+ uint32_t rm = 1;
1860
+ if ((device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size == 64 && device->subgroup_max_size == 64) || device->vendor_id == VK_VENDOR_ID_INTEL)
1861
+ rm = 2;
1862
+
1863
+ // computing additional rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
1864
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1865
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1866
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1867
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1868
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1869
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1870
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
1871
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1872
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1873
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1874
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1875
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1876
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
1877
+
1878
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1879
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1880
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1881
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1882
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1883
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1884
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
1885
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1886
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1887
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1888
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1889
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1890
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
1891
+
1892
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1893
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1894
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1895
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1896
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1897
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1898
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
1899
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1900
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1901
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1902
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1903
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1904
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
1706
1905
 
1707
1906
  // dequant shaders
1708
1907
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -1737,7 +1936,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1737
1936
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
1738
1937
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
1739
1938
 
1740
- ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
1939
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
1741
1940
 
1742
1941
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
1743
1942
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
@@ -1750,13 +1949,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
1750
1949
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1751
1950
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1752
1951
 
1753
- ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1754
- ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1952
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1953
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1954
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1955
+
1956
+ ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1957
+ ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1958
+ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1959
+ ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1755
1960
 
1756
1961
  ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1757
1962
 
1758
- ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1759
- ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1963
+ ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1964
+ ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1965
+ ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1966
+ ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1760
1967
 
1761
1968
  ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
1762
1969
  ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -1785,27 +1992,43 @@ static void ggml_vk_load_shaders(vk_device& device) {
1785
1992
 
1786
1993
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
1787
1994
 
1788
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
1789
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
1995
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1996
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
1997
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1998
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
1790
1999
 
1791
2000
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1792
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1793
-
1794
2001
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1795
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2002
+
2003
+ if (device->float_controls_rte_fp16) {
2004
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2005
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2006
+ } else {
2007
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2008
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2009
+ }
1796
2010
 
1797
2011
  ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
1798
2012
 
1799
2013
  ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1800
2014
 
1801
2015
  ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1802
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
2016
+ if (device->float_controls_rte_fp16) {
2017
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
2018
+ } else {
2019
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
2020
+ }
1803
2021
 
1804
2022
  ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
1805
2023
 
2024
+ ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
2025
+
2026
+ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2027
+
1806
2028
  for (auto &c : compiles) {
1807
2029
  c.wait();
1808
2030
  }
2031
+ std::cerr << "Done!" << std::endl;
1809
2032
  }
1810
2033
 
1811
2034
  static vk_device ggml_vk_get_device(size_t idx) {
@@ -1835,12 +2058,40 @@ static vk_device ggml_vk_get_device(size_t idx) {
1835
2058
  device->physical_device = physical_devices[dev_num];
1836
2059
  const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
1837
2060
 
2061
+ bool fp16_storage = false;
2062
+ bool fp16_compute = false;
1838
2063
  bool maintenance4_support = false;
2064
+ bool sm_builtins = false;
2065
+ bool amd_shader_core_properties2 = false;
2066
+ bool pipeline_robustness = false;
2067
+ bool coopmat2_support = false;
2068
+ device->coopmat_support = false;
1839
2069
 
1840
2070
  // Check if maintenance4 is supported
1841
2071
  for (const auto& properties : ext_props) {
1842
2072
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1843
2073
  maintenance4_support = true;
2074
+ } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
2075
+ fp16_storage = true;
2076
+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
2077
+ fp16_compute = true;
2078
+ } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
2079
+ sm_builtins = true;
2080
+ } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
2081
+ amd_shader_core_properties2 = true;
2082
+ } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
2083
+ pipeline_robustness = true;
2084
+ } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
2085
+ device->subgroup_size_control = true;
2086
+ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2087
+ !getenv("GGML_VK_DISABLE_COOPMAT")) {
2088
+ device->coopmat_support = true;
2089
+ device->coopmat_m = 0;
2090
+ device->coopmat_n = 0;
2091
+ device->coopmat_k = 0;
2092
+ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2093
+ !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2094
+ coopmat2_support = true;
1844
2095
  }
1845
2096
  }
1846
2097
 
@@ -1848,18 +2099,51 @@ static vk_device ggml_vk_get_device(size_t idx) {
1848
2099
  vk::PhysicalDeviceMaintenance3Properties props3;
1849
2100
  vk::PhysicalDeviceMaintenance4Properties props4;
1850
2101
  vk::PhysicalDeviceSubgroupProperties subgroup_props;
2102
+ vk::PhysicalDeviceDriverProperties driver_props;
2103
+ vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2104
+ vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2105
+ vk::PhysicalDeviceVulkan12Properties vk12_props;
2106
+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2107
+
1851
2108
  props2.pNext = &props3;
1852
2109
  props3.pNext = &subgroup_props;
2110
+ subgroup_props.pNext = &driver_props;
2111
+ driver_props.pNext = &vk12_props;
2112
+
2113
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
2114
+
1853
2115
  if (maintenance4_support) {
1854
- subgroup_props.pNext = &props4;
2116
+ last_struct->pNext = (VkBaseOutStructure *)&props4;
2117
+ last_struct = (VkBaseOutStructure *)&props4;
2118
+ }
2119
+ if (sm_builtins) {
2120
+ last_struct->pNext = (VkBaseOutStructure *)&sm_props;
2121
+ last_struct = (VkBaseOutStructure *)&sm_props;
1855
2122
  }
2123
+ if (amd_shader_core_properties2) {
2124
+ last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
2125
+ last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
2126
+ }
2127
+ if (device->subgroup_size_control) {
2128
+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
2129
+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
2130
+ }
2131
+
2132
+ #if defined(VK_NV_cooperative_matrix2)
2133
+ vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
2134
+ if (coopmat2_support) {
2135
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
2136
+ last_struct = (VkBaseOutStructure *)&coopmat2_props;
2137
+ }
2138
+ #endif
2139
+
1856
2140
  device->physical_device.getProperties2(&props2);
1857
2141
  device->properties = props2.properties;
1858
2142
 
1859
2143
  const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
1860
2144
 
1861
2145
  if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
1862
- device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
2146
+ device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
1863
2147
  } else if (maintenance4_support) {
1864
2148
  device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
1865
2149
  } else {
@@ -1869,23 +2153,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
1869
2153
  device->vendor_id = device->properties.vendorID;
1870
2154
  device->subgroup_size = subgroup_props.subgroupSize;
1871
2155
  device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1872
-
1873
- bool fp16_storage = false;
1874
- bool fp16_compute = false;
1875
-
1876
- for (const auto& properties : ext_props) {
1877
- if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1878
- fp16_storage = true;
1879
- } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1880
- fp16_compute = true;
1881
- }
2156
+ if (sm_builtins) {
2157
+ device->shader_core_count = sm_props.shaderSMCount;
2158
+ } else if (amd_shader_core_properties2) {
2159
+ device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
2160
+ } else {
2161
+ device->shader_core_count = 0;
1882
2162
  }
2163
+ device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
1883
2164
 
1884
- const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1885
- const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
2165
+ const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
1886
2166
 
1887
2167
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1888
2168
 
2169
+ if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2170
+ // Intel drivers don't support coopmat properly yet
2171
+ // Only RADV supports coopmat properly on AMD
2172
+ device->coopmat_support = false;
2173
+ }
2174
+
1889
2175
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
1890
2176
 
1891
2177
  // Try to find a non-graphics compute queue and transfer-focused queues
@@ -1923,10 +2209,149 @@ static vk_device ggml_vk_get_device(size_t idx) {
1923
2209
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1924
2210
  vk11_features.pNext = &vk12_features;
1925
2211
 
2212
+ last_struct = (VkBaseOutStructure *)&vk12_features;
2213
+
2214
+ VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
2215
+ pl_robustness_features.pNext = nullptr;
2216
+ pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
2217
+ pl_robustness_features.pipelineRobustness = VK_FALSE;
2218
+
2219
+ if (pipeline_robustness) {
2220
+ last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
2221
+ last_struct = (VkBaseOutStructure *)&pl_robustness_features;
2222
+ device_extensions.push_back("VK_EXT_pipeline_robustness");
2223
+ }
2224
+
2225
+ VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
2226
+ subgroup_size_control_features.pNext = nullptr;
2227
+ subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
2228
+ subgroup_size_control_features.computeFullSubgroups = false;
2229
+ subgroup_size_control_features.subgroupSizeControl = false;
2230
+
2231
+ if (device->subgroup_size_control) {
2232
+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
2233
+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
2234
+ }
2235
+
2236
+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2237
+ coopmat_features.pNext = nullptr;
2238
+ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
2239
+ coopmat_features.cooperativeMatrix = VK_FALSE;
2240
+
2241
+ if (device->coopmat_support) {
2242
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
2243
+ last_struct = (VkBaseOutStructure *)&coopmat_features;
2244
+ }
2245
+
2246
+ #if defined(VK_NV_cooperative_matrix2)
2247
+ VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
2248
+ coopmat2_features.pNext = nullptr;
2249
+ coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
2250
+ if (coopmat2_support) {
2251
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
2252
+ last_struct = (VkBaseOutStructure *)&coopmat2_features;
2253
+ device_extensions.push_back("VK_NV_cooperative_matrix2");
2254
+ }
2255
+ #endif
2256
+
1926
2257
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
1927
2258
 
1928
2259
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
1929
2260
 
2261
+ device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
2262
+
2263
+ if (device->subgroup_size_control) {
2264
+ device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
2265
+ device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
2266
+ }
2267
+
2268
+ device->subgroup_size_control = device->subgroup_size_control &&
2269
+ (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
2270
+ subgroup_size_control_features.subgroupSizeControl;
2271
+
2272
+ if (device->subgroup_size_control) {
2273
+ device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
2274
+ device_extensions.push_back("VK_EXT_subgroup_size_control");
2275
+ }
2276
+
2277
+ device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
2278
+
2279
+ if (coopmat2_support) {
2280
+ #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2281
+ if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
2282
+ coopmat2_features.cooperativeMatrixFlexibleDimensions &&
2283
+ coopmat2_features.cooperativeMatrixReductions &&
2284
+ coopmat2_features.cooperativeMatrixConversions &&
2285
+ coopmat2_features.cooperativeMatrixPerElementOperations &&
2286
+ coopmat2_features.cooperativeMatrixTensorAddressing &&
2287
+ coopmat2_features.cooperativeMatrixBlockLoads &&
2288
+ vk12_features.bufferDeviceAddress) {
2289
+
2290
+ std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;
2291
+ uint32_t count = 0;
2292
+
2293
+ PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
2294
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
2295
+ (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
2296
+ vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
2297
+
2298
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
2299
+
2300
+ VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
2301
+ empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
2302
+ flexible_dimensions.resize(count, empty_prop);
2303
+
2304
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
2305
+
2306
+ bool found_fp16_128 = false,
2307
+ found_fp16_256 = false,
2308
+ found_fp32_128 = false,
2309
+ found_fp32_256 = false;
2310
+ // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
2311
+ // with 32x16x16 and 256 with 32x32x16.
2312
+ for (auto &prop : flexible_dimensions) {
2313
+ if (prop.saturatingAccumulation == VK_FALSE &&
2314
+ prop.scope == VK_SCOPE_WORKGROUP_KHR &&
2315
+ prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
2316
+ prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
2317
+
2318
+ if (prop.workgroupInvocations == 128 &&
2319
+ prop.MGranularity <= 32 &&
2320
+ prop.NGranularity <= 16 &&
2321
+ prop.KGranularity <= 16) {
2322
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
2323
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
2324
+ found_fp16_128 = true;
2325
+ }
2326
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
2327
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
2328
+ found_fp32_128 = true;
2329
+ }
2330
+ }
2331
+ if (prop.workgroupInvocations == 256 &&
2332
+ prop.MGranularity <= 32 &&
2333
+ prop.NGranularity <= 32 &&
2334
+ prop.KGranularity <= 16) {
2335
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
2336
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
2337
+ found_fp16_256 = true;
2338
+ }
2339
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
2340
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
2341
+ found_fp32_256 = true;
2342
+ }
2343
+ }
2344
+ }
2345
+ }
2346
+ if (found_fp16_128 && found_fp16_256 &&
2347
+ found_fp32_128 && found_fp32_256 &&
2348
+ coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
2349
+ device->coopmat2 = true;
2350
+ }
2351
+ }
2352
+ #endif
2353
+ }
2354
+
1930
2355
  if (!vk11_features.storageBuffer16BitAccess) {
1931
2356
  std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1932
2357
  throw std::runtime_error("Unsupported device");
@@ -1941,7 +2366,75 @@ static vk_device ggml_vk_get_device(size_t idx) {
1941
2366
  if (device->fp16) {
1942
2367
  device_extensions.push_back("VK_KHR_shader_float16_int8");
1943
2368
  }
1944
- device->name = device->properties.deviceName.data();
2369
+
2370
+ if (device->coopmat_support) {
2371
+ // Query supported shapes
2372
+ std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
2373
+
2374
+ PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
2375
+ (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
2376
+
2377
+ uint32_t cm_props_num;
2378
+
2379
+ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
2380
+
2381
+ cm_props.resize(cm_props_num);
2382
+
2383
+ for (auto& prop : cm_props) {
2384
+ prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
2385
+ }
2386
+
2387
+ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
2388
+
2389
+ VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
2390
+
2391
+ for (auto& prop : cm_props) {
2392
+ VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
2393
+
2394
+ if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
2395
+ (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
2396
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
2397
+ ) {
2398
+ if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
2399
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
2400
+ // coopmat sizes not set yet
2401
+ if (device->coopmat_m == 0) {
2402
+ device->coopmat_acc_f32_support = true;
2403
+ device->coopmat_m = prop.MSize;
2404
+ device->coopmat_n = prop.NSize;
2405
+ device->coopmat_k = prop.KSize;
2406
+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
2407
+ // Only enable if shape is identical
2408
+ device->coopmat_acc_f32_support = true;
2409
+ }
2410
+ } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
2411
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
2412
+ // coopmat sizes not set yet
2413
+ if (device->coopmat_m == 0) {
2414
+ device->coopmat_acc_f16_support = true;
2415
+ device->coopmat_m = prop.MSize;
2416
+ device->coopmat_n = prop.NSize;
2417
+ device->coopmat_k = prop.KSize;
2418
+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
2419
+ // Only enable if shape is identical
2420
+ device->coopmat_acc_f16_support = true;
2421
+ }
2422
+ }
2423
+ }
2424
+ }
2425
+
2426
+ if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
2427
+ // No suitable matmul mode found
2428
+ GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
2429
+ device->coopmat_support = false;
2430
+ }
2431
+ }
2432
+
2433
+ if (device->coopmat_support) {
2434
+ device_extensions.push_back("VK_KHR_cooperative_matrix");
2435
+ }
2436
+
2437
+ device->name = GGML_VK_NAME + std::to_string(idx);
1945
2438
 
1946
2439
  device_create_info = {
1947
2440
  vk::DeviceCreateFlags(),
@@ -1956,6 +2449,37 @@ static vk_device ggml_vk_get_device(size_t idx) {
1956
2449
  ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
1957
2450
 
1958
2451
  // Shaders
2452
+ // Disable matmul tile sizes early if performance low or not supported
2453
+ switch (device->vendor_id) {
2454
+ #ifndef GGML_VULKAN_RUN_TESTS
2455
+ case VK_VENDOR_ID_AMD:
2456
+ case VK_VENDOR_ID_INTEL:
2457
+ device->mul_mat_l = false;
2458
+ device->mul_mat_m = true;
2459
+ device->mul_mat_s = true;
2460
+ device->mul_mat_id_l = false;
2461
+ device->mul_mat_id_m = true;
2462
+ device->mul_mat_id_s = true;
2463
+ break;
2464
+ case VK_VENDOR_ID_APPLE:
2465
+ device->mul_mat_l = false;
2466
+ device->mul_mat_m = true;
2467
+ device->mul_mat_s = false;
2468
+ device->mul_mat_id_l = false;
2469
+ device->mul_mat_id_m = true;
2470
+ device->mul_mat_id_s = false;
2471
+ break;
2472
+ #endif
2473
+ default:
2474
+ device->mul_mat_l = true;
2475
+ device->mul_mat_m = true;
2476
+ device->mul_mat_s = true;
2477
+ device->mul_mat_id_l = true;
2478
+ device->mul_mat_id_m = true;
2479
+ device->mul_mat_id_s = true;
2480
+ break;
2481
+ }
2482
+
1959
2483
  ggml_vk_load_shaders(device);
1960
2484
 
1961
2485
  if (!device->single_queue) {
@@ -1968,7 +2492,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
1968
2492
 
1969
2493
  device->buffer_type = {
1970
2494
  /* .iface = */ ggml_backend_vk_buffer_type_interface,
1971
- /* .device = */ nullptr,
2495
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
1972
2496
  /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
1973
2497
  };
1974
2498
 
@@ -2013,15 +2537,31 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2013
2537
 
2014
2538
  bool fp16_storage = false;
2015
2539
  bool fp16_compute = false;
2540
+ bool coopmat_support = false;
2541
+ bool coopmat2_support = false;
2016
2542
 
2017
2543
  for (auto properties : ext_props) {
2018
2544
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
2019
2545
  fp16_storage = true;
2020
2546
  } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
2021
2547
  fp16_compute = true;
2548
+ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2549
+ !getenv("GGML_VK_DISABLE_COOPMAT")) {
2550
+ coopmat_support = true;
2551
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2552
+ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2553
+ !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2554
+ coopmat2_support = true;
2555
+ #endif
2022
2556
  }
2023
2557
  }
2024
2558
 
2559
+ if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2560
+ // Intel drivers don't support coopmat properly yet
2561
+ // Only RADV supports coopmat properly on AMD
2562
+ coopmat_support = false;
2563
+ }
2564
+
2025
2565
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
2026
2566
  bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
2027
2567
 
@@ -2044,15 +2584,33 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2044
2584
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
2045
2585
  vk11_features.pNext = &vk12_features;
2046
2586
 
2587
+ // Pointer to the last chain element
2588
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
2589
+
2590
+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2591
+ coopmat_features.pNext = nullptr;
2592
+ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
2593
+ coopmat_features.cooperativeMatrix = VK_FALSE;
2594
+
2595
+ if (coopmat_support) {
2596
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
2597
+ last_struct = (VkBaseOutStructure *)&coopmat_features;
2598
+ }
2599
+
2047
2600
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
2048
2601
 
2049
2602
  fp16 = fp16 && vk12_features.shaderFloat16;
2050
2603
 
2604
+ coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
2605
+
2606
+ std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
2607
+
2051
2608
  std::string device_name = props2.properties.deviceName.data();
2052
- std::cerr << GGML_VK_NAME << idx << ": " << device_name << " (" << driver_props.driverName << ") | uma: " << uma << " | fp16: " << fp16 << " | warp size: " << subgroup_size << std::endl;
2609
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
2610
+ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str());
2053
2611
 
2054
2612
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
2055
- std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
2613
+ GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
2056
2614
  }
2057
2615
  }
2058
2616
 
@@ -2107,8 +2665,7 @@ void ggml_vk_instance_init() {
2107
2665
  };
2108
2666
  validation_features.setPNext(nullptr);
2109
2667
  instance_create_info.setPNext(&validation_features);
2110
-
2111
- std::cerr << "ggml_vulkan: Validation layers enabled" << std::endl;
2668
+ GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n");
2112
2669
  }
2113
2670
  vk_instance.instance = vk::createInstance(instance_create_info);
2114
2671
 
@@ -2222,8 +2779,7 @@ void ggml_vk_instance_init() {
2222
2779
  vk_instance.device_indices.push_back(0);
2223
2780
  }
2224
2781
  }
2225
-
2226
- std::cerr << "ggml_vulkan: Found " << vk_instance.device_indices.size() << " Vulkan devices:" << std::endl;
2782
+ GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
2227
2783
 
2228
2784
  for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
2229
2785
  ggml_vk_print_gpu_info(i);
@@ -2279,7 +2835,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
2279
2835
  return ctx->device->pipeline_dequant[type];
2280
2836
  }
2281
2837
 
2282
- static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
2838
+ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
2283
2839
  VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
2284
2840
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2285
2841
  return ctx->device->pipeline_matmul_f32;
@@ -2287,14 +2843,23 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2287
2843
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2288
2844
  return ctx->device->pipeline_matmul_f32_f16;
2289
2845
  }
2290
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2291
- return ctx->device->pipeline_matmul_f16_f32;
2292
- }
2293
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2294
- return ctx->device->pipeline_matmul_f16;
2846
+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
2847
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2848
+ return ctx->device->pipeline_matmul_f16_f32.f16acc;
2849
+ }
2850
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2851
+ return ctx->device->pipeline_matmul_f16.f16acc;
2852
+ }
2853
+ } else {
2854
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2855
+ return ctx->device->pipeline_matmul_f16_f32.f32acc;
2856
+ }
2857
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2858
+ return ctx->device->pipeline_matmul_f16.f32acc;
2859
+ }
2295
2860
  }
2296
2861
 
2297
- if (src1_type != GGML_TYPE_F32) {
2862
+ if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
2298
2863
  return nullptr;
2299
2864
  }
2300
2865
 
@@ -2315,7 +2880,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2315
2880
  return nullptr;
2316
2881
  }
2317
2882
 
2318
- return ctx->device->pipeline_dequant_mul_mat_mat[src0_type];
2883
+ if (ctx->device->coopmat2) {
2884
+ assert(src1_type == GGML_TYPE_F16);
2885
+ return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
2886
+ }
2887
+ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
2319
2888
  }
2320
2889
 
2321
2890
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@@ -2344,16 +2913,25 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
2344
2913
  return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
2345
2914
  }
2346
2915
 
2347
- static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
2916
+ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
2348
2917
  VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
2349
2918
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2350
2919
  return ctx->device->pipeline_matmul_id_f32;
2351
2920
  }
2352
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2353
- return ctx->device->pipeline_matmul_id_f16_f32;
2354
- }
2355
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2356
- return ctx->device->pipeline_matmul_id_f16;
2921
+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
2922
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2923
+ return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
2924
+ }
2925
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2926
+ return ctx->device->pipeline_matmul_id_f16.f16acc;
2927
+ }
2928
+ } else {
2929
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2930
+ return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
2931
+ }
2932
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2933
+ return ctx->device->pipeline_matmul_id_f16.f32acc;
2934
+ }
2357
2935
  }
2358
2936
 
2359
2937
  GGML_ASSERT(src1_type == GGML_TYPE_F32);
@@ -2375,7 +2953,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
2375
2953
  return nullptr;
2376
2954
  }
2377
2955
 
2378
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
2956
+ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
2379
2957
  }
2380
2958
 
2381
2959
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@@ -2944,55 +3522,44 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
2944
3522
  dst->device->device.resetFences({ dst->device->fence });
2945
3523
  }
2946
3524
 
2947
- static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
3525
+ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
2948
3526
  VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
2949
- // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
2950
- // return 4;
2951
- // }
2952
-
2953
- return 1;
2954
-
2955
- GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
2956
- }
2957
3527
 
2958
- static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2959
- if (m <= 32 || n <= 32) {
2960
- return aligned ? mmp->a_s : mmp->s;
3528
+ uint32_t split_k = 1;
3529
+ if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
3530
+ // If k is 'large' and the SMs will fill less than halfway, use split_k.
3531
+ uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
3532
+ uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
3533
+ if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
3534
+ split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
3535
+ // Clamp to 2 or 4
3536
+ split_k = std::min(split_k, 4u);
3537
+ if (split_k == 3) {
3538
+ split_k = 2;
3539
+ }
3540
+ }
2961
3541
  }
2962
- return aligned ? mmp->a_m : mmp->m;
2963
-
2964
- GGML_UNUSED(ctx);
2965
- }
2966
-
2967
- static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2968
- return aligned ? mmp->a_m : mmp->m;
2969
-
2970
- GGML_UNUSED(ctx);
2971
- }
2972
-
2973
- static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2974
- return aligned ? mmp->a_s : mmp->s;
2975
3542
 
2976
- GGML_UNUSED(ctx);
3543
+ return split_k;
2977
3544
  }
2978
3545
 
2979
3546
  static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2980
3547
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
2981
- switch (ctx->device->vendor_id) {
2982
- case VK_VENDOR_ID_AMD:
2983
- return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
2984
- case VK_VENDOR_ID_APPLE:
2985
- return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
2986
- case VK_VENDOR_ID_INTEL:
2987
- return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
2988
- default:
2989
- break;
3548
+
3549
+ if (ctx->device->coopmat2) {
3550
+ if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
3551
+ return aligned ? mmp->a_l : mmp->l;
3552
+ }
3553
+ if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
3554
+ return aligned ? mmp->a_m : mmp->m;
3555
+ }
3556
+ return aligned ? mmp->a_s : mmp->s;
2990
3557
  }
2991
3558
 
2992
- if (m <= 32 || n <= 32) {
3559
+ if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
2993
3560
  return aligned ? mmp->a_s : mmp->s;
2994
3561
  }
2995
- if (m <= 64 || n <= 64) {
3562
+ if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
2996
3563
  return aligned ? mmp->a_m : mmp->m;
2997
3564
  }
2998
3565
  return aligned ? mmp->a_l : mmp->l;
@@ -3027,6 +3594,33 @@ static void ggml_vk_matmul(
3027
3594
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
3028
3595
  }
3029
3596
 
3597
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3598
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
3599
+
3600
+ if (ctx->device->coopmat2) {
3601
+ if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
3602
+ return aligned ? mmp->a_l : mmp->l;
3603
+ }
3604
+ if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
3605
+ return aligned ? mmp->a_m : mmp->m;
3606
+ }
3607
+ return aligned ? mmp->a_s : mmp->s;
3608
+ }
3609
+
3610
+ if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
3611
+ return aligned ? mmp->a_s : mmp->s;
3612
+ }
3613
+ if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
3614
+ return aligned ? mmp->a_m : mmp->m;
3615
+ }
3616
+ return aligned ? mmp->a_l : mmp->l;
3617
+ }
3618
+
3619
+ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
3620
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3621
+ return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
3622
+ }
3623
+
3030
3624
  static void ggml_vk_matmul_id(
3031
3625
  ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
3032
3626
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
@@ -3050,18 +3644,34 @@ static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) {
3050
3644
  tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
3051
3645
  }
3052
3646
 
3053
- static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, ggml_type from, ggml_type to) {
3054
- if (from == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
3055
- return ctx->device->pipeline_cpy_f32_f32;
3647
+ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) {
3648
+
3649
+ // Choose "contiguous copy" shader if src/dst are contiguous
3650
+ bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst));
3651
+
3652
+ if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) {
3653
+ if (contig) {
3654
+ return ctx->device->pipeline_contig_cpy_f32_f32;
3655
+ } else {
3656
+ return ctx->device->pipeline_cpy_f32_f32;
3657
+ }
3056
3658
  }
3057
- if (from == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
3058
- return ctx->device->pipeline_cpy_f32_f16;
3659
+ if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) {
3660
+ if (contig) {
3661
+ return ctx->device->pipeline_contig_cpy_f32_f16;
3662
+ } else {
3663
+ return ctx->device->pipeline_cpy_f32_f16;
3664
+ }
3059
3665
  }
3060
- if (from == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
3061
- return ctx->device->pipeline_cpy_f16_f16;
3666
+ if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) {
3667
+ if (contig) {
3668
+ return ctx->device->pipeline_contig_cpy_f16_f16;
3669
+ } else {
3670
+ return ctx->device->pipeline_cpy_f16_f16;
3671
+ }
3062
3672
  }
3063
3673
 
3064
- std::cerr << "Missing CPY op for types: " << ggml_type_name(from) << " " << ggml_type_name(to) << std::endl;
3674
+ std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
3065
3675
  GGML_ABORT("fatal error");
3066
3676
  }
3067
3677
 
@@ -3071,16 +3681,27 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
3071
3681
  const int tensor_type_size = ggml_type_size(tensor->type);
3072
3682
 
3073
3683
  const uint32_t ne = ggml_nelements(tensor);
3684
+ std::array<uint32_t, 3> elements;
3685
+
3686
+ if (ne > 262144) {
3687
+ elements = { 512, 512, CEIL_DIV(ne, 262144) };
3688
+ } else if (ne > 512) {
3689
+ elements = { 512, CEIL_DIV(ne, 512), 1 };
3690
+ } else {
3691
+ elements = { ne, 1, 1 };
3692
+ }
3074
3693
 
3075
- const vk_op_unary_push_constants pc = {
3694
+ vk_op_unary_push_constants pc = {
3076
3695
  (uint32_t)ne,
3077
3696
  (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
3078
3697
  (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
3079
3698
  0,
3080
3699
  0.0f, 0.0f,
3700
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
3081
3701
  };
3702
+ init_pushconst_fastdiv(pc);
3082
3703
  ggml_vk_sync_buffers(subctx);
3083
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, { ne, 1, 1 });
3704
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
3084
3705
  }
3085
3706
 
3086
3707
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -3127,18 +3748,20 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3127
3748
  }
3128
3749
 
3129
3750
  const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
3130
- const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
3751
+ // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
3752
+ const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
3753
+ !ggml_vk_dim01_contiguous(src1);
3131
3754
 
3132
3755
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
3133
3756
 
3134
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
3757
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
3135
3758
 
3136
3759
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3137
3760
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3138
3761
 
3139
- if (mmp == nullptr) {
3762
+ if (qx_needs_dequant) {
3140
3763
  // Fall back to dequant + f16 mulmat
3141
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
3764
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
3142
3765
  }
3143
3766
 
3144
3767
  // Not implemented
@@ -3151,10 +3774,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3151
3774
  const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
3152
3775
  const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
3153
3776
 
3154
- const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
3155
-
3156
3777
  vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
3157
3778
 
3779
+ const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
3780
+
3158
3781
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
3159
3782
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
3160
3783
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -3165,12 +3788,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3165
3788
  vk_pipeline to_fp16_vk_1 = nullptr;
3166
3789
 
3167
3790
  if (x_non_contig) {
3168
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
3791
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
3169
3792
  } else {
3170
3793
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
3171
3794
  }
3172
3795
  if (y_non_contig) {
3173
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
3796
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
3174
3797
  } else {
3175
3798
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
3176
3799
  }
@@ -3180,7 +3803,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3180
3803
  if (dryrun) {
3181
3804
  const uint64_t x_sz_upd = x_sz * ne02 * ne03;
3182
3805
  const uint64_t y_sz_upd = y_sz * ne12 * ne13;
3183
- const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
3806
+ const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
3184
3807
  if (
3185
3808
  (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
3186
3809
  (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
@@ -3350,10 +3973,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
3350
3973
  vk_pipeline to_fp16_vk_0 = nullptr;
3351
3974
  vk_pipeline to_fp16_vk_1 = nullptr;
3352
3975
  if (x_non_contig) {
3353
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
3976
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
3354
3977
  }
3355
3978
  if (y_non_contig) {
3356
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
3979
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
3357
3980
  } else {
3358
3981
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
3359
3982
  }
@@ -3447,7 +4070,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
3447
4070
 
3448
4071
  if (ne01 > max_groups_x) {
3449
4072
  groups_z = 64;
3450
- groups_x /= groups_z;
4073
+ groups_x = CEIL_DIV(groups_x, groups_z);
3451
4074
  }
3452
4075
 
3453
4076
  // compute
@@ -3619,9 +4242,19 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
3619
4242
 
3620
4243
  static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
3621
4244
  VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")");
3622
- if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1) {
4245
+ if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 &&
4246
+ // detect 0213 permutation, and batch size of 1
4247
+ src0->nb[0] <= src0->nb[2] &&
4248
+ src0->nb[2] <= src0->nb[1] &&
4249
+ src0->nb[1] <= src0->nb[3] &&
4250
+ src1->nb[0] <= src1->nb[2] &&
4251
+ src1->nb[2] <= src1->nb[1] &&
4252
+ src1->nb[1] <= src1->nb[3] &&
4253
+ src0->ne[3] == 1 &&
4254
+ src1->ne[3] == 1) {
3623
4255
  ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
3624
- } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1) {
4256
+ } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
4257
+ !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
3625
4258
  ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
3626
4259
  } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
3627
4260
  ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -3692,12 +4325,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
3692
4325
 
3693
4326
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
3694
4327
 
3695
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
4328
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
3696
4329
 
3697
4330
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3698
4331
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
3699
4332
 
3700
- if (mmp == nullptr) {
4333
+ if (qx_needs_dequant) {
3701
4334
  GGML_ABORT("fatal error");
3702
4335
  }
3703
4336
 
@@ -3708,10 +4341,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
3708
4341
  const uint64_t y_ne = ne11 * ne10;
3709
4342
  const uint64_t d_ne = ne21 * ne20;
3710
4343
 
3711
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1));
4344
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
3712
4345
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
3713
4346
 
3714
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned);
4347
+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
3715
4348
 
3716
4349
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
3717
4350
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@@ -3724,12 +4357,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
3724
4357
  vk_pipeline to_fp16_vk_1 = nullptr;
3725
4358
 
3726
4359
  if (x_non_contig) {
3727
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, GGML_TYPE_F16);
4360
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
3728
4361
  } else {
3729
4362
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
3730
4363
  }
3731
4364
  if (y_non_contig) {
3732
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, GGML_TYPE_F16);
4365
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
3733
4366
  } else {
3734
4367
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
3735
4368
  }
@@ -3917,10 +4550,10 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
3917
4550
  vk_pipeline to_fp16_vk_0 = nullptr;
3918
4551
  vk_pipeline to_fp16_vk_1 = nullptr;
3919
4552
  if (x_non_contig) {
3920
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0->type, src0->type);
4553
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type);
3921
4554
  }
3922
4555
  if (y_non_contig) {
3923
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1->type, src1->type);
4556
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type);
3924
4557
  } else {
3925
4558
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
3926
4559
  }
@@ -4014,7 +4647,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
4014
4647
 
4015
4648
  if (ne01 > max_groups_x) {
4016
4649
  groups_z = 64;
4017
- groups_x /= groups_z;
4650
+ groups_x = CEIL_DIV(groups_x, groups_z);
4018
4651
  }
4019
4652
 
4020
4653
  // compute
@@ -4039,38 +4672,199 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
4039
4672
  }
4040
4673
  }
4041
4674
 
4042
- static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
4043
- switch (op) {
4044
- case GGML_OP_GET_ROWS:
4045
- GGML_ASSERT(src1->type == GGML_TYPE_I32);
4046
- if (dst->type == GGML_TYPE_F16) {
4047
- return ctx->device->pipeline_get_rows[src0->type];
4048
- }
4049
- if (dst->type == GGML_TYPE_F32) {
4050
- return ctx->device->pipeline_get_rows_f32[src0->type];
4051
- }
4052
- return nullptr;
4053
- case GGML_OP_ACC:
4054
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4055
- return ctx->device->pipeline_acc_f32;
4675
+ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
4676
+ VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
4677
+ std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
4678
+ std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
4679
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
4680
+ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
4681
+
4682
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
4683
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
4684
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
4685
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
4686
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
4687
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
4688
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
4689
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
4690
+
4691
+ const uint32_t nem1 = mask ? mask->ne[1] : 0;
4692
+ const uint32_t nbm1 = mask ? mask->nb[1] : 0;
4693
+
4694
+ const uint32_t D = neq0;
4695
+ const uint32_t N = neq1;
4696
+ const uint32_t KV = nek1;
4697
+
4698
+ GGML_ASSERT(ne0 == D);
4699
+ GGML_ASSERT(ne2 == N);
4700
+
4701
+ // input tensor rows must be contiguous
4702
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
4703
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
4704
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
4705
+
4706
+ GGML_ASSERT(neq0 == D);
4707
+ GGML_ASSERT(nek0 == D);
4708
+ GGML_ASSERT(nev0 == D);
4709
+
4710
+ GGML_ASSERT(neq1 == N);
4711
+ GGML_ASSERT(nev0 == D);
4712
+
4713
+ GGML_ASSERT(nev1 == nek1);
4714
+
4715
+ // dst cannot be transposed or permuted
4716
+ GGML_ASSERT(nb0 == sizeof(float));
4717
+ GGML_ASSERT(nb0 <= nb1);
4718
+ GGML_ASSERT(nb1 <= nb2);
4719
+ GGML_ASSERT(nb2 <= nb3);
4720
+
4721
+ assert(dst->type == GGML_TYPE_F32);
4722
+ assert(q->type == GGML_TYPE_F32);
4723
+ assert(k->type == v->type);
4724
+
4725
+ vk_pipeline *pipelines;
4726
+ // XXX TODO other backends may be changing accumulator precision to default to f32 soon
4727
+ bool f32acc = dst->op_params[3] == GGML_PREC_F32;
4728
+ bool small_rows = N <= flash_attention_num_small_rows;
4729
+ switch (D) {
4730
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
4731
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
4732
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
4733
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
4734
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
4735
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
4736
+ default:
4737
+ assert(!"unsupported D value");
4738
+ return;
4739
+ }
4740
+ assert(pipelines);
4741
+
4742
+ bool aligned = (KV % pipelines[1]->align) == 0;
4743
+ vk_pipeline pipeline = pipelines[aligned];
4744
+ assert(pipeline);
4745
+
4746
+ if (dryrun) {
4747
+ // Request descriptor sets
4748
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
4749
+ return;
4750
+ }
4751
+
4752
+ float scale = 1.0f;
4753
+ float max_bias = 0.0f;
4754
+ float logit_softcap = 0.0f;
4755
+
4756
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
4757
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
4758
+ memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
4759
+
4760
+ if (logit_softcap != 0) {
4761
+ scale /= logit_softcap;
4762
+ }
4763
+
4764
+ const uint32_t n_head_kv = neq2;
4765
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
4766
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
4767
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
4768
+
4769
+ ggml_vk_sync_buffers(subctx);
4770
+
4771
+ vk_buffer d_Q, d_K, d_V, d_D, d_M;
4772
+ uint64_t q_buf_offset, k_buf_offset, v_buf_offset, d_buf_offset, m_buf_offset;
4773
+
4774
+ bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
4775
+
4776
+ if (ctx->device->uma) {
4777
+ ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
4778
+ ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
4779
+ ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
4780
+ ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
4781
+ Q_uma = d_Q != nullptr;
4782
+ K_uma = d_K != nullptr;
4783
+ V_uma = d_V != nullptr;
4784
+ D_uma = d_D != nullptr;
4785
+ if (mask) {
4786
+ ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
4787
+ M_uma = d_M != nullptr;
4788
+ }
4789
+ }
4790
+
4791
+
4792
+ ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
4793
+ ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context;
4794
+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
4795
+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
4796
+
4797
+ if (!Q_uma) {
4798
+ d_Q = q_buf_ctx->dev_buffer;
4799
+ q_buf_offset = vk_tensor_offset(q) + q->view_offs;
4800
+ }
4801
+ if (!K_uma) {
4802
+ d_K = k_buf_ctx->dev_buffer;
4803
+ k_buf_offset = vk_tensor_offset(k) + k->view_offs;
4804
+ }
4805
+ if (!V_uma) {
4806
+ d_V = v_buf_ctx->dev_buffer;
4807
+ v_buf_offset = vk_tensor_offset(v) + v->view_offs;
4808
+ }
4809
+ if (!D_uma) {
4810
+ d_D = d_buf_ctx->dev_buffer;
4811
+ d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
4812
+ }
4813
+
4814
+ if (!M_uma) {
4815
+ d_M = d_Q;
4816
+ m_buf_offset = q_buf_offset;
4817
+ if (mask) {
4818
+ ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context;
4819
+ d_M = m_buf_ctx->dev_buffer;
4820
+ m_buf_offset = vk_tensor_offset(mask) + mask->view_offs;
4821
+ }
4822
+ }
4823
+
4824
+ const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
4825
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
4826
+ {
4827
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
4828
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
4829
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
4830
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
4831
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
4832
+ },
4833
+ sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
4834
+ }
4835
+
4836
+ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
4837
+ switch (op) {
4838
+ case GGML_OP_GET_ROWS:
4839
+ GGML_ASSERT(src1->type == GGML_TYPE_I32);
4840
+ if (dst->type == GGML_TYPE_F16) {
4841
+ return ctx->device->pipeline_get_rows[src0->type];
4842
+ }
4843
+ if (dst->type == GGML_TYPE_F32) {
4844
+ return ctx->device->pipeline_get_rows_f32[src0->type];
4845
+ }
4846
+ return nullptr;
4847
+ case GGML_OP_ACC:
4848
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4849
+ return ctx->device->pipeline_acc_f32;
4056
4850
  }
4057
4851
  return nullptr;
4058
4852
  case GGML_OP_ADD:
4059
4853
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4060
- return ctx->device->pipeline_add_f32;
4854
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
4061
4855
  }
4062
4856
  if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
4063
- return ctx->device->pipeline_add_f16_f32_f16;
4857
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
4064
4858
  }
4065
4859
  return nullptr;
4066
4860
  case GGML_OP_MUL:
4067
4861
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4068
- return ctx->device->pipeline_mul_f32;
4862
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
4069
4863
  }
4070
4864
  return nullptr;
4071
4865
  case GGML_OP_DIV:
4072
4866
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4073
- return ctx->device->pipeline_div_f32;
4867
+ return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
4074
4868
  }
4075
4869
  return nullptr;
4076
4870
  case GGML_OP_CONCAT:
@@ -4127,7 +4921,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
4127
4921
  case GGML_OP_CPY:
4128
4922
  case GGML_OP_CONT:
4129
4923
  case GGML_OP_DUP:
4130
- return ggml_vk_get_cpy_pipeline(ctx, src0->type, dst->type);
4924
+ return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
4131
4925
  case GGML_OP_NORM:
4132
4926
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4133
4927
  return ctx->device->pipeline_norm_f32;
@@ -4183,10 +4977,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
4183
4977
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
4184
4978
 
4185
4979
  if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
4186
- return ctx->device->pipeline_soft_max_f32;
4980
+ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
4187
4981
  }
4188
4982
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
4189
- return ctx->device->pipeline_soft_max_f32_f16;
4983
+ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
4190
4984
  }
4191
4985
  return nullptr;
4192
4986
  case GGML_OP_ROPE:
@@ -4234,6 +5028,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
4234
5028
  return ctx->device->pipeline_timestep_embedding_f32;
4235
5029
  }
4236
5030
  return nullptr;
5031
+ case GGML_OP_POOL_2D:
5032
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5033
+ return ctx->device->pipeline_pool2d_f32;
5034
+ }
5035
+ return nullptr;
5036
+ case GGML_OP_RWKV_WKV6:
5037
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5038
+ return ctx->device->pipeline_rwkv_wkv6_f32;
5039
+ }
5040
+ return nullptr;
4237
5041
  case GGML_OP_LEAKY_RELU:
4238
5042
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
4239
5043
  return ctx->device->pipeline_leaky_relu_f32;
@@ -4255,7 +5059,6 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
4255
5059
  case GGML_OP_DIV:
4256
5060
  case GGML_OP_CONCAT:
4257
5061
  case GGML_OP_UPSCALE:
4258
- case GGML_OP_SCALE:
4259
5062
  case GGML_OP_SQR:
4260
5063
  case GGML_OP_SIN:
4261
5064
  case GGML_OP_COS:
@@ -4269,7 +5072,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
4269
5072
  }
4270
5073
 
4271
5074
  template<typename PC>
4272
- static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false) {
5075
+ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
4273
5076
  VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
4274
5077
  if (src1 != nullptr) {
4275
5078
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4309,6 +5112,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4309
5112
  const uint64_t ned3 = dst->ne[3];
4310
5113
  const uint64_t ned = ned0 * ned1;
4311
5114
 
5115
+ init_pushconst_fastdiv(pc);
5116
+
4312
5117
  vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
4313
5118
 
4314
5119
  if (pipeline == nullptr) {
@@ -4454,7 +5259,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4454
5259
  const uint32_t OH = is_2D ? dst->ne[2] : 1;
4455
5260
  const uint32_t OW = dst->ne[1];
4456
5261
 
4457
- const uint32_t batch = src1->ne[3];
5262
+ const uint32_t batch = src1->ne[is_2D ? 3 : 2];
4458
5263
 
4459
5264
  elements = { OW * KW * KH, OH, batch * IC };
4460
5265
  } break;
@@ -4464,6 +5269,14 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4464
5269
  uint32_t half_ceil = (dim + 1) / 2;
4465
5270
  elements = { half_ceil, (uint32_t)src0->ne[0], 1 };
4466
5271
  } break;
5272
+ case GGML_OP_POOL_2D:
5273
+ {
5274
+ const uint32_t N = dst->ne[3];
5275
+ const uint32_t OC = dst->ne[2];
5276
+ const uint32_t OH = dst->ne[1];
5277
+ const uint32_t OW = dst->ne[0];
5278
+ elements = { N * OC * OH * OW, 1, 1};
5279
+ } break;
4467
5280
  case GGML_OP_ADD:
4468
5281
  case GGML_OP_DIV:
4469
5282
  case GGML_OP_MUL:
@@ -4627,6 +5440,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
4627
5440
  }, dryrun);
4628
5441
  }
4629
5442
 
5443
+ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
5444
+ const ggml_tensor * k = dst->src[0];
5445
+ const ggml_tensor * v = dst->src[1];
5446
+ const ggml_tensor * r = dst->src[2];
5447
+ const ggml_tensor * tf = dst->src[3];
5448
+ const ggml_tensor * td = dst->src[4];
5449
+ const ggml_tensor * state = dst->src[5];
5450
+
5451
+ GGML_ASSERT(!ggml_is_quantized(k->type));
5452
+ GGML_ASSERT(!ggml_is_quantized(v->type));
5453
+ GGML_ASSERT(!ggml_is_quantized(r->type));
5454
+ GGML_ASSERT(!ggml_is_quantized(tf->type));
5455
+ GGML_ASSERT(!ggml_is_quantized(td->type));
5456
+ GGML_ASSERT(!ggml_is_quantized(state->type));
5457
+ GGML_ASSERT(dst->buffer != nullptr);
5458
+
5459
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
5460
+ GGML_ASSERT(pipeline != nullptr);
5461
+
5462
+ if (dryrun) {
5463
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5464
+ return;
5465
+ }
5466
+
5467
+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
5468
+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
5469
+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
5470
+ ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
5471
+ ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
5472
+ ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
5473
+ ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
5474
+
5475
+ ggml_vk_sync_buffers(subctx);
5476
+
5477
+ vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
5478
+ uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
5479
+ bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
5480
+
5481
+ if (ctx->device->uma) {
5482
+ ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
5483
+ ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
5484
+ ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
5485
+ ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
5486
+ ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
5487
+ ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
5488
+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
5489
+
5490
+ K_uma = d_K != nullptr;
5491
+ V_uma = d_V != nullptr;
5492
+ R_uma = d_R != nullptr;
5493
+ TF_uma = d_TF != nullptr;
5494
+ TD_uma = d_TD != nullptr;
5495
+ STATE_uma = d_State != nullptr;
5496
+ DST_uma = d_D != nullptr;
5497
+ }
5498
+
5499
+ if (!K_uma) {
5500
+ d_K = k_buf_ctx->dev_buffer;
5501
+ k_offset = vk_tensor_offset(k) + k->view_offs;
5502
+ }
5503
+ if (!V_uma) {
5504
+ d_V = v_buf_ctx->dev_buffer;
5505
+ v_offset = vk_tensor_offset(v) + v->view_offs;
5506
+ }
5507
+ if (!R_uma) {
5508
+ d_R = r_buf_ctx->dev_buffer;
5509
+ r_offset = vk_tensor_offset(r) + r->view_offs;
5510
+ }
5511
+ if (!TF_uma) {
5512
+ d_TF = tf_buf_ctx->dev_buffer;
5513
+ tf_offset = vk_tensor_offset(tf) + tf->view_offs;
5514
+ }
5515
+ if (!TD_uma) {
5516
+ d_TD = td_buf_ctx->dev_buffer;
5517
+ td_offset = vk_tensor_offset(td) + td->view_offs;
5518
+ }
5519
+ if (!STATE_uma) {
5520
+ d_State = state_buf_ctx->dev_buffer;
5521
+ state_offset = vk_tensor_offset(state) + state->view_offs;
5522
+ }
5523
+ if (!DST_uma) {
5524
+ d_D = dst_buf_ctx->dev_buffer;
5525
+ dst_offset = vk_tensor_offset(dst) + dst->view_offs;
5526
+ }
5527
+
5528
+ const uint64_t k_size = ggml_nbytes(k);
5529
+ const uint64_t v_size = ggml_nbytes(v);
5530
+ const uint64_t r_size = ggml_nbytes(r);
5531
+ const uint64_t tf_size = ggml_nbytes(tf);
5532
+ const uint64_t td_size = ggml_nbytes(td);
5533
+ const uint64_t state_size = ggml_nbytes(state);
5534
+ const uint64_t dst_size = ggml_nbytes(dst);
5535
+
5536
+ std::array<uint32_t, 3> elements = {
5537
+ (uint32_t)(pc.B * pc.H),
5538
+ 1,
5539
+ 1
5540
+ };
5541
+
5542
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
5543
+ vk_subbuffer{ d_K, k_offset, k_size },
5544
+ vk_subbuffer{ d_V, v_offset, v_size },
5545
+ vk_subbuffer{ d_R, r_offset, r_size },
5546
+ vk_subbuffer{ d_TF, tf_offset, tf_size },
5547
+ vk_subbuffer{ d_TD, td_offset, td_size },
5548
+ vk_subbuffer{ d_State, state_offset, state_size },
5549
+ vk_subbuffer{ d_D, dst_offset, dst_size }
5550
+ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
5551
+ }
5552
+
5553
+ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
5554
+ const size_t seq_length = dst->src[0]->ne[3];
5555
+ const size_t n_embed = dst->ne[0];
5556
+ const size_t n_heads = dst->src[0]->ne[2];
5557
+ const size_t n_seqs = dst->src[5]->ne[1];
5558
+
5559
+ ggml_vk_op_f32_rwkv6(
5560
+ ctx, subctx, dst,
5561
+ {
5562
+ (uint32_t)n_seqs,
5563
+ (uint32_t)seq_length,
5564
+ (uint32_t)n_embed,
5565
+ (uint32_t)n_heads,
5566
+ },
5567
+ dryrun
5568
+ );
5569
+ }
5570
+
4630
5571
  static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
4631
5572
  int * op_params = (int *)dst->op_params;
4632
5573
 
@@ -4670,7 +5611,8 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
4670
5611
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
4671
5612
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4672
5613
  0,
4673
- op_params[0], 0.0f
5614
+ op_params[0], 0.0f,
5615
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4674
5616
  }, dryrun);
4675
5617
  }
4676
5618
 
@@ -4684,6 +5626,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
4684
5626
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4685
5627
  0,
4686
5628
  0.0f, 0.0f,
5629
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4687
5630
  }, dryrun);
4688
5631
  }
4689
5632
 
@@ -4697,6 +5640,7 @@ static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const
4697
5640
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4698
5641
  0,
4699
5642
  0.0f, 0.0f,
5643
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4700
5644
  }, dryrun);
4701
5645
  }
4702
5646
 
@@ -4710,6 +5654,7 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
4710
5654
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4711
5655
  0,
4712
5656
  0.0f, 0.0f,
5657
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4713
5658
  }, dryrun);
4714
5659
  }
4715
5660
 
@@ -4724,6 +5669,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
4724
5669
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4725
5670
  0,
4726
5671
  op_params[0], op_params[1],
5672
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4727
5673
  }, dryrun);
4728
5674
  }
4729
5675
 
@@ -4737,6 +5683,7 @@ static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const
4737
5683
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4738
5684
  0,
4739
5685
  0.0f, 0.0f,
5686
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4740
5687
  }, dryrun);
4741
5688
  }
4742
5689
 
@@ -4750,6 +5697,7 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
4750
5697
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4751
5698
  0,
4752
5699
  0.0f, 0.0f,
5700
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4753
5701
  }, dryrun);
4754
5702
  }
4755
5703
 
@@ -4764,6 +5712,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
4764
5712
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4765
5713
  d_offset,
4766
5714
  0.0f, 0.0f,
5715
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4767
5716
  }, dryrun);
4768
5717
  }
4769
5718
 
@@ -4820,6 +5769,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
4820
5769
  scale, max_bias,
4821
5770
  m0, m1,
4822
5771
  n_head_log2,
5772
+ nrows_x,
4823
5773
  }, dryrun);
4824
5774
  }
4825
5775
 
@@ -4891,7 +5841,7 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
4891
5841
  const uint32_t OW = dst->ne[1];
4892
5842
 
4893
5843
  const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
4894
- const uint32_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
5844
+ const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
4895
5845
 
4896
5846
  const uint32_t pelements = OW * KW * KH;
4897
5847
 
@@ -4914,6 +5864,34 @@ static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context
4914
5864
  }, dryrun);
4915
5865
  }
4916
5866
 
5867
+ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
5868
+ uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
5869
+ const int32_t k1 = dst->op_params[1];
5870
+ const int32_t k0 = dst->op_params[2];
5871
+ const int32_t s1 = dst->op_params[3];
5872
+ const int32_t s0 = dst->op_params[4];
5873
+ const int32_t p1 = dst->op_params[5];
5874
+ const int32_t p0 = dst->op_params[6];
5875
+
5876
+ const uint32_t IH = src0->ne[1];
5877
+ const uint32_t IW = src0->ne[0];
5878
+
5879
+ const uint32_t N = dst->ne[3];
5880
+
5881
+ const uint32_t OC = dst->ne[2];
5882
+ const uint32_t OH = dst->ne[1];
5883
+ const uint32_t OW = dst->ne[0];
5884
+
5885
+ const uint32_t parallel_elements = N * OC * OH * OW;
5886
+
5887
+ ggml_vk_op_f32<vk_op_pool2d_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, {
5888
+ IW, IH, OW, OH, OC,
5889
+ parallel_elements,
5890
+ op,
5891
+ k0, k1, s0, s1, p0, p1,
5892
+ }, dryrun);
5893
+ }
5894
+
4917
5895
  static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
4918
5896
  const float * op_params = (const float *)dst->op_params;
4919
5897
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
@@ -4970,10 +5948,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
4970
5948
  p = ctx->device->pipeline_matmul_f32_f16->a_s;
4971
5949
  shname = "F32_F16_ALIGNED_S";
4972
5950
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
4973
- p = ctx->device->pipeline_matmul_f16_f32->a_s;
5951
+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s;
4974
5952
  shname = "F16_F32_ALIGNED_S";
4975
5953
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
4976
- p = ctx->device->pipeline_matmul_f16->a_s;
5954
+ p = ctx->device->pipeline_matmul_f16.f32acc->a_s;
4977
5955
  shname = "F16_ALIGNED_S";
4978
5956
  } else {
4979
5957
  GGML_ABORT("fatal error");
@@ -4986,10 +5964,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
4986
5964
  p = ctx->device->pipeline_matmul_f32_f16->a_m;
4987
5965
  shname = "F32_F16_ALIGNED_M";
4988
5966
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
4989
- p = ctx->device->pipeline_matmul_f16_f32->a_m;
5967
+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m;
4990
5968
  shname = "F16_F32_ALIGNED_M";
4991
5969
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
4992
- p = ctx->device->pipeline_matmul_f16->a_m;
5970
+ p = ctx->device->pipeline_matmul_f16.f32acc->a_m;
4993
5971
  shname = "F16_ALIGNED_M";
4994
5972
  } else {
4995
5973
  GGML_ABORT("fatal error");
@@ -5002,10 +5980,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5002
5980
  p = ctx->device->pipeline_matmul_f32_f16->a_l;
5003
5981
  shname = "F32_F16_ALIGNED_L";
5004
5982
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
5005
- p = ctx->device->pipeline_matmul_f16_f32->a_l;
5983
+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l;
5006
5984
  shname = "F16_F32_ALIGNED_L";
5007
5985
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
5008
- p = ctx->device->pipeline_matmul_f16->a_l;
5986
+ p = ctx->device->pipeline_matmul_f16.f32acc->a_l;
5009
5987
  shname = "F16_ALIGNED_L";
5010
5988
  } else {
5011
5989
  GGML_ABORT("fatal error");
@@ -5025,10 +6003,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5025
6003
  p = ctx->device->pipeline_matmul_f32_f16->s;
5026
6004
  shname = "F32_F16_S";
5027
6005
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
5028
- p = ctx->device->pipeline_matmul_f16_f32->s;
6006
+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->s;
5029
6007
  shname = "F16_F32_S";
5030
6008
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
5031
- p = ctx->device->pipeline_matmul_f16->s;
6009
+ p = ctx->device->pipeline_matmul_f16.f32acc->s;
5032
6010
  shname = "F16_S";
5033
6011
  }
5034
6012
  } else if (shader_size == 1) {
@@ -5039,10 +6017,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5039
6017
  p = ctx->device->pipeline_matmul_f32_f16->m;
5040
6018
  shname = "F32_F16_M";
5041
6019
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
5042
- p = ctx->device->pipeline_matmul_f16_f32->m;
6020
+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->m;
5043
6021
  shname = "F16_F32_M";
5044
6022
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
5045
- p = ctx->device->pipeline_matmul_f16->m;
6023
+ p = ctx->device->pipeline_matmul_f16.f32acc->m;
5046
6024
  shname = "F16_M";
5047
6025
  }
5048
6026
  } else if (shader_size == 2) {
@@ -5053,10 +6031,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5053
6031
  p = ctx->device->pipeline_matmul_f32_f16->l;
5054
6032
  shname = "F32_F16_L";
5055
6033
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<float, Y_TYPE>()) {
5056
- p = ctx->device->pipeline_matmul_f16_f32->l;
6034
+ p = ctx->device->pipeline_matmul_f16_f32.f32acc->l;
5057
6035
  shname = "F16_F32_L";
5058
6036
  } else if (std::is_same<ggml_fp16_t, X_TYPE>() && std::is_same<ggml_fp16_t, Y_TYPE>()) {
5059
- p = ctx->device->pipeline_matmul_f16->l;
6037
+ p = ctx->device->pipeline_matmul_f16.f32acc->l;
5060
6038
  shname = "F16_L";
5061
6039
  }
5062
6040
  }
@@ -5088,19 +6066,27 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5088
6066
  for (size_t i = 0; i < x_ne; i++) {
5089
6067
  if (std::is_same<float, X_TYPE>()) {
5090
6068
  x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
6069
+ // x[i] = 1.0f;
6070
+ // x[i] = i + 1;
6071
+ // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
5091
6072
  } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
5092
6073
  x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
6074
+ // x[i] = ggml_fp32_to_fp16(1.0f);
6075
+ // x[i] = ggml_fp32_to_fp16(i + 1);
6076
+ // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
5093
6077
  } else {
5094
6078
  GGML_ABORT("fatal error");
5095
6079
  }
5096
6080
  }
5097
6081
  for (size_t i = 0; i < y_ne; i++) {
5098
6082
  if (std::is_same<float, Y_TYPE>()) {
5099
- // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
5100
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
6083
+ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
6084
+ // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
6085
+ // y[i] = i + 1;
5101
6086
  } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
5102
- // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
5103
- y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
6087
+ y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
6088
+ // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
6089
+ // y[i] = ggml_fp32_to_fp16(i + 1);
5104
6090
  } else {
5105
6091
  GGML_ABORT("fatal error");
5106
6092
  }
@@ -5110,16 +6096,16 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5110
6096
  ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
5111
6097
 
5112
6098
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
6099
+ ggml_vk_ctx_begin(ctx->device, subctx);
5113
6100
  for (size_t i = 0; i < num_it; i++) {
5114
- ggml_vk_ctx_begin(ctx->device, subctx);
5115
6101
  ggml_vk_matmul(
5116
6102
  ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
5117
6103
  m, n, k,
5118
6104
  k, k, m, k*m, k*n, m*n,
5119
6105
  split_k, batch, batch, batch, 1, 1
5120
6106
  );
5121
- ggml_vk_ctx_end(subctx);
5122
6107
  }
6108
+ ggml_vk_ctx_end(subctx);
5123
6109
 
5124
6110
  auto begin = std::chrono::high_resolution_clock::now();
5125
6111
  ggml_vk_submit(subctx, ctx->fence);
@@ -5184,7 +6170,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5184
6170
  double err = std::fabs(d[i] - d_chk[i]);
5185
6171
  avg_err += err;
5186
6172
 
5187
- if (err > 0.05f && first_err_n == -1) {
6173
+ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
5188
6174
  first_err_b = i / (m * n);
5189
6175
  first_err_n = (i % (m * n)) / m;
5190
6176
  first_err_m = (i % (m * n)) % m;
@@ -5197,12 +6183,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
5197
6183
 
5198
6184
  std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
5199
6185
 
5200
- if (avg_err > 0.1) {
6186
+ if (avg_err > 0.1 || std::isnan(avg_err)) {
5201
6187
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
5202
6188
  std::cerr << "Actual result: " << std::endl << std::endl;
5203
6189
  ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
5204
- std::cerr << std::endl;
5205
- ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
5206
6190
  std::cerr << "Expected result: " << std::endl << std::endl;
5207
6191
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
5208
6192
 
@@ -5287,9 +6271,9 @@ static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, gg
5287
6271
  return;
5288
6272
  }
5289
6273
 
5290
- ggml_type_traits_t tt = ggml_internal_get_type_traits(quant);
6274
+ const auto * tt = ggml_get_type_traits(quant);
5291
6275
 
5292
- ggml_to_float_t dequant_fn = tt.to_float;
6276
+ ggml_to_float_t dequant_fn = tt->to_float;
5293
6277
 
5294
6278
  dequant_fn(from, to, ne);
5295
6279
  }
@@ -5385,13 +6369,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5385
6369
  vk_pipeline p;
5386
6370
  std::string shname;
5387
6371
  if (shader_size == 0) {
5388
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_s;
6372
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
5389
6373
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
5390
6374
  } else if (shader_size == 1) {
5391
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_m;
6375
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
5392
6376
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
5393
6377
  } else if (shader_size == 2) {
5394
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->a_l;
6378
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
5395
6379
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
5396
6380
  } else {
5397
6381
  GGML_ASSERT(0);
@@ -5401,13 +6385,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5401
6385
 
5402
6386
  if (k != kpad) {
5403
6387
  if (shader_size == 0) {
5404
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->s;
6388
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
5405
6389
  shname = std::string(ggml_type_name(quant)) + "_S";
5406
6390
  } else if (shader_size == 1) {
5407
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->m;
6391
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
5408
6392
  shname = std::string(ggml_type_name(quant)) + "_M";
5409
6393
  } else if (shader_size == 2) {
5410
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant]->l;
6394
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
5411
6395
  shname = std::string(ggml_type_name(quant)) + "_L";
5412
6396
  } else {
5413
6397
  GGML_ASSERT(0);
@@ -5457,16 +6441,16 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5457
6441
  ggml_vk_buffer_write(y_buf, 0, y, y_sz);
5458
6442
 
5459
6443
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
6444
+ ggml_vk_ctx_begin(ctx->device, subctx);
5460
6445
  for (size_t i = 0; i < num_it; i++) {
5461
- ggml_vk_ctx_begin(ctx->device, subctx);
5462
6446
  ggml_vk_matmul(
5463
6447
  ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
5464
6448
  m, n, k,
5465
6449
  k, k, m, k*m, k*n, m*n,
5466
6450
  split_k, batch, batch, batch, 1, 1
5467
6451
  );
5468
- ggml_vk_ctx_end(subctx);
5469
6452
  }
6453
+ ggml_vk_ctx_end(subctx);
5470
6454
 
5471
6455
  auto begin = std::chrono::high_resolution_clock::now();
5472
6456
 
@@ -5566,105 +6550,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5566
6550
 
5567
6551
  static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
5568
6552
  #if defined(GGML_VULKAN_RUN_TESTS)
5569
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
5570
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
5571
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
5572
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
5573
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
5574
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
5575
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
5576
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
5577
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
5578
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
5579
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
5580
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
5581
-
5582
- ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
5583
-
5584
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
5585
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
5586
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
5587
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
5588
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
5589
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
5590
-
5591
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
5592
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
5593
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
5594
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
5595
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
5596
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
5597
-
5598
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
5599
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
5600
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
5601
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
5602
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
5603
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
5604
-
5605
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
5606
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
5607
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
5608
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
5609
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
5610
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
5611
-
5612
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
5613
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
5614
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
5615
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
5616
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
5617
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
5618
-
5619
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
5620
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
5621
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
5622
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
5623
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
5624
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
5625
-
5626
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
5627
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
5628
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
5629
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
5630
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
5631
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
5632
-
5633
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
5634
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
5635
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
5636
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
5637
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
5638
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
5639
-
5640
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
5641
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
5642
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
5643
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
5644
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
5645
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
5646
-
5647
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
5648
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
5649
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
5650
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
5651
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
5652
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
5653
-
5654
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
5655
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
5656
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
5657
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
5658
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
5659
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
5660
-
5661
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
5662
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
5663
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
5664
-
5665
- std::cerr << std::endl;
5666
-
5667
6553
  const std::vector<size_t> vals {
6554
+ 512, 512, 128,
6555
+ 128, 512, 512,
6556
+ 4096, 512, 4096,
6557
+ 11008, 512, 4096,
6558
+ 4096, 512, 11008,
6559
+ 32000, 512, 4096,
5668
6560
  8, 8, 8,
5669
6561
  100, 46, 576,
5670
6562
  623, 111, 128,
@@ -5677,25 +6569,52 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
5677
6569
  49, 49, 128,
5678
6570
  128, 49, 49,
5679
6571
  4096, 49, 4096,
5680
- 11008, 49, 4096,
5681
- 4096, 49, 11008,
5682
- 32000, 49, 4096,
5683
- 512, 512, 128,
5684
- 128, 512, 512,
5685
- 4096, 512, 4096,
5686
- 11008, 512, 4096,
5687
- 4096, 512, 11008,
5688
- 32000, 512, 4096,
5689
6572
  };
5690
- const size_t num_it = 1;
6573
+ const size_t num_it = 100;
6574
+
5691
6575
  for (size_t i = 0; i < vals.size(); i += 3) {
5692
6576
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
5693
6577
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
5694
6578
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
5695
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
5696
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
5697
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
5698
- std::cerr << std::endl;
6579
+ std::cerr << '\n';
6580
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
6581
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
6582
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
6583
+ std::cerr << '\n';
6584
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
6585
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
6586
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
6587
+ std::cerr << '\n' << std::endl;
6588
+
6589
+ if (vals[i + 2] % 32 == 0) {
6590
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
6591
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
6592
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
6593
+ std::cerr << '\n';
6594
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
6595
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
6596
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
6597
+ std::cerr << '\n';
6598
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
6599
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
6600
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
6601
+ std::cerr << '\n' << std::endl;
6602
+ }
6603
+
6604
+ if (vals[i + 2] % 256 == 0) {
6605
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
6606
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
6607
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
6608
+ std::cerr << '\n';
6609
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
6610
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
6611
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
6612
+ std::cerr << '\n';
6613
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
6614
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
6615
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
6616
+ std::cerr << '\n' << std::endl;
6617
+ }
5699
6618
  }
5700
6619
 
5701
6620
  GGML_ABORT("fatal error");
@@ -5742,6 +6661,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5742
6661
  const ggml_tensor * src0 = node->src[0];
5743
6662
  const ggml_tensor * src1 = node->src[1];
5744
6663
  const ggml_tensor * src2 = node->src[2];
6664
+ const ggml_tensor * src3 = node->src[3];
5745
6665
 
5746
6666
  switch (node->op) {
5747
6667
  // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
@@ -5792,7 +6712,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5792
6712
  case GGML_OP_SUM_ROWS:
5793
6713
  case GGML_OP_IM2COL:
5794
6714
  case GGML_OP_TIMESTEP_EMBEDDING:
6715
+ case GGML_OP_POOL_2D:
6716
+ case GGML_OP_RWKV_WKV6:
5795
6717
  case GGML_OP_LEAKY_RELU:
6718
+ case GGML_OP_FLASH_ATTN_EXT:
5796
6719
  break;
5797
6720
  default:
5798
6721
  std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -5810,6 +6733,48 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5810
6733
  } else {
5811
6734
  compute_ctx = ctx->compute_ctx.lock();
5812
6735
  }
6736
+ } else {
6737
+ switch (node->op) {
6738
+ case GGML_OP_REPEAT:
6739
+ case GGML_OP_ACC:
6740
+ case GGML_OP_GET_ROWS:
6741
+ case GGML_OP_ADD:
6742
+ case GGML_OP_MUL:
6743
+ case GGML_OP_DIV:
6744
+ case GGML_OP_CONCAT:
6745
+ case GGML_OP_UPSCALE:
6746
+ case GGML_OP_SCALE:
6747
+ case GGML_OP_SQR:
6748
+ case GGML_OP_SIN:
6749
+ case GGML_OP_COS:
6750
+ case GGML_OP_CLAMP:
6751
+ case GGML_OP_PAD:
6752
+ case GGML_OP_CPY:
6753
+ case GGML_OP_CONT:
6754
+ case GGML_OP_DUP:
6755
+ case GGML_OP_NORM:
6756
+ case GGML_OP_GROUP_NORM:
6757
+ case GGML_OP_RMS_NORM:
6758
+ case GGML_OP_UNARY:
6759
+ case GGML_OP_DIAG_MASK_INF:
6760
+ case GGML_OP_SOFT_MAX:
6761
+ case GGML_OP_ROPE:
6762
+ case GGML_OP_ARGSORT:
6763
+ case GGML_OP_SUM_ROWS:
6764
+ case GGML_OP_IM2COL:
6765
+ case GGML_OP_TIMESTEP_EMBEDDING:
6766
+ case GGML_OP_POOL_2D:
6767
+ case GGML_OP_LEAKY_RELU:
6768
+ {
6769
+ // These operations all go through ggml_vk_op_f32, so short-circuit and
6770
+ // do the only thing needed for the dryrun.
6771
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
6772
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
6773
+ return false;
6774
+ }
6775
+ default:
6776
+ break;
6777
+ }
5813
6778
  }
5814
6779
 
5815
6780
  switch (node->op) {
@@ -5927,6 +6892,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5927
6892
  case GGML_OP_TIMESTEP_EMBEDDING:
5928
6893
  ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
5929
6894
 
6895
+ break;
6896
+ case GGML_OP_POOL_2D:
6897
+ ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
6898
+
5930
6899
  break;
5931
6900
  case GGML_OP_LEAKY_RELU:
5932
6901
  ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
@@ -5939,6 +6908,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5939
6908
  case GGML_OP_MUL_MAT_ID:
5940
6909
  ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
5941
6910
 
6911
+ break;
6912
+
6913
+ case GGML_OP_FLASH_ATTN_EXT:
6914
+ ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
6915
+
6916
+ break;
6917
+
6918
+ case GGML_OP_RWKV_WKV6:
6919
+ ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
6920
+
5942
6921
  break;
5943
6922
  default:
5944
6923
  return false;
@@ -6018,6 +6997,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
6018
6997
  case GGML_OP_SUM_ROWS:
6019
6998
  case GGML_OP_IM2COL:
6020
6999
  case GGML_OP_TIMESTEP_EMBEDDING:
7000
+ case GGML_OP_POOL_2D:
7001
+ case GGML_OP_RWKV_WKV6:
6021
7002
  case GGML_OP_LEAKY_RELU:
6022
7003
  case GGML_OP_REPEAT:
6023
7004
  buf = tensor->buffer;
@@ -6038,6 +7019,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
6038
7019
  break;
6039
7020
  case GGML_OP_MUL_MAT:
6040
7021
  case GGML_OP_MUL_MAT_ID:
7022
+ case GGML_OP_FLASH_ATTN_EXT:
6041
7023
  buf = tensor->buffer;
6042
7024
 
6043
7025
  break;
@@ -6186,13 +7168,8 @@ static void ggml_vk_get_device_description(int device, char * description, size_
6186
7168
 
6187
7169
  // device backend
6188
7170
 
6189
- static const char * ggml_backend_vk_buffer_get_name(ggml_backend_buffer_t buffer) {
6190
- ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context;
6191
- return ctx->name.c_str();
6192
- }
6193
-
6194
7171
  static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) {
6195
- return buffer->iface.get_name == ggml_backend_vk_buffer_get_name;
7172
+ return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name;
6196
7173
  }
6197
7174
 
6198
7175
  static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) {
@@ -6256,7 +7233,6 @@ static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t v
6256
7233
  }
6257
7234
 
6258
7235
  static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = {
6259
- /* .get_name = */ ggml_backend_vk_buffer_get_name,
6260
7236
  /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer,
6261
7237
  /* .get_base = */ ggml_backend_vk_buffer_get_base,
6262
7238
  /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor,
@@ -6352,7 +7328,6 @@ static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_
6352
7328
 
6353
7329
  ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
6354
7330
  buffer->buft = buft;
6355
- buffer->iface.get_name = ggml_backend_vk_host_buffer_name;
6356
7331
  buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer;
6357
7332
 
6358
7333
  return buffer;
@@ -6378,7 +7353,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
6378
7353
  /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
6379
7354
  /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host,
6380
7355
  },
6381
- /* .device = */ nullptr,
7356
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
6382
7357
  /* .context = */ nullptr,
6383
7358
  };
6384
7359
 
@@ -6541,16 +7516,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
6541
7516
  bool first_node_in_batch = true; // true if next node will be first node in a batch
6542
7517
  int submit_node_idx = 0; // index to first node in a batch
6543
7518
 
6544
- // submit work every submit_count node to overlap CPU cmdbuffer generation with GPU execution
6545
- constexpr int submit_count = 100;
7519
+ // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
7520
+ // Start with a smaller count to get work submitted right away, and increase it after each submit.
7521
+ int nodes_per_submit = 20;
6546
7522
  int submitted_nodes = 0;
7523
+ int submit_count = 0;
6547
7524
  for (int i = 0; i < cgraph->n_nodes; i++) {
6548
7525
  if (first_node_in_batch) {
6549
7526
  submit_node_idx = i;
6550
7527
  }
6551
7528
 
6552
- bool submit = (submitted_nodes >= submit_count) || (i == last_node);
6553
-
7529
+ bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
6554
7530
 
6555
7531
  bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
6556
7532
 
@@ -6567,6 +7543,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
6567
7543
  if (submit) {
6568
7544
  first_node_in_batch = true;
6569
7545
  submitted_nodes = 0;
7546
+ switch (submit_count) {
7547
+ case 0:
7548
+ nodes_per_submit = 50;
7549
+ break;
7550
+ default:
7551
+ nodes_per_submit = 100;
7552
+ break;
7553
+ }
7554
+ submit_count++;
6570
7555
  }
6571
7556
  }
6572
7557
 
@@ -6581,9 +7566,132 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
6581
7566
  UNUSED(backend);
6582
7567
  }
6583
7568
 
6584
- static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
6585
- // ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
7569
+ // TODO: enable async and synchronize
7570
+ static ggml_backend_i ggml_backend_vk_interface = {
7571
+ /* .get_name = */ ggml_backend_vk_name,
7572
+ /* .free = */ ggml_backend_vk_free,
7573
+ /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
7574
+ /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
7575
+ /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
7576
+ /* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
7577
+ /* .graph_plan_create = */ NULL,
7578
+ /* .graph_plan_free = */ NULL,
7579
+ /* .graph_plan_update = */ NULL,
7580
+ /* .graph_plan_compute = */ NULL,
7581
+ /* .graph_compute = */ ggml_backend_vk_graph_compute,
7582
+ /* .event_record = */ NULL,
7583
+ /* .event_wait = */ NULL,
7584
+ };
7585
+
7586
+ static ggml_guid_t ggml_backend_vk_guid() {
7587
+ static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
7588
+ return &guid;
7589
+ }
7590
+
7591
+ ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
7592
+ VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
7593
+
7594
+ ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
7595
+ ggml_vk_init(ctx, dev_num);
7596
+
7597
+ ggml_backend_t vk_backend = new ggml_backend {
7598
+ /* .guid = */ ggml_backend_vk_guid(),
7599
+ /* .interface = */ ggml_backend_vk_interface,
7600
+ /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
7601
+ /* .context = */ ctx,
7602
+ };
7603
+
7604
+ return vk_backend;
7605
+ }
7606
+
7607
+ bool ggml_backend_is_vk(ggml_backend_t backend) {
7608
+ return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
7609
+ }
7610
+
7611
+ int ggml_backend_vk_get_device_count() {
7612
+ return ggml_vk_get_device_count();
7613
+ }
7614
+
7615
+ void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
7616
+ GGML_ASSERT(device < (int) vk_instance.device_indices.size());
7617
+ int dev_idx = vk_instance.device_indices[device];
7618
+ ggml_vk_get_device_description(dev_idx, description, description_size);
7619
+ }
7620
+
7621
+ void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
7622
+ GGML_ASSERT(device < (int) vk_instance.device_indices.size());
7623
+
7624
+ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
7625
+
7626
+ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
7627
+
7628
+ for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
7629
+ if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
7630
+ *total = heap.size;
7631
+ *free = heap.size;
7632
+ break;
7633
+ }
7634
+ }
7635
+ }
7636
+
7637
+ //////////////////////////
6586
7638
 
7639
+ struct ggml_backend_vk_device_context {
7640
+ size_t device;
7641
+ std::string name;
7642
+ std::string description;
7643
+ };
7644
+
7645
+ static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
7646
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7647
+ return ctx->name.c_str();
7648
+ }
7649
+
7650
+ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
7651
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7652
+ return ctx->description.c_str();
7653
+ }
7654
+
7655
+ static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
7656
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
7657
+ ggml_backend_vk_get_device_memory(ctx->device, free, total);
7658
+ }
7659
+
7660
+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
7661
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7662
+ return ggml_backend_vk_buffer_type(ctx->device);
7663
+ }
7664
+
7665
+ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
7666
+ UNUSED(dev);
7667
+ return ggml_backend_vk_host_buffer_type();
7668
+ }
7669
+
7670
+ static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
7671
+ UNUSED(dev);
7672
+ return GGML_BACKEND_DEVICE_TYPE_GPU;
7673
+ }
7674
+
7675
+ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
7676
+ props->name = ggml_backend_vk_device_get_name(dev);
7677
+ props->description = ggml_backend_vk_device_get_description(dev);
7678
+ props->type = ggml_backend_vk_device_get_type(dev);
7679
+ ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
7680
+ props->caps = {
7681
+ /* .async = */ false,
7682
+ /* .host_buffer = */ true,
7683
+ /* .buffer_from_host_ptr = */ false,
7684
+ /* .events = */ false,
7685
+ };
7686
+ }
7687
+
7688
+ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
7689
+ UNUSED(params);
7690
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7691
+ return ggml_backend_vk_init(ctx->device);
7692
+ }
7693
+
7694
+ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
6587
7695
  switch (op->op) {
6588
7696
  case GGML_OP_UNARY:
6589
7697
  switch (ggml_get_unary_op(op)) {
@@ -6600,6 +7708,12 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
6600
7708
  case GGML_OP_MUL_MAT:
6601
7709
  case GGML_OP_MUL_MAT_ID:
6602
7710
  {
7711
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7712
+ const vk_device& device = ggml_vk_get_device(ctx->device);
7713
+ if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
7714
+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
7715
+ return false;
7716
+ }
6603
7717
  switch (op->src[0]->type) {
6604
7718
  case GGML_TYPE_F32:
6605
7719
  case GGML_TYPE_F16:
@@ -6630,8 +7744,64 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
6630
7744
  if (a->ne[3] != b->ne[3]) {
6631
7745
  return false;
6632
7746
  }
7747
+ if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
7748
+ !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
7749
+ return false;
7750
+ }
7751
+
6633
7752
  return true;
6634
7753
  } break;
7754
+ case GGML_OP_FLASH_ATTN_EXT:
7755
+ {
7756
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7757
+ if (!ggml_vk_get_device(ctx->device)->coopmat2) {
7758
+ return false;
7759
+ }
7760
+ switch (op->src[0]->ne[0]) {
7761
+ case 64:
7762
+ case 80:
7763
+ case 96:
7764
+ case 112:
7765
+ case 128:
7766
+ case 256:
7767
+ break;
7768
+ default:
7769
+ return false;
7770
+ }
7771
+ if (op->src[0]->type != GGML_TYPE_F32) {
7772
+ return false;
7773
+ }
7774
+ if (op->type != GGML_TYPE_F32) {
7775
+ return false;
7776
+ }
7777
+ if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
7778
+ return false;
7779
+ }
7780
+ // It's straightforward to support different K/V dequant, but would
7781
+ // significantly increase the number of pipelines
7782
+ if (op->src[1]->type != op->src[2]->type) {
7783
+ return false;
7784
+ }
7785
+ switch (op->src[1]->type) {
7786
+ case GGML_TYPE_F16:
7787
+ case GGML_TYPE_Q4_0:
7788
+ case GGML_TYPE_Q4_1:
7789
+ case GGML_TYPE_Q5_0:
7790
+ case GGML_TYPE_Q5_1:
7791
+ case GGML_TYPE_Q8_0:
7792
+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
7793
+ //case GGML_TYPE_Q2_K:
7794
+ //case GGML_TYPE_Q3_K:
7795
+ //case GGML_TYPE_Q4_K:
7796
+ //case GGML_TYPE_Q5_K:
7797
+ //case GGML_TYPE_Q6_K:
7798
+ case GGML_TYPE_IQ4_NL:
7799
+ break;
7800
+ default:
7801
+ return false;
7802
+ }
7803
+ return true;
7804
+ }
6635
7805
  case GGML_OP_GET_ROWS:
6636
7806
  {
6637
7807
  switch (op->src[0]->type) {
@@ -6668,7 +7838,16 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
6668
7838
  case GGML_OP_REPEAT:
6669
7839
  return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
6670
7840
  case GGML_OP_ROPE:
6671
- return ggml_is_contiguous(op->src[0]);
7841
+ {
7842
+ const int mode = ((const int32_t *) op->op_params)[2];
7843
+ if (mode & GGML_ROPE_TYPE_MROPE) {
7844
+ return false;
7845
+ }
7846
+ if (mode & GGML_ROPE_TYPE_VISION) {
7847
+ return false;
7848
+ }
7849
+ return ggml_is_contiguous(op->src[0]);
7850
+ }
6672
7851
  case GGML_OP_NONE:
6673
7852
  case GGML_OP_RESHAPE:
6674
7853
  case GGML_OP_VIEW:
@@ -6695,103 +7874,110 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
6695
7874
  case GGML_OP_SUM_ROWS:
6696
7875
  case GGML_OP_IM2COL:
6697
7876
  case GGML_OP_TIMESTEP_EMBEDDING:
7877
+ case GGML_OP_POOL_2D:
7878
+ case GGML_OP_RWKV_WKV6:
6698
7879
  case GGML_OP_LEAKY_RELU:
6699
7880
  return true;
6700
7881
  default:
6701
7882
  return false;
6702
7883
  }
6703
7884
 
6704
- UNUSED(backend);
6705
- }
6706
-
6707
- static bool ggml_backend_vk_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
6708
- const int min_batch_size = 32;
6709
-
6710
- return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
6711
- (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
6712
-
6713
- UNUSED(backend);
7885
+ UNUSED(dev);
6714
7886
  }
6715
7887
 
6716
- static bool ggml_backend_vk_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
7888
+ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
6717
7889
  if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
6718
7890
  return false;
6719
7891
  }
6720
7892
 
7893
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
6721
7894
  ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
6722
- ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
6723
-
6724
- return buft_ctx->device == ctx->device;
6725
- }
6726
-
6727
- // TODO: enable async and synchronize
6728
- static ggml_backend_i ggml_backend_vk_interface = {
6729
- /* .get_name = */ ggml_backend_vk_name,
6730
- /* .free = */ ggml_backend_vk_free,
6731
- /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
6732
- /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async,
6733
- /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async,
6734
- /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async,
6735
- /* .synchronize = */ NULL, // ggml_backend_vk_synchronize,
6736
- /* .graph_plan_create = */ NULL,
6737
- /* .graph_plan_free = */ NULL,
6738
- /* .graph_plan_update = */ NULL,
6739
- /* .graph_plan_compute = */ NULL,
6740
- /* .graph_compute = */ ggml_backend_vk_graph_compute,
6741
- /* .supports_op = */ ggml_backend_vk_supports_op,
6742
- /* .supports_buft = */ ggml_backend_vk_supports_buft,
6743
- /* .offload_op = */ ggml_backend_vk_offload_op,
6744
- /* .event_record = */ NULL,
6745
- /* .event_wait = */ NULL,
6746
- };
6747
7895
 
6748
- static ggml_guid_t ggml_backend_vk_guid() {
6749
- static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
6750
- return &guid;
7896
+ return buft_ctx->device->idx == ctx->device;
6751
7897
  }
6752
7898
 
6753
- ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
6754
- VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
7899
+ static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
7900
+ const int min_batch_size = 32;
6755
7901
 
6756
- ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
6757
- ggml_vk_init(ctx, dev_num);
7902
+ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
7903
+ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
6758
7904
 
6759
- ggml_backend_t vk_backend = new ggml_backend {
6760
- /* .guid = */ ggml_backend_vk_guid(),
6761
- /* .interface = */ ggml_backend_vk_interface,
6762
- /* .device = */ nullptr,
6763
- /* .context = */ ctx,
6764
- };
7905
+ UNUSED(dev);
7906
+ }
7907
+
7908
+ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
7909
+ /* .get_name = */ ggml_backend_vk_device_get_name,
7910
+ /* .get_description = */ ggml_backend_vk_device_get_description,
7911
+ /* .get_memory = */ ggml_backend_vk_device_get_memory,
7912
+ /* .get_type = */ ggml_backend_vk_device_get_type,
7913
+ /* .get_props = */ ggml_backend_vk_device_get_props,
7914
+ /* .init_backend = */ ggml_backend_vk_device_init,
7915
+ /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
7916
+ /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
7917
+ /* .buffer_from_host_ptr = */ NULL,
7918
+ /* .supports_op = */ ggml_backend_vk_device_supports_op,
7919
+ /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
7920
+ /* .offload_op = */ ggml_backend_vk_device_offload_op,
7921
+ /* .event_new = */ NULL,
7922
+ /* .event_free = */ NULL,
7923
+ /* .event_synchronize = */ NULL,
7924
+ };
6765
7925
 
6766
- return vk_backend;
7926
+ static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
7927
+ UNUSED(reg);
7928
+ return GGML_VK_NAME;
6767
7929
  }
6768
7930
 
6769
- bool ggml_backend_is_vk(ggml_backend_t backend) {
6770
- return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
7931
+ static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
7932
+ UNUSED(reg);
7933
+ return ggml_backend_vk_get_device_count();
6771
7934
  }
6772
7935
 
6773
- int ggml_backend_vk_get_device_count() {
6774
- return ggml_vk_get_device_count();
6775
- }
7936
+ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
7937
+ static std::vector<ggml_backend_dev_t> devices;
6776
7938
 
6777
- void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
6778
- ggml_vk_get_device_description(device, description, description_size);
6779
- }
7939
+ static bool initialized = false;
6780
7940
 
6781
- void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
6782
- GGML_ASSERT(device < (int) vk_instance.device_indices.size());
7941
+ {
7942
+ static std::mutex mutex;
7943
+ std::lock_guard<std::mutex> lock(mutex);
7944
+ if (!initialized) {
7945
+ for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) {
7946
+ ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
7947
+ char desc[256];
7948
+ ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
7949
+ ctx->device = i;
7950
+ ctx->name = GGML_VK_NAME + std::to_string(i);
7951
+ ctx->description = desc;
7952
+ devices.push_back(new ggml_backend_device {
7953
+ /* .iface = */ ggml_backend_vk_device_i,
7954
+ /* .reg = */ reg,
7955
+ /* .context = */ ctx,
7956
+ });
7957
+ }
7958
+ initialized = true;
7959
+ }
7960
+ }
6783
7961
 
6784
- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
7962
+ GGML_ASSERT(device < devices.size());
7963
+ return devices[device];
7964
+ }
6785
7965
 
6786
- vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
7966
+ static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
7967
+ /* .get_name = */ ggml_backend_vk_reg_get_name,
7968
+ /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
7969
+ /* .get_device = */ ggml_backend_vk_reg_get_device,
7970
+ /* .get_proc_address = */ NULL,
7971
+ };
6787
7972
 
6788
- for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
6789
- if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
6790
- *total = heap.size;
6791
- *free = heap.size;
6792
- break;
6793
- }
6794
- }
7973
+ ggml_backend_reg_t ggml_backend_vk_reg() {
7974
+ static ggml_backend_reg reg = {
7975
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
7976
+ /* .iface = */ ggml_backend_vk_reg_i,
7977
+ /* .context = */ nullptr,
7978
+ };
7979
+
7980
+ return &reg;
6795
7981
  }
6796
7982
 
6797
7983
  // Extension availability
@@ -6940,6 +8126,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
6940
8126
  ggml_tensor * src0 = tensor->src[0];
6941
8127
  ggml_tensor * src1 = tensor->src[1];
6942
8128
  ggml_tensor * src2 = tensor->src[2];
8129
+ ggml_tensor * src3 = tensor->src[3];
6943
8130
 
6944
8131
  struct ggml_init_params iparams = {
6945
8132
  /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
@@ -6952,15 +8139,18 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
6952
8139
  struct ggml_tensor * src0_clone = nullptr;
6953
8140
  struct ggml_tensor * src1_clone = nullptr;
6954
8141
  struct ggml_tensor * src2_clone = nullptr;
8142
+ struct ggml_tensor * src3_clone = nullptr;
6955
8143
  struct ggml_tensor * tensor_clone = nullptr;
6956
8144
 
6957
8145
  size_t src0_size;
6958
8146
  size_t src1_size;
6959
8147
  size_t src2_size;
8148
+ size_t src3_size;
6960
8149
 
6961
8150
  void * src0_buffer = nullptr;
6962
8151
  void * src1_buffer = nullptr;
6963
8152
  void * src2_buffer = nullptr;
8153
+ void * src3_buffer = nullptr;
6964
8154
 
6965
8155
  if (src0 != nullptr) {
6966
8156
  src0_clone = ggml_dup_tensor(ggml_ctx, src0);
@@ -7088,8 +8278,53 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7088
8278
  ggml_vk_print_tensor(src2, "src2");
7089
8279
  }
7090
8280
  }
8281
+ if (src3 != nullptr) {
8282
+ src3_clone = ggml_dup_tensor(ggml_ctx, src3);
8283
+
8284
+ src3_size = ggml_nbytes(src3);
8285
+
8286
+ src3_buffer = malloc(src3_size);
8287
+ src3_clone->data = src3_buffer;
8288
+ if (ggml_backend_buffer_is_host(src3->buffer)) {
8289
+ memcpy(src3_clone->data, src3->data, src3_size);
8290
+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
8291
+ } else if (ggml_backend_buffer_is_vk(src3->buffer)) {
8292
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
8293
+ vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8294
+ uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
8295
+ if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
8296
+ for (int i3 = 0; i3 < src3->ne[3]; i3++) {
8297
+ for (int i2 = 0; i2 < src3->ne[2]; i2++) {
8298
+ const int idx = i3*src3->ne[2] + i2;
8299
+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
8300
+ }
8301
+ }
8302
+
8303
+ src3_clone->nb[0] = src3->nb[0];
8304
+ src3_clone->nb[1] = src3->nb[1];
8305
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
8306
+ src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
8307
+ }
8308
+ } else {
8309
+ if (offset + src3_size >= buffer_gpu->size) {
8310
+ src3_size = buffer_gpu->size - offset;
8311
+ }
8312
+ ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
8313
+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
8314
+ }
8315
+ } else {
8316
+ GGML_ABORT("fatal error");
8317
+ }
8318
+
8319
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8320
+ ggml_vk_print_tensor(src3, "src3");
8321
+ }
8322
+ }
7091
8323
 
7092
- if (tensor->op == GGML_OP_MUL_MAT) {
8324
+ if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
8325
+ const float *params = (const float *)tensor->op_params;
8326
+ tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
8327
+ } else if (tensor->op == GGML_OP_MUL_MAT) {
7093
8328
  tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
7094
8329
  } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
7095
8330
  tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
@@ -7204,10 +8439,24 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7204
8439
  const int32_t dim = tensor->op_params[0];
7205
8440
  const int32_t max_period = tensor->op_params[1];
7206
8441
  tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
8442
+ } else if (tensor->op == GGML_OP_POOL_2D) {
8443
+ enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
8444
+ const int32_t k0 = tensor->op_params[1];
8445
+ const int32_t k1 = tensor->op_params[2];
8446
+ const int32_t s0 = tensor->op_params[3];
8447
+ const int32_t s1 = tensor->op_params[4];
8448
+ const int32_t p0 = tensor->op_params[5];
8449
+ const int32_t p1 = tensor->op_params[6];
8450
+
8451
+ tensor_clone = ggml_pool_2d(ggml_ctx, src0_clone, op, k0, k1, s0, s1, p0, p1);
7207
8452
  } else if (tensor->op == GGML_OP_LEAKY_RELU) {
7208
8453
  const float * op_params = (const float *)tensor->op_params;
7209
8454
  tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
7210
- } else {
8455
+ } else if (tensor->op == GGML_OP_RWKV_WKV6) {
8456
+ tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
8457
+ tensor->src[4], tensor->src[5]);
8458
+ }
8459
+ else {
7211
8460
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
7212
8461
  GGML_ABORT("fatal error");
7213
8462
  }
@@ -7404,3 +8653,5 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
7404
8653
  VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
7405
8654
  }
7406
8655
  #endif
8656
+
8657
+ GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg)