@fugood/llama.node 0.3.3 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +1 -1
  21. package/src/LlamaContext.cpp +81 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -1,7 +1,8 @@
1
1
  #include "ggml-vulkan.h"
2
2
  #include <vulkan/vulkan_core.h>
3
- #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF)
3
+ #if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_PERF) || defined(GGML_VULKAN_CHECK_RESULTS)
4
4
  #include <chrono>
5
+ #include "ggml-cpu.h"
5
6
  #endif
6
7
 
7
8
  #include <vulkan/vulkan.hpp>
@@ -43,12 +44,6 @@
43
44
 
44
45
  #define MAX_VK_BUFFERS 256
45
46
 
46
- #ifndef K_QUANTS_PER_ITERATION
47
- #define K_QUANTS_PER_ITERATION 1
48
- #else
49
- static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
50
- #endif
51
-
52
47
  #define VK_CHECK(err, msg) \
53
48
  do { \
54
49
  vk::Result err_ = (err); \
@@ -158,29 +153,53 @@ struct vk_device_struct {
158
153
  std::string name;
159
154
  uint64_t max_memory_allocation_size;
160
155
  bool fp16;
156
+ bool pipeline_robustness;
161
157
  vk::Device device;
162
158
  uint32_t vendor_id;
163
159
  vk_queue compute_queue;
164
160
  vk_queue transfer_queue;
165
161
  bool single_queue;
166
162
  uint32_t subgroup_size;
163
+ uint32_t shader_core_count;
167
164
  bool uma;
165
+ bool float_controls_rte_fp16;
166
+
167
+ bool subgroup_size_control;
168
+ uint32_t subgroup_min_size;
169
+ uint32_t subgroup_max_size;
170
+ bool subgroup_require_full_support;
171
+
172
+ bool coopmat_support;
173
+ bool coopmat_acc_f32_support;
174
+ bool coopmat_acc_f16_support;
175
+ uint32_t coopmat_m;
176
+ uint32_t coopmat_n;
177
+ uint32_t coopmat_k;
178
+ bool coopmat2;
168
179
 
169
180
  size_t idx;
170
181
 
182
+ bool mul_mat_l;
183
+ bool mul_mat_m;
184
+ bool mul_mat_s;
185
+ bool mul_mat_id_l;
186
+ bool mul_mat_id_m;
187
+ bool mul_mat_id_s;
188
+
171
189
  vk_matmul_pipeline pipeline_matmul_f32;
172
190
  vk_matmul_pipeline pipeline_matmul_f32_f16;
173
191
  vk_matmul_pipeline2 pipeline_matmul_f16;
174
192
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
175
193
  vk_pipeline pipeline_matmul_split_k_reduce;
176
194
 
195
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
177
196
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
178
197
 
179
198
  vk_matmul_pipeline pipeline_matmul_id_f32;
180
- vk_matmul_pipeline pipeline_matmul_id_f16;
181
- vk_matmul_pipeline pipeline_matmul_id_f16_f32;
199
+ vk_matmul_pipeline2 pipeline_matmul_id_f16;
200
+ vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
182
201
 
183
- vk_matmul_pipeline pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
202
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
184
203
 
185
204
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
186
205
  vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
@@ -218,6 +237,7 @@ struct vk_device_struct {
218
237
  vk_pipeline pipeline_tanh_f32;
219
238
  vk_pipeline pipeline_diag_mask_inf_f32;
220
239
  vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
240
+ vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
221
241
  vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
222
242
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
223
243
  vk_pipeline pipeline_argsort_f32;
@@ -225,6 +245,15 @@ struct vk_device_struct {
225
245
  vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
226
246
  vk_pipeline pipeline_timestep_embedding_f32;
227
247
  vk_pipeline pipeline_pool2d_f32;
248
+ vk_pipeline pipeline_rwkv_wkv6_f32;
249
+
250
+ // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
251
+ vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
252
+ vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
253
+ vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
254
+ vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
255
+ vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
256
+ vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
228
257
 
229
258
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
230
259
  std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
@@ -337,6 +366,40 @@ struct vk_mat_vec_id_push_constants {
337
366
  uint32_t nei0; uint32_t ne11;
338
367
  };
339
368
 
369
+ struct vk_flash_attn_push_constants {
370
+ uint32_t N;
371
+ uint32_t KV;
372
+
373
+ uint32_t ne1;
374
+ uint32_t ne2;
375
+ uint32_t ne3;
376
+
377
+ uint32_t neq2;
378
+ uint32_t neq3;
379
+ uint32_t nek2;
380
+ uint32_t nek3;
381
+ uint32_t nev2;
382
+ uint32_t nev3;
383
+ uint32_t nem1;
384
+
385
+ uint32_t nb02;
386
+ uint32_t nb03;
387
+ uint32_t nb12;
388
+ uint32_t nb13;
389
+ uint32_t nb22;
390
+ uint32_t nb23;
391
+ uint32_t nb31;
392
+
393
+ float scale;
394
+ float max_bias;
395
+ float logit_softcap;
396
+
397
+ uint32_t mask;
398
+ uint32_t n_head_log2;
399
+ float m0;
400
+ float m1;
401
+ };
402
+
340
403
  struct vk_op_push_constants {
341
404
  uint32_t KX;
342
405
  uint32_t KY;
@@ -350,7 +413,46 @@ struct vk_op_unary_push_constants {
350
413
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
351
414
  uint32_t d_offset;
352
415
  float param1; float param2;
416
+ uint32_t ne0_012mp; uint32_t ne0_012L;
417
+ uint32_t ne0_01mp; uint32_t ne0_01L;
418
+ uint32_t ne0_0mp; uint32_t ne0_0L;
419
+ uint32_t ne1_012mp; uint32_t ne1_012L;
420
+ uint32_t ne1_01mp; uint32_t ne1_01L;
421
+ uint32_t ne1_0mp; uint32_t ne1_0L;
353
422
  };
423
+ static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128");
424
+
425
+ // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
426
+ // Precompute mp (m' in the paper) and L such that division
427
+ // can be computed using a multiply (high 32b of 64b result)
428
+ // and a shift:
429
+ //
430
+ // n/d = (mulhi(n, mp) + n) >> L;
431
+ static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L)
432
+ {
433
+ // compute L = ceil(log2(d));
434
+ L = 0;
435
+ while (L < 32 && (uint32_t{1} << L) < d) {
436
+ L++;
437
+ }
438
+
439
+ mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1);
440
+ }
441
+
442
+ template <typename T> void init_pushconst_fastdiv(T &p) {
443
+ GGML_UNUSED(p);
444
+ static_assert(!std::is_const<T>::value, "unexpected type");
445
+ }
446
+
447
+ template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) {
448
+ // Compute magic values to divide by these six numbers.
449
+ init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L);
450
+ init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L);
451
+ init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L);
452
+ init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L);
453
+ init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L);
454
+ init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L);
455
+ }
354
456
 
355
457
  struct vk_op_binary_push_constants {
356
458
  uint32_t ne;
@@ -388,6 +490,7 @@ struct vk_op_soft_max_push_constants {
388
490
  float m0;
389
491
  float m1;
390
492
  uint32_t n_head_log2;
493
+ uint32_t nrows_x;
391
494
  };
392
495
 
393
496
  struct vk_op_argsort_push_constants {
@@ -426,6 +529,13 @@ struct vk_op_pool2d_push_constants {
426
529
  int32_t p0; int32_t p1;
427
530
  };
428
531
 
532
+ struct vk_op_rwkv_wkv6_push_constants {
533
+ uint32_t B;
534
+ uint32_t T;
535
+ uint32_t C;
536
+ uint32_t H;
537
+ };
538
+
429
539
  // Allow pre-recording command buffers
430
540
  struct vk_staging_memcpy {
431
541
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -652,8 +762,12 @@ static uint32_t compile_count = 0;
652
762
  static std::mutex compile_count_mutex;
653
763
  static std::condition_variable compile_count_cond;
654
764
 
655
- static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants, uint32_t align) {
656
- VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")");
765
+ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, const std::string name, size_t spv_size, const void* spv_data, const std::string entrypoint,
766
+ uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<uint32_t> specialization_constants,
767
+ uint32_t align, bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) {
768
+ VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size <<
769
+ ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align <<
770
+ ", " << disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")");
657
771
  GGML_ASSERT(parameter_count > 0);
658
772
  GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
659
773
 
@@ -712,16 +826,39 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin
712
826
  specialization_constants.data()
713
827
  );
714
828
 
829
+ vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{};
830
+
831
+ if (device->subgroup_require_full_support && require_full_subgroups) {
832
+ pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT;
833
+ }
834
+
715
835
  vk::PipelineShaderStageCreateInfo pipeline_shader_create_info(
716
- vk::PipelineShaderStageCreateFlags(),
836
+ pipeline_shader_stage_create_flags,
717
837
  vk::ShaderStageFlagBits::eCompute,
718
838
  pipeline->shader_module,
719
839
  entrypoint.c_str(),
720
840
  &specialization_info);
841
+
842
+ vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info;
843
+ pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size;
844
+ if (device->subgroup_size_control && required_subgroup_size > 0) {
845
+ GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size);
846
+ pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info);
847
+ }
848
+
721
849
  vk::ComputePipelineCreateInfo compute_pipeline_create_info(
722
- vk::PipelineCreateFlags(),
850
+ vk::PipelineCreateFlags{},
723
851
  pipeline_shader_create_info,
724
852
  pipeline->layout);
853
+
854
+ vk::PipelineRobustnessCreateInfoEXT rci;
855
+
856
+ if (device->pipeline_robustness && disable_robustness) {
857
+ rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
858
+ rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled;
859
+ compute_pipeline_create_info.setPNext(&rci);
860
+ }
861
+
725
862
  pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value;
726
863
 
727
864
  {
@@ -1214,52 +1351,186 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
1214
1351
  );
1215
1352
  }
1216
1353
 
1354
+ // number of rows/cols for flash attention shader
1355
+ static constexpr uint32_t flash_attention_num_small_rows = 32;
1356
+ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1357
+ GGML_UNUSED(clamp);
1358
+
1359
+ // small rows, large cols
1360
+ if (small_rows) {
1361
+ return {flash_attention_num_small_rows, 128};
1362
+ }
1363
+ // small cols to reduce register count
1364
+ if (ggml_is_quantized(type) || D == 256) {
1365
+ return {64, 32};
1366
+ }
1367
+ return {64, 64};
1368
+ };
1369
+
1370
+ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector<uint32_t>& warptile, bool mul_mat_id) {
1371
+ // Needs to be kept up to date on shader changes
1372
+ const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1;
1373
+ const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float);
1374
+ const uint32_t warps = warptile[0] / warptile[10];
1375
+
1376
+ const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1377
+ const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
1378
+ const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1379
+
1380
+ return (load_bufs + mmid_row_ids + coopmat_stage) <= device->properties.limits.maxComputeSharedMemorySize;
1381
+ }
1382
+
1217
1383
  static void ggml_vk_load_shaders(vk_device& device) {
1218
1384
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
1219
1385
 
1220
1386
  std::cerr << "ggml_vulkan: Compiling shaders";
1221
1387
 
1222
- // mulmat
1223
- std::vector<uint32_t> l_warptile, m_warptile, s_warptile, l_warptile_mmq, m_warptile_mmq, s_warptile_mmq;
1224
- std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms;
1225
- uint32_t l_align, m_align, s_align;
1226
-
1227
- l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1228
- m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1229
- s_warptile = { std::max(device->subgroup_size, 16u), 32, 32, 16, 32, 32, 2, 2, 2, device->subgroup_size };
1230
-
1231
- l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, 4, 4, device->subgroup_size };
1232
- m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, 4, 2, device->subgroup_size };
1233
- s_warptile_mmq = { std::max(device->subgroup_size, 16u), 32, 32, 32, 32, 32, 2, 2, 2, device->subgroup_size };
1388
+ // some shaders have a minimum subgroup size
1389
+ const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u);
1390
+ const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u);
1234
1391
 
1235
- l_wg_denoms = {128, 128, 1 };
1236
- m_wg_denoms = { 64, 64, 1 };
1237
- s_wg_denoms = { 32, 32, 1 };
1392
+ // mulmat
1393
+ std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
1394
+ l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
1395
+ l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
1396
+ l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
1397
+ std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
1398
+ l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms,
1399
+ l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k,
1400
+ l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms;
1238
1401
 
