@fugood/llama.node 0.3.14 → 0.3.16

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +37 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +20 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +10 -28
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +1493 -12
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  58. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +284 -29
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  60. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  61. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  62. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  63. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  64. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +35 -12
  65. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  66. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +93 -27
  67. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  68. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  69. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  70. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +12 -43
  71. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  72. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +109 -40
  73. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  74. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  75. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  76. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  78. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  79. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +398 -158
  81. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  82. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +7 -2
  83. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  84. package/src/llama.cpp/include/llama.h +86 -22
  85. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  86. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  87. package/src/llama.cpp/src/llama-adapter.h +11 -9
  88. package/src/llama.cpp/src/llama-arch.cpp +103 -16
  89. package/src/llama.cpp/src/llama-arch.h +18 -0
  90. package/src/llama.cpp/src/llama-batch.h +2 -2
  91. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  92. package/src/llama.cpp/src/llama-context.h +214 -77
  93. package/src/llama.cpp/src/llama-cparams.h +1 -0
  94. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  95. package/src/llama.cpp/src/llama-graph.h +574 -0
  96. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  97. package/src/llama.cpp/src/llama-hparams.h +9 -0
  98. package/src/llama.cpp/src/llama-io.cpp +15 -0
  99. package/src/llama.cpp/src/llama-io.h +35 -0
  100. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  101. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  102. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  103. package/src/llama.cpp/src/llama-memory.h +21 -0
  104. package/src/llama.cpp/src/llama-model.cpp +8244 -173
  105. package/src/llama.cpp/src/llama-model.h +34 -1
  106. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  107. package/src/llama.cpp/src/llama.cpp +51 -9984
  108. package/src/llama.cpp/tests/test-backend-ops.cpp +145 -23
  109. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  110. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -29,6 +29,7 @@
29
29
 
30
30
  #include "ggml-vulkan-shaders.hpp"
31
31
 
32
+ #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
32
33
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
33
34
 
34
35
  #define VK_VENDOR_ID_AMD 0x1002
@@ -148,6 +149,67 @@ class vk_perf_logger;
148
149
  static void ggml_vk_destroy_buffer(vk_buffer& buf);
149
150
 
150
151
  static constexpr uint32_t mul_mat_vec_max_cols = 8;
152
+ static constexpr uint32_t p021_max_gqa_ratio = 8;
153
+
154
+ enum vk_device_architecture {
155
+ OTHER,
156
+ AMD_GCN,
157
+ AMD_RDNA1,
158
+ AMD_RDNA2,
159
+ AMD_RDNA3,
160
+ };
161
+
162
+ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
163
+ vk::PhysicalDeviceProperties props = device.getProperties();
164
+
165
+ if (props.vendorID == VK_VENDOR_ID_AMD) {
166
+ const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
167
+
168
+ bool amd_shader_core_properties = false;
169
+ bool integer_dot_product = false;
170
+ bool subgroup_size_control = false;
171
+
172
+ for (const auto& properties : ext_props) {
173
+ if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
174
+ amd_shader_core_properties = true;
175
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
176
+ integer_dot_product = true;
177
+ } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
178
+ subgroup_size_control = true;
179
+ }
180
+ }
181
+
182
+ if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
183
+ return vk_device_architecture::OTHER;
184
+ }
185
+
186
+ vk::PhysicalDeviceProperties2 props2;
187
+ vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
188
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
189
+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
190
+
191
+ props2.pNext = &shader_core_props_amd;
192
+ shader_core_props_amd.pNext = &integer_dot_props;
193
+ integer_dot_props.pNext = &subgroup_size_control_props;
194
+
195
+ device.getProperties2(&props2);
196
+
197
+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
198
+ return vk_device_architecture::AMD_GCN;
199
+ }
200
+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
201
+ // RDNA
202
+ if (shader_core_props_amd.wavefrontsPerSimd == 20) {
203
+ return vk_device_architecture::AMD_RDNA1;
204
+ }
205
+ if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
206
+ return vk_device_architecture::AMD_RDNA3;
207
+ }
208
+ return vk_device_architecture::AMD_RDNA2;
209
+ }
210
+ }
211
+ return vk_device_architecture::OTHER;
212
+ }
151
213
 
152
214
  struct vk_device_struct {
153
215
  std::mutex mutex;
@@ -161,6 +223,7 @@ struct vk_device_struct {
161
223
  bool pipeline_robustness;
162
224
  vk::Device device;
163
225
  uint32_t vendor_id;
226
+ vk_device_architecture architecture;
164
227
  vk_queue compute_queue;
165
228
  vk_queue transfer_queue;
166
229
  bool single_queue;
@@ -169,6 +232,7 @@ struct vk_device_struct {
169
232
  bool uma;
170
233
  bool prefer_host_memory;
171
234
  bool float_controls_rte_fp16;
235
+ bool subgroup_add;
172
236
 
173
237
  bool subgroup_size_control;
174
238
  uint32_t subgroup_min_size;
@@ -215,7 +279,7 @@ struct vk_device_struct {
215
279
  vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
216
280
  vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
217
281
 
218
- vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
282
+ vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
219
283
  vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
220
284
  vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
221
285
  vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
@@ -242,6 +306,7 @@ struct vk_device_struct {
242
306
  vk_pipeline pipeline_group_norm_f32;
243
307
  vk_pipeline pipeline_rms_norm_f32;
244
308
  vk_pipeline pipeline_rms_norm_back_f32;
309
+ vk_pipeline pipeline_l2_norm_f32;
245
310
  vk_pipeline pipeline_gelu_f32;
246
311
  vk_pipeline pipeline_gelu_quick_f32;
247
312
  vk_pipeline pipeline_silu_f32;
@@ -266,6 +331,7 @@ struct vk_device_struct {
266
331
  vk_pipeline pipeline_timestep_embedding_f32;
267
332
  vk_pipeline pipeline_pool2d_f32;
268
333
  vk_pipeline pipeline_rwkv_wkv6_f32;
334
+ vk_pipeline pipeline_rwkv_wkv7_f32;
269
335
  vk_pipeline pipeline_opt_step_adamw_f32;
270
336
 
271
337
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
@@ -368,6 +434,7 @@ struct vk_mat_mat_push_constants {
368
434
  uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
369
435
  uint32_t k_split;
370
436
  uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
437
+ uint32_t padded_N;
371
438
  };
372
439
  struct vk_mat_vec_push_constants {
373
440
  uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -380,6 +447,7 @@ struct vk_mat_mat_id_push_constants {
380
447
  uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
381
448
  uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
382
449
  uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
450
+ uint32_t padded_N;
383
451
  };
384
452
  struct vk_mat_vec_id_push_constants {
385
453
  uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -565,6 +633,13 @@ struct vk_op_rwkv_wkv6_push_constants {
565
633
  uint32_t H;
566
634
  };
567
635
 
636
+ struct vk_op_rwkv_wkv7_push_constants {
637
+ uint32_t B;
638
+ uint32_t T;
639
+ uint32_t C;
640
+ uint32_t H;
641
+ };
642
+
568
643
  // Allow pre-recording command buffers
569
644
  struct vk_staging_memcpy {
570
645
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1445,6 +1520,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1445
1520
  return supported;
1446
1521
  }
1447
1522
 
1523
+ struct GpuPipelineConfig {
1524
+ // GPU architecture identifier.
1525
+ // Example: vk_device_architecture::AMD_GCN
1526
+ vk_device_architecture arch;
1527
+
1528
+ // Mapping of pipeline names to their specific subgroup sizes.
1529
+ // Example: {"soft_max_f32", 64}
1530
+ std::unordered_map<std::string, uint32_t> pipelines;
1531
+
1532
+ // Default subgroup size for this GPU.
1533
+ // Defaults to 0 if not explicitly provided.
1534
+ uint32_t default_subgroup_size = 0;
1535
+ };
1536
+
1537
+ // Pipeline configuration for RDNA1 GPUs.
1538
+ static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
1539
+ {"soft_max", 64}, {"im2col", 64},
1540
+ {"argmax", 64}, {"mul_mat_vec", 64},
1541
+ {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
1542
+ };
1543
+
1544
+ // Pipeline configuration for RDNA2 GPUs.
1545
+ static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
1546
+ {"soft_max", 64}, {"im2col", 64},
1547
+ };
1548
+
1549
+ static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
1550
+
1551
+ // Define configurations for different GPUs.
1552
+ static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
1553
+ {
1554
+ vk_device_architecture::AMD_RDNA1,
1555
+ {
1556
+ rdna1_pipelines,
1557
+ },
1558
+ RDNA_DEFAULT_SUBGROUP_SIZE
1559
+ },
1560
+ {
1561
+ vk_device_architecture::AMD_RDNA2,
1562
+ {
1563
+ rdna2_pipelines,
1564
+ },
1565
+ RDNA_DEFAULT_SUBGROUP_SIZE
1566
+ },
1567
+ };
1568
+
1569
+ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
1570
+ for (const auto &config : gpu_pipeline_configs) {
1571
+ if (config.arch == arch) {
1572
+ auto pipIt = config.pipelines.find(pipeline_name);
1573
+ if (pipIt != config.pipelines.end()) {
1574
+ return pipIt->second;
1575
+ }
1576
+ std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
1577
+ std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
1578
+ [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
1579
+ for (const auto &entry : sorted_pipelines) {
1580
+ if (pipeline_name.find(entry.first) != std::string::npos) {
1581
+ return entry.second;
1582
+ }
1583
+ }
1584
+ return config.default_subgroup_size;
1585
+ }
1586
+ }
1587
+ return 0; // If no matching configuration is found
1588
+ }
1589
+
1448
1590
  static void ggml_vk_load_shaders(vk_device& device) {
1449
1591
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
1450
1592
 
@@ -1466,36 +1608,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
1466
1608
  uint32_t l_align, m_align, s_align;
1467
1609
  if (device->coopmat2) {
1468
1610
  // spec constants and tile sizes for non-quant matmul/matmul_id
1469
- l_warptile = { 256, 128, 256, 64 };
1470
- m_warptile = { 256, 128, 128, 64 };
1471
- s_warptile = { 128, 64, 64, 64 };
1611
+ l_warptile = { 256, 128, 256, 64, 1 };
1612
+ m_warptile = { 256, 128, 128, 64, 0 };
1613
+ s_warptile = { 128, 64, 64, 64, 0 };
1472
1614
  l_wg_denoms = {128, 256, 1 };
1473
1615
  m_wg_denoms = {128, 128, 1 };
1474
1616
  s_wg_denoms = { 64, 64, 1 };
1475
1617
 
1476
1618
  // spec constants and tile sizes for quant matmul (non-Qi_K)
1477
- l_warptile_mmq = { 256, 128, 256, 64 };
1478
- m_warptile_mmq = { 256, 128, 128, 64 };
1479
- s_warptile_mmq = { 256, 128, 128, 64 };
1619
+ l_warptile_mmq = { 256, 128, 256, 64, 1 };
1620
+ m_warptile_mmq = { 256, 128, 128, 64, 1 };
1621
+ s_warptile_mmq = { 256, 32, 64, 128, 0 };
1480
1622
  l_mmq_wg_denoms = { 128, 256, 1 };
1481
1623
  m_mmq_wg_denoms = { 128, 128, 1 };
1482
- s_mmq_wg_denoms = { 128, 128, 1 };
1624
+ s_mmq_wg_denoms = { 32, 64, 1 };
1483
1625
 
1484
1626
  // spec constants and tile sizes for quant matmul (Qi_K)
1485
- l_warptile_mmq_k = { 256, 128, 512, 16 };
1486
- m_warptile_mmq_k = { 256, 128, 256, 16 };
1487
- s_warptile_mmq_k = { 256, 32, 128, 64 };
1488
- l_mmq_wg_denoms_k = { 128, 512, 1 };
1489
- m_mmq_wg_denoms_k = { 128, 256, 1 };
1490
- s_mmq_wg_denoms_k = { 32, 128, 1 };
1627
+ l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
1628
+ m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
1629
+ s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
1630
+ l_mmq_wg_denoms_k = { 64, 128, 1 };
1631
+ m_mmq_wg_denoms_k = { 32, 64, 1 };
1632
+ s_mmq_wg_denoms_k = { 32, 32, 1 };
1491
1633
 
1492
1634
  // spec constants and tile sizes for quant matmul_id
1493
- l_warptile_mmqid = { 256, 128, 128, 16 };
1494
- m_warptile_mmqid = { 256, 128, 64, 16 };
1495
- s_warptile_mmqid = { 256, 64, 64, 16 };
1496
- l_mmqid_wg_denoms = { 128, 128, 1 };
1635
+ l_warptile_mmqid = { 256, 128, 64, 16, 0 };
1636
+ m_warptile_mmqid = { 256, 128, 64, 16, 0 };
1637
+ s_warptile_mmqid = { 256, 128, 64, 16, 0 };
1638
+ l_mmqid_wg_denoms = { 128, 64, 1 };
1497
1639
  m_mmqid_wg_denoms = { 128, 64, 1 };
1498
- s_mmqid_wg_denoms = { 64, 64, 1 };
1640
+ s_mmqid_wg_denoms = { 128, 64, 1 };
1499
1641
 
1500
1642
  l_align = 128;
1501
1643
  m_align = 64;
@@ -1571,6 +1713,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
1571
1713
  uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
1572
1714
  uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
1573
1715
 
1716
+ if (!require_full_subgroups && required_subgroup_size == 0) {
1717
+ required_subgroup_size = get_subgroup_size(name, device->architecture);
1718
+ }
1719
+
1574
1720
  if (!pipeline) {
1575
1721
  pipeline = std::make_shared<vk_pipeline_struct>();
1576
1722
  pipeline->name = name;
@@ -2121,13 +2267,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
2121
2267
 
2122
2268
  ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
2123
2269
 
2124
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32, "mul_mat_vec_p021_f16_f32", mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2270
+ for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
2271
+ if (device->subgroup_add && device->subgroup_require_full_support) {
2272
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true);
2273
+ } else {
2274
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
2275
+ }
2276
+ }
2125
2277
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2126
2278
 
2127
2279
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2128
2280
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2129
2281
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2130
2282
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2283
+ ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2131
2284
 
2132
2285
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2133
2286
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2136,13 +2289,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
2136
2289
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2137
2290
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2138
2291
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2139
-
2140
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2141
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2142
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2143
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2144
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2145
- ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2292
+ if (device->float_controls_rte_fp16) {
2293
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2294
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2295
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2296
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2297
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2298
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2299
+ } else {
2300
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2301
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
2302
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1);
2303
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1);
2304
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2305
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2306
+ }
2146
2307
 
2147
2308
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2148
2309
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
@@ -2239,6 +2400,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2239
2400
 
2240
2401
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2241
2402
 
2403
+ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2404
+
2242
2405
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2243
2406
 
2244
2407
  for (auto &c : compiles) {
@@ -2247,7 +2410,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2247
2410
  device->need_compiles = false;
2248
2411
  }
2249
2412
 
2250
- static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
2413
+ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
2251
2414
 
2252
2415
  static vk_device ggml_vk_get_device(size_t idx) {
2253
2416
  VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
@@ -2276,6 +2439,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2276
2439
  device->physical_device = physical_devices[dev_num];
2277
2440
  const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
2278
2441
 
2442
+ device->architecture = get_device_architecture(device->physical_device);
2443
+
2279
2444
  const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
2280
2445
  device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
2281
2446
 
@@ -2288,7 +2453,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
2288
2453
  bool coopmat2_support = false;
2289
2454
  device->coopmat_support = false;
2290
2455
 
2291
- // Check if maintenance4 is supported
2292
2456
  for (const auto& properties : ext_props) {
2293
2457
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
2294
2458
  maintenance4_support = true;
@@ -2323,13 +2487,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
2323
2487
  vk::PhysicalDeviceDriverProperties driver_props;
2324
2488
  vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2325
2489
  vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2490
+ vk::PhysicalDeviceVulkan11Properties vk11_props;
2326
2491
  vk::PhysicalDeviceVulkan12Properties vk12_props;
2327
2492
  vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
2328
2493
 
2329
2494
  props2.pNext = &props3;
2330
2495
  props3.pNext = &subgroup_props;
2331
2496
  subgroup_props.pNext = &driver_props;
2332
- driver_props.pNext = &vk12_props;
2497
+ driver_props.pNext = &vk11_props;
2498
+ vk11_props.pNext = &vk12_props;
2333
2499
 
2334
2500
  VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props;
2335
2501
 
@@ -2376,13 +2542,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2376
2542
 
2377
2543
  if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
2378
2544
  device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
2379
- #if defined(_WIN32)
2380
- } else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
2545
+ } else {
2381
2546
  // Limit batching of allocations to 1GB by default to avoid fragmentation issues
2382
2547
  device->suballocation_block_size = 1024*1024*1024;
2383
- #endif
2384
- } else {
2385
- device->suballocation_block_size = device->max_memory_allocation_size;
2386
2548
  }
2387
2549
  device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
2388
2550
 
@@ -2397,11 +2559,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
2397
2559
  }
2398
2560
  device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
2399
2561
 
2562
+ device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2563
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2564
+
2400
2565
  const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2401
2566
 
2402
2567
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2403
2568
 
2404
- if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
2569
+ if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
2405
2570
  device->coopmat_support = false;
2406
2571
  }
2407
2572
 
@@ -2779,7 +2944,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2779
2944
  subgroup_props.pNext = &driver_props;
2780
2945
  physical_device.getProperties2(&props2);
2781
2946
 
2782
- const size_t subgroup_size = subgroup_props.subgroupSize;
2947
+ vk_device_architecture arch = get_device_architecture(physical_device);
2948
+ uint32_t default_subgroup_size = get_subgroup_size("", arch);
2949
+ const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
2950
+
2783
2951
  const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
2784
2952
 
2785
2953
  bool fp16_storage = false;
@@ -2805,7 +2973,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2805
2973
  }
2806
2974
  }
2807
2975
 
2808
- if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
2976
+ const vk_device_architecture device_architecture = get_device_architecture(physical_device);
2977
+
2978
+ if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
2809
2979
  coopmat_support = false;
2810
2980
  }
2811
2981
 
@@ -3850,10 +4020,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
3850
4020
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3851
4021
 
3852
4022
  if (ctx->device->coopmat2) {
3853
- if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
4023
+ // Use large shader when the N dimension is greater than the medium shader's tile size
4024
+ uint32_t crossover_large = mmp->m->wg_denoms[1];
4025
+ if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
3854
4026
  return aligned ? mmp->a_l : mmp->l;
3855
4027
  }
3856
- if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
4028
+ // Use medium shader when the N dimension is greater than the small shader's tile size
4029
+ uint32_t crossover_medium = mmp->s->wg_denoms[1];
4030
+ if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
3857
4031
  return aligned ? mmp->a_m : mmp->m;
3858
4032
  }
3859
4033
  return aligned ? mmp->a_s : mmp->s;
@@ -3878,18 +4052,19 @@ static void ggml_vk_matmul(
3878
4052
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
3879
4053
  uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
3880
4054
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3881
- uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
4055
+ uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
4056
+ uint32_t padded_n) {
3882
4057
  VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
3883
4058
  ggml_vk_sync_buffers(subctx);
3884
4059
  if (split_k == 1) {
3885
- const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
4060
+ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
3886
4061
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
3887
4062
  return;
3888
4063
  }
3889
4064
 
3890
4065
  GGML_ASSERT(batch_stride_d == m * n);
3891
4066
 
3892
- const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
4067
+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
3893
4068
  // Make sure enough workgroups get assigned for split k to work
3894
4069
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
3895
4070
  ggml_vk_sync_buffers(subctx);
@@ -3898,13 +4073,17 @@ static void ggml_vk_matmul(
3898
4073
  }
3899
4074
 
3900
4075
  static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3901
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4076
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3902
4077
 
3903
4078
  if (ctx->device->coopmat2) {
3904
- if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
4079
+ // Use large shader when the N dimension is greater than the medium shader's tile size
4080
+ uint32_t crossover_large = mmp->m->wg_denoms[1];
4081
+ if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
3905
4082
  return aligned ? mmp->a_l : mmp->l;
3906
4083
  }
3907
- if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
4084
+ // Use medium shader when the N dimension is greater than the small shader's tile size
4085
+ uint32_t crossover_medium = mmp->s->wg_denoms[1];
4086
+ if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
3908
4087
  return aligned ? mmp->a_m : mmp->m;
3909
4088
  }
