@fugood/llama.node 0.3.13 → 0.3.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (184) hide show
  1. package/bin/darwin/arm64/llama-node.node +0 -0
  2. package/bin/darwin/x64/llama-node.node +0 -0
  3. package/bin/linux/arm64/llama-node.node +0 -0
  4. package/bin/linux/x64/llama-node.node +0 -0
  5. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  6. package/bin/linux-cuda/x64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  9. package/bin/win32/arm64/llama-node.node +0 -0
  10. package/bin/win32/arm64/node.lib +0 -0
  11. package/bin/win32/x64/llama-node.node +0 -0
  12. package/bin/win32/x64/node.lib +0 -0
  13. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  14. package/bin/win32-vulkan/arm64/node.lib +0 -0
  15. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  16. package/bin/win32-vulkan/x64/node.lib +0 -0
  17. package/lib/binding.ts +1 -1
  18. package/package.json +1 -1
  19. package/src/LlamaContext.cpp +98 -76
  20. package/src/LlamaContext.h +1 -1
  21. package/src/common.hpp +1 -2
  22. package/src/llama.cpp/.github/workflows/build.yml +89 -10
  23. package/src/llama.cpp/.github/workflows/server.yml +2 -0
  24. package/src/llama.cpp/CMakeLists.txt +9 -1
  25. package/src/llama.cpp/cmake/common.cmake +2 -0
  26. package/src/llama.cpp/common/CMakeLists.txt +3 -3
  27. package/src/llama.cpp/common/arg.cpp +132 -13
  28. package/src/llama.cpp/common/chat.cpp +960 -266
  29. package/src/llama.cpp/common/chat.h +135 -0
  30. package/src/llama.cpp/common/common.cpp +33 -174
  31. package/src/llama.cpp/common/common.h +27 -67
  32. package/src/llama.cpp/common/json-schema-to-grammar.cpp +4 -5
  33. package/src/llama.cpp/common/json-schema-to-grammar.h +0 -1
  34. package/src/llama.cpp/common/{minja.hpp → minja/minja.hpp} +37 -5
  35. package/src/llama.cpp/common/ngram-cache.cpp +1 -0
  36. package/src/llama.cpp/common/sampling.cpp +45 -7
  37. package/src/llama.cpp/common/speculative.cpp +10 -9
  38. package/src/llama.cpp/common/speculative.h +1 -1
  39. package/src/llama.cpp/docs/build.md +45 -7
  40. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +2 -2
  41. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +4 -2
  42. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -1
  43. package/src/llama.cpp/examples/export-lora/export-lora.cpp +4 -2
  44. package/src/llama.cpp/examples/gritlm/gritlm.cpp +2 -2
  45. package/src/llama.cpp/examples/imatrix/imatrix.cpp +3 -4
  46. package/src/llama.cpp/examples/infill/infill.cpp +2 -2
  47. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +2 -2
  48. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +5 -5
  49. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  50. package/src/llama.cpp/examples/llava/clip.cpp +373 -107
  51. package/src/llama.cpp/examples/llava/clip.h +19 -3
  52. package/src/llama.cpp/examples/llava/gemma3-cli.cpp +341 -0
  53. package/src/llama.cpp/examples/llava/llava.cpp +4 -2
  54. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +30 -11
  55. package/src/llama.cpp/examples/lookahead/lookahead.cpp +7 -6
  56. package/src/llama.cpp/examples/lookup/lookup.cpp +1 -1
  57. package/src/llama.cpp/examples/main/main.cpp +79 -34
  58. package/src/llama.cpp/examples/parallel/parallel.cpp +6 -5
  59. package/src/llama.cpp/examples/passkey/passkey.cpp +15 -14
  60. package/src/llama.cpp/examples/perplexity/perplexity.cpp +6 -6
  61. package/src/llama.cpp/examples/quantize/quantize.cpp +1 -0
  62. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -2
  63. package/src/llama.cpp/examples/retrieval/retrieval.cpp +1 -1
  64. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.cpp +882 -237
  65. package/src/llama.cpp/examples/run/linenoise.cpp/linenoise.h +35 -26
  66. package/src/llama.cpp/examples/run/run.cpp +196 -108
  67. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +2 -2
  68. package/src/llama.cpp/examples/server/server.cpp +113 -101
  69. package/src/llama.cpp/examples/server/utils.hpp +94 -105
  70. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +2 -2
  71. package/src/llama.cpp/examples/speculative/speculative.cpp +14 -14
  72. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
  73. package/src/llama.cpp/examples/sycl/run-llama2.sh +2 -2
  74. package/src/llama.cpp/examples/tts/tts.cpp +263 -151
  75. package/src/llama.cpp/ggml/CMakeLists.txt +14 -1
  76. package/src/llama.cpp/ggml/cmake/common.cmake +26 -0
  77. package/src/llama.cpp/ggml/include/ggml-alloc.h +1 -1
  78. package/src/llama.cpp/ggml/include/ggml-backend.h +3 -3
  79. package/src/llama.cpp/ggml/include/ggml-cpu.h +3 -0
  80. package/src/llama.cpp/ggml/include/ggml.h +29 -1
  81. package/src/llama.cpp/ggml/src/CMakeLists.txt +15 -34
  82. package/src/llama.cpp/ggml/src/ggml-alloc.c +24 -15
  83. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +1 -1
  84. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +58 -54
  85. package/src/llama.cpp/ggml/src/ggml-backend.cpp +10 -8
  86. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +6 -2
  87. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +3 -7
  88. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +3 -5
  89. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +139 -16
  90. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +2 -1
  91. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +4 -0
  92. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +2 -1
  93. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +151 -0
  94. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1546 -387
  95. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1645 -113
  96. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +22 -0
  97. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +259 -0
  98. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +61 -0
  99. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +288 -0
  100. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.h +17 -0
  101. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +15 -2
  102. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +2 -1
  103. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +3 -1
  104. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +14 -0
  105. package/src/llama.cpp/ggml/src/ggml-impl.h +1 -1
  106. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +4 -5
  107. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +242 -0
  108. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +6 -6
  109. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +1 -0
  110. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +315 -138
  111. package/src/llama.cpp/ggml/src/ggml-quants.c +114 -114
  112. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +2 -1
  113. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +5 -0
  114. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +17 -0
  116. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +117 -36
  117. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +33 -4
  118. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +2 -2
  119. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.cpp +701 -0
  120. package/src/llama.cpp/ggml/src/ggml-sycl/cpy.hpp +11 -0
  121. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +55 -0
  122. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +147 -16
  123. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +40 -40
  124. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +307 -0
  125. package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +23 -0
  126. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +262 -746
  127. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +0 -1
  128. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -78
  129. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +114 -6
  130. package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +6 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +4 -1
  132. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.cpp +13 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/sycl_hw.hpp +23 -0
  134. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.cpp +305 -0
  135. package/src/llama.cpp/ggml/src/ggml-sycl/wkv.hpp +10 -0
  136. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +498 -188
  137. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +0 -4
  138. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +16 -3
  139. package/src/llama.cpp/ggml/src/ggml.c +93 -5
  140. package/src/llama.cpp/include/llama.h +105 -27
  141. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.inp +112 -0
  142. package/src/llama.cpp/models/ggml-vocab-gpt-4o.gguf.out +46 -0
  143. package/src/llama.cpp/requirements/requirements-all.txt +1 -0
  144. package/src/llama.cpp/requirements/requirements-tool_bench.txt +12 -0
  145. package/src/llama.cpp/requirements.txt +1 -0
  146. package/src/llama.cpp/src/CMakeLists.txt +5 -2
  147. package/src/llama.cpp/src/llama-adapter.cpp +19 -20
  148. package/src/llama.cpp/src/llama-adapter.h +11 -9
  149. package/src/llama.cpp/src/llama-arch.cpp +123 -16
  150. package/src/llama.cpp/src/llama-arch.h +19 -0
  151. package/src/llama.cpp/src/llama-batch.h +2 -2
  152. package/src/llama.cpp/src/llama-chat.cpp +1 -0
  153. package/src/llama.cpp/src/llama-context.cpp +2253 -1222
  154. package/src/llama.cpp/src/llama-context.h +214 -77
  155. package/src/llama.cpp/src/llama-cparams.h +1 -0
  156. package/src/llama.cpp/src/llama-grammar.cpp +182 -182
  157. package/src/llama.cpp/src/llama-grammar.h +12 -3
  158. package/src/llama.cpp/src/llama-graph.cpp +1662 -0
  159. package/src/llama.cpp/src/llama-graph.h +574 -0
  160. package/src/llama.cpp/src/llama-hparams.cpp +8 -0
  161. package/src/llama.cpp/src/llama-hparams.h +9 -0
  162. package/src/llama.cpp/src/llama-io.cpp +15 -0
  163. package/src/llama.cpp/src/llama-io.h +35 -0
  164. package/src/llama.cpp/src/llama-kv-cache.cpp +1006 -291
  165. package/src/llama.cpp/src/llama-kv-cache.h +178 -109
  166. package/src/llama.cpp/src/llama-memory.cpp +1 -0
  167. package/src/llama.cpp/src/llama-memory.h +21 -0
  168. package/src/llama.cpp/src/llama-mmap.cpp +11 -1
  169. package/src/llama.cpp/src/llama-model.cpp +8230 -122
  170. package/src/llama.cpp/src/llama-model.h +34 -1
  171. package/src/llama.cpp/src/llama-quant.cpp +10 -1
  172. package/src/llama.cpp/src/llama-sampling.cpp +43 -10
  173. package/src/llama.cpp/src/llama-vocab.cpp +12 -0
  174. package/src/llama.cpp/src/llama.cpp +51 -9837
  175. package/src/llama.cpp/tests/test-backend-ops.cpp +247 -112
  176. package/src/llama.cpp/tests/test-chat-template.cpp +32 -22
  177. package/src/llama.cpp/tests/test-chat.cpp +593 -395
  178. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +63 -63
  179. package/src/llama.cpp/tests/test-quantize-fns.cpp +1 -9
  180. package/src/llama.cpp/Sources/llama/llama.h +0 -4
  181. package/src/llama.cpp/common/chat.hpp +0 -55
  182. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +0 -143
  183. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +0 -9
  184. /package/src/llama.cpp/common/{chat-template.hpp → minja/chat-template.hpp} +0 -0
@@ -29,6 +29,7 @@
29
29
 
30
30
  #include "ggml-vulkan-shaders.hpp"
31
31
 
32
+ #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
32
33
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
33
34
 
34
35
  #define VK_VENDOR_ID_AMD 0x1002
@@ -149,6 +150,66 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf);
149
150
 
150
151
  static constexpr uint32_t mul_mat_vec_max_cols = 8;
151
152
 
153
+ enum vk_device_architecture {
154
+ OTHER,
155
+ AMD_GCN,
156
+ AMD_RDNA1,
157
+ AMD_RDNA2,
158
+ AMD_RDNA3,
159
+ };
160
+
161
+ static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) {
162
+ vk::PhysicalDeviceProperties props = device.getProperties();
163
+
164
+ if (props.vendorID == VK_VENDOR_ID_AMD) {
165
+ const std::vector<vk::ExtensionProperties> ext_props = device.enumerateDeviceExtensionProperties();
166
+
167
+ bool amd_shader_core_properties = false;
168
+ bool integer_dot_product = false;
169
+ bool subgroup_size_control = false;
170
+
171
+ for (const auto& properties : ext_props) {
172
+ if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) {
173
+ amd_shader_core_properties = true;
174
+ } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) {
175
+ integer_dot_product = true;
176
+ } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) {
177
+ subgroup_size_control = true;
178
+ }
179
+ }
180
+
181
+ if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) {
182
+ return vk_device_architecture::OTHER;
183
+ }
184
+
185
+ vk::PhysicalDeviceProperties2 props2;
186
+ vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd;
187
+ vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props;
188
+ vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
189
+
190
+ props2.pNext = &shader_core_props_amd;
191
+ shader_core_props_amd.pNext = &integer_dot_props;
192
+ integer_dot_props.pNext = &subgroup_size_control_props;
193
+
194
+ device.getProperties2(&props2);
195
+
196
+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) {
197
+ return vk_device_architecture::AMD_GCN;
198
+ }
199
+ if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) {
200
+ // RDNA
201
+ if (shader_core_props_amd.wavefrontsPerSimd == 20) {
202
+ return vk_device_architecture::AMD_RDNA1;
203
+ }
204
+ if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) {
205
+ return vk_device_architecture::AMD_RDNA3;
206
+ }
207
+ return vk_device_architecture::AMD_RDNA2;
208
+ }
209
+ }
210
+ return vk_device_architecture::OTHER;
211
+ }
212
+
152
213
  struct vk_device_struct {
153
214
  std::mutex mutex;
154
215
 
@@ -161,6 +222,7 @@ struct vk_device_struct {
161
222
  bool pipeline_robustness;
162
223
  vk::Device device;
163
224
  uint32_t vendor_id;
225
+ vk_device_architecture architecture;
164
226
  vk_queue compute_queue;
165
227
  vk_queue transfer_queue;
166
228
  bool single_queue;
@@ -241,15 +303,20 @@ struct vk_device_struct {
241
303
  vk_pipeline pipeline_norm_f32;
242
304
  vk_pipeline pipeline_group_norm_f32;
243
305
  vk_pipeline pipeline_rms_norm_f32;
306
+ vk_pipeline pipeline_rms_norm_back_f32;
307
+ vk_pipeline pipeline_l2_norm_f32;
244
308
  vk_pipeline pipeline_gelu_f32;
245
309
  vk_pipeline pipeline_gelu_quick_f32;
246
310
  vk_pipeline pipeline_silu_f32;
311
+ vk_pipeline pipeline_silu_back_f32;
247
312
  vk_pipeline pipeline_relu_f32;
248
313
  vk_pipeline pipeline_leaky_relu_f32;
249
314
  vk_pipeline pipeline_tanh_f32;
315
+ vk_pipeline pipeline_sigmoid_f32;
250
316
  vk_pipeline pipeline_diag_mask_inf_f32;
251
317
  vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
252
318
  vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
319
+ vk_pipeline pipeline_soft_max_back_f32;
253
320
  vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16;
254
321
  vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16;
255
322
  vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16;
@@ -262,6 +329,7 @@ struct vk_device_struct {
262
329
  vk_pipeline pipeline_timestep_embedding_f32;
263
330
  vk_pipeline pipeline_pool2d_f32;
264
331
  vk_pipeline pipeline_rwkv_wkv6_f32;
332
+ vk_pipeline pipeline_rwkv_wkv7_f32;
265
333
  vk_pipeline pipeline_opt_step_adamw_f32;
266
334
 
267
335
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
@@ -364,6 +432,7 @@ struct vk_mat_mat_push_constants {
364
432
  uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
365
433
  uint32_t k_split;
366
434
  uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3;
435
+ uint32_t padded_N;
367
436
  };
368
437
  struct vk_mat_vec_push_constants {
369
438
  uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -376,6 +445,7 @@ struct vk_mat_mat_id_push_constants {
376
445
  uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
377
446
  uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d;
378
447
  uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11;
448
+ uint32_t padded_N;
379
449
  };
380
450
  struct vk_mat_vec_id_push_constants {
381
451
  uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d;
@@ -504,6 +574,7 @@ struct vk_op_rope_push_constants {
504
574
  uint32_t s1;
505
575
  uint32_t s2;
506
576
  int32_t sections[4];
577
+ uint32_t is_back;
507
578
  };
508
579
 
509
580
  struct vk_op_soft_max_push_constants {
@@ -560,6 +631,13 @@ struct vk_op_rwkv_wkv6_push_constants {
560
631
  uint32_t H;
561
632
  };
562
633
 
634
+ struct vk_op_rwkv_wkv7_push_constants {
635
+ uint32_t B;
636
+ uint32_t T;
637
+ uint32_t C;
638
+ uint32_t H;
639
+ };
640
+
563
641
  // Allow pre-recording command buffers
564
642
  struct vk_staging_memcpy {
565
643
  vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
@@ -1440,6 +1518,73 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1440
1518
  return supported;
1441
1519
  }
1442
1520
 
1521
+ struct GpuPipelineConfig {
1522
+ // GPU architecture identifier.
1523
+ // Example: vk_device_architecture::AMD_GCN
1524
+ vk_device_architecture arch;
1525
+
1526
+ // Mapping of pipeline names to their specific subgroup sizes.
1527
+ // Example: {"soft_max_f32", 64}
1528
+ std::unordered_map<std::string, uint32_t> pipelines;
1529
+
1530
+ // Default subgroup size for this GPU.
1531
+ // Defaults to 0 if not explicitly provided.
1532
+ uint32_t default_subgroup_size = 0;
1533
+ };
1534
+
1535
+ // Pipeline configuration for RDNA1 GPUs.
1536
+ static const std::unordered_map<std::string, uint32_t> rdna1_pipelines = {
1537
+ {"soft_max", 64}, {"im2col", 64},
1538
+ {"argmax", 64}, {"mul_mat_vec", 64},
1539
+ {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32}
1540
+ };
1541
+
1542
+ // Pipeline configuration for RDNA2 GPUs.
1543
+ static const std::unordered_map<std::string, uint32_t> rdna2_pipelines = {
1544
+ {"soft_max", 64}, {"im2col", 64},
1545
+ };
1546
+
1547
+ static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32;
1548
+
1549
+ // Define configurations for different GPUs.
1550
+ static std::vector<GpuPipelineConfig> gpu_pipeline_configs = {
1551
+ {
1552
+ vk_device_architecture::AMD_RDNA1,
1553
+ {
1554
+ rdna1_pipelines,
1555
+ },
1556
+ RDNA_DEFAULT_SUBGROUP_SIZE
1557
+ },
1558
+ {
1559
+ vk_device_architecture::AMD_RDNA2,
1560
+ {
1561
+ rdna2_pipelines,
1562
+ },
1563
+ RDNA_DEFAULT_SUBGROUP_SIZE
1564
+ },
1565
+ };
1566
+
1567
+ static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) {
1568
+ for (const auto &config : gpu_pipeline_configs) {
1569
+ if (config.arch == arch) {
1570
+ auto pipIt = config.pipelines.find(pipeline_name);
1571
+ if (pipIt != config.pipelines.end()) {
1572
+ return pipIt->second;
1573
+ }
1574
+ std::vector<std::pair<std::string, uint32_t>> sorted_pipelines(config.pipelines.begin(), config.pipelines.end());
1575
+ std::sort(sorted_pipelines.begin(), sorted_pipelines.end(),
1576
+ [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); });
1577
+ for (const auto &entry : sorted_pipelines) {
1578
+ if (pipeline_name.find(entry.first) != std::string::npos) {
1579
+ return entry.second;
1580
+ }
1581
+ }
1582
+ return config.default_subgroup_size;
1583
+ }
1584
+ }
1585
+ return 0; // If no matching configuration is found
1586
+ }
1587
+
1443
1588
  static void ggml_vk_load_shaders(vk_device& device) {
1444
1589
  VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")");
1445
1590
 
@@ -1461,36 +1606,36 @@ static void ggml_vk_load_shaders(vk_device& device) {
1461
1606
  uint32_t l_align, m_align, s_align;
1462
1607
  if (device->coopmat2) {
1463
1608
  // spec constants and tile sizes for non-quant matmul/matmul_id
1464
- l_warptile = { 256, 128, 256, 64 };
1465
- m_warptile = { 256, 128, 128, 64 };
1466
- s_warptile = { 128, 64, 64, 64 };
1609
+ l_warptile = { 256, 128, 256, 64, 1 };
1610
+ m_warptile = { 256, 128, 128, 64, 0 };
1611
+ s_warptile = { 128, 64, 64, 64, 0 };
1467
1612
  l_wg_denoms = {128, 256, 1 };
1468
1613
  m_wg_denoms = {128, 128, 1 };
1469
1614
  s_wg_denoms = { 64, 64, 1 };
1470
1615
 
1471
1616
  // spec constants and tile sizes for quant matmul (non-Qi_K)
1472
- l_warptile_mmq = { 256, 128, 256, 64 };
1473
- m_warptile_mmq = { 256, 128, 128, 64 };
1474
- s_warptile_mmq = { 256, 128, 128, 64 };
1617
+ l_warptile_mmq = { 256, 128, 256, 64, 1 };
1618
+ m_warptile_mmq = { 256, 128, 128, 64, 1 };
1619
+ s_warptile_mmq = { 256, 32, 64, 128, 0 };
1475
1620
  l_mmq_wg_denoms = { 128, 256, 1 };
1476
1621
  m_mmq_wg_denoms = { 128, 128, 1 };
1477
- s_mmq_wg_denoms = { 128, 128, 1 };
1622
+ s_mmq_wg_denoms = { 32, 64, 1 };
1478
1623
 
1479
1624
  // spec constants and tile sizes for quant matmul (Qi_K)
1480
- l_warptile_mmq_k = { 256, 128, 512, 16 };
1481
- m_warptile_mmq_k = { 256, 128, 256, 16 };
1482
- s_warptile_mmq_k = { 256, 32, 128, 64 };
1483
- l_mmq_wg_denoms_k = { 128, 512, 1 };
1484
- m_mmq_wg_denoms_k = { 128, 256, 1 };
1485
- s_mmq_wg_denoms_k = { 32, 128, 1 };
1625
+ l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
1626
+ m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
1627
+ s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
1628
+ l_mmq_wg_denoms_k = { 64, 128, 1 };
1629
+ m_mmq_wg_denoms_k = { 32, 64, 1 };
1630
+ s_mmq_wg_denoms_k = { 32, 32, 1 };
1486
1631
 
1487
1632
  // spec constants and tile sizes for quant matmul_id
1488
- l_warptile_mmqid = { 256, 128, 128, 16 };
1489
- m_warptile_mmqid = { 256, 128, 64, 16 };
1490
- s_warptile_mmqid = { 256, 64, 64, 16 };
1491
- l_mmqid_wg_denoms = { 128, 128, 1 };
1633
+ l_warptile_mmqid = { 256, 128, 64, 16, 0 };
1634
+ m_warptile_mmqid = { 256, 128, 64, 16, 0 };
1635
+ s_warptile_mmqid = { 256, 128, 64, 16, 0 };
1636
+ l_mmqid_wg_denoms = { 128, 64, 1 };
1492
1637
  m_mmqid_wg_denoms = { 128, 64, 1 };
1493
- s_mmqid_wg_denoms = { 64, 64, 1 };
1638
+ s_mmqid_wg_denoms = { 128, 64, 1 };
1494
1639
 
1495
1640
  l_align = 128;
1496
1641
  m_align = 64;
@@ -1566,6 +1711,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
1566
1711
  uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, const std::vector<uint32_t>& specialization_constants,
1567
1712
  uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) {
1568
1713
 
1714
+ if (!require_full_subgroups && required_subgroup_size == 0) {
1715
+ required_subgroup_size = get_subgroup_size(name, device->architecture);
1716
+ }
1717
+
1569
1718
  if (!pipeline) {
1570
1719
  pipeline = std::make_shared<vk_pipeline_struct>();
1571
1720
  pipeline->name = name;
@@ -1987,6 +2136,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
1987
2136
  }
1988
2137
  } else if (device->vendor_id == VK_VENDOR_ID_INTEL)
1989
2138
  rm_stdq = 2;
2139
+ uint32_t rm_iq = 2 * rm_kq;
1990
2140
 
1991
2141
  for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
1992
2142
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -2001,15 +2151,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2001
2151
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2002
2152
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2003
2153
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2004
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2005
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2006
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2007
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2008
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2009
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2010
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2011
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2012
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
2154
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f32_f32_len, mul_mat_vec_iq1_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2155
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f32_f32_len, mul_mat_vec_iq1_m_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2156
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f32_f32_len, mul_mat_vec_iq2_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2157
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f32_f32_len, mul_mat_vec_iq2_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2158
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2159
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2160
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2161
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2162
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2013
2163
 
2014
2164
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2015
2165
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@@ -2023,15 +2173,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2023
2173
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2024
2174
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2025
2175
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2026
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2027
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2028
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2029
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2030
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2031
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2032
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2033
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
2034
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
2176
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_s_f16_f32_len, mul_mat_vec_iq1_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2177
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq1_m_f16_f32_len, mul_mat_vec_iq1_m_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2178
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xxs_f16_f32_len, mul_mat_vec_iq2_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2179
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_xs_f16_f32_len, mul_mat_vec_iq2_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2180
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2181
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2182
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2183
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2184
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq, i+1}, 1, true);
2035
2185
  }
2036
2186
 
2037
2187
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
@@ -2046,15 +2196,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2046
2196
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2047
2197
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2048
2198
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2049
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2050
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2051
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2052
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2053
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2054
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2055
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2056
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
2057
- ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
2199
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2200
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2201
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2202
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2203
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2204
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2205
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2206
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2207
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
2058
2208
 
2059
2209
  // dequant shaders
2060
2210
  ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@@ -2121,6 +2271,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2121
2271
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2122
2272
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2123
2273
  ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2274
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2275
+ ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2124
2276
 
2125
2277
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2126
2278
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@@ -2180,9 +2332,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2180
2332
  ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2181
2333
  ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2182
2334
  ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2335
+ ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2183
2336
  ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2184
2337
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2185
2338
  ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2339
+ ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2186
2340
 
2187
2341
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2188
2342
 
@@ -2190,6 +2344,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2190
2344
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2191
2345
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2192
2346
  ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
2347
+ ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
2193
2348
 
2194
2349
  ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
2195
2350
  ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -2229,6 +2384,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2229
2384
 
2230
2385
  ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2231
2386
 
2387
+ ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
2388
+
2232
2389
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2233
2390
 
2234
2391
  for (auto &c : compiles) {
@@ -2237,7 +2394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2237
2394
  device->need_compiles = false;
2238
2395
  }
2239
2396
 
2240
- static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props);
2397
+ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch);
2241
2398
 
2242
2399
  static vk_device ggml_vk_get_device(size_t idx) {
2243
2400
  VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")");
@@ -2266,6 +2423,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
2266
2423
  device->physical_device = physical_devices[dev_num];
2267
2424
  const std::vector<vk::ExtensionProperties> ext_props = device->physical_device.enumerateDeviceExtensionProperties();
2268
2425
 
2426
+ device->architecture = get_device_architecture(device->physical_device);
2427
+
2269
2428
  const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY");
2270
2429
  device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr;
2271
2430
 
@@ -2278,7 +2437,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
2278
2437
  bool coopmat2_support = false;
2279
2438
  device->coopmat_support = false;
2280
2439
 
2281
- // Check if maintenance4 is supported
2282
2440
  for (const auto& properties : ext_props) {
2283
2441
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
2284
2442
  maintenance4_support = true;
@@ -2366,13 +2524,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2366
2524
 
2367
2525
  if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
2368
2526
  device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
2369
- #if defined(_WIN32)
2370
- } else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
2527
+ } else {
2371
2528
  // Limit batching of allocations to 1GB by default to avoid fragmentation issues
2372
2529
  device->suballocation_block_size = 1024*1024*1024;
2373
- #endif
2374
- } else {
2375
- device->suballocation_block_size = device->max_memory_allocation_size;
2376
2530
  }
2377
2531
  device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
2378
2532
 
@@ -2391,7 +2545,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2391
2545
 
2392
2546
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
2393
2547
 
2394
- if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props)) {
2548
+ if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) {
2395
2549
  device->coopmat_support = false;
2396
2550
  }
2397
2551
 
@@ -2769,7 +2923,10 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2769
2923
  subgroup_props.pNext = &driver_props;
2770
2924
  physical_device.getProperties2(&props2);
2771
2925
 
2772
- const size_t subgroup_size = subgroup_props.subgroupSize;
2926
+ vk_device_architecture arch = get_device_architecture(physical_device);
2927
+ uint32_t default_subgroup_size = get_subgroup_size("", arch);
2928
+ const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
2929
+
2773
2930
  const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
2774
2931
 
2775
2932
  bool fp16_storage = false;
@@ -2795,7 +2952,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
2795
2952
  }
2796
2953
  }
2797
2954
 
2798
- if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props)) {
2955
+ const vk_device_architecture device_architecture = get_device_architecture(physical_device);
2956
+
2957
+ if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
2799
2958
  coopmat_support = false;
2800
2959
  }
2801
2960
 
@@ -3840,10 +3999,14 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
3840
3999
  VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3841
4000
 
3842
4001
  if (ctx->device->coopmat2) {
3843
- if ((ctx->device->mul_mat_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
4002
+ // Use large shader when the N dimension is greater than the medium shader's tile size
4003
+ uint32_t crossover_large = mmp->m->wg_denoms[1];
4004
+ if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) {
3844
4005
  return aligned ? mmp->a_l : mmp->l;
3845
4006
  }
3846
- if ((ctx->device->mul_mat_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_s[src0_type]) {
4007
+ // Use medium shader when the N dimension is greater than the small shader's tile size
4008
+ uint32_t crossover_medium = mmp->s->wg_denoms[1];
4009
+ if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) {
3847
4010
  return aligned ? mmp->a_m : mmp->m;
3848
4011
  }
3849
4012
  return aligned ? mmp->a_s : mmp->s;
@@ -3868,18 +4031,19 @@ static void ggml_vk_matmul(
3868
4031
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer,
3869
4032
  uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
3870
4033
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3871
- uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3) {
4034
+ uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
4035
+ uint32_t padded_n) {
3872
4036
  VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
3873
4037
  ggml_vk_sync_buffers(subctx);
3874
4038
  if (split_k == 1) {
3875
- const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3 };
4039
+ const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
3876
4040
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, sizeof(vk_mat_mat_push_constants), &pc, { m, n, batch });
3877
4041
  return;
3878
4042
  }
3879
4043
 
3880
4044
  GGML_ASSERT(batch_stride_d == m * n);
3881
4045
 
3882
- const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3 };
4046
+ const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, CEIL_DIV(k, split_k), ne02, ne12, broadcast2, broadcast3, padded_n };
3883
4047
  // Make sure enough workgroups get assigned for split k to work
3884
4048
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, sizeof(vk_mat_mat_push_constants), &pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch });
3885
4049
  ggml_vk_sync_buffers(subctx);
@@ -3888,13 +4052,17 @@ static void ggml_vk_matmul(
3888
4052
  }
3889
4053
 
3890
4054
  static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, bool aligned, ggml_type src0_type) {
3891
- VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
4055
+ VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
3892
4056
 
3893
4057
  if (ctx->device->coopmat2) {
3894
- if ((ctx->device->mul_mat_id_l[src0_type] && (m % mmp->l->wg_denoms[0]) == 0 && (n % mmp->l->wg_denoms[1]) == 0) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
4058
+ // Use large shader when the N dimension is greater than the medium shader's tile size
4059
+ uint32_t crossover_large = mmp->m->wg_denoms[1];
4060
+ if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) {
3895
4061
  return aligned ? mmp->a_l : mmp->l;
3896
4062
  }
3897
- if ((ctx->device->mul_mat_id_m[src0_type] && (m % mmp->m->wg_denoms[0]) == 0 && (n % mmp->m->wg_denoms[1]) == 0) || !ctx->device->mul_mat_id_s[src0_type]) {
4063
+ // Use medium shader when the N dimension is greater than the small shader's tile size
4064
+ uint32_t crossover_medium = mmp->s->wg_denoms[1];
4065
+ if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) {
3898
4066
  return aligned ? mmp->a_m : mmp->m;
3899
4067
  }
3900
4068
  return aligned ? mmp->a_s : mmp->s;
@@ -3919,14 +4087,15 @@ static void ggml_vk_matmul_id(
3919
4087
  vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids,
3920
4088
  uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d,
3921
4089
  uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
3922
- uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11) {
4090
+ uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11,
4091
+ uint32_t padded_n) {
3923
4092
  VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " <<
3924
4093
  "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " <<
3925
4094
  "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " <<
3926
4095
  "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")");
3927
4096
  ggml_vk_sync_buffers(subctx);
3928
4097
  const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d,
3929
- nei0, nei1, nbi1, ne11 };
4098
+ nei0, nei1, nbi1, ne11, padded_n };
3930
4099
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, sizeof(vk_mat_mat_id_push_constants), &pc, { m, nei1, n_as });
3931
4100
  }
3932
4101
 
@@ -4088,15 +4257,17 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4088
4257
  // Not implemented
4089
4258
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4090
4259
 
4091
- const int x_ne = ne01 * ne00;
4092
- const int y_ne = ne11 * ne10;
4093
- const int d_ne = ne11 * ne01;
4094
-
4095
4260
  const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4096
4261
  const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
4097
4262
 
4098
4263
  vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4099
4264
 
4265
+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4266
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4267
+ const int x_ne = ne01 * ne00;
4268
+ const int y_ne = padded_n * ne10;
4269
+ const int d_ne = ne11 * ne01;
4270
+
4100
4271
  const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, pipeline);
4101
4272
 
4102
4273
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
@@ -4183,7 +4354,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4183
4354
  }
4184
4355
  if (qy_needs_dequant) {
4185
4356
  d_Y = ctx->prealloc_y;
4186
- GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
4357
+ GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4187
4358
  } else {
4188
4359
  d_Y = d_Qy;
4189
4360
  y_buf_offset = qy_buf_offset;
@@ -4219,7 +4390,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4219
4390
  { d_D, d_buf_offset, d_sz * ne12 * ne13 }, { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k },
4220
4391
  ne01, ne11, ne10,
4221
4392
  ne10, ne10, ne01, stride_batch_x, stride_batch_y, ne20*ne21,
4222
- split_k, ne12*ne13, ne02, ne12, r2, r3
4393
+ split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n
4223
4394
  ); // NOLINT
4224
4395
  }
4225
4396
 
@@ -4670,15 +4841,17 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4670
4841
  // Not implemented
4671
4842
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4672
4843
 
4673
- const uint64_t x_ne = ne01 * ne00;
4674
- const uint64_t y_ne = ne11 * ne10;
4675
- const uint64_t d_ne = ne21 * ne20;
4676
-
4677
4844
  const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
4678
4845
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
4679
4846
 
4680
4847
  vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
4681
4848
 
4849
+ // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4850
+ uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
4851
+ const uint64_t x_ne = ne01 * ne00;
4852
+ const uint64_t y_ne = padded_n * ne10;
4853
+ const uint64_t d_ne = ne21 * ne20;
4854
+
4682
4855
  const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
4683
4856
  const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
4684
4857
  const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
@@ -4760,7 +4933,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4760
4933
  }
4761
4934
  if (qy_needs_dequant) {
4762
4935
  d_Y = ctx->prealloc_y;
4763
- GGML_ASSERT(d_Y->size >= y_sz * ne02 * ne03);
4936
+ GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
4764
4937
  } else {
4765
4938
  d_Y = d_Qy;
4766
4939
  y_buf_offset = qy_buf_offset;
@@ -4797,7 +4970,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
4797
4970
  { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz },
4798
4971
  ne01, ne21, ne10, ne10, ne10, ne01,
4799
4972
  stride_batch_x, stride_batch_y, ne20*ne21,
4800
- n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11
4973
+ n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n
4801
4974
  ); // NOLINT
4802
4975
  }
4803
4976
 
@@ -5283,6 +5456,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5283
5456
  case GGML_OP_CONT:
5284
5457
  case GGML_OP_DUP:
5285
5458
  return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
5459
+ case GGML_OP_SILU_BACK:
5460
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5461
+ return ctx->device->pipeline_silu_back_f32;
5462
+ }
5463
+ return nullptr;
5286
5464
  case GGML_OP_NORM:
5287
5465
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5288
5466
  return ctx->device->pipeline_norm_f32;
@@ -5298,6 +5476,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5298
5476
  return ctx->device->pipeline_rms_norm_f32;
5299
5477
  }
5300
5478
  return nullptr;
5479
+ case GGML_OP_RMS_NORM_BACK:
5480
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5481
+ return ctx->device->pipeline_rms_norm_back_f32;
5482
+ }
5483
+ return nullptr;
5484
+ case GGML_OP_L2_NORM:
5485
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5486
+ return ctx->device->pipeline_l2_norm_f32;
5487
+ }
5488
+ return nullptr;
5301
5489
  case GGML_OP_UNARY:
5302
5490
  switch (ggml_get_unary_op(dst)) {
5303
5491
  case GGML_UNARY_OP_SILU:
@@ -5325,6 +5513,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5325
5513
  return ctx->device->pipeline_tanh_f32;
5326
5514
  }
5327
5515
  break;
5516
+ case GGML_UNARY_OP_SIGMOID:
5517
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5518
+ return ctx->device->pipeline_sigmoid_f32;
5519
+ }
5520
+ break;
5328
5521
  default:
5329
5522
  break;
5330
5523
  }
@@ -5344,7 +5537,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5344
5537
  return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16;
5345
5538
  }
5346
5539
  return nullptr;
5540
+ case GGML_OP_SOFT_MAX_BACK:
5541
+ if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5542
+ return ctx->device->pipeline_soft_max_back_f32;
5543
+ }
5544
+ return nullptr;
5347
5545
  case GGML_OP_ROPE:
5546
+ case GGML_OP_ROPE_BACK:
5348
5547
  {
5349
5548
  const int mode = ((const int32_t *) dst->op_params)[2];
5350
5549
  const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
@@ -5426,6 +5625,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5426
5625
  return ctx->device->pipeline_rwkv_wkv6_f32;
5427
5626
  }
5428
5627
  return nullptr;
5628
+ case GGML_OP_RWKV_WKV7:
5629
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5630
+ return ctx->device->pipeline_rwkv_wkv7_f32;
5631
+ }
5632
+ return nullptr;
5429
5633
  case GGML_OP_OPT_STEP_ADAMW:
5430
5634
  if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5431
5635
  return ctx->device->pipeline_opt_step_adamw_f32;
@@ -5672,7 +5876,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5672
5876
  switch (op) {
5673
5877
  case GGML_OP_NORM:
5674
5878
  case GGML_OP_RMS_NORM:
5879
+ case GGML_OP_RMS_NORM_BACK:
5880
+ case GGML_OP_L2_NORM:
5675
5881
  case GGML_OP_SOFT_MAX:
5882
+ case GGML_OP_SOFT_MAX_BACK:
5676
5883
  case GGML_OP_SUM_ROWS:
5677
5884
  case GGML_OP_ARGMAX:
5678
5885
  {
@@ -5696,6 +5903,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5696
5903
  } break;
5697
5904
  case GGML_OP_DIAG_MASK_INF:
5698
5905
  case GGML_OP_ROPE:
5906
+ case GGML_OP_ROPE_BACK:
5699
5907
  elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 };
5700
5908
  break;
5701
5909
  case GGML_OP_GET_ROWS:
@@ -5791,7 +5999,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
5791
5999
 
5792
6000
  ggml_vk_sync_buffers(subctx);
5793
6001
  ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements);
5794
- } else if (op == GGML_OP_ROPE) {
6002
+ } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) {
5795
6003
  // Empty src2 is possible in rope, but the shader needs a buffer
5796
6004
  vk_subbuffer subbuf_z;
5797
6005
  if (use_src2) {
@@ -5919,23 +6127,17 @@ static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const
5919
6127
  }, dryrun);
5920
6128
  }
5921
6129
 
5922
- static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, bool dryrun = false) {
5923
- const ggml_tensor * k = dst->src[0];
5924
- const ggml_tensor * v = dst->src[1];
5925
- const ggml_tensor * r = dst->src[2];
5926
- const ggml_tensor * tf = dst->src[3];
5927
- const ggml_tensor * td = dst->src[4];
5928
- const ggml_tensor * state = dst->src[5];
5929
-
5930
- GGML_ASSERT(!ggml_is_quantized(k->type));
5931
- GGML_ASSERT(!ggml_is_quantized(v->type));
5932
- GGML_ASSERT(!ggml_is_quantized(r->type));
5933
- GGML_ASSERT(!ggml_is_quantized(tf->type));
5934
- GGML_ASSERT(!ggml_is_quantized(td->type));
5935
- GGML_ASSERT(!ggml_is_quantized(state->type));
6130
+ static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) {
6131
+ GGML_ASSERT(version == 6 || version == 7);
6132
+ int num_srcs = version == 6 ? 6 : 7;
6133
+
6134
+ for (int i = 0; i < num_srcs; i++) {
6135
+ GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
6136
+ }
6137
+
5936
6138
  GGML_ASSERT(dst->buffer != nullptr);
5937
6139
 
5938
- vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, k, v, r, dst, GGML_OP_RWKV_WKV6);
6140
+ vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
5939
6141
  GGML_ASSERT(pipeline != nullptr);
5940
6142
 
5941
6143
  if (dryrun) {
@@ -5944,89 +6146,73 @@ static void ggml_vk_op_f32_rwkv6(ggml_backend_vk_context * ctx, vk_context& subc
5944
6146
  }
5945
6147
 
5946
6148
  ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
5947
- ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context;
5948
- ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context;
5949
- ggml_backend_vk_buffer_context * r_buf_ctx = (ggml_backend_vk_buffer_context *)r->buffer->context;
5950
- ggml_backend_vk_buffer_context * tf_buf_ctx = (ggml_backend_vk_buffer_context *)tf->buffer->context;
5951
- ggml_backend_vk_buffer_context * td_buf_ctx = (ggml_backend_vk_buffer_context *)td->buffer->context;
5952
- ggml_backend_vk_buffer_context * state_buf_ctx = (ggml_backend_vk_buffer_context *)state->buffer->context;
6149
+ ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
6150
+ for (int i = 0; i < num_srcs; i++) {
6151
+ src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context;
6152
+ }
5953
6153
 
5954
6154
  ggml_vk_sync_buffers(subctx);
5955
6155
 
5956
- vk_buffer d_D = nullptr, d_K = nullptr, d_V = nullptr, d_R = nullptr, d_TF = nullptr, d_TD = nullptr, d_State = nullptr;
5957
- size_t k_offset = 0, v_offset = 0, r_offset = 0, tf_offset = 0, td_offset = 0, state_offset = 0, dst_offset = 0;
5958
- bool K_uma = false, V_uma = false, R_uma = false, TF_uma = false, TD_uma = false, STATE_uma = false, DST_uma = false;
6156
+ vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr };
6157
+ size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 };
6158
+ bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false };
5959
6159
 
5960
6160
  if (ctx->device->uma) {
5961
- ggml_vk_host_get(ctx->device, k->data, d_K, k_offset);
5962
- ggml_vk_host_get(ctx->device, v->data, d_V, v_offset);
5963
- ggml_vk_host_get(ctx->device, r->data, d_R, r_offset);
5964
- ggml_vk_host_get(ctx->device, tf->data, d_TF, tf_offset);
5965
- ggml_vk_host_get(ctx->device, td->data, d_TD, td_offset);
5966
- ggml_vk_host_get(ctx->device, state->data, d_State, state_offset);
5967
- ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
6161
+ for (int i = 0; i < num_srcs; i++) {
6162
+ ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]);
6163
+ srcs_uma[i] = d_srcs[i] != nullptr;
6164
+ }
5968
6165
 
5969
- K_uma = d_K != nullptr;
5970
- V_uma = d_V != nullptr;
5971
- R_uma = d_R != nullptr;
5972
- TF_uma = d_TF != nullptr;
5973
- TD_uma = d_TD != nullptr;
5974
- STATE_uma = d_State != nullptr;
5975
- DST_uma = d_D != nullptr;
6166
+ ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset);
6167
+ dst_uma = d_D != nullptr;
5976
6168
  }
5977
6169
 
5978
- if (!K_uma) {
5979
- d_K = k_buf_ctx->dev_buffer;
5980
- k_offset = vk_tensor_offset(k) + k->view_offs;
5981
- }
5982
- if (!V_uma) {
5983
- d_V = v_buf_ctx->dev_buffer;
5984
- v_offset = vk_tensor_offset(v) + v->view_offs;
5985
- }
5986
- if (!R_uma) {
5987
- d_R = r_buf_ctx->dev_buffer;
5988
- r_offset = vk_tensor_offset(r) + r->view_offs;
5989
- }
5990
- if (!TF_uma) {
5991
- d_TF = tf_buf_ctx->dev_buffer;
5992
- tf_offset = vk_tensor_offset(tf) + tf->view_offs;
5993
- }
5994
- if (!TD_uma) {
5995
- d_TD = td_buf_ctx->dev_buffer;
5996
- td_offset = vk_tensor_offset(td) + td->view_offs;
5997
- }
5998
- if (!STATE_uma) {
5999
- d_State = state_buf_ctx->dev_buffer;
6000
- state_offset = vk_tensor_offset(state) + state->view_offs;
6170
+ uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 };
6171
+ for (int i = 0; i < num_srcs; i++) {
6172
+ src_sizes[i] = ggml_nbytes(dst->src[i]);
6173
+ if (!srcs_uma[i]) {
6174
+ d_srcs[i] = src_buf_ctxs[i]->dev_buffer;
6175
+ src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs;
6176
+ }
6001
6177
  }
6002
- if (!DST_uma) {
6178
+
6179
+ const uint64_t dst_size = ggml_nbytes(dst);
6180
+ if (!dst_uma) {
6003
6181
  d_D = dst_buf_ctx->dev_buffer;
6004
6182
  dst_offset = vk_tensor_offset(dst) + dst->view_offs;
6005
6183
  }
6006
6184
 
6007
- const uint64_t k_size = ggml_nbytes(k);
6008
- const uint64_t v_size = ggml_nbytes(v);
6009
- const uint64_t r_size = ggml_nbytes(r);
6010
- const uint64_t tf_size = ggml_nbytes(tf);
6011
- const uint64_t td_size = ggml_nbytes(td);
6012
- const uint64_t state_size = ggml_nbytes(state);
6013
- const uint64_t dst_size = ggml_nbytes(dst);
6014
-
6015
6185
  std::array<uint32_t, 3> elements = {
6016
6186
  (uint32_t)(pc.B * pc.H),
6017
6187
  1,
6018
6188
  1
6019
6189
  };
6020
6190
 
6021
- ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6022
- vk_subbuffer{ d_K, k_offset, k_size },
6023
- vk_subbuffer{ d_V, v_offset, v_size },
6024
- vk_subbuffer{ d_R, r_offset, r_size },
6025
- vk_subbuffer{ d_TF, tf_offset, tf_size },
6026
- vk_subbuffer{ d_TD, td_offset, td_size },
6027
- vk_subbuffer{ d_State, state_offset, state_size },
6028
- vk_subbuffer{ d_D, dst_offset, dst_size }
6029
- }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
6191
+ if (version == 6) {
6192
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6193
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
6194
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
6195
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
6196
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
6197
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
6198
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
6199
+ vk_subbuffer{ d_D, dst_offset, dst_size }
6200
+ }, sizeof(vk_op_rwkv_wkv6_push_constants), &pc, elements);
6201
+ } else if (version == 7) {
6202
+ ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {
6203
+ vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] },
6204
+ vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] },
6205
+ vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] },
6206
+ vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] },
6207
+ vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] },
6208
+ vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] },
6209
+ vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] },
6210
+ vk_subbuffer{ d_D, dst_offset, dst_size }
6211
+ }, sizeof(vk_op_rwkv_wkv7_push_constants), &pc, elements);
6212
+ } else {
6213
+ // shouldn't happen
6214
+ GGML_ASSERT(false);
6215
+ }
6030
6216
  }
6031
6217
 
6032
6218
  static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
@@ -6035,7 +6221,7 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
6035
6221
  const size_t n_heads = dst->src[0]->ne[1];
6036
6222
  const size_t n_seqs = dst->src[5]->ne[1];
6037
6223
 
6038
- ggml_vk_op_f32_rwkv6(
6224
+ ggml_vk_op_f32_wkv(
6039
6225
  ctx, subctx, dst,
6040
6226
  {
6041
6227
  (uint32_t)n_seqs,
@@ -6043,6 +6229,26 @@ static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx,
6043
6229
  (uint32_t)n_embed,
6044
6230
  (uint32_t)n_heads,
6045
6231
  },
6232
+ 6,
6233
+ dryrun
6234
+ );
6235
+ }
6236
+
6237
+ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) {
6238
+ const size_t seq_length = dst->src[0]->ne[2];
6239
+ const size_t n_embed = dst->ne[0];
6240
+ const size_t n_heads = dst->src[0]->ne[1];
6241
+ const size_t n_seqs = dst->src[6]->ne[1];
6242
+
6243
+ ggml_vk_op_f32_wkv(
6244
+ ctx, subctx, dst,
6245
+ {
6246
+ (uint32_t)n_seqs,
6247
+ (uint32_t)seq_length,
6248
+ (uint32_t)n_embed,
6249
+ (uint32_t)n_heads,
6250
+ },
6251
+ 7,
6046
6252
  dryrun
6047
6253
  );
6048
6254
  }
@@ -6313,6 +6519,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
6313
6519
  }, dryrun);
6314
6520
  }
6315
6521
 
6522
+ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6523
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6524
+ }
6525
+
6316
6526
  static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6317
6527
  float * op_params = (float *)dst->op_params;
6318
6528
 
@@ -6335,6 +6545,16 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx,
6335
6545
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6336
6546
  }
6337
6547
 
6548
+ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6549
+ float * op_params = (float *)dst->op_params;
6550
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6551
+ }
6552
+
6553
+ static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6554
+ float * op_params = (float *)dst->op_params;
6555
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun);
6556
+ }
6557
+
6338
6558
  static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
6339
6559
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
6340
6560
  }
@@ -6370,7 +6590,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
6370
6590
  }, dryrun);
6371
6591
  }
6372
6592
 
6373
- static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
6593
+ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
6594
+ float * op_params = (float *)dst->op_params;
6595
+ ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
6596
+ }
6597
+
6598
+ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
6374
6599
  const int n_dims = ((int32_t *) dst->op_params)[1];
6375
6600
  const int mode = ((int32_t *) dst->op_params)[2];
6376
6601
  // const int n_ctx = ((int32_t *) dst->op_params)[3];
@@ -6398,7 +6623,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons
6398
6623
  (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1],
6399
6624
  freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale,
6400
6625
  src2 != nullptr, (uint32_t)src0->ne[2], s1, s2,
6401
- sections[0], sections[1], sections[2], sections[3],
6626
+ sections[0], sections[1], sections[2], sections[3], backprop
6402
6627
  }, dryrun);
6403
6628
  }
6404
6629
 
@@ -6719,7 +6944,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
6719
6944
  ctx, subctx, p, ggml_vk_subbuffer(d_X), ggml_vk_subbuffer(d_Y), ggml_vk_subbuffer(d_D), ggml_vk_subbuffer(ctx->prealloc_split_k),
6720
6945
  m, n, k,
6721
6946
  k, k, m, k*m, k*n, m*n,
6722
- split_k, batch, batch, batch, 1, 1
6947
+ split_k, batch, batch, batch, 1, 1, n
6723
6948
  );
6724
6949
  }
6725
6950
  ggml_vk_ctx_end(subctx);
@@ -7064,7 +7289,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
7064
7289
  ctx, subctx, p, ggml_vk_subbuffer(qx_buf), ggml_vk_subbuffer(y_buf), ggml_vk_subbuffer(d_buf), ggml_vk_subbuffer(ctx->prealloc_split_k),
7065
7290
  m, n, k,
7066
7291
  k, k, m, k*m, k*n, m*n,
7067
- split_k, batch, batch, batch, 1, 1
7292
+ split_k, batch, batch, batch, 1, 1, n
7068
7293
  );
7069
7294
  }
7070
7295
  ggml_vk_ctx_end(subctx);
@@ -7295,6 +7520,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7295
7520
  case GGML_UNARY_OP_GELU_QUICK:
7296
7521
  case GGML_UNARY_OP_RELU:
7297
7522
  case GGML_UNARY_OP_TANH:
7523
+ case GGML_UNARY_OP_SIGMOID:
7298
7524
  break;
7299
7525
  default:
7300
7526
  return false;
@@ -7319,12 +7545,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7319
7545
  case GGML_OP_CPY:
7320
7546
  case GGML_OP_CONT:
7321
7547
  case GGML_OP_DUP:
7548
+ case GGML_OP_SILU_BACK:
7322
7549
  case GGML_OP_NORM:
7323
7550
  case GGML_OP_GROUP_NORM:
7324
7551
  case GGML_OP_RMS_NORM:
7552
+ case GGML_OP_RMS_NORM_BACK:
7553
+ case GGML_OP_L2_NORM:
7325
7554
  case GGML_OP_DIAG_MASK_INF:
7326
7555
  case GGML_OP_SOFT_MAX:
7556
+ case GGML_OP_SOFT_MAX_BACK:
7327
7557
  case GGML_OP_ROPE:
7558
+ case GGML_OP_ROPE_BACK:
7328
7559
  case GGML_OP_MUL_MAT:
7329
7560
  case GGML_OP_MUL_MAT_ID:
7330
7561
  case GGML_OP_ARGSORT:
@@ -7336,6 +7567,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7336
7567
  case GGML_OP_TIMESTEP_EMBEDDING:
7337
7568
  case GGML_OP_POOL_2D:
7338
7569
  case GGML_OP_RWKV_WKV6:
7570
+ case GGML_OP_RWKV_WKV7:
7339
7571
  case GGML_OP_LEAKY_RELU:
7340
7572
  case GGML_OP_FLASH_ATTN_EXT:
7341
7573
  case GGML_OP_OPT_STEP_ADAMW:
@@ -7377,13 +7609,18 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7377
7609
  case GGML_OP_CPY:
7378
7610
  case GGML_OP_CONT:
7379
7611
  case GGML_OP_DUP:
7612
+ case GGML_OP_SILU_BACK:
7380
7613
  case GGML_OP_NORM:
7381
7614
  case GGML_OP_GROUP_NORM:
7382
7615
  case GGML_OP_RMS_NORM:
7616
+ case GGML_OP_RMS_NORM_BACK:
7617
+ case GGML_OP_L2_NORM:
7383
7618
  case GGML_OP_UNARY:
7384
7619
  case GGML_OP_DIAG_MASK_INF:
7385
7620
  case GGML_OP_SOFT_MAX:
7621
+ case GGML_OP_SOFT_MAX_BACK:
7386
7622
  case GGML_OP_ROPE:
7623
+ case GGML_OP_ROPE_BACK:
7387
7624
  case GGML_OP_ARGSORT:
7388
7625
  case GGML_OP_SUM:
7389
7626
  case GGML_OP_SUM_ROWS:
@@ -7475,6 +7712,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7475
7712
  case GGML_OP_DUP:
7476
7713
  ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun);
7477
7714
 
7715
+ break;
7716
+ case GGML_OP_SILU_BACK:
7717
+ ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun);
7718
+
7478
7719
  break;
7479
7720
  case GGML_OP_NORM:
7480
7721
  ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun);
@@ -7487,6 +7728,14 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7487
7728
  case GGML_OP_RMS_NORM:
7488
7729
  ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
7489
7730
 
7731
+ break;
7732
+ case GGML_OP_RMS_NORM_BACK:
7733
+ ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
7734
+
7735
+ break;
7736
+ case GGML_OP_L2_NORM:
7737
+ ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun);
7738
+
7490
7739
  break;
7491
7740
  case GGML_OP_UNARY:
7492
7741
  switch (ggml_get_unary_op(node)) {
@@ -7495,6 +7744,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7495
7744
  case GGML_UNARY_OP_GELU_QUICK:
7496
7745
  case GGML_UNARY_OP_RELU:
7497
7746
  case GGML_UNARY_OP_TANH:
7747
+ case GGML_UNARY_OP_SIGMOID:
7498
7748
  ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun);
7499
7749
  break;
7500
7750
  default:
@@ -7508,9 +7758,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7508
7758
  case GGML_OP_SOFT_MAX:
7509
7759
  ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun);
7510
7760
 
7761
+ break;
7762
+ case GGML_OP_SOFT_MAX_BACK:
7763
+ ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun);
7764
+
7511
7765
  break;
7512
7766
  case GGML_OP_ROPE:
7513
- ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun);
7767
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun);
7768
+
7769
+ break;
7770
+ case GGML_OP_ROPE_BACK:
7771
+ ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun);
7514
7772
 
7515
7773
  break;
7516
7774
  case GGML_OP_ARGSORT:
@@ -7568,6 +7826,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
7568
7826
 
7569
7827
  break;
7570
7828
 
7829
+ case GGML_OP_RWKV_WKV7:
7830
+ ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun);
7831
+
7832
+ break;
7833
+
7571
7834
  case GGML_OP_OPT_STEP_ADAMW:
7572
7835
  ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
7573
7836
 
@@ -7636,12 +7899,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7636
7899
  case GGML_OP_CPY:
7637
7900
  case GGML_OP_CONT:
7638
7901
  case GGML_OP_DUP:
7902
+ case GGML_OP_SILU_BACK:
7639
7903
  case GGML_OP_NORM:
7640
7904
  case GGML_OP_GROUP_NORM:
7641
7905
  case GGML_OP_RMS_NORM:
7906
+ case GGML_OP_RMS_NORM_BACK:
7907
+ case GGML_OP_L2_NORM:
7642
7908
  case GGML_OP_DIAG_MASK_INF:
7643
7909
  case GGML_OP_SOFT_MAX:
7910
+ case GGML_OP_SOFT_MAX_BACK:
7644
7911
  case GGML_OP_ROPE:
7912
+ case GGML_OP_ROPE_BACK:
7645
7913
  case GGML_OP_RESHAPE:
7646
7914
  case GGML_OP_VIEW:
7647
7915
  case GGML_OP_PERMUTE:
@@ -7656,6 +7924,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7656
7924
  case GGML_OP_TIMESTEP_EMBEDDING:
7657
7925
  case GGML_OP_POOL_2D:
7658
7926
  case GGML_OP_RWKV_WKV6:
7927
+ case GGML_OP_RWKV_WKV7:
7659
7928
  case GGML_OP_LEAKY_RELU:
7660
7929
  case GGML_OP_REPEAT:
7661
7930
  case GGML_OP_REPEAT_BACK:
@@ -7670,6 +7939,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
7670
7939
  case GGML_UNARY_OP_GELU_QUICK:
7671
7940
  case GGML_UNARY_OP_RELU:
7672
7941
  case GGML_UNARY_OP_TANH:
7942
+ case GGML_UNARY_OP_SIGMOID:
7673
7943
  buf = tensor->buffer;
7674
7944
  break;
7675
7945
  default:
@@ -7844,11 +8114,12 @@ static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) {
7844
8114
  UNUSED(buffer);
7845
8115
  }
7846
8116
 
7847
- static void ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
8117
+ static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
7848
8118
  VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")");
7849
8119
  if (tensor->view_src != nullptr) {
7850
8120
  GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft);
7851
8121
  }
8122
+ return GGML_STATUS_SUCCESS;
7852
8123
  }
7853
8124
 
7854
8125
  static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
@@ -8165,8 +8436,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8165
8436
  VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
8166
8437
  ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
8167
8438
 
8439
+ uint64_t total_mat_mul_bytes = 0;
8168
8440
  for (int i = 0; i < cgraph->n_nodes; i++) {
8169
8441
  ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
8442
+ if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8443
+ total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8444
+ }
8170
8445
  }
8171
8446
  if (ctx->device->need_compiles) {
8172
8447
  ggml_vk_load_shaders(ctx->device);
@@ -8187,17 +8462,27 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8187
8462
  bool first_node_in_batch = true; // true if next node will be first node in a batch
8188
8463
  int submit_node_idx = 0; // index to first node in a batch
8189
8464
 
8190
- // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
8191
- // Start with a smaller count to get work submitted right away, and increase it after each submit.
8192
- int nodes_per_submit = 20;
8465
+ // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
8466
+ // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
8467
+ // (and scaled down based on model size, so smaller models submit earlier).
8468
+ // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
8469
+ int nodes_per_submit = 100;
8193
8470
  int submitted_nodes = 0;
8194
8471
  int submit_count = 0;
8472
+ uint64_t mul_mat_bytes = 0;
8473
+ uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
8195
8474
  for (int i = 0; i < cgraph->n_nodes; i++) {
8196
8475
  if (first_node_in_batch) {
8197
8476
  submit_node_idx = i;
8198
8477
  }
8199
8478
 
8200
- bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
8479
+ if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
8480
+ mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
8481
+ }
8482
+
8483
+ bool submit = (submitted_nodes >= nodes_per_submit) ||
8484
+ (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
8485
+ (i == last_node);
8201
8486
 
8202
8487
  bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
8203
8488
 
@@ -8214,13 +8499,9 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
8214
8499
  if (submit) {
8215
8500
  first_node_in_batch = true;
8216
8501
  submitted_nodes = 0;
8217
- switch (submit_count) {
8218
- case 0:
8219
- nodes_per_submit = 50;
8220
- break;
8221
- default:
8222
- nodes_per_submit = 100;
8223
- break;
8502
+ mul_mat_bytes = 0;
8503
+ if (submit_count < 3) {
8504
+ mul_mat_bytes_per_submit *= 2;
8224
8505
  }
8225
8506
  submit_count++;
8226
8507
  }
@@ -8371,7 +8652,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8371
8652
  case GGML_UNARY_OP_SILU:
8372
8653
  case GGML_UNARY_OP_RELU:
8373
8654
  case GGML_UNARY_OP_TANH:
8374
- return ggml_is_contiguous(op->src[0]);
8655
+ case GGML_UNARY_OP_SIGMOID:
8656
+ return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
8375
8657
  default:
8376
8658
  return false;
8377
8659
  }
@@ -8560,6 +8842,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8560
8842
  case GGML_OP_REPEAT_BACK:
8561
8843
  return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32;
8562
8844
  case GGML_OP_ROPE:
8845
+ case GGML_OP_ROPE_BACK:
8563
8846
  case GGML_OP_NONE:
8564
8847
  case GGML_OP_RESHAPE:
8565
8848
  case GGML_OP_VIEW:
@@ -8569,22 +8852,27 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8569
8852
  case GGML_OP_NORM:
8570
8853
  case GGML_OP_GROUP_NORM:
8571
8854
  case GGML_OP_RMS_NORM:
8855
+ case GGML_OP_L2_NORM:
8572
8856
  return ggml_is_contiguous(op->src[0]);
8573
8857
  case GGML_OP_ADD:
8574
- case GGML_OP_ACC:
8575
8858
  case GGML_OP_SUB:
8576
8859
  case GGML_OP_MUL:
8577
8860
  case GGML_OP_DIV:
8578
- case GGML_OP_CONCAT:
8579
- case GGML_OP_UPSCALE:
8580
- case GGML_OP_SCALE:
8861
+ case GGML_OP_SILU_BACK:
8862
+ case GGML_OP_RMS_NORM_BACK:
8581
8863
  case GGML_OP_SQR:
8582
8864
  case GGML_OP_SIN:
8583
8865
  case GGML_OP_COS:
8584
8866
  case GGML_OP_CLAMP:
8867
+ return op->src[0]->type == GGML_TYPE_F32;
8868
+ case GGML_OP_ACC:
8869
+ case GGML_OP_CONCAT:
8870
+ case GGML_OP_UPSCALE:
8871
+ case GGML_OP_SCALE:
8585
8872
  case GGML_OP_PAD:
8586
8873
  case GGML_OP_DIAG_MASK_INF:
8587
8874
  case GGML_OP_SOFT_MAX:
8875
+ case GGML_OP_SOFT_MAX_BACK:
8588
8876
  case GGML_OP_ARGSORT:
8589
8877
  case GGML_OP_SUM:
8590
8878
  case GGML_OP_SUM_ROWS:
@@ -8594,6 +8882,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
8594
8882
  case GGML_OP_TIMESTEP_EMBEDDING:
8595
8883
  case GGML_OP_POOL_2D:
8596
8884
  case GGML_OP_RWKV_WKV6:
8885
+ case GGML_OP_RWKV_WKV7:
8597
8886
  case GGML_OP_LEAKY_RELU:
8598
8887
  case GGML_OP_OPT_STEP_ADAMW:
8599
8888
  return true;
@@ -8740,7 +9029,7 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve
8740
9029
  UNUSED(instance_extensions);
8741
9030
  }
8742
9031
 
8743
- static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props) {
9032
+ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) {
8744
9033
  switch (props.vendorID) {
8745
9034
  case VK_VENDOR_ID_INTEL:
8746
9035
  // Intel drivers don't support coopmat properly yet
@@ -8748,10 +9037,7 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope
8748
9037
  case VK_VENDOR_ID_AMD:
8749
9038
  if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) {
8750
9039
  // Workaround for AMD proprietary driver reporting support on all GPUs
8751
- const std::string name = props.deviceName;
8752
- return name.rfind("AMD Radeon RX 7", 0) == 0 || name.rfind("AMD Radeon(TM) RX 7", 0) == 0 || // RDNA 3 consumer GPUs
8753
- name.rfind("AMD Radeon PRO W7", 0) == 0 || name.rfind("AMD Radeon(TM) PRO W7", 0) == 0 || // RDNA 3 workstation GPUs
8754
- name.rfind("AMD Radeon 7", 0) == 0 || name.rfind("AMD Radeon(TM) 7", 0) == 0; // RDNA 3 APUs
9040
+ return arch == vk_device_architecture::AMD_RDNA3;
8755
9041
  }
8756
9042
  return true;
8757
9043
  default:
@@ -8976,15 +9262,25 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8976
9262
  tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]);
8977
9263
  } else if (tensor->op == GGML_OP_RMS_NORM) {
8978
9264
  tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
9265
+ } else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
9266
+ const float eps = ((float *) tensor->op_params)[0];
9267
+ tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps);
9268
+ } else if (tensor->op == GGML_OP_SILU_BACK) {
9269
+ tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]);
9270
+ } else if (tensor->op == GGML_OP_L2_NORM) {
9271
+ const float eps = ((float *) tensor->op_params)[0];
9272
+ tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
8979
9273
  } else if (tensor->op == GGML_OP_SOFT_MAX) {
8980
9274
  if (src1 != nullptr) {
8981
9275
  tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8982
9276
  } else {
8983
9277
  tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
8984
9278
  }
9279
+ } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
9280
+ tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
8985
9281
  } else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
8986
9282
  tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params);
8987
- } else if (tensor->op == GGML_OP_ROPE) {
9283
+ } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
8988
9284
  const int n_dims = ((int32_t *) tensor->op_params)[1];
8989
9285
  const int mode = ((int32_t *) tensor->op_params)[2];
8990
9286
  //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3];
@@ -8997,9 +9293,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
8997
9293
  const float beta_slow = ((float *) tensor->op_params)[10];
8998
9294
  if (mode & GGML_ROPE_TYPE_MROPE) {
8999
9295
  int32_t *sections = ((int32_t *) tensor->op_params) + 11;
9000
- tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9296
+ if (tensor->op == GGML_OP_ROPE) {
9297
+ tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9298
+ } else {
9299
+ tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9300
+ }
9001
9301
  } else {
9002
- tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9302
+ if (tensor->op == GGML_OP_ROPE) {
9303
+ tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9304
+ } else {
9305
+ tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow);
9306
+ }
9003
9307
  }
9004
9308
  } else if (tensor->op == GGML_OP_UNARY) {
9005
9309
  switch (ggml_get_unary_op(tensor)) {
@@ -9018,6 +9322,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9018
9322
  case GGML_UNARY_OP_TANH:
9019
9323
  tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]);
9020
9324
  break;
9325
+ case GGML_UNARY_OP_SIGMOID:
9326
+ tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]);
9327
+ break;
9021
9328
  default:
9022
9329
  std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
9023
9330
  GGML_ABORT("fatal error");
@@ -9082,6 +9389,9 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
9082
9389
  } else if (tensor->op == GGML_OP_RWKV_WKV6) {
9083
9390
  tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1],
9084
9391
  src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
9392
+ } else if (tensor->op == GGML_OP_RWKV_WKV7) {
9393
+ tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
9394
+ src_clone[4], src_clone[5], src_clone[6]);
9085
9395
  } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
9086
9396
  src_clone[0]->flags = src0->flags;
9087
9397
  tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],