@fugood/llama.node 0.3.17 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (193) hide show
  1. package/CMakeLists.txt +3 -1
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-cuda/arm64/llama-node.node +0 -0
  7. package/bin/linux-cuda/x64/llama-node.node +0 -0
  8. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  9. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  10. package/bin/win32/arm64/llama-node.node +0 -0
  11. package/bin/win32/arm64/node.lib +0 -0
  12. package/bin/win32/x64/llama-node.node +0 -0
  13. package/bin/win32/x64/node.lib +0 -0
  14. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/arm64/node.lib +0 -0
  16. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  17. package/bin/win32-vulkan/x64/node.lib +0 -0
  18. package/lib/binding.ts +39 -2
  19. package/lib/index.js +132 -1
  20. package/lib/index.ts +203 -3
  21. package/package.json +2 -1
  22. package/src/EmbeddingWorker.cpp +1 -1
  23. package/src/LlamaCompletionWorker.cpp +366 -19
  24. package/src/LlamaCompletionWorker.h +30 -10
  25. package/src/LlamaContext.cpp +213 -5
  26. package/src/LlamaContext.h +12 -0
  27. package/src/common.hpp +15 -0
  28. package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
  29. package/src/llama.cpp/.github/workflows/build.yml +41 -762
  30. package/src/llama.cpp/.github/workflows/docker.yml +5 -2
  31. package/src/llama.cpp/.github/workflows/release.yml +716 -0
  32. package/src/llama.cpp/.github/workflows/server.yml +12 -12
  33. package/src/llama.cpp/CMakeLists.txt +5 -17
  34. package/src/llama.cpp/cmake/build-info.cmake +8 -2
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
  36. package/src/llama.cpp/common/CMakeLists.txt +31 -3
  37. package/src/llama.cpp/common/arg.cpp +48 -29
  38. package/src/llama.cpp/common/chat.cpp +128 -106
  39. package/src/llama.cpp/common/chat.h +2 -0
  40. package/src/llama.cpp/common/common.cpp +37 -1
  41. package/src/llama.cpp/common/common.h +18 -9
  42. package/src/llama.cpp/common/llguidance.cpp +1 -0
  43. package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
  44. package/src/llama.cpp/common/minja/minja.hpp +69 -36
  45. package/src/llama.cpp/common/regex-partial.cpp +204 -0
  46. package/src/llama.cpp/common/regex-partial.h +56 -0
  47. package/src/llama.cpp/common/sampling.cpp +57 -50
  48. package/src/llama.cpp/examples/CMakeLists.txt +2 -23
  49. package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
  50. package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
  51. package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
  52. package/src/llama.cpp/examples/training/finetune.cpp +96 -0
  53. package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
  54. package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
  55. package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
  56. package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
  57. package/src/llama.cpp/ggml/include/ggml.h +10 -7
  58. package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
  60. package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
  61. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
  62. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
  63. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
  64. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
  65. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
  66. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
  67. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
  68. package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
  69. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
  70. package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
  71. package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
  72. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
  73. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
  74. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
  75. package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
  76. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
  77. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
  78. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
  79. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
  80. package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
  81. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
  82. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
  83. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
  84. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
  85. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
  86. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
  87. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
  88. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
  89. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
  90. package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
  91. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
  92. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
  93. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
  94. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
  95. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
  96. package/src/llama.cpp/ggml/src/ggml.c +29 -20
  97. package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
  98. package/src/llama.cpp/include/llama.h +52 -11
  99. package/src/llama.cpp/requirements/requirements-all.txt +3 -3
  100. package/src/llama.cpp/scripts/xxd.cmake +1 -1
  101. package/src/llama.cpp/src/CMakeLists.txt +1 -0
  102. package/src/llama.cpp/src/llama-adapter.cpp +6 -0
  103. package/src/llama.cpp/src/llama-arch.cpp +3 -0
  104. package/src/llama.cpp/src/llama-batch.cpp +5 -1
  105. package/src/llama.cpp/src/llama-batch.h +2 -1
  106. package/src/llama.cpp/src/llama-chat.cpp +17 -7
  107. package/src/llama.cpp/src/llama-chat.h +1 -0
  108. package/src/llama.cpp/src/llama-context.cpp +389 -501
  109. package/src/llama.cpp/src/llama-context.h +44 -32
  110. package/src/llama.cpp/src/llama-cparams.h +1 -0
  111. package/src/llama.cpp/src/llama-graph.cpp +20 -38
  112. package/src/llama.cpp/src/llama-graph.h +12 -8
  113. package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
  114. package/src/llama.cpp/src/llama-kv-cache.h +271 -85
  115. package/src/llama.cpp/src/llama-memory.h +11 -1
  116. package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
  117. package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
  118. package/src/llama.cpp/src/llama-model-saver.h +37 -0
  119. package/src/llama.cpp/src/llama-model.cpp +316 -69
  120. package/src/llama.cpp/src/llama-model.h +8 -1
  121. package/src/llama.cpp/src/llama-quant.cpp +15 -13
  122. package/src/llama.cpp/src/llama-sampling.cpp +18 -6
  123. package/src/llama.cpp/src/llama-vocab.cpp +42 -4
  124. package/src/llama.cpp/src/llama-vocab.h +6 -0
  125. package/src/llama.cpp/src/llama.cpp +14 -0
  126. package/src/llama.cpp/tests/CMakeLists.txt +10 -2
  127. package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
  128. package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
  129. package/src/llama.cpp/tests/test-chat.cpp +3 -1
  130. package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
  131. package/src/llama.cpp/tests/test-opt.cpp +33 -21
  132. package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
  133. package/src/llama.cpp/tests/test-sampling.cpp +1 -1
  134. package/src/llama.cpp/tools/CMakeLists.txt +39 -0
  135. package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
  136. package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
  137. package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
  138. package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
  139. package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
  140. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
  141. package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
  142. package/src/llama.cpp/tools/mtmd/clip.h +99 -0
  143. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
  144. package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
  145. package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
  146. package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
  147. package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
  148. package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
  149. package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
  150. package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
  151. package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
  152. package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
  153. package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
  154. package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
  155. package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
  156. package/src/llama.cpp/examples/infill/infill.cpp +0 -590
  157. package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
  158. package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
  159. package/src/llama.cpp/examples/llava/clip.h +0 -135
  160. package/src/llama.cpp/examples/llava/llava.cpp +0 -586
  161. package/src/llama.cpp/examples/llava/llava.h +0 -49
  162. package/src/llama.cpp/examples/llava/mtmd.h +0 -168
  163. package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
  164. /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
  165. /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
  166. /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
  167. /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
  168. /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
  169. /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
  170. /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
  171. /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
  172. /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
  173. /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
  174. /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
  175. /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
  176. /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
  177. /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
  178. /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
  179. /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
  180. /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
  181. /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
  182. /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
  183. /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
  184. /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
  185. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
  186. /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
  187. /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
  188. /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
  189. /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
  190. /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
  191. /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
  192. /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
  193. /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
@@ -51,6 +51,24 @@
51
51
 
52
52
  #include "ggml-vulkan-shaders.hpp"
53
53
 
54
+ // remove this once it's more widely available in the SDK
55
+ #if !defined(VK_KHR_shader_bfloat16)
56
+
57
+ #define VK_KHR_shader_bfloat16 1
58
+ #define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
59
+ #define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
60
+ #define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
61
+ #define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
62
+
63
+ typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
64
+ VkStructureType sType;
65
+ void* pNext;
66
+ VkBool32 shaderBFloat16Type;
67
+ VkBool32 shaderBFloat16DotProduct;
68
+ VkBool32 shaderBFloat16CooperativeMatrix;
69
+ } VkPhysicalDeviceShaderBfloat16FeaturesKHR;
70
+ #endif
71
+
54
72
  #define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