3910
4089
  return aligned ? mmp->a_s : mmp->s;
@@ -3929,14 +4108,15 @@ static void ggml_vk_matmul_id(
3929
4108
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
3930
4109
  uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
3931
4110
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3932
- uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
4111
+ uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
4112
+ uint32_t padded_n) {
3933
4113
  VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
3934
4114
  "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
3935
4115
  "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
3936
4116
  "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
3937
4117
  ggml_vk_sync_buffers(subctx);
3938
4118
  const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
3939
- nei0, nei1, nbi1, ne11 };
4119
+ nei0, nei1, nbi1, ne11, padded_n };
3940
4120
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
3941
4121
  }
3942
4122
 
@@ -4098,15 +4278,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4098
4278
  // Not implemented
4099
4279
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4100
4280
 
4101
- const int x_ne = ne01 * ne00;
4102
- const int y_ne = ne11 * ne10;
4103
- const int d_ne = ne11 * ne01;
4104
-
4105
4281
  const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4106
4282
  const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
4107
4283
 
4108
4284
  vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4109
4285
 
4286
+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4287
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4288
+ const int x_ne = ne01 * ne00;
4289
+ const int y_ne = padded_n * ne10;
4290
+ const int d_ne = ne11 * ne01;
4291
+
4110
4292
  const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