1239
- l_align = 128;
1240
- m_align = 64;
1241
- s_align = 32;
1402
+ uint32_t l_align, m_align, s_align;
1403
+ if (device->coopmat2) {
1404
+ // spec constants and tile sizes for non-quant matmul/matmul_id
1405
+ l_warptile = { 256, 128, 256, 64 };
1406
+ m_warptile = { 256, 128, 128, 64 };
1407
+ s_warptile = { 128, 32, 16, 64 };
1408
+ l_wg_denoms = {128, 256, 1 };
1409
+ m_wg_denoms = {128, 128, 1 };
1410
+ s_wg_denoms = { 32, 16, 1 };
1411
+
1412
+ // spec constants and tile sizes for quant matmul (non-Qi_K)
1413
+ l_warptile_mmq = { 256, 128, 256, 64 };
1414
+ m_warptile_mmq = { 256, 128, 128, 64 };
1415
+ s_warptile_mmq = { 256, 128, 128, 64 };
1416
+ l_mmq_wg_denoms = { 128, 256, 1 };
1417
+ m_mmq_wg_denoms = { 128, 128, 1 };
1418
+ s_mmq_wg_denoms = { 128, 128, 1 };
1419
+
1420
+ // spec constants and tile sizes for quant matmul (Qi_K)
1421
+ l_warptile_mmq_k = { 256, 128, 512, 16 };
1422
+ m_warptile_mmq_k = { 256, 128, 256, 16 };
1423
+ s_warptile_mmq_k = { 256, 32, 128, 64 };
1424
+ l_mmq_wg_denoms_k = { 128, 512, 1 };
1425
+ m_mmq_wg_denoms_k = { 128, 256, 1 };
1426
+ s_mmq_wg_denoms_k = { 32, 128, 1 };
1427
+
1428
+ // spec constants and tile sizes for quant matmul_id
1429
+ l_warptile_mmqid = { 256, 128, 128, 16 };
1430
+ m_warptile_mmqid = { 256, 128, 64, 16 };
1431
+ s_warptile_mmqid = { 256, 64, 64, 16 };
1432
+ l_mmqid_wg_denoms = { 128, 128, 1 };
1433
+ m_mmqid_wg_denoms = { 128, 64, 1 };
1434
+ s_mmqid_wg_denoms = { 64, 64, 1 };
1435
+
1436
+ l_align = 128;
1437
+ m_align = 64;
1438
+ s_align = 32;
1439
+ } else {
1440
+ // Matrix cores require different warp group sizes
1441
+ const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4;
1442
+ const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4;
1443
+ const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2;
1444
+ const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4;
1445
+ const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2;
1446
+ const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2;
1447
+ const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1;
1448
+ const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
1449
+ const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
1450
+
1451
+ l_warptile = { 128, 128, 128, 16, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1452
+ m_warptile = { 128, 64, 64, 16, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1453
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1454
+
1455
+ l_warptile_mmq = { 128, 128, 128, 32, device->subgroup_size * 2, 64, 2, tm_l, tn_l, tk_l, device->subgroup_size };
1456
+ m_warptile_mmq = { 128, 64, 64, 32, device->subgroup_size, 32, 2, tm_m, tn_m, tk_m, device->subgroup_size };
1457
+ s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, device->subgroup_size };
1458
+
1459
+ l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1460
+ m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1461
+ s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
1462
+ l_align = 128;
1463
+ m_align = 64;
1464
+ s_align = 32;
1465
+
1466
+ // Fallback to smaller sizes if there's not enough shared memory. Given the current shaders
1467
+ // and tile sizes, this should handle 16KB, 32KB, and 48KB+.
1468
+ // This logic doesn't explicitly account for the 12KB row_ids in the mul_mat_mat_id shaders.
1469
+ // But the numbers happen to work out for 32KB shared memory size that when using the medium
1470
+ // size there's enough room for everything, and we assert for this.
1471
+ uint32_t shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1472
+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1473
+ l_warptile = m_warptile;
1474
+ l_wg_denoms = m_wg_denoms;
1475
+ shmem_needed = (l_warptile[1] + l_warptile[2]) * (l_warptile[3] + 1) * sizeof(float);
1476
+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1477
+ }
1478
+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1479
+ // assert mul_mat_mat_id shaders will fit.
1480
+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1481
+ }
1482
+
1483
+ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1484
+ if (shmem_needed > device->properties.limits.maxComputeSharedMemorySize) {
1485
+ if (device->properties.limits.maxComputeSharedMemorySize == 32768) {
1486
+ l_warptile_mmq = m_warptile_mmq;
1487
+ l_mmq_wg_denoms = m_mmq_wg_denoms;
1488
+ } else {
1489
+ l_warptile_mmq = s_warptile_mmq;
1490
+ l_mmq_wg_denoms = s_mmq_wg_denoms;
1491
+ }
1492
+ shmem_needed = (l_warptile_mmq[1] + l_warptile_mmq[2]) * (l_warptile_mmq[3] + 1) * sizeof(float);
1493
+ GGML_ASSERT(shmem_needed <= device->properties.limits.maxComputeSharedMemorySize);
1494
+ }
1495
+ if (device->properties.limits.maxComputeSharedMemorySize >= 32768) {
1496
+ // assert mul_mat_mat_id shaders will fit.
1497
+ GGML_ASSERT(shmem_needed + 3072*4 <= device->properties.limits.maxComputeSharedMemorySize);
1498
+ }
1499
+ // Disable medium and large matrix multiplication if not enough shared memory is available
1500
+ // Check mmq warptiles as the largest configuration
1501
+ // Throw an error if not enough for any matrix multiplication is available
1502
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false)) {
1503
+ std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl;
1504
+ throw std::runtime_error("Shared memory size too small for matrix multiplication.");
1505
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false)) {
1506
+ device->mul_mat_m = false;
1507
+ device->mul_mat_l = false;
1508
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false)) {
1509
+ device->mul_mat_l = false;
1510
+ }
1511
+
1512
+ // Disable mul_mat_id if not enough shared memory is available
1513
+ if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, true)) {
1514
+ device->mul_mat_id_s = false;
1515
+ device->mul_mat_id_m = false;
1516
+ device->mul_mat_id_l = false;
1517
+ } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, true)) {
1518
+ device->mul_mat_id_m = false;
1519
+ device->mul_mat_id_l = false;
1520
+ } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, true)) {
1521
+ device->mul_mat_id_l = false;
1522
+ }
1523
+ }
1242
1524
 
1243
1525
  device->pipeline_matmul_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1244
1526
  device->pipeline_matmul_f32_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1245
1527
 
1246
1528
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1247
- device->pipeline_matmul_id_f16_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1248
- device->pipeline_matmul_id_f16 = std::make_shared<vk_matmul_pipeline_struct>();
1249
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0] = std::make_shared<vk_matmul_pipeline_struct>();
1250
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1] = std::make_shared<vk_matmul_pipeline_struct>();
1251
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0] = std::make_shared<vk_matmul_pipeline_struct>();
1252
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1] = std::make_shared<vk_matmul_pipeline_struct>();
1253
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0] = std::make_shared<vk_matmul_pipeline_struct>();
1254
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K] = std::make_shared<vk_matmul_pipeline_struct>();
1255
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K] = std::make_shared<vk_matmul_pipeline_struct>();
1256
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K] = std::make_shared<vk_matmul_pipeline_struct>();
1257
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K] = std::make_shared<vk_matmul_pipeline_struct>();
1258
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K] = std::make_shared<vk_matmul_pipeline_struct>();
1259
- device->pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL] = std::make_shared<vk_matmul_pipeline_struct>();
1260
1529
 
1261
1530
  std::vector<std::future<void>> compiles;
1262
- auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants, uint32_t align) {
1531
+ auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
1532
+ uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
1533
+ uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
1263
1534
  {
1264
1535
  // wait until fewer than N compiles are in progress
1265
1536
  uint32_t N = std::max(1u, std::thread::hardware_concurrency());
@@ -1269,144 +1540,368 @@ static void ggml_vk_load_shaders(vk_device& device) {
1269
1540
  }
1270
1541
  compile_count++;
1271
1542
  }
1272
- compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint, parameter_count, push_constant_size, wg_denoms, specialization_constants, align));
1543
+ compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), name, spv_size, spv_data, entrypoint,
1544
+ parameter_count, push_constant_size, wg_denoms, specialization_constants, align, disable_robustness, require_full_subgroups, required_subgroup_size));
1273
1545
  };
1274
1546
 