55
73
  #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
56
74
  static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
@@ -257,6 +275,7 @@ struct vk_device_struct {
257
275
  bool prefer_host_memory;
258
276
  bool float_controls_rte_fp16;
259
277
  bool subgroup_add;
278
+ bool subgroup_shuffle;
260
279
 
261
280
  bool integer_dot_product;
262
281
 
@@ -266,8 +285,12 @@ struct vk_device_struct {
266
285
  bool subgroup_require_full_support;
267
286
 
268
287
  bool coopmat_support;
269
- bool coopmat_acc_f32_support;
270
- bool coopmat_acc_f16_support;
288
+ bool coopmat_acc_f32_support {};
289
+ bool coopmat_acc_f16_support {};
290
+ bool coopmat_bf16_support {};
291
+ bool coopmat_support_16x16x16_f16acc {};
292
+ bool coopmat_support_16x16x16_f32acc {};
293
+ bool coopmat1_fa_support {};
271
294
  uint32_t coopmat_m;
272
295
  uint32_t coopmat_n;
273
296
  uint32_t coopmat_k;
@@ -293,6 +316,7 @@ struct vk_device_struct {
293
316
 
294
317
  vk_matmul_pipeline pipeline_matmul_f32 {};
295
318
  vk_matmul_pipeline pipeline_matmul_f32_f16 {};
319
+ vk_matmul_pipeline pipeline_matmul_bf16 {};
296
320
  vk_matmul_pipeline2 pipeline_matmul_f16;
297
321
  vk_matmul_pipeline2 pipeline_matmul_f16_f32;
298
322
 
@@ -301,6 +325,7 @@ struct vk_device_struct {
301
325
  vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
302
326
 
303
327
  vk_matmul_pipeline pipeline_matmul_id_f32 {};
328
+ vk_matmul_pipeline pipeline_matmul_id_bf16 {};
304
329
  vk_matmul_pipeline2 pipeline_matmul_id_f16;
305
330
  vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
306
331
 
@@ -319,11 +344,17 @@ struct vk_device_struct {
319
344
  vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
320
345
  vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
321
346
  vk_pipeline pipeline_acc_f32;
322
- vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
323
- vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
324
- vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
325
- vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
326
- vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
347
+
348
+ // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
349
+ vk_pipeline pipeline_add[2][2][2];
350
+ vk_pipeline pipeline_add_norepeat[2][2][2];
351
+ vk_pipeline pipeline_sub[2][2][2];
352
+ vk_pipeline pipeline_sub_norepeat[2][2][2];
353
+ vk_pipeline pipeline_mul[2][2][2];
354
+ vk_pipeline pipeline_mul_norepeat[2][2][2];
355
+ vk_pipeline pipeline_div[2][2][2];
356
+ vk_pipeline pipeline_div_norepeat[2][2][2];
357
+
327
358
  vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
328
359
  vk_pipeline pipeline_upscale_f32;
329
360
  vk_pipeline pipeline_scale_f32;
@@ -333,8 +364,8 @@ struct vk_device_struct {
333
364
  vk_pipeline pipeline_clamp_f32;
334
365
  vk_pipeline pipeline_pad_f32;
335
366
  vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
336
- vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
337
- vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
367
+ vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
368
+ vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
338
369
  vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
339
370
  vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
340
371
  vk_pipeline pipeline_norm_f32;
@@ -342,14 +373,17 @@ struct vk_device_struct {
342
373
  vk_pipeline pipeline_rms_norm_f32;
343
374
  vk_pipeline pipeline_rms_norm_back_f32;
344
375
  vk_pipeline pipeline_l2_norm_f32;
345
- vk_pipeline pipeline_gelu_f32;
346
- vk_pipeline pipeline_gelu_quick_f32;
347
- vk_pipeline pipeline_silu_f32;
348
- vk_pipeline pipeline_silu_back_f32;
349
- vk_pipeline pipeline_relu_f32;
376
+
377
+ // [src/dst 0=fp32,1=fp16]
378
+ vk_pipeline pipeline_gelu[2];
379
+ vk_pipeline pipeline_gelu_quick[2];
380
+ vk_pipeline pipeline_silu[2];
381
+ vk_pipeline pipeline_relu[2];
382
+ vk_pipeline pipeline_tanh[2];
383
+ vk_pipeline pipeline_sigmoid[2];
384
+
350
385
  vk_pipeline pipeline_leaky_relu_f32;
351
- vk_pipeline pipeline_tanh_f32;
352
- vk_pipeline pipeline_sigmoid_f32;
386
+ vk_pipeline pipeline_silu_back_f32;
353
387
  vk_pipeline pipeline_diag_mask_inf_f32;
354
388
  vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
355
389
  vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@@ -368,14 +402,31 @@ struct vk_device_struct {
368
402
  vk_pipeline pipeline_rwkv_wkv6_f32;
369
403
  vk_pipeline pipeline_rwkv_wkv7_f32;
370
404
  vk_pipeline pipeline_opt_step_adamw_f32;
405
+ vk_pipeline pipeline_conv2d_dw_whcn_f32;
406
+ vk_pipeline pipeline_conv2d_dw_cwhn_f32;
371
407
 
372
408
  // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
409
+ vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
410
+ vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
411
+ vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
412
+ vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
413
+ vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
414
+ vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
415
+
416
+ vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
417
+ vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
418
+ vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
419
+ vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
420
+ vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
421
+ vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
422
+
373
423
  vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
374
424
  vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
375
425
  vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
376
426
  vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
377
427
  vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
378
428
  vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
429
+
379
430
  vk_pipeline pipeline_flash_attn_split_k_reduce;
380
431
 
381
432
  std::unordered_map<std::string, vk_pipeline_ref> pipelines;
@@ -680,6 +731,24 @@ struct vk_op_rwkv_wkv7_push_constants {
680
731
  uint32_t H;
681
732
  };
682
733
 
734
+ struct vk_op_conv2d_dw_push_constants {
735
+ uint32_t ne;
736
+ uint32_t batches;
737
+ uint32_t channels;
738
+ uint32_t dst_w;
739
+ uint32_t dst_h;
740
+ uint32_t src_w;
741
+ uint32_t src_h;
742
+ uint32_t knl_w;
743
+ uint32_t knl_h;
744
+ int32_t stride_x;
745
+ int32_t stride_y;
746
+ int32_t pad_x;
747
+ int32_t pad_y;
748
+ int32_t dilation_x;
749
+ int32_t dilation_y;
750
+ };
751
+
683
752
  struct vk_op_upscale_push_constants {
684
753
  uint32_t ne; uint32_t a_offset; uint32_t d_offset;
685
754
  uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
@@ -1529,15 +1598,56 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
1529
1598
  );
1530
1599
  }
1531
1600
 
1601
+ enum FaCodePath {
1602
+ FA_SCALAR,
1603
+ FA_COOPMAT1,
1604
+ FA_COOPMAT2,
1605
+ };
1606
+
1532
1607
  // number of rows/cols for flash attention shader
1533
1608
  static constexpr uint32_t flash_attention_num_small_rows = 32;
1534
- static std::array<uint32_t, 2> fa_rows_cols(uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1609
+ static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
1610
+ static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
1611
+
1612
+ // The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
1613
+ // 128 threads split into four subgroups, each subgroup does 1/4
1614
+ // of the Bc dimension.
1615
+ static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
1616
+ static constexpr uint32_t scalar_flash_attention_Bc = 64;
1617
+ static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
1618
+
1619
+ static uint32_t get_fa_num_small_rows(FaCodePath path) {
1620
+ if (path == FA_COOPMAT2) {
1621
+ return flash_attention_num_small_rows;
1622
+ } else {
1623
+ return scalar_flash_attention_num_small_rows;
1624
+ }
1625
+ }
1626
+
1627
+ static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
1535
1628
  GGML_UNUSED(clamp);
1536
1629
 
1630
+ if (path == FA_SCALAR) {
1631
+ if (small_rows) {
1632
+ return {scalar_flash_attention_num_small_rows, 64};
1633
+ } else {
1634
+ return {scalar_flash_attention_num_large_rows, 32};
1635
+ }
1636
+ }
1637
+
1638
+ if (path == FA_COOPMAT1) {
1639
+ if (small_rows) {
1640
+ return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
1641
+ } else {
1642
+ return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
1643
+ }
1644
+ }
1645
+
1537
1646
  // small rows, large cols
1538
1647
  if (small_rows) {
1539
- return {flash_attention_num_small_rows, 64};
1648
+ return {get_fa_num_small_rows(FA_COOPMAT2), 32};
1540
1649
  }
1650
+
1541
1651
  // small cols to reduce register count
1542
1652
  if (ggml_is_quantized(type) || D == 256) {
1543
1653
  return {64, 32};
@@ -1582,7 +1692,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
1582
1692
  const uint32_t warps = warptile[0] / warptile[10];
1583
1693
 
1584
1694
  const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
1585
- const uint32_t mmid_row_ids = mul_mat_id ? 3072 * sizeof(uint32_t) : 0;
1695
+ const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
1586
1696
  const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
1587
1697
 
1588
1698
  const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
@@ -1791,6 +1901,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
1791
1901
  if (!device->pipeline_matmul_id_f32) {
1792
1902
  device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
1793
1903
  }
1904
+ if (!device->pipeline_matmul_bf16) {
1905
+ device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
1906
+ }
1907
+ if (!device->pipeline_matmul_id_bf16) {
1908
+ device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
1909
+ }
1794
1910
 
1795
1911
  std::vector<std::future<void>> compiles;
1796
1912
  auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
@@ -1826,65 +1942,75 @@ static void ggml_vk_load_shaders(vk_device& device) {
1826
1942
  parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
1827
1943
  };
1828
1944
 
1829
- #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1830
- if (device->coopmat2) {
1831
-
1832
- auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1833
- return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
1834
- };
1945
+ auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
1946
+ return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
1947
+ };
1835
1948
 
1836
- auto const &fa_spec_constants = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
1837
- // For large number of rows, 128 invocations seems to work best.
1838
- // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
1839
- // can't use 256 for D==80.
1840
- uint32_t wg_size = (small_rows && (D % 32) == 0) ? 256 : 128;
1841
- auto rows_cols = fa_rows_cols(D, clamp, type, small_rows);
1842
- // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
1843
- GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
1844
- return {wg_size, rows_cols[0], rows_cols[1], (D), clamp};
1845
- };
1949
+ auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
1950
+ // For large number of rows, 128 invocations seems to work best.
1951
+ // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
1952
+ // can't use 256 for D==80.
1953
+ // For scalar, use 128 (arbitrary)
1954
+ uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
1955
+ ? scalar_flash_attention_workgroup_size
1956
+ : ((small_rows && (D % 32) == 0) ? 256 : 128);
1957
+ auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
1958
+
1959
+ // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
1960
+ // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
1961
+ const uint32_t D_lsb = D ^ (D & (D-1));
1962
+ uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
1963
+
1964
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
1965
+ GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
1966
+ return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
1967
+ };
1846
1968
 
1847
- #define CREATE_FA2(TYPE, NAMELC, D) \
1848
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1849
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1850
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,false), fa_spec_constants(D,1,TYPE,false), 1); \
1851
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,false), fa_spec_constants(D,0,TYPE,false), fa_rows_cols(D,0,TYPE,false)[1]); \
1852
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1853
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_len, flash_attn_f32_f16_ ## NAMELC ## _f16acc_cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1854
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,1,TYPE,true), fa_spec_constants(D,1,TYPE,true), 1); \
1855
- ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _cm2_len, flash_attn_f32_f16_ ## NAMELC ## _cm2_data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(D,0,TYPE,true), fa_spec_constants(D,0,TYPE,true), fa_rows_cols(D,0,TYPE,true)[1]); \
1856
-
1857
- #define CREATE_FA(TYPE, NAMELC) \
1858
- CREATE_FA2(TYPE, NAMELC, 64) \
1859
- CREATE_FA2(TYPE, NAMELC, 80) \
1860
- CREATE_FA2(TYPE, NAMELC, 96) \
1861
- CREATE_FA2(TYPE, NAMELC, 112) \
1862
- CREATE_FA2(TYPE, NAMELC, 128) \
1863
- CREATE_FA2(TYPE, NAMELC, 256)
1864
-
1865
- CREATE_FA(GGML_TYPE_F16, f16)
1866
- CREATE_FA(GGML_TYPE_Q4_0, q4_0)
1867
- CREATE_FA(GGML_TYPE_Q4_1, q4_1)
1868
- CREATE_FA(GGML_TYPE_Q5_0, q5_0)
1869
- CREATE_FA(GGML_TYPE_Q5_1, q5_1)
1870
- CREATE_FA(GGML_TYPE_Q8_0, q8_0)
1871
- // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
1872
- //CREATE_FA(GGML_TYPE_Q2_K, q2_k)
1873
- //CREATE_FA(GGML_TYPE_Q3_K, q3_k)
1874
- //CREATE_FA(GGML_TYPE_Q4_K, q4_k)
1875
- //CREATE_FA(GGML_TYPE_Q5_K, q5_k)
1876
- //CREATE_FA(GGML_TYPE_Q6_K, q6_k)
1877
- //CREATE_FA(GGML_TYPE_IQ1_S, iq1_s)
1878
- //CREATE_FA(GGML_TYPE_IQ1_M, iq1_m)
1879
- //CREATE_FA(GGML_TYPE_IQ2_XXS, iq2_xxs)
1880
- //CREATE_FA(GGML_TYPE_IQ2_XS, iq2_xs)
1881
- //CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
1882
- //CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
1883
- //CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
1884
- //CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs)
1885
- CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
1969
+ #define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
1970
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1971
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1972
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1973
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1974
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1975
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1976
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1977
+ ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
1978
+
1979
+ #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
1980
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
1981
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
1982
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
1983
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
1984
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
1985
+ CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
1986
+
1987
+ CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
1988
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
1989
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
1990
+ #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
1991
+ if (device->coopmat1_fa_support) {
1992
+ CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
1993
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
1994
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
1995
+ }
1996
+ #endif
1997
+ #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
1998
+ if (device->coopmat2) {
1999
+ CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
2000
+ CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
2001
+ CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
2002
+ CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
2003
+ CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
2004
+ CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
2005
+ CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
2006
+ }
2007
+ #endif
2008
+ #undef CREATE_FA2
1886
2009
  #undef CREATE_FA
1887
2010
 
2011
+ #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
2012
+ if (device->coopmat2) {
2013
+
1888
2014
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1889
2015
  #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1890
2016
  ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
@@ -1900,6 +2026,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1900
2026
  CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
1901
2027
 
1902
2028
  CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
2029
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2030
+ if (device->coopmat_bf16_support) {
2031
+ CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
2032
+ }
2033
+ #endif
1903
2034
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1904
2035
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1905
2036
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1921,6 +2052,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1921
2052
  CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
1922
2053
 
1923
2054
  CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2055
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2056
+ if (device->coopmat_bf16_support) {
2057
+ CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
2058
+ }
2059
+ #endif
1924
2060
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1925
2061
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
1926
2062
  CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
@@ -1949,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
1949
2085
  // Create 6 variants, {s,m,l}x{unaligned,aligned}
1950
2086
  #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
1951
2087
  if (device->mul_mat ## ID ## _l[TYPE]) \
1952
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
2088
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
1953
2089
  if (device->mul_mat ## ID ## _m[TYPE]) \
1954
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
2090
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
1955
2091
  if (device->mul_mat ## ID ## _s[TYPE]) \
1956
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _coopmat_len, NAMELC ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
2092
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
1957
2093
  if (device->mul_mat ## ID ## _l[TYPE]) \
1958
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
2094
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
1959
2095
  if (device->mul_mat ## ID ## _m[TYPE]) \
1960
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
2096
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
1961
2097
  if (device->mul_mat ## ID ## _s[TYPE]) \
1962
- ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _coopmat_len, NAMELC ## _aligned ## F16ACC ## _coopmat_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
2098
+ ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
1963
2099
 
1964
2100
  // Create 2 variants, {f16,f32} accumulator
1965
2101
  #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
@@ -1974,6 +2110,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
1974
2110
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1975
2111
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
1976
2112
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2113
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2114
+ if (device->coopmat_bf16_support) {
2115
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
2116
+ }
2117
+ #endif
1977
2118
 
1978
2119
  if (device->coopmat_acc_f16_support) {
1979
2120
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2022,6 +2163,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
2022
2163
  CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2023
2164
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2024
2165
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2166
+ #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
2167
+ if (device->coopmat_bf16_support) {
2168
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2169
+ }
2170
+ #endif
2025
2171
 
2026
2172
  if (device->coopmat_acc_f16_support) {
2027
2173
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2104,6 +2250,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2104
2250
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2105
2251
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2106
2252
 
2253
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2254
+
2107
2255
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2108
2256
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2109
2257
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2139,6 +2287,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2139
2287
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2140
2288
  CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2141
2289
 
2290
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2291
+
2142
2292
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2143
2293
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2144
2294
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2191,6 +2341,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2191
2341
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2192
2342
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2193
2343
 
2344
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2345
+
2194
2346
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2195
2347
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
2196
2348
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@@ -2226,6 +2378,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
2226
2378
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2227
2379
  CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
2228
2380
 
2381
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2382
+
2229
2383
  CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2230
2384
  CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2231
2385
  CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@@ -2246,8 +2400,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
2246
2400
  CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2247
2401
  CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2248
2402
  CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
2249
- #undef CREATE_MM
2250
2403
  }
2404
+ // reusing CREATE_MM from the fp32 path
2405
+ if ((device->coopmat2 || device->coopmat_support)
2406
+ #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
2407
+ && !device->coopmat_bf16_support
2408
+ #endif
2409
+ ) {
2410
+ // use scalar tile sizes
2411
+ l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
2412
+ m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
2413
+ s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
2414
+
2415
+ l_wg_denoms = {128, 128, 1 };
2416
+ m_wg_denoms = { 64, 64, 1 };
2417
+ s_wg_denoms = { 32, 32, 1 };
2418
+
2419
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
2420
+ CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
2421
+ }
2422
+ #undef CREATE_MM
2251
2423
 
2252
2424
  // mul mat vec
2253
2425
 
@@ -2266,6 +2438,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2266
2438
  for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
2267
2439
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2268
2440
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2441
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2269
2442
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2270
2443
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2271
2444
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
@@ -2288,6 +2461,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2288
2461
 
2289
2462
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2290
2463
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2464
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
2291
2465
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2292
2466
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
2293
2467
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
@@ -2311,6 +2485,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2311
2485
 
2312
2486
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2313
2487
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2488
+ ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
2314
2489
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
2315
2490
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
2316
2491
  ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
@@ -2356,6 +2531,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2356
2531
  // get_rows
2357
2532
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2358
2533
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2534
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2359
2535
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2360
2536
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2361
2537
  ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2373,6 +2549,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2373
2549
 
2374
2550
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2375
2551
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2552
+ ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
2376
2553
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2377
2554
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
2378
2555
  ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
@@ -2399,7 +2576,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
2399
2576
  ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
2400
2577
  }
2401
2578
  }
2402
- ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 7 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2579
+ ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
2403
2580
 
2404
2581
  ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
2405
2582
  ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
@@ -2410,10 +2587,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
2410
2587
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2411
2588
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2412
2589
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2590
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2591
+ ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2413
2592
 
2414
2593
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2415
2594
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2416
2595
  ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2596
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2597
+ ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2598
+
2417
2599
  if (device->float_controls_rte_fp16) {
2418
2600
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
2419
2601
  ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
@@ -2437,20 +2619,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
2437
2619
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
2438
2620
  ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
2439
2621
 
2440
- ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2441
- ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2442
- ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2443
- ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2622
+ auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
2623
+ std::string s;
2624
+ s += std::string(src0_f16 ? "_f16" : "_f32");
2625
+ s += std::string(src1_f16 ? "_f16" : "_f32");
2626
+ s += std::string(dst_f16 ? "_f16" : "_f32");
2627
+ return s;
2628
+ };
2629
+
2630
+ #define CREATE_BINARY(name, namemod, spec) \
2631
+ for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
2632
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
2633
+ #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
2634
+ "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
2635
+
2636
+ CREATE_BINARY(add, , {0})
2637
+ CREATE_BINARY(add, _norepeat, {1})
2638
+ CREATE_BINARY(sub, , {0})
2639
+ CREATE_BINARY(sub, _norepeat, {1})
2640
+ CREATE_BINARY(mul, , {0})
2641
+ CREATE_BINARY(mul, _norepeat, {1})
2642
+ CREATE_BINARY(div, , {0})
2643
+ CREATE_BINARY(div, _norepeat, {1})
2644
+ #undef CREATE_BINARY
2444
2645
 
2445
2646
  ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2446
2647
 
2447
- ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2448
- ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2449
- ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2450
- ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2451
- ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
2452
- ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
2453
-
2454
2648
  ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2455
2649
  ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
2456
2650
  ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -2470,14 +2664,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
2470
2664
  ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2471
2665
  ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
2472
2666
 
2473
- ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2474
- ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2475
- ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2476
- ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2477
- ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2667
+ #define CREATE_UNARY(name) \
2668
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
2669
+ ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2670
+
2671
+ CREATE_UNARY(gelu)
2672
+ CREATE_UNARY(gelu_quick)
2673
+ CREATE_UNARY(silu)
2674
+ CREATE_UNARY(relu)
2675
+ CREATE_UNARY(tanh)
2676
+ CREATE_UNARY(sigmoid)
2677
+ #undef CREATE_UNARY
2678
+
2478
2679
  ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2479
- ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2480
- ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2680
+ ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2481
2681
 
2482
2682
  ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
2483
2683
 
@@ -2529,6 +2729,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
2529
2729
 
2530
2730
  ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
2531
2731
 
2732
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2733
+ ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
2734
+
2532
2735
  for (auto &c : compiles) {
2533
2736
  c.wait();
2534
2737
  }
@@ -2578,6 +2781,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2578
2781
  bool coopmat2_support = false;
2579
2782
  device->coopmat_support = false;
2580
2783
  device->integer_dot_product = false;
2784
+ bool bfloat16_support = false;
2581
2785
 
2582
2786
  for (const auto& properties : ext_props) {
2583
2787
  if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
@@ -2608,6 +2812,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2608
2812
  !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
2609
2813
  device->integer_dot_product = true;
2610
2814
  #endif
2815
+ } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
2816
+ !getenv("GGML_VK_DISABLE_BFLOAT16")) {
2817
+ bfloat16_support = true;
2611
2818
  }
2612
2819
  }
2613
2820
 
@@ -2700,6 +2907,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2700
2907
  device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2701
2908
  (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
2702
2909
 
2910
+ device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
2911
+ (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
2912
+
2703
2913
  const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2704
2914
 
2705
2915
  device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
@@ -2794,6 +3004,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
2794
3004
  }
2795
3005
  #endif
2796
3006
 
3007
+ #if defined(VK_KHR_shader_bfloat16)
3008
+ VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
3009
+ bfloat16_features.pNext = nullptr;
3010
+ bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
3011
+ if (bfloat16_support) {
3012
+ last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
3013
+ last_struct = (VkBaseOutStructure *)&bfloat16_features;
3014
+ device_extensions.push_back("VK_KHR_shader_bfloat16");
3015
+ }
3016
+ #endif
3017
+
2797
3018
  VkPhysicalDeviceMaintenance4Features maint4_features {};
2798
3019
  maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
2799
3020
  if (maintenance4_support) {
@@ -2832,6 +3053,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
2832
3053
 
2833
3054
  #if defined(VK_KHR_cooperative_matrix)
2834
3055
  device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
3056
+
3057
+ // coopmat1 fa shader currently assumes 32 invocations per subgroup
3058
+ device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
3059
+ device->subgroup_size_control && device->subgroup_min_size <= 32 &&
3060
+ device->subgroup_max_size >= 32;
2835
3061
  #endif
2836
3062
 
2837
3063
  if (coopmat2_support) {
@@ -2966,6 +3192,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2966
3192
  // Only enable if shape is identical
2967
3193
  device->coopmat_acc_f32_support = true;
2968
3194
  }
3195
+ if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
3196
+ device->coopmat_support_16x16x16_f32acc = true;
3197
+ }
2969
3198
  } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
2970
3199
  (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
2971
3200
  // coopmat sizes not set yet
@@ -2978,6 +3207,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
2978
3207
  // Only enable if shape is identical
2979
3208
  device->coopmat_acc_f16_support = true;
2980
3209
  }
3210
+ if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
3211
+ device->coopmat_support_16x16x16_f16acc = true;
3212
+ }
2981
3213
  }
2982
3214
  } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
2983
3215
  (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
@@ -2991,6 +3223,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
2991
3223
  device->coopmat_int_n = prop.NSize;
2992
3224
  device->coopmat_int_k = prop.KSize;
2993
3225
  }
3226
+ #if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
3227
+ if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
3228
+ prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
3229
+ prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
3230
+ prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
3231
+ (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
3232
+ ) {
3233
+ // coopmat sizes not set yet
3234
+ if (device->coopmat_m == 0) {
3235
+ device->coopmat_bf16_support = true;
3236
+ device->coopmat_m = prop.MSize;
3237
+ device->coopmat_n = prop.NSize;
3238
+ device->coopmat_k = prop.KSize;
3239
+ } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
3240
+ // Only enable if shape is identical
3241
+ device->coopmat_bf16_support = true;
3242
+ }
3243
+ }
3244
+ #endif
2994
3245
  }