4111
4293
 
4112
4294
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
@@ -4229,7 +4411,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4229
4411
  { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
4230
4412
  ne01, ne11, ne10,
4231
4413
  ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
4232
- split_k, ne12*ne13, ne02, ne12, r2, r3
4414
+ split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
4233
4415
  ); // NOLINT
4234
4416
  }
4235
4417
 
@@ -4466,9 +4648,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4466
4648
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4467
4649
  const uint64_t d_sz = sizeof(float) * d_ne;
4468
4650
 
4651
+ // With grouped query attention there are > 1 Q matrices per K, V matrix.
4652
+ uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02;
4653
+ if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) {
4654
+ gqa_ratio = 1;
4655
+ }
4656
+
4469
4657
  if (dryrun) {
4470
4658
  // Request descriptor sets
4471
- ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, 1);
4659
+ ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1);
4472
4660
  return;
4473
4661
  }
4474
4662
 
@@ -4492,8 +4680,15 @@ static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_c
4492
4680
 
4493
4681
  // compute
4494
4682
  const std::array<uint32_t, 6> pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
4683
+
4684
+ uint32_t workgroups_z = (uint32_t)ne12;
4685
+ // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups
4686
+ if (gqa_ratio > 1) {
4687
+ workgroups_z /= gqa_ratio;
4688
+ }
4689
+
4495
4690
  ggml_vk_sync_buffers(subctx);
4496
- ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
4691
+ ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 6 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, workgroups_z });
4497
4692
  }
4498
4693
 
4499
4694
  static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
@@ -4680,15 +4875,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4680
4875
  // Not implemented
4681
4876
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4682
4877
 
4683
- const uint64_t x_ne = ne01 * ne00;
4684
- const uint64_t y_ne = ne11 * ne10;
4685
- const uint64_t d_ne = ne21 * ne20;
4686
-
4687
4878
  const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4688
4879
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
4689
4880
 
4690
4881
  vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4691
4882
 
4883
+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4884
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4885
+ const uint64_t x_ne = ne01 * ne00;
4886
+ const uint64_t y_ne = padded_n * ne10;
4887
+ const uint64_t d_ne = ne21 * ne20;
4888
+
4692
4889
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4693
4890
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4694
4891
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -4807,7 +5004,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4807
5004
  { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
4808
5005
  ne01, ne21, ne10, ne10, ne10, ne01,
4809
5006
  stride_batch_x, stride_batch_y, ne20*ne21,
4810
- n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
5007
+ n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
4811
5008
  ); // NOLINT
4812
5009
  }
4813
5010
 
@@ -5318,6 +5515,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5318
5515
  return ctx->device->pipeline_rms_norm_back_f32;
5319
5516
  }
5320
5517
  return nullptr;
5518
+ case GGML_OP_L2_NORM:
5519
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5520
+ return ctx->device->pipeline_l2_norm_f32;
5521
+ }
5522
+ return nullptr;
5321
5523
  case GGML_OP_UNARY:
5322
5524
  switch (ggml_get_unary_op(dst)) {
5323
5525
  case GGML_UNARY_OP_SILU:
@@ -5457,6 +5659,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5457
5659
  return ctx->device->pipeline_rwkv_wkv6_f32;
5458
5660
  }
5459
5661
  return nullptr;
5662
+ case GGML_OP_RWKV_WKV7:
5663
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5664
+ return ctx->device->pipeline_rwkv_wkv7_f32;
5665
+ }
5666
+ return nullptr;
5460
5667
  case GGML_OP_OPT_STEP_ADAMW:
5461
5668
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5462
5669
  return ctx->device->pipeline_opt_step_adamw_f32;
@@ -5704,6 +5911,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5704
5911
  case GGML_OP_NORM:
5705
5912
  case GGML_OP_RMS_NORM:
5706
5913
  case GGML_OP_RMS_NORM_BACK:
5914
+ case GGML_OP_L2_NORM:
5707
5915
  case GGML_OP_SOFT_MAX:
5708
5916
  case GGML_OP_SOFT_MAX_BACK:
5709
5917
  case GGML_OP_SUM_ROWS:
@@ -5953,23 +6161,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
5953
6161
  }, dryrun);
5954
6162
  }
5955
6163
 
5956
- static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
5957
- const ggml_tensor * k = dst->src[0];
5958
- const ggml_tensor * v = dst->src[1];
5959
- const ggml_tensor * r = dst->src[2];
5960
- const ggml_tensor * tf = dst->src[3];
5961
- const ggml_tensor * td = dst->src[4];
5962
- const ggml_tensor * state = dst->src[5];
5963
-
5964
- GGML_ASSERT(!ggml_is_quantized(k->type));
5965
- GGML_ASSERT(!ggml_is_quantized(v->type));
5966
- GGML_ASSERT(!ggml_is_quantized(r->type));
5967
- GGML_ASSERT(!ggml_is_quantized(tf->type));
5968
- GGML_ASSERT(!ggml_is_quantized(td->type));
5969
- GGML_ASSERT(!ggml_is_quantized(state->type));
6164
+ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
6165
+ GGML_ASSERT(version == 6 || version == 7);
6166
+ int num_srcs = version == 6 ? 6 : 7;
6167
+
6168
+ for (int i = 0; i < num_srcs; i++) {
6169
+ GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
6170
+ }
6171
+
5970
6172
  GGML_ASSERT(dst->buffer != nullptr);