1275
- if (device->fp16) {
1547
+ #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1548
+ if (device->coopmat2) {
1549
+
1550
+ auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1551
+ return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
1552
+ };
1553
+
1554
+ auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
1555
+ // For large number of rows, 128 invocations seems to work best.
1556
+ // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
1557
+ // can't use 256 for D==80.
1558
+ uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
1559
+ auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
1560
+ return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
1561
+ };
1562
+
1563
+ #define CREATE_FA2(TYPE, NAMELC, D) \
1564
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1565
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1566
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1567
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1568
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1569
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1570
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1571
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1572
+
1573
+ #define CREATE_FA(TYPE, NAMELC) \
1574
+ CREATE_FA2(TYPE, NAMELC, 64) \
1575
+ CREATE_FA2(TYPE, NAMELC, 80) \
1576
+ CREATE_FA2(TYPE, NAMELC, 96) \
1577
+ CREATE_FA2(TYPE, NAMELC, 112) \
1578
+ CREATE_FA2(TYPE, NAMELC, 128) \
1579
+ CREATE_FA2(TYPE, NAMELC, 256)
1580
+
1581
+ CREATE_FA(GGML_TYPE_F16, f16)
1582
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0)
1583
+ CREATE_FA(GGML_TYPE_Q4_1, q4_1)
1584
+ CREATE_FA(GGML_TYPE_Q5_0, q5_0)
1585
+ CREATE_FA(GGML_TYPE_Q5_1, q5_1)
1586
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0)
1587
+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
1588
+ //CREATE_FA(GGML_TYPE_Q2_K, q2_k)
1589
+ //CREATE_FA(GGML_TYPE_Q3_K, q3_k)
1590
+ //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
1591
+ //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
1592
+ //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
1593
+ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
1594
+ #undef CREATE_FA
1595
+
1276
1596
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1277
1597
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1278
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1279
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1280
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1281
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1282
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1283
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1284
-
1285
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1286
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1287
- CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1288
- CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1289
-
1290
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1291
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1292
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1293
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1294
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1295
-
1296
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1297
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1298
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1299
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1300
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1301
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1302
-
1303
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1304
- CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1305
- CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1306
-
1307
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1308
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1309
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1310
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1311
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1312
-
1313
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1314
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1315
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1316
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1317
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1318
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1598
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1599
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1600
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1601
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1602
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1603
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1604
+
1605
+ // Create 2 variants, {f16,f32} accumulator
1606
+ #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1607
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1608
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1609
+
1610
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1611
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1612
+
1613
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1614
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
1615
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1616
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1617
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1618
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1619
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1620
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1621
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1622
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1623
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1624
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
1625
+ CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1626
+
1627
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1628
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1629
+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
1630
+
1631
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1632
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1633
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1634
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1635
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1636
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1637
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1638
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1639
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1640
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1641
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1642
+ #undef CREATE_MM
1643
+ #undef CREATE_MM2
1644
+ } else
1645
+ #endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1646
+ if (device->coopmat_support) {
1647
+ // Create 6 variants, {s,m,l}x{unaligned,aligned}
1648
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1649
+ if (device->mul_mat ## ID ## _l) \
1650
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
1651
+ if (device->mul_mat ## ID ## _m) \
1652
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
1653
+ if (device->mul_mat ## ID ## _s) \
1654
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
1655
+ if (device->mul_mat ## ID ## _l) \
1656
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
1657
+ if (device->mul_mat ## ID ## _m) \
1658
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
1659
+ if (device->mul_mat ## ID ## _s) \
1660
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
1661
+
1662
+ // Create 2 variants, {f16,f32} accumulator
1663
+ #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1664
+ if (device->coopmat_acc_f16_support) { \
1665
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1666
+ } \
1667
+ if (device->coopmat_acc_f32_support) { \
1668
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1669
+ } \
1670
+
1671
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1672
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1673
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1674
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1675
+
1676
+ if (device->coopmat_acc_f16_support) {
1677
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1678
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1679
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1680
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1681
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1682
+
1683
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1684
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1685
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1686
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1687
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1688
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1689
+ } else {
1690
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1691
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1692
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1693
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1694
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1695
+
1696
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1697
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1698
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1699
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1700
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1701
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1702
+ }
1703
+
1704
+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1705
+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1706
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1707
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1708
+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1709
+
1710
+ if (device->coopmat_acc_f16_support) {
1711
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1712
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1713
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1714
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1715
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1716
+
1717
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1718
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1719
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1720
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1721
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1722
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1723
+ } else {
1724
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1725
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1726
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1727
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1728
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1729
+
1730
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1731
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1732
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1733
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1734
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1735
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1736
+ }
1737
+ }
1738
+ #undef CREATE_MM2
1739
+ #undef CREATE_MM
1740
+ } else if (device->fp16) {
1741
+ // Create 6 variants, {s,m,l}x{unaligned,aligned}
1742
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1743
+ if (device->mul_mat ## ID ## _l) \
1744
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1745
+ if (device->mul_mat ## ID ## _m) \
1746
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1747
+ if (device->mul_mat ## ID ## _s) \
1748
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1749
+ if (device->mul_mat ## ID ## _l) \
1750
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1751
+ if (device->mul_mat ## ID ## _m) \
1752
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1753
+ if (device->mul_mat ## ID ## _s) \
1754
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1755
+
1756
+ // Create 2 variants, {f16,f32} accumulator
1757
+ #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1758
+ CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1759
+ CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1760
+
1761
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1762
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1763
+ CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1764
+ CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1765
+
1766
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1767
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1768
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1769
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1770
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1771
+
1772
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f16acc, matmul_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1773
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f16acc, matmul_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1774
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f16acc, matmul_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1775
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f16acc, matmul_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1776
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1777
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1778
+
1779
+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1780
+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1781
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1782
+ CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1783
+ CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1784
+
1785
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1786
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1787
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1788
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1789
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1790
+
1791
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1792
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1793
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1794
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1795
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1796
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1797
+ }
1798
+ #undef CREATE_MM2
1319
1799
  #undef CREATE_MM
1320
1800
  } else {
1321
1801
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1322
- #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1323
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1324
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1325
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1326
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1327
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1328
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1329
-
1330
- CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1331
- CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1332
- CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1333
- CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3);
1334
-
1335
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1336
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1337
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1338
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1339
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1340
-
1341
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1342
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1343
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1344
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1345
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1346
- CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3);
1347
-
1348
- CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1349
- CREATE_MM(pipeline_matmul_id_f16, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1350
- CREATE_MM(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4);
1351
-
1352
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1353
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1354
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1355
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1356
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1357
-
1358
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1359
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1360
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1361
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1362
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1363
- CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, , wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4);
1802
+ #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1803
+ if (device->mul_mat ## ID ## _l) \
1804
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
1805
+ if (device->mul_mat ## ID ## _m) \
1806
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
1807
+ if (device->mul_mat ## ID ## _s) \
1808
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
1809
+ if (device->mul_mat ## ID ## _l) \
1810
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \
1811
+ if (device->mul_mat ## ID ## _m) \
1812
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \
1813
+ if (device->mul_mat ## ID ## _s) \
1814
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
1815
+
1816
+ CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1817
+ CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1818
+ CREATE_MM(pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1819
+ CREATE_MM(pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1820
+
1821
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1822
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1823
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1824
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1825
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1826
+
1827
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1828
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1829
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1830
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1831
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1832
+ CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
1833
+
1834
+ // If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
1835
+ if (device->mul_mat_id_s || device->mul_mat_id_m || device->mul_mat_id_l) {
1836
+ CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1837
+ CREATE_MM(pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1838
+ CREATE_MM(pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
1839
+
1840
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1841
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1842
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1843
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1844
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1845
+
1846
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1847
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1848
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1849
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1850
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1851
+ CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
1852
+ }
1364
1853
  #undef CREATE_MM
1365
1854
  }
1366
1855
 
1367
1856
  // mul mat vec
1368
- // computing two rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
1857
+
1858
+ // AMD GCN and Intel graphics cards perform best when the number of rows per shader is doubled
1859
+ uint32_t rm = 1;
1860
+ if ((device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size == 64 && device->subgroup_max_size == 64) || device->vendor_id == VK_VENDOR_ID_INTEL)
1861
+ rm = 2;
1862
+
1863
+ // computing additional rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
1369
1864
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1370
1865
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1371
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1372
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1373
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1374
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1375
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
1376
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1377
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1378
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1379
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1380
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1381
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1866
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1867
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1868
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1869
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1870
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
1871
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1872
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1873
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1874
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1875
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1876
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
1382
1877
 
1383
1878
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1384
1879
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1385
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1386
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1387
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1388
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1389
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
1390
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1391
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1392
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1393
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1394
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1395
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size}, 1);
1880
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1881
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1882
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1883
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1884
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
1885
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1886
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1887
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1888
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1889
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1890
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
1396
1891
 
1397
1892
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1398
1893
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1399
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1400
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1401
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1402
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1403
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size, 1}, 1);
1404
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1405
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1406
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1407
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1408
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
1409
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1894
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1895
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1896
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1897
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1898
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
1899
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1900
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1901
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1902
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1903
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1904
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
1410
1905
 
1411
1906
  // dequant shaders
1412
1907
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -1441,7 +1936,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1441
1936
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
1442
1937
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
1443
1938
 
1444
- ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256, 1, 1}, {}, 1);
1939
+ ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
1445
1940
 
1446
1941
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
1447
1942
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
@@ -1497,26 +1992,39 @@ static void ggml_vk_load_shaders(vk_device& device) {
1497
1992
 
1498
1993
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1);
1499
1994
 
1500
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
1501
- ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1);
1995
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1996
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
1997
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1998
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
1502
1999
 
1503
2000
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1504
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1505
-
1506
2001
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1507
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2002
+
2003
+ if (device->float_controls_rte_fp16) {
2004
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2005
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2006
+ } else {
2007
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2008
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2009
+ }
1508
2010
 
1509
2011
  ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
1510
2012
 
1511
2013
  ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1512
2014
 
1513
2015
  ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1514
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
2016
+ if (device->float_controls_rte_fp16) {
2017
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
2018
+ } else {
2019
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
2020
+ }
1515
2021
 
1516
2022
  ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
1517
2023
 
1518
2024
  ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);
1519
2025
 
2026
+ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2027
+
1520
2028
  for (auto &c : compiles) {
1521
2029
  c.wait();
1522
2030
  }
@@ -1550,12 +2058,40 @@ static vk_device ggml_vk_get_device(size_t idx) {
1550
2058
  device->physical_device = physical_devices[dev_num];
1551
2059
  const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
1552
2060
 
2061
+ bool fp16_storage = false;
2062
+ bool fp16_compute = false;
1553
2063
  bool maintenance4_support = false;
2064
+ bool sm_builtins = false;
2065
+ bool amd_shader_core_properties2 = false;
2066
+ bool pipeline_robustness = false;
2067
+ bool coopmat2_support = false;
2068
+ device->coopmat_support = false;
1554
2069
 
1555
2070
  // Check if maintenance4 is supported
1556
2071
  for (const auto& properties : ext_props) {
1557
2072
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
1558
2073
  maintenance4_support = true;
2074
+ } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
2075
+ fp16_storage = true;
2076
+ } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
2077
+ fp16_compute = true;
2078
+ } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) {
2079
+ sm_builtins = true;
2080
+ } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) {
2081
+ amd_shader_core_properties2 = true;
2082
+ } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) {
2083
+ pipeline_robustness = true;
2084
+ } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
2085
+ device->subgroup_size_control = true;
2086
+ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2087
+ !getenv("GGML_VK_DISABLE_COOPMAT")) {
2088
+ device->coopmat_support = true;
2089
+ device->coopmat_m = 0;
2090
+ device->coopmat_n = 0;
2091
+ device->coopmat_k = 0;
2092
+ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2093
+ !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2094
+ coopmat2_support = true;
1559
2095
  }
1560
2096
  }
1561
2097
 
@@ -1563,18 +2099,51 @@ static vk_device ggml_vk_get_device(size_t idx) {
1563
2099
  vk::PhysicalDeviceMaintenance3Properties props3;
1564
2100
  vk::PhysicalDeviceMaintenance4Properties props4;
1565
2101
  vk::PhysicalDeviceSubgroupProperties subgroup_props;
2102
+ vk::PhysicalDeviceDriverProperties driver_props;
2103
+ vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2104
+ vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2105
+ vk::PhysicalDeviceVulkan12Properties vk12_props;
2106
+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2107
+
1566
2108
  props2.pNext = &props3;
1567
2109
  props3.pNext = &subgroup_props;
2110
+ subgroup_props.pNext = &driver_props;
2111
+ driver_props.pNext = &vk12_props;
2112
+
2113
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
2114
+
1568
2115
  if (maintenance4_support) {
1569
- subgroup_props.pNext = &props4;
2116
+ last_struct->pNext = (VkBaseOutStructure *)&props4;
2117
+ last_struct = (VkBaseOutStructure *)&props4;
2118
+ }
2119
+ if (sm_builtins) {
2120
+ last_struct->pNext = (VkBaseOutStructure *)&sm_props;
2121
+ last_struct = (VkBaseOutStructure *)&sm_props;
2122
+ }
2123
+ if (amd_shader_core_properties2) {
2124
+ last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
2125
+ last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props;
1570
2126
  }
2127
+ if (device->subgroup_size_control) {
2128
+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props;
2129
+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_props;
2130
+ }
2131
+
2132
+ #if defined(VK_NV_cooperative_matrix2)
2133
+ vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props;
2134
+ if (coopmat2_support) {
2135
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props;
2136
+ last_struct = (VkBaseOutStructure *)&coopmat2_props;
2137
+ }
2138
+ #endif
2139
+
1571
2140
  device->physical_device.getProperties2(&props2);
1572
2141
  device->properties = props2.properties;
1573
2142
 
1574
2143
  const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
1575
2144
 
1576
2145
  if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) {
1577
- device->max_memory_allocation_size = std::stoi(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
2146
+ device->max_memory_allocation_size = std::stoul(GGML_VK_FORCE_MAX_ALLOCATION_SIZE);
1578
2147
  } else if (maintenance4_support) {
1579
2148
  device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize);
1580
2149
  } else {
@@ -1584,23 +2153,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
1584
2153
  device->vendor_id = device->properties.vendorID;
1585
2154
  device->subgroup_size = subgroup_props.subgroupSize;
1586
2155
  device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
1587
-
1588
- bool fp16_storage = false;
1589
- bool fp16_compute = false;
1590
-
1591
- for (const auto& properties : ext_props) {
1592
- if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1593
- fp16_storage = true;
1594
- } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1595
- fp16_compute = true;
1596
- }
2156
+ if (sm_builtins) {
2157
+ device->shader_core_count = sm_props.shaderSMCount;
2158
+ } else if (amd_shader_core_properties2) {
2159
+ device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount;
2160
+ } else {
2161
+ device->shader_core_count = 0;
1597
2162
  }
2163
+ device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
1598
2164
 
1599
- const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1600
- const bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
2165
+ const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
1601
2166
 
1602
2167
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
1603
2168
 
2169
+ if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2170
+ // Intel drivers don't support coopmat properly yet
2171
+ // Only RADV supports coopmat properly on AMD
2172
+ device->coopmat_support = false;
2173
+ }
2174
+
1604
2175
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
1605
2176
 
1606
2177
  // Try to find a non-graphics compute queue and transfer-focused queues
@@ -1638,10 +2209,149 @@ static vk_device ggml_vk_get_device(size_t idx) {
1638
2209
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1639
2210
  vk11_features.pNext = &vk12_features;
1640
2211
 
2212
+ last_struct = (VkBaseOutStructure *)&vk12_features;
2213
+
2214
+ VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features;
2215
+ pl_robustness_features.pNext = nullptr;
2216
+ pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT;
2217
+ pl_robustness_features.pipelineRobustness = VK_FALSE;
2218
+
2219
+ if (pipeline_robustness) {
2220
+ last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features;
2221
+ last_struct = (VkBaseOutStructure *)&pl_robustness_features;
2222
+ device_extensions.push_back("VK_EXT_pipeline_robustness");
2223
+ }
2224
+
2225
+ VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features;
2226
+ subgroup_size_control_features.pNext = nullptr;
2227
+ subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT;
2228
+ subgroup_size_control_features.computeFullSubgroups = false;
2229
+ subgroup_size_control_features.subgroupSizeControl = false;
2230
+
2231
+ if (device->subgroup_size_control) {
2232
+ last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features;
2233
+ last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
2234
+ }
2235
+
2236
+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2237
+ coopmat_features.pNext = nullptr;
2238
+ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
2239
+ coopmat_features.cooperativeMatrix = VK_FALSE;
2240
+
2241
+ if (device->coopmat_support) {
2242
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
2243
+ last_struct = (VkBaseOutStructure *)&coopmat_features;
2244
+ }
2245
+
2246
+ #if defined(VK_NV_cooperative_matrix2)
2247
+ VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
2248
+ coopmat2_features.pNext = nullptr;
2249
+ coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV;
2250
+ if (coopmat2_support) {
2251
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features;
2252
+ last_struct = (VkBaseOutStructure *)&coopmat2_features;
2253
+ device_extensions.push_back("VK_NV_cooperative_matrix2");
2254
+ }
2255
+ #endif
2256
+
1641
2257
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
1642
2258
 
1643
2259
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
1644
2260
 
2261
+ device->pipeline_robustness = pl_robustness_features.pipelineRobustness;
2262
+
2263
+ if (device->subgroup_size_control) {
2264
+ device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
2265
+ device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
2266
+ }
2267
+
2268
+ device->subgroup_size_control = device->subgroup_size_control &&
2269
+ (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) &&
2270
+ subgroup_size_control_features.subgroupSizeControl;
2271
+
2272
+ if (device->subgroup_size_control) {
2273
+ device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
2274
+ device_extensions.push_back("VK_EXT_subgroup_size_control");
2275
+ }
2276
+
2277
+ device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
2278
+
2279
+ if (coopmat2_support) {
2280
+ #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2281
+ if (coopmat2_features.cooperativeMatrixWorkgroupScope &&
2282
+ coopmat2_features.cooperativeMatrixFlexibleDimensions &&
2283
+ coopmat2_features.cooperativeMatrixReductions &&
2284
+ coopmat2_features.cooperativeMatrixConversions &&
2285
+ coopmat2_features.cooperativeMatrixPerElementOperations &&
2286
+ coopmat2_features.cooperativeMatrixTensorAddressing &&
2287
+ coopmat2_features.cooperativeMatrixBlockLoads &&
2288
+ vk12_features.bufferDeviceAddress) {
2289
+
2290
+ std::vector<VkCooperativeMatrixFlexibleDimensionsPropertiesNV> flexible_dimensions;
2291
+ uint32_t count = 0;
2292
+
2293
+ PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV
2294
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV =
2295
+ (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV)
2296
+ vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV");
2297
+
2298
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr);
2299
+
2300
+ VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {};
2301
+ empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV;
2302
+ flexible_dimensions.resize(count, empty_prop);
2303
+
2304
+ _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data());
2305
+
2306
+ bool found_fp16_128 = false,
2307
+ found_fp16_256 = false,
2308
+ found_fp32_128 = false,
2309
+ found_fp32_256 = false;
2310
+ // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128
2311
+ // with 32x16x16 and 256 with 32x32x16.
2312
+ for (auto &prop : flexible_dimensions) {
2313
+ if (prop.saturatingAccumulation == VK_FALSE &&
2314
+ prop.scope == VK_SCOPE_WORKGROUP_KHR &&
2315
+ prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
2316
+ prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
2317
+
2318
+ if (prop.workgroupInvocations == 128 &&
2319
+ prop.MGranularity <= 32 &&
2320
+ prop.NGranularity <= 16 &&
2321
+ prop.KGranularity <= 16) {
2322
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
2323
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
2324
+ found_fp16_128 = true;
2325
+ }
2326
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
2327
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
2328
+ found_fp32_128 = true;
2329
+ }
2330
+ }
2331
+ if (prop.workgroupInvocations == 256 &&
2332
+ prop.MGranularity <= 32 &&
2333
+ prop.NGranularity <= 32 &&
2334
+ prop.KGranularity <= 16) {
2335
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR &&
2336
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) {
2337
+ found_fp16_256 = true;
2338
+ }
2339
+ if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
2340
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) {
2341
+ found_fp32_256 = true;
2342
+ }
2343
+ }
2344
+ }
2345
+ }
2346
+ if (found_fp16_128 && found_fp16_256 &&
2347
+ found_fp32_128 && found_fp32_256 &&
2348
+ coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) {
2349
+ device->coopmat2 = true;
2350
+ }
2351
+ }
2352
+ #endif
2353
+ }
2354
+
1645
2355
  if (!vk11_features.storageBuffer16BitAccess) {
1646
2356
  std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl;
1647
2357
  throw std::runtime_error("Unsupported device");
@@ -1656,6 +2366,74 @@ static vk_device ggml_vk_get_device(size_t idx) {
1656
2366
  if (device->fp16) {
1657
2367
  device_extensions.push_back("VK_KHR_shader_float16_int8");
1658
2368
  }
2369
+
2370
+ if (device->coopmat_support) {
2371
+ // Query supported shapes
2372
+ std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
2373
+
2374
+ PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR =
2375
+ (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR");
2376
+
2377
+ uint32_t cm_props_num;
2378
+
2379
+ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr);
2380
+
2381
+ cm_props.resize(cm_props_num);
2382
+
2383
+ for (auto& prop : cm_props) {
2384
+ prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR;
2385
+ }
2386
+
2387
+ pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data());
2388
+
2389
+ VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size());
2390
+
2391
+ for (auto& prop : cm_props) {
2392
+ VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope));
2393
+
2394
+ if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 &&
2395
+ (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 &&
2396
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
2397
+ ) {
2398
+ if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 &&
2399
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) {
2400
+ // coopmat sizes not set yet
2401
+ if (device->coopmat_m == 0) {
2402
+ device->coopmat_acc_f32_support = true;
2403
+ device->coopmat_m = prop.MSize;
2404
+ device->coopmat_n = prop.NSize;
2405
+ device->coopmat_k = prop.KSize;
2406
+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
2407
+ // Only enable if shape is identical
2408
+ device->coopmat_acc_f32_support = true;
2409
+ }
2410
+ } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
2411
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
2412
+ // coopmat sizes not set yet
2413
+ if (device->coopmat_m == 0) {
2414
+ device->coopmat_acc_f16_support = true;
2415
+ device->coopmat_m = prop.MSize;
2416
+ device->coopmat_n = prop.NSize;
2417
+ device->coopmat_k = prop.KSize;
2418
+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
2419
+ // Only enable if shape is identical
2420
+ device->coopmat_acc_f16_support = true;
2421
+ }
2422
+ }
2423
+ }
2424
+ }
2425
+
2426
+ if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
2427
+ // No suitable matmul mode found
2428
+ GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
2429
+ device->coopmat_support = false;
2430
+ }
2431
+ }
2432
+
2433
+ if (device->coopmat_support) {
2434
+ device_extensions.push_back("VK_KHR_cooperative_matrix");
2435
+ }
2436
+
1659
2437
  device->name = GGML_VK_NAME + std::to_string(idx);
1660
2438
 
1661
2439
  device_create_info = {
@@ -1671,6 +2449,37 @@ static vk_device ggml_vk_get_device(size_t idx) {
1671
2449
  ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false);
1672
2450
 
1673
2451
  // Shaders
2452
+ // Disable matmul tile sizes early if performance low or not supported
2453
+ switch (device->vendor_id) {
2454
+ #ifndef GGML_VULKAN_RUN_TESTS
2455
+ case VK_VENDOR_ID_AMD:
2456
+ case VK_VENDOR_ID_INTEL:
2457
+ device->mul_mat_l = false;
2458
+ device->mul_mat_m = true;
2459
+ device->mul_mat_s = true;
2460
+ device->mul_mat_id_l = false;
2461
+ device->mul_mat_id_m = true;
2462
+ device->mul_mat_id_s = true;
2463
+ break;
2464
+ case VK_VENDOR_ID_APPLE:
2465
+ device->mul_mat_l = false;
2466
+ device->mul_mat_m = true;
2467
+ device->mul_mat_s = false;
2468
+ device->mul_mat_id_l = false;
2469
+ device->mul_mat_id_m = true;
2470
+ device->mul_mat_id_s = false;
2471
+ break;
2472
+ #endif
2473
+ default:
2474
+ device->mul_mat_l = true;
2475
+ device->mul_mat_m = true;
2476
+ device->mul_mat_s = true;
2477
+ device->mul_mat_id_l = true;
2478
+ device->mul_mat_id_m = true;
2479
+ device->mul_mat_id_s = true;
2480
+ break;
2481
+ }
2482
+
1674
2483
  ggml_vk_load_shaders(device);
1675
2484
 
1676
2485
  if (!device->single_queue) {
@@ -1728,15 +2537,31 @@ static void ggml_vk_print_gpu_info(size_t idx) {
1728
2537
 
1729
2538
  bool fp16_storage = false;
1730
2539
  bool fp16_compute = false;
2540
+ bool coopmat_support = false;
2541
+ bool coopmat2_support = false;
1731
2542
 
1732
2543
  for (auto properties : ext_props) {
1733
2544
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
1734
2545
  fp16_storage = true;
1735
2546
  } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
1736
2547
  fp16_compute = true;
2548
+ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2549
+ !getenv("GGML_VK_DISABLE_COOPMAT")) {
2550
+ coopmat_support = true;
2551
+ #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2552
+ } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2553
+ !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2554
+ coopmat2_support = true;
2555
+ #endif
1737
2556
  }
1738
2557
  }
1739
2558
 
2559
+ if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2560
+ // Intel drivers don't support coopmat properly yet
2561
+ // Only RADV supports coopmat properly on AMD
2562
+ coopmat_support = false;
2563
+ }
2564
+
1740
2565
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
1741
2566
  bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
1742
2567
 
@@ -1759,16 +2584,33 @@ static void ggml_vk_print_gpu_info(size_t idx) {
1759
2584
  vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES;
1760
2585
  vk11_features.pNext = &vk12_features;
1761
2586
 
2587
+ // Pointer to the last chain element
2588
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
2589
+
2590
+ VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2591
+ coopmat_features.pNext = nullptr;
2592
+ coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
2593
+ coopmat_features.cooperativeMatrix = VK_FALSE;
2594
+
2595
+ if (coopmat_support) {
2596
+ last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
2597
+ last_struct = (VkBaseOutStructure *)&coopmat_features;
2598
+ }
2599
+
1762
2600
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
1763
2601
 
1764
2602
  fp16 = fp16 && vk12_features.shaderFloat16;
1765
2603
 
2604
+ coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
2605
+
2606
+ std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
2607
+
1766
2608
  std::string device_name = props2.properties.deviceName.data();
1767
- GGML_LOG_DEBUG("ggml_vulkan: %d = %s (%s) | uma: %d | fp16: %d | warp size: %d\n",
1768
- idx, device_name.c_str(), driver_props.driverName, uma, fp16, subgroup_size);
2609
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
2610
+ idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str());
1769
2611
 
1770
2612
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
1771
- std::cerr << "ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want." << std::endl;
2613
+ GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
1772
2614
  }
1773
2615
  }
1774
2616
 
@@ -1937,8 +2779,7 @@ void ggml_vk_instance_init() {
1937
2779
  vk_instance.device_indices.push_back(0);
1938
2780
  }
1939
2781
  }
1940
- GGML_LOG_DEBUG("ggml_vulkan: Found %d Vulkan devices:\n", vk_instance.device_indices.size());
1941
-
2782
+ GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size());
1942
2783
 
1943
2784
  for (size_t i = 0; i < vk_instance.device_indices.size(); i++) {
1944
2785
  ggml_vk_print_gpu_info(i);
@@ -1994,7 +2835,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
1994
2835
  return ctx->device->pipeline_dequant[type];
1995
2836
  }
1996
2837
 
1997
- static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
2838
+ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
1998
2839
  VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
1999
2840
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2000
2841
  return ctx->device->pipeline_matmul_f32;
@@ -2002,14 +2843,23 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2002
2843
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
2003
2844
  return ctx->device->pipeline_matmul_f32_f16;
2004
2845
  }
2005
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2006
- return ctx->device->pipeline_matmul_f16_f32.f32acc;
2007
- }
2008
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2009
- return ctx->device->pipeline_matmul_f16.f32acc;
2846
+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
2847
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2848
+ return ctx->device->pipeline_matmul_f16_f32.f16acc;
2849
+ }
2850
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2851
+ return ctx->device->pipeline_matmul_f16.f16acc;
2852
+ }
2853
+ } else {
2854
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2855
+ return ctx->device->pipeline_matmul_f16_f32.f32acc;
2856
+ }
2857
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2858
+ return ctx->device->pipeline_matmul_f16.f32acc;
2859
+ }
2010
2860
  }
2011
2861
 
2012
- if (src1_type != GGML_TYPE_F32) {
2862
+ if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
2013
2863
  return nullptr;
2014
2864
  }
2015
2865
 
@@ -2030,7 +2880,11 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2030
2880
  return nullptr;
2031
2881
  }
2032
2882
 
2033
- return ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
2883
+ if (ctx->device->coopmat2) {
2884
+ assert(src1_type == GGML_TYPE_F16);
2885
+ return ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc;
2886
+ }
2887
+ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
2034
2888
  }
2035
2889
 
2036
2890
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@@ -2059,16 +2913,25 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
2059
2913
  return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
2060
2914
  }
2061
2915
 
2062
- static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type) {
2916
+ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
2063
2917
  VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()");
2064
2918
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
2065
2919
  return ctx->device->pipeline_matmul_id_f32;
2066
2920
  }
2067
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2068
- return ctx->device->pipeline_matmul_id_f16_f32;
2069
- }
2070
- if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2071
- return ctx->device->pipeline_matmul_id_f16;
2921
+ if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
2922
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2923
+ return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
2924
+ }
2925
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2926
+ return ctx->device->pipeline_matmul_id_f16.f16acc;
2927
+ }
2928
+ } else {
2929
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
2930
+ return ctx->device->pipeline_matmul_id_f16_f32.f32acc;
2931
+ }
2932
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
2933
+ return ctx->device->pipeline_matmul_id_f16.f32acc;
2934
+ }
2072
2935
  }
2073
2936
 
2074
2937
  GGML_ASSERT(src1_type == GGML_TYPE_F32);
@@ -2090,7 +2953,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
2090
2953
  return nullptr;
2091
2954
  }
2092
2955
 
2093
- return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type];
2956
+ return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc;
2094
2957
  }
2095
2958
 
2096
2959
  static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
@@ -2659,55 +3522,44 @@ static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, siz
2659
3522
  dst->device->device.resetFences({ dst->device->fence });
2660
3523
  }
2661
3524
 
2662
- static uint32_t ggml_vk_guess_split_k(int m, int n, int k) {
3525
+ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int n, int k, const vk_pipeline& pipeline) {
2663
3526
  VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ")");
2664
- // if (k > 128 && (m < 128 || n < 128) && m > 2 && n > 2) {
2665
- // return 4;
2666
- // }
2667
3527
 
2668
- return 1;
2669
-
2670
- GGML_UNUSED(m); GGML_UNUSED(n); GGML_UNUSED(k);
2671
- }
2672
-
2673
- static vk_pipeline ggml_vk_guess_matmul_pipeline_amd(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2674
- if (m <= 32 || n <= 32) {
2675
- return aligned ? mmp->a_s : mmp->s;
3528
+ uint32_t split_k = 1;
3529
+ if (ctx->device->shader_core_count != 0 && m >= (int)pipeline->wg_denoms[0] && n >= (int)pipeline->wg_denoms[1]) {
3530
+ // If k is 'large' and the SMs will fill less than halfway, use split_k.
3531
+ uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]);
3532
+ uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]);
3533
+ if (k >= 2048 && m_tiles * n_tiles < ctx->device->shader_core_count / 2) {
3534
+ split_k = ctx->device->shader_core_count / (m_tiles * n_tiles);
3535
+ // Clamp to 2 or 4
3536
+ split_k = std::min(split_k, 4u);
3537
+ if (split_k == 3) {
3538
+ split_k = 2;
3539
+ }
3540
+ }
2676
3541
  }
2677
- return aligned ? mmp->a_m : mmp->m;
2678
-
2679
- GGML_UNUSED(ctx);
2680
- }
2681
3542
 
2682
- static vk_pipeline ggml_vk_guess_matmul_pipeline_apple(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2683
- return aligned ? mmp->a_m : mmp->m;
2684
-
2685
- GGML_UNUSED(ctx);
2686
- }
2687
-
2688
- static vk_pipeline ggml_vk_guess_matmul_pipeline_intel(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, bool aligned) {
2689
- return aligned ? mmp->a_s : mmp->s;
2690
-
2691
- GGML_UNUSED(ctx);
3543
+ return split_k;
2692
3544
  }
2693
3545
 
2694
3546
  static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
2695
3547
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
2696
- switch (ctx->device->vendor_id) {
2697
- case VK_VENDOR_ID_AMD:
2698
- return ggml_vk_guess_matmul_pipeline_amd(ctx, mmp, m, n, aligned);
2699
- case VK_VENDOR_ID_APPLE:
2700
- return ggml_vk_guess_matmul_pipeline_apple(ctx, mmp, aligned);
2701
- case VK_VENDOR_ID_INTEL:
2702
- return ggml_vk_guess_matmul_pipeline_intel(ctx, mmp, aligned);
2703
- default:
2704
- break;
3548
+
3549
+ if (ctx->device->coopmat2) {
3550
+ if ((ctx->device->mul_mat_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_s)) {
3551
+ return aligned ? mmp->a_l : mmp->l;
3552
+ }
3553
+ if ((ctx->device->mul_mat_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s) {
3554
+ return aligned ? mmp->a_m : mmp->m;
3555
+ }
3556
+ return aligned ? mmp->a_s : mmp->s;
2705
3557
  }
2706
3558
 
2707
- if (m <= 32 || n <= 32) {
3559
+ if ((ctx->device->mul_mat_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m && !ctx->device->mul_mat_l)) {
2708
3560
  return aligned ? mmp->a_s : mmp->s;
2709
3561
  }
2710
- if (m <= 64 || n <= 64) {
3562
+ if ((ctx->device->mul_mat_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l) {
2711
3563
  return aligned ? mmp->a_m : mmp->m;
2712
3564
  }
2713
3565
  return aligned ? mmp->a_l : mmp->l;
@@ -2742,6 +3594,33 @@ static void ggml_vk_matmul(
2742
3594
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
2743
3595
  }
2744
3596
 
3597
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned) {
3598
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ")");
3599
+
3600
+ if (ctx->device->coopmat2) {
3601
+ if ((ctx->device->mul_mat_id_l && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_s)) {
3602
+ return aligned ? mmp->a_l : mmp->l;
3603
+ }
3604
+ if ((ctx->device->mul_mat_id_m && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s) {
3605
+ return aligned ? mmp->a_m : mmp->m;
3606
+ }
3607
+ return aligned ? mmp->a_s : mmp->s;
3608
+ }
3609
+
3610
+ if ((ctx->device->mul_mat_id_s && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m && !ctx->device->mul_mat_id_l)) {
3611
+ return aligned ? mmp->a_s : mmp->s;
3612
+ }
3613
+ if ((ctx->device->mul_mat_id_m && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l) {
3614
+ return aligned ? mmp->a_m : mmp->m;
3615
+ }
3616
+ return aligned ? mmp->a_l : mmp->l;
3617
+ }
3618
+
3619
+ static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n) {
3620
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ")");
3621
+ return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true)->align;
3622
+ }
3623
+
2745
3624
  static void ggml_vk_matmul_id(
2746
3625
  ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline,
2747
3626
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
@@ -2812,13 +3691,15 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
2812
3691
  elements = { ne, 1, 1 };
2813
3692
  }
2814
3693
 
2815
- const vk_op_unary_push_constants pc = {
3694
+ vk_op_unary_push_constants pc = {
2816
3695
  (uint32_t)ne,
2817
3696
  (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size,
2818
3697
  (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]),
2819
3698
  0,
2820
3699
  0.0f, 0.0f,
3700
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
2821
3701
  };
3702
+ init_pushconst_fastdiv(pc);
2822
3703
  ggml_vk_sync_buffers(subctx);
2823
3704
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
2824
3705
  }
@@ -2867,18 +3748,20 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
2867
3748
  }
2868
3749
 
2869
3750
  const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
2870
- const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
3751
+ // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
3752
+ const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
3753
+ !ggml_vk_dim01_contiguous(src1);
2871
3754
 
2872
3755
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
2873
3756
 
2874
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
3757
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
2875
3758
 
2876
3759
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
2877
3760
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
2878
3761
 
2879
3762
  if (qx_needs_dequant) {
2880
3763
  // Fall back to dequant + f16 mulmat
2881
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16);
3764
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
2882
3765
  }
2883
3766
 
2884
3767
  // Not implemented
@@ -2891,10 +3774,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
2891
3774
  const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11));
2892
3775
  const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
2893
3776
 
2894
- const uint32_t split_k = ggml_vk_guess_split_k(ne01, ne11, ne10);
2895
-
2896
3777
  vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned);
2897
3778
 
3779
+ const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
3780
+
2898
3781
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
2899
3782
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
2900
3783
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -2920,7 +3803,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
2920
3803
  if (dryrun) {
2921
3804
  const uint64_t x_sz_upd = x_sz * ne02 * ne03;
2922
3805
  const uint64_t y_sz_upd = y_sz * ne12 * ne13;
2923
- const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * 4 : 0;
3806
+ const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0;
2924
3807
  if (
2925
3808
  (qx_needs_dequant && x_sz_upd > ctx->device->max_memory_allocation_size) ||
2926
3809
  (qy_needs_dequant && y_sz_upd > ctx->device->max_memory_allocation_size) ||
@@ -3187,7 +4070,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
3187
4070
 
3188
4071
  if (ne01 > max_groups_x) {
3189
4072
  groups_z = 64;
3190
- groups_x /= groups_z;
4073
+ groups_x = CEIL_DIV(groups_x, groups_z);
3191
4074
  }
3192
4075
 
3193
4076
  // compute
@@ -3442,7 +4325,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
3442
4325
 
3443
4326
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
3444
4327
 
3445
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type);
4328
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
3446
4329
 
3447
4330
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
3448
4331
  const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
@@ -3458,10 +4341,10 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
3458
4341
  const uint64_t y_ne = ne11 * ne10;
3459
4342
  const uint64_t d_ne = ne21 * ne20;
3460
4343
 
3461
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, nei1));
4344
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1));
3462
4345
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
3463
4346
 
3464
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, nei1, aligned);
4347
+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned);
3465
4348
 
3466
4349
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
3467
4350
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
@@ -3764,7 +4647,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
3764
4647
 
3765
4648
  if (ne01 > max_groups_x) {
3766
4649
  groups_z = 64;
3767
- groups_x /= groups_z;
4650
+ groups_x = CEIL_DIV(groups_x, groups_z);
3768
4651
  }
3769
4652
 
3770
4653
  // compute
@@ -3789,6 +4672,167 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
3789
4672
  }
3790
4673
  }
3791
4674
 
4675
+ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
4676
+ VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
4677
+ std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
4678
+ std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3];
4679
+ std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
4680
+ std::cerr << "), " << (dryrun ? "dryrun" : "") << ")");
4681
+
4682
+ GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
4683
+ GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
4684
+ GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
4685
+ GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
4686
+ GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
4687
+ GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
4688
+ GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
4689
+ GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
4690
+
4691
+ const uint32_t nem1 = mask ? mask->ne[1] : 0;
4692
+ const uint32_t nbm1 = mask ? mask->nb[1] : 0;
4693
+
4694
+ const uint32_t D = neq0;
4695
+ const uint32_t N = neq1;
4696
+ const uint32_t KV = nek1;
4697
+
4698
+ GGML_ASSERT(ne0 == D);
4699
+ GGML_ASSERT(ne2 == N);
4700
+
4701
+ // input tensor rows must be contiguous
4702
+ GGML_ASSERT(nbq0 == ggml_type_size(q->type));
4703
+ GGML_ASSERT(nbk0 == ggml_type_size(k->type));
4704
+ GGML_ASSERT(nbv0 == ggml_type_size(v->type));
4705
+
4706
+ GGML_ASSERT(neq0 == D);
4707
+ GGML_ASSERT(nek0 == D);
4708
+ GGML_ASSERT(nev0 == D);
4709
+
4710
+ GGML_ASSERT(neq1 == N);
4711
+ GGML_ASSERT(nev0 == D);
4712
+
4713
+ GGML_ASSERT(nev1 == nek1);
4714
+
4715
+ // dst cannot be transposed or permuted
4716
+ GGML_ASSERT(nb0 == sizeof(float));
4717
+ GGML_ASSERT(nb0 <= nb1);
4718
+ GGML_ASSERT(nb1 <= nb2);
4719
+ GGML_ASSERT(nb2 <= nb3);
4720
+
4721
+ assert(dst->type == GGML_TYPE_F32);
4722
+ assert(q->type == GGML_TYPE_F32);
4723
+ assert(k->type == v->type);
4724
+
4725
+ vk_pipeline *pipelines;
4726
+ // XXX TODO other backends may be changing accumulator precision to default to f32 soon
4727
+ bool f32acc = dst->op_params[3] == GGML_PREC_F32;
4728
+ bool small_rows = N <= flash_attention_num_small_rows;
4729
+ switch (D) {
4730
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
4731
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
4732
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
4733
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
4734
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
4735
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
4736
+ default:
4737
+ assert(!"unsupported D value");
4738
+ return;
4739
+ }
4740
+ assert(pipelines);
4741
+
4742
+ bool aligned = (KV % pipelines[1]->align) == 0;
4743
+ vk_pipeline pipeline = pipelines[aligned];
4744
+ assert(pipeline);
4745
+
4746
+ if (dryrun) {
4747
+ // Request descriptor sets
4748
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
4749
+ return;
4750
+ }
4751
+
4752
+ float scale = 1.0f;
4753
+ float max_bias = 0.0f;
4754
+ float logit_softcap = 0.0f;
4755
+
4756
+ memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
4757
+ memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
4758
+ memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float));
4759
+
4760
+ if (logit_softcap != 0) {
4761
+ scale /= logit_softcap;
4762
+ }
4763
+
4764
+ const uint32_t n_head_kv = neq2;
4765
+ const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
4766
+ const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
4767
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
4768
+
4769
+ ggml_vk_sync_buffers(subctx);
4770
+
4771
+ vk_buffer d_Q, d_K, d_V, d_D, d_M;
4772
+ uint64_t q_buf_offset, k_buf_offset, v_buf_offset, d_buf_offset, m_buf_offset;
4773
+
4774
+ bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
4775
+
4776
+ if (ctx->device->uma) {
4777
+ ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
4778
+ ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
4779
+ ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
4780
+ ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
4781
+ Q_uma = d_Q != nullptr;
4782
+ K_uma = d_K != nullptr;
4783
+ V_uma = d_V != nullptr;
4784
+ D_uma = d_D != nullptr;
4785
+ if (mask) {
4786
+ ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
4787
+ M_uma = d_M != nullptr;
4788
+ }
4789
+ }
4790
+
4791
+
4792
+ ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
4793
+ ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context;
4794
+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
4795
+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
4796
+
4797
+ if (!Q_uma) {
4798
+ d_Q = q_buf_ctx->dev_buffer;
4799
+ q_buf_offset = vk_tensor_offset(q) + q->view_offs;
4800
+ }
4801
+ if (!K_uma) {
4802
+ d_K = k_buf_ctx->dev_buffer;
4803
+ k_buf_offset = vk_tensor_offset(k) + k->view_offs;
4804
+ }
4805
+ if (!V_uma) {
4806
+ d_V = v_buf_ctx->dev_buffer;
4807
+ v_buf_offset = vk_tensor_offset(v) + v->view_offs;
4808
+ }
4809
+ if (!D_uma) {
4810
+ d_D = d_buf_ctx->dev_buffer;
4811
+ d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
4812
+ }
4813
+
4814
+ if (!M_uma) {
4815
+ d_M = d_Q;
4816
+ m_buf_offset = q_buf_offset;
4817
+ if (mask) {
4818
+ ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context;
4819
+ d_M = m_buf_ctx->dev_buffer;
4820
+ m_buf_offset = vk_tensor_offset(mask) + mask->view_offs;
4821
+ }
4822
+ }
4823
+
4824
+ const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
4825
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
4826
+ {
4827
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
4828
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
4829
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
4830
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
4831
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
4832
+ },
4833
+ sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
4834
+ }
4835
+
3792
4836
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
3793
4837
  switch (op) {
3794
4838
  case GGML_OP_GET_ROWS:
@@ -3933,10 +4977,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3933
4977
  GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
3934
4978
 
3935
4979
  if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
3936
- return ctx->device->pipeline_soft_max_f32;
4980
+ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32;
3937
4981
  }
3938
4982
  if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
3939
- return ctx->device->pipeline_soft_max_f32_f16;
4983
+ return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
3940
4984
  }
3941
4985
  return nullptr;
3942
4986
  case GGML_OP_ROPE:
@@ -3989,6 +5033,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
3989
5033
  return ctx->device->pipeline_pool2d_f32;
3990
5034
  }
3991
5035
  return nullptr;
5036
+ case GGML_OP_RWKV_WKV6:
5037
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5038
+ return ctx->device->pipeline_rwkv_wkv6_f32;
5039
+ }
5040
+ return nullptr;
3992
5041
  case GGML_OP_LEAKY_RELU:
3993
5042
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
3994
5043
  return ctx->device->pipeline_leaky_relu_f32;
@@ -4023,7 +5072,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
4023
5072
  }
4024
5073
 
4025
5074
  template<typename PC>
4026
- static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc, bool dryrun = false) {
5075
+ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
4027
5076
  VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
4028
5077
  if (src1 != nullptr) {
4029
5078
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4063,6 +5112,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
4063
5112
  const uint64_t ned3 = dst->ne[3];
4064
5113
  const uint64_t ned = ned0 * ned1;
4065
5114
 
5115
+ init_pushconst_fastdiv(pc);
5116
+
4066
5117
  vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op);
4067
5118
 
4068
5119
  if (pipeline == nullptr) {
@@ -4389,6 +5440,134 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
4389
5440
  }, dryrun);
4390
5441
  }
4391
5442
 
5443
+ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
5444
+ const ggml_tensor * k = dst->src[0];
5445
+ const ggml_tensor * v = dst->src[1];
5446
+ const ggml_tensor * r = dst->src[2];
5447
+ const ggml_tensor * tf = dst->src[3];
5448
+ const ggml_tensor * td = dst->src[4];
5449
+ const ggml_tensor * state = dst->src[5];
5450
+
5451
+ GGML_ASSERT(!ggml_is_quantized(k->type));
5452
+ GGML_ASSERT(!ggml_is_quantized(v->type));
5453
+ GGML_ASSERT(!ggml_is_quantized(r->type));
5454
+ GGML_ASSERT(!ggml_is_quantized(tf->type));
5455
+ GGML_ASSERT(!ggml_is_quantized(td->type));
5456
+ GGML_ASSERT(!ggml_is_quantized(state->type));
5457
+ GGML_ASSERT(dst->buffer != nullptr);
5458
+
5459
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
5460
+ GGML_ASSERT(pipeline != nullptr);
5461
+
5462
+ if (dryrun) {
5463
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5464
+ return;
5465
+ }
5466
+
5467
+ ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
5468
+ ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
5469
+ ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
5470
+ ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
5471
+ ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
5472
+ ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
5473
+ ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
5474
+
5475
+ ggml_vk_sync_buffers(subctx);
5476
+
5477
+ vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
5478
+ uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
5479
+ bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
5480
+
5481
+ if (ctx->device->uma) {
5482
+ ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
5483
+ ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
5484
+ ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
5485
+ ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
5486
+ ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
5487
+ ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
5488
+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
5489
+
5490
+ K_uma = d_K != nullptr;
5491
+ V_uma = d_V != nullptr;
5492
+ R_uma = d_R != nullptr;
5493
+ TF_uma = d_TF != nullptr;
5494
+ TD_uma = d_TD != nullptr;
5495
+ STATE_uma = d_State != nullptr;
5496
+ DST_uma = d_D != nullptr;
5497
+ }
5498
+
5499
+ if (!K_uma) {
5500
+ d_K = k_buf_ctx->dev_buffer;
5501
+ k_offset = vk_tensor_offset(k) + k->view_offs;
5502
+ }
5503
+ if (!V_uma) {
5504
+ d_V = v_buf_ctx->dev_buffer;
5505
+ v_offset = vk_tensor_offset(v) + v->view_offs;
5506
+ }
5507
+ if (!R_uma) {
5508
+ d_R = r_buf_ctx->dev_buffer;
5509
+ r_offset = vk_tensor_offset(r) + r->view_offs;
5510
+ }
5511
+ if (!TF_uma) {
5512
+ d_TF = tf_buf_ctx->dev_buffer;
5513
+ tf_offset = vk_tensor_offset(tf) + tf->view_offs;
5514
+ }
5515
+ if (!TD_uma) {
5516
+ d_TD = td_buf_ctx->dev_buffer;
5517
+ td_offset = vk_tensor_offset(td) + td->view_offs;
5518
+ }
5519
+ if (!STATE_uma) {
5520
+ d_State = state_buf_ctx->dev_buffer;
5521
+ state_offset = vk_tensor_offset(state) + state->view_offs;
5522
+ }
5523
+ if (!DST_uma) {
5524
+ d_D = dst_buf_ctx->dev_buffer;
5525
+ dst_offset = vk_tensor_offset(dst) + dst->view_offs;
5526
+ }
5527
+
5528
+ const uint64_t k_size = ggml_nbytes(k);
5529
+ const uint64_t v_size = ggml_nbytes(v);
5530
+ const uint64_t r_size = ggml_nbytes(r);
5531
+ const uint64_t tf_size = ggml_nbytes(tf);
5532
+ const uint64_t td_size = ggml_nbytes(td);
5533
+ const uint64_t state_size = ggml_nbytes(state);
5534
+ const uint64_t dst_size = ggml_nbytes(dst);
5535
+
5536
+ std::array<uint32_t, 3> elements = {
5537
+ (uint32_t)(pc.B * pc.H),
5538
+ 1,
5539
+ 1
5540
+ };
5541
+
5542
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
5543
+ vk_subbuffer{ d_K, k_offset, k_size },
5544
+ vk_subbuffer{ d_V, v_offset, v_size },
5545
+ vk_subbuffer{ d_R, r_offset, r_size },
5546
+ vk_subbuffer{ d_TF, tf_offset, tf_size },
5547
+ vk_subbuffer{ d_TD, td_offset, td_size },
5548
+ vk_subbuffer{ d_State, state_offset, state_size },
5549
+ vk_subbuffer{ d_D, dst_offset, dst_size }
5550
+ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
5551
+ }
5552
+
5553
+ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
5554
+ const size_t seq_length = dst->src[0]->ne[3];
5555
+ const size_t n_embed = dst->ne[0];
5556
+ const size_t n_heads = dst->src[0]->ne[2];
5557
+ const size_t n_seqs = dst->src[5]->ne[1];
5558
+
5559
+ ggml_vk_op_f32_rwkv6(
5560
+ ctx, subctx, dst,
5561
+ {
5562
+ (uint32_t)n_seqs,
5563
+ (uint32_t)seq_length,
5564
+ (uint32_t)n_embed,
5565
+ (uint32_t)n_heads,
5566
+ },
5567
+ dryrun
5568
+ );
5569
+ }
5570
+
4392
5571
  static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
4393
5572
  int * op_params = (int *)dst->op_params;
4394
5573
 
@@ -4432,7 +5611,8 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
4432
5611
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
4433
5612
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4434
5613
  0,
4435
- op_params[0], 0.0f
5614
+ op_params[0], 0.0f,
5615
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4436
5616
  }, dryrun);
4437
5617
  }
4438
5618
 
@@ -4446,6 +5626,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const
4446
5626
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4447
5627
  0,
4448
5628
  0.0f, 0.0f,
5629
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4449
5630
  }, dryrun);
4450
5631
  }
4451
5632
 
@@ -4459,6 +5640,7 @@ static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const
4459
5640
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4460
5641
  0,
4461
5642
  0.0f, 0.0f,
5643
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4462
5644
  }, dryrun);
4463
5645
  }
4464
5646
 
@@ -4472,6 +5654,7 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
4472
5654
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4473
5655
  0,
4474
5656
  0.0f, 0.0f,
5657
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4475
5658
  }, dryrun);
4476
5659
  }
4477
5660
 
@@ -4486,6 +5669,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
4486
5669
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4487
5670
  0,
4488
5671
  op_params[0], op_params[1],
5672
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4489
5673
  }, dryrun);
4490
5674
  }
4491
5675
 
@@ -4499,6 +5683,7 @@ static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const
4499
5683
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4500
5684
  0,
4501
5685
  0.0f, 0.0f,
5686
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4502
5687
  }, dryrun);
4503
5688
  }
4504
5689
 
@@ -4512,6 +5697,7 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
4512
5697
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4513
5698
  0,
4514
5699
  0.0f, 0.0f,
5700
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4515
5701
  }, dryrun);
4516
5702
  }
4517
5703
 
@@ -4526,6 +5712,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
4526
5712
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
4527
5713
  d_offset,
4528
5714
  0.0f, 0.0f,
5715
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
4529
5716
  }, dryrun);
4530
5717
  }
4531
5718
 
@@ -4582,6 +5769,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
4582
5769
  scale, max_bias,
4583
5770
  m0, m1,
4584
5771
  n_head_log2,
5772
+ nrows_x,
4585
5773
  }, dryrun);
4586
5774
  }
4587
5775
 
@@ -4878,19 +6066,27 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
4878
6066
  for (size_t i = 0; i < x_ne; i++) {
4879
6067
  if (std::is_same<float, X_TYPE>()) {
4880
6068
  x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
6069
+ // x[i] = 1.0f;
6070
+ // x[i] = i + 1;
6071
+ // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
4881
6072
  } else if (std::is_same<ggml_fp16_t, X_TYPE>()) {
4882
6073
  x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
6074
+ // x[i] = ggml_fp32_to_fp16(1.0f);
6075
+ // x[i] = ggml_fp32_to_fp16(i + 1);
6076
+ // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
4883
6077
  } else {
4884
6078
  GGML_ABORT("fatal error");
4885
6079
  }
4886
6080
  }
4887
6081
  for (size_t i = 0; i < y_ne; i++) {
4888
6082
  if (std::is_same<float, Y_TYPE>()) {
4889
- // y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
4890
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
6083
+ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
6084
+ // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
6085
+ // y[i] = i + 1;
4891
6086
  } else if (std::is_same<ggml_fp16_t, Y_TYPE>()) {
4892
- // y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
4893
- y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
6087
+ y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f);
6088
+ // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f);
6089
+ // y[i] = ggml_fp32_to_fp16(i + 1);
4894
6090
  } else {
4895
6091
  GGML_ABORT("fatal error");
4896
6092
  }
@@ -4900,16 +6096,16 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
4900
6096
  ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch);
4901
6097
 
4902
6098
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
6099
+ ggml_vk_ctx_begin(ctx->device, subctx);
4903
6100
  for (size_t i = 0; i < num_it; i++) {
4904
- ggml_vk_ctx_begin(ctx->device, subctx);
4905
6101
  ggml_vk_matmul(
4906
6102
  ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
4907
6103
  m, n, k,
4908
6104
  k, k, m, k*m, k*n, m*n,
4909
6105
  split_k, batch, batch, batch, 1, 1
4910
6106
  );
4911
- ggml_vk_ctx_end(subctx);
4912
6107
  }
6108
+ ggml_vk_ctx_end(subctx);
4913
6109
 
4914
6110
  auto begin = std::chrono::high_resolution_clock::now();
4915
6111
  ggml_vk_submit(subctx, ctx->fence);
@@ -4974,7 +6170,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
4974
6170
  double err = std::fabs(d[i] - d_chk[i]);
4975
6171
  avg_err += err;
4976
6172
 
4977
- if (err > 0.05f && first_err_n == -1) {
6173
+ if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) {
4978
6174
  first_err_b = i / (m * n);
4979
6175
  first_err_n = (i % (m * n)) / m;
4980
6176
  first_err_m = (i % (m * n)) % m;
@@ -4987,12 +6183,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
4987
6183
 
4988
6184
  std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
4989
6185
 
4990
- if (avg_err > 0.1) {
6186
+ if (avg_err > 0.1 || std::isnan(avg_err)) {
4991
6187
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
4992
6188
  std::cerr << "Actual result: " << std::endl << std::endl;
4993
6189
  ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4994
- std::cerr << std::endl;
4995
- ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n + 15, first_err_b);
4996
6190
  std::cerr << "Expected result: " << std::endl << std::endl;
4997
6191
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
4998
6192
 
@@ -5175,13 +6369,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5175
6369
  vk_pipeline p;
5176
6370
  std::string shname;
5177
6371
  if (shader_size == 0) {
5178
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
6372
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
5179
6373
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
5180
6374
  } else if (shader_size == 1) {
5181
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
6375
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
5182
6376
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
5183
6377
  } else if (shader_size == 2) {
5184
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
6378
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
5185
6379
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
5186
6380
  } else {
5187
6381
  GGML_ASSERT(0);
@@ -5191,13 +6385,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5191
6385
 
5192
6386
  if (k != kpad) {
5193
6387
  if (shader_size == 0) {
5194
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
6388
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
5195
6389
  shname = std::string(ggml_type_name(quant)) + "_S";
5196
6390
  } else if (shader_size == 1) {
5197
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
6391
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
5198
6392
  shname = std::string(ggml_type_name(quant)) + "_M";
5199
6393
  } else if (shader_size == 2) {
5200
- p = ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
6394
+ p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
5201
6395
  shname = std::string(ggml_type_name(quant)) + "_L";
5202
6396
  } else {
5203
6397
  GGML_ASSERT(0);
@@ -5247,16 +6441,16 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5247
6441
  ggml_vk_buffer_write(y_buf, 0, y, y_sz);
5248
6442
 
5249
6443
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
6444
+ ggml_vk_ctx_begin(ctx->device, subctx);
5250
6445
  for (size_t i = 0; i < num_it; i++) {
5251
- ggml_vk_ctx_begin(ctx->device, subctx);
5252
6446
  ggml_vk_matmul(
5253
6447
  ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
5254
6448
  m, n, k,
5255
6449
  k, k, m, k*m, k*n, m*n,
5256
6450
  split_k, batch, batch, batch, 1, 1
5257
6451
  );
5258
- ggml_vk_ctx_end(subctx);
5259
6452
  }
6453
+ ggml_vk_ctx_end(subctx);
5260
6454
 
5261
6455
  auto begin = std::chrono::high_resolution_clock::now();
5262
6456
 
@@ -5356,105 +6550,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
5356
6550
 
5357
6551
  static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
5358
6552
  #if defined(GGML_VULKAN_RUN_TESTS)
5359
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_F32);
5360
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_0);
5361
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_1);
5362
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_0);
5363
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_1);
5364
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q8_0);
5365
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q2_K);
5366
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q3_K);
5367
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q4_K);
5368
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q5_K);
5369
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_Q6_K);
5370
- ggml_vk_test_dequant(ctx, 7680, GGML_TYPE_IQ4_NL);
5371
-
5372
- ggml_vk_test_matmul<ggml_fp16_t, ggml_fp16_t>(ctx, 512, 512, 100, 32, 100, 1, 2);
5373
-
5374
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 0);
5375
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 1);
5376
- ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 1, 2);
5377
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 0);
5378
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 1);
5379
- // ggml_vk_test_matmul<float, float>(ctx, 128, 512, 512, 2, 100, 4, 2);
5380
-
5381
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_0);
5382
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_0);
5383
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_0);
5384
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_0);
5385
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_0);
5386
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_0);
5387
-
5388
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_1);
5389
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_1);
5390
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_1);
5391
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_1);
5392
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_1);
5393
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_1);
5394
-
5395
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_0);
5396
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_0);
5397
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_0);
5398
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_0);
5399
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_0);
5400
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_0);
5401
-
5402
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_1);
5403
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_1);
5404
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_1);
5405
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_1);
5406
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_1);
5407
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_1);
5408
-
5409
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q8_0);
5410
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q8_0);
5411
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q8_0);
5412
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q8_0);
5413
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q8_0);
5414
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q8_0);
5415
-
5416
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q2_K);
5417
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q2_K);
5418
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q2_K);
5419
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q2_K);
5420
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q2_K);
5421
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q2_K);
5422
-
5423
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q3_K);
5424
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q3_K);
5425
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q3_K);
5426
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q3_K);
5427
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q3_K);
5428
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q3_K);
5429
-
5430
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q4_K);
5431
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q4_K);
5432
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q4_K);
5433
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q4_K);
5434
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q4_K);
5435
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q4_K);
5436
-
5437
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q5_K);
5438
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q5_K);
5439
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q5_K);
5440
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q5_K);
5441
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q5_K);
5442
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q5_K);
5443
-
5444
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_Q6_K);
5445
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_Q6_K);
5446
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_Q6_K);
5447
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 0, GGML_TYPE_Q6_K);
5448
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 1, GGML_TYPE_Q6_K);
5449
- // ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 4, 2, GGML_TYPE_Q6_K);
5450
-
5451
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 0, GGML_TYPE_IQ4_NL);
5452
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 1, GGML_TYPE_IQ4_NL);
5453
- ggml_vk_test_dequant_matmul(ctx, 128, 512, 512, 2, 100, 1, 2, GGML_TYPE_IQ4_NL);
5454
-
5455
- std::cerr << std::endl;
5456
-
5457
6553
  const std::vector<size_t> vals {
6554
+ 512, 512, 128,
6555
+ 128, 512, 512,
6556
+ 4096, 512, 4096,
6557
+ 11008, 512, 4096,
6558
+ 4096, 512, 11008,
6559
+ 32000, 512, 4096,
5458
6560
  8, 8, 8,
5459
6561
  100, 46, 576,
5460
6562
  623, 111, 128,
@@ -5467,25 +6569,52 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
5467
6569
  49, 49, 128,
5468
6570
  128, 49, 49,
5469
6571
  4096, 49, 4096,
5470
- 11008, 49, 4096,
5471
- 4096, 49, 11008,
5472
- 32000, 49, 4096,
5473
- 512, 512, 128,
5474
- 128, 512, 512,
5475
- 4096, 512, 4096,
5476
- 11008, 512, 4096,
5477
- 4096, 512, 11008,
5478
- 32000, 512, 4096,
5479
6572
  };
5480
- const size_t num_it = 1;
6573
+ const size_t num_it = 100;
6574
+
5481
6575
  for (size_t i = 0; i < vals.size(); i += 3) {
5482
6576
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
5483
6577
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
5484
6578
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2);
5485
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
5486
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
5487
- // ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
5488
- std::cerr << std::endl;
6579
+ std::cerr << '\n';
6580
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0);
6581
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1);
6582
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2);
6583
+ std::cerr << '\n';
6584
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0);
6585
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1);
6586
+ ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2);
6587
+ std::cerr << '\n' << std::endl;
6588
+
6589
+ if (vals[i + 2] % 32 == 0) {
6590
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0);
6591
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0);
6592
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0);
6593
+ std::cerr << '\n';
6594
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0);
6595
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0);
6596
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0);
6597
+ std::cerr << '\n';
6598
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0);
6599
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0);
6600
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0);
6601
+ std::cerr << '\n' << std::endl;
6602
+ }
6603
+
6604
+ if (vals[i + 2] % 256 == 0) {
6605
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K);
6606
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K);
6607
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K);
6608
+ std::cerr << '\n';
6609
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K);
6610
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K);
6611
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K);
6612
+ std::cerr << '\n';
6613
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K);
6614
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K);
6615
+ ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K);
6616
+ std::cerr << '\n' << std::endl;
6617
+ }
5489
6618
  }
5490
6619
 
5491
6620
  GGML_ABORT("fatal error");
@@ -5532,6 +6661,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5532
6661
  const ggml_tensor * src0 = node->src[0];
5533
6662
  const ggml_tensor * src1 = node->src[1];
5534
6663
  const ggml_tensor * src2 = node->src[2];
6664
+ const ggml_tensor * src3 = node->src[3];
5535
6665
 
5536
6666
  switch (node->op) {
5537
6667
  // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor
@@ -5583,7 +6713,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5583
6713
  case GGML_OP_IM2COL:
5584
6714
  case GGML_OP_TIMESTEP_EMBEDDING:
5585
6715
  case GGML_OP_POOL_2D:
6716
+ case GGML_OP_RWKV_WKV6:
5586
6717
  case GGML_OP_LEAKY_RELU:
6718
+ case GGML_OP_FLASH_ATTN_EXT:
5587
6719
  break;
5588
6720
  default:
5589
6721
  std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
@@ -5601,6 +6733,48 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5601
6733
  } else {
5602
6734
  compute_ctx = ctx->compute_ctx.lock();
5603
6735
  }
6736
+ } else {
6737
+ switch (node->op) {
6738
+ case GGML_OP_REPEAT:
6739
+ case GGML_OP_ACC:
6740
+ case GGML_OP_GET_ROWS:
6741
+ case GGML_OP_ADD:
6742
+ case GGML_OP_MUL:
6743
+ case GGML_OP_DIV:
6744
+ case GGML_OP_CONCAT:
6745
+ case GGML_OP_UPSCALE:
6746
+ case GGML_OP_SCALE:
6747
+ case GGML_OP_SQR:
6748
+ case GGML_OP_SIN:
6749
+ case GGML_OP_COS:
6750
+ case GGML_OP_CLAMP:
6751
+ case GGML_OP_PAD:
6752
+ case GGML_OP_CPY:
6753
+ case GGML_OP_CONT:
6754
+ case GGML_OP_DUP:
6755
+ case GGML_OP_NORM:
6756
+ case GGML_OP_GROUP_NORM:
6757
+ case GGML_OP_RMS_NORM:
6758
+ case GGML_OP_UNARY:
6759
+ case GGML_OP_DIAG_MASK_INF:
6760
+ case GGML_OP_SOFT_MAX:
6761
+ case GGML_OP_ROPE:
6762
+ case GGML_OP_ARGSORT:
6763
+ case GGML_OP_SUM_ROWS:
6764
+ case GGML_OP_IM2COL:
6765
+ case GGML_OP_TIMESTEP_EMBEDDING:
6766
+ case GGML_OP_POOL_2D:
6767
+ case GGML_OP_LEAKY_RELU:
6768
+ {
6769
+ // These operations all go through ggml_vk_op_f32, so short-circuit and
6770
+ // do the only thing needed for the dryrun.
6771
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op);
6772
+ ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
6773
+ return false;
6774
+ }
6775
+ default:
6776
+ break;
6777
+ }
5604
6778
  }
5605
6779
 
5606
6780
  switch (node->op) {
@@ -5734,6 +6908,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
5734
6908
  case GGML_OP_MUL_MAT_ID:
5735
6909
  ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun);
5736
6910
 
6911
+ break;
6912
+
6913
+ case GGML_OP_FLASH_ATTN_EXT:
6914
+ ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node, dryrun);
6915
+
6916
+ break;
6917
+
6918
+ case GGML_OP_RWKV_WKV6:
6919
+ ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun);
6920
+
5737
6921
  break;
5738
6922
  default:
5739
6923
  return false;
@@ -5814,6 +6998,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
5814
6998
  case GGML_OP_IM2COL:
5815
6999
  case GGML_OP_TIMESTEP_EMBEDDING:
5816
7000
  case GGML_OP_POOL_2D:
7001
+ case GGML_OP_RWKV_WKV6:
5817
7002
  case GGML_OP_LEAKY_RELU:
5818
7003
  case GGML_OP_REPEAT:
5819
7004
  buf = tensor->buffer;
@@ -5834,6 +7019,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
5834
7019
  break;
5835
7020
  case GGML_OP_MUL_MAT:
5836
7021
  case GGML_OP_MUL_MAT_ID:
7022
+ case GGML_OP_FLASH_ATTN_EXT:
5837
7023
  buf = tensor->buffer;
5838
7024
 
5839
7025
  break;
@@ -6330,16 +7516,17 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
6330
7516
  bool first_node_in_batch = true; // true if next node will be first node in a batch
6331
7517
  int submit_node_idx = 0; // index to first node in a batch
6332
7518
 
6333
- // submit work every submit_count node to overlap CPU cmdbuffer generation with GPU execution
6334
- constexpr int submit_count = 100;
7519
+ // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
7520
+ // Start with a smaller count to get work submitted right away, and increase it after each submit.
7521
+ int nodes_per_submit = 20;
6335
7522
  int submitted_nodes = 0;
7523
+ int submit_count = 0;
6336
7524
  for (int i = 0; i < cgraph->n_nodes; i++) {
6337
7525
  if (first_node_in_batch) {
6338
7526
  submit_node_idx = i;
6339
7527
  }
6340
7528
 
6341
- bool submit = (submitted_nodes >= submit_count) || (i == last_node);
6342
-
7529
+ bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
6343
7530
 
6344
7531
  bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
6345
7532
 
@@ -6356,6 +7543,15 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
6356
7543
  if (submit) {
6357
7544
  first_node_in_batch = true;
6358
7545
  submitted_nodes = 0;
7546
+ switch (submit_count) {
7547
+ case 0:
7548
+ nodes_per_submit = 50;
7549
+ break;
7550
+ default:
7551
+ nodes_per_submit = 100;
7552
+ break;
7553
+ }
7554
+ submit_count++;
6359
7555
  }
6360
7556
  }
6361
7557
 
@@ -6512,6 +7708,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
6512
7708
  case GGML_OP_MUL_MAT:
6513
7709
  case GGML_OP_MUL_MAT_ID:
6514
7710
  {
7711
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7712
+ const vk_device& device = ggml_vk_get_device(ctx->device);
7713
+ if (op->op == GGML_OP_MUL_MAT_ID && !device->mul_mat_id_s && !device->mul_mat_id_m && !device->mul_mat_id_l) {
7714
+ // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
7715
+ return false;
7716
+ }
6515
7717
  switch (op->src[0]->type) {
6516
7718
  case GGML_TYPE_F32:
6517
7719
  case GGML_TYPE_F16:
@@ -6549,6 +7751,57 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
6549
7751
 
6550
7752
  return true;
6551
7753
  } break;
7754
+ case GGML_OP_FLASH_ATTN_EXT:
7755
+ {
7756
+ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
7757
+ if (!ggml_vk_get_device(ctx->device)->coopmat2) {
7758
+ return false;
7759
+ }
7760
+ switch (op->src[0]->ne[0]) {
7761
+ case 64:
7762
+ case 80:
7763
+ case 96:
7764
+ case 112:
7765
+ case 128:
7766
+ case 256:
7767
+ break;
7768
+ default:
7769
+ return false;
7770
+ }
7771
+ if (op->src[0]->type != GGML_TYPE_F32) {
7772
+ return false;
7773
+ }
7774
+ if (op->type != GGML_TYPE_F32) {
7775
+ return false;
7776
+ }
7777
+ if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
7778
+ return false;
7779
+ }
7780
+ // It's straightforward to support different K/V dequant, but would
7781
+ // significantly increase the number of pipelines
7782
+ if (op->src[1]->type != op->src[2]->type) {
7783
+ return false;
7784
+ }
7785
+ switch (op->src[1]->type) {
7786
+ case GGML_TYPE_F16:
7787
+ case GGML_TYPE_Q4_0:
7788
+ case GGML_TYPE_Q4_1:
7789
+ case GGML_TYPE_Q5_0:
7790
+ case GGML_TYPE_Q5_1:
7791
+ case GGML_TYPE_Q8_0:
7792
+ // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
7793
+ //case GGML_TYPE_Q2_K:
7794
+ //case GGML_TYPE_Q3_K:
7795
+ //case GGML_TYPE_Q4_K:
7796
+ //case GGML_TYPE_Q5_K:
7797
+ //case GGML_TYPE_Q6_K:
7798
+ case GGML_TYPE_IQ4_NL:
7799
+ break;
7800
+ default:
7801
+ return false;
7802
+ }
7803
+ return true;
7804
+ }
6552
7805
  case GGML_OP_GET_ROWS:
6553
7806
  {
6554
7807
  switch (op->src[0]->type) {
@@ -6585,7 +7838,16 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
6585
7838
  case GGML_OP_REPEAT:
6586
7839
  return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float);
6587
7840
  case GGML_OP_ROPE:
6588
- return ggml_is_contiguous(op->src[0]);
7841
+ {
7842
+ const int mode = ((const int32_t *) op->op_params)[2];
7843
+ if (mode & GGML_ROPE_TYPE_MROPE) {
7844
+ return false;
7845
+ }
7846
+ if (mode & GGML_ROPE_TYPE_VISION) {
7847
+ return false;
7848
+ }
7849
+ return ggml_is_contiguous(op->src[0]);
7850
+ }
6589
7851
  case GGML_OP_NONE:
6590
7852
  case GGML_OP_RESHAPE:
6591
7853
  case GGML_OP_VIEW:
@@ -6613,6 +7875,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
6613
7875
  case GGML_OP_IM2COL:
6614
7876
  case GGML_OP_TIMESTEP_EMBEDDING:
6615
7877
  case GGML_OP_POOL_2D:
7878
+ case GGML_OP_RWKV_WKV6:
6616
7879
  case GGML_OP_LEAKY_RELU:
6617
7880
  return true;
6618
7881
  default:
@@ -6709,8 +7972,9 @@ static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
6709
7972
 
6710
7973
  ggml_backend_reg_t ggml_backend_vk_reg() {
6711
7974
  static ggml_backend_reg reg = {
6712
- /* .iface = */ ggml_backend_vk_reg_i,
6713
- /* .context = */ nullptr,
7975
+ /* .api_version = */ GGML_BACKEND_API_VERSION,
7976
+ /* .iface = */ ggml_backend_vk_reg_i,
7977
+ /* .context = */ nullptr,
6714
7978
  };
6715
7979
 
6716
7980
  return &reg;
@@ -6862,6 +8126,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
6862
8126
  ggml_tensor * src0 = tensor->src[0];
6863
8127
  ggml_tensor * src1 = tensor->src[1];
6864
8128
  ggml_tensor * src2 = tensor->src[2];
8129
+ ggml_tensor * src3 = tensor->src[3];
6865
8130
 
6866
8131
  struct ggml_init_params iparams = {
6867
8132
  /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul,
@@ -6874,15 +8139,18 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
6874
8139
  struct ggml_tensor * src0_clone = nullptr;
6875
8140
  struct ggml_tensor * src1_clone = nullptr;
6876
8141
  struct ggml_tensor * src2_clone = nullptr;
8142
+ struct ggml_tensor * src3_clone = nullptr;
6877
8143
  struct ggml_tensor * tensor_clone = nullptr;
6878
8144
 
6879
8145
  size_t src0_size;
6880
8146
  size_t src1_size;
6881
8147
  size_t src2_size;
8148
+ size_t src3_size;
6882
8149
 
6883
8150
  void * src0_buffer = nullptr;
6884
8151
  void * src1_buffer = nullptr;
6885
8152
  void * src2_buffer = nullptr;
8153
+ void * src3_buffer = nullptr;
6886
8154
 
6887
8155
  if (src0 != nullptr) {
6888
8156
  src0_clone = ggml_dup_tensor(ggml_ctx, src0);
@@ -7010,8 +8278,53 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7010
8278
  ggml_vk_print_tensor(src2, "src2");
7011
8279
  }
7012
8280
  }
8281
+ if (src3 != nullptr) {
8282
+ src3_clone = ggml_dup_tensor(ggml_ctx, src3);
7013
8283
 
7014
- if (tensor->op == GGML_OP_MUL_MAT) {
8284
+ src3_size = ggml_nbytes(src3);
8285
+
8286
+ src3_buffer = malloc(src3_size);
8287
+ src3_clone->data = src3_buffer;
8288
+ if (ggml_backend_buffer_is_host(src3->buffer)) {
8289
+ memcpy(src3_clone->data, src3->data, src3_size);
8290
+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
8291
+ } else if (ggml_backend_buffer_is_vk(src3->buffer)) {
8292
+ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)src3->buffer->context;
8293
+ vk_buffer& buffer_gpu = buf_ctx->dev_buffer;
8294
+ uint64_t offset = vk_tensor_offset(src3) + src3->view_offs;
8295
+ if (!ggml_is_contiguous(src3) && ggml_vk_dim01_contiguous(src3)) {
8296
+ for (int i3 = 0; i3 < src3->ne[3]; i3++) {
8297
+ for (int i2 = 0; i2 < src3->ne[2]; i2++) {
8298
+ const int idx = i3*src3->ne[2] + i2;
8299
+ ggml_vk_buffer_read(buffer_gpu, offset + idx * src3->nb[2], ((char *)src3_clone->data + idx * src3_clone->nb[2]), src3->ne[1] * src3->nb[1]);
8300
+ }
8301
+ }
8302
+
8303
+ src3_clone->nb[0] = src3->nb[0];
8304
+ src3_clone->nb[1] = src3->nb[1];
8305
+ for (int i = 2; i < GGML_MAX_DIMS; i++) {
8306
+ src3_clone->nb[i] = src3_clone->nb[i - 1]*src3_clone->ne[i - 1];
8307
+ }
8308
+ } else {
8309
+ if (offset + src3_size >= buffer_gpu->size) {
8310
+ src3_size = buffer_gpu->size - offset;
8311
+ }
8312
+ ggml_vk_buffer_read(buffer_gpu, offset, src3_clone->data, src3_size);
8313
+ memcpy(src3_clone->nb, src3->nb, sizeof(size_t) * GGML_MAX_DIMS);
8314
+ }
8315
+ } else {
8316
+ GGML_ABORT("fatal error");
8317
+ }
8318
+
8319
+ if (vk_output_tensor > 0 && vk_output_tensor == check_counter) {
8320
+ ggml_vk_print_tensor(src3, "src3");
8321
+ }
8322
+ }
8323
+
8324
+ if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
8325
+ const float *params = (const float *)tensor->op_params;
8326
+ tensor_clone = ggml_flash_attn_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, src3_clone, params[0], params[1], params[2]);
8327
+ } else if (tensor->op == GGML_OP_MUL_MAT) {
7015
8328
  tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone);
7016
8329
  } else if (tensor->op == GGML_OP_MUL_MAT_ID) {
7017
8330
  tensor_clone = ggml_mul_mat_id(ggml_ctx, src0_clone, src1_clone, src2_clone);
@@ -7127,7 +8440,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7127
8440
  const int32_t max_period = tensor->op_params[1];
7128
8441
  tensor_clone = ggml_timestep_embedding(ggml_ctx, src0_clone, dim, max_period);
7129
8442
  } else if (tensor->op == GGML_OP_POOL_2D) {
7130
- enum ggml_op_pool op = static_cast<ggml_op_pool>(dst->op_params[0]);
8443
+ enum ggml_op_pool op = static_cast<ggml_op_pool>(tensor->op_params[0]);
7131
8444
  const int32_t k0 = tensor->op_params[1];
7132
8445
  const int32_t k1 = tensor->op_params[2];
7133
8446
  const int32_t s0 = tensor->op_params[3];
@@ -7139,7 +8452,11 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
7139
8452
  } else if (tensor->op == GGML_OP_LEAKY_RELU) {
7140
8453
  const float * op_params = (const float *)tensor->op_params;
7141
8454
  tensor_clone = ggml_leaky_relu(ggml_ctx, src0_clone, op_params[0], false);
7142
- } else {
8455
+ } else if (tensor->op == GGML_OP_RWKV_WKV6) {
8456
+ tensor_clone = ggml_rwkv_wkv6(ggml_ctx, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3],
8457
+ tensor->src[4], tensor->src[5]);
8458
+ }
8459
+ else {
7143
8460
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
7144
8461
  GGML_ABORT("fatal error");
7145
8462
  }
@@ -7336,3 +8653,5 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
7336
8653
  VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")");
7337
8654
  }
7338
8655
  #endif
8656
+
8657
+ GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg)