@fugood/llama.node 0.3.7 → 0.3.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (186) hide show
  1. package/README.md +17 -2
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +2 -0
  19. package/lib/index.js +16 -1
  20. package/lib/index.ts +16 -0
  21. package/package.json +1 -1
  22. package/src/EmbeddingWorker.cpp +4 -3
  23. package/src/LlamaCompletionWorker.cpp +4 -2
  24. package/src/LlamaContext.cpp +61 -6
  25. package/src/LlamaContext.h +1 -0
  26. package/src/common.hpp +6 -11
  27. package/src/llama.cpp/.github/workflows/build.yml +19 -17
  28. package/src/llama.cpp/.github/workflows/docker.yml +77 -30
  29. package/src/llama.cpp/.github/workflows/editorconfig.yml +3 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +22 -3
  31. package/src/llama.cpp/CMakeLists.txt +49 -24
  32. package/src/llama.cpp/common/arg.cpp +82 -26
  33. package/src/llama.cpp/common/arg.h +3 -0
  34. package/src/llama.cpp/common/common.cpp +192 -72
  35. package/src/llama.cpp/common/common.h +51 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +12 -12
  37. package/src/llama.cpp/common/ngram-cache.h +2 -2
  38. package/src/llama.cpp/common/sampling.cpp +11 -6
  39. package/src/llama.cpp/common/speculative.cpp +18 -15
  40. package/src/llama.cpp/docs/build.md +2 -0
  41. package/src/llama.cpp/examples/batched/batched.cpp +9 -7
  42. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +3 -3
  43. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +10 -8
  44. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +11 -8
  45. package/src/llama.cpp/examples/cvector-generator/mean.hpp +1 -1
  46. package/src/llama.cpp/examples/cvector-generator/pca.hpp +1 -1
  47. package/src/llama.cpp/examples/embedding/embedding.cpp +8 -7
  48. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +7 -6
  49. package/src/llama.cpp/examples/export-lora/export-lora.cpp +8 -7
  50. package/src/llama.cpp/examples/gguf/gguf.cpp +10 -6
  51. package/src/llama.cpp/examples/gguf-hash/gguf-hash.cpp +1 -0
  52. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +8 -7
  53. package/src/llama.cpp/examples/gritlm/gritlm.cpp +13 -10
  54. package/src/llama.cpp/examples/imatrix/imatrix.cpp +13 -12
  55. package/src/llama.cpp/examples/infill/infill.cpp +23 -24
  56. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +44 -13
  57. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -6
  58. package/src/llama.cpp/examples/llava/clip.cpp +4 -2
  59. package/src/llama.cpp/examples/llava/llava-cli.cpp +9 -6
  60. package/src/llama.cpp/examples/llava/llava.cpp +2 -2
  61. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +8 -4
  62. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +11 -8
  63. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -7
  64. package/src/llama.cpp/examples/lookup/lookup-create.cpp +4 -9
  65. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +3 -7
  66. package/src/llama.cpp/examples/lookup/lookup.cpp +5 -6
  67. package/src/llama.cpp/examples/main/main.cpp +51 -29
  68. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -6
  69. package/src/llama.cpp/examples/passkey/passkey.cpp +7 -5
  70. package/src/llama.cpp/examples/perplexity/perplexity.cpp +37 -23
  71. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -14
  72. package/src/llama.cpp/examples/retrieval/retrieval.cpp +8 -8
  73. package/src/llama.cpp/examples/rpc/rpc-server.cpp +12 -0
  74. package/src/llama.cpp/examples/run/CMakeLists.txt +1 -1
  75. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +1351 -0
  76. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +114 -0
  77. package/src/llama.cpp/examples/run/run.cpp +175 -61
  78. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -25
  79. package/src/llama.cpp/examples/server/CMakeLists.txt +1 -0
  80. package/src/llama.cpp/examples/server/httplib.h +1295 -409
  81. package/src/llama.cpp/examples/server/server.cpp +387 -181
  82. package/src/llama.cpp/examples/server/tests/requirements.txt +1 -0
  83. package/src/llama.cpp/examples/server/utils.hpp +170 -58
  84. package/src/llama.cpp/examples/simple/simple.cpp +9 -8
  85. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +16 -12
  86. package/src/llama.cpp/examples/speculative/speculative.cpp +22 -23
  87. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +8 -12
  88. package/src/llama.cpp/examples/tokenize/tokenize.cpp +17 -5
  89. package/src/llama.cpp/examples/tts/tts.cpp +64 -23
  90. package/src/llama.cpp/ggml/CMakeLists.txt +5 -21
  91. package/src/llama.cpp/ggml/include/ggml-backend.h +2 -0
  92. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -0
  93. package/src/llama.cpp/ggml/include/ggml.h +36 -145
  94. package/src/llama.cpp/ggml/include/gguf.h +202 -0
  95. package/src/llama.cpp/ggml/src/CMakeLists.txt +6 -3
  96. package/src/llama.cpp/ggml/src/ggml-alloc.c +5 -0
  97. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +0 -1
  98. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +79 -49
  99. package/src/llama.cpp/ggml/src/ggml-backend.cpp +5 -2
  100. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +33 -23
  101. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +57 -72
  102. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +87 -2
  103. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +335 -66
  104. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +10 -2
  105. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1090 -378
  106. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.h +2 -2
  107. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +1 -0
  108. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +3 -0
  109. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -0
  110. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +3 -1
  111. package/src/llama.cpp/ggml/src/ggml-impl.h +11 -16
  112. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +16 -0
  113. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +6 -6
  114. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +154 -35
  115. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +9 -3
  117. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +18 -0
  118. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/concat.hpp +1 -2
  120. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +3 -2
  121. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +1 -2
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +40 -95
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +48 -48
  124. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +24 -24
  125. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +238 -164
  126. package/src/llama.cpp/ggml/src/ggml-sycl/gla.cpp +105 -0
  127. package/src/llama.cpp/ggml/src/ggml-sycl/gla.hpp +8 -0
  128. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +3 -3
  129. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +1 -2
  130. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +3 -2
  131. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +1 -2
  132. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +7 -5
  133. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +1 -2
  134. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +74 -4
  135. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +314 -116
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +4 -2
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +9 -3
  138. package/src/llama.cpp/ggml/src/ggml.c +117 -1327
  139. package/src/llama.cpp/ggml/src/gguf.cpp +1329 -0
  140. package/src/llama.cpp/include/llama-cpp.h +6 -1
  141. package/src/llama.cpp/include/llama.h +138 -75
  142. package/src/llama.cpp/src/CMakeLists.txt +13 -1
  143. package/src/llama.cpp/src/llama-adapter.cpp +347 -0
  144. package/src/llama.cpp/src/llama-adapter.h +74 -0
  145. package/src/llama.cpp/src/llama-arch.cpp +1487 -0
  146. package/src/llama.cpp/src/llama-arch.h +400 -0
  147. package/src/llama.cpp/src/llama-batch.cpp +368 -0
  148. package/src/llama.cpp/src/llama-batch.h +88 -0
  149. package/src/llama.cpp/src/llama-chat.cpp +578 -0
  150. package/src/llama.cpp/src/llama-chat.h +52 -0
  151. package/src/llama.cpp/src/llama-context.cpp +1775 -0
  152. package/src/llama.cpp/src/llama-context.h +128 -0
  153. package/src/llama.cpp/src/llama-cparams.cpp +1 -0
  154. package/src/llama.cpp/src/llama-cparams.h +37 -0
  155. package/src/llama.cpp/src/llama-grammar.cpp +5 -4
  156. package/src/llama.cpp/src/llama-grammar.h +3 -1
  157. package/src/llama.cpp/src/llama-hparams.cpp +71 -0
  158. package/src/llama.cpp/src/llama-hparams.h +139 -0
  159. package/src/llama.cpp/src/llama-impl.cpp +167 -0
  160. package/src/llama.cpp/src/llama-impl.h +16 -136
  161. package/src/llama.cpp/src/llama-kv-cache.cpp +718 -0
  162. package/src/llama.cpp/src/llama-kv-cache.h +218 -0
  163. package/src/llama.cpp/src/llama-mmap.cpp +589 -0
  164. package/src/llama.cpp/src/llama-mmap.h +67 -0
  165. package/src/llama.cpp/src/llama-model-loader.cpp +1124 -0
  166. package/src/llama.cpp/src/llama-model-loader.h +167 -0
  167. package/src/llama.cpp/src/llama-model.cpp +3953 -0
  168. package/src/llama.cpp/src/llama-model.h +370 -0
  169. package/src/llama.cpp/src/llama-quant.cpp +934 -0
  170. package/src/llama.cpp/src/llama-quant.h +1 -0
  171. package/src/llama.cpp/src/llama-sampling.cpp +147 -32
  172. package/src/llama.cpp/src/llama-sampling.h +3 -19
  173. package/src/llama.cpp/src/llama-vocab.cpp +1832 -575
  174. package/src/llama.cpp/src/llama-vocab.h +97 -142
  175. package/src/llama.cpp/src/llama.cpp +7160 -20314
  176. package/src/llama.cpp/src/unicode.cpp +8 -3
  177. package/src/llama.cpp/tests/CMakeLists.txt +2 -0
  178. package/src/llama.cpp/tests/test-autorelease.cpp +3 -3
  179. package/src/llama.cpp/tests/test-backend-ops.cpp +370 -59
  180. package/src/llama.cpp/tests/test-chat-template.cpp +162 -125
  181. package/src/llama.cpp/tests/test-gguf.cpp +222 -187
  182. package/src/llama.cpp/tests/test-model-load-cancel.cpp +1 -1
  183. package/src/llama.cpp/tests/test-sampling.cpp +0 -1
  184. package/src/llama.cpp/tests/test-tokenizer-0.cpp +4 -4
  185. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +9 -7
  186. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +8 -6
