@fugood/llama.node 0.3.17 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +3 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +39 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +366 -19
- package/src/LlamaCompletionWorker.h +30 -10
- package/src/LlamaContext.cpp +213 -5
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +133 -24
- package/src/llama.cpp/.github/workflows/build.yml +41 -762
- package/src/llama.cpp/.github/workflows/docker.yml +5 -2
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +12 -12
- package/src/llama.cpp/CMakeLists.txt +5 -17
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +31 -3
- package/src/llama.cpp/common/arg.cpp +48 -29
- package/src/llama.cpp/common/chat.cpp +128 -106
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +37 -1
- package/src/llama.cpp/common/common.h +18 -9
- package/src/llama.cpp/common/llguidance.cpp +1 -0
- package/src/llama.cpp/common/minja/chat-template.hpp +9 -5
- package/src/llama.cpp/common/minja/minja.hpp +69 -36
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +57 -50
- package/src/llama.cpp/examples/CMakeLists.txt +2 -23
- package/src/llama.cpp/examples/embedding/embedding.cpp +2 -11
- package/src/llama.cpp/examples/parallel/parallel.cpp +86 -14
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +27 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml.h +10 -7
- package/src/llama.cpp/ggml/src/CMakeLists.txt +1 -1
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +20 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +306 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +4 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +29 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +88 -5
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +264 -69
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +501 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +0 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +0 -6
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +36 -11
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +0 -2
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +41 -27
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +29 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +9 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +121 -232
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +7 -15
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +0 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +338 -166
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +81 -70
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +657 -193
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +20 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +123 -29
- package/src/llama.cpp/ggml/src/ggml.c +29 -20
- package/src/llama.cpp/ggml/src/gguf.cpp +33 -33
- package/src/llama.cpp/include/llama.h +52 -11
- package/src/llama.cpp/requirements/requirements-all.txt +3 -3
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +1 -0
- package/src/llama.cpp/src/llama-adapter.cpp +6 -0
- package/src/llama.cpp/src/llama-arch.cpp +3 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +17 -7
- package/src/llama.cpp/src/llama-chat.h +1 -0
- package/src/llama.cpp/src/llama-context.cpp +389 -501
- package/src/llama.cpp/src/llama-context.h +44 -32
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +20 -38
- package/src/llama.cpp/src/llama-graph.h +12 -8
- package/src/llama.cpp/src/llama-kv-cache.cpp +1503 -389
- package/src/llama.cpp/src/llama-kv-cache.h +271 -85
- package/src/llama.cpp/src/llama-memory.h +11 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +24 -15
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +316 -69
- package/src/llama.cpp/src/llama-model.h +8 -1
- package/src/llama.cpp/src/llama-quant.cpp +15 -13
- package/src/llama.cpp/src/llama-sampling.cpp +18 -6
- package/src/llama.cpp/src/llama-vocab.cpp +42 -4
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +14 -0
- package/src/llama.cpp/tests/CMakeLists.txt +10 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +107 -47
- package/src/llama.cpp/tests/test-chat-template.cpp +10 -11
- package/src/llama.cpp/tests/test-chat.cpp +3 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +2 -2
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +495 -348
- package/src/llama.cpp/{examples → tools}/main/main.cpp +6 -9
- package/src/llama.cpp/{examples/llava → tools/mtmd}/CMakeLists.txt +1 -35
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip-impl.h +25 -5
- package/src/llama.cpp/{examples/llava → tools/mtmd}/clip.cpp +1440 -1349
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd-cli.cpp +70 -44
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/{examples/llava → tools/mtmd}/mtmd.cpp +251 -281
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +4 -2
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +13 -76
- package/src/llama.cpp/{examples → tools}/rpc/rpc-server.cpp +70 -74
- package/src/llama.cpp/{examples → tools}/run/run.cpp +18 -4
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/server.cpp +291 -76
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +377 -5
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.h +0 -135
- package/src/llama.cpp/examples/llava/llava.cpp +0 -586
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/mtmd.h +0 -168
- package/src/llama.cpp/examples/llava/qwen2vl-test.cpp +0 -636
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/deprecation-warning.cpp +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/rpc/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/httplib.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/tts.cpp +0 -0
|
@@ -51,6 +51,24 @@
|
|
|
51
51
|
|
|
52
52
|
#include "ggml-vulkan-shaders.hpp"
|
|
53
53
|
|
|
54
|
+
// remove this once it's more widely available in the SDK
|
|
55
|
+
#if !defined(VK_KHR_shader_bfloat16)
|
|
56
|
+
|
|
57
|
+
#define VK_KHR_shader_bfloat16 1
|
|
58
|
+
#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
|
|
59
|
+
#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
|
|
60
|
+
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
|
|
61
|
+
#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
|
|
62
|
+
|
|
63
|
+
typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
|
|
64
|
+
VkStructureType sType;
|
|
65
|
+
void* pNext;
|
|
66
|
+
VkBool32 shaderBFloat16Type;
|
|
67
|
+
VkBool32 shaderBFloat16DotProduct;
|
|
68
|
+
VkBool32 shaderBFloat16CooperativeMatrix;
|
|
69
|
+
} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
|
|
70
|
+
#endif
|
|
71
|
+
|
|
54
72
|
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
|
55
73
|
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
|
56
74
|
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
|
@@ -257,6 +275,7 @@ struct vk_device_struct {
|
|
|
257
275
|
bool prefer_host_memory;
|
|
258
276
|
bool float_controls_rte_fp16;
|
|
259
277
|
bool subgroup_add;
|
|
278
|
+
bool subgroup_shuffle;
|
|
260
279
|
|
|
261
280
|
bool integer_dot_product;
|
|
262
281
|
|
|
@@ -266,8 +285,12 @@ struct vk_device_struct {
|
|
|
266
285
|
bool subgroup_require_full_support;
|
|
267
286
|
|
|
268
287
|
bool coopmat_support;
|
|
269
|
-
bool coopmat_acc_f32_support;
|
|
270
|
-
bool coopmat_acc_f16_support;
|
|
288
|
+
bool coopmat_acc_f32_support {};
|
|
289
|
+
bool coopmat_acc_f16_support {};
|
|
290
|
+
bool coopmat_bf16_support {};
|
|
291
|
+
bool coopmat_support_16x16x16_f16acc {};
|
|
292
|
+
bool coopmat_support_16x16x16_f32acc {};
|
|
293
|
+
bool coopmat1_fa_support {};
|
|
271
294
|
uint32_t coopmat_m;
|
|
272
295
|
uint32_t coopmat_n;
|
|
273
296
|
uint32_t coopmat_k;
|
|
@@ -293,6 +316,7 @@ struct vk_device_struct {
|
|
|
293
316
|
|
|
294
317
|
vk_matmul_pipeline pipeline_matmul_f32 {};
|
|
295
318
|
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
|
319
|
+
vk_matmul_pipeline pipeline_matmul_bf16 {};
|
|
296
320
|
vk_matmul_pipeline2 pipeline_matmul_f16;
|
|
297
321
|
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
|
298
322
|
|
|
@@ -301,6 +325,7 @@ struct vk_device_struct {
|
|
|
301
325
|
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
|
302
326
|
|
|
303
327
|
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
|
328
|
+
vk_matmul_pipeline pipeline_matmul_id_bf16 {};
|
|
304
329
|
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
|
305
330
|
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
|
306
331
|
|
|
@@ -319,11 +344,17 @@ struct vk_device_struct {
|
|
|
319
344
|
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
|
320
345
|
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
|
321
346
|
vk_pipeline pipeline_acc_f32;
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
vk_pipeline
|
|
325
|
-
vk_pipeline
|
|
326
|
-
vk_pipeline
|
|
347
|
+
|
|
348
|
+
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
|
|
349
|
+
vk_pipeline pipeline_add[2][2][2];
|
|
350
|
+
vk_pipeline pipeline_add_norepeat[2][2][2];
|
|
351
|
+
vk_pipeline pipeline_sub[2][2][2];
|
|
352
|
+
vk_pipeline pipeline_sub_norepeat[2][2][2];
|
|
353
|
+
vk_pipeline pipeline_mul[2][2][2];
|
|
354
|
+
vk_pipeline pipeline_mul_norepeat[2][2][2];
|
|
355
|
+
vk_pipeline pipeline_div[2][2][2];
|
|
356
|
+
vk_pipeline pipeline_div_norepeat[2][2][2];
|
|
357
|
+
|
|
327
358
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
|
328
359
|
vk_pipeline pipeline_upscale_f32;
|
|
329
360
|
vk_pipeline pipeline_scale_f32;
|
|
@@ -333,8 +364,8 @@ struct vk_device_struct {
|
|
|
333
364
|
vk_pipeline pipeline_clamp_f32;
|
|
334
365
|
vk_pipeline pipeline_pad_f32;
|
|
335
366
|
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
|
336
|
-
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
|
|
337
|
-
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
|
|
367
|
+
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
|
|
368
|
+
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
|
|
338
369
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
|
339
370
|
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
|
340
371
|
vk_pipeline pipeline_norm_f32;
|
|
@@ -342,14 +373,17 @@ struct vk_device_struct {
|
|
|
342
373
|
vk_pipeline pipeline_rms_norm_f32;
|
|
343
374
|
vk_pipeline pipeline_rms_norm_back_f32;
|
|
344
375
|
vk_pipeline pipeline_l2_norm_f32;
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
vk_pipeline
|
|
348
|
-
vk_pipeline
|
|
349
|
-
vk_pipeline
|
|
376
|
+
|
|
377
|
+
// [src/dst 0=fp32,1=fp16]
|
|
378
|
+
vk_pipeline pipeline_gelu[2];
|
|
379
|
+
vk_pipeline pipeline_gelu_quick[2];
|
|
380
|
+
vk_pipeline pipeline_silu[2];
|
|
381
|
+
vk_pipeline pipeline_relu[2];
|
|
382
|
+
vk_pipeline pipeline_tanh[2];
|
|
383
|
+
vk_pipeline pipeline_sigmoid[2];
|
|
384
|
+
|
|
350
385
|
vk_pipeline pipeline_leaky_relu_f32;
|
|
351
|
-
vk_pipeline
|
|
352
|
-
vk_pipeline pipeline_sigmoid_f32;
|
|
386
|
+
vk_pipeline pipeline_silu_back_f32;
|
|
353
387
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
354
388
|
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
|
355
389
|
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
|
@@ -368,14 +402,31 @@ struct vk_device_struct {
|
|
|
368
402
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
369
403
|
vk_pipeline pipeline_rwkv_wkv7_f32;
|
|
370
404
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
|
405
|
+
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
|
406
|
+
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
|
371
407
|
|
|
372
408
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
409
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
410
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
411
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
412
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
413
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
414
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
415
|
+
|
|
416
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
417
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
418
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
419
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
420
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
421
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
422
|
+
|
|
373
423
|
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
374
424
|
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
375
425
|
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
376
426
|
vk_pipeline pipeline_flash_attn_f32_f16_D112[GGML_TYPE_COUNT][2][2][2];
|
|
377
427
|
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
|
378
428
|
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
429
|
+
|
|
379
430
|
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
380
431
|
|
|
381
432
|
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
|
@@ -680,6 +731,24 @@ struct vk_op_rwkv_wkv7_push_constants {
|
|
|
680
731
|
uint32_t H;
|
|
681
732
|
};
|
|
682
733
|
|
|
734
|
+
struct vk_op_conv2d_dw_push_constants {
|
|
735
|
+
uint32_t ne;
|
|
736
|
+
uint32_t batches;
|
|
737
|
+
uint32_t channels;
|
|
738
|
+
uint32_t dst_w;
|
|
739
|
+
uint32_t dst_h;
|
|
740
|
+
uint32_t src_w;
|
|
741
|
+
uint32_t src_h;
|
|
742
|
+
uint32_t knl_w;
|
|
743
|
+
uint32_t knl_h;
|
|
744
|
+
int32_t stride_x;
|
|
745
|
+
int32_t stride_y;
|
|
746
|
+
int32_t pad_x;
|
|
747
|
+
int32_t pad_y;
|
|
748
|
+
int32_t dilation_x;
|
|
749
|
+
int32_t dilation_y;
|
|
750
|
+
};
|
|
751
|
+
|
|
683
752
|
struct vk_op_upscale_push_constants {
|
|
684
753
|
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
|
685
754
|
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
|
@@ -1529,15 +1598,56 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
|
1529
1598
|
);
|
|
1530
1599
|
}
|
|
1531
1600
|
|
|
1601
|
+
enum FaCodePath {
|
|
1602
|
+
FA_SCALAR,
|
|
1603
|
+
FA_COOPMAT1,
|
|
1604
|
+
FA_COOPMAT2,
|
|
1605
|
+
};
|
|
1606
|
+
|
|
1532
1607
|
// number of rows/cols for flash attention shader
|
|
1533
1608
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
1534
|
-
static
|
|
1609
|
+
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
1610
|
+
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
|
1611
|
+
|
|
1612
|
+
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
|
1613
|
+
// 128 threads split into four subgroups, each subgroup does 1/4
|
|
1614
|
+
// of the Bc dimension.
|
|
1615
|
+
static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
|
|
1616
|
+
static constexpr uint32_t scalar_flash_attention_Bc = 64;
|
|
1617
|
+
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
|
|
1618
|
+
|
|
1619
|
+
static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
|
1620
|
+
if (path == FA_COOPMAT2) {
|
|
1621
|
+
return flash_attention_num_small_rows;
|
|
1622
|
+
} else {
|
|
1623
|
+
return scalar_flash_attention_num_small_rows;
|
|
1624
|
+
}
|
|
1625
|
+
}
|
|
1626
|
+
|
|
1627
|
+
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
1535
1628
|
GGML_UNUSED(clamp);
|
|
1536
1629
|
|
|
1630
|
+
if (path == FA_SCALAR) {
|
|
1631
|
+
if (small_rows) {
|
|
1632
|
+
return {scalar_flash_attention_num_small_rows, 64};
|
|
1633
|
+
} else {
|
|
1634
|
+
return {scalar_flash_attention_num_large_rows, 32};
|
|
1635
|
+
}
|
|
1636
|
+
}
|
|
1637
|
+
|
|
1638
|
+
if (path == FA_COOPMAT1) {
|
|
1639
|
+
if (small_rows) {
|
|
1640
|
+
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
|
|
1641
|
+
} else {
|
|
1642
|
+
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
|
|
1643
|
+
}
|
|
1644
|
+
}
|
|
1645
|
+
|
|
1537
1646
|
// small rows, large cols
|
|
1538
1647
|
if (small_rows) {
|
|
1539
|
-
return {
|
|
1648
|
+
return {get_fa_num_small_rows(FA_COOPMAT2), 32};
|
|
1540
1649
|
}
|
|
1650
|
+
|
|
1541
1651
|
// small cols to reduce register count
|
|
1542
1652
|
if (ggml_is_quantized(type) || D == 256) {
|
|
1543
1653
|
return {64, 32};
|
|
@@ -1582,7 +1692,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
1582
1692
|
const uint32_t warps = warptile[0] / warptile[10];
|
|
1583
1693
|
|
|
1584
1694
|
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
|
1585
|
-
const uint32_t mmid_row_ids = mul_mat_id ?
|
|
1695
|
+
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
|
1586
1696
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
1587
1697
|
|
|
1588
1698
|
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
|
@@ -1791,6 +1901,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1791
1901
|
if (!device->pipeline_matmul_id_f32) {
|
|
1792
1902
|
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1793
1903
|
}
|
|
1904
|
+
if (!device->pipeline_matmul_bf16) {
|
|
1905
|
+
device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1906
|
+
}
|
|
1907
|
+
if (!device->pipeline_matmul_id_bf16) {
|
|
1908
|
+
device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1909
|
+
}
|
|
1794
1910
|
|
|
1795
1911
|
std::vector<std::future<void>> compiles;
|
|
1796
1912
|
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
|
@@ -1826,65 +1942,75 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1826
1942
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
|
1827
1943
|
};
|
|
1828
1944
|
|
|
1829
|
-
|
|
1830
|
-
|
|
1831
|
-
|
|
1832
|
-
auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
1833
|
-
return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
|
|
1834
|
-
};
|
|
1945
|
+
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
1946
|
+
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
|
1947
|
+
};
|
|
1835
1948
|
|
|
1836
|
-
|
|
1837
|
-
|
|
1838
|
-
|
|
1839
|
-
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
|
|
1845
|
-
|
|
1949
|
+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
1950
|
+
// For large number of rows, 128 invocations seems to work best.
|
|
1951
|
+
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
1952
|
+
// can't use 256 for D==80.
|
|
1953
|
+
// For scalar, use 128 (arbitrary)
|
|
1954
|
+
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
|
1955
|
+
? scalar_flash_attention_workgroup_size
|
|
1956
|
+
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
|
1957
|
+
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
|
1958
|
+
|
|
1959
|
+
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
|
1960
|
+
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
|
1961
|
+
const uint32_t D_lsb = D ^ (D & (D-1));
|
|
1962
|
+
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
|
1963
|
+
|
|
1964
|
+
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
1965
|
+
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
|
1966
|
+
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
|
1967
|
+
};
|
|
1846
1968
|
|
|
1847
|
-
#define CREATE_FA2(TYPE, NAMELC, D) \
|
|
1848
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1849
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1850
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1851
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1852
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1853
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1854
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1855
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1856
|
-
|
|
1857
|
-
#define CREATE_FA(TYPE, NAMELC) \
|
|
1858
|
-
CREATE_FA2(TYPE, NAMELC, 64) \
|
|
1859
|
-
CREATE_FA2(TYPE, NAMELC, 80) \
|
|
1860
|
-
CREATE_FA2(TYPE, NAMELC, 96) \
|
|
1861
|
-
CREATE_FA2(TYPE, NAMELC, 112) \
|
|
1862
|
-
CREATE_FA2(TYPE, NAMELC, 128) \
|
|
1863
|
-
CREATE_FA2(TYPE, NAMELC, 256)
|
|
1864
|
-
|
|
1865
|
-
|
|
1866
|
-
|
|
1867
|
-
|
|
1868
|
-
|
|
1869
|
-
|
|
1870
|
-
CREATE_FA(
|
|
1871
|
-
|
|
1872
|
-
|
|
1873
|
-
|
|
1874
|
-
|
|
1875
|
-
|
|
1876
|
-
|
|
1877
|
-
|
|
1878
|
-
|
|
1879
|
-
|
|
1880
|
-
|
|
1881
|
-
|
|
1882
|
-
|
|
1883
|
-
|
|
1884
|
-
|
|
1885
|
-
|
|
1969
|
+
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
|
1970
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1971
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1972
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1973
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1974
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1975
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1976
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1977
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1978
|
+
|
|
1979
|
+
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
|
1980
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
|
1981
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
|
1982
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
|
1983
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
|
1984
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
|
1985
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
|
1986
|
+
|
|
1987
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
|
1988
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
1989
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
|
1990
|
+
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
1991
|
+
if (device->coopmat1_fa_support) {
|
|
1992
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
|
1993
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
|
1994
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
|
1995
|
+
}
|
|
1996
|
+
#endif
|
|
1997
|
+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
1998
|
+
if (device->coopmat2) {
|
|
1999
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
|
2000
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
|
2001
|
+
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
|
2002
|
+
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
|
|
2003
|
+
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
|
|
2004
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
|
|
2005
|
+
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
|
|
2006
|
+
}
|
|
2007
|
+
#endif
|
|
2008
|
+
#undef CREATE_FA2
|
|
1886
2009
|
#undef CREATE_FA
|
|
1887
2010
|
|
|
2011
|
+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
2012
|
+
if (device->coopmat2) {
|
|
2013
|
+
|
|
1888
2014
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1889
2015
|
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
1890
2016
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
@@ -1900,6 +2026,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1900
2026
|
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
1901
2027
|
|
|
1902
2028
|
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
2029
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2030
|
+
if (device->coopmat_bf16_support) {
|
|
2031
|
+
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
2032
|
+
}
|
|
2033
|
+
#endif
|
|
1903
2034
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1904
2035
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1905
2036
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
@@ -1921,6 +2052,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1921
2052
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1922
2053
|
|
|
1923
2054
|
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
2055
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2056
|
+
if (device->coopmat_bf16_support) {
|
|
2057
|
+
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
2058
|
+
}
|
|
2059
|
+
#endif
|
|
1924
2060
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1925
2061
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1926
2062
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
@@ -1949,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1949
2085
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1950
2086
|
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1951
2087
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1952
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ##
|
|
2088
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
|
1953
2089
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1954
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ##
|
|
2090
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
|
1955
2091
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1956
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ##
|
|
2092
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
|
1957
2093
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1958
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ##
|
|
2094
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
|
1959
2095
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1960
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ##
|
|
2096
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
|
1961
2097
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1962
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ##
|
|
2098
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
|
1963
2099
|
|
|
1964
2100
|
// Create 2 variants, {f16,f32} accumulator
|
|
1965
2101
|
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
@@ -1974,6 +2110,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1974
2110
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1975
2111
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1976
2112
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2113
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2114
|
+
if (device->coopmat_bf16_support) {
|
|
2115
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
|
|
2116
|
+
}
|
|
2117
|
+
#endif
|
|
1977
2118
|
|
|
1978
2119
|
if (device->coopmat_acc_f16_support) {
|
|
1979
2120
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2022,6 +2163,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2022
2163
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2023
2164
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2024
2165
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2166
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2167
|
+
if (device->coopmat_bf16_support) {
|
|
2168
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2169
|
+
}
|
|
2170
|
+
#endif
|
|
2025
2171
|
|
|
2026
2172
|
if (device->coopmat_acc_f16_support) {
|
|
2027
2173
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2104,6 +2250,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2104
2250
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2105
2251
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2106
2252
|
|
|
2253
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2254
|
+
|
|
2107
2255
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2108
2256
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2109
2257
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2139,6 +2287,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2139
2287
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2140
2288
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2141
2289
|
|
|
2290
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
|
2291
|
+
|
|
2142
2292
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2143
2293
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2144
2294
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2191,6 +2341,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2191
2341
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2192
2342
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2193
2343
|
|
|
2344
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2345
|
+
|
|
2194
2346
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2195
2347
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2196
2348
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2226,6 +2378,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2226
2378
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2227
2379
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2228
2380
|
|
|
2381
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
|
2382
|
+
|
|
2229
2383
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2230
2384
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2231
2385
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2246,8 +2400,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2246
2400
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2247
2401
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2248
2402
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2249
|
-
#undef CREATE_MM
|
|
2250
2403
|
}
|
|
2404
|
+
// reusing CREATE_MM from the fp32 path
|
|
2405
|
+
if ((device->coopmat2 || device->coopmat_support)
|
|
2406
|
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2407
|
+
&& !device->coopmat_bf16_support
|
|
2408
|
+
#endif
|
|
2409
|
+
) {
|
|
2410
|
+
// use scalar tile sizes
|
|
2411
|
+
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
|
2412
|
+
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
|
|
2413
|
+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
|
|
2414
|
+
|
|
2415
|
+
l_wg_denoms = {128, 128, 1 };
|
|
2416
|
+
m_wg_denoms = { 64, 64, 1 };
|
|
2417
|
+
s_wg_denoms = { 32, 32, 1 };
|
|
2418
|
+
|
|
2419
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2420
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
|
2421
|
+
}
|
|
2422
|
+
#undef CREATE_MM
|
|
2251
2423
|
|
|
2252
2424
|
// mul mat vec
|
|
2253
2425
|
|
|
@@ -2266,6 +2438,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2266
2438
|
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
|
2267
2439
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2268
2440
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2441
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2269
2442
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
2270
2443
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
2271
2444
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
@@ -2288,6 +2461,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2288
2461
|
|
|
2289
2462
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2290
2463
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2464
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2291
2465
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
2292
2466
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
2293
2467
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
@@ -2311,6 +2485,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2311
2485
|
|
|
2312
2486
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
2313
2487
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
2488
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
2314
2489
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
2315
2490
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
2316
2491
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
@@ -2356,6 +2531,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2356
2531
|
// get_rows
|
|
2357
2532
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2358
2533
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2534
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2359
2535
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2360
2536
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2361
2537
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -2373,6 +2549,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2373
2549
|
|
|
2374
2550
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2375
2551
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2552
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2376
2553
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2377
2554
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2378
2555
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -2399,7 +2576,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2399
2576
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
|
|
2400
2577
|
}
|
|
2401
2578
|
}
|
|
2402
|
-
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3,
|
|
2579
|
+
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
|
2403
2580
|
|
|
2404
2581
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2405
2582
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
@@ -2410,10 +2587,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2410
2587
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2411
2588
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2412
2589
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2590
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2591
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2413
2592
|
|
|
2414
2593
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2415
2594
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2416
2595
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2596
|
+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2597
|
+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2598
|
+
|
|
2417
2599
|
if (device->float_controls_rte_fp16) {
|
|
2418
2600
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
|
2419
2601
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
|
@@ -2437,20 +2619,32 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2437
2619
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
|
|
2438
2620
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
|
|
2439
2621
|
|
|
2440
|
-
|
|
2441
|
-
|
|
2442
|
-
|
|
2443
|
-
|
|
2622
|
+
auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
|
|
2623
|
+
std::string s;
|
|
2624
|
+
s += std::string(src0_f16 ? "_f16" : "_f32");
|
|
2625
|
+
s += std::string(src1_f16 ? "_f16" : "_f32");
|
|
2626
|
+
s += std::string(dst_f16 ? "_f16" : "_f32");
|
|
2627
|
+
return s;
|
|
2628
|
+
};
|
|
2629
|
+
|
|
2630
|
+
#define CREATE_BINARY(name, namemod, spec) \
|
|
2631
|
+
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
|
|
2632
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
|
|
2633
|
+
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
|
|
2634
|
+
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
|
|
2635
|
+
|
|
2636
|
+
CREATE_BINARY(add, , {0})
|
|
2637
|
+
CREATE_BINARY(add, _norepeat, {1})
|
|
2638
|
+
CREATE_BINARY(sub, , {0})
|
|
2639
|
+
CREATE_BINARY(sub, _norepeat, {1})
|
|
2640
|
+
CREATE_BINARY(mul, , {0})
|
|
2641
|
+
CREATE_BINARY(mul, _norepeat, {1})
|
|
2642
|
+
CREATE_BINARY(div, , {0})
|
|
2643
|
+
CREATE_BINARY(div, _norepeat, {1})
|
|
2644
|
+
#undef CREATE_BINARY
|
|
2444
2645
|
|
|
2445
2646
|
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2446
2647
|
|
|
2447
|
-
ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
|
2448
|
-
ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
|
2449
|
-
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
|
2450
|
-
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
|
2451
|
-
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
|
2452
|
-
ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
|
2453
|
-
|
|
2454
2648
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2455
2649
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2456
2650
|
ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -2470,14 +2664,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2470
2664
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2471
2665
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2472
2666
|
|
|
2473
|
-
|
|
2474
|
-
ggml_vk_create_pipeline(device, device->
|
|
2475
|
-
ggml_vk_create_pipeline(device, device->
|
|
2476
|
-
|
|
2477
|
-
|
|
2667
|
+
#define CREATE_UNARY(name) \
|
|
2668
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
|
2669
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2670
|
+
|
|
2671
|
+
CREATE_UNARY(gelu)
|
|
2672
|
+
CREATE_UNARY(gelu_quick)
|
|
2673
|
+
CREATE_UNARY(silu)
|
|
2674
|
+
CREATE_UNARY(relu)
|
|
2675
|
+
CREATE_UNARY(tanh)
|
|
2676
|
+
CREATE_UNARY(sigmoid)
|
|
2677
|
+
#undef CREATE_UNARY
|
|
2678
|
+
|
|
2478
2679
|
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2479
|
-
ggml_vk_create_pipeline(device, device->
|
|
2480
|
-
ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2680
|
+
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2481
2681
|
|
|
2482
2682
|
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
|
2483
2683
|
|
|
@@ -2529,6 +2729,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2529
2729
|
|
|
2530
2730
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2531
2731
|
|
|
2732
|
+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
2733
|
+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
2734
|
+
|
|
2532
2735
|
for (auto &c : compiles) {
|
|
2533
2736
|
c.wait();
|
|
2534
2737
|
}
|
|
@@ -2578,6 +2781,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2578
2781
|
bool coopmat2_support = false;
|
|
2579
2782
|
device->coopmat_support = false;
|
|
2580
2783
|
device->integer_dot_product = false;
|
|
2784
|
+
bool bfloat16_support = false;
|
|
2581
2785
|
|
|
2582
2786
|
for (const auto& properties : ext_props) {
|
|
2583
2787
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
|
@@ -2608,6 +2812,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2608
2812
|
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
|
2609
2813
|
device->integer_dot_product = true;
|
|
2610
2814
|
#endif
|
|
2815
|
+
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
|
|
2816
|
+
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
|
2817
|
+
bfloat16_support = true;
|
|
2611
2818
|
}
|
|
2612
2819
|
}
|
|
2613
2820
|
|
|
@@ -2700,6 +2907,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2700
2907
|
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
2701
2908
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
|
|
2702
2909
|
|
|
2910
|
+
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
2911
|
+
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
|
|
2912
|
+
|
|
2703
2913
|
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
|
2704
2914
|
|
|
2705
2915
|
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
|
@@ -2794,6 +3004,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2794
3004
|
}
|
|
2795
3005
|
#endif
|
|
2796
3006
|
|
|
3007
|
+
#if defined(VK_KHR_shader_bfloat16)
|
|
3008
|
+
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
|
|
3009
|
+
bfloat16_features.pNext = nullptr;
|
|
3010
|
+
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
|
|
3011
|
+
if (bfloat16_support) {
|
|
3012
|
+
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
|
|
3013
|
+
last_struct = (VkBaseOutStructure *)&bfloat16_features;
|
|
3014
|
+
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
|
3015
|
+
}
|
|
3016
|
+
#endif
|
|
3017
|
+
|
|
2797
3018
|
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
|
2798
3019
|
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
|
2799
3020
|
if (maintenance4_support) {
|
|
@@ -2832,6 +3053,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2832
3053
|
|
|
2833
3054
|
#if defined(VK_KHR_cooperative_matrix)
|
|
2834
3055
|
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
|
3056
|
+
|
|
3057
|
+
// coopmat1 fa shader currently assumes 32 invocations per subgroup
|
|
3058
|
+
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
|
|
3059
|
+
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
|
|
3060
|
+
device->subgroup_max_size >= 32;
|
|
2835
3061
|
#endif
|
|
2836
3062
|
|
|
2837
3063
|
if (coopmat2_support) {
|
|
@@ -2966,6 +3192,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2966
3192
|
// Only enable if shape is identical
|
|
2967
3193
|
device->coopmat_acc_f32_support = true;
|
|
2968
3194
|
}
|
|
3195
|
+
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
|
3196
|
+
device->coopmat_support_16x16x16_f32acc = true;
|
|
3197
|
+
}
|
|
2969
3198
|
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
|
2970
3199
|
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
|
2971
3200
|
// coopmat sizes not set yet
|
|
@@ -2978,6 +3207,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2978
3207
|
// Only enable if shape is identical
|
|
2979
3208
|
device->coopmat_acc_f16_support = true;
|
|
2980
3209
|
}
|
|
3210
|
+
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
|
3211
|
+
device->coopmat_support_16x16x16_f16acc = true;
|
|
3212
|
+
}
|
|
2981
3213
|
}
|
|
2982
3214
|
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
|
|
2983
3215
|
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
|
|
@@ -2991,6 +3223,25 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2991
3223
|
device->coopmat_int_n = prop.NSize;
|
|
2992
3224
|
device->coopmat_int_k = prop.KSize;
|
|
2993
3225
|
}
|
|
3226
|
+
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
3227
|
+
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
|
3228
|
+
prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
|
3229
|
+
prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
|
3230
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
|
3231
|
+
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
|
|
3232
|
+
) {
|
|
3233
|
+
// coopmat sizes not set yet
|
|
3234
|
+
if (device->coopmat_m == 0) {
|
|
3235
|
+
device->coopmat_bf16_support = true;
|
|
3236
|
+
device->coopmat_m = prop.MSize;
|
|
3237
|
+
device->coopmat_n = prop.NSize;
|
|
3238
|
+
device->coopmat_k = prop.KSize;
|
|
3239
|
+
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
|
3240
|
+
// Only enable if shape is identical
|
|
3241
|
+
device->coopmat_bf16_support = true;
|
|
3242
|
+
}
|
|
3243
|
+
}
|
|
3244
|
+
#endif
|
|
2994
3245
|
}
|
|
2995
3246
|
|
|
2996
3247
|
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
|
@@ -2998,11 +3249,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2998
3249
|
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
|
2999
3250
|
device->coopmat_support = false;
|
|
3000
3251
|
}
|
|
3252
|
+
if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
|
3253
|
+
device->coopmat_bf16_support = false;
|
|
3254
|
+
}
|
|
3001
3255
|
}
|
|
3002
3256
|
|
|
3003
3257
|
if (device->coopmat_support) {
|
|
3004
3258
|
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
|
3005
3259
|
}
|
|
3260
|
+
#if defined(VK_KHR_shader_bfloat16)
|
|
3261
|
+
if (device->coopmat_bf16_support) {
|
|
3262
|
+
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
|
3263
|
+
}
|
|
3264
|
+
#endif
|
|
3006
3265
|
#endif
|
|
3007
3266
|
device->name = GGML_VK_NAME + std::to_string(idx);
|
|
3008
3267
|
|
|
@@ -3459,6 +3718,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
3459
3718
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
|
3460
3719
|
return ctx->device->pipeline_matmul_f32_f16;
|
|
3461
3720
|
}
|
|
3721
|
+
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
|
3722
|
+
return ctx->device->pipeline_matmul_bf16;
|
|
3723
|
+
}
|
|
3462
3724
|
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
|
3463
3725
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
|
3464
3726
|
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
|
@@ -3530,6 +3792,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
|
3530
3792
|
switch (a_type) {
|
|
3531
3793
|
case GGML_TYPE_F32:
|
|
3532
3794
|
case GGML_TYPE_F16:
|
|
3795
|
+
case GGML_TYPE_BF16:
|
|
3533
3796
|
case GGML_TYPE_Q4_0:
|
|
3534
3797
|
case GGML_TYPE_Q4_1:
|
|
3535
3798
|
case GGML_TYPE_Q5_0:
|
|
@@ -3562,6 +3825,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
|
3562
3825
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
|
3563
3826
|
return ctx->device->pipeline_matmul_id_f32;
|
|
3564
3827
|
}
|
|
3828
|
+
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
|
3829
|
+
return ctx->device->pipeline_matmul_id_bf16;
|
|
3830
|
+
}
|
|
3565
3831
|
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
|
3566
3832
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
|
3567
3833
|
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
|
@@ -3615,6 +3881,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|
|
3615
3881
|
switch (a_type) {
|
|
3616
3882
|
case GGML_TYPE_F32:
|
|
3617
3883
|
case GGML_TYPE_F16:
|
|
3884
|
+
case GGML_TYPE_BF16:
|
|
3618
3885
|
case GGML_TYPE_Q4_0:
|
|
3619
3886
|
case GGML_TYPE_Q4_1:
|
|
3620
3887
|
case GGML_TYPE_Q5_0:
|
|
@@ -4350,6 +4617,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
4350
4617
|
return ctx->device->pipeline_cpy_f16_f16;
|
|
4351
4618
|
}
|
|
4352
4619
|
}
|
|
4620
|
+
if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
|
|
4621
|
+
if (contig) {
|
|
4622
|
+
return ctx->device->pipeline_contig_cpy_f16_f32;
|
|
4623
|
+
} else {
|
|
4624
|
+
return ctx->device->pipeline_cpy_f16_f32;
|
|
4625
|
+
}
|
|
4626
|
+
}
|
|
4627
|
+
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
|
|
4628
|
+
if (contig) {
|
|
4629
|
+
return ctx->device->pipeline_contig_cpy_f32_bf16;
|
|
4630
|
+
} else {
|
|
4631
|
+
return ctx->device->pipeline_cpy_f32_bf16;
|
|
4632
|
+
}
|
|
4633
|
+
}
|
|
4353
4634
|
if (src->type == GGML_TYPE_F32) {
|
|
4354
4635
|
switch (to) {
|
|
4355
4636
|
case GGML_TYPE_Q4_0:
|
|
@@ -4477,8 +4758,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4477
4758
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
|
4478
4759
|
!ggml_vk_dim01_contiguous(src0);
|
|
4479
4760
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
|
4761
|
+
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
|
4480
4762
|
!ggml_vk_dim01_contiguous(src1);
|
|
4481
4763
|
|
|
4764
|
+
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
|
4765
|
+
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
|
4766
|
+
|
|
4482
4767
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
|
4483
4768
|
|
|
4484
4769
|
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
|
@@ -4488,25 +4773,25 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4488
4773
|
|
|
4489
4774
|
if (mmp == nullptr) {
|
|
4490
4775
|
// Fall back to f16 dequant mul mat
|
|
4491
|
-
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ?
|
|
4776
|
+
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
|
4492
4777
|
quantize_y = false;
|
|
4493
4778
|
}
|
|
4494
4779
|
|
|
4495
4780
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
|
4496
|
-
const bool qy_needs_dequant = !quantize_y && ((src1->type !=
|
|
4781
|
+
const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
|
|
4497
4782
|
|
|
4498
4783
|
if (qx_needs_dequant) {
|
|
4499
4784
|
// Fall back to dequant + f16 mulmat
|
|
4500
|
-
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx,
|
|
4785
|
+
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
|
4501
4786
|
}
|
|
4502
4787
|
|
|
4503
4788
|
// Not implemented
|
|
4504
4789
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
4505
4790
|
|
|
4506
|
-
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ?
|
|
4791
|
+
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
|
|
4507
4792
|
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
|
4508
4793
|
|
|
4509
|
-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ?
|
|
4794
|
+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
|
|
4510
4795
|
|
|
4511
4796
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
|
4512
4797
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
|
@@ -4527,12 +4812,12 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4527
4812
|
vk_pipeline to_q8_1 = nullptr;
|
|
4528
4813
|
|
|
4529
4814
|
if (x_non_contig) {
|
|
4530
|
-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr,
|
|
4815
|
+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
|
4531
4816
|
} else {
|
|
4532
4817
|
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
|
4533
4818
|
}
|
|
4534
4819
|
if (y_non_contig) {
|
|
4535
|
-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr,
|
|
4820
|
+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
|
4536
4821
|
} else {
|
|
4537
4822
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
|
4538
4823
|
}
|
|
@@ -4949,6 +5234,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
4949
5234
|
const uint64_t nb01 = src0->nb[1];
|
|
4950
5235
|
const uint64_t nb02 = src0->nb[2];
|
|
4951
5236
|
|
|
5237
|
+
const uint64_t nb12 = src1->nb[2];
|
|
5238
|
+
|
|
4952
5239
|
// const uint64_t ne10 = src1->ne[0];
|
|
4953
5240
|
const uint64_t ne11 = src1->ne[1];
|
|
4954
5241
|
const uint64_t ne12 = src1->ne[2];
|
|
@@ -4974,6 +5261,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
4974
5261
|
|
|
4975
5262
|
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
|
|
4976
5263
|
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
|
|
5264
|
+
const uint32_t channel_stride_y = nb12 / sizeof(float);
|
|
4977
5265
|
|
|
4978
5266
|
const uint64_t qx_sz = ggml_nbytes(src0);
|
|
4979
5267
|
const uint64_t qy_sz = ggml_nbytes(src1);
|
|
@@ -5004,7 +5292,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
5004
5292
|
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
|
|
5005
5293
|
|
|
5006
5294
|
// compute
|
|
5007
|
-
const std::array<uint32_t,
|
|
5295
|
+
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
|
5008
5296
|
ggml_vk_sync_buffers(subctx);
|
|
5009
5297
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
|
5010
5298
|
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
|
@@ -5029,7 +5317,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
5029
5317
|
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
|
5030
5318
|
// when ne12 and ne13 are one.
|
|
5031
5319
|
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
|
5032
|
-
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
|
5320
|
+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
|
|
5033
5321
|
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
|
5034
5322
|
} else {
|
|
5035
5323
|
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
|
@@ -5056,7 +5344,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5056
5344
|
|
|
5057
5345
|
const uint64_t nei0 = ids->ne[0];
|
|
5058
5346
|
const uint64_t nei1 = ids->ne[1];
|
|
5059
|
-
GGML_ASSERT(nei0 * nei1 <=
|
|
5347
|
+
GGML_ASSERT(nei0 * nei1 <= 4096);
|
|
5060
5348
|
|
|
5061
5349
|
const uint32_t nbi1 = ids->nb[1];
|
|
5062
5350
|
const uint32_t nbi2 = ids->nb[2];
|
|
@@ -5097,27 +5385,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5097
5385
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
|
5098
5386
|
!ggml_vk_dim01_contiguous(src0);
|
|
5099
5387
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
|
5388
|
+
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
|
5100
5389
|
!ggml_vk_dim01_contiguous(src1);
|
|
5101
5390
|
|
|
5391
|
+
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
|
5392
|
+
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
|
5393
|
+
|
|
5102
5394
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
|
5103
5395
|
|
|
5104
|
-
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ?
|
|
5396
|
+
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
|
5105
5397
|
|
|
5106
5398
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
|
5107
|
-
const bool qy_needs_dequant = (src1->type !=
|
|
5399
|
+
const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
|
|
5108
5400
|
|
|
5109
5401
|
if (qx_needs_dequant) {
|
|
5110
5402
|
// Fall back to dequant + f16 mulmat
|
|
5111
|
-
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx,
|
|
5403
|
+
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
|
5112
5404
|
}
|
|
5113
5405
|
|
|
5114
5406
|
// Not implemented
|
|
5115
5407
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
5116
5408
|
|
|
5117
|
-
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ?
|
|
5409
|
+
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
|
|
5118
5410
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
|
5119
5411
|
|
|
5120
|
-
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ?
|
|
5412
|
+
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
|
|
5121
5413
|
|
|
5122
5414
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
|
5123
5415
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
|
@@ -5136,12 +5428,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
5136
5428
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
|
5137
5429
|
|
|
5138
5430
|
if (x_non_contig) {
|
|
5139
|
-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr,
|
|
5431
|
+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
|
5140
5432
|
} else {
|
|
5141
5433
|
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
|
5142
5434
|
}
|
|
5143
5435
|
if (y_non_contig) {
|
|
5144
|
-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr,
|
|
5436
|
+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
|
5145
5437
|
} else {
|
|
5146
5438
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
|
5147
5439
|
}
|
|
@@ -5451,6 +5743,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5451
5743
|
}
|
|
5452
5744
|
}
|
|
5453
5745
|
|
|
5746
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
|
5747
|
+
// Needs to be kept up to date on shader changes
|
|
5748
|
+
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
5749
|
+
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
|
5750
|
+
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
5751
|
+
|
|
5752
|
+
const uint32_t acctype = f32acc ? 4 : 2;
|
|
5753
|
+
const uint32_t f16vec4 = 8;
|
|
5754
|
+
|
|
5755
|
+
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
5756
|
+
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
|
5757
|
+
|
|
5758
|
+
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
|
5759
|
+
|
|
5760
|
+
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
|
5761
|
+
const uint32_t sfsh = Bc * sfshstride * acctype;
|
|
5762
|
+
|
|
5763
|
+
const uint32_t kshstride = D / 4 + 2;
|
|
5764
|
+
const uint32_t ksh = Bc * kshstride * f16vec4;
|
|
5765
|
+
|
|
5766
|
+
const uint32_t slope = Br * sizeof(float);
|
|
5767
|
+
|
|
5768
|
+
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
|
5769
|
+
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
5770
|
+
|
|
5771
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
|
5772
|
+
|
|
5773
|
+
return supported;
|
|
5774
|
+
}
|
|
5775
|
+
|
|
5454
5776
|
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
|
5455
5777
|
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
|
5456
5778
|
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
|
@@ -5501,36 +5823,19 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5501
5823
|
assert(q->type == GGML_TYPE_F32);
|
|
5502
5824
|
assert(k->type == v->type);
|
|
5503
5825
|
|
|
5504
|
-
|
|
5505
|
-
|
|
5506
|
-
bool f32acc = dst->op_params[3] == GGML_PREC_F32;
|
|
5507
|
-
bool small_rows = N <= flash_attention_num_small_rows;
|
|
5508
|
-
switch (D) {
|
|
5509
|
-
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
5510
|
-
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
5511
|
-
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
5512
|
-
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
5513
|
-
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
5514
|
-
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
5515
|
-
default:
|
|
5516
|
-
assert(!"unsupported D value");
|
|
5517
|
-
return;
|
|
5518
|
-
}
|
|
5519
|
-
assert(pipelines);
|
|
5520
|
-
|
|
5521
|
-
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
|
5522
|
-
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
|
5523
|
-
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
|
5826
|
+
FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
|
|
5827
|
+
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
|
5524
5828
|
|
|
5525
|
-
|
|
5526
|
-
|
|
5527
|
-
|
|
5829
|
+
if (path == FA_COOPMAT1) {
|
|
5830
|
+
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
|
5831
|
+
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
|
5528
5832
|
|
|
5529
|
-
|
|
5530
|
-
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
|
|
5833
|
+
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
|
5531
5834
|
|
|
5532
|
-
|
|
5533
|
-
|
|
5835
|
+
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
|
5836
|
+
path = FA_SCALAR;
|
|
5837
|
+
}
|
|
5838
|
+
}
|
|
5534
5839
|
|
|
5535
5840
|
uint32_t gqa_ratio = 1;
|
|
5536
5841
|
uint32_t qk_ratio = neq2 / nek2;
|
|
@@ -5538,7 +5843,23 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5538
5843
|
uint32_t workgroups_y = (uint32_t)neq2;
|
|
5539
5844
|
uint32_t workgroups_z = (uint32_t)neq3;
|
|
5540
5845
|
|
|
5541
|
-
|
|
5846
|
+
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
|
5847
|
+
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
|
5848
|
+
uint32_t max_gqa;
|
|
5849
|
+
switch (path) {
|
|
5850
|
+
case FA_SCALAR:
|
|
5851
|
+
case FA_COOPMAT1:
|
|
5852
|
+
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
5853
|
+
max_gqa = scalar_flash_attention_num_large_rows;
|
|
5854
|
+
break;
|
|
5855
|
+
case FA_COOPMAT2:
|
|
5856
|
+
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
|
5857
|
+
break;
|
|
5858
|
+
default:
|
|
5859
|
+
GGML_ASSERT(0);
|
|
5860
|
+
}
|
|
5861
|
+
|
|
5862
|
+
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
5542
5863
|
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
|
5543
5864
|
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
|
5544
5865
|
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
|
@@ -5548,11 +5869,89 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5548
5869
|
workgroups_y /= N;
|
|
5549
5870
|
}
|
|
5550
5871
|
|
|
5872
|
+
vk_pipeline *pipelines;
|
|
5873
|
+
bool small_rows = N <= get_fa_num_small_rows(path);
|
|
5874
|
+
|
|
5875
|
+
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
|
5876
|
+
// So use scalar instead.
|
|
5877
|
+
if (small_rows && path == FA_COOPMAT1) {
|
|
5878
|
+
path = FA_SCALAR;
|
|
5879
|
+
}
|
|
5880
|
+
|
|
5881
|
+
// scalar is faster than coopmat2 when N==1
|
|
5882
|
+
if (N == 1 && path == FA_COOPMAT2) {
|
|
5883
|
+
path = FA_SCALAR;
|
|
5884
|
+
}
|
|
5885
|
+
|
|
5886
|
+
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
|
5887
|
+
|
|
5888
|
+
switch (path) {
|
|
5889
|
+
case FA_SCALAR:
|
|
5890
|
+
switch (D) {
|
|
5891
|
+
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
5892
|
+
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
5893
|
+
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
5894
|
+
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
5895
|
+
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
5896
|
+
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
5897
|
+
default:
|
|
5898
|
+
GGML_ASSERT(!"unsupported D value");
|
|
5899
|
+
return;
|
|
5900
|
+
}
|
|
5901
|
+
break;
|
|
5902
|
+
case FA_COOPMAT1:
|
|
5903
|
+
switch (D) {
|
|
5904
|
+
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5905
|
+
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5906
|
+
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5907
|
+
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5908
|
+
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5909
|
+
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5910
|
+
default:
|
|
5911
|
+
GGML_ASSERT(!"unsupported D value");
|
|
5912
|
+
return;
|
|
5913
|
+
}
|
|
5914
|
+
break;
|
|
5915
|
+
case FA_COOPMAT2:
|
|
5916
|
+
switch (D) {
|
|
5917
|
+
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5918
|
+
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5919
|
+
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5920
|
+
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5921
|
+
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5922
|
+
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5923
|
+
default:
|
|
5924
|
+
GGML_ASSERT(!"unsupported D value");
|
|
5925
|
+
return;
|
|
5926
|
+
}
|
|
5927
|
+
break;
|
|
5928
|
+
default:
|
|
5929
|
+
GGML_ASSERT(0);
|
|
5930
|
+
}
|
|
5931
|
+
assert(pipelines);
|
|
5932
|
+
|
|
5933
|
+
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
|
|
5934
|
+
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
|
|
5935
|
+
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
|
|
5936
|
+
|
|
5937
|
+
bool aligned = (KV % pipelines[1]->align) == 0 &&
|
|
5938
|
+
// the "aligned" shader variant will forcibly align strides, for performance
|
|
5939
|
+
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
|
5940
|
+
|
|
5941
|
+
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
5942
|
+
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
|
|
5943
|
+
|
|
5944
|
+
vk_pipeline pipeline = pipelines[aligned];
|
|
5945
|
+
assert(pipeline);
|
|
5946
|
+
|
|
5551
5947
|
uint32_t split_kv = KV;
|
|
5552
5948
|
uint32_t split_k = 1;
|
|
5553
5949
|
|
|
5950
|
+
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
|
5951
|
+
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
|
5952
|
+
|
|
5554
5953
|
// Try to use split_k when KV is large enough to be worth the overhead
|
|
5555
|
-
if (workgroups_x == 1 &&
|
|
5954
|
+
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
|
5556
5955
|
// Try to run two workgroups per SM.
|
|
5557
5956
|
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
|
5558
5957
|
if (split_k > 1) {
|
|
@@ -5722,26 +6121,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5722
6121
|
}
|
|
5723
6122
|
return nullptr;
|
|
5724
6123
|
case GGML_OP_ADD:
|
|
5725
|
-
|
|
5726
|
-
|
|
6124
|
+
case GGML_OP_SUB:
|
|
6125
|
+
case GGML_OP_MUL:
|
|
6126
|
+
case GGML_OP_DIV:
|
|
6127
|
+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
|
6128
|
+
(src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
|
|
6129
|
+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
|
|
6130
|
+
return nullptr;
|
|
6131
|
+
}
|
|
6132
|
+
switch (op) {
|
|
6133
|
+
case GGML_OP_ADD:
|
|
6134
|
+
{
|
|
6135
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
|
|
6136
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
5727
6137
|
}
|
|
5728
|
-
|
|
5729
|
-
|
|
6138
|
+
case GGML_OP_SUB:
|
|
6139
|
+
{
|
|
6140
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
|
|
6141
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
5730
6142
|
}
|
|
5731
|
-
|
|
5732
|
-
|
|
5733
|
-
|
|
5734
|
-
return
|
|
6143
|
+
case GGML_OP_MUL:
|
|
6144
|
+
{
|
|
6145
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
|
|
6146
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
5735
6147
|
}
|
|
5736
|
-
|
|
5737
|
-
|
|
5738
|
-
|
|
5739
|
-
return
|
|
6148
|
+
case GGML_OP_DIV:
|
|
6149
|
+
{
|
|
6150
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
|
|
6151
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
5740
6152
|
}
|
|
5741
|
-
|
|
5742
|
-
|
|
5743
|
-
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
5744
|
-
return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
|
|
6153
|
+
default:
|
|
6154
|
+
break;
|
|
5745
6155
|
}
|
|
5746
6156
|
return nullptr;
|
|
5747
6157
|
case GGML_OP_CONCAT:
|
|
@@ -5835,37 +6245,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5835
6245
|
}
|
|
5836
6246
|
return nullptr;
|
|
5837
6247
|
case GGML_OP_UNARY:
|
|
6248
|
+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
|
6249
|
+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
|
|
6250
|
+
(src0->type != dst->type)) {
|
|
6251
|
+
return nullptr;
|
|
6252
|
+
}
|
|
6253
|
+
|
|
5838
6254
|
switch (ggml_get_unary_op(dst)) {
|
|
5839
6255
|
case GGML_UNARY_OP_SILU:
|
|
5840
|
-
|
|
5841
|
-
return ctx->device->pipeline_silu_f32;
|
|
5842
|
-
}
|
|
5843
|
-
break;
|
|
6256
|
+
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
|
5844
6257
|
case GGML_UNARY_OP_GELU:
|
|
5845
|
-
|
|
5846
|
-
return ctx->device->pipeline_gelu_f32;
|
|
5847
|
-
}
|
|
5848
|
-
break;
|
|
6258
|
+
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
|
|
5849
6259
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
5850
|
-
|
|
5851
|
-
return ctx->device->pipeline_gelu_quick_f32;
|
|
5852
|
-
}
|
|
5853
|
-
break;
|
|
6260
|
+
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
|
|
5854
6261
|
case GGML_UNARY_OP_RELU:
|
|
5855
|
-
|
|
5856
|
-
return ctx->device->pipeline_relu_f32;
|
|
5857
|
-
}
|
|
5858
|
-
break;
|
|
6262
|
+
return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
|
|
5859
6263
|
case GGML_UNARY_OP_TANH:
|
|
5860
|
-
|
|
5861
|
-
return ctx->device->pipeline_tanh_f32;
|
|
5862
|
-
}
|
|
5863
|
-
break;
|
|
6264
|
+
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
|
|
5864
6265
|
case GGML_UNARY_OP_SIGMOID:
|
|
5865
|
-
|
|
5866
|
-
return ctx->device->pipeline_sigmoid_f32;
|
|
5867
|
-
}
|
|
5868
|
-
break;
|
|
6266
|
+
return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
|
|
5869
6267
|
default:
|
|
5870
6268
|
break;
|
|
5871
6269
|
}
|
|
@@ -5988,6 +6386,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5988
6386
|
return ctx->device->pipeline_leaky_relu_f32;
|
|
5989
6387
|
}
|
|
5990
6388
|
return nullptr;
|
|
6389
|
+
case GGML_OP_CONV_2D_DW:
|
|
6390
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6391
|
+
if (ggml_is_contiguous(src1)) {
|
|
6392
|
+
return ctx->device->pipeline_conv2d_dw_whcn_f32;
|
|
6393
|
+
} else if (ggml_is_contiguous_channels(src1)) {
|
|
6394
|
+
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
|
|
6395
|
+
}
|
|
6396
|
+
}
|
|
6397
|
+
return nullptr;
|
|
5991
6398
|
default:
|
|
5992
6399
|
return nullptr;
|
|
5993
6400
|
}
|
|
@@ -6014,6 +6421,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
6014
6421
|
case GGML_OP_REPEAT_BACK:
|
|
6015
6422
|
case GGML_OP_ROPE:
|
|
6016
6423
|
case GGML_OP_RMS_NORM:
|
|
6424
|
+
case GGML_OP_CONV_2D_DW:
|
|
6017
6425
|
return true;
|
|
6018
6426
|
default:
|
|
6019
6427
|
return false;
|
|
@@ -6310,6 +6718,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
6310
6718
|
case GGML_OP_CONCAT:
|
|
6311
6719
|
case GGML_OP_UPSCALE:
|
|
6312
6720
|
case GGML_OP_UNARY:
|
|
6721
|
+
case GGML_OP_CONV_2D_DW:
|
|
6313
6722
|
{
|
|
6314
6723
|
const uint32_t ne = ggml_nelements(dst);
|
|
6315
6724
|
if (ne > 262144) {
|
|
@@ -7096,6 +7505,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
7096
7505
|
}, dryrun);
|
|
7097
7506
|
}
|
|
7098
7507
|
|
|
7508
|
+
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7509
|
+
vk_op_conv2d_dw_push_constants p{};
|
|
7510
|
+
p.ne = ggml_nelements(dst);
|
|
7511
|
+
p.channels = dst->ne[2];
|
|
7512
|
+
p.batches = dst->ne[3];
|
|
7513
|
+
p.dst_w = dst->ne[0];
|
|
7514
|
+
p.dst_h = dst->ne[1];
|
|
7515
|
+
p.src_w = src1->ne[0];
|
|
7516
|
+
p.src_h = src1->ne[1];
|
|
7517
|
+
p.knl_w = src0->ne[0];
|
|
7518
|
+
p.knl_h = src0->ne[1];
|
|
7519
|
+
p.stride_x = dst->op_params[0];
|
|
7520
|
+
p.stride_y = dst->op_params[1];
|
|
7521
|
+
p.pad_x = dst->op_params[2];
|
|
7522
|
+
p.pad_y = dst->op_params[3];
|
|
7523
|
+
p.dilation_x = dst->op_params[4];
|
|
7524
|
+
p.dilation_y = dst->op_params[5];
|
|
7525
|
+
|
|
7526
|
+
GGML_ASSERT(src0->ne[3] == p.channels);
|
|
7527
|
+
GGML_ASSERT(src1->ne[3] == p.batches);
|
|
7528
|
+
|
|
7529
|
+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
|
|
7530
|
+
}
|
|
7531
|
+
|
|
7099
7532
|
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
7100
7533
|
const float * op_params = (const float *)dst->op_params;
|
|
7101
7534
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
|
|
@@ -8116,6 +8549,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8116
8549
|
case GGML_OP_IM2COL:
|
|
8117
8550
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
8118
8551
|
case GGML_OP_POOL_2D:
|
|
8552
|
+
case GGML_OP_CONV_2D_DW:
|
|
8119
8553
|
case GGML_OP_RWKV_WKV6:
|
|
8120
8554
|
case GGML_OP_RWKV_WKV7:
|
|
8121
8555
|
case GGML_OP_LEAKY_RELU:
|
|
@@ -8179,6 +8613,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8179
8613
|
case GGML_OP_IM2COL:
|
|
8180
8614
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
8181
8615
|
case GGML_OP_POOL_2D:
|
|
8616
|
+
case GGML_OP_CONV_2D_DW:
|
|
8182
8617
|
case GGML_OP_LEAKY_RELU:
|
|
8183
8618
|
{
|
|
8184
8619
|
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
|
@@ -8352,6 +8787,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
8352
8787
|
case GGML_OP_POOL_2D:
|
|
8353
8788
|
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
|
|
8354
8789
|
|
|
8790
|
+
break;
|
|
8791
|
+
case GGML_OP_CONV_2D_DW:
|
|
8792
|
+
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
8793
|
+
|
|
8355
8794
|
break;
|
|
8356
8795
|
case GGML_OP_LEAKY_RELU:
|
|
8357
8796
|
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
|
|
@@ -8473,6 +8912,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
8473
8912
|
case GGML_OP_IM2COL:
|
|
8474
8913
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
8475
8914
|
case GGML_OP_POOL_2D:
|
|
8915
|
+
case GGML_OP_CONV_2D_DW:
|
|
8476
8916
|
case GGML_OP_RWKV_WKV6:
|
|
8477
8917
|
case GGML_OP_RWKV_WKV7:
|
|
8478
8918
|
case GGML_OP_LEAKY_RELU:
|
|
@@ -9209,7 +9649,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9209
9649
|
case GGML_UNARY_OP_RELU:
|
|
9210
9650
|
case GGML_UNARY_OP_TANH:
|
|
9211
9651
|
case GGML_UNARY_OP_SIGMOID:
|
|
9212
|
-
return ggml_is_contiguous(op->src[0]) &&
|
|
9652
|
+
return ggml_is_contiguous(op->src[0]) &&
|
|
9653
|
+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
9654
|
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
9655
|
+
(op->src[0]->type == op->type);
|
|
9213
9656
|
default:
|
|
9214
9657
|
return false;
|
|
9215
9658
|
}
|
|
@@ -9227,6 +9670,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9227
9670
|
switch (src0_type) {
|
|
9228
9671
|
case GGML_TYPE_F32:
|
|
9229
9672
|
case GGML_TYPE_F16:
|
|
9673
|
+
case GGML_TYPE_BF16:
|
|
9230
9674
|
case GGML_TYPE_Q4_0:
|
|
9231
9675
|
case GGML_TYPE_Q4_1:
|
|
9232
9676
|
case GGML_TYPE_Q5_0:
|
|
@@ -9262,19 +9706,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9262
9706
|
if (a->ne[3] != b->ne[3]) {
|
|
9263
9707
|
return false;
|
|
9264
9708
|
}
|
|
9265
|
-
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
|
|
9709
|
+
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
|
|
9266
9710
|
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
|
9267
9711
|
return false;
|
|
9268
9712
|
}
|
|
9713
|
+
if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
|
|
9714
|
+
// We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
|
|
9715
|
+
// So don't support this combination for now.
|
|
9716
|
+
return false;
|
|
9717
|
+
}
|
|
9269
9718
|
|
|
9270
9719
|
return true;
|
|
9271
9720
|
} break;
|
|
9272
9721
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
9273
9722
|
{
|
|
9274
9723
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
9275
|
-
|
|
9276
|
-
|
|
9277
|
-
}
|
|
9724
|
+
auto device = ggml_vk_get_device(ctx->device);
|
|
9725
|
+
bool coopmat2 = device->coopmat2;
|
|
9278
9726
|
switch (op->src[0]->ne[0]) {
|
|
9279
9727
|
case 64:
|
|
9280
9728
|
case 80:
|
|
@@ -9282,7 +9730,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9282
9730
|
case 112:
|
|
9283
9731
|
case 128:
|
|
9284
9732
|
case 256:
|
|
9285
|
-
case 575: // DeepSeek MLA
|
|
9286
9733
|
break;
|
|
9287
9734
|
default:
|
|
9288
9735
|
return false;
|
|
@@ -9308,10 +9755,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9308
9755
|
switch (op->src[1]->type) {
|
|
9309
9756
|
case GGML_TYPE_F16:
|
|
9310
9757
|
case GGML_TYPE_Q4_0:
|
|
9758
|
+
case GGML_TYPE_Q8_0:
|
|
9759
|
+
// supported in scalar and coopmat2 paths
|
|
9760
|
+
break;
|
|
9311
9761
|
case GGML_TYPE_Q4_1:
|
|
9312
9762
|
case GGML_TYPE_Q5_0:
|
|
9313
9763
|
case GGML_TYPE_Q5_1:
|
|
9314
|
-
case GGML_TYPE_Q8_0:
|
|
9315
9764
|
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
|
9316
9765
|
//case GGML_TYPE_Q2_K:
|
|
9317
9766
|
//case GGML_TYPE_Q3_K:
|
|
@@ -9327,10 +9776,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9327
9776
|
//case GGML_TYPE_IQ3_S:
|
|
9328
9777
|
//case GGML_TYPE_IQ4_XS:
|
|
9329
9778
|
case GGML_TYPE_IQ4_NL:
|
|
9779
|
+
// currently supported only in coopmat2 path
|
|
9780
|
+
if (!coopmat2) {
|
|
9781
|
+
return false;
|
|
9782
|
+
}
|
|
9330
9783
|
break;
|
|
9331
9784
|
default:
|
|
9332
9785
|
return false;
|
|
9333
9786
|
}
|
|
9787
|
+
if (!coopmat2 && !device->subgroup_shuffle) {
|
|
9788
|
+
// scalar FA uses subgroupShuffle
|
|
9789
|
+
return false;
|
|
9790
|
+
}
|
|
9334
9791
|
return true;
|
|
9335
9792
|
}
|
|
9336
9793
|
case GGML_OP_GET_ROWS:
|
|
@@ -9338,6 +9795,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9338
9795
|
switch (op->src[0]->type) {
|
|
9339
9796
|
case GGML_TYPE_F32:
|
|
9340
9797
|
case GGML_TYPE_F16:
|
|
9798
|
+
case GGML_TYPE_BF16:
|
|
9341
9799
|
case GGML_TYPE_Q4_0:
|
|
9342
9800
|
case GGML_TYPE_Q4_1:
|
|
9343
9801
|
case GGML_TYPE_Q5_0:
|
|
@@ -9368,6 +9826,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9368
9826
|
switch (src1_type) {
|
|
9369
9827
|
case GGML_TYPE_F32:
|
|
9370
9828
|
case GGML_TYPE_F16:
|
|
9829
|
+
case GGML_TYPE_BF16:
|
|
9371
9830
|
case GGML_TYPE_Q4_0:
|
|
9372
9831
|
case GGML_TYPE_Q4_1:
|
|
9373
9832
|
case GGML_TYPE_Q5_0:
|
|
@@ -9381,6 +9840,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9381
9840
|
}
|
|
9382
9841
|
if (src1_type == GGML_TYPE_F32) {
|
|
9383
9842
|
switch (src0_type) {
|
|
9843
|
+
case GGML_TYPE_F16:
|
|
9384
9844
|
case GGML_TYPE_Q4_0:
|
|
9385
9845
|
case GGML_TYPE_Q4_1:
|
|
9386
9846
|
case GGML_TYPE_Q5_0:
|
|
@@ -9419,6 +9879,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9419
9879
|
case GGML_OP_SUB:
|
|
9420
9880
|
case GGML_OP_MUL:
|
|
9421
9881
|
case GGML_OP_DIV:
|
|
9882
|
+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
9883
|
+
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
|
9884
|
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
|
9422
9885
|
case GGML_OP_SILU_BACK:
|
|
9423
9886
|
case GGML_OP_RMS_NORM_BACK:
|
|
9424
9887
|
case GGML_OP_SQR:
|
|
@@ -9442,6 +9905,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
9442
9905
|
case GGML_OP_COUNT_EQUAL:
|
|
9443
9906
|
case GGML_OP_IM2COL:
|
|
9444
9907
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
9908
|
+
case GGML_OP_CONV_2D_DW:
|
|
9445
9909
|
case GGML_OP_POOL_2D:
|
|
9446
9910
|
case GGML_OP_RWKV_WKV6:
|
|
9447
9911
|
case GGML_OP_RWKV_WKV7:
|