2995
3246
 
2996
3247
  if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
@@ -2998,11 +3249,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
2998
3249
  GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
2999
3250
  device->coopmat_support = false;
3000
3251
  }
3252
+ if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
3253
+ device->coopmat_bf16_support = false;
3254
+ }
3001
3255
  }
3002
3256
 
3003
3257
  if (device->coopmat_support) {
3004
3258
  device_extensions.push_back("VK_KHR_cooperative_matrix");
3005
3259
  }
3260
+ #if defined(VK_KHR_shader_bfloat16)
3261
+ if (device->coopmat_bf16_support) {
3262
+ device_extensions.push_back("VK_KHR_shader_bfloat16");
3263
+ }
3264
+ #endif
3006
3265
  #endif
3007
3266
  device->name = GGML_VK_NAME + std::to_string(idx);
3008
3267
 
@@ -3459,6 +3718,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
3459
3718
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
3460
3719
  return ctx->device->pipeline_matmul_f32_f16;
3461
3720
  }
3721
+ if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
3722
+ return ctx->device->pipeline_matmul_bf16;
3723
+ }
3462
3724
  if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
3463
3725
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3464
3726
  return ctx->device->pipeline_matmul_f16_f32.f16acc;
@@ -3530,6 +3792,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
3530
3792
  switch (a_type) {
3531
3793
  case GGML_TYPE_F32:
3532
3794
  case GGML_TYPE_F16:
3795
+ case GGML_TYPE_BF16:
3533
3796
  case GGML_TYPE_Q4_0:
3534
3797
  case GGML_TYPE_Q4_1:
3535
3798
  case GGML_TYPE_Q5_0:
@@ -3562,6 +3825,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
3562
3825
  if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
3563
3826
  return ctx->device->pipeline_matmul_id_f32;
3564
3827
  }
3828
+ if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
3829
+ return ctx->device->pipeline_matmul_id_bf16;
3830
+ }
3565
3831
  if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
3566
3832
  if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
3567
3833
  return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
@@ -3615,6 +3881,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
3615
3881
  switch (a_type) {
3616
3882
  case GGML_TYPE_F32:
3617
3883
  case GGML_TYPE_F16:
3884
+ case GGML_TYPE_BF16:
3618
3885
  case GGML_TYPE_Q4_0:
3619
3886
  case GGML_TYPE_Q4_1:
3620
3887
  case GGML_TYPE_Q5_0:
@@ -4350,6 +4617,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
4350
4617
  return ctx->device->pipeline_cpy_f16_f16;
4351
4618
  }
4352
4619
  }
4620
+ if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
4621
+ if (contig) {
4622
+ return ctx->device->pipeline_contig_cpy_f16_f32;
4623
+ } else {
4624
+ return ctx->device->pipeline_cpy_f16_f32;
4625
+ }
4626
+ }
4627
+ if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
4628
+ if (contig) {
4629
+ return ctx->device->pipeline_contig_cpy_f32_bf16;
4630
+ } else {
4631
+ return ctx->device->pipeline_cpy_f32_bf16;
4632
+ }
4633
+ }
4353
4634
  if (src->type == GGML_TYPE_F32) {
4354
4635
  switch (to) {
4355
4636
  case GGML_TYPE_Q4_0:
@@ -4477,8 +4758,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4477
4758
  const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
4478
4759
  !ggml_vk_dim01_contiguous(src0);
4479
4760
  const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
4761
+ (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
4480
4762
  !ggml_vk_dim01_contiguous(src1);
4481
4763
 
4764
+ // If src0 is BF16, try to use a BF16 x BF16 multiply
4765
+ ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
4766
+
4482
4767
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
4483
4768
 
4484
4769
  bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
@@ -4488,25 +4773,25 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4488
4773
 
4489
4774
  if (mmp == nullptr) {
4490
4775
  // Fall back to f16 dequant mul mat
4491
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
4776
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
4492
4777
  quantize_y = false;
4493
4778
  }
4494
4779
 
4495
4780
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
4496
- const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig);
4781
+ const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
4497
4782
 
4498
4783
  if (qx_needs_dequant) {
4499
4784
  // Fall back to dequant + f16 mulmat
4500
- mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
4785
+ mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
4501
4786
  }
4502
4787
 
4503
4788
  // Not implemented
4504
4789
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
4505
4790
 
4506
- const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
4791
+ const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
4507
4792
  const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
4508
4793
 
4509
- vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
4794
+ vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
4510
4795
 
4511
4796
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
4512
4797
  uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
@@ -4527,12 +4812,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
4527
4812
  vk_pipeline to_q8_1 = nullptr;
4528
4813
 
4529
4814
  if (x_non_contig) {
4530
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
4815
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
4531
4816
  } else {
4532
4817
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
4533
4818
  }
4534
4819
  if (y_non_contig) {
4535
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
4820
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
4536
4821
  } else {
4537
4822
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
4538
4823
  }
@@ -4949,6 +5234,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
4949
5234
  const uint64_t nb01 = src0->nb[1];
4950
5235
  const uint64_t nb02 = src0->nb[2];
4951
5236
 
5237
+ const uint64_t nb12 = src1->nb[2];
5238
+
4952
5239
  // const uint64_t ne10 = src1->ne[0];
4953
5240
  const uint64_t ne11 = src1->ne[1];
4954
5241
  const uint64_t ne12 = src1->ne[2];
@@ -4974,6 +5261,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
4974
5261
 
4975
5262
  const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
4976
5263
  const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
5264
+ const uint32_t channel_stride_y = nb12 / sizeof(float);
4977
5265
 
4978
5266
  const uint64_t qx_sz = ggml_nbytes(src0);
4979
5267
  const uint64_t qy_sz = ggml_nbytes(src1);
@@ -5004,7 +5292,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
5004
5292
  const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
5005
5293
 
5006
5294
  // compute
5007
- const std::array<uint32_t, 7> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, (uint32_t)(ne12 / ne02), (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
5295
+ const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
5008
5296
  ggml_vk_sync_buffers(subctx);
5009
5297
  ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
5010
5298
  { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
@@ -5029,7 +5317,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
5029
5317
  // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
5030
5318
  // when ne12 and ne13 are one.
5031
5319
  } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
5032
- (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
5320
+ (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
5033
5321
  ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
5034
5322
  } else {
5035
5323
  ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
@@ -5056,7 +5344,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
5056
5344
 
5057
5345
  const uint64_t nei0 = ids->ne[0];
5058
5346
  const uint64_t nei1 = ids->ne[1];
5059
- GGML_ASSERT(nei0 * nei1 <= 3072);
5347
+ GGML_ASSERT(nei0 * nei1 <= 4096);
5060
5348
 
5061
5349
  const uint32_t nbi1 = ids->nb[1];
5062
5350
  const uint32_t nbi2 = ids->nb[2];
@@ -5097,27 +5385,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
5097
5385
  const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
5098
5386
  !ggml_vk_dim01_contiguous(src0);
5099
5387
  const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
5388
+ (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
5100
5389
  !ggml_vk_dim01_contiguous(src1);
5101
5390
 
5391
+ // If src0 is BF16, try to use a BF16 x BF16 multiply
5392
+ ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
5393
+
5102
5394
  const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
5103
5395
 
5104
- vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? GGML_TYPE_F16 : src1->type, (ggml_prec)dst->op_params[0]);
5396
+ vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
5105
5397
 
5106
5398
  const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
5107
- const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
5399
+ const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
5108
5400
 
5109
5401
  if (qx_needs_dequant) {
5110
5402
  // Fall back to dequant + f16 mulmat
5111
- mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
5403
+ mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
5112
5404
  }
5113
5405
 
5114
5406
  // Not implemented
5115
5407
  GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
5116
5408
 
5117
- const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? GGML_TYPE_F16 : src0->type));
5409
+ const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
5118
5410
  const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
5119
5411
 
5120
- vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? GGML_TYPE_F16 : src0->type);
5412
+ vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
5121
5413
 
5122
5414
  // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
5123
5415
  uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
@@ -5136,12 +5428,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
5136
5428
  vk_pipeline to_fp16_vk_1 = nullptr;
5137
5429
 
5138
5430
  if (x_non_contig) {
5139
- to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, GGML_TYPE_F16);
5431
+ to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
5140
5432
  } else {
5141
5433
  to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
5142
5434
  }
5143
5435
  if (y_non_contig) {
5144
- to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, GGML_TYPE_F16);
5436
+ to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
5145
5437
  } else {
5146
5438
  to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
5147
5439
  }
@@ -5451,6 +5743,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
5451
5743
  }
5452
5744
  }
5453
5745
 
5746
+ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
5747
+ // Needs to be kept up to date on shader changes
5748
+ const uint32_t wg_size = scalar_flash_attention_workgroup_size;
5749
+ const uint32_t Br = scalar_flash_attention_num_large_rows;
5750
+ const uint32_t Bc = scalar_flash_attention_Bc;
5751
+
5752
+ const uint32_t acctype = f32acc ? 4 : 2;
5753
+ const uint32_t f16vec4 = 8;
5754
+
5755
+ const uint32_t tmpsh = wg_size * sizeof(float);
5756
+ const uint32_t tmpshv4 = wg_size * 4 * acctype;
5757
+
5758
+ const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
5759
+
5760
+ const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
5761
+ const uint32_t sfsh = Bc * sfshstride * acctype;
5762
+
5763
+ const uint32_t kshstride = D / 4 + 2;
5764
+ const uint32_t ksh = Bc * kshstride * f16vec4;
5765
+
5766
+ const uint32_t slope = Br * sizeof(float);
5767
+
5768
+ const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
5769
+ const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
5770
+
5771
+ VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
5772
+
5773
+ return supported;
5774
+ }
5775
+
5454
5776
  static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
5455
5777
  VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
5456
5778
  std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
@@ -5501,36 +5823,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5501
5823
  assert(q->type == GGML_TYPE_F32);
5502
5824
  assert(k->type == v->type);
5503
5825
 
5504
- vk_pipeline *pipelines;
5505
- // XXX TODO other backends may be changing accumulator precision to default to f32 soon
5506
- bool f32acc = dst->op_params[3] == GGML_PREC_F32;
5507
- bool small_rows = N <= flash_attention_num_small_rows;
5508
- switch (D) {
5509
- case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
5510
- case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
5511
- case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
5512
- case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
5513
- case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
5514
- case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
5515
- default:
5516
- assert(!"unsupported D value");
5517
- return;
5518
- }
5519
- assert(pipelines);
5520
-
5521
- const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
5522
- const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
5523
- const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
5826
+ FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
5827
+ ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
5524
5828
 
5525
- bool aligned = (KV % pipelines[1]->align) == 0 &&
5526
- // the "aligned" shader variant will forcibly align strides, for performance
5527
- (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
5829
+ if (path == FA_COOPMAT1) {
5830
+ const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
5831
+ (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
5528
5832
 
5529
- // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
5530
- GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
5833
+ const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
5531
5834
 
5532
- vk_pipeline pipeline = pipelines[aligned];
5533
- assert(pipeline);
5835
+ if (!coopmat_shape_supported || !coopmat_shmem_supported) {
5836
+ path = FA_SCALAR;
5837
+ }
5838
+ }
5534
5839
 
5535
5840
  uint32_t gqa_ratio = 1;
5536
5841
  uint32_t qk_ratio = neq2 / nek2;
@@ -5538,7 +5843,23 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5538
5843
  uint32_t workgroups_y = (uint32_t)neq2;
5539
5844
  uint32_t workgroups_z = (uint32_t)neq3;
5540
5845
 
5541
- if (N == 1 && qk_ratio > 1 && gqa_ratio <= flash_attention_num_small_rows &&
5846
+ // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
5847
+ // For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
5848
+ uint32_t max_gqa;
5849
+ switch (path) {
5850
+ case FA_SCALAR:
5851
+ case FA_COOPMAT1:
5852
+ // We may switch from coopmat1 to scalar, so use the scalar limit for both
5853
+ max_gqa = scalar_flash_attention_num_large_rows;
5854
+ break;
5855
+ case FA_COOPMAT2:
5856
+ max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
5857
+ break;
5858
+ default:
5859
+ GGML_ASSERT(0);
5860
+ }
5861
+
5862
+ if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
5542
5863
  qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
5543
5864
  // grouped query attention - make the N dimension equal to gqa_ratio, reduce
5544
5865
  // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
@@ -5548,11 +5869,89 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
5548
5869
  workgroups_y /= N;
5549
5870
  }
5550
5871
 
5872
+ vk_pipeline *pipelines;
5873
+ bool small_rows = N <= get_fa_num_small_rows(path);
5874
+
5875
+ // coopmat1 does not actually support "small rows" (it needs 16 rows).
5876
+ // So use scalar instead.
5877
+ if (small_rows && path == FA_COOPMAT1) {
5878
+ path = FA_SCALAR;
5879
+ }
5880
+
5881
+ // scalar is faster than coopmat2 when N==1
5882
+ if (N == 1 && path == FA_COOPMAT2) {
5883
+ path = FA_SCALAR;
5884
+ }
5885
+
5886
+ bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
5887
+
5888
+ switch (path) {
5889
+ case FA_SCALAR:
5890
+ switch (D) {
5891
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
5892
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
5893
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
5894
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
5895
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
5896
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
5897
+ default:
5898
+ GGML_ASSERT(!"unsupported D value");
5899
+ return;
5900
+ }
5901
+ break;
5902
+ case FA_COOPMAT1:
5903
+ switch (D) {
5904
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
5905
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
5906
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
5907
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
5908
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
5909
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
5910
+ default:
5911
+ GGML_ASSERT(!"unsupported D value");
5912
+ return;
5913
+ }
5914
+ break;
5915
+ case FA_COOPMAT2:
5916
+ switch (D) {
5917
+ case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
5918
+ case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
5919
+ case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
5920
+ case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
5921
+ case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
5922
+ case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
5923
+ default:
5924
+ GGML_ASSERT(!"unsupported D value");
5925
+ return;
5926
+ }
5927
+ break;
5928
+ default:
5929
+ GGML_ASSERT(0);
5930
+ }
5931
+ assert(pipelines);
5932
+
5933
+ const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
5934
+ const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
5935
+ const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
5936
+
5937
+ bool aligned = (KV % pipelines[1]->align) == 0 &&
5938
+ // the "aligned" shader variant will forcibly align strides, for performance
5939
+ (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
5940
+
5941
+ // mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
5942
+ GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
5943
+
5944
+ vk_pipeline pipeline = pipelines[aligned];
5945
+ assert(pipeline);
5946
+
5551
5947
  uint32_t split_kv = KV;
5552
5948
  uint32_t split_k = 1;
5553
5949
 
5950
+ // Use a placeholder core count if one isn't available. split_k is a big help for perf.
5951
+ const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
5952
+
5554
5953
  // Try to use split_k when KV is large enough to be worth the overhead
5555
- if (workgroups_x == 1 && ctx->device->shader_core_count > 0 && KV >= 512) {
5954
+ if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
5556
5955
  // Try to run two workgroups per SM.
5557
5956
  split_k = ctx->device->shader_core_count * 2 / workgroups_y;
5558
5957
  if (split_k > 1) {
@@ -5722,26 +6121,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5722
6121
  }
5723
6122
  return nullptr;
5724
6123
  case GGML_OP_ADD:
5725
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5726
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
6124
+ case GGML_OP_SUB:
6125
+ case GGML_OP_MUL:
6126
+ case GGML_OP_DIV:
6127
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6128
+ (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
6129
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
6130
+ return nullptr;
6131
+ }
6132
+ switch (op) {
6133
+ case GGML_OP_ADD:
6134
+ {
6135
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
6136
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5727
6137
  }
5728
- if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
5729
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
6138
+ case GGML_OP_SUB:
6139
+ {
6140
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
6141
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5730
6142
  }
5731
- return nullptr;
5732
- case GGML_OP_SUB:
5733
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5734
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
6143
+ case GGML_OP_MUL:
6144
+ {
6145
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
6146
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5735
6147
  }
5736
- return nullptr;
5737
- case GGML_OP_MUL:
5738
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5739
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
6148
+ case GGML_OP_DIV:
6149
+ {
6150
+ auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
6151
+ return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
5740
6152
  }
5741
- return nullptr;
5742
- case GGML_OP_DIV:
5743
- if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5744
- return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
6153
+ default:
6154
+ break;
5745
6155
  }
5746
6156
  return nullptr;
5747
6157
  case GGML_OP_CONCAT:
@@ -5835,37 +6245,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5835
6245
  }
5836
6246
  return nullptr;
5837
6247
  case GGML_OP_UNARY:
6248
+ if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
6249
+ (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
6250
+ (src0->type != dst->type)) {
6251
+ return nullptr;
6252
+ }
6253
+
5838
6254
  switch (ggml_get_unary_op(dst)) {
5839
6255
  case GGML_UNARY_OP_SILU:
5840
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5841
- return ctx->device->pipeline_silu_f32;
5842
- }
5843
- break;
6256
+ return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
5844
6257
  case GGML_UNARY_OP_GELU:
5845
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5846
- return ctx->device->pipeline_gelu_f32;
5847
- }
5848
- break;
6258
+ return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
5849
6259
  case GGML_UNARY_OP_GELU_QUICK:
5850
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5851
- return ctx->device->pipeline_gelu_quick_f32;
5852
- }
5853
- break;
6260
+ return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
5854
6261
  case GGML_UNARY_OP_RELU:
5855
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5856
- return ctx->device->pipeline_relu_f32;
5857
- }
5858
- break;
6262
+ return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
5859
6263
  case GGML_UNARY_OP_TANH:
5860
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5861
- return ctx->device->pipeline_tanh_f32;
5862
- }
5863
- break;
6264
+ return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
5864
6265
  case GGML_UNARY_OP_SIGMOID:
5865
- if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
5866
- return ctx->device->pipeline_sigmoid_f32;
5867
- }
5868
- break;
6266
+ return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
5869
6267
  default:
5870
6268
  break;
5871
6269
  }
@@ -5988,6 +6386,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
5988
6386
  return ctx->device->pipeline_leaky_relu_f32;
5989
6387
  }
5990
6388
  return nullptr;
6389
+ case GGML_OP_CONV_2D_DW:
6390
+ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
6391
+ if (ggml_is_contiguous(src1)) {
6392
+ return ctx->device->pipeline_conv2d_dw_whcn_f32;
6393
+ } else if (ggml_is_contiguous_channels(src1)) {
6394
+ return ctx->device->pipeline_conv2d_dw_cwhn_f32;
6395
+ }
6396
+ }
6397
+ return nullptr;
5991
6398
  default:
5992
6399
  return nullptr;
5993
6400
  }
@@ -6014,6 +6421,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
6014
6421
  case GGML_OP_REPEAT_BACK:
6015
6422
  case GGML_OP_ROPE:
6016
6423
  case GGML_OP_RMS_NORM:
6424
+ case GGML_OP_CONV_2D_DW:
6017
6425
  return true;
6018
6426
  default:
6019
6427
  return false;
@@ -6310,6 +6718,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
6310
6718
  case GGML_OP_CONCAT:
6311
6719
  case GGML_OP_UPSCALE:
6312
6720
  case GGML_OP_UNARY:
6721
+ case GGML_OP_CONV_2D_DW:
6313
6722
  {
6314
6723
  const uint32_t ne = ggml_nelements(dst);
6315
6724
  if (ne > 262144) {
@@ -7096,6 +7505,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
7096
7505
  }, dryrun);
7097
7506
  }
7098
7507
 
7508
+ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
7509
+ vk_op_conv2d_dw_push_constants p{};
7510
+ p.ne = ggml_nelements(dst);
7511
+ p.channels = dst->ne[2];
7512
+ p.batches = dst->ne[3];
7513
+ p.dst_w = dst->ne[0];
7514
+ p.dst_h = dst->ne[1];
7515
+ p.src_w = src1->ne[0];
7516
+ p.src_h = src1->ne[1];
7517
+ p.knl_w = src0->ne[0];
7518
+ p.knl_h = src0->ne[1];
7519
+ p.stride_x = dst->op_params[0];
7520
+ p.stride_y = dst->op_params[1];
7521
+ p.pad_x = dst->op_params[2];
7522
+ p.pad_y = dst->op_params[3];
7523
+ p.dilation_x = dst->op_params[4];
7524
+ p.dilation_y = dst->op_params[5];
7525
+
7526
+ GGML_ASSERT(src0->ne[3] == p.channels);
7527
+ GGML_ASSERT(src1->ne[3] == p.batches);
7528
+
7529
+ ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
7530
+ }
7531
+
7099
7532
  static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