@@ -145,6 +145,8 @@ class vk_perf_logger;
145
145
  #endif
146
146
  static void ggml_vk_destroy_buffer(vk_buffer& buf);
147
147
 
148
+ static constexpr uint32_t mul_mat_vec_max_cols = 8;
149
+
148
150
  struct vk_device_struct {
149
151
  std::mutex mutex;
150
152
 
@@ -202,8 +204,8 @@ struct vk_device_struct {
202
204
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
203
205
 
204
206
  vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
205
- vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
206
- vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT];
207
+ vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
208
+ vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
207
209
  vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
208
210
 
209
211
  vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
@@ -226,6 +228,8 @@ struct vk_device_struct {
226
228
  vk_pipeline pipeline_repeat_f32;
227
229
  vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
228
230
  vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
231
+ vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
232
+ vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
229
233
  vk_pipeline pipeline_norm_f32;
230
234
  vk_pipeline pipeline_group_norm_f32;
231
235
  vk_pipeline pipeline_rms_norm_f32;
@@ -382,10 +386,13 @@ struct vk_flash_attn_push_constants {
382
386
  uint32_t nev3;
383
387
  uint32_t nem1;
384
388
 
389
+ uint32_t nb01;
385
390
  uint32_t nb02;
386
391
  uint32_t nb03;
392
+ uint32_t nb11;
387
393
  uint32_t nb12;
388
394
  uint32_t nb13;
395
+ uint32_t nb21;
389
396
  uint32_t nb22;
390
397
  uint32_t nb23;
391
398
  uint32_t nb31;
@@ -411,7 +418,7 @@ struct vk_op_unary_push_constants {
411
418
  uint32_t ne;
412
419
  uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
413
420
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
414
- uint32_t d_offset;
421
+ uint32_t misalign_offsets;
415
422
  float param1; float param2;
416
423
  uint32_t ne0_012mp; uint32_t ne0_012L;
417
424
  uint32_t ne0_01mp; uint32_t ne0_01L;
@@ -459,7 +466,7 @@ struct vk_op_binary_push_constants {
459
466
  uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
460
467
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
461
468
  uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
462
- uint32_t d_offset;
469
+ uint32_t misalign_offsets;
463
470
  float param1; float param2; int32_t param3;
464
471
  };
465
472
 
@@ -546,7 +553,7 @@ struct vk_staging_memcpy {
546
553
  };
547
554
 
548
555
  struct vk_op_upscale_push_constants {
549
- uint32_t ne; uint32_t d_offset;
556
+ uint32_t ne; uint32_t a_offset; uint32_t d_offset;
550
557
  uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
551
558
  uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
552
559
  float sf0; float sf1; float sf2; float sf3;
@@ -1404,10 +1411,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
1404
1411
  // spec constants and tile sizes for non-quant matmul/matmul_id
1405
1412
  l_warptile = { 256, 128, 256, 64 };
1406
1413
  m_warptile = { 256, 128, 128, 64 };
1407
- s_warptile = { 128, 32, 16, 64 };
1414
+ s_warptile = { 128, 64, 64, 64 };
1408
1415
  l_wg_denoms = {128, 256, 1 };
1409
1416
  m_wg_denoms = {128, 128, 1 };
1410
- s_wg_denoms = { 32, 16, 1 };
1417
+ s_wg_denoms = { 64, 64, 1 };
1411
1418
 
1412
1419
  // spec constants and tile sizes for quant matmul (non-Qi_K)
1413
1420
  l_warptile_mmq = { 256, 128, 256, 64 };
@@ -1643,6 +1650,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1643
1650
  #undef CREATE_MM2
1644
1651
  } else
1645
1652
  #endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1653
+ #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
1646
1654
  if (device->coopmat_support) {
1647
1655
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1648
1656
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -1737,7 +1745,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
1737
1745
  }
1738
1746
  #undef CREATE_MM2
1739
1747
  #undef CREATE_MM
1740
- } else if (device->fp16) {
1748
+ } else
1749
+ #endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
1750
+ if (device->fp16) {
1741
1751
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1742
1752
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1743
1753
  if (device->mul_mat ## ID ## _l) \
@@ -1855,53 +1865,60 @@ static void ggml_vk_load_shaders(vk_device& device) {
1855
1865
 
1856
1866
  // mul mat vec
1857
1867
 
1858
- // AMD GCN and Intel graphics cards perform best when the number of rows per shader is doubled
1859
- uint32_t rm = 1;
1860
- if ((device->vendor_id == VK_VENDOR_ID_AMD && device->subgroup_min_size == 64 && device->subgroup_max_size == 64) || device->vendor_id == VK_VENDOR_ID_INTEL)
1861
- rm = 2;
1862
-
1863
- // computing additional rows per workgroup is a benefit for Q4_0 -> Q5_1, but not for Q8_0.
1864
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1865
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1866
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1867
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1868
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1869
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1870
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
1871
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1872
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1873
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1874
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1875
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1876
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
1877
-
1878
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1879
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1880
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1881
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1882
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1883
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1884
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
1885
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1886
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1887
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1888
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1889
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1890
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
1868
+ // the number of rows computed per shader depends on GPU model and quant
1869
+ uint32_t rm_stdq = 1;
1870
+ uint32_t rm_kq = 2;
1871
+ if (device->vendor_id == VK_VENDOR_ID_AMD) {
1872
+ if (device->subgroup_min_size == 64 && device->subgroup_max_size == 64) { // GCN
1873
+ rm_stdq = 2;
1874
+ rm_kq = 4;
1875
+ }
1876
+ } else if (device->vendor_id == VK_VENDOR_ID_INTEL)
1877
+ rm_stdq = 2;
1878
+
1879
+ for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
1880
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
1881
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
1882
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
1883
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
1884
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
1885
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
1886
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
1887
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1888
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1889
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1890
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1891
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1892
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
1893
+
1894
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
1895
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
1896
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
1897
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
1898
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
1899
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
1900
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
1901
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1902
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1903
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1904
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1905
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
1906
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
1907
+ }
1891
1908
 
1892
1909
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1893
1910
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
1894
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1895
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1896
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1897
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {device->subgroup_size, 2*rm}, 1, true);
1898
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm, 1, 1}, {device->subgroup_size, 1*rm}, 1, true);
1899
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1900
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1901
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1902
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1903
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1, 1, 1}, {subgroup_size_16}, 1, true);
1904
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm, 1, 1}, {subgroup_size_16, 2*rm}, 1, true);
1911
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
1912
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
1913
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
1914
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
1915
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
1916
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
1917
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
1918
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
1919
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
1920
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
1921
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
1905
1922
 
