@fugood/llama.node 0.3.16 → 0.3.17

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (202) hide show
  1. package/CMakeLists.txt +3 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +5 -0
  19. package/package.json +1 -1
  20. package/src/LlamaCompletionWorker.cpp +8 -0
  21. package/src/LlamaCompletionWorker.h +1 -0
  22. package/src/LlamaContext.cpp +3 -2
  23. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +124 -0
  24. package/src/llama.cpp/.github/workflows/build.yml +70 -27
  25. package/src/llama.cpp/.github/workflows/docker.yml +6 -6
  26. package/src/llama.cpp/.github/workflows/server.yml +7 -11
  27. package/src/llama.cpp/CMakeLists.txt +23 -1
  28. package/src/llama.cpp/common/CMakeLists.txt +6 -3
  29. package/src/llama.cpp/common/arg.cpp +809 -105
  30. package/src/llama.cpp/common/arg.h +9 -0
  31. package/src/llama.cpp/common/chat.cpp +1 -1
  32. package/src/llama.cpp/common/common.cpp +31 -521
  33. package/src/llama.cpp/common/common.h +17 -36
  34. package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
  35. package/src/llama.cpp/common/llguidance.cpp +30 -47
  36. package/src/llama.cpp/common/minja/chat-template.hpp +15 -7
  37. package/src/llama.cpp/common/minja/minja.hpp +119 -93
  38. package/src/llama.cpp/common/sampling.cpp +3 -0
  39. package/src/llama.cpp/docs/build.md +122 -7
  40. package/src/llama.cpp/examples/CMakeLists.txt +0 -9
  41. package/src/llama.cpp/examples/batched/batched.cpp +1 -1
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +1 -1
  43. package/src/llama.cpp/examples/embedding/embedding.cpp +7 -1
  44. package/src/llama.cpp/examples/export-lora/export-lora.cpp +1 -1
  45. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +15 -16
  46. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +210 -8
  48. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +39 -24
  50. package/src/llama.cpp/examples/llava/clip-impl.h +345 -0
  51. package/src/llama.cpp/examples/llava/clip.cpp +2152 -1803
  52. package/src/llama.cpp/examples/llava/clip.h +39 -22
  53. package/src/llama.cpp/examples/llava/deprecation-warning.cpp +22 -0
  54. package/src/llama.cpp/examples/llava/llava.cpp +64 -52
  55. package/src/llama.cpp/examples/llava/mtmd-cli.cpp +344 -0
  56. package/src/llama.cpp/examples/llava/mtmd.cpp +708 -0
  57. package/src/llama.cpp/examples/llava/mtmd.h +168 -0
  58. package/src/llama.cpp/examples/llava/{qwen2vl-cli.cpp → qwen2vl-test.cpp} +83 -31
  59. package/src/llama.cpp/examples/main/main.cpp +16 -5
  60. package/src/llama.cpp/examples/parallel/parallel.cpp +3 -1
  61. package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
  62. package/src/llama.cpp/examples/perplexity/perplexity.cpp +17 -3
  63. package/src/llama.cpp/examples/quantize/quantize.cpp +115 -2
  64. package/src/llama.cpp/examples/rpc/CMakeLists.txt +4 -2
  65. package/src/llama.cpp/examples/rpc/rpc-server.cpp +163 -8
  66. package/src/llama.cpp/examples/run/CMakeLists.txt +12 -1
  67. package/src/llama.cpp/examples/run/run.cpp +14 -28
  68. package/src/llama.cpp/examples/server/httplib.h +313 -247
  69. package/src/llama.cpp/examples/server/server.cpp +238 -139
  70. package/src/llama.cpp/examples/server/utils.hpp +51 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/build.sh +2 -2
  74. package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
  75. package/src/llama.cpp/examples/tts/tts.cpp +6 -9
  76. package/src/llama.cpp/ggml/CMakeLists.txt +8 -2
  77. package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
  78. package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
  79. package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
  80. package/src/llama.cpp/ggml/include/ggml.h +66 -99
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -7
  82. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
  83. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
  84. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
  85. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
  87. package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
  88. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
  89. package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
  90. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +48 -22
  91. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
  93. package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -192
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
  97. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +754 -404
  98. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1003 -13519
  99. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +2 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +2 -7
  101. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +0 -1
  102. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +3 -4
  103. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +533 -88
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8809 -0
  105. package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
  106. package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
  107. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
  108. package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
  109. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +258 -0
  110. package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
  111. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
  112. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
  113. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
  114. package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
  115. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +70 -3
  116. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
  117. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -260
  118. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +293 -40
  119. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +96 -22
  120. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +350 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
  123. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
  124. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +2 -292
  125. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
  126. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +967 -438
  127. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
  128. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
  129. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
  130. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +204 -280
  131. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
  132. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
  133. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
  134. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
  135. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
  136. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
  137. package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +23 -0
  139. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +646 -114
  140. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +12 -0
  141. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +17 -8
  142. package/src/llama.cpp/ggml/src/ggml.c +141 -245
  143. package/src/llama.cpp/ggml/src/gguf.cpp +1 -0
  144. package/src/llama.cpp/include/llama.h +30 -11
  145. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
  146. package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
  147. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
  148. package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
  149. package/src/llama.cpp/requirements/requirements-all.txt +2 -0
  150. package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
  151. package/src/llama.cpp/src/CMakeLists.txt +3 -2
  152. package/src/llama.cpp/src/llama-adapter.cpp +37 -1
  153. package/src/llama.cpp/src/llama-arch.cpp +160 -17
  154. package/src/llama.cpp/src/llama-arch.h +16 -0
  155. package/src/llama.cpp/src/llama-chat.cpp +82 -17
  156. package/src/llama.cpp/src/llama-chat.h +6 -2
  157. package/src/llama.cpp/src/llama-context.cpp +108 -92
  158. package/src/llama.cpp/src/llama-context.h +1 -2
  159. package/src/llama.cpp/src/llama-graph.cpp +189 -119
  160. package/src/llama.cpp/src/llama-graph.h +26 -6
  161. package/src/llama.cpp/src/llama-hparams.h +13 -0
  162. package/src/llama.cpp/src/llama-kv-cache.cpp +70 -123
  163. package/src/llama.cpp/src/llama-kv-cache.h +41 -115
  164. package/src/llama.cpp/src/llama-memory.h +1 -1
  165. package/src/llama.cpp/src/llama-mmap.cpp +1 -1
  166. package/src/llama.cpp/src/llama-model-loader.cpp +10 -5
  167. package/src/llama.cpp/src/llama-model-loader.h +5 -3
  168. package/src/llama.cpp/src/llama-model.cpp +1760 -534
  169. package/src/llama.cpp/src/llama-model.h +13 -1
  170. package/src/llama.cpp/src/llama-quant.cpp +29 -8
  171. package/src/llama.cpp/src/llama-sampling.cpp +7 -1
  172. package/src/llama.cpp/src/llama-vocab.cpp +44 -6
  173. package/src/llama.cpp/src/llama.cpp +1 -1
  174. package/src/llama.cpp/tests/CMakeLists.txt +43 -30
  175. package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
  176. package/src/llama.cpp/tests/test-backend-ops.cpp +82 -43
  177. package/src/llama.cpp/tests/test-chat-template.cpp +34 -13
  178. package/src/llama.cpp/tests/test-chat.cpp +12 -2
  179. package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
  180. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
  181. package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
  182. package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
  183. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
  184. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
  185. package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
  186. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
  187. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
  188. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
  189. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
  190. package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
  191. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
  192. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
  193. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
  194. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
  195. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
  196. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
  197. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
  198. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
  199. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
  200. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
  201. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
  202. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
