@fugood/llama.node 0.3.14 → 0.3.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/package.json +1 -1
  18. package/src/llama.cpp/.github/workflows/build.yml +30 -1
  19. package/src/llama.cpp/CMakeLists.txt +9 -1
  20. package/src/llama.cpp/cmake/common.cmake +2 -0
  21. package/src/llama.cpp/common/arg.cpp +20 -2
  22. package/src/llama.cpp/common/common.cpp +6 -3
  23. package/src/llama.cpp/common/speculative.cpp +4 -4
  24. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  25. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +1 -1
  26. package/src/llama.cpp/examples/embedding/embedding.cpp +1 -1
  27. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  28. package/src/llama.cpp/examples/imatrix/imatrix.cpp +1 -1
  29. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  30. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  31. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +4 -4
  32. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +1 -1
  33. package/src/llama.cpp/examples/lookahead/lookahead.cpp +6 -6
  34. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  35. package/src/llama.cpp/examples/main/main.cpp +6 -6
  36. package/src/llama.cpp/examples/parallel/parallel.cpp +5 -5
  37. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  38. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  39. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  40. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  41. package/src/llama.cpp/examples/run/run.cpp +91 -46
  42. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  43. package/src/llama.cpp/examples/server/server.cpp +32 -15
  44. package/src/llama.cpp/examples/server/utils.hpp +3 -1
  45. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  46. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  47. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  48. package/src/llama.cpp/examples/tts/tts.cpp +12 -9
  49. package/src/llama.cpp/ggml/CMakeLists.txt +1 -0
  50. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  51. package/src/llama.cpp/ggml/include/ggml.h +24 -0
  52. package/src/llama.cpp/ggml/src/CMakeLists.txt +5 -27
  53. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  54. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +0 -5
  55. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +15 -7
  56. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +150 -1
  57. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +253 -2
  58. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  59. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  60. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +7 -0
  61. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +0 -4
  62. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +95 -22
  63. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +3 -0
  64. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +1 -1
  65. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +66 -26
  66. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  67. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +12 -13
  68. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  69. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +1 -2
  70. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +103 -34
  71. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  72. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +19 -20
  73. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  74. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  75. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +1 -1
  76. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  77. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  78. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +352 -146
  79. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  80. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +3 -0
  81. package/src/llama.cpp/ggml/src/ggml.c +85 -2
  82. package/src/llama.cpp/include/llama.h +86 -22
  83. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  84. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  85. package/src/llama.cpp/src/llama-adapter.h +11 -9
  86. package/src/llama.cpp/src/llama-arch.cpp +102 -16
  87. package/src/llama.cpp/src/llama-arch.h +18 -0
  88. package/src/llama.cpp/src/llama-batch.h +2 -2
  89. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  90. package/src/llama.cpp/src/llama-context.h +214 -77
  91. package/src/llama.cpp/src/llama-cparams.h +1 -0
  92. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  93. package/src/llama.cpp/src/llama-graph.h +574 -0
  94. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  95. package/src/llama.cpp/src/llama-hparams.h +9 -0
  96. package/src/llama.cpp/src/llama-io.cpp +15 -0
  97. package/src/llama.cpp/src/llama-io.h +35 -0
  98. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  99. package/src/llama.cpp/src/llama-kv-cache.h +178 -110
  100. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  101. package/src/llama.cpp/src/llama-memory.h +21 -0
  102. package/src/llama.cpp/src/llama-model.cpp +8207 -163
  103. package/src/llama.cpp/src/llama-model.h +34 -1
  104. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  105. package/src/llama.cpp/src/llama.cpp +51 -9984
  106. package/src/llama.cpp/tests/test-backend-ops.cpp +88 -9
  107. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  108. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
@@ -29,6 +29,7 @@
29
29
 
30
30
  #include "ggml-vulkan-shaders.hpp"
31
31
 
32
+ #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
32
33
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
33
34
 
34
35
  #define VK_VENDOR_ID_AMD 0x1002
@@ -149,6 +150,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
149
150
 
150
151
  static constexpr uint32_t mul_mat_vec_max_cols = 8;
151
152
 
153
+ enum vk_device_architecture {
154
+ OTHER,
155
+ AMD_GCN,
156
+ AMD_RDNA1,
157
+ AMD_RDNA2,
158
+ AMD_RDNA3,
159
+ };
160
+
161
+ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
162
+ vk::PhysicalDeviceProperties props = device.getProperties();
163
+
164
+ if (props.vendorID == VK_VENDOR_ID_AMD) {
165
+ const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
166
+
167
+ bool amd_shader_core_properties = false;
168
+ bool integer_dot_product = false;
169
+ bool subgroup_size_control = false;
170
+
171
+ for (const auto& properties : ext_props) {
172
+ if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
173
+ amd_shader_core_properties = true;
174
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
175
+ integer_dot_product = true;
176
+ } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
177
+ subgroup_size_control = true;
178
+ }
179
+ }
180
+
181
+ if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
182
+ return vk_device_architecture::OTHER;
183
+ }
184
+
185
+ vk::PhysicalDeviceProperties2 props2;
186
+ vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
187
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
188
+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
189
+
190
+ props2.pNext = &shader_core_props_amd;
191
+ shader_core_props_amd.pNext = &integer_dot_props;
192
+ integer_dot_props.pNext = &subgroup_size_control_props;
193
+
194
+ device.getProperties2(&props2);
195
+
196
+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
197
+ return vk_device_architecture::AMD_GCN;
198
+ }
199
+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
200
+ // RDNA
201
+ if (shader_core_props_amd.wavefrontsPerSimd == 20) {
202
+ return vk_device_architecture::AMD_RDNA1;
203
+ }
204
+ if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
205
+ return vk_device_architecture::AMD_RDNA3;
206
+ }
207
+ return vk_device_architecture::AMD_RDNA2;
208
+ }
209
+ }
210
+ return vk_device_architecture::OTHER;
211
+ }
212
+
152
213
  struct vk_device_struct {
153
214
  std::mutex mutex;
154
215
 
@@ -161,6 +222,7 @@ struct vk_device_struct {
161
222
  bool pipeline_robustness;
162
223
  vk::Device device;
163
224
  uint32_t vendor_id;
225
+ vk_device_architecture architecture;
164
226
  vk_queue compute_queue;
165
227
  vk_queue transfer_queue;
166
228
  bool single_queue;
@@ -242,6 +304,7 @@ struct vk_device_struct {
242
304
  vk_pipeline pipeline_group_norm_f32;
243
305
  vk_pipeline pipeline_rms_norm_f32;
244
306
  vk_pipeline pipeline_rms_norm_back_f32;
307
+ vk_pipeline pipeline_l2_norm_f32;
245
308
  vk_pipeline pipeline_gelu_f32;
246
309
  vk_pipeline pipeline_gelu_quick_f32;
247
310
  vk_pipeline pipeline_silu_f32;
@@ -266,6 +329,7 @@ struct vk_device_struct {
266
329
  vk_pipeline pipeline_timestep_embedding_f32;
267
330
  vk_pipeline pipeline_pool2d_f32;
268
331
  vk_pipeline pipeline_rwkv_wkv6_f32;
332
+ vk_pipeline pipeline_rwkv_wkv7_f32;
269
333
  vk_pipeline pipeline_opt_step_adamw_f32;
270
334
 
271
335
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
@@ -368,6 +432,7 @@ struct vk_mat_mat_push_constants {
368
432
  uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
369
433
  uint32_t k_split;
370
434
  uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
435
+ uint32_t padded_N;
371
436
  };
372
437
  struct vk_mat_vec_push_constants {
373
438
  uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -380,6 +445,7 @@ struct vk_mat_mat_id_push_constants {
380
445
  uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
381
446
  uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
382
447
  uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
448
+ uint32_t padded_N;
383
449
  };
384
450
  struct vk_mat_vec_id_push_constants {
385
451
  uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -565,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
565
631
  uint32_t H;
566
632
  };
567
633
 
634
+ struct vk_op_rwkv_wkv7_push_constants {
635
+ uint32_t B;
636
+ uint32_t T;
637
+ uint32_t C;
638
+ uint32_t H;
639
+ };
640
+
568
641
  // Allow pre-recording command buffers
569
642
  struct vk_staging_memcpy {
570
643
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1445,6 +1518,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1445
1518
  return supported;
1446
1519
  }
1447
1520
 
1521
+ struct GpuPipelineConfig {
1522
+ // GPU architecture identifier.
1523
+ // Example: vk_device_architecture::AMD_GCN
1524
+ vk_device_architecture arch;
1525
+
1526
+ // Mapping of pipeline names to their specific subgroup sizes.
1527
+ // Example: {"soft_max_f32", 64}
1528
+ std::unordered_map<std::string, uint32_t> pipelines;
1529
+
1530
+ // Default subgroup size for this GPU.
1531
+ // Defaults to 0 if not explicitly provided.
1532
+ uint32_t default_subgroup_size = 0;
1533
+ };
1534
+
1535
+ // Pipeline configuration for RDNA1 GPUs.
1536
+ static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
1537
+ {"soft_max", 64}, {"im2col", 64},
1538
+ {"argmax", 64}, {"mul_mat_vec", 64},
1539
+ {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
1540
+ };
1541
+
1542
+ // Pipeline configuration for RDNA2 GPUs.
1543
+ static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
1544
+ {"soft_max", 64}, {"im2col", 64},
1545
+ };
1546
+
1547
+ static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
1548
+
1549
+ // Define configurations for different GPUs.
1550
+ static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
1551
+ {
1552
+ vk_device_architecture::AMD_RDNA1,
1553
+ {
1554
+ rdna1_pipelines,
1555
+ },
1556
+ RDNA_DEFAULT_SUBGROUP_SIZE
1557
+ },
1558
+ {
1559
+ vk_device_architecture::AMD_RDNA2,
1560
+ {
1561
+ rdna2_pipelines,
1562
+ },
1563
+ RDNA_DEFAULT_SUBGROUP_SIZE
1564
+ },
1565
+ };
1566
+
1567
+ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
1568
+ for (const auto &config : gpu_pipeline_configs) {
1569
+ if (config.arch == arch) {
1570
+ auto pipIt = config.pipelines.find(pipeline_name);
1571
+ if (pipIt != config.pipelines.end()) {
1572
+ return pipIt->second;
1573
+ }
1574
+ std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
1575
+ std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
1576
+ [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
1577
+ for (const auto &entry : sorted_pipelines) {
1578
+ if (pipeline_name.find(entry.first) != std::string::npos) {
1579
+ return entry.second;
1580
+ }
1581
+ }
1582
+ return config.default_subgroup_size;
1583
+ }
1584
+ }
1585
+ return 0; // If no matching configuration is found
1586
+ }
1587
+
1448
1588
  static void ggml_vk_load_shaders(vk_device& device) {
1449
1589
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
1450
1590
 
@@ -1466,36 +1606,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
1466
1606
  uint32_t l_align, m_align, s_align;
1467
1607
  if (device->coopmat2) {
1468
1608
  // spec constants and tile sizes for non-quant matmul/matmul_id
1469
- l_warptile = { 256, 128, 256, 64 };
1470
- m_warptile = { 256, 128, 128, 64 };
1471
- s_warptile = { 128, 64, 64, 64 };
1609
+ l_warptile = { 256, 128, 256, 64, 1 };
1610
+ m_warptile = { 256, 128, 128, 64, 0 };
1611
+ s_warptile = { 128, 64, 64, 64, 0 };
1472
1612
  l_wg_denoms = {128, 256, 1 };
1473
1613
  m_wg_denoms = {128, 128, 1 };
1474
1614
  s_wg_denoms = { 64, 64, 1 };
1475
1615
 
1476
1616
  // spec constants and tile sizes for quant matmul (non-Qi_K)
1477
- l_warptile_mmq = { 256, 128, 256, 64 };
1478
- m_warptile_mmq = { 256, 128, 128, 64 };
1479
- s_warptile_mmq = { 256, 128, 128, 64 };
1617
+ l_warptile_mmq = { 256, 128, 256, 64, 1 };
1618
+ m_warptile_mmq = { 256, 128, 128, 64, 1 };
1619
+ s_warptile_mmq = { 256, 32, 64, 128, 0 };
1480
1620
  l_mmq_wg_denoms = { 128, 256, 1 };
1481
1621
  m_mmq_wg_denoms = { 128, 128, 1 };
1482
- s_mmq_wg_denoms = { 128, 128, 1 };
1622
+ s_mmq_wg_denoms = { 32, 64, 1 };
1483
1623
 
1484
1624
  // spec constants and tile sizes for quant matmul (Qi_K)
1485
- l_warptile_mmq_k = { 256, 128, 512, 16 };
1486
- m_warptile_mmq_k = { 256, 128, 256, 16 };
1487
- s_warptile_mmq_k = { 256, 32, 128, 64 };
1488
- l_mmq_wg_denoms_k = { 128, 512, 1 };
1489
- m_mmq_wg_denoms_k = { 128, 256, 1 };
1490
- s_mmq_wg_denoms_k = { 32, 128, 1 };
1625
+ l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
1626
+ m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
1627
+ s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
1628
+ l_mmq_wg_denoms_k = { 64, 128, 1 };
1629
+ m_mmq_wg_denoms_k = { 32, 64, 1 };
1630
+ s_mmq_wg_denoms_k = { 32, 32, 1 };
1491
1631
 
1492
1632
  // spec constants and tile sizes for quant matmul_id
1493
- l_warptile_mmqid = { 256, 128, 128, 16 };
1494
- m_warptile_mmqid = { 256, 128, 64, 16 };
1495
- s_warptile_mmqid = { 256, 64, 64, 16 };
1496
- l_mmqid_wg_denoms = { 128, 128, 1 };
1633
+ l_warptile_mmqid = { 256, 128, 64, 16, 0 };
1634
+ m_warptile_mmqid = { 256, 128, 64, 16, 0 };
1635
+ s_warptile_mmqid = { 256, 128, 64, 16, 0 };
1636
+ l_mmqid_wg_denoms = { 128, 64, 1 };
1497
1637
  m_mmqid_wg_denoms = { 128, 64, 1 };
1498
- s_mmqid_wg_denoms = { 64, 64, 1 };
1638
+ s_mmqid_wg_denoms = { 128, 64, 1 };
1499
1639
 
1500
1640
  l_align = 128;
1501
1641
  m_align = 64;
@@ -1571,6 +1711,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
1571
1711
  uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
1572
1712
  uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
1573
1713
 
1714
+ if (!require_full_subgroups && required_subgroup_size == 0) {
1715
+ required_subgroup_size = get_subgroup_size(name, device->architecture);
1716
+ }
1717
+
1574
1718
  if (!pipeline) {
1575
1719
  pipeline = std::make_shared<vk_pipeline_struct>();
1576
1720
  pipeline->name = name;
@@ -2128,6 +2272,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2128
2272
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2129
2273
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2130
2274
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2275
+ ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2131
2276
 
2132
2277
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2133
2278
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2239,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2239
2384
 
2240
2385
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2241
2386
 
2387
+ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2388
+
2242
2389
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2243
2390
 
2244
2391
  for (auto &c : compiles) {
@@ -2247,7 +2394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2247
2394
  device->need_compiles = false;
2248
2395
  }
2249
2396
 
2250
- static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
2397
+ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
2251
2398
 
2252
2399
  static vk_device ggml_vk_get_device(size_t idx) {
2253
2400
  VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
@@ -2276,6 +2423,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2276
2423
  device->physical_device = physical_devices[dev_num];
2277
2424
  const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
2278
2425
 
2426
+ device->architecture = get_device_architecture(device->physical_device);
2427
+
2279
2428
  const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
2280
2429
  device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
2281
2430
 
@@ -2288,7 +2437,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
2288
2437
  bool coopmat2_support = false;
2289
2438
  device->coopmat_support = false;
2290
2439
 
2291
- // Check if maintenance4 is supported
2292
2440
  for (const auto& properties : ext_props) {
2293
2441
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
2294
2442
  maintenance4_support = true;
@@ -2376,13 +2524,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2376
2524
 
2377
2525
  if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
2378
2526
  device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
2379
- #if defined(_WIN32)
2380
- } else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
2527
+ } else {
2381
2528
  // Limit batching of allocations to 1GB by default to avoid fragmentation issues
2382
2529
  device->suballocation_block_size = 1024*1024*1024;
2383
- #endif
2384
- } else {
2385
- device->suballocation_block_size = device->max_memory_allocation_size;
2386
2530
  }
2387
2531
  device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
2388
2532
 
@@ -2401,7 +2545,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2401
2545
 
2402
2546
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2403
2547
 
2404
- if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
2548
+ if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
2405
2549
  device->coopmat_support = false;
2406
2550
  }
2407
2551
 
@@ -2779,7 +2923,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2779
2923
  subgroup_props.pNext = &driver_props;
2780
2924
  physical_device.getProperties2(&props2);
2781
2925
 
2782
- const size_t subgroup_size = subgroup_props.subgroupSize;
2926
+ vk_device_architecture arch = get_device_architecture(physical_device);
2927
+ uint32_t default_subgroup_size = get_subgroup_size("", arch);
2928
+ const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
2929
+
2783
2930
  const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
2784
2931
 
2785
2932
  bool fp16_storage = false;
@@ -2805,7 +2952,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2805
2952
  }
2806
2953
  }
2807
2954
 
2808
- if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
2955
+ const vk_device_architecture device_architecture = get_device_architecture(physical_device);
2956
+
2957
+ if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
2809
2958
  coopmat_support = false;
2810
2959
  }
2811
2960
 
@@ -3850,10 +3999,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
3850
3999
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3851
4000
 
3852
4001
  if (ctx->device->coopmat2) {
3853
- if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
4002
+ // Use large shader when the N dimension is greater than the medium shader's tile size
4003
+ uint32_t crossover_large = mmp->m->wg_denoms[1];
4004
+ if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
3854
4005
  return aligned ? mmp->a_l : mmp->l;
3855
4006
  }
3856
- if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
4007
+ // Use medium shader when the N dimension is greater than the small shader's tile size
4008
+ uint32_t crossover_medium = mmp->s->wg_denoms[1];
4009
+ if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
3857
4010
  return aligned ? mmp->a_m : mmp->m;
3858
4011
  }
3859
4012
  return aligned ? mmp->a_s : mmp->s;
@@ -3878,18 +4031,19 @@ static void ggml_vk_matmul(
3878
4031
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
3879
4032
  uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
3880
4033
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3881
- uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
4034
+ uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
4035
+ uint32_t padded_n) {
3882
4036
  VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
3883
4037
  ggml_vk_sync_buffers(subctx);
3884
4038
  if (split_k == 1) {
3885
- const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
4039
+ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
3886
4040
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
3887
4041
  return;
3888
4042
  }
3889
4043
 
3890
4044
  GGML_ASSERT(batch_stride_d == m * n);
3891
4045
 
3892
- const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
4046
+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
3893
4047
  // Make sure enough workgroups get assigned for split k to work
3894
4048
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
3895
4049
  ggml_vk_sync_buffers(subctx);
@@ -3898,13 +4052,17 @@ static void ggml_vk_matmul(
3898
4052
  }
3899
4053
 
3900
4054
  static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3901
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4055
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3902
4056
 
3903
4057
  if (ctx->device->coopmat2) {
3904
- if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
4058
+ // Use large shader when the N dimension is greater than the medium shader's tile size
4059
+ uint32_t crossover_large = mmp->m->wg_denoms[1];
4060
+ if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
3905
4061
  return aligned ? mmp->a_l : mmp->l;
3906
4062
  }
3907
- if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
4063
+ // Use medium shader when the N dimension is greater than the small shader's tile size
4064
+ uint32_t crossover_medium = mmp->s->wg_denoms[1];
4065
+ if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
3908
4066
  return aligned ? mmp->a_m : mmp->m;
3909
4067
  }
3910
4068
  return aligned ? mmp->a_s : mmp->s;
@@ -3929,14 +4087,15 @@ static void ggml_vk_matmul_id(
3929
4087
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
3930
4088
  uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
3931
4089
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3932
- uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
4090
+ uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
4091
+ uint32_t padded_n) {
3933
4092
  VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
3934
4093
  "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
3935
4094
  "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
3936
4095
  "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
3937
4096
  ggml_vk_sync_buffers(subctx);
3938
4097
  const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
3939
- nei0, nei1, nbi1, ne11 };
4098
+ nei0, nei1, nbi1, ne11, padded_n };
3940
4099
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
3941
4100
  }
3942
4101
 
@@ -4098,15 +4257,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4098
4257
  // Not implemented
4099
4258
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4100
4259
 
4101
- const int x_ne = ne01 * ne00;
4102
- const int y_ne = ne11 * ne10;
4103
- const int d_ne = ne11 * ne01;
4104
-
4105
4260
  const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4106
4261
  const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
4107
4262
 
4108
4263
  vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4109
4264
 
4265
+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4266
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4267
+ const int x_ne = ne01 * ne00;
4268
+ const int y_ne = padded_n * ne10;
4269
+ const int d_ne = ne11 * ne01;
4270
+
4110
4271
  const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
4111
4272
 
4112
4273
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
@@ -4229,7 +4390,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4229
4390
  { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
4230
4391
  ne01, ne11, ne10,
4231
4392
  ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
4232
- split_k, ne12*ne13, ne02, ne12, r2, r3
4393
+ split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
4233
4394
  ); // NOLINT
4234
4395
  }
4235
4396
 
@@ -4680,15 +4841,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4680
4841
  // Not implemented
4681
4842
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4682
4843
 
4683
- const uint64_t x_ne = ne01 * ne00;
4684
- const uint64_t y_ne = ne11 * ne10;
4685
- const uint64_t d_ne = ne21 * ne20;
4686
-
4687
4844
  const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4688
4845
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
4689
4846
 
4690
4847
  vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4691
4848
 
4849
+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4850
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4851
+ const uint64_t x_ne = ne01 * ne00;
4852
+ const uint64_t y_ne = padded_n * ne10;
4853
+ const uint64_t d_ne = ne21 * ne20;
4854
+
4692
4855
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4693
4856
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4694
4857
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -4807,7 +4970,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4807
4970
  { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
4808
4971
  ne01, ne21, ne10, ne10, ne10, ne01,
4809
4972
  stride_batch_x, stride_batch_y, ne20*ne21,
4810
- n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
4973
+ n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
4811
4974
  ); // NOLINT
4812
4975
  }
4813
4976
 
@@ -5318,6 +5481,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5318
5481
  return ctx->device->pipeline_rms_norm_back_f32;
5319
5482
  }
5320
5483
  return nullptr;
5484
+ case GGML_OP_L2_NORM:
5485
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5486
+ return ctx->device->pipeline_l2_norm_f32;
5487
+ }
5488
+ return nullptr;
5321
5489
  case GGML_OP_UNARY:
5322
5490
  switch (ggml_get_unary_op(dst)) {
5323
5491
  case GGML_UNARY_OP_SILU:
@@ -5457,6 +5625,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5457
5625
  return ctx->device->pipeline_rwkv_wkv6_f32;
5458
5626
  }
5459
5627
  return nullptr;
5628
+ case GGML_OP_RWKV_WKV7:
5629
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5630
+ return ctx->device->pipeline_rwkv_wkv7_f32;
5631
+ }
5632
+ return nullptr;
5460
5633
  case GGML_OP_OPT_STEP_ADAMW:
5461
5634
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5462
5635
  return ctx->device->pipeline_opt_step_adamw_f32;
@@ -5704,6 +5877,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5704
5877
  case GGML_OP_NORM:
5705
5878
  case GGML_OP_RMS_NORM:
5706
5879
  case GGML_OP_RMS_NORM_BACK:
5880
+ case GGML_OP_L2_NORM:
5707
5881
  case GGML_OP_SOFT_MAX:
5708
5882
  case GGML_OP_SOFT_MAX_BACK:
5709
5883
  case GGML_OP_SUM_ROWS:
@@ -5953,23 +6127,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
5953
6127
  }, dryrun);
5954
6128
  }
5955
6129
 
5956
- static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
5957
- const ggml_tensor * k = dst->src[0];
5958
- const ggml_tensor * v = dst->src[1];
5959
- const ggml_tensor * r = dst->src[2];
5960
- const ggml_tensor * tf = dst->src[3];
5961
- const ggml_tensor * td = dst->src[4];
5962
- const ggml_tensor * state = dst->src[5];
5963
-
5964
- GGML_ASSERT(!ggml_is_quantized(k->type));
5965
- GGML_ASSERT(!ggml_is_quantized(v->type));
5966
- GGML_ASSERT(!ggml_is_quantized(r->type));
5967
- GGML_ASSERT(!ggml_is_quantized(tf->type));
5968
- GGML_ASSERT(!ggml_is_quantized(td->type));
5969
- GGML_ASSERT(!ggml_is_quantized(state->type));
6130
+ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
6131
+ GGML_ASSERT(version == 6 || version == 7);
6132
+ int num_srcs = version == 6 ? 6 : 7;
6133
+
6134
+ for (int i = 0; i < num_srcs; i++) {
6135
+ GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
6136
+ }
6137
+
5970
6138
  GGML_ASSERT(dst->buffer != nullptr);
5971
6139
 
5972
- vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
6140
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
5973
6141
  GGML_ASSERT(pipeline != nullptr);
5974
6142
 
5975
6143
  if (dryrun) {
@@ -5978,89 +6146,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
5978
6146
  }
5979
6147
 
5980
6148
  ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
5981
- ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
5982
- ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
5983
- ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
5984
- ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
5985
- ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
5986
- ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
6149
+ ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
6150
+ for (int i = 0; i < num_srcs; i++) {
6151
+ src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
6152
+ }
5987
6153
 
5988
6154
  ggml_vk_sync_buffers(subctx);
5989
6155
 
5990
- vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
5991
- size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
5992
- bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
6156
+ vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
6157
+ size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
6158
+ bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
5993
6159
 
5994
6160
  if (ctx->device->uma) {
5995
- ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
5996
- ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
5997
- ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
5998
- ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
5999
- ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
6000
- ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
6001
- ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
6161
+ for (int i = 0; i < num_srcs; i++) {
6162
+ ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
6163
+ srcs_uma[i] = d_srcs[i] != nullptr;
6164
+ }
6002
6165
 
6003
- K_uma = d_K != nullptr;
6004
- V_uma = d_V != nullptr;
6005
- R_uma = d_R != nullptr;
6006
- TF_uma = d_TF != nullptr;
6007
- TD_uma = d_TD != nullptr;
6008
- STATE_uma = d_State != nullptr;
6009
- DST_uma = d_D != nullptr;
6166
+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
6167
+ dst_uma = d_D != nullptr;
6010
6168
  }
6011
6169
 
6012
- if (!K_uma) {
6013
- d_K = k_buf_ctx->dev_buffer;
6014
- k_offset = vk_tensor_offset(k) + k->view_offs;
6015
- }
6016
- if (!V_uma) {
6017
- d_V = v_buf_ctx->dev_buffer;
6018
- v_offset = vk_tensor_offset(v) + v->view_offs;
6019
- }
6020
- if (!R_uma) {
6021
- d_R = r_buf_ctx->dev_buffer;
6022
- r_offset = vk_tensor_offset(r) + r->view_offs;
6023
- }
6024
- if (!TF_uma) {
6025
- d_TF = tf_buf_ctx->dev_buffer;
6026
- tf_offset = vk_tensor_offset(tf) + tf->view_offs;
6027
- }
6028
- if (!TD_uma) {
6029
- d_TD = td_buf_ctx->dev_buffer;
6030
- td_offset = vk_tensor_offset(td) + td->view_offs;
6031
- }
6032
- if (!STATE_uma) {
6033
- d_State = state_buf_ctx->dev_buffer;
6034
- state_offset = vk_tensor_offset(state) + state->view_offs;
6170
+ uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
6171
+ for (int i = 0; i < num_srcs; i++) {
6172
+ src_sizes[i] = ggml_nbytes(dst->src[i]);
6173
+ if (!srcs_uma[i]) {
6174
+ d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
6175
+ src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
6176
+ }
6035
6177
  }
6036
- if (!DST_uma) {
6178
+
6179
+ const uint64_t dst_size = ggml_nbytes(dst);
6180
+ if (!dst_uma) {
6037
6181
  d_D = dst_buf_ctx->dev_buffer;
6038
6182
  dst_offset = vk_tensor_offset(dst) + dst->view_offs;
6039
6183
  }
6040
6184
 
6041
- const uint64_t k_size = ggml_nbytes(k);
6042
- const uint64_t v_size = ggml_nbytes(v);
6043
- const uint64_t r_size = ggml_nbytes(r);
6044
- const uint64_t tf_size = ggml_nbytes(tf);
6045
- const uint64_t td_size = ggml_nbytes(td);
6046
- const uint64_t state_size = ggml_nbytes(state);
6047
- const uint64_t dst_size = ggml_nbytes(dst);
6048
-
6049
6185
  std::array<uint32_t, 3> elements = {
6050
6186
  (uint32_t)(pc.B * pc.H),
6051
6187
  1,
6052
6188
  1
6053
6189
  };
6054
6190
 
6055
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6056
- vk_subbuffer{ d_K, k_offset, k_size },
6057
- vk_subbuffer{ d_V, v_offset, v_size },
6058
- vk_subbuffer{ d_R, r_offset, r_size },
6059
- vk_subbuffer{ d_TF, tf_offset, tf_size },
6060
- vk_subbuffer{ d_TD, td_offset, td_size },
6061
- vk_subbuffer{ d_State, state_offset, state_size },
6062
- vk_subbuffer{ d_D, dst_offset, dst_size }
6063
- }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
6191
+ if (version == 6) {
6192
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6193
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
6194
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
6195
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
6196
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
6197
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
6198
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
6199
+ vk_subbuffer{ d_D, dst_offset, dst_size }
6200
+ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
6201
+ } else if (version == 7) {
6202
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6203
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
6204
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
6205
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
6206
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
6207
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
6208
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
6209
+ vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
6210
+ vk_subbuffer{ d_D, dst_offset, dst_size }
6211
+ }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
6212
+ } else {
6213
+ // shouldn't happen
6214
+ GGML_ASSERT(false);
6215
+ }
6064
6216
  }
6065
6217
 
6066
6218
  static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@@ -6069,7 +6221,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
6069
6221
  const size_t n_heads = dst->src[0]->ne[1];
6070
6222
  const size_t n_seqs = dst->src[5]->ne[1];
6071
6223
 
6072
- ggml_vk_op_f32_rwkv6(
6224
+ ggml_vk_op_f32_wkv(
6073
6225
  ctx, subctx, dst,
6074
6226
  {
6075
6227
  (uint32_t)n_seqs,
@@ -6077,6 +6229,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
6077
6229
  (uint32_t)n_embed,
6078
6230
  (uint32_t)n_heads,
6079
6231
  },
6232
+ 6,
6233
+ dryrun
6234
+ );
6235
+ }
6236
+
6237
+ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
6238
+ const size_t seq_length = dst->src[0]->ne[2];
6239
+ const size_t n_embed = dst->ne[0];
6240
+ const size_t n_heads = dst->src[0]->ne[1];
6241
+ const size_t n_seqs = dst->src[6]->ne[1];
6242
+
6243
+ ggml_vk_op_f32_wkv(
6244
+ ctx, subctx, dst,
6245
+ {
6246
+ (uint32_t)n_seqs,
6247
+ (uint32_t)seq_length,
6248
+ (uint32_t)n_embed,
6249
+ (uint32_t)n_heads,
6250
+ },
6251
+ 7,
6080
6252
  dryrun
6081
6253
  );
6082
6254
  }
@@ -6378,6 +6550,11 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub
6378
6550
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6379
6551
  }
6380
6552
 
6553
+ static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6554
+ float * op_params = (float *)dst->op_params;
6555
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6556
+ }
6557
+
6381
6558
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6382
6559
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6383
6560
  }
@@ -6767,7 +6944,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
6767
6944
  ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
6768
6945
  m, n, k,
6769
6946
  k, k, m, k*m, k*n, m*n,
6770
- split_k, batch, batch, batch, 1, 1
6947
+ split_k, batch, batch, batch, 1, 1, n
6771
6948
  );
6772
6949
  }
6773
6950
  ggml_vk_ctx_end(subctx);
@@ -7112,7 +7289,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7112
7289
  ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
7113
7290
  m, n, k,
7114
7291
  k, k, m, k*m, k*n, m*n,
7115
- split_k, batch, batch, batch, 1, 1
7292
+ split_k, batch, batch, batch, 1, 1, n
7116
7293
  );
7117
7294
  }
7118
7295
  ggml_vk_ctx_end(subctx);
@@ -7373,6 +7550,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7373
7550
  case GGML_OP_GROUP_NORM:
7374
7551
  case GGML_OP_RMS_NORM:
7375
7552
  case GGML_OP_RMS_NORM_BACK:
7553
+ case GGML_OP_L2_NORM:
7376
7554
  case GGML_OP_DIAG_MASK_INF:
7377
7555
  case GGML_OP_SOFT_MAX:
7378
7556
  case GGML_OP_SOFT_MAX_BACK:
@@ -7389,6 +7567,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7389
7567
  case GGML_OP_TIMESTEP_EMBEDDING:
7390
7568
  case GGML_OP_POOL_2D:
7391
7569
  case GGML_OP_RWKV_WKV6:
7570
+ case GGML_OP_RWKV_WKV7:
7392
7571
  case GGML_OP_LEAKY_RELU:
7393
7572
  case GGML_OP_FLASH_ATTN_EXT:
7394
7573
  case GGML_OP_OPT_STEP_ADAMW:
@@ -7435,6 +7614,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7435
7614
  case GGML_OP_GROUP_NORM:
7436
7615
  case GGML_OP_RMS_NORM:
7437
7616
  case GGML_OP_RMS_NORM_BACK:
7617
+ case GGML_OP_L2_NORM:
7438
7618
  case GGML_OP_UNARY:
7439
7619
  case GGML_OP_DIAG_MASK_INF:
7440
7620
  case GGML_OP_SOFT_MAX:
@@ -7552,6 +7732,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7552
7732
  case GGML_OP_RMS_NORM_BACK:
7553
7733
  ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7554
7734
 
7735
+ break;
7736
+ case GGML_OP_L2_NORM:
7737
+ ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
7738
+
7555
7739
  break;
7556
7740
  case GGML_OP_UNARY:
7557
7741
  switch (ggml_get_unary_op(node)) {
@@ -7642,6 +7826,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7642
7826
 
7643
7827
  break;
7644
7828
 
7829
+ case GGML_OP_RWKV_WKV7:
7830
+ ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
7831
+
7832
+ break;
7833
+
7645
7834
  case GGML_OP_OPT_STEP_ADAMW:
7646
7835
  ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
7647
7836
 
@@ -7715,6 +7904,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7715
7904
  case GGML_OP_GROUP_NORM:
7716
7905
  case GGML_OP_RMS_NORM:
7717
7906
  case GGML_OP_RMS_NORM_BACK:
7907
+ case GGML_OP_L2_NORM:
7718
7908
  case GGML_OP_DIAG_MASK_INF:
7719
7909
  case GGML_OP_SOFT_MAX:
7720
7910
  case GGML_OP_SOFT_MAX_BACK:
@@ -7734,6 +7924,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7734
7924
  case GGML_OP_TIMESTEP_EMBEDDING:
7735
7925
  case GGML_OP_POOL_2D:
7736
7926
  case GGML_OP_RWKV_WKV6:
7927
+ case GGML_OP_RWKV_WKV7:
7737
7928
  case GGML_OP_LEAKY_RELU:
7738
7929
  case GGML_OP_REPEAT:
7739
7930
  case GGML_OP_REPEAT_BACK:
@@ -8245,8 +8436,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8245
8436
  VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
8246
8437
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
8247
8438
 
8439
+ uint64_t total_mat_mul_bytes = 0;
8248
8440
  for (int i = 0; i < cgraph->n_nodes; i++) {
8249
8441
  ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
8442
+ if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8443
+ total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8444
+ }
8250
8445
  }
8251
8446
  if (ctx->device->need_compiles) {
8252
8447
  ggml_vk_load_shaders(ctx->device);
@@ -8267,17 +8462,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8267
8462
  bool first_node_in_batch = true; // true if next node will be first node in a batch
8268
8463
  int submit_node_idx = 0; // index to first node in a batch
8269
8464
 
8270
- // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
8271
- // Start with a smaller count to get work submitted right away, and increase it after each submit.
8272
- int nodes_per_submit = 20;
8465
+ // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
8466
+ // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
8467
+ // (and scaled down based on model size, so smaller models submit earlier).
8468
+ // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
8469
+ int nodes_per_submit = 100;
8273
8470
  int submitted_nodes = 0;
8274
8471
  int submit_count = 0;
8472
+ uint64_t mul_mat_bytes = 0;
8473
+ uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
8275
8474
  for (int i = 0; i < cgraph->n_nodes; i++) {
8276
8475
  if (first_node_in_batch) {
8277
8476
  submit_node_idx = i;
8278
8477
  }
8279
8478
 
8280
- bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
8479
+ if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8480
+ mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8481
+ }
8482
+
8483
+ bool submit = (submitted_nodes >= nodes_per_submit) ||
8484
+ (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
8485
+ (i == last_node);
8281
8486
 
8282
8487
  bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
8283
8488
 
@@ -8294,13 +8499,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8294
8499
  if (submit) {
8295
8500
  first_node_in_batch = true;
8296
8501
  submitted_nodes = 0;
8297
- switch (submit_count) {
8298
- case 0:
8299
- nodes_per_submit = 50;
8300
- break;
8301
- default:
8302
- nodes_per_submit = 100;
8303
- break;
8502
+ mul_mat_bytes = 0;
8503
+ if (submit_count < 3) {
8504
+ mul_mat_bytes_per_submit *= 2;
8304
8505
  }
8305
8506
  submit_count++;
8306
8507
  }
@@ -8651,6 +8852,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8651
8852
  case GGML_OP_NORM:
8652
8853
  case GGML_OP_GROUP_NORM:
8653
8854
  case GGML_OP_RMS_NORM:
8855
+ case GGML_OP_L2_NORM:
8654
8856
  return ggml_is_contiguous(op->src[0]);
8655
8857
  case GGML_OP_ADD:
8656
8858
  case GGML_OP_SUB:
@@ -8680,6 +8882,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8680
8882
  case GGML_OP_TIMESTEP_EMBEDDING:
8681
8883
  case GGML_OP_POOL_2D:
8682
8884
  case GGML_OP_RWKV_WKV6:
8885
+ case GGML_OP_RWKV_WKV7:
8683
8886
  case GGML_OP_LEAKY_RELU:
8684
8887
  case GGML_OP_OPT_STEP_ADAMW:
8685
8888
  return true;
@@ -8826,7 +9029,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
8826
9029
  UNUSED(instance_extensions);
8827
9030
  }
8828
9031
 
8829
- static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
9032
+ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
8830
9033
  switch (props.vendorID) {
8831
9034
  case VK_VENDOR_ID_INTEL:
8832
9035
  // Intel drivers don't support coopmat properly yet
@@ -8834,10 +9037,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
8834
9037
  case VK_VENDOR_ID_AMD:
8835
9038
  if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
8836
9039
  // Workaround for AMD proprietary driver reporting support on all GPUs
8837
- const std::string name = props.deviceName;
8838
- return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
8839
- name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
8840
- name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
9040
+ return arch == vk_device_architecture::AMD_RDNA3;
8841
9041
  }
8842
9042
  return true;
8843
9043
  default:
@@ -9067,6 +9267,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9067
9267
  tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9068
9268
  } else if (tensor->op == GGML_OP_SILU_BACK) {
9069
9269
  tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
9270
+ } else if (tensor->op == GGML_OP_L2_NORM) {
9271
+ const float eps = ((float *) tensor->op_params)[0];
9272
+ tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
9070
9273
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
9071
9274
  if (src1 != nullptr) {
9072
9275
  tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
@@ -9186,6 +9389,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9186
9389
  } else if (tensor->op == GGML_OP_RWKV_WKV6) {
9187
9390
  tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
9188
9391
  src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
9392
+ } else if (tensor->op == GGML_OP_RWKV_WKV7) {
9393
+ tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
9394
+ src_clone[4], src_clone[5], src_clone[6]);
9189
9395
  } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9190
9396
  src_clone[0]->flags = src0->flags;
9191
9397
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],