1906
1923
  // dequant shaders
1907
1924
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -1953,6 +1970,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
1953
1970
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1954
1971
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
1955
1972
 
1973
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
1974
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
1975
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
1976
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
1977
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
1978
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
1979
+
1980
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
1981
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
1982
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
1983
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
1984
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
1985
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
1986
+
1956
1987
  ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
1957
1988
  ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
1958
1989
  ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
@@ -2012,11 +2043,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2012
2043
 
2013
2044
  ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2014
2045
 
2015
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
2046
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2016
2047
  if (device->float_controls_rte_fp16) {
2017
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
2048
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2018
2049
  } else {
2019
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
2050
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
2020
2051
  }
2021
2052
 
2022
2053
  ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@@ -2031,6 +2062,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2031
2062
  std::cerr << "Done!" << std::endl;
2032
2063
  }
2033
2064
 
2065
+ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
2066
+
2034
2067
  static vk_device ggml_vk_get_device(size_t idx) {
2035
2068
  VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
2036
2069
 
@@ -2166,9 +2199,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2166
2199
 
2167
2200
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2168
2201
 
2169
- if (device->vendor_id == VK_VENDOR_ID_INTEL || (device->vendor_id == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2170
- // Intel drivers don't support coopmat properly yet
2171
- // Only RADV supports coopmat properly on AMD
2202
+ if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
2172
2203
  device->coopmat_support = false;
2173
2204
  }
2174
2205
 
@@ -2233,6 +2264,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2233
2264
  last_struct = (VkBaseOutStructure *)&subgroup_size_control_features;
2234
2265
  }
2235
2266
 
2267
+ #if defined(VK_KHR_cooperative_matrix)
2236
2268
  VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2237
2269
  coopmat_features.pNext = nullptr;
2238
2270
  coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
@@ -2242,6 +2274,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2242
2274
  last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
2243
2275
  last_struct = (VkBaseOutStructure *)&coopmat_features;
2244
2276
  }
2277
+ #endif
2245
2278
 
2246
2279
  #if defined(VK_NV_cooperative_matrix2)
2247
2280
  VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {};
@@ -2263,6 +2296,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2263
2296
  if (device->subgroup_size_control) {
2264
2297
  device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
2265
2298
  device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
2299
+ device_extensions.push_back("VK_EXT_subgroup_size_control");
2266
2300
  }
2267
2301
 
2268
2302
  device->subgroup_size_control = device->subgroup_size_control &&
@@ -2271,10 +2305,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
2271
2305
 
2272
2306
  if (device->subgroup_size_control) {
2273
2307
  device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
2274
- device_extensions.push_back("VK_EXT_subgroup_size_control");
2275
2308
  }
2276
2309
 
2310
+ #if defined(VK_KHR_cooperative_matrix)
2277
2311
  device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
2312
+ #endif
2278
2313
 
2279
2314
  if (coopmat2_support) {
2280
2315
  #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
@@ -2367,6 +2402,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2367
2402
  device_extensions.push_back("VK_KHR_shader_float16_int8");
2368
2403
  }
2369
2404
 
2405
+ #if defined(VK_KHR_cooperative_matrix)
2370
2406
  if (device->coopmat_support) {
2371
2407
  // Query supported shapes
2372
2408
  std::vector<VkCooperativeMatrixPropertiesKHR> cm_props;
@@ -2433,7 +2469,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2433
2469
  if (device->coopmat_support) {
2434
2470
  device_extensions.push_back("VK_KHR_cooperative_matrix");
2435
2471
  }
2436
-
2472
+ #endif
2437
2473
  device->name = GGML_VK_NAME + std::to_string(idx);
2438
2474
 
2439
2475
  device_create_info = {
@@ -2506,7 +2542,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
2506
2542
  return vk_instance.devices[idx];
2507
2543
  }
2508
2544
 
2509
-
2510
2545
  static void ggml_vk_print_gpu_info(size_t idx) {
2511
2546
  GGML_ASSERT(idx < vk_instance.device_indices.size());
2512
2547
  size_t dev_num = vk_instance.device_indices[idx];
@@ -2545,9 +2580,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2545
2580
  fp16_storage = true;
2546
2581
  } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) {
2547
2582
  fp16_compute = true;
2548
- } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2583
+ #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
2584
+ } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 &&
2549
2585
  !getenv("GGML_VK_DISABLE_COOPMAT")) {
2550
2586
  coopmat_support = true;
2587
+ #endif
2551
2588
  #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2552
2589
  } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
2553
2590
  !getenv("GGML_VK_DISABLE_COOPMAT2")) {
@@ -2556,9 +2593,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2556
2593
  }
2557
2594
  }
2558
2595
 
2559
- if (props2.properties.vendorID == VK_VENDOR_ID_INTEL || (props2.properties.vendorID == VK_VENDOR_ID_AMD && (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource))) {
2560
- // Intel drivers don't support coopmat properly yet
2561
- // Only RADV supports coopmat properly on AMD
2596
+ if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
2562
2597
  coopmat_support = false;
2563
2598
  }
2564
2599
 
@@ -2587,6 +2622,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2587
2622
  // Pointer to the last chain element
2588
2623
  VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_features;
2589
2624
 
2625
+ #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
2590
2626
  VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
2591
2627
  coopmat_features.pNext = nullptr;
2592
2628
  coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR;
@@ -2602,6 +2638,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2602
2638
  fp16 = fp16 && vk12_features.shaderFloat16;
2603
2639
 
2604
2640
  coopmat_support = coopmat_support && coopmat_features.cooperativeMatrix;
2641
+ #endif
2605
2642
 
2606
2643
  std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
2607
2644
 
@@ -2887,9 +2924,10 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
2887
2924
  return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
2888
2925
  }
2889
2926
 
2890
- static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
2927
+ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
2891
2928
  VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
2892
2929
  GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
2930
+ GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
2893
2931
 
2894
2932
  switch (a_type) {
2895
2933
  case GGML_TYPE_F32:
@@ -2910,7 +2948,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
2910
2948
  return nullptr;
2911
2949
  }
2912
2950
 
2913
- return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
2951
+ return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1];
2914
2952
  }
2915
2953
 
2916
2954
  static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
@@ -3205,8 +3243,8 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
3205
3243
  GGML_ABORT("fatal error");
3206
3244
  }
3207
3245
  // Check if src is pinned memory
3208
- vk_buffer buf;
3209
- size_t buf_offset;
3246
+ vk_buffer buf = nullptr;
3247
+ size_t buf_offset = 0;
3210
3248
  ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset);
3211
3249
 
3212
3250
  const uint64_t ne0 = tensor->ne[0];
@@ -3269,7 +3307,7 @@ static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_cont
3269
3307
  VkBufferCopy buf_copy{ 0, offset, copy_size };
3270
3308
 
3271
3309
  ggml_vk_sync_buffers(subctx);
3272
- vkCmdCopyBuffer(subctx->s->buffer, staging->buffer, dst->buffer, 1, &buf_copy);
3310
+ vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
3273
3311
 
3274
3312
  for (uint64_t i3 = 0; i3 < ne3; i3++) {
3275
3313
  for (uint64_t i2 = 0; i2 < ne2; i2++) {
@@ -3302,7 +3340,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
3302
3340
  }
3303
3341
  // Check if src is pinned memory
3304
3342
  vk_buffer buf = nullptr;
3305
- size_t buf_offset;
3343
+ size_t buf_offset = 0;
3306
3344
  ggml_vk_host_get(dst->device, src, buf, buf_offset);
3307
3345
 
3308
3346
  if (buf != nullptr) {
@@ -3344,7 +3382,7 @@ static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, siz
3344
3382
  copy_size};
3345
3383
 
3346
3384
  ggml_vk_sync_buffers(subctx);
3347
- vkCmdCopyBuffer(subctx->s->buffer, staging_buffer->buffer, dst->buffer, 1, &buf_copy);
3385
+ vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy);
3348
3386
 
3349
3387
  if (width == spitch) {
3350
3388
  deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys);
@@ -3400,7 +3438,7 @@ static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size
3400
3438
 
3401
3439
  // Check if dst is pinned memory
3402
3440
  vk_buffer buf = nullptr;
3403
- size_t buf_offset;
3441
+ size_t buf_offset = 0;
3404
3442
  ggml_vk_host_get(src->device, dst, buf, buf_offset);
3405
3443
 
3406
3444
  std::vector<vk::BufferCopy> slices(1);
@@ -3480,7 +3518,7 @@ static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t ds
3480
3518
 
3481
3519
  VkBufferCopy bc{ src_offset, dst_offset, size };
3482
3520
 
3483
- vkCmdCopyBuffer(ctx->s->buffer, src->buffer, dst->buffer, 1, &bc);
3521
+ vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc);
3484
3522
  }
3485
3523
 
3486
3524
  static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) {
@@ -3670,6 +3708,33 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
3670
3708
  return ctx->device->pipeline_cpy_f16_f16;
3671
3709
  }
3672
3710
  }
3711
+ if (src->type == GGML_TYPE_F32) {
3712
+ switch (to) {
3713
+ case GGML_TYPE_Q4_0:
3714
+ case GGML_TYPE_Q4_1:
3715
+ case GGML_TYPE_Q5_0:
3716
+ case GGML_TYPE_Q5_1:
3717
+ case GGML_TYPE_Q8_0:
3718
+ case GGML_TYPE_IQ4_NL:
3719
+ return ctx->device->pipeline_cpy_f32_quant[to];
3720
+ default:
3721
+ break;
3722
+ }
3723
+ }
3724
+
3725
+ if (to == GGML_TYPE_F32) {
3726
+ switch (src->type) {
3727
+ case GGML_TYPE_Q4_0:
3728
+ case GGML_TYPE_Q4_1:
3729
+ case GGML_TYPE_Q5_0:
3730
+ case GGML_TYPE_Q5_1:
3731
+ case GGML_TYPE_Q8_0:
3732
+ case GGML_TYPE_IQ4_NL:
3733
+ return ctx->device->pipeline_cpy_quant_f32[src->type];
3734
+ default:
3735
+ break;
3736
+ }
3737
+ }
3673
3738
 
3674
3739
  std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
3675
3740
  GGML_ABORT("fatal error");
@@ -3732,9 +3797,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
3732
3797
  ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
3733
3798
  ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
3734
3799
 
3735
- vk_buffer d_Qx;
3800
+ vk_buffer d_Qx = nullptr;
3736
3801
  size_t qx_buf_offset = 0;
3737
- vk_buffer d_Qy;
3802
+ vk_buffer d_Qy = nullptr;
3738
3803
  size_t qy_buf_offset = 0;
3739
3804
 
3740
3805
  bool src0_uma = false;
@@ -3920,8 +3985,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
3920
3985
  const uint64_t ne12 = src1->ne[2];
3921
3986
  const uint64_t ne13 = src1->ne[3];
3922
3987
 
3923
- GGML_ASSERT(ne11 == 1);
3924
-
3925
3988
  const uint64_t ne20 = dst->ne[0];
3926
3989
  const uint64_t ne21 = dst->ne[1];
3927
3990
  const uint64_t ne22 = dst->ne[2];
@@ -3930,13 +3993,18 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
3930
3993
  const uint64_t r2 = ne12 / ne02;
3931
3994
  const uint64_t r3 = ne13 / ne03;
3932
3995
 
3996
+ // batch_n indicates that we need to compute a few vector results, and this assumes
3997
+ // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
3998
+ GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
3999
+ bool batch_n = ne11 > 1;
4000
+
3933
4001
  ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
3934
4002
  ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
3935
4003
  ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
3936
4004
 
3937
- vk_buffer d_Qx;
4005
+ vk_buffer d_Qx = nullptr;
3938
4006
  size_t qx_buf_offset = 0;
3939
- vk_buffer d_Qy;
4007
+ vk_buffer d_Qy = nullptr;
3940
4008
  size_t qy_buf_offset = 0;
3941
4009
 
3942
4010
  bool src0_uma = false;
@@ -3980,7 +4048,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
3980
4048
  } else {
3981
4049
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
3982
4050
  }
3983
- vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
4051
+ vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
3984
4052
  GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
3985
4053
  GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
3986
4054
  GGML_ASSERT(dmmv != nullptr);
@@ -4052,8 +4120,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
4052
4120
  ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
4053
4121
  }
4054
4122
 
4055
- uint32_t stride_batch_x = ne00*ne01;
4056
- uint32_t stride_batch_y = ne10*ne11;
4123
+ // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
4124
+ uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
4125
+ uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
4126
+ uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
4057
4127
 
4058
4128
  if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
4059
4129
  stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
@@ -4076,7 +4146,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
4076
4146
  // compute
4077
4147
  const vk_mat_vec_push_constants pc = {
4078
4148
  (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
4079
- stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
4149
+ stride_batch_x, stride_batch_y, stride_batch_d,
4080
4150
  (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
4081
4151
  };
4082
4152
  ggml_vk_sync_buffers(subctx);
@@ -4112,7 +4182,7 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4112
4182
  ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
4113
4183
  ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
4114
4184
 
4115
- vk_buffer d_Qy;
4185
+ vk_buffer d_Qy = nullptr;
4116
4186
  size_t qy_buf_offset = 0;
4117
4187
 
4118
4188
  bool src1_uma = false;
@@ -4256,7 +4326,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
4256
4326
  } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
4257
4327
  !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
4258
4328
  ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
4259
- } else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
4329
+ // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
4330
+ // when ne12 and ne13 are one.
4331
+ } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
4332
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
4260
4333
  ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
4261
4334
  } else {
4262
4335
  ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -4300,11 +4373,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4300
4373
  ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
4301
4374
  ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
4302
4375
 
4303
- vk_buffer d_Qx;
4376
+ vk_buffer d_Qx = nullptr;
4304
4377
  size_t qx_buf_offset = 0;
4305
- vk_buffer d_Qy;
4378
+ vk_buffer d_Qy = nullptr;
4306
4379
  size_t qy_buf_offset = 0;
4307
- vk_buffer d_ids;
4380
+ vk_buffer d_ids = nullptr;
4308
4381
  size_t ids_buf_offset = 0;
4309
4382
 
4310
4383
  bool src0_uma = false;
@@ -4505,11 +4578,11 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
4505
4578
  ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
4506
4579
  ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context;
4507
4580
 
4508
- vk_buffer d_Qx;
4581
+ vk_buffer d_Qx = nullptr;
4509
4582
  size_t qx_buf_offset = 0;
4510
- vk_buffer d_Qy;
4583
+ vk_buffer d_Qy = nullptr;
4511
4584
  size_t qy_buf_offset = 0;
4512
- vk_buffer d_ids;
4585
+ vk_buffer d_ids = nullptr;
4513
4586
  size_t ids_buf_offset = 0;
4514
4587
 
4515
4588
  bool src0_uma = false;
@@ -4739,7 +4812,14 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
4739
4812
  }
4740
4813
  assert(pipelines);
4741
4814
 
4742
- bool aligned = (KV % pipelines[1]->align) == 0;
4815
+ const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
4816
+ const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
4817
+ const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
4818
+
4819
+ bool aligned = (KV % pipelines[1]->align) == 0 &&
4820
+ // the "aligned" shader variant will forcibly align strides, for performance
4821
+ (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
4822
+
4743
4823
  vk_pipeline pipeline = pipelines[aligned];
4744
4824
  assert(pipeline);
4745
4825
 
@@ -4768,22 +4848,22 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
4768
4848
 
4769
4849
  ggml_vk_sync_buffers(subctx);
4770
4850
 
4771
- vk_buffer d_Q, d_K, d_V, d_D, d_M;
4772
- uint64_t q_buf_offset, k_buf_offset, v_buf_offset, d_buf_offset, m_buf_offset;
4851
+ vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
4852
+ size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
4773
4853
 
4774
4854
  bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false;
4775
4855
 
4776
4856
  if (ctx->device->uma) {
4777
4857
  ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset);
4778
- ggml_vk_host_get(ctx->device, k->data, d_K, q_buf_offset);
4779
- ggml_vk_host_get(ctx->device, v->data, d_V, q_buf_offset);
4780
- ggml_vk_host_get(ctx->device, dst->data, d_D, q_buf_offset);
4858
+ ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset);
4859
+ ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset);
4860
+ ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset);
4781
4861
  Q_uma = d_Q != nullptr;
4782
4862
  K_uma = d_K != nullptr;
4783
4863
  V_uma = d_V != nullptr;
4784
4864
  D_uma = d_D != nullptr;
4785
4865
  if (mask) {
4786
- ggml_vk_host_get(ctx->device, mask->data, d_M, q_buf_offset);
4866
+ ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset);
4787
4867
  M_uma = d_M != nullptr;
4788
4868
  }
4789
4869
  }
@@ -4821,7 +4901,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
4821
4901
  }
4822
4902
  }
4823
4903
 
4824
- const vk_flash_attn_push_constants pc = { N, KV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, (uint32_t)neq2, (uint32_t)neq3, (uint32_t)nek2, (uint32_t)nek3, (uint32_t)nev2, (uint32_t)nev3, nem1, (uint32_t)nbq2, (uint32_t)nbq3, (uint32_t)nbk2, (uint32_t)nbk3, (uint32_t)nbv2, (uint32_t)nbv3, nbm1, scale, max_bias, logit_softcap, mask != nullptr, n_head_log2, m0, m1 };
4904
+ const vk_flash_attn_push_constants pc = { N, KV,
4905
+ (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3,
4906
+ (uint32_t)neq2, (uint32_t)neq3,
4907
+ (uint32_t)nek2, (uint32_t)nek3,
4908
+ (uint32_t)nev2, (uint32_t)nev3,
4909
+ nem1,
4910
+ q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
4911
+ k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
4912
+ v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
4913
+ nbm1,
4914
+ scale, max_bias, logit_softcap,
4915
+ mask != nullptr, n_head_log2, m0, m1 };
4825
4916
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
4826
4917
  {
4827
4918
  vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
@@ -5071,6 +5162,57 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
5071
5162
  }
5072
5163
  }
5073
5164
 
5165
+ static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t)
5166
+ {
5167
+ return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
5168
+ }
5169
+
5170
+ template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
5171
+ GGML_UNUSED(p);
5172
+ GGML_UNUSED(src0);
5173
+ GGML_UNUSED(src1);
5174
+ GGML_UNUSED(src2);
5175
+ GGML_UNUSED(dst);
5176
+ static_assert(!std::is_const<T>::value, "unexpected type");
5177
+ GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
5178
+ GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
5179
+ GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
5180
+ GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
5181
+ }
5182
+
5183
+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
5184
+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
5185
+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
5186
+
5187
+ p.misalign_offsets = (a_offset << 16) | d_offset;
5188
+
5189
+ GGML_UNUSED(src1);
5190
+ GGML_UNUSED(src2);
5191
+ }
5192
+
5193
+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
5194
+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
5195
+ const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
5196
+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
5197
+
5198
+ GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));
5199
+
5200
+ p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
5201
+
5202
+ GGML_UNUSED(src2);
5203
+ }
5204
+
5205
+ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
5206
+ const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
5207
+ const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
5208
+
5209
+ p.a_offset = a_offset;
5210
+ p.d_offset = d_offset;
5211
+
5212
+ GGML_UNUSED(src1);
5213
+ GGML_UNUSED(src2);
5214
+ }
5215
+
5074
5216
  template<typename PC>
5075
5217
  static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
5076
5218
  VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
@@ -5082,7 +5224,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5082
5224
  }
5083
5225
  std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
5084
5226
  std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")");
5085
- GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
5227
+ GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
5086
5228
  GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
5087
5229
  GGML_ASSERT(dst->buffer != nullptr);
5088
5230
  const uint64_t ne00 = src0->ne[0];
@@ -5174,8 +5316,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5174
5316
  }
5175
5317
 
5176
5318
  GGML_ASSERT(d_D != nullptr);
5177
- uint64_t d_buf_offset = ((vk_tensor_offset(dst) + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
5178
- GGML_ASSERT(d_buf_offset == vk_tensor_offset(dst) || op == GGML_OP_CPY); // NOLINT
5319
+ uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
5179
5320
  if(!src0_uma) {
5180
5321
  d_X = src0_buf_ctx->dev_buffer;
5181
5322
  x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
@@ -5191,6 +5332,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5191
5332
  z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
5192
5333
  GGML_ASSERT(d_Z != nullptr);
5193
5334
  }
5335
+ // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
5336
+ init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst);
5337
+ x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
5338
+ y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
5339
+ z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
5340
+ d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
5194
5341
 
5195
5342
  if (op_supports_incontiguous) {
5196
5343
  x_sz = ggml_nbytes(src0);
@@ -5378,7 +5525,6 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
5378
5525
  const uint32_t src0_type_size = ggml_type_size(src0->type);
5379
5526
  const uint32_t src1_type_size = ggml_type_size(src1->type);
5380
5527
  const uint32_t dst_type_size = ggml_type_size(dst->type);
5381
- const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
5382
5528
 
5383
5529
  int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
5384
5530
  int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
@@ -5390,7 +5536,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
5390
5536
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
5391
5537
  (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
5392
5538
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
5393
- d_offset,
5539
+ 0,
5394
5540
  0.0f, 0.0f, offset,
5395
5541
  }, dryrun);
5396
5542
  }
@@ -5474,8 +5620,8 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
5474
5620
 
5475
5621
  ggml_vk_sync_buffers(subctx);
5476
5622
 
5477
- vk_buffer d_D, d_K, d_V, d_R, d_TF, d_TD, d_State;
5478
- uint64_t k_offset, v_offset, r_offset, tf_offset, td_offset, state_offset, dst_offset;
5623
+ vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
5624
+ size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
5479
5625
  bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
5480
5626
 
5481
5627
  if (ctx->device->uma) {
@@ -5551,9 +5697,9 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
5551
5697
  }
5552
5698
 
5553
5699
  static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
5554
- const size_t seq_length = dst->src[0]->ne[3];
5700
+ const size_t seq_length = dst->src[0]->ne[2];
5555
5701
  const size_t n_embed = dst->ne[0];
5556
- const size_t n_heads = dst->src[0]->ne[2];
5702
+ const size_t n_heads = dst->src[0]->ne[1];
5557
5703
  const size_t n_seqs = dst->src[5]->ne[1];
5558
5704
 
5559
5705
  ggml_vk_op_f32_rwkv6(
@@ -5594,7 +5740,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
5594
5740
  const float sf3 = (float)dst->ne[3] / src0->ne[3];
5595
5741
 
5596
5742
  ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
5597
- (uint32_t)ggml_nelements(dst), 0,
5743
+ (uint32_t)ggml_nelements(dst), 0, 0,
5598
5744
  (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
5599
5745
  (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
5600
5746
  sf0, sf1, sf2, sf3,
@@ -5704,13 +5850,12 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
5704
5850
  static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
5705
5851
  const uint32_t src0_type_size = ggml_type_size(src0->type);
5706
5852
  const uint32_t dst_type_size = ggml_type_size(dst->type);
5707
- const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
5708
5853
 
5709
5854
  ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
5710
5855
  (uint32_t)ggml_nelements(src0),
5711
5856
  (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
5712
5857
  (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
5713
- d_offset,
5858
+ 0,
5714
5859
  0.0f, 0.0f,
5715
5860
  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
5716
5861
  }, dryrun);
@@ -7824,12 +7969,36 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
7824
7969
  {
7825
7970
  ggml_type src0_type = op->src[0]->type;
7826
7971
  ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type;
7827
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
7828
- return true;
7972
+
7973
+ if (src0_type == GGML_TYPE_F32) {
7974
+ switch (src1_type) {
7975
+ case GGML_TYPE_F32:
7976
+ case GGML_TYPE_F16:
7977
+ case GGML_TYPE_Q4_0:
7978
+ case GGML_TYPE_Q4_1:
7979
+ case GGML_TYPE_Q5_0:
7980
+ case GGML_TYPE_Q5_1:
7981
+ case GGML_TYPE_Q8_0:
7982
+ case GGML_TYPE_IQ4_NL:
7983
+ return true;
7984
+ default:
7985
+ break;
7986
+ }
7829
7987
  }
7830
- if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
7831
- return true;
7988
+ if (src1_type == GGML_TYPE_F32) {
7989
+ switch (src0_type) {
7990
+ case GGML_TYPE_Q4_0:
7991
+ case GGML_TYPE_Q4_1:
7992
+ case GGML_TYPE_Q5_0:
7993
+ case GGML_TYPE_Q5_1:
7994
+ case GGML_TYPE_Q8_0:
7995
+ case GGML_TYPE_IQ4_NL:
7996
+ return true;
7997
+ default:
7998
+ break;
7999
+ }
7832
8000
  }
8001
+
7833
8002
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
7834
8003
  return true;
7835
8004
  }
@@ -8016,6 +8185,25 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
8016
8185
  UNUSED(instance_extensions);
8017
8186
  }
8018
8187
 
8188
+ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
8189
+ switch (props.vendorID) {
8190
+ case VK_VENDOR_ID_INTEL:
8191
+ // Intel drivers don't support coopmat properly yet
8192
+ return false;
8193
+ case VK_VENDOR_ID_AMD:
8194
+ if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
8195
+ // Workaround for AMD proprietary driver reporting support on all GPUs
8196
+ const std::string name = props.deviceName;
8197
+ return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
8198
+ name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
8199
+ name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
8200
+ }
8201
+ return true;
8202
+ default:
8203
+ return true;
8204
+ }
8205
+ }
8206
+
8019
8207
  // checks
8020
8208
 
8021
8209
  #ifdef GGML_VULKAN_CHECK_RESULTS
@@ -8501,6 +8689,7 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8501
8689
  ggml_tensor * src0 = tensor->src[0];
8502
8690
  ggml_tensor * src1 = tensor->src[1];
8503
8691
  ggml_tensor * src2 = tensor->src[2];
8692
+ ggml_tensor * src3 = tensor->src[3];
8504
8693
 
8505
8694
  void * tensor_data = tensor->data;
8506
8695
 
@@ -8563,6 +8752,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8563
8752
  if (src2 != nullptr) {
8564
8753
  std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
8565
8754
  }
8755
+ if (src3 != nullptr) {
8756
+ std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
8757
+ }
8566
8758
  std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
8567
8759
  std::cerr << std::endl << "Result:" << std::endl;
8568
8760
  ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3);
@@ -8607,6 +8799,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8607
8799
  if (src2 != nullptr) {
8608
8800
  std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
8609
8801
  }
8802
+ if (src3 != nullptr) {
8803
+ std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
8804
+ }
8610
8805
  std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
8611
8806
  std::cerr << std::endl << "Result:" << std::endl;
8612
8807
  ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0);
@@ -8629,6 +8824,9 @@ static void ggml_vk_check_results_1(ggml_tensor * tensor) {
8629
8824
  if (src2 != nullptr) {
8630
8825
  std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl;
8631
8826
  }
8827
+ if (src3 != nullptr) {
8828
+ std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl;
8829
+ }
8632
8830
  std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl;
8633
8831
  std::cerr << std::endl << "Result:" << std::endl;
8634
8832
  ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]);