@@ -24,6 +24,28 @@
24
24
  #include <future>
25
25
  #include <thread>
26
26
 
27
+ #if defined(_MSC_VER)
28
+ # define NOMINMAX 1
29
+ # include <windows.h>
30
+ # define YIELD() YieldProcessor()
31
+ #elif defined(__clang__) || defined(__GNUC__)
32
+ # if defined(__x86_64__) ||defined(__i386__)
33
+ # include <immintrin.h>
34
+ # define YIELD() _mm_pause()
35
+ # elif defined(__arm__) || defined(__aarch64__)
36
+ # if defined(__clang__)
37
+ # include <arm_acle.h>
38
+ # define YIELD() __yield()
39
+ # else
40
+ # define YIELD() asm volatile("yield")
41
+ # endif
42
+ # endif
43
+ #endif
44
+
45
+ #if !defined(YIELD)
46
+ #define YIELD()
47
+ #endif
48
+
27
49
  #include "ggml-impl.h"
28
50
  #include "ggml-backend-impl.h"
29
51
 
@@ -31,6 +53,7 @@
31
53
 
32
54
  #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
33
55
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
56
+ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
34
57
 
35
58
  #define VK_VENDOR_ID_AMD 0x1002
36
59
  #define VK_VENDOR_ID_APPLE 0x106b