7100
7533
  const float * op_params = (const float *)dst->op_params;
7101
7534
  ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
@@ -8116,6 +8549,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8116
8549
  case GGML_OP_IM2COL:
8117
8550
  case GGML_OP_TIMESTEP_EMBEDDING:
8118
8551
  case GGML_OP_POOL_2D:
8552
+ case GGML_OP_CONV_2D_DW:
8119
8553
  case GGML_OP_RWKV_WKV6:
8120
8554
  case GGML_OP_RWKV_WKV7:
8121
8555
  case GGML_OP_LEAKY_RELU:
@@ -8179,6 +8613,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8179
8613
  case GGML_OP_IM2COL:
8180
8614
  case GGML_OP_TIMESTEP_EMBEDDING:
8181
8615
  case GGML_OP_POOL_2D:
8616
+ case GGML_OP_CONV_2D_DW:
8182
8617
  case GGML_OP_LEAKY_RELU:
8183
8618
  {
8184
8619
  // These operations all go through ggml_vk_op_f32, so short-circuit and
@@ -8352,6 +8787,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
8352
8787
  case GGML_OP_POOL_2D:
8353
8788
  ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
8354
8789
 
8790
+ break;
8791
+ case GGML_OP_CONV_2D_DW:
8792
+ ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
8793
+
8355
8794
  break;
8356
8795
  case GGML_OP_LEAKY_RELU:
8357
8796
  ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
@@ -8473,6 +8912,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
8473
8912
  case GGML_OP_IM2COL:
8474
8913
  case GGML_OP_TIMESTEP_EMBEDDING:
8475
8914
  case GGML_OP_POOL_2D:
8915
+ case GGML_OP_CONV_2D_DW:
8476
8916
  case GGML_OP_RWKV_WKV6:
8477
8917
  case GGML_OP_RWKV_WKV7:
8478
8918
  case GGML_OP_LEAKY_RELU:
@@ -9209,7 +9649,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9209
9649
  case GGML_UNARY_OP_RELU:
9210
9650
  case GGML_UNARY_OP_TANH:
9211
9651
  case GGML_UNARY_OP_SIGMOID:
9212
- return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
9652
+ return ggml_is_contiguous(op->src[0]) &&
9653
+ (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
9654
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
9655
+ (op->src[0]->type == op->type);
9213
9656
  default:
9214
9657
  return false;
9215
9658
  }
@@ -9227,6 +9670,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9227
9670
  switch (src0_type) {
9228
9671
  case GGML_TYPE_F32:
9229
9672
  case GGML_TYPE_F16:
9673
+ case GGML_TYPE_BF16:
9230
9674
  case GGML_TYPE_Q4_0:
9231
9675
  case GGML_TYPE_Q4_1:
9232
9676
  case GGML_TYPE_Q5_0:
@@ -9262,19 +9706,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9262
9706
  if (a->ne[3] != b->ne[3]) {
9263
9707
  return false;
9264
9708
  }
9265
- if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
9709
+ if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
9266
9710
  !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
9267
9711
  return false;
9268
9712
  }
9713
+ if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
9714
+ // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
9715
+ // So don't support this combination for now.
9716
+ return false;
9717
+ }
9269
9718
 
9270
9719
  return true;
9271
9720
  } break;
9272
9721
  case GGML_OP_FLASH_ATTN_EXT:
9273
9722
  {
9274
9723
  ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
9275
- if (!ggml_vk_get_device(ctx->device)->coopmat2) {
9276
- return false;
9277
- }
9724
+ auto device = ggml_vk_get_device(ctx->device);
9725
+ bool coopmat2 = device->coopmat2;
9278
9726
  switch (op->src[0]->ne[0]) {
9279
9727
  case 64:
9280
9728
  case 80:
@@ -9282,7 +9730,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9282
9730
  case 112:
9283
9731
  case 128:
9284
9732
  case 256:
9285
- case 575: // DeepSeek MLA
9286
9733
  break;
9287
9734
  default:
9288
9735
  return false;
@@ -9308,10 +9755,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9308
9755
  switch (op->src[1]->type) {
9309
9756
  case GGML_TYPE_F16:
9310
9757
  case GGML_TYPE_Q4_0:
9758
+ case GGML_TYPE_Q8_0:
9759
+ // supported in scalar and coopmat2 paths
9760
+ break;
9311
9761
  case GGML_TYPE_Q4_1:
9312
9762
  case GGML_TYPE_Q5_0:
9313
9763
  case GGML_TYPE_Q5_1:
9314
- case GGML_TYPE_Q8_0:
9315
9764
  // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
9316
9765
  //case GGML_TYPE_Q2_K:
9317
9766
  //case GGML_TYPE_Q3_K:
@@ -9327,10 +9776,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9327
9776
  //case GGML_TYPE_IQ3_S:
9328
9777
  //case GGML_TYPE_IQ4_XS:
9329
9778
  case GGML_TYPE_IQ4_NL:
9779
+ // currently supported only in coopmat2 path
9780
+ if (!coopmat2) {
9781
+ return false;
9782
+ }
9330
9783
  break;
9331
9784
  default:
9332
9785
  return false;
9333
9786
  }
9787
+ if (!coopmat2 && !device->subgroup_shuffle) {
9788
+ // scalar FA uses subgroupShuffle
9789
+ return false;
9790
+ }
9334
9791
  return true;
9335
9792
  }
9336
9793
  case GGML_OP_GET_ROWS:
@@ -9338,6 +9795,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9338
9795
  switch (op->src[0]->type) {
9339
9796
  case GGML_TYPE_F32:
9340
9797
  case GGML_TYPE_F16:
9798
+ case GGML_TYPE_BF16:
9341
9799
  case GGML_TYPE_Q4_0:
9342
9800
  case GGML_TYPE_Q4_1:
9343
9801
  case GGML_TYPE_Q5_0:
@@ -9368,6 +9826,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9368
9826
  switch (src1_type) {
9369
9827
  case GGML_TYPE_F32:
9370
9828
  case GGML_TYPE_F16:
9829
+ case GGML_TYPE_BF16:
9371
9830
  case GGML_TYPE_Q4_0:
9372
9831
  case GGML_TYPE_Q4_1:
9373
9832
  case GGML_TYPE_Q5_0:
@@ -9381,6 +9840,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9381
9840
  }
9382
9841
  if (src1_type == GGML_TYPE_F32) {
9383
9842
  switch (src0_type) {
9843
+ case GGML_TYPE_F16:
9384
9844
  case GGML_TYPE_Q4_0:
9385
9845
  case GGML_TYPE_Q4_1:
9386
9846
  case GGML_TYPE_Q5_0:
@@ -9419,6 +9879,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9419
9879
  case GGML_OP_SUB:
9420
9880
  case GGML_OP_MUL:
9421
9881
  case GGML_OP_DIV:
9882
+ return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
9883
+ (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
9884
+ (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
9422
9885
  case GGML_OP_SILU_BACK:
9423
9886
  case GGML_OP_RMS_NORM_BACK:
9424
9887
  case GGML_OP_SQR:
@@ -9442,6 +9905,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
9442
9905
  case GGML_OP_COUNT_EQUAL:
9443
9906
  case GGML_OP_IM2COL:
9444
9907
  case GGML_OP_TIMESTEP_EMBEDDING:
9908
+ case GGML_OP_CONV_2D_DW:
9445
9909
  case GGML_OP_POOL_2D:
9446
9910
  case GGML_OP_RWKV_WKV6:
9447
9911
  case GGML_OP_RWKV_WKV7: