@fugood/llama.node 0.3.15 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +243 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +14 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -8
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2413 -228
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1004 -13516
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +127 -33
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +29 -293
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  130. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  131. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +210 -286
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  133. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  135. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  136. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  138. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +692 -126
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  142. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +21 -10
  143. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  144. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  145. package/src/llama.cpp/include/llama.h +30 -11
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  147. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  149. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  150. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  151. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  152. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  153. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  154. package/src/llama.cpp/src/llama-arch.cpp +161 -17
  155. package/src/llama.cpp/src/llama-arch.h +16 -0
  156. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  157. package/src/llama.cpp/src/llama-chat.h +6 -2
  158. package/src/llama.cpp/src/llama-context.cpp +108 -92
  159. package/src/llama.cpp/src/llama-context.h +1 -2
  160. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  161. package/src/llama.cpp/src/llama-graph.h +26 -6
  162. package/src/llama.cpp/src/llama-hparams.h +13 -0
  163. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  164. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  165. package/src/llama.cpp/src/llama-memory.h +1 -1
  166. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  167. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  168. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  169. package/src/llama.cpp/src/llama-model.cpp +1544 -291
  170. package/src/llama.cpp/src/llama-model.h +13 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  172. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  173. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  174. package/src/llama.cpp/src/llama.cpp +1 -1
  175. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  176. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  177. package/src/llama.cpp/tests/test-backend-ops.cpp +139 -57
  178. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  179. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  180. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  181. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  182. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  183. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  184. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  185. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  186. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  188. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  189. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  190. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  191. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  192. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  193. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  203. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -24,6 +24,28 @@
24
24
  #include <future>
25
25
  #include <thread>
26
26
 
27
+ #if defined(_MSC_VER)
28
+ # define NOMINMAX 1
29
+ # include <windows.h>
30
+ # define YIELD() YieldProcessor()
31
+ #elif defined(__clang__) || defined(__GNUC__)
32
+ # if defined(__x86_64__) ||defined(__i386__)
33
+ # include <immintrin.h>
34
+ # define YIELD() _mm_pause()
35
+ # elif defined(__arm__) || defined(__aarch64__)
36
+ # if defined(__clang__)
37
+ # include <arm_acle.h>
38
+ # define YIELD() __yield()
39
+ # else
40
+ # define YIELD() asm volatile("yield")
41
+ # endif
42
+ # endif
43
+ #endif
44
+
45
+ #if !defined(YIELD)
46
+ #define YIELD()
47
+ #endif
48
+
27
49
  #include "ggml-impl.h"
28
50
  #include "ggml-backend-impl.h"
29
51
 
@@ -31,6 +53,7 @@
31
53
 
32
54
  #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
33
55
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
56
+ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
34
57
 
35
58
  #define VK_VENDOR_ID_AMD 0x1002
36
59
  #define VK_VENDOR_ID_APPLE 0x106b
@@ -149,6 +172,7 @@ class vk_perf_logger;
149
172
  static void ggml_vk_destroy_buffer(vk_buffer& buf);
150
173
 
151
174
  static constexpr uint32_t mul_mat_vec_max_cols = 8;
175
+ static constexpr uint32_t p021_max_gqa_ratio = 8;
152
176
 
153
177
  enum vk_device_architecture {
154
178
  OTHER,
@@ -222,6 +246,7 @@ struct vk_device_struct {
222
246
  bool pipeline_robustness;
223
247
  vk::Device device;
224
248
  uint32_t vendor_id;
249
+ vk::DriverId driver_id;
225
250
  vk_device_architecture architecture;
226
251
  vk_queue compute_queue;
227
252
  vk_queue transfer_queue;
@@ -231,6 +256,9 @@ struct vk_device_struct {
231
256
  bool uma;
232
257
  bool prefer_host_memory;
233
258
  bool float_controls_rte_fp16;
259
+ bool subgroup_add;
260
+
261
+ bool integer_dot_product;
234
262
 
235
263
  bool subgroup_size_control;
236
264
  uint32_t subgroup_min_size;
@@ -243,6 +271,12 @@ struct vk_device_struct {
243
271
  uint32_t coopmat_m;
244
272
  uint32_t coopmat_n;
245
273
  uint32_t coopmat_k;
274
+
275
+ bool coopmat_int_support;
276
+ uint32_t coopmat_int_m;
277
+ uint32_t coopmat_int_n;
278
+ uint32_t coopmat_int_k;
279
+
246
280
  bool coopmat2;
247
281
 
248
282
  size_t idx;
@@ -261,10 +295,10 @@ struct vk_device_struct {
261
295
  vk_matmul_pipeline pipeline_matmul_f32_f16 {};
262
296
  vk_matmul_pipeline2 pipeline_matmul_f16;
263
297
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
264
- vk_pipeline pipeline_matmul_split_k_reduce;
265
298
 
266
- vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
267
299
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
300
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
301
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
268
302
 
269
303
  vk_matmul_pipeline pipeline_matmul_id_f32 {};
270
304
  vk_matmul_pipeline2 pipeline_matmul_id_f16;
@@ -272,12 +306,15 @@ struct vk_device_struct {
272
306
 
273
307
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
274
308
 
309
+ vk_pipeline pipeline_matmul_split_k_reduce;
310
+ vk_pipeline pipeline_quantize_q8_1;
311
+
275
312
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
276
313
  vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
277
314
  vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
278
315
  vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
279
316
 
280
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
317
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
281
318
  vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
282
319
  vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
283
320
  vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -339,6 +376,7 @@ struct vk_device_struct {
339
376
  vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
340
377
  vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
341
378
  vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
379
+ vk_pipeline pipeline_flash_attn_split_k_reduce;
342
380
 
343
381
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
344
382
  std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
@@ -488,6 +526,10 @@ struct vk_flash_attn_push_constants {
488
526
  uint32_t n_head_log2;
489
527
  float m0;
490
528
  float m1;
529
+
530
+ uint32_t gqa_ratio;
531
+ uint32_t split_kv;
532
+ uint32_t k_num;
491
533
  };
492
534
 
493
535
  struct vk_op_push_constants {
@@ -638,6 +680,13 @@ struct vk_op_rwkv_wkv7_push_constants {
638
680
  uint32_t H;
639
681
  };
640
682
 
683
+ struct vk_op_upscale_push_constants {
684
+ uint32_t ne; uint32_t a_offset; uint32_t d_offset;
685
+ uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
686
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
687
+ float sf0; float sf1; float sf2; float sf3;
688
+ };
689
+
641
690
  // Allow pre-recording command buffers
642
691
  struct vk_staging_memcpy {
643
692
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -647,13 +696,6 @@ struct vk_staging_memcpy {
647
696
  size_t n;
648
697
  };
649
698
 
650
- struct vk_op_upscale_push_constants {
651
- uint32_t ne; uint32_t a_offset; uint32_t d_offset;
652
- uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
653
- uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
654
- float sf0; float sf1; float sf2; float sf3;
655
- };
656
-
657
699
  struct vk_context_struct {
658
700
  vk_submission * s;
659
701
  std::vector<vk_sequence> seqs;
@@ -768,7 +810,8 @@ struct ggml_backend_vk_context {
768
810
  ggml_vk_garbage_collector gc;
769
811
  size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
770
812
  vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
771
- vk::Fence fence;
813
+ vk::Fence fence, almost_ready_fence;
814
+ bool almost_ready_fence_pending {};
772
815
 
773
816
  vk_buffer buffer_pool[MAX_VK_BUFFERS];
774
817
 
@@ -859,6 +902,39 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
859
902
 
860
903
  static void ggml_backend_vk_free(ggml_backend_t backend);
861
904
 
905
+ // Wait for ctx->fence to be signaled.
906
+ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
907
+ // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
908
+ // during this wait.
909
+ if (ctx->almost_ready_fence_pending) {
910
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence");
911
+ ctx->device->device.resetFences({ ctx->almost_ready_fence });
912
+ ctx->almost_ready_fence_pending = false;
913
+ }
914
+
915
+ // Spin (w/pause) waiting for the graph to finish executing.
916
+ vk::Result result;
917
+ while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) {
918
+ if (result != vk::Result::eNotReady) {
919
+ fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__);
920
+ exit(1);
921
+ }
922
+ for (uint32_t i = 0; i < 100; ++i) {
923
+ YIELD();
924
+ YIELD();
925
+ YIELD();
926
+ YIELD();
927
+ YIELD();
928
+ YIELD();
929
+ YIELD();
930
+ YIELD();
931
+ YIELD();
932
+ YIELD();
933
+ }
934
+ }
935
+ ctx->device->device.resetFences({ ctx->fence });
936
+ }
937
+
862
938
  // variables to track number of compiles in progress
863
939
  static uint32_t compile_count = 0;
864
940
  static std::mutex compile_count_mutex;
@@ -1460,7 +1536,7 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
1460
1536
 
1461
1537
  // small rows, large cols
1462
1538
  if (small_rows) {
1463
- return {flash_attention_num_small_rows, 128};
1539
+ return {flash_attention_num_small_rows, 64};
1464
1540
  }
1465
1541
  // small cols to reduce register count
1466
1542
  if (ggml_is_quantized(type) || D == 256) {
@@ -1596,6 +1672,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1596
1672
  // mulmat
1597
1673
  std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
1598
1674
  l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
1675
+ l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
1599
1676
  l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
1600
1677
  l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
1601
1678
  std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
@@ -1660,6 +1737,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
1660
1737
  m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
1661
1738
  s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
1662
1739
 
1740
+ l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
1741
+ m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
1742
+ s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
1743
+
1744
+ // chip specific tuning
1745
+ if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
1746
+ m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
1747
+ }
1748
+
1663
1749
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1664
1750
  m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1665
1751
  s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
@@ -1753,6 +1839,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1753
1839
  // can't use 256 for D==80.
1754
1840
  uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
1755
1841
  auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
1842
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
1843
+ GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
1756
1844
  return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
1757
1845
  };
1758
1846
 
@@ -1998,6 +2086,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
1998
2086
  if (device->mul_mat ## ID ## _s[TYPE]) \
1999
2087
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2000
2088
 
2089
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2090
+ if (device->mul_mat ## ID ## _l[TYPE]) \
2091
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2092
+ if (device->mul_mat ## ID ## _m[TYPE]) \
2093
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2094
+ if (device->mul_mat ## ID ## _s[TYPE]) \
2095
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2096
+
2001
2097
  // Create 2 variants, {f16,f32} accumulator
2002
2098
  #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2003
2099
  CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -2029,6 +2125,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
2029
2125
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2030
2126
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2031
2127
 
2128
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2129
+ if (device->integer_dot_product) {
2130
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2131
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2132
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2133
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2134
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2135
+ }
2136
+ #endif
2137
+
2032
2138
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2033
2139
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2034
2140
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2054,6 +2160,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2054
2160
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2055
2161
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2056
2162
  #undef CREATE_MM2
2163
+ #undef CREATE_MMQ
2057
2164
  #undef CREATE_MM
2058
2165
  } else {
2059
2166
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -2071,6 +2178,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2071
2178
  if (device->mul_mat ## ID ## _s[TYPE]) \
2072
2179
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2073
2180
 
2181
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2182
+ if (device->mul_mat ## ID ## _l[TYPE]) \
2183
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2184
+ if (device->mul_mat ## ID ## _m[TYPE]) \
2185
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2186
+ if (device->mul_mat ## ID ## _s[TYPE]) \
2187
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2188
+
2074
2189
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2075
2190
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2076
2191
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@@ -2097,6 +2212,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
2097
2212
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2098
2213
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2099
2214
 
2215
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2216
+ if (device->integer_dot_product) {
2217
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2218
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2219
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2220
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2221
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2222
+ }
2223
+ #endif
2224
+
2100
2225
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2101
2226
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2102
2227
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2130,7 +2255,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2130
2255
  uint32_t rm_stdq = 1;
2131
2256
  uint32_t rm_kq = 2;
2132
2257
  if (device->vendor_id == VK_VENDOR_ID_AMD) {
2133
- if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN
2258
+ if (device->architecture == AMD_GCN) {
2134
2259
  rm_stdq = 2;
2135
2260
  rm_kq = 4;
2136
2261
  }
@@ -2264,13 +2389,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
2264
2389
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2265
2390
 
2266
2391
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2392
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2393
+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2267
2394
 
2268
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2395
+ for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2396
+ if (device->subgroup_add && device->subgroup_require_full_support) {
2397
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
2398
+ } else {
2399
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
2400
+ }
2401
+ }
2269
2402
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2270
2403
 
2271
2404
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2272
2405
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2273
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2406
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2274
2407
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2275
2408
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2276
2409
 
@@ -2281,13 +2414,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
2281
2414
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2282
2415
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2283
2416
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2284
-
2285
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2286
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2287
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2288
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2289
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2290
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2417
+ if (device->float_controls_rte_fp16) {
2418
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2419
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2420
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2421
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2422
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2423
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2424
+ } else {
2425
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2426
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2427
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2428
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2429
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2430
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2431
+ }
2291
2432
 
2292
2433
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2293
2434
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
@@ -2436,6 +2577,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2436
2577
  bool pipeline_robustness = false;
2437
2578
  bool coopmat2_support = false;
2438
2579
  device->coopmat_support = false;
2580
+ device->integer_dot_product = false;
2439
2581
 
2440
2582
  for (const auto& properties : ext_props) {
2441
2583
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -2461,6 +2603,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
2461
2603
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2462
2604
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2463
2605
  coopmat2_support = true;
2606
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2607
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
2608
+ !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
2609
+ device->integer_dot_product = true;
2610
+ #endif
2464
2611
  }
2465
2612
  }
2466
2613
 
@@ -2471,13 +2618,16 @@ static vk_device ggml_vk_get_device(size_t idx) {
2471
2618
  vk::PhysicalDeviceDriverProperties driver_props;
2472
2619
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2473
2620
  vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2621
+ vk::PhysicalDeviceVulkan11Properties vk11_props;
2474
2622
  vk::PhysicalDeviceVulkan12Properties vk12_props;
2475
2623
  vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2624
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
2476
2625
 
2477
2626
  props2.pNext = &props3;
2478
2627
  props3.pNext = &subgroup_props;
2479
2628
  subgroup_props.pNext = &driver_props;
2480
- driver_props.pNext = &vk12_props;
2629
+ driver_props.pNext = &vk11_props;
2630
+ vk11_props.pNext = &vk12_props;
2481
2631
 
2482
2632
  VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
2483
2633
 
@@ -2506,9 +2656,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
2506
2656
  }
2507
2657
  #endif
2508
2658
 
2659
+ if (device->integer_dot_product) {
2660
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
2661
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
2662
+ }
2663
+
2509
2664
  device->physical_device.getProperties2(&props2);
2510
2665
  device->properties = props2.properties;
2511
2666
  device->vendor_id = device->properties.vendorID;
2667
+ device->driver_id = driver_props.driverID;
2512
2668
 
2513
2669
  const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
2514
2670
 
@@ -2541,6 +2697,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2541
2697
  }
2542
2698
  device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
2543
2699
 
2700
+ device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2701
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2702
+
2544
2703
  const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2545
2704
 
2546
2705
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -2549,6 +2708,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2549
2708
  device->coopmat_support = false;
2550
2709
  }
2551
2710
 
2711
+ device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
2712
+
2552
2713
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
2553
2714
 
2554
2715
  // Try to find a non-graphics compute queue and transfer-focused queues
@@ -2641,6 +2802,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
2641
2802
  device_extensions.push_back("VK_KHR_maintenance4");
2642
2803
  }
2643
2804
 
2805
+ VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
2806
+ shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
2807
+ if (device->integer_dot_product) {
2808
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
2809
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
2810
+ device_extensions.push_back("VK_KHR_shader_integer_dot_product");
2811
+ }
2812
+
2644
2813
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
2645
2814
 
2646
2815
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -2810,6 +2979,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
2810
2979
  device->coopmat_acc_f16_support = true;
2811
2980
  }
2812
2981
  }
2982
+ } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
2983
+ (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
2984
+ (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
2985
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
2986
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
2987
+ device->coopmat_int_m == 0
2988
+ ) {
2989
+ device->coopmat_int_support = true;
2990
+ device->coopmat_int_m = prop.MSize;
2991
+ device->coopmat_int_n = prop.NSize;
2992
+ device->coopmat_int_k = prop.KSize;
2813
2993
  }
2814
2994
  }
2815
2995
 
@@ -2914,25 +3094,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2914
3094
  vk::PhysicalDevice physical_device = devices[dev_num];
2915
3095
  std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
2916
3096
 
2917
- vk::PhysicalDeviceProperties2 props2;
2918
- vk::PhysicalDeviceMaintenance3Properties props3;
2919
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
2920
- vk::PhysicalDeviceDriverProperties driver_props;
2921
- props2.pNext = &props3;
2922
- props3.pNext = &subgroup_props;
2923
- subgroup_props.pNext = &driver_props;
2924
- physical_device.getProperties2(&props2);
2925
-
2926
- vk_device_architecture arch = get_device_architecture(physical_device);
2927
- uint32_t default_subgroup_size = get_subgroup_size("", arch);
2928
- const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
2929
-
2930
- const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
2931
-
2932
3097
  bool fp16_storage = false;
2933
3098
  bool fp16_compute = false;
2934
3099
  bool coopmat_support = false;
2935
3100
  bool coopmat2_support = false;
3101
+ bool integer_dot_product = false;
2936
3102
 
2937
3103
  for (auto properties : ext_props) {
2938
3104
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -2948,27 +3114,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2948
3114
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2949
3115
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2950
3116
  coopmat2_support = true;
3117
+ #endif
3118
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3119
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
3120
+ !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
3121
+ integer_dot_product = true;
2951
3122
  #endif
2952
3123
  }
2953
3124
  }
2954
3125
 
2955
3126
  const vk_device_architecture device_architecture = get_device_architecture(physical_device);
2956
3127
 
2957
- if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
2958
- coopmat_support = false;
2959
- }
2960
-
2961
3128
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
2962
3129
  bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
2963
3130
 
2964
3131
  bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2965
3132
 
2966
- vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
3133
+ vk::PhysicalDeviceProperties2 props2;
3134
+ vk::PhysicalDeviceMaintenance3Properties props3;
3135
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
3136
+ vk::PhysicalDeviceDriverProperties driver_props;
3137
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
3138
+ props2.pNext = &props3;
3139
+ props3.pNext = &subgroup_props;
3140
+ subgroup_props.pNext = &driver_props;
3141
+
3142
+ // Pointer to the last chain element
3143
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
3144
+
3145
+ if (integer_dot_product) {
3146
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
3147
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
3148
+ }
3149
+
3150
+ physical_device.getProperties2(&props2);
2967
3151
 
2968
3152
  VkPhysicalDeviceFeatures2 device_features2;
2969
3153
  device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
2970
3154
  device_features2.pNext = nullptr;
2971
- device_features2.features = (VkPhysicalDeviceFeatures)device_features;
2972
3155
 
2973
3156
  VkPhysicalDeviceVulkan11Features vk11_features;
2974
3157
  vk11_features.pNext = nullptr;
@@ -2981,7 +3164,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2981
3164
  vk11_features.pNext = &vk12_features;
2982
3165
 
2983
3166
  // Pointer to the last chain element
2984
- VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
3167
+ last_struct = (VkBaseOutStructure *)&vk12_features;
2985
3168
 
2986
3169
  #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
2987
3170
  VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
@@ -2993,20 +3176,39 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2993
3176
  last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
2994
3177
  last_struct = (VkBaseOutStructure *)&coopmat_features;
2995
3178
  }
3179
+ #endif
3180
+
3181
+ VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
3182
+ shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
3183
+ if (integer_dot_product) {
3184
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3185
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3186
+ }
2996
3187
 
2997
3188
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
2998
3189
 
2999
3190
  fp16 = fp16 && vk12_features.shaderFloat16;
3000
3191
 
3001
- coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
3192
+ uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
3193
+ const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
3194
+ const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
3195
+
3196
+ integer_dot_product = integer_dot_product
3197
+ && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
3198
+ && shader_integer_dot_product_features.shaderIntegerDotProduct;
3199
+
3200
+ coopmat_support = coopmat_support
3201
+ #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3202
+ && coopmat_features.cooperativeMatrix
3002
3203
  #endif
3204
+ && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
3003
3205
 
3004
3206
  std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
3005
3207
 
3006
3208
  std::string device_name = props2.properties.deviceName.data();
3007
- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
3209
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
3008
3210
  idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
3009
- props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
3211
+ props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
3010
3212
 
3011
3213
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
3012
3214
  GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -3208,6 +3410,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
3208
3410
  ctx->prealloc_size_split_k = 0;
3209
3411
 
3210
3412
  ctx->fence = ctx->device->device.createFence({});
3413
+ ctx->almost_ready_fence = ctx->device->device.createFence({});
3211
3414
 
3212
3415
  #ifdef GGML_VULKAN_CHECK_RESULTS
3213
3416
  const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
@@ -3272,6 +3475,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3272
3475
  }
3273
3476
  }
3274
3477
 
3478
+ // MMQ
3479
+ if (src1_type == GGML_TYPE_Q8_1) {
3480
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
3481
+
3482
+ if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
3483
+ return nullptr;
3484
+ }
3485
+
3486
+ return pipelines;
3487
+ }
3488
+
3275
3489
  if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
3276
3490
  return nullptr;
3277
3491
  }
@@ -3564,8 +3778,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
3564
3778
  return s;
3565
3779
  }
3566
3780
 
3567
-
3568
-
3569
3781
  static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
3570
3782
  const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
3571
3783
  const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
@@ -3989,14 +4201,20 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
3989
4201
  if (split_k == 3) {
3990
4202
  split_k = 2;
3991
4203
  }
4204
+ if (ctx->device->coopmat2) {
4205
+ // coopmat2 shader expects splits to be aligned to 256
4206
+ while (split_k > 1 && ((k / split_k) % 256) != 0) {
4207
+ split_k /= 2;
4208
+ }
4209
+ }
3992
4210
  }
3993
4211
  }
3994
4212
 
3995
4213
  return split_k;
3996
4214
  }
3997
4215
 
3998
- static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3999
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4216
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
4217
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
4000
4218
 
4001
4219
  if (ctx->device->coopmat2) {
4002
4220
  // Use large shader when the N dimension is greater than the medium shader's tile size
@@ -4021,9 +4239,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
4021
4239
  return aligned ? mmp->a_l : mmp->l;
4022
4240
  }
4023
4241
 
4024
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
4025
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
4026
- return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
4242
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
4243
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
4244
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
4027
4245
  }
4028
4246
 
4029
4247
  static void ggml_vk_matmul(
@@ -4033,7 +4251,7 @@ static void ggml_vk_matmul(
4033
4251
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
4034
4252
  uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
4035
4253
  uint32_t padded_n) {
4036
- VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
4254
+ VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
4037
4255
  ggml_vk_sync_buffers(subctx);
4038
4256
  if (split_k == 1) {
4039
4257
  const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
@@ -4051,7 +4269,7 @@ static void ggml_vk_matmul(
4051
4269
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
4052
4270
  }
4053
4271
 
4054
- static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
4272
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
4055
4273
  VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4056
4274
 
4057
4275
  if (ctx->device->coopmat2) {
@@ -4193,6 +4411,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
4193
4411
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
4194
4412
  }
4195
4413
 
4414
+ static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
4415
+ switch(type) {
4416
+ case GGML_TYPE_Q8_1:
4417
+ return ctx->device->pipeline_quantize_q8_1;
4418
+ default:
4419
+ std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
4420
+ GGML_ABORT("fatal error");
4421
+ }
4422
+ }
4423
+
4424
+ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
4425
+ VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
4426
+
4427
+ vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
4428
+
4429
+ ggml_vk_sync_buffers(subctx);
4430
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
4431
+ }
4432
+
4196
4433
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
4197
4434
  VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
4198
4435
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4244,10 +4481,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4244
4481
 
4245
4482
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4246
4483
 
4247
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
4484
+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
4485
+
4486
+ // Check for mmq first
4487
+ vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
4488
+
4489
+ if (mmp == nullptr) {
4490
+ // Fall back to f16 dequant mul mat
4491
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
4492
+ quantize_y = false;
4493
+ }
4248
4494
 
4249
4495
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4250
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
4496
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
4251
4497
 
4252
4498
  if (qx_needs_dequant) {
4253
4499
  // Fall back to dequant + f16 mulmat
@@ -4257,13 +4503,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4257
4503
  // Not implemented
4258
4504
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4259
4505
 
4260
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4261
- const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
4506
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
4507
+ const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
4262
4508
 
4263
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4509
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
4264
4510
 
4265
4511
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4266
- uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4512
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
4267
4513
  const int x_ne = ne01 * ne00;
4268
4514
  const int y_ne = padded_n * ne10;
4269
4515
  const int d_ne = ne11 * ne01;
@@ -4273,11 +4519,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4273
4519
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4274
4520
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4275
4521
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
4276
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
4522
+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
4277
4523
  const uint64_t d_sz = sizeof(float) * d_ne;
4278
4524
 
4279
4525
  vk_pipeline to_fp16_vk_0 = nullptr;
4280
4526
  vk_pipeline to_fp16_vk_1 = nullptr;
4527
+ vk_pipeline to_q8_1 = nullptr;
4281
4528
 
4282
4529
  if (x_non_contig) {
4283
4530
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
@@ -4292,6 +4539,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4292
4539
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
4293
4540
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
4294
4541
 
4542
+ if (quantize_y) {
4543
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
4544
+ }
4545
+
4295
4546
  if (dryrun) {
4296
4547
  const uint64_t x_sz_upd = x_sz * ne02 * ne03;
4297
4548
  const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -4305,7 +4556,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4305
4556
  if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
4306
4557
  ctx->prealloc_size_x = x_sz_upd;
4307
4558
  }
4308
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
4559
+ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
4309
4560
  ctx->prealloc_size_y = y_sz_upd;
4310
4561
  }
4311
4562
  if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
@@ -4320,6 +4571,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4320
4571
  if (qy_needs_dequant) {
4321
4572
  ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
4322
4573
  }
4574
+ if (quantize_y) {
4575
+ ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
4576
+ }
4323
4577
  if (split_k > 1) {
4324
4578
  ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
4325
4579
  }
@@ -4355,6 +4609,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4355
4609
  if (qy_needs_dequant) {
4356
4610
  d_Y = ctx->prealloc_y;
4357
4611
  GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4612
+ } else if (quantize_y) {
4613
+ d_Y = ctx->prealloc_y;
4614
+ GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
4358
4615
  } else {
4359
4616
  d_Y = d_Qy;
4360
4617
  y_buf_offset = qy_buf_offset;
@@ -4371,6 +4628,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4371
4628
  if (y_non_contig) {
4372
4629
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
4373
4630
  }
4631
+ if (quantize_y) {
4632
+ ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
4633
+ }
4374
4634
 
4375
4635
  uint32_t stride_batch_x = ne00*ne01;
4376
4636
  uint32_t stride_batch_y = ne10*ne11;
@@ -4379,7 +4639,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4379
4639
  stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
4380
4640
  }
4381
4641
 
4382
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
4642
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
4383
4643
  stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
4384
4644
  }
4385
4645
 
@@ -4627,9 +4887,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4627
4887
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4628
4888
  const uint64_t d_sz = sizeof(float) * d_ne;
4629
4889
 
4890
+ // With grouped query attention there are > 1 Q matrices per K, V matrix.
4891
+ uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
4892
+ if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
4893
+ gqa_ratio = 1;
4894
+ }
4895
+
4630
4896
  if (dryrun) {
4631
4897
  // Request descriptor sets
4632
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
4898
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
4633
4899
  return;
4634
4900
  }
4635
4901
 
@@ -4653,8 +4919,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4653
4919
 
4654
4920
  // compute
4655
4921
  const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
4922
+
4923
+ uint32_t workgroups_z = (uint32_t)ne12;
4924
+ // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
4925
+ if (gqa_ratio > 1) {
4926
+ workgroups_z /= gqa_ratio;
4927
+ }
4928
+
4656
4929
  ggml_vk_sync_buffers(subctx);
4657
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
4930
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
4658
4931
  }
4659
4932
 
4660
4933
  static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -5198,7 +5471,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5198
5471
  const uint32_t nbm1 = mask ? mask->nb[1] : 0;
5199
5472
 
5200
5473
  const uint32_t D = neq0;
5201
- const uint32_t N = neq1;
5474
+ uint32_t N = neq1;
5202
5475
  const uint32_t KV = nek1;
5203
5476
 
5204
5477
  GGML_ASSERT(ne0 == D);
@@ -5253,12 +5526,60 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5253
5526
  // the "aligned" shader variant will forcibly align strides, for performance
5254
5527
  (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
5255
5528
 
5529
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
5530
+ GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
5531
+
5256
5532
  vk_pipeline pipeline = pipelines[aligned];
5257
5533
  assert(pipeline);
5258
5534
 
5535
+ uint32_t gqa_ratio = 1;
5536
+ uint32_t qk_ratio = neq2 / nek2;
5537
+ uint32_t workgroups_x = (uint32_t)neq1;
5538
+ uint32_t workgroups_y = (uint32_t)neq2;
5539
+ uint32_t workgroups_z = (uint32_t)neq3;
5540
+
5541
+ if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
5542
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5543
+ // grouped query attention - make the N dimension equal to gqa_ratio, reduce
5544
+ // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5545
+ // and change addressing calculations to index Q's dimension 2.
5546
+ gqa_ratio = qk_ratio;
5547
+ N = gqa_ratio;
5548
+ workgroups_y /= N;
5549
+ }
5550
+
5551
+ uint32_t split_kv = KV;
5552
+ uint32_t split_k = 1;
5553
+
5554
+ // Try to use split_k when KV is large enough to be worth the overhead
5555
+ if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
5556
+ // Try to run two workgroups per SM.
5557
+ split_k = ctx->device->shader_core_count * 2 / workgroups_y;
5558
+ if (split_k > 1) {
5559
+ // Try to evenly split KV into split_k chunks, but it needs to be a multiple
5560
+ // of "align", so recompute split_k based on that.
5561
+ split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
5562
+ split_k = CEIL_DIV(KV, split_kv);
5563
+ workgroups_x = split_k;
5564
+ }
5565
+ }
5566
+
5567
+ // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
5568
+ // and the per-row m and L values (ne1 rows).
5569
+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
5570
+ if (split_k_size > ctx->device->max_memory_allocation_size) {
5571
+ GGML_ABORT("Requested preallocation size is too large");
5572
+ }
5573
+ if (ctx->prealloc_size_split_k < split_k_size) {
5574
+ ctx->prealloc_size_split_k = split_k_size;
5575
+ }
5576
+
5259
5577
  if (dryrun) {
5260
5578
  // Request descriptor sets
5261
5579
  ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5580
+ if (split_k > 1) {
5581
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
5582
+ }
5262
5583
  return;
5263
5584
  }
5264
5585
 
@@ -5279,8 +5600,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5279
5600
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5280
5601
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5281
5602
 
5282
- ggml_vk_sync_buffers(subctx);
5283
-
5284
5603
  vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
5285
5604
  size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
5286
5605
 
@@ -5345,16 +5664,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5345
5664
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
5346
5665
  nbm1,
5347
5666
  scale, max_bias, logit_softcap,
5348
- mask != nullptr, n_head_log2, m0, m1 };
5349
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5350
- {
5351
- vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5352
- vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5353
- vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5354
- vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5355
- vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5356
- },
5357
- sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
5667
+ mask != nullptr, n_head_log2, m0, m1,
5668
+ gqa_ratio, split_kv, split_k };
5669
+
5670
+ ggml_vk_sync_buffers(subctx);
5671
+
5672
+ if (split_k > 1) {
5673
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5674
+ {
5675
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5676
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5677
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5678
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5679
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5680
+ },
5681
+ // We only use split_k when group query attention is enabled, which means
5682
+ // there's no more than one tile of rows (i.e. workgroups_x would have been
5683
+ // one). We reuse workgroups_x to mean the number of splits, so we need to
5684
+ // cancel out the divide by wg_denoms[0].
5685
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
5686
+
5687
+ ggml_vk_sync_buffers(subctx);
5688
+ const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
5689
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
5690
+ {
5691
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5692
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5693
+ },
5694
+ pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
5695
+ } else {
5696
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5697
+ {
5698
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5699
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5700
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5701
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5702
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5703
+ },
5704
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5705
+ }
5358
5706
  }
5359
5707
 
5360
5708
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
@@ -5408,7 +5756,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5408
5756
  }
5409
5757
  return nullptr;
5410
5758
  case GGML_OP_UPSCALE:
5411
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5759
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
5412
5760
  return ctx->device->pipeline_upscale_f32;
5413
5761
  }
5414
5762
  return nullptr;
@@ -5665,6 +6013,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5665
6013
  case GGML_OP_REPEAT:
5666
6014
  case GGML_OP_REPEAT_BACK:
5667
6015
  case GGML_OP_ROPE:
6016
+ case GGML_OP_RMS_NORM:
5668
6017
  return true;
5669
6018
  default:
5670
6019
  return false;
@@ -5875,7 +6224,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5875
6224
 
5876
6225
  switch (op) {
5877
6226
  case GGML_OP_NORM:
5878
- case GGML_OP_RMS_NORM:
5879
6227
  case GGML_OP_RMS_NORM_BACK:
5880
6228
  case GGML_OP_L2_NORM:
5881
6229
  case GGML_OP_SOFT_MAX:
@@ -5892,6 +6240,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5892
6240
  elements = { nr, 1, 1 };
5893
6241
  }
5894
6242
  } break;
6243
+ case GGML_OP_RMS_NORM:
6244
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
6245
+ break;
6246
+
5895
6247
  case GGML_OP_SUM:
5896
6248
  // We use GGML_OP_SUM_ROWS with 1 row.
5897
6249
  elements = { 1, 1, 1 };
@@ -6542,7 +6894,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
6542
6894
 
6543
6895
  static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6544
6896
  float * op_params = (float *)dst->op_params;
6545
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6897
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
6898
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
6899
+
6900
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
6901
+ (uint32_t)ggml_nelements(src0),
6902
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6903
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6904
+ 0,
6905
+ op_params[0], 0.0f,
6906
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6907
+ }, dryrun);
6546
6908
  }
6547
6909
 
6548
6910
  static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -6895,6 +7257,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
6895
7257
  }
6896
7258
  }
6897
7259
 
7260
+ if (ctx->device->need_compiles) {
7261
+ ggml_vk_load_shaders(ctx->device);
7262
+ }
7263
+
6898
7264
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
6899
7265
 
6900
7266
  vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -7143,6 +7509,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
7143
7509
 
7144
7510
  ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
7145
7511
 
7512
+ if (ctx->device->need_compiles) {
7513
+ ggml_vk_load_shaders(ctx->device);
7514
+ }
7515
+
7146
7516
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7147
7517
 
7148
7518
  ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
@@ -7202,66 +7572,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
7202
7572
  free(x_chk);
7203
7573
  }
7204
7574
 
7205
- static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
7575
+ // This does not work without ggml q8_1 quantization support
7576
+ //
7577
+ // typedef uint16_t ggml_half;
7578
+ // typedef uint32_t ggml_half2;
7579
+ //
7580
+ // #define QK8_1 32
7581
+ // typedef struct {
7582
+ // union {
7583
+ // struct {
7584
+ // ggml_half d; // delta
7585
+ // ggml_half s; // d * sum(qs[i])
7586
+ // } GGML_COMMON_AGGR_S;
7587
+ // ggml_half2 ds;
7588
+ // } GGML_COMMON_AGGR_U;
7589
+ // int8_t qs[QK8_1]; // quants
7590
+ // } block_q8_1;
7591
+ //
7592
+ // static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
7593
+ // VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
7594
+ // GGML_ASSERT(quant == GGML_TYPE_Q8_1);
7595
+ //
7596
+ // const size_t x_sz = sizeof(float) * ne;
7597
+ // const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
7598
+ // float * x = (float *) malloc(x_sz);
7599
+ // block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
7600
+ // block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
7601
+ // vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7602
+ // vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7603
+ //
7604
+ // for (size_t i = 0; i < ne; i++) {
7605
+ // x[i] = rand() / (float)RAND_MAX;
7606
+ // }
7607
+ //
7608
+ // vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
7609
+ //
7610
+ // ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
7611
+ //
7612
+ // if (ctx->device->need_compiles) {
7613
+ // ggml_vk_load_shaders(ctx->device);
7614
+ // }
7615
+ //
7616
+ // ggml_pipeline_allocate_descriptor_sets(ctx->device);
7617
+ //
7618
+ // ggml_vk_buffer_write(x_buf, 0, x, x_sz);
7619
+ //
7620
+ // vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
7621
+ // ggml_vk_ctx_begin(ctx->device, subctx);
7622
+ // ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
7623
+ // ggml_vk_ctx_end(subctx);
7624
+ //
7625
+ // auto begin = std::chrono::high_resolution_clock::now();
7626
+ //
7627
+ // ggml_vk_submit(subctx, ctx->fence);
7628
+ // VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
7629
+ // ctx->device->device.resetFences({ ctx->fence });
7630
+ //
7631
+ // auto end = std::chrono::high_resolution_clock::now();
7632
+ //
7633
+ // double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
7634
+ // ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
7635
+ //
7636
+ // ggml_vk_quantize_data(x, qx_res, ne, quant);
7637
+ //
7638
+ // int first_err = -1;
7639
+ //
7640
+ // for (size_t i = 0; i < ne / 32; i++) {
7641
+ // double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
7642
+ //
7643
+ // if (first_err < 0 && error > 0.1) {
7644
+ // first_err = i;
7645
+ // }
7646
+ //
7647
+ // error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
7648
+ //
7649
+ // if (first_err < 0 && error > 0.1) {
7650
+ // first_err = i;
7651
+ // }
7652
+ //
7653
+ // for (size_t j = 0; j < 32; j++) {
7654
+ // uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
7655
+ //
7656
+ // if (first_err < 0 && error > 1) {
7657
+ // first_err = i;
7658
+ // }
7659
+ // }
7660
+ // }
7661
+ //
7662
+ // std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
7663
+ //
7664
+ // if (first_err != -1) {
7665
+ // std::cerr << "first_error = " << first_err << std::endl;
7666
+ // std::cerr << "Actual result: " << std::endl << std::endl;
7667
+ // std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
7668
+ // for (size_t j = 0; j < 32; j++) {
7669
+ // std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
7670
+ // }
7671
+ // std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
7672
+ // std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
7673
+ // for (size_t j = 0; j < 32; j++) {
7674
+ // std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
7675
+ // }
7676
+ // std::cerr << std::endl;
7677
+ // }
7678
+ //
7679
+ // ggml_vk_destroy_buffer(x_buf);
7680
+ // ggml_vk_destroy_buffer(qx_buf);
7681
+ //
7682
+ // free(x);
7683
+ // free(qx);
7684
+ // free(qx_res);
7685
+ // }
7686
+
7687
+ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
7206
7688
  VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
7207
7689
  const size_t x_ne = m * k * batch;
7208
7690
  const size_t y_ne = k * n * batch;
7209
7691
  const size_t d_ne = m * n * batch;
7210
7692
 
7693
+ vk_matmul_pipeline2 * pipelines;
7694
+
7695
+ if (mmq) {
7696
+ pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
7697
+ } else {
7698
+ pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
7699
+ }
7700
+
7701
+ const bool fp16acc = ctx->device->fp16;
7702
+
7211
7703
  vk_pipeline p;
7212
7704
  std::string shname;
7213
7705
  if (shader_size == 0) {
7214
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
7706
+ p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
7215
7707
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
7216
7708
  } else if (shader_size == 1) {
7217
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
7709
+ p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
7218
7710
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
7219
7711
  } else if (shader_size == 2) {
7220
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
7712
+ p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
7221
7713
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
7222
7714
  } else {
7223
7715
  GGML_ASSERT(0);
7224
7716
  }
7225
7717
 
7226
- const size_t kpad = ggml_vk_align_size(k, p->align);
7718
+ const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
7227
7719
 
7228
- if (k != kpad) {
7720
+ if (mmq || k != kpad) {
7229
7721
  if (shader_size == 0) {
7230
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
7722
+ p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
7231
7723
  shname = std::string(ggml_type_name(quant)) + "_S";
7232
7724
  } else if (shader_size == 1) {
7233
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
7725
+ p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
7234
7726
  shname = std::string(ggml_type_name(quant)) + "_M";
7235
7727
  } else if (shader_size == 2) {
7236
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
7728
+ p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
7237
7729
  shname = std::string(ggml_type_name(quant)) + "_L";
7238
7730
  } else {
7239
7731
  GGML_ASSERT(0);
7240
7732
  }
7241
7733
  }
7242
7734
 
7735
+ if (p == nullptr) {
7736
+ std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
7737
+ return;
7738
+ }
7739
+
7243
7740
  const size_t x_sz = sizeof(float) * x_ne;
7244
7741
  const size_t y_sz = sizeof(float) * y_ne;
7245
7742
  const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
7743
+ const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
7246
7744
  const size_t d_sz = sizeof(float) * d_ne;
7247
7745
  float * x = (float *) malloc(x_sz);
7248
7746
  float * y = (float *) malloc(y_sz);
7249
7747
  void * qx = malloc(qx_sz);
7250
7748
  vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7251
7749
  vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7750
+ vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7252
7751
  vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7253
7752
  float * d = (float *) malloc(d_sz);
7254
7753
  float * d_chk = (float *) malloc(d_sz);
7255
7754
 
7256
7755
  for (size_t i = 0; i < x_ne; i++) {
7257
7756
  x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7757
+ // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
7758
+ // x[i] = i % k;
7258
7759
  }
7259
7760
 
7260
7761
  ggml_vk_quantize_data(x, qx, x_ne, quant);
7261
7762
 
7262
7763
  for (size_t i = 0; i < y_ne; i++) {
7263
- // y[i] = rand() / (float)RAND_MAX;
7264
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
7764
+ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7765
+ // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
7766
+ // y[i] = i % k;
7265
7767
  }
7266
7768
 
7267
7769
  ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
@@ -7276,6 +7778,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7276
7778
  ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
7277
7779
  }
7278
7780
  }
7781
+ if (mmq) {
7782
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
7783
+ }
7784
+
7785
+ if (ctx->device->need_compiles) {
7786
+ ggml_vk_load_shaders(ctx->device);
7787
+ }
7279
7788
 
7280
7789
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7281
7790
 
@@ -7284,13 +7793,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7284
7793
 
7285
7794
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
7286
7795
  ggml_vk_ctx_begin(ctx->device, subctx);
7287
- for (size_t i = 0; i < num_it; i++) {
7288
- ggml_vk_matmul(
7289
- ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
7290
- m, n, k,
7291
- k, k, m, k*m, k*n, m*n,
7292
- split_k, batch, batch, batch, 1, 1, n
7293
- );
7796
+ if (mmq) {
7797
+ for (size_t i = 0; i < num_it; i++) {
7798
+ ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
7799
+ ggml_vk_matmul(
7800
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7801
+ m, n, k,
7802
+ k, k, m, k*m, k*n, m*n,
7803
+ split_k, batch, batch, batch, 1, 1, n
7804
+ );
7805
+ }
7806
+ } else {
7807
+ for (size_t i = 0; i < num_it; i++) {
7808
+ ggml_vk_matmul(
7809
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7810
+ m, n, k,
7811
+ k, k, m, k*m, k*n, m*n,
7812
+ split_k, batch, batch, batch, 1, 1, n
7813
+ );
7814
+ }
7294
7815
  }
7295
7816
  ggml_vk_ctx_end(subctx);
7296
7817
 
@@ -7348,7 +7869,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7348
7869
 
7349
7870
  double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
7350
7871
 
7351
- std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
7872
+ std::cerr << "TEST dequant matmul " << shname;
7873
+ if (mmq) {
7874
+ std::cerr << " mmq";
7875
+ }
7876
+ std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
7352
7877
 
7353
7878
  if (avg_err > 0.01 || std::isnan(avg_err)) {
7354
7879
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
@@ -7358,6 +7883,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7358
7883
  std::cerr << "Expected result: " << std::endl << std::endl;
7359
7884
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
7360
7885
 
7886
+ std::cerr << "src0: " << std::endl << std::endl;
7887
+ ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
7888
+ std::cerr << std::endl;
7889
+ std::cerr << "src1: " << std::endl << std::endl;
7890
+ ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
7891
+
7361
7892
  if (split_k > 1) {
7362
7893
  float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
7363
7894
  ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
@@ -7380,6 +7911,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7380
7911
 
7381
7912
  ggml_vk_destroy_buffer(qx_buf);
7382
7913
  ggml_vk_destroy_buffer(y_buf);
7914
+ ggml_vk_destroy_buffer(qy_buf);
7383
7915
  ggml_vk_destroy_buffer(d_buf);
7384
7916
 
7385
7917
  free(x);
@@ -7414,6 +7946,24 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
7414
7946
  };
7415
7947
  const size_t num_it = 100;
7416
7948
 
7949
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
7950
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
7951
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
7952
+
7953
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
7954
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
7955
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
7956
+
7957
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
7958
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
7959
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
7960
+
7961
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
7962
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
7963
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
7964
+
7965
+ abort();
7966
+
7417
7967
  for (size_t i = 0; i < vals.size(); i += 3) {
7418
7968
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
7419
7969
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
@@ -7488,11 +8038,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
7488
8038
  }
7489
8039
  }
7490
8040
 
7491
- static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence);
8041
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
7492
8042
 
7493
8043
  // Returns true if node has enqueued work into the queue, false otherwise
7494
8044
  // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
7495
- static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){
8045
+ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
7496
8046
  if (ggml_is_empty(node) || !node->buffer) {
7497
8047
  return false;
7498
8048
  }
@@ -7864,7 +8414,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7864
8414
 
7865
8415
  ctx->compute_ctx.reset();
7866
8416
 
7867
- bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false);
8417
+ bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
7868
8418
  if (!ok) {
7869
8419
  if (node->op == GGML_OP_UNARY) {
7870
8420
  std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@@ -7878,7 +8428,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7878
8428
  return true;
7879
8429
  }
7880
8430
 
7881
- static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){
8431
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
7882
8432
  ggml_backend_buffer * buf = nullptr;
7883
8433
 
7884
8434
  switch (tensor->op) {
@@ -7981,12 +8531,15 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7981
8531
  memcpy(cpy.dst, cpy.src, cpy.n);
7982
8532
  }
7983
8533
 
7984
- ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{});
8534
+ if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) {
8535
+ ggml_vk_submit(subctx, ctx->almost_ready_fence);
8536
+ ctx->almost_ready_fence_pending = true;
8537
+ } else {
8538
+ ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{});
8539
+ }
7985
8540
 
7986
8541
  if (use_fence) {
7987
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
7988
-
7989
- ctx->device->device.resetFences({ ctx->fence });
8542
+ ggml_vk_wait_for_fence(ctx);
7990
8543
  }
7991
8544
  #ifdef GGML_VULKAN_CHECK_RESULTS
7992
8545
  ggml_vk_check_results_1(tensor);
@@ -8072,6 +8625,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
8072
8625
  ctx->gc.events.clear();
8073
8626
 
8074
8627
  ctx->device->device.destroyFence(ctx->fence);
8628
+ ctx->device->device.destroyFence(ctx->almost_ready_fence);
8075
8629
  }
8076
8630
 
8077
8631
  static int ggml_vk_get_device_count() {
@@ -8418,8 +8972,7 @@ static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
8418
8972
  }
8419
8973
 
8420
8974
  ggml_vk_submit(transfer_ctx, ctx->fence);
8421
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
8422
- ctx->device->device.resetFences({ ctx->fence });
8975
+ ggml_vk_wait_for_fence(ctx);
8423
8976
 
8424
8977
  for (auto& cpy : transfer_ctx->out_memcpys) {
8425
8978
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -8438,7 +8991,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8438
8991
 
8439
8992
  uint64_t total_mat_mul_bytes = 0;
8440
8993
  for (int i = 0; i < cgraph->n_nodes; i++) {
8441
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
8994
+ ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
8442
8995
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8443
8996
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8444
8997
  }
@@ -8480,11 +9033,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8480
9033
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8481
9034
  }
8482
9035
 
9036
+ // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
9037
+ bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
8483
9038
  bool submit = (submitted_nodes >= nodes_per_submit) ||
8484
9039
  (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
8485
- (i == last_node);
9040
+ (i == last_node) ||
9041
+ (almost_ready && !ctx->almost_ready_fence_pending);
8486
9042
 
8487
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
9043
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
8488
9044
 
8489
9045
  if (enqueued) {
8490
9046
  ++submitted_nodes;
@@ -8496,7 +9052,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8496
9052
  #endif
8497
9053
  }
8498
9054
 
8499
- if (submit) {
9055
+ if (submit && enqueued) {
8500
9056
  first_node_in_batch = true;
8501
9057
  submitted_nodes = 0;
8502
9058
  mul_mat_bytes = 0;
@@ -8726,10 +9282,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8726
9282
  case 112:
8727
9283
  case 128:
8728
9284
  case 256:
9285
+ case 575: // DeepSeek MLA
8729
9286
  break;
8730
9287
  default:
8731
9288
  return false;
8732
9289
  }
9290
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
9291
+ // different head sizes of K and V are not supported yet
9292
+ return false;
9293
+ }
8733
9294
  if (op->src[0]->type != GGML_TYPE_F32) {
8734
9295
  return false;
8735
9296
  }
@@ -8848,10 +9409,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8848
9409
  case GGML_OP_VIEW:
8849
9410
  case GGML_OP_PERMUTE:
8850
9411
  case GGML_OP_TRANSPOSE:
9412
+ case GGML_OP_RMS_NORM:
8851
9413
  return true;
8852
9414
  case GGML_OP_NORM:
8853
9415
  case GGML_OP_GROUP_NORM:
8854
- case GGML_OP_RMS_NORM:
8855
9416
  case GGML_OP_L2_NORM:
8856
9417
  return ggml_is_contiguous(op->src[0]);
8857
9418
  case GGML_OP_ADD:
@@ -8865,9 +9426,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8865
9426
  case GGML_OP_COS:
8866
9427
  case GGML_OP_CLAMP:
8867
9428
  return op->src[0]->type == GGML_TYPE_F32;
9429
+ case GGML_OP_UPSCALE:
9430
+ return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
8868
9431
  case GGML_OP_ACC:
8869
9432
  case GGML_OP_CONCAT:
8870
- case GGML_OP_UPSCALE:
8871
9433
  case GGML_OP_SCALE:
8872
9434
  case GGML_OP_PAD:
8873
9435
  case GGML_OP_DIAG_MASK_INF:
@@ -9220,7 +9782,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9220
9782
  }
9221
9783
 
9222
9784
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
9223
- const float *params = (const float *)tensor->op_params;
9785
+ const float * params = (const float *)tensor->op_params;
9224
9786
  tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
9225
9787
  } else if (tensor->op == GGML_OP_MUL_MAT) {
9226
9788
  tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
@@ -9235,9 +9797,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9235
9797
  } else if (tensor->op == GGML_OP_CONCAT) {
9236
9798
  tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
9237
9799
  } else if (tensor->op == GGML_OP_UPSCALE) {
9238
- tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9800
+ tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]);
9239
9801
  } else if (tensor->op == GGML_OP_SCALE) {
9240
- tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
9802
+ const float * params = (const float *)tensor->op_params;
9803
+ tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
9241
9804
  } else if (tensor->op == GGML_OP_SQR) {
9242
9805
  tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
9243
9806
  } else if (tensor->op == GGML_OP_SIN) {
@@ -9245,7 +9808,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9245
9808
  } else if (tensor->op == GGML_OP_COS) {
9246
9809
  tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
9247
9810
  } else if (tensor->op == GGML_OP_CLAMP) {
9248
- tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9811
+ const float * params = (const float *)tensor->op_params;
9812
+ tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
9249
9813
  } else if (tensor->op == GGML_OP_PAD) {
9250
9814
  tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
9251
9815
  } else if (tensor->op == GGML_OP_REPEAT) {
@@ -9259,7 +9823,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9259
9823
  } else if (tensor->op == GGML_OP_NORM) {
9260
9824
  tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9261
9825
  } else if (tensor->op == GGML_OP_GROUP_NORM) {
9262
- tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
9826
+ const float * float_params = (const float *)tensor->op_params;
9827
+ tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
9263
9828
  } else if (tensor->op == GGML_OP_RMS_NORM) {
9264
9829
  tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9265
9830
  } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
@@ -9272,14 +9837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9272
9837
  tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
9273
9838
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9274
9839
  if (src1 != nullptr) {
9275
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9840
+ const float * params = (const float *)tensor->op_params;
9841
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
9276
9842
  } else {
9277
9843
  tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
9278
9844
  }
9279
9845
  } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9280
9846
  tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9281
9847
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
9282
- tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
9848
+ tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
9283
9849
  } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
9284
9850
  const int n_dims = ((int32_t *) tensor->op_params)[1];
9285
9851
  const int mode = ((int32_t *) tensor->op_params)[2];