@@ -223,6 +246,7 @@ struct vk_device_struct {
223
246
  bool pipeline_robustness;
224
247
  vk::Device device;
225
248
  uint32_t vendor_id;
249
+ vk::DriverId driver_id;
226
250
  vk_device_architecture architecture;
227
251
  vk_queue compute_queue;
228
252
  vk_queue transfer_queue;
@@ -234,6 +258,8 @@ struct vk_device_struct {
234
258
  bool float_controls_rte_fp16;
235
259
  bool subgroup_add;
236
260
 
261
+ bool integer_dot_product;
262
+
237
263
  bool subgroup_size_control;
238
264
  uint32_t subgroup_min_size;
239
265
  uint32_t subgroup_max_size;
@@ -245,6 +271,12 @@ struct vk_device_struct {
245
271
  uint32_t coopmat_m;
246
272
  uint32_t coopmat_n;
247
273
  uint32_t coopmat_k;
274
+
275
+ bool coopmat_int_support;
276
+ uint32_t coopmat_int_m;
277
+ uint32_t coopmat_int_n;
278
+ uint32_t coopmat_int_k;
279
+
248
280
  bool coopmat2;
249
281
 
250
282
  size_t idx;
@@ -263,10 +295,10 @@ struct vk_device_struct {
263
295
  vk_matmul_pipeline pipeline_matmul_f32_f16 {};
264
296
  vk_matmul_pipeline2 pipeline_matmul_f16;
265
297
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
266
- vk_pipeline pipeline_matmul_split_k_reduce;
267
298
 
268
- vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
269
299
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
300
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
301
+ vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
270
302
 
271
303
  vk_matmul_pipeline pipeline_matmul_id_f32 {};
272
304
  vk_matmul_pipeline2 pipeline_matmul_id_f16;
@@ -274,6 +306,9 @@ struct vk_device_struct {
274
306
 
275
307
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
276
308
 
309
+ vk_pipeline pipeline_matmul_split_k_reduce;
310
+ vk_pipeline pipeline_quantize_q8_1;
311
+
277
312
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
278
313
  vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
279
314
  vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
@@ -341,6 +376,7 @@ struct vk_device_struct {
341
376
  vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
342
377
  vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
343
378
  vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
379
+ vk_pipeline pipeline_flash_attn_split_k_reduce;
344
380
 
345
381
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
346
382
  std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
@@ -490,6 +526,10 @@ struct vk_flash_attn_push_constants {
490
526
  uint32_t n_head_log2;
491
527
  float m0;
492
528
  float m1;
529
+
530
+ uint32_t gqa_ratio;
531
+ uint32_t split_kv;
532
+ uint32_t k_num;
493
533
  };
494
534
 
495
535
  struct vk_op_push_constants {
@@ -640,6 +680,13 @@ struct vk_op_rwkv_wkv7_push_constants {
640
680
  uint32_t H;
641
681
  };
642
682
 
683
+ struct vk_op_upscale_push_constants {
684
+ uint32_t ne; uint32_t a_offset; uint32_t d_offset;
685
+ uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
686
+ uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
687
+ float sf0; float sf1; float sf2; float sf3;
688
+ };
689
+
643
690
  // Allow pre-recording command buffers
644
691
  struct vk_staging_memcpy {
645
692
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -649,13 +696,6 @@ struct vk_staging_memcpy {
649
696
  size_t n;
650
697
  };
651
698
 
652
- struct vk_op_upscale_push_constants {
653
- uint32_t ne; uint32_t a_offset; uint32_t d_offset;
654
- uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
655
- uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
656
- float sf0; float sf1; float sf2; float sf3;
657
- };
658
-
659
699
  struct vk_context_struct {
660
700
  vk_submission * s;
661
701
  std::vector<vk_sequence> seqs;
@@ -770,7 +810,8 @@ struct ggml_backend_vk_context {
770
810
  ggml_vk_garbage_collector gc;
771
811
  size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
772
812
  vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
773
- vk::Fence fence;
813
+ vk::Fence fence, almost_ready_fence;
814
+ bool almost_ready_fence_pending {};
774
815
 
775
816
  vk_buffer buffer_pool[MAX_VK_BUFFERS];
776
817
 
@@ -861,6 +902,39 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
861
902
 
862
903
  static void ggml_backend_vk_free(ggml_backend_t backend);
863
904
 
905
+ // Wait for ctx->fence to be signaled.
906
+ static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
907
+ // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
908
+ // during this wait.
909
+ if (ctx->almost_ready_fence_pending) {
910
+ VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence");
911
+ ctx->device->device.resetFences({ ctx->almost_ready_fence });
912
+ ctx->almost_ready_fence_pending = false;
913
+ }
914
+
915
+ // Spin (w/pause) waiting for the graph to finish executing.
916
+ vk::Result result;
917
+ while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) {
918
+ if (result != vk::Result::eNotReady) {
919
+ fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__);
920
+ exit(1);
921
+ }
922
+ for (uint32_t i = 0; i < 100; ++i) {
923
+ YIELD();
924
+ YIELD();
925
+ YIELD();
926
+ YIELD();
927
+ YIELD();
928
+ YIELD();
929
+ YIELD();
930
+ YIELD();
931
+ YIELD();
932
+ YIELD();
933
+ }
934
+ }
935
+ ctx->device->device.resetFences({ ctx->fence });
936
+ }
937
+
864
938
  // variables to track number of compiles in progress
865
939
  static uint32_t compile_count = 0;
866
940
  static std::mutex compile_count_mutex;
@@ -1462,7 +1536,7 @@ static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_typ
1462
1536
 
1463
1537
  // small rows, large cols
1464
1538
  if (small_rows) {
1465
- return {flash_attention_num_small_rows, 128};
1539
+ return {flash_attention_num_small_rows, 64};
1466
1540
  }
1467
1541
  // small cols to reduce register count
1468
1542
  if (ggml_is_quantized(type) || D == 256) {
@@ -1598,6 +1672,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1598
1672
  // mulmat
1599
1673
  std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
1600
1674
  l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
1675
+ l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
1601
1676
  l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
1602
1677
  l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
1603
1678
  std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
@@ -1662,6 +1737,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
1662
1737
  m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
1663
1738
  s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
1664
1739
 
1740
+ l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
1741
+ m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
1742
+ s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
1743
+
1744
+ // chip specific tuning
1745
+ if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
1746
+ m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
1747
+ }
1748
+
1665
1749
  l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
1666
1750
  m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
1667
1751
  s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
@@ -1755,6 +1839,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
1755
1839
  // can't use 256 for D==80.
1756
1840
  uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
1757
1841
  auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
1842
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
1843
+ GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
1758
1844
  return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
1759
1845
  };
1760
1846
 
@@ -2000,6 +2086,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2000
2086
  if (device->mul_mat ## ID ## _s[TYPE]) \
2001
2087
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2002
2088
 
2089
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2090
+ if (device->mul_mat ## ID ## _l[TYPE]) \
2091
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2092
+ if (device->mul_mat ## ID ## _m[TYPE]) \
2093
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2094
+ if (device->mul_mat ## ID ## _s[TYPE]) \
2095
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2096
+
2003
2097
  // Create 2 variants, {f16,f32} accumulator
2004
2098
  #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2005
2099
  CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -2031,6 +2125,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
2031
2125
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2032
2126
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2033
2127
 
2128
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2129
+ if (device->integer_dot_product) {
2130
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2131
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2132
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2133
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2134
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2135
+ }
2136
+ #endif
2137
+
2034
2138
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2035
2139
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2036
2140
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2056,6 +2160,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2056
2160
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2057
2161
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2058
2162
  #undef CREATE_MM2
2163
+ #undef CREATE_MMQ
2059
2164
  #undef CREATE_MM
2060
2165
  } else {
2061
2166
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
@@ -2073,6 +2178,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
2073
2178
  if (device->mul_mat ## ID ## _s[TYPE]) \
2074
2179
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
2075
2180
 
2181
+ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
2182
+ if (device->mul_mat ## ID ## _l[TYPE]) \
2183
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
2184
+ if (device->mul_mat ## ID ## _m[TYPE]) \
2185
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
2186
+ if (device->mul_mat ## ID ## _s[TYPE]) \
2187
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
2188
+
2076
2189
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2077
2190
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2078
2191
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
@@ -2099,6 +2212,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
2099
2212
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2100
2213
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2101
2214
 
2215
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2216
+ if (device->integer_dot_product) {
2217
+ CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2218
+ CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2219
+ CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2220
+ CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2221
+ CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
2222
+ }
2223
+ #endif
2224
+
2102
2225
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2103
2226
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2104
2227
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
@@ -2132,7 +2255,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2132
2255
  uint32_t rm_stdq = 1;
2133
2256
  uint32_t rm_kq = 2;
2134
2257
  if (device->vendor_id == VK_VENDOR_ID_AMD) {
2135
- if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN
2258
+ if (device->architecture == AMD_GCN) {
2136
2259
  rm_stdq = 2;
2137
2260
  rm_kq = 4;
2138
2261
  }
@@ -2266,6 +2389,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2266
2389
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2267
2390
 
2268
2391
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2392
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
2393
+ ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
2269
2394
 
2270
2395
  for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2271
2396
  if (device->subgroup_add && device->subgroup_require_full_support) {
@@ -2278,7 +2403,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2278
2403
 
2279
2404
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2280
2405
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2281
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2406
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
2282
2407
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2283
2408
  ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2284
2409
 
@@ -2452,6 +2577,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2452
2577
  bool pipeline_robustness = false;
2453
2578
  bool coopmat2_support = false;
2454
2579
  device->coopmat_support = false;
2580
+ device->integer_dot_product = false;
2455
2581
 
2456
2582
  for (const auto& properties : ext_props) {
2457
2583
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -2477,6 +2603,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
2477
2603
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2478
2604
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2479
2605
  coopmat2_support = true;
2606
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2607
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
2608
+ !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
2609
+ device->integer_dot_product = true;
2610
+ #endif
2480
2611
  }
2481
2612
  }
2482
2613
 
@@ -2490,6 +2621,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2490
2621
  vk::PhysicalDeviceVulkan11Properties vk11_props;
2491
2622
  vk::PhysicalDeviceVulkan12Properties vk12_props;
2492
2623
  vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2624
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
2493
2625
 
2494
2626
  props2.pNext = &props3;
2495
2627
  props3.pNext = &subgroup_props;
@@ -2524,9 +2656,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
2524
2656
  }
2525
2657
  #endif
2526
2658
 
2659
+ if (device->integer_dot_product) {
2660
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
2661
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
2662
+ }
2663
+
2527
2664
  device->physical_device.getProperties2(&props2);
2528
2665
  device->properties = props2.properties;
2529
2666
  device->vendor_id = device->properties.vendorID;
2667
+ device->driver_id = driver_props.driverID;
2530
2668
 
2531
2669
  const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
2532
2670
 
@@ -2570,6 +2708,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2570
2708
  device->coopmat_support = false;
2571
2709
  }
2572
2710
 
2711
+ device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
2712
+
2573
2713
  std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
2574
2714
 
2575
2715
  // Try to find a non-graphics compute queue and transfer-focused queues
@@ -2662,6 +2802,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
2662
2802
  device_extensions.push_back("VK_KHR_maintenance4");
2663
2803
  }
2664
2804
 
2805
+ VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
2806
+ shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
2807
+ if (device->integer_dot_product) {
2808
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
2809
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
2810
+ device_extensions.push_back("VK_KHR_shader_integer_dot_product");
2811
+ }
2812
+
2665
2813
  vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
2666
2814
 
2667
2815
  device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -2831,6 +2979,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
2831
2979
  device->coopmat_acc_f16_support = true;
2832
2980
  }
2833
2981
  }
2982
+ } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
2983
+ (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
2984
+ (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
2985
+ (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
2986
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
2987
+ device->coopmat_int_m == 0
2988
+ ) {
2989
+ device->coopmat_int_support = true;
2990
+ device->coopmat_int_m = prop.MSize;
2991
+ device->coopmat_int_n = prop.NSize;
2992
+ device->coopmat_int_k = prop.KSize;
2834
2993
  }
2835
2994
  }
2836
2995
 
@@ -2935,25 +3094,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2935
3094
  vk::PhysicalDevice physical_device = devices[dev_num];
2936
3095
  std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
2937
3096
 
2938
- vk::PhysicalDeviceProperties2 props2;
2939
- vk::PhysicalDeviceMaintenance3Properties props3;
2940
- vk::PhysicalDeviceSubgroupProperties subgroup_props;
2941
- vk::PhysicalDeviceDriverProperties driver_props;
2942
- props2.pNext = &props3;
2943
- props3.pNext = &subgroup_props;
2944
- subgroup_props.pNext = &driver_props;
2945
- physical_device.getProperties2(&props2);
2946
-
2947
- vk_device_architecture arch = get_device_architecture(physical_device);
2948
- uint32_t default_subgroup_size = get_subgroup_size("", arch);
2949
- const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
2950
-
2951
- const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
2952
-
2953
3097
  bool fp16_storage = false;
2954
3098
  bool fp16_compute = false;
2955
3099
  bool coopmat_support = false;
2956
3100
  bool coopmat2_support = false;
3101
+ bool integer_dot_product = false;
2957
3102
 
2958
3103
  for (auto properties : ext_props) {
2959
3104
  if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
@@ -2969,27 +3114,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2969
3114
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2970
3115
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
2971
3116
  coopmat2_support = true;
3117
+ #endif
3118
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
3119
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
3120
+ !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
3121
+ integer_dot_product = true;
2972
3122
  #endif
2973
3123
  }
2974
3124
  }
2975
3125
 
2976
3126
  const vk_device_architecture device_architecture = get_device_architecture(physical_device);
2977
3127
 
2978
- if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
2979
- coopmat_support = false;
2980
- }
2981
-
2982
3128
  const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
2983
3129
  bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
2984
3130
 
2985
3131
  bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2986
3132
 
2987
- vk::PhysicalDeviceFeatures device_features = physical_device.getFeatures();
3133
+ vk::PhysicalDeviceProperties2 props2;
3134
+ vk::PhysicalDeviceMaintenance3Properties props3;
3135
+ vk::PhysicalDeviceSubgroupProperties subgroup_props;
3136
+ vk::PhysicalDeviceDriverProperties driver_props;
3137
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
3138
+ props2.pNext = &props3;
3139
+ props3.pNext = &subgroup_props;
3140
+ subgroup_props.pNext = &driver_props;
3141
+
3142
+ // Pointer to the last chain element
3143
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
3144
+
3145
+ if (integer_dot_product) {
3146
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
3147
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
3148
+ }
3149
+
3150
+ physical_device.getProperties2(&props2);
2988
3151
 
2989
3152
  VkPhysicalDeviceFeatures2 device_features2;
2990
3153
  device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
2991
3154
  device_features2.pNext = nullptr;
2992
- device_features2.features = (VkPhysicalDeviceFeatures)device_features;
2993
3155
 
2994
3156
  VkPhysicalDeviceVulkan11Features vk11_features;
2995
3157
  vk11_features.pNext = nullptr;
@@ -3002,7 +3164,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3002
3164
  vk11_features.pNext = &vk12_features;
3003
3165
 
3004
3166
  // Pointer to the last chain element
3005
- VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
3167
+ last_struct = (VkBaseOutStructure *)&vk12_features;
3006
3168
 
3007
3169
  #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3008
3170
  VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
@@ -3014,20 +3176,39 @@ static void ggml_vk_print_gpu_info(size_t idx) {
3014
3176
  last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
3015
3177
  last_struct = (VkBaseOutStructure *)&coopmat_features;
3016
3178
  }
3179
+ #endif
3180
+
3181
+ VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
3182
+ shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
3183
+ if (integer_dot_product) {
3184
+ last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3185
+ last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
3186
+ }
3017
3187
 
3018
3188
  vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
3019
3189
 
3020
3190
  fp16 = fp16 && vk12_features.shaderFloat16;
3021
3191
 
3022
- coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
3192
+ uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
3193
+ const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
3194
+ const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
3195
+
3196
+ integer_dot_product = integer_dot_product
3197
+ && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
3198
+ && shader_integer_dot_product_features.shaderIntegerDotProduct;
3199
+
3200
+ coopmat_support = coopmat_support
3201
+ #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
3202
+ && coopmat_features.cooperativeMatrix
3023
3203
  #endif
3204
+ && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
3024
3205
 
3025
3206
  std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
3026
3207
 
3027
3208
  std::string device_name = props2.properties.deviceName.data();
3028
- GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
3209
+ GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
3029
3210
  idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
3030
- props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
3211
+ props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
3031
3212
 
3032
3213
  if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
3033
3214
  GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@@ -3229,6 +3410,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
3229
3410
  ctx->prealloc_size_split_k = 0;
3230
3411
 
3231
3412
  ctx->fence = ctx->device->device.createFence({});
3413
+ ctx->almost_ready_fence = ctx->device->device.createFence({});
3232
3414
 
3233
3415
  #ifdef GGML_VULKAN_CHECK_RESULTS
3234
3416
  const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
@@ -3293,6 +3475,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3293
3475
  }
3294
3476
  }
3295
3477
 
3478
+ // MMQ
3479
+ if (src1_type == GGML_TYPE_Q8_1) {
3480
+ vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
3481
+
3482
+ if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
3483
+ return nullptr;
3484
+ }
3485
+
3486
+ return pipelines;
3487
+ }
3488
+
3296
3489
  if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
3297
3490
  return nullptr;
3298
3491
  }
@@ -3585,8 +3778,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
3585
3778
  return s;
3586
3779
  }
3587
3780
 
3588
-
3589
-
3590
3781
  static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
3591
3782
  const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
3592
3783
  const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
@@ -4010,14 +4201,20 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
4010
4201
  if (split_k == 3) {
4011
4202
  split_k = 2;
4012
4203
  }
4204
+ if (ctx->device->coopmat2) {
4205
+ // coopmat2 shader expects splits to be aligned to 256
4206
+ while (split_k > 1 && ((k / split_k) % 256) != 0) {
4207
+ split_k /= 2;
4208
+ }
4209
+ }
4013
4210
  }
4014
4211
  }
4015
4212
 
4016
4213
  return split_k;
4017
4214
  }
4018
4215
 
4019
- static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
4020
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4216
+ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
4217
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
4021
4218
 
4022
4219
  if (ctx->device->coopmat2) {
4023
4220
  // Use large shader when the N dimension is greater than the medium shader's tile size
@@ -4042,9 +4239,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
4042
4239
  return aligned ? mmp->a_l : mmp->l;
4043
4240
  }
4044
4241
 
4045
- static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
4046
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
4047
- return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
4242
+ static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
4243
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
4244
+ return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
4048
4245
  }
4049
4246
 
4050
4247
  static void ggml_vk_matmul(
@@ -4054,7 +4251,7 @@ static void ggml_vk_matmul(
4054
4251
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
4055
4252
  uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
4056
4253
  uint32_t padded_n) {
4057
- VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
4254
+ VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
4058
4255
  ggml_vk_sync_buffers(subctx);
4059
4256
  if (split_k == 1) {
4060
4257
  const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
@@ -4072,7 +4269,7 @@ static void ggml_vk_matmul(
4072
4269
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
4073
4270
  }
4074
4271
 
4075
- static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
4272
+ static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
4076
4273
  VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4077
4274
 
4078
4275
  if (ctx->device->coopmat2) {
@@ -4214,6 +4411,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
4214
4411
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
4215
4412
  }
4216
4413
 
4414
+ static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
4415
+ switch(type) {
4416
+ case GGML_TYPE_Q8_1:
4417
+ return ctx->device->pipeline_quantize_q8_1;
4418
+ default:
4419
+ std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
4420
+ GGML_ABORT("fatal error");
4421
+ }
4422
+ }
4423
+
4424
+ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
4425
+ VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
4426
+
4427
+ vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
4428
+
4429
+ ggml_vk_sync_buffers(subctx);
4430
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
4431
+ }
4432
+
4217
4433
  static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
4218
4434
  VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
4219
4435
  std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
@@ -4265,10 +4481,19 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4265
4481
 
4266
4482
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4267
4483
 
4268
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
4484
+ bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
4485
+
4486
+ // Check for mmq first
4487
+ vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
4488
+
4489
+ if (mmp == nullptr) {
4490
+ // Fall back to f16 dequant mul mat
4491
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
4492
+ quantize_y = false;
4493
+ }
4269
4494
 
4270
4495
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4271
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
4496
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
4272
4497
 
4273
4498
  if (qx_needs_dequant) {
4274
4499
  // Fall back to dequant + f16 mulmat
@@ -4278,13 +4503,13 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4278
4503
  // Not implemented
4279
4504
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4280
4505
 
4281
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4282
- const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
4506
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
4507
+ const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
4283
4508
 
4284
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4509
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
4285
4510
 
4286
4511
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4287
- uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4512
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
4288
4513
  const int x_ne = ne01 * ne00;
4289
4514
  const int y_ne = padded_n * ne10;
4290
4515
  const int d_ne = ne11 * ne01;
@@ -4294,11 +4519,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4294
4519
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4295
4520
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4296
4521
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
4297
- const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
4522
+ const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
4298
4523
  const uint64_t d_sz = sizeof(float) * d_ne;
4299
4524
 
4300
4525
  vk_pipeline to_fp16_vk_0 = nullptr;
4301
4526
  vk_pipeline to_fp16_vk_1 = nullptr;
4527
+ vk_pipeline to_q8_1 = nullptr;
4302
4528
 
4303
4529
  if (x_non_contig) {
4304
4530
  to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
@@ -4313,6 +4539,10 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4313
4539
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
4314
4540
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
4315
4541
 
4542
+ if (quantize_y) {
4543
+ to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
4544
+ }
4545
+
4316
4546
  if (dryrun) {
4317
4547
  const uint64_t x_sz_upd = x_sz * ne02 * ne03;
4318
4548
  const uint64_t y_sz_upd = y_sz * ne12 * ne13;
@@ -4326,7 +4556,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4326
4556
  if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
4327
4557
  ctx->prealloc_size_x = x_sz_upd;
4328
4558
  }
4329
- if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
4559
+ if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
4330
4560
  ctx->prealloc_size_y = y_sz_upd;
4331
4561
  }
4332
4562
  if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
@@ -4341,6 +4571,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4341
4571
  if (qy_needs_dequant) {
4342
4572
  ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
4343
4573
  }
4574
+ if (quantize_y) {
4575
+ ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
4576
+ }
4344
4577
  if (split_k > 1) {
4345
4578
  ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
4346
4579
  }
@@ -4376,6 +4609,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4376
4609
  if (qy_needs_dequant) {
4377
4610
  d_Y = ctx->prealloc_y;
4378
4611
  GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4612
+ } else if (quantize_y) {
4613
+ d_Y = ctx->prealloc_y;
4614
+ GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
4379
4615
  } else {
4380
4616
  d_Y = d_Qy;
4381
4617
  y_buf_offset = qy_buf_offset;
@@ -4392,6 +4628,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4392
4628
  if (y_non_contig) {
4393
4629
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
4394
4630
  }
4631
+ if (quantize_y) {
4632
+ ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
4633
+ }
4395
4634
 
4396
4635
  uint32_t stride_batch_x = ne00*ne01;
4397
4636
  uint32_t stride_batch_y = ne10*ne11;
@@ -4400,7 +4639,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4400
4639
  stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
4401
4640
  }
4402
4641
 
4403
- if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
4642
+ if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
4404
4643
  stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
4405
4644
  }
4406
4645
 
@@ -5232,7 +5471,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5232
5471
  const uint32_t nbm1 = mask ? mask->nb[1] : 0;
5233
5472
 
5234
5473
  const uint32_t D = neq0;
5235
- const uint32_t N = neq1;
5474
+ uint32_t N = neq1;
5236
5475
  const uint32_t KV = nek1;
5237
5476
 
5238
5477
  GGML_ASSERT(ne0 == D);
@@ -5287,12 +5526,60 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5287
5526
  // the "aligned" shader variant will forcibly align strides, for performance
5288
5527
  (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
5289
5528
 
5529
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
5530
+ GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
5531
+
5290
5532
  vk_pipeline pipeline = pipelines[aligned];
5291
5533
  assert(pipeline);
5292
5534
 
5535
+ uint32_t gqa_ratio = 1;
5536
+ uint32_t qk_ratio = neq2 / nek2;
5537
+ uint32_t workgroups_x = (uint32_t)neq1;
5538
+ uint32_t workgroups_y = (uint32_t)neq2;
5539
+ uint32_t workgroups_z = (uint32_t)neq3;
5540
+
5541
+ if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
5542
+ qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5543
+ // grouped query attention - make the N dimension equal to gqa_ratio, reduce
5544
+ // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
5545
+ // and change addressing calculations to index Q's dimension 2.
5546
+ gqa_ratio = qk_ratio;
5547
+ N = gqa_ratio;
5548
+ workgroups_y /= N;
5549
+ }
5550
+
5551
+ uint32_t split_kv = KV;
5552
+ uint32_t split_k = 1;
5553
+
5554
+ // Try to use split_k when KV is large enough to be worth the overhead
5555
+ if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
5556
+ // Try to run two workgroups per SM.
5557
+ split_k = ctx->device->shader_core_count * 2 / workgroups_y;
5558
+ if (split_k > 1) {
5559
+ // Try to evenly split KV into split_k chunks, but it needs to be a multiple
5560
+ // of "align", so recompute split_k based on that.
5561
+ split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
5562
+ split_k = CEIL_DIV(KV, split_kv);
5563
+ workgroups_x = split_k;
5564
+ }
5565
+ }
5566
+
5567
+ // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
5568
+ // and the per-row m and L values (ne1 rows).
5569
+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
5570
+ if (split_k_size > ctx->device->max_memory_allocation_size) {
5571
+ GGML_ABORT("Requested preallocation size is too large");
5572
+ }
5573
+ if (ctx->prealloc_size_split_k < split_k_size) {
5574
+ ctx->prealloc_size_split_k = split_k_size;
5575
+ }
5576
+
5293
5577
  if (dryrun) {
5294
5578
  // Request descriptor sets
5295
5579
  ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
5580
+ if (split_k > 1) {
5581
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
5582
+ }
5296
5583
  return;
5297
5584
  }
5298
5585
 
@@ -5313,8 +5600,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5313
5600
  const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
5314
5601
  const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
5315
5602
 
5316
- ggml_vk_sync_buffers(subctx);
5317
-
5318
5603
  vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
5319
5604
  size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
5320
5605
 
@@ -5379,16 +5664,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5379
5664
  v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
5380
5665
  nbm1,
5381
5666
  scale, max_bias, logit_softcap,
5382
- mask != nullptr, n_head_log2, m0, m1 };
5383
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5384
- {
5385
- vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5386
- vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5387
- vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5388
- vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5389
- vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5390
- },
5391
- sizeof(vk_flash_attn_push_constants), &pc, { (uint32_t)neq1, (uint32_t)neq2, (uint32_t)neq3 });
5667
+ mask != nullptr, n_head_log2, m0, m1,
5668
+ gqa_ratio, split_kv, split_k };
5669
+
5670
+ ggml_vk_sync_buffers(subctx);
5671
+
5672
+ if (split_k > 1) {
5673
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5674
+ {
5675
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5676
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5677
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5678
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5679
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5680
+ },
5681
+ // We only use split_k when group query attention is enabled, which means
5682
+ // there's no more than one tile of rows (i.e. workgroups_x would have been
5683
+ // one). We reuse workgroups_x to mean the number of splits, so we need to
5684
+ // cancel out the divide by wg_denoms[0].
5685
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
5686
+
5687
+ ggml_vk_sync_buffers(subctx);
5688
+ const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
5689
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
5690
+ {
5691
+ vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
5692
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5693
+ },
5694
+ pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
5695
+ } else {
5696
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
5697
+ {
5698
+ vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
5699
+ vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
5700
+ vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
5701
+ vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
5702
+ vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
5703
+ },
5704
+ sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
5705
+ }
5392
5706
  }
5393
5707
 
5394
5708
  static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
@@ -5442,7 +5756,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5442
5756
  }
5443
5757
  return nullptr;
5444
5758
  case GGML_OP_UPSCALE:
5445
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5759
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
5446
5760
  return ctx->device->pipeline_upscale_f32;
5447
5761
  }
5448
5762
  return nullptr;
@@ -5699,6 +6013,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5699
6013
  case GGML_OP_REPEAT:
5700
6014
  case GGML_OP_REPEAT_BACK:
5701
6015
  case GGML_OP_ROPE:
6016
+ case GGML_OP_RMS_NORM:
5702
6017
  return true;
5703
6018
  default:
5704
6019
  return false;
@@ -5909,7 +6224,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5909
6224
 
5910
6225
  switch (op) {
5911
6226
  case GGML_OP_NORM:
5912
- case GGML_OP_RMS_NORM:
5913
6227
  case GGML_OP_RMS_NORM_BACK:
5914
6228
  case GGML_OP_L2_NORM:
5915
6229
  case GGML_OP_SOFT_MAX:
@@ -5926,6 +6240,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5926
6240
  elements = { nr, 1, 1 };
5927
6241
  }
5928
6242
  } break;
6243
+ case GGML_OP_RMS_NORM:
6244
+ elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
6245
+ break;
6246
+
5929
6247
  case GGML_OP_SUM:
5930
6248
  // We use GGML_OP_SUM_ROWS with 1 row.
5931
6249
  elements = { 1, 1, 1 };
@@ -6576,7 +6894,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
6576
6894
 
6577
6895
  static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6578
6896
  float * op_params = (float *)dst->op_params;
6579
- ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6897
+ const uint32_t src0_type_size = ggml_type_size(src0->type);
6898
+ const uint32_t dst_type_size = ggml_type_size(dst->type);
6899
+
6900
+ ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
6901
+ (uint32_t)ggml_nelements(src0),
6902
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
6903
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
6904
+ 0,
6905
+ op_params[0], 0.0f,
6906
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
6907
+ }, dryrun);
6580
6908
  }
6581
6909
 
6582
6910
  static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -6929,6 +7257,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
6929
7257
  }
6930
7258
  }
6931
7259
 
7260
+ if (ctx->device->need_compiles) {
7261
+ ggml_vk_load_shaders(ctx->device);
7262
+ }
7263
+
6932
7264
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
6933
7265
 
6934
7266
  vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
@@ -7177,6 +7509,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
7177
7509
 
7178
7510
  ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
7179
7511
 
7512
+ if (ctx->device->need_compiles) {
7513
+ ggml_vk_load_shaders(ctx->device);
7514
+ }
7515
+
7180
7516
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7181
7517
 
7182
7518
  ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
@@ -7236,66 +7572,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
7236
7572
  free(x_chk);
7237
7573
  }
7238
7574
 
7239
- static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant) {
7575
+ // This does not work without ggml q8_1 quantization support
7576
+ //
7577
+ // typedef uint16_t ggml_half;
7578
+ // typedef uint32_t ggml_half2;
7579
+ //
7580
+ // #define QK8_1 32
7581
+ // typedef struct {
7582
+ // union {
7583
+ // struct {
7584
+ // ggml_half d; // delta
7585
+ // ggml_half s; // d * sum(qs[i])
7586
+ // } GGML_COMMON_AGGR_S;
7587
+ // ggml_half2 ds;
7588
+ // } GGML_COMMON_AGGR_U;
7589
+ // int8_t qs[QK8_1]; // quants
7590
+ // } block_q8_1;
7591
+ //
7592
+ // static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
7593
+ // VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
7594
+ // GGML_ASSERT(quant == GGML_TYPE_Q8_1);
7595
+ //
7596
+ // const size_t x_sz = sizeof(float) * ne;
7597
+ // const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
7598
+ // float * x = (float *) malloc(x_sz);
7599
+ // block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
7600
+ // block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
7601
+ // vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7602
+ // vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7603
+ //
7604
+ // for (size_t i = 0; i < ne; i++) {
7605
+ // x[i] = rand() / (float)RAND_MAX;
7606
+ // }
7607
+ //
7608
+ // vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
7609
+ //
7610
+ // ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
7611
+ //
7612
+ // if (ctx->device->need_compiles) {
7613
+ // ggml_vk_load_shaders(ctx->device);
7614
+ // }
7615
+ //
7616
+ // ggml_pipeline_allocate_descriptor_sets(ctx->device);
7617
+ //
7618
+ // ggml_vk_buffer_write(x_buf, 0, x, x_sz);
7619
+ //
7620
+ // vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
7621
+ // ggml_vk_ctx_begin(ctx->device, subctx);
7622
+ // ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
7623
+ // ggml_vk_ctx_end(subctx);
7624
+ //
7625
+ // auto begin = std::chrono::high_resolution_clock::now();
7626
+ //
7627
+ // ggml_vk_submit(subctx, ctx->fence);
7628
+ // VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
7629
+ // ctx->device->device.resetFences({ ctx->fence });
7630
+ //
7631
+ // auto end = std::chrono::high_resolution_clock::now();
7632
+ //
7633
+ // double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
7634
+ // ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
7635
+ //
7636
+ // ggml_vk_quantize_data(x, qx_res, ne, quant);
7637
+ //
7638
+ // int first_err = -1;
7639
+ //
7640
+ // for (size_t i = 0; i < ne / 32; i++) {
7641
+ // double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
7642
+ //
7643
+ // if (first_err < 0 && error > 0.1) {
7644
+ // first_err = i;
7645
+ // }
7646
+ //
7647
+ // error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
7648
+ //
7649
+ // if (first_err < 0 && error > 0.1) {
7650
+ // first_err = i;
7651
+ // }
7652
+ //
7653
+ // for (size_t j = 0; j < 32; j++) {
7654
+ // uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
7655
+ //
7656
+ // if (first_err < 0 && error > 1) {
7657
+ // first_err = i;
7658
+ // }
7659
+ // }
7660
+ // }
7661
+ //
7662
+ // std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
7663
+ //
7664
+ // if (first_err != -1) {
7665
+ // std::cerr << "first_error = " << first_err << std::endl;
7666
+ // std::cerr << "Actual result: " << std::endl << std::endl;
7667
+ // std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
7668
+ // for (size_t j = 0; j < 32; j++) {
7669
+ // std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
7670
+ // }
7671
+ // std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
7672
+ // std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
7673
+ // for (size_t j = 0; j < 32; j++) {
7674
+ // std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
7675
+ // }
7676
+ // std::cerr << std::endl;
7677
+ // }
7678
+ //
7679
+ // ggml_vk_destroy_buffer(x_buf);
7680
+ // ggml_vk_destroy_buffer(qx_buf);
7681
+ //
7682
+ // free(x);
7683
+ // free(qx);
7684
+ // free(qx_res);
7685
+ // }
7686
+
7687
+ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
7240
7688
  VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
7241
7689
  const size_t x_ne = m * k * batch;
7242
7690
  const size_t y_ne = k * n * batch;
7243
7691
  const size_t d_ne = m * n * batch;
7244
7692
 
7693
+ vk_matmul_pipeline2 * pipelines;
7694
+
7695
+ if (mmq) {
7696
+ pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
7697
+ } else {
7698
+ pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
7699
+ }
7700
+
7701
+ const bool fp16acc = ctx->device->fp16;
7702
+
7245
7703
  vk_pipeline p;
7246
7704
  std::string shname;
7247
7705
  if (shader_size == 0) {
7248
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_s;
7706
+ p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
7249
7707
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
7250
7708
  } else if (shader_size == 1) {
7251
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_m;
7709
+ p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
7252
7710
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
7253
7711
  } else if (shader_size == 2) {
7254
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->a_l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->a_l;
7712
+ p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
7255
7713
  shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
7256
7714
  } else {
7257
7715
  GGML_ASSERT(0);
7258
7716
  }
7259
7717
 
7260
- const size_t kpad = ggml_vk_align_size(k, p->align);
7718
+ const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
7261
7719
 
7262
- if (k != kpad) {
7720
+ if (mmq || k != kpad) {
7263
7721
  if (shader_size == 0) {
7264
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->s : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->s;
7722
+ p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
7265
7723
  shname = std::string(ggml_type_name(quant)) + "_S";
7266
7724
  } else if (shader_size == 1) {
7267
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->m : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->m;
7725
+ p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
7268
7726
  shname = std::string(ggml_type_name(quant)) + "_M";
7269
7727
  } else if (shader_size == 2) {
7270
- p = ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[quant].f16acc->l : ctx->device->pipeline_dequant_mul_mat_mat[quant].f32acc->l;
7728
+ p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
7271
7729
  shname = std::string(ggml_type_name(quant)) + "_L";
7272
7730
  } else {
7273
7731
  GGML_ASSERT(0);
7274
7732
  }
7275
7733
  }
7276
7734
 
7735
+ if (p == nullptr) {
7736
+ std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
7737
+ return;
7738
+ }
7739
+
7277
7740
  const size_t x_sz = sizeof(float) * x_ne;
7278
7741
  const size_t y_sz = sizeof(float) * y_ne;
7279
7742
  const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
7743
+ const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
7280
7744
  const size_t d_sz = sizeof(float) * d_ne;
7281
7745
  float * x = (float *) malloc(x_sz);
7282
7746
  float * y = (float *) malloc(y_sz);
7283
7747
  void * qx = malloc(qx_sz);
7284
7748
  vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7285
7749
  vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7750
+ vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7286
7751
  vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
7287
7752
  float * d = (float *) malloc(d_sz);
7288
7753
  float * d_chk = (float *) malloc(d_sz);
7289
7754
 
7290
7755
  for (size_t i = 0; i < x_ne; i++) {
7291
7756
  x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7757
+ // x[i] = (i % k == i / k) ? 1.0f : 0.0f;
7758
+ // x[i] = i % k;
7292
7759
  }
7293
7760
 
7294
7761
  ggml_vk_quantize_data(x, qx, x_ne, quant);
7295
7762
 
7296
7763
  for (size_t i = 0; i < y_ne; i++) {
7297
- // y[i] = rand() / (float)RAND_MAX;
7298
- y[i] = (i % k == i / k) ? 1.0f : 0.0f;
7764
+ y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
7765
+ // y[i] = (i % k == i / k) ? 1.0f : 0.0f;
7766
+ // y[i] = i % k;
7299
7767
  }
7300
7768
 
7301
7769
  ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
@@ -7310,6 +7778,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7310
7778
  ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
7311
7779
  }
7312
7780
  }
7781
+ if (mmq) {
7782
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
7783
+ }
7784
+
7785
+ if (ctx->device->need_compiles) {
7786
+ ggml_vk_load_shaders(ctx->device);
7787
+ }
7313
7788
 
7314
7789
  ggml_pipeline_allocate_descriptor_sets(ctx->device);
7315
7790
 
@@ -7318,13 +7793,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7318
7793
 
7319
7794
  vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
7320
7795
  ggml_vk_ctx_begin(ctx->device, subctx);
7321
- for (size_t i = 0; i < num_it; i++) {
7322
- ggml_vk_matmul(
7323
- ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
7324
- m, n, k,
7325
- k, k, m, k*m, k*n, m*n,
7326
- split_k, batch, batch, batch, 1, 1, n
7327
- );
7796
+ if (mmq) {
7797
+ for (size_t i = 0; i < num_it; i++) {
7798
+ ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
7799
+ ggml_vk_matmul(
7800
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7801
+ m, n, k,
7802
+ k, k, m, k*m, k*n, m*n,
7803
+ split_k, batch, batch, batch, 1, 1, n
7804
+ );
7805
+ }
7806
+ } else {
7807
+ for (size_t i = 0; i < num_it; i++) {
7808
+ ggml_vk_matmul(
7809
+ ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
7810
+ m, n, k,
7811
+ k, k, m, k*m, k*n, m*n,
7812
+ split_k, batch, batch, batch, 1, 1, n
7813
+ );
7814
+ }
7328
7815
  }
7329
7816
  ggml_vk_ctx_end(subctx);
7330
7817
 
@@ -7382,7 +7869,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7382
7869
 
7383
7870
  double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
7384
7871
 
7385
- std::cerr << "TEST MMQ " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
7872
+ std::cerr << "TEST dequant matmul " << shname;
7873
+ if (mmq) {
7874
+ std::cerr << " mmq";
7875
+ }
7876
+ std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
7386
7877
 
7387
7878
  if (avg_err > 0.01 || std::isnan(avg_err)) {
7388
7879
  std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
@@ -7392,6 +7883,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7392
7883
  std::cerr << "Expected result: " << std::endl << std::endl;
7393
7884
  ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
7394
7885
 
7886
+ std::cerr << "src0: " << std::endl << std::endl;
7887
+ ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
7888
+ std::cerr << std::endl;
7889
+ std::cerr << "src1: " << std::endl << std::endl;
7890
+ ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
7891
+
7395
7892
  if (split_k > 1) {
7396
7893
  float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
7397
7894
  ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
@@ -7414,6 +7911,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7414
7911
 
7415
7912
  ggml_vk_destroy_buffer(qx_buf);
7416
7913
  ggml_vk_destroy_buffer(y_buf);
7914
+ ggml_vk_destroy_buffer(qy_buf);
7417
7915
  ggml_vk_destroy_buffer(d_buf);
7418
7916
 
7419
7917
  free(x);
@@ -7448,6 +7946,24 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
7448
7946
  };
7449
7947
  const size_t num_it = 100;
7450
7948
 
7949
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
7950
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
7951
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
7952
+
7953
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
7954
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
7955
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
7956
+
7957
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
7958
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
7959
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
7960
+
7961
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
7962
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
7963
+ ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
7964
+
7965
+ abort();
7966
+
7451
7967
  for (size_t i = 0; i < vals.size(); i += 3) {
7452
7968
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
7453
7969
  ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
@@ -7522,11 +8038,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
7522
8038
  }
7523
8039
  }
7524
8040
 
7525
- static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence);
8041
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
7526
8042
 
7527
8043
  // Returns true if node has enqueued work into the queue, false otherwise
7528
8044
  // If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
7529
- static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){
8045
+ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
7530
8046
  if (ggml_is_empty(node) || !node->buffer) {
7531
8047
  return false;
7532
8048
  }
@@ -7898,7 +8414,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7898
8414
 
7899
8415
  ctx->compute_ctx.reset();
7900
8416
 
7901
- bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false);
8417
+ bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
7902
8418
  if (!ok) {
7903
8419
  if (node->op == GGML_OP_UNARY) {
7904
8420
  std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
@@ -7912,7 +8428,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7912
8428
  return true;
7913
8429
  }
7914
8430
 
7915
- static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){
8431
+ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
7916
8432
  ggml_backend_buffer * buf = nullptr;
7917
8433
 
7918
8434
  switch (tensor->op) {
@@ -8015,12 +8531,15 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
8015
8531
  memcpy(cpy.dst, cpy.src, cpy.n);
8016
8532
  }
8017
8533
 
8018
- ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{});
8534
+ if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) {
8535
+ ggml_vk_submit(subctx, ctx->almost_ready_fence);
8536
+ ctx->almost_ready_fence_pending = true;
8537
+ } else {
8538
+ ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{});
8539
+ }
8019
8540
 
8020
8541
  if (use_fence) {
8021
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_compute_forward waitForFences");
8022
-
8023
- ctx->device->device.resetFences({ ctx->fence });
8542
+ ggml_vk_wait_for_fence(ctx);
8024
8543
  }
8025
8544
  #ifdef GGML_VULKAN_CHECK_RESULTS
8026
8545
  ggml_vk_check_results_1(tensor);
@@ -8106,6 +8625,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
8106
8625
  ctx->gc.events.clear();
8107
8626
 
8108
8627
  ctx->device->device.destroyFence(ctx->fence);
8628
+ ctx->device->device.destroyFence(ctx->almost_ready_fence);
8109
8629
  }
8110
8630
 
8111
8631
  static int ggml_vk_get_device_count() {
@@ -8452,8 +8972,7 @@ static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
8452
8972
  }
8453
8973
 
8454
8974
  ggml_vk_submit(transfer_ctx, ctx->fence);
8455
- VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_backend_vk_synchronize waitForFences");
8456
- ctx->device->device.resetFences({ ctx->fence });
8975
+ ggml_vk_wait_for_fence(ctx);
8457
8976
 
8458
8977
  for (auto& cpy : transfer_ctx->out_memcpys) {
8459
8978
  memcpy(cpy.dst, cpy.src, cpy.n);
@@ -8472,7 +8991,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8472
8991
 
8473
8992
  uint64_t total_mat_mul_bytes = 0;
8474
8993
  for (int i = 0; i < cgraph->n_nodes; i++) {
8475
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
8994
+ ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
8476
8995
  if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8477
8996
  total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8478
8997
  }
@@ -8514,11 +9033,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8514
9033
  mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8515
9034
  }
8516
9035
 
9036
+ // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
9037
+ bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
8517
9038
  bool submit = (submitted_nodes >= nodes_per_submit) ||
8518
9039
  (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
8519
- (i == last_node);
9040
+ (i == last_node) ||
9041
+ (almost_ready && !ctx->almost_ready_fence_pending);
8520
9042
 
8521
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
9043
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
8522
9044
 
8523
9045
  if (enqueued) {
8524
9046
  ++submitted_nodes;
@@ -8530,7 +9052,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8530
9052
  #endif
8531
9053
  }
8532
9054
 
8533
- if (submit) {
9055
+ if (submit && enqueued) {
8534
9056
  first_node_in_batch = true;
8535
9057
  submitted_nodes = 0;
8536
9058
  mul_mat_bytes = 0;
@@ -8760,10 +9282,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8760
9282
  case 112:
8761
9283
  case 128:
8762
9284
  case 256:
9285
+ case 575: // DeepSeek MLA
8763
9286
  break;
8764
9287
  default:
8765
9288
  return false;
8766
9289
  }
9290
+ if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
9291
+ // different head sizes of K and V are not supported yet
9292
+ return false;
9293
+ }
8767
9294
  if (op->src[0]->type != GGML_TYPE_F32) {
8768
9295
  return false;
8769
9296
  }
@@ -8882,10 +9409,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8882
9409
  case GGML_OP_VIEW:
8883
9410
  case GGML_OP_PERMUTE:
8884
9411
  case GGML_OP_TRANSPOSE:
9412
+ case GGML_OP_RMS_NORM:
8885
9413
  return true;
8886
9414
  case GGML_OP_NORM:
8887
9415
  case GGML_OP_GROUP_NORM:
8888
- case GGML_OP_RMS_NORM:
8889
9416
  case GGML_OP_L2_NORM:
8890
9417
  return ggml_is_contiguous(op->src[0]);
8891
9418
  case GGML_OP_ADD:
@@ -8899,9 +9426,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8899
9426
  case GGML_OP_COS:
8900
9427
  case GGML_OP_CLAMP:
8901
9428
  return op->src[0]->type == GGML_TYPE_F32;
9429
+ case GGML_OP_UPSCALE:
9430
+ return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
8902
9431
  case GGML_OP_ACC:
8903
9432
  case GGML_OP_CONCAT:
8904
- case GGML_OP_UPSCALE:
8905
9433
  case GGML_OP_SCALE:
8906
9434
  case GGML_OP_PAD:
8907
9435
  case GGML_OP_DIAG_MASK_INF:
@@ -9254,7 +9782,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9254
9782
  }
9255
9783
 
9256
9784
  if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
9257
- const float *params = (const float *)tensor->op_params;
9785
+ const float * params = (const float *)tensor->op_params;
9258
9786
  tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
9259
9787
  } else if (tensor->op == GGML_OP_MUL_MAT) {
9260
9788
  tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
@@ -9269,9 +9797,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9269
9797
  } else if (tensor->op == GGML_OP_CONCAT) {
9270
9798
  tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
9271
9799
  } else if (tensor->op == GGML_OP_UPSCALE) {
9272
- tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
9800
+ tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]);
9273
9801
  } else if (tensor->op == GGML_OP_SCALE) {
9274
- tensor_clone = ggml_scale(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0]);
9802
+ const float * params = (const float *)tensor->op_params;
9803
+ tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
9275
9804
  } else if (tensor->op == GGML_OP_SQR) {
9276
9805
  tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
9277
9806
  } else if (tensor->op == GGML_OP_SIN) {
@@ -9279,7 +9808,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9279
9808
  } else if (tensor->op == GGML_OP_COS) {
9280
9809
  tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
9281
9810
  } else if (tensor->op == GGML_OP_CLAMP) {
9282
- tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9811
+ const float * params = (const float *)tensor->op_params;
9812
+ tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
9283
9813
  } else if (tensor->op == GGML_OP_PAD) {
9284
9814
  tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
9285
9815
  } else if (tensor->op == GGML_OP_REPEAT) {
@@ -9293,7 +9823,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9293
9823
  } else if (tensor->op == GGML_OP_NORM) {
9294
9824
  tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9295
9825
  } else if (tensor->op == GGML_OP_GROUP_NORM) {
9296
- tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
9826
+ const float * float_params = (const float *)tensor->op_params;
9827
+ tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
9297
9828
  } else if (tensor->op == GGML_OP_RMS_NORM) {
9298
9829
  tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9299
9830
  } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
@@ -9306,14 +9837,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9306
9837
  tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
9307
9838
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9308
9839
  if (src1 != nullptr) {
9309
- tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9840
+ const float * params = (const float *)tensor->op_params;
9841
+ tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
9310
9842
  } else {
9311
9843
  tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
9312
9844
  }
9313
9845
  } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9314
9846
  tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
9315
9847
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
9316
- tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
9848
+ tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
9317
9849
  } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
9318
9850
  const int n_dims = ((int32_t *) tensor->op_params)[1];
9319
9851
  const int mode = ((int32_t *) tensor->op_params)[2];