5971
6173
 
5972
- vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
6174
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
5973
6175
  GGML_ASSERT(pipeline != nullptr);
5974
6176
 
5975
6177
  if (dryrun) {
@@ -5978,89 +6180,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
5978
6180
  }
5979
6181
 
5980
6182
  ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
5981
- ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
5982
- ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
5983
- ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
5984
- ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
5985
- ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
5986
- ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
6183
+ ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
6184
+ for (int i = 0; i < num_srcs; i++) {
6185
+ src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
6186
+ }
5987
6187
 
5988
6188
  ggml_vk_sync_buffers(subctx);
5989
6189
 
5990
- vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
5991
- size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
5992
- bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
6190
+ vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
6191
+ size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
6192
+ bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
5993
6193
 
5994
6194
  if (ctx->device->uma) {
5995
- ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
5996
- ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
5997
- ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
5998
- ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
5999
- ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
6000
- ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
6001
- ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
6195
+ for (int i = 0; i < num_srcs; i++) {
6196
+ ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
6197
+ srcs_uma[i] = d_srcs[i] != nullptr;
6198
+ }
6002
6199
 
6003
- K_uma = d_K != nullptr;
6004
- V_uma = d_V != nullptr;
6005
- R_uma = d_R != nullptr;
6006
- TF_uma = d_TF != nullptr;
6007
- TD_uma = d_TD != nullptr;
6008
- STATE_uma = d_State != nullptr;
6009
- DST_uma = d_D != nullptr;
6200
+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
6201
+ dst_uma = d_D != nullptr;
6010
6202
  }
6011
6203
 
6012
- if (!K_uma) {
6013
- d_K = k_buf_ctx->dev_buffer;
6014
- k_offset = vk_tensor_offset(k) + k->view_offs;
6015
- }
6016
- if (!V_uma) {
6017
- d_V = v_buf_ctx->dev_buffer;
6018
- v_offset = vk_tensor_offset(v) + v->view_offs;
6019
- }
6020
- if (!R_uma) {
6021
- d_R = r_buf_ctx->dev_buffer;
6022
- r_offset = vk_tensor_offset(r) + r->view_offs;
6023
- }
6024
- if (!TF_uma) {
6025
- d_TF = tf_buf_ctx->dev_buffer;
6026
- tf_offset = vk_tensor_offset(tf) + tf->view_offs;
6027
- }
6028
- if (!TD_uma) {
6029
- d_TD = td_buf_ctx->dev_buffer;
6030
- td_offset = vk_tensor_offset(td) + td->view_offs;
6031
- }
6032
- if (!STATE_uma) {
6033
- d_State = state_buf_ctx->dev_buffer;
6034
- state_offset = vk_tensor_offset(state) + state->view_offs;
6204
+ uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
6205
+ for (int i = 0; i < num_srcs; i++) {
6206
+ src_sizes[i] = ggml_nbytes(dst->src[i]);
6207
+ if (!srcs_uma[i]) {
6208
+ d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
6209
+ src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
6210
+ }
6035
6211
  }
6036
- if (!DST_uma) {
6212
+
6213
+ const uint64_t dst_size = ggml_nbytes(dst);
6214
+ if (!dst_uma) {
6037
6215
  d_D = dst_buf_ctx->dev_buffer;
6038
6216
  dst_offset = vk_tensor_offset(dst) + dst->view_offs;
6039
6217
  }
6040
6218
 
6041
- const uint64_t k_size = ggml_nbytes(k);
6042
- const uint64_t v_size = ggml_nbytes(v);
6043
- const uint64_t r_size = ggml_nbytes(r);
6044
- const uint64_t tf_size = ggml_nbytes(tf);
6045
- const uint64_t td_size = ggml_nbytes(td);
6046
- const uint64_t state_size = ggml_nbytes(state);
6047
- const uint64_t dst_size = ggml_nbytes(dst);
6048
-
6049
6219
  std::array<uint32_t, 3> elements = {
6050
6220
  (uint32_t)(pc.B * pc.H),
6051
6221
  1,
6052
6222
  1
6053
6223
  };
6054
6224
 
6055
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6056
- vk_subbuffer{ d_K, k_offset, k_size },
6057
- vk_subbuffer{ d_V, v_offset, v_size },
6058
- vk_subbuffer{ d_R, r_offset, r_size },
6059
- vk_subbuffer{ d_TF, tf_offset, tf_size },
6060
- vk_subbuffer{ d_TD, td_offset, td_size },
6061
- vk_subbuffer{ d_State, state_offset, state_size },
6062
- vk_subbuffer{ d_D, dst_offset, dst_size }
6063
- }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
6225
+ if (version == 6) {
6226
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6227
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
6228
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
6229
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
6230
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
6231
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
6232
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
6233
+ vk_subbuffer{ d_D, dst_offset, dst_size }
6234
+ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
6235
+ } else if (version == 7) {
6236
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6237
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
6238
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
6239
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
6240
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
6241
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
6242
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
6243
+ vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
6244
+ vk_subbuffer{ d_D, dst_offset, dst_size }
6245
+ }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
6246
+ } else {
6247
+ // shouldn't happen
6248
+ GGML_ASSERT(false);
6249
+ }
6064
6250
  }
6065
6251
 
6066
6252
  static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@@ -6069,7 +6255,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
6069
6255
  const size_t n_heads = dst->src[0]->ne[1];
6070
6256
  const size_t n_seqs = dst->src[5]->ne[1];
6071
6257
 
6072
- ggml_vk_op_f32_rwkv6(
6258
+ ggml_vk_op_f32_wkv(
6259
+ ctx, subctx, dst,
6260
+ {
6261
+ (uint32_t)n_seqs,
6262
+ (uint32_t)seq_length,
6263
+ (uint32_t)n_embed,
6264
+ (uint32_t)n_heads,
6265
+ },
6266
+ 6,
6267
+ dryrun
6268
+ );
6269
+ }
6270
+
6271
+ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
6272
+ const size_t seq_length = dst->src[0]->ne[2];
6273
+ const size_t n_embed = dst->ne[0];
6274
+ const size_t n_heads = dst->src[0]->ne[1];
6275
+ const size_t n_seqs = dst->src[6]->ne[1];
6276
+
6277
+ ggml_vk_op_f32_wkv(
6073
6278
  ctx, subctx, dst,
6074
6279
  {
6075
6280
  (uint32_t)n_seqs,
@@ -6077,6 +6282,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
6077
6282
  (uint32_t)n_embed,
6078
6283
  (uint32_t)n_heads,
6079
6284
  },
6285
+ 7,
6080
6286
  dryrun
6081
6287
  );
6082
6288
  }
@@ -6378,6 +6584,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
6378
6584
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6379
6585
  }
6380
6586
 
6587
+ static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6588
+ float * op_params = (float *)dst->op_params;
6589
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6590
+ }
6591
+
6381
6592
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6382
6593
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6383
6594
  }
@@ -6767,7 +6978,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
6767
6978
  ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
6768
6979
  m, n, k,
6769
6980
  k, k, m, k*m, k*n, m*n,
6770
- split_k, batch, batch, batch, 1, 1
6981
+ split_k, batch, batch, batch, 1, 1, n
6771
6982
  );
6772
6983
  }
6773
6984
  ggml_vk_ctx_end(subctx);
@@ -7112,7 +7323,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7112
7323
  ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
7113
7324
  m, n, k,
7114
7325
  k, k, m, k*m, k*n, m*n,
7115
- split_k, batch, batch, batch, 1, 1
7326
+ split_k, batch, batch, batch, 1, 1, n
7116
7327
  );
7117
7328
  }
7118
7329
  ggml_vk_ctx_end(subctx);
@@ -7373,6 +7584,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7373
7584
  case GGML_OP_GROUP_NORM:
7374
7585
  case GGML_OP_RMS_NORM:
7375
7586
  case GGML_OP_RMS_NORM_BACK:
7587
+ case GGML_OP_L2_NORM:
7376
7588
  case GGML_OP_DIAG_MASK_INF:
7377
7589
  case GGML_OP_SOFT_MAX:
7378
7590
  case GGML_OP_SOFT_MAX_BACK:
@@ -7389,6 +7601,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7389
7601
  case GGML_OP_TIMESTEP_EMBEDDING:
7390
7602
  case GGML_OP_POOL_2D:
7391
7603
  case GGML_OP_RWKV_WKV6:
7604
+ case GGML_OP_RWKV_WKV7:
7392
7605
  case GGML_OP_LEAKY_RELU:
7393
7606
  case GGML_OP_FLASH_ATTN_EXT:
7394
7607
  case GGML_OP_OPT_STEP_ADAMW:
@@ -7435,6 +7648,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7435
7648
  case GGML_OP_GROUP_NORM:
7436
7649
  case GGML_OP_RMS_NORM:
7437
7650
  case GGML_OP_RMS_NORM_BACK:
7651
+ case GGML_OP_L2_NORM:
7438
7652
  case GGML_OP_UNARY:
7439
7653
  case GGML_OP_DIAG_MASK_INF:
7440
7654
  case GGML_OP_SOFT_MAX:
@@ -7552,6 +7766,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7552
7766
  case GGML_OP_RMS_NORM_BACK:
7553
7767
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7554
7768
 
7769
+ break;
7770
+ case GGML_OP_L2_NORM:
7771
+ ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
7772
+
7555
7773
  break;
7556
7774
  case GGML_OP_UNARY:
7557
7775
  switch (ggml_get_unary_op(node)) {
@@ -7642,6 +7860,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7642
7860
 
7643
7861
  break;
7644
7862
 
7863
+ case GGML_OP_RWKV_WKV7:
7864
+ ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
7865
+
7866
+ break;
7867
+
7645
7868
  case GGML_OP_OPT_STEP_ADAMW:
7646
7869
  ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
7647
7870
 
@@ -7715,6 +7938,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7715
7938
  case GGML_OP_GROUP_NORM:
7716
7939
  case GGML_OP_RMS_NORM:
7717
7940
  case GGML_OP_RMS_NORM_BACK:
7941
+ case GGML_OP_L2_NORM:
7718
7942
  case GGML_OP_DIAG_MASK_INF:
7719
7943
  case GGML_OP_SOFT_MAX:
7720
7944
  case GGML_OP_SOFT_MAX_BACK:
@@ -7734,6 +7958,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7734
7958
  case GGML_OP_TIMESTEP_EMBEDDING:
7735
7959
  case GGML_OP_POOL_2D:
7736
7960
  case GGML_OP_RWKV_WKV6:
7961
+ case GGML_OP_RWKV_WKV7:
7737
7962
  case GGML_OP_LEAKY_RELU:
7738
7963
  case GGML_OP_REPEAT:
7739
7964
  case GGML_OP_REPEAT_BACK:
@@ -8245,8 +8470,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8245
8470
  VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
8246
8471
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
8247
8472
 
8473
+ uint64_t total_mat_mul_bytes = 0;
8248
8474
  for (int i = 0; i < cgraph->n_nodes; i++) {
8249
8475
  ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
8476
+ if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8477
+ total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8478
+ }
8250
8479
  }
8251
8480
  if (ctx->device->need_compiles) {
8252
8481
  ggml_vk_load_shaders(ctx->device);
@@ -8267,17 +8496,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8267
8496
  bool first_node_in_batch = true; // true if next node will be first node in a batch
8268
8497
  int submit_node_idx = 0; // index to first node in a batch
8269
8498
 
8270
- // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
8271
- // Start with a smaller count to get work submitted right away, and increase it after each submit.
8272
- int nodes_per_submit = 20;
8499
+ // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
8500
+ // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
8501
+ // (and scaled down based on model size, so smaller models submit earlier).
8502
+ // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
8503
+ int nodes_per_submit = 100;
8273
8504
  int submitted_nodes = 0;
8274
8505
  int submit_count = 0;
8506
+ uint64_t mul_mat_bytes = 0;
8507
+ uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
8275
8508
  for (int i = 0; i < cgraph->n_nodes; i++) {
8276
8509
  if (first_node_in_batch) {
8277
8510
  submit_node_idx = i;
8278
8511
  }
8279
8512
 
8280
- bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
8513
+ if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8514
+ mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8515
+ }
8516
+
8517
+ bool submit = (submitted_nodes >= nodes_per_submit) ||
8518
+ (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
8519
+ (i == last_node);
8281
8520
 
8282
8521
  bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
8283
8522
 
@@ -8294,13 +8533,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8294
8533
  if (submit) {
8295
8534
  first_node_in_batch = true;
8296
8535
  submitted_nodes = 0;
8297
- switch (submit_count) {
8298
- case 0:
8299
- nodes_per_submit = 50;
8300
- break;
8301
- default:
8302
- nodes_per_submit = 100;
8303
- break;
8536
+ mul_mat_bytes = 0;
8537
+ if (submit_count < 3) {
8538
+ mul_mat_bytes_per_submit *= 2;
8304
8539
  }
8305
8540
  submit_count++;
8306
8541
  }
@@ -8651,6 +8886,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8651
8886
  case GGML_OP_NORM:
8652
8887
  case GGML_OP_GROUP_NORM:
8653
8888
  case GGML_OP_RMS_NORM:
8889
+ case GGML_OP_L2_NORM:
8654
8890
  return ggml_is_contiguous(op->src[0]);
8655
8891
  case GGML_OP_ADD:
8656
8892
  case GGML_OP_SUB:
@@ -8680,6 +8916,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8680
8916
  case GGML_OP_TIMESTEP_EMBEDDING:
8681
8917
  case GGML_OP_POOL_2D:
8682
8918
  case GGML_OP_RWKV_WKV6:
8919
+ case GGML_OP_RWKV_WKV7:
8683
8920
  case GGML_OP_LEAKY_RELU:
8684
8921
  case GGML_OP_OPT_STEP_ADAMW:
8685
8922
  return true;
@@ -8826,7 +9063,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
8826
9063
  UNUSED(instance_extensions);
8827
9064
  }
8828
9065
 
8829
- static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
9066
+ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
8830
9067
  switch (props.vendorID) {
8831
9068
  case VK_VENDOR_ID_INTEL:
8832
9069
  // Intel drivers don't support coopmat properly yet
@@ -8834,10 +9071,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
8834
9071
  case VK_VENDOR_ID_AMD:
8835
9072
  if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
8836
9073
  // Workaround for AMD proprietary driver reporting support on all GPUs
8837
- const std::string name = props.deviceName;
8838
- return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
8839
- name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
8840
- name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
9074
+ return arch == vk_device_architecture::AMD_RDNA3;
8841
9075
  }
8842
9076
  return true;
8843
9077
  default:
@@ -9067,6 +9301,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9067
9301
  tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9068
9302
  } else if (tensor->op == GGML_OP_SILU_BACK) {
9069
9303
  tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
9304
+ } else if (tensor->op == GGML_OP_L2_NORM) {
9305
+ const float eps = ((float *) tensor->op_params)[0];
9306
+ tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
9070
9307
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9071
9308
  if (src1 != nullptr) {
9072
9309
  tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
@@ -9186,6 +9423,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9186
9423
  } else if (tensor->op == GGML_OP_RWKV_WKV6) {
9187
9424
  tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
9188
9425
  src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
9426
+ } else if (tensor->op == GGML_OP_RWKV_WKV7) {
9427
+ tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
9428
+ src_clone[4], src_clone[5], src_clone[6]);
9189
9429
  } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9190
9430
  src_clone[0]->flags = src0->flags;
9191
9431
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],