@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -24,13 +24,54 @@
|
|
|
24
24
|
#include <future>
|
|
25
25
|
#include <thread>
|
|
26
26
|
|
|
27
|
+
#if defined(_MSC_VER)
|
|
28
|
+
# define NOMINMAX 1
|
|
29
|
+
# include <windows.h>
|
|
30
|
+
# define YIELD() YieldProcessor()
|
|
31
|
+
#elif defined(__clang__) || defined(__GNUC__)
|
|
32
|
+
# if defined(__x86_64__) ||defined(__i386__)
|
|
33
|
+
# include <immintrin.h>
|
|
34
|
+
# define YIELD() _mm_pause()
|
|
35
|
+
# elif defined(__arm__) || defined(__aarch64__)
|
|
36
|
+
# if defined(__clang__)
|
|
37
|
+
# include <arm_acle.h>
|
|
38
|
+
# define YIELD() __yield()
|
|
39
|
+
# else
|
|
40
|
+
# define YIELD() asm volatile("yield")
|
|
41
|
+
# endif
|
|
42
|
+
# endif
|
|
43
|
+
#endif
|
|
44
|
+
|
|
45
|
+
#if !defined(YIELD)
|
|
46
|
+
#define YIELD()
|
|
47
|
+
#endif
|
|
48
|
+
|
|
27
49
|
#include "ggml-impl.h"
|
|
28
50
|
#include "ggml-backend-impl.h"
|
|
29
51
|
|
|
30
52
|
#include "ggml-vulkan-shaders.hpp"
|
|
31
53
|
|
|
54
|
+
// remove this once it's more widely available in the SDK
|
|
55
|
+
#if !defined(VK_KHR_shader_bfloat16)
|
|
56
|
+
|
|
57
|
+
#define VK_KHR_shader_bfloat16 1
|
|
58
|
+
#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1
|
|
59
|
+
#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16"
|
|
60
|
+
#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000)
|
|
61
|
+
#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000)
|
|
62
|
+
|
|
63
|
+
typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR {
|
|
64
|
+
VkStructureType sType;
|
|
65
|
+
void* pNext;
|
|
66
|
+
VkBool32 shaderBFloat16Type;
|
|
67
|
+
VkBool32 shaderBFloat16DotProduct;
|
|
68
|
+
VkBool32 shaderBFloat16CooperativeMatrix;
|
|
69
|
+
} VkPhysicalDeviceShaderBfloat16FeaturesKHR;
|
|
70
|
+
#endif
|
|
71
|
+
|
|
32
72
|
#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1))
|
|
33
73
|
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
|
74
|
+
static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
|
|
34
75
|
|
|
35
76
|
#define VK_VENDOR_ID_AMD 0x1002
|
|
36
77
|
#define VK_VENDOR_ID_APPLE 0x106b
|
|
@@ -223,6 +264,7 @@ struct vk_device_struct {
|
|
|
223
264
|
bool pipeline_robustness;
|
|
224
265
|
vk::Device device;
|
|
225
266
|
uint32_t vendor_id;
|
|
267
|
+
vk::DriverId driver_id;
|
|
226
268
|
vk_device_architecture architecture;
|
|
227
269
|
vk_queue compute_queue;
|
|
228
270
|
vk_queue transfer_queue;
|
|
@@ -233,6 +275,9 @@ struct vk_device_struct {
|
|
|
233
275
|
bool prefer_host_memory;
|
|
234
276
|
bool float_controls_rte_fp16;
|
|
235
277
|
bool subgroup_add;
|
|
278
|
+
bool subgroup_shuffle;
|
|
279
|
+
|
|
280
|
+
bool integer_dot_product;
|
|
236
281
|
|
|
237
282
|
bool subgroup_size_control;
|
|
238
283
|
uint32_t subgroup_min_size;
|
|
@@ -240,11 +285,21 @@ struct vk_device_struct {
|
|
|
240
285
|
bool subgroup_require_full_support;
|
|
241
286
|
|
|
242
287
|
bool coopmat_support;
|
|
243
|
-
bool coopmat_acc_f32_support;
|
|
244
|
-
bool coopmat_acc_f16_support;
|
|
288
|
+
bool coopmat_acc_f32_support {};
|
|
289
|
+
bool coopmat_acc_f16_support {};
|
|
290
|
+
bool coopmat_bf16_support {};
|
|
291
|
+
bool coopmat_support_16x16x16_f16acc {};
|
|
292
|
+
bool coopmat_support_16x16x16_f32acc {};
|
|
293
|
+
bool coopmat1_fa_support {};
|
|
245
294
|
uint32_t coopmat_m;
|
|
246
295
|
uint32_t coopmat_n;
|
|
247
296
|
uint32_t coopmat_k;
|
|
297
|
+
|
|
298
|
+
bool coopmat_int_support;
|
|
299
|
+
uint32_t coopmat_int_m;
|
|
300
|
+
uint32_t coopmat_int_n;
|
|
301
|
+
uint32_t coopmat_int_k;
|
|
302
|
+
|
|
248
303
|
bool coopmat2;
|
|
249
304
|
|
|
250
305
|
size_t idx;
|
|
@@ -261,19 +316,24 @@ struct vk_device_struct {
|
|
|
261
316
|
|
|
262
317
|
vk_matmul_pipeline pipeline_matmul_f32 {};
|
|
263
318
|
vk_matmul_pipeline pipeline_matmul_f32_f16 {};
|
|
319
|
+
vk_matmul_pipeline pipeline_matmul_bf16 {};
|
|
264
320
|
vk_matmul_pipeline2 pipeline_matmul_f16;
|
|
265
321
|
vk_matmul_pipeline2 pipeline_matmul_f16_f32;
|
|
266
|
-
vk_pipeline pipeline_matmul_split_k_reduce;
|
|
267
322
|
|
|
268
|
-
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
|
|
269
323
|
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT];
|
|
324
|
+
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT];
|
|
325
|
+
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT];
|
|
270
326
|
|
|
271
327
|
vk_matmul_pipeline pipeline_matmul_id_f32 {};
|
|
328
|
+
vk_matmul_pipeline pipeline_matmul_id_bf16 {};
|
|
272
329
|
vk_matmul_pipeline2 pipeline_matmul_id_f16;
|
|
273
330
|
vk_matmul_pipeline2 pipeline_matmul_id_f16_f32;
|
|
274
331
|
|
|
275
332
|
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
|
276
333
|
|
|
334
|
+
vk_pipeline pipeline_matmul_split_k_reduce;
|
|
335
|
+
vk_pipeline pipeline_quantize_q8_1;
|
|
336
|
+
|
|
277
337
|
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
|
278
338
|
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
|
279
339
|
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
|
@@ -284,11 +344,17 @@ struct vk_device_struct {
|
|
|
284
344
|
vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
|
|
285
345
|
vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
|
|
286
346
|
vk_pipeline pipeline_acc_f32;
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
vk_pipeline
|
|
290
|
-
vk_pipeline
|
|
291
|
-
vk_pipeline
|
|
347
|
+
|
|
348
|
+
// [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
|
|
349
|
+
vk_pipeline pipeline_add[2][2][2];
|
|
350
|
+
vk_pipeline pipeline_add_norepeat[2][2][2];
|
|
351
|
+
vk_pipeline pipeline_sub[2][2][2];
|
|
352
|
+
vk_pipeline pipeline_sub_norepeat[2][2][2];
|
|
353
|
+
vk_pipeline pipeline_mul[2][2][2];
|
|
354
|
+
vk_pipeline pipeline_mul_norepeat[2][2][2];
|
|
355
|
+
vk_pipeline pipeline_div[2][2][2];
|
|
356
|
+
vk_pipeline pipeline_div_norepeat[2][2][2];
|
|
357
|
+
|
|
292
358
|
vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
|
|
293
359
|
vk_pipeline pipeline_upscale_f32;
|
|
294
360
|
vk_pipeline pipeline_scale_f32;
|
|
@@ -298,8 +364,8 @@ struct vk_device_struct {
|
|
|
298
364
|
vk_pipeline pipeline_clamp_f32;
|
|
299
365
|
vk_pipeline pipeline_pad_f32;
|
|
300
366
|
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
|
|
301
|
-
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16;
|
|
302
|
-
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16;
|
|
367
|
+
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
|
|
368
|
+
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
|
|
303
369
|
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
|
|
304
370
|
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
|
|
305
371
|
vk_pipeline pipeline_norm_f32;
|
|
@@ -307,14 +373,17 @@ struct vk_device_struct {
|
|
|
307
373
|
vk_pipeline pipeline_rms_norm_f32;
|
|
308
374
|
vk_pipeline pipeline_rms_norm_back_f32;
|
|
309
375
|
vk_pipeline pipeline_l2_norm_f32;
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
vk_pipeline
|
|
313
|
-
vk_pipeline
|
|
314
|
-
vk_pipeline
|
|
376
|
+
|
|
377
|
+
// [src/dst 0=fp32,1=fp16]
|
|
378
|
+
vk_pipeline pipeline_gelu[2];
|
|
379
|
+
vk_pipeline pipeline_gelu_quick[2];
|
|
380
|
+
vk_pipeline pipeline_silu[2];
|
|
381
|
+
vk_pipeline pipeline_relu[2];
|
|
382
|
+
vk_pipeline pipeline_tanh[2];
|
|
383
|
+
vk_pipeline pipeline_sigmoid[2];
|
|
384
|
+
|
|
315
385
|
vk_pipeline pipeline_leaky_relu_f32;
|
|
316
|
-
vk_pipeline
|
|
317
|
-
vk_pipeline pipeline_sigmoid_f32;
|
|
386
|
+
vk_pipeline pipeline_silu_back_f32;
|
|
318
387
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
|
319
388
|
vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
|
|
320
389
|
vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
|
|
@@ -333,8 +402,24 @@ struct vk_device_struct {
|
|
|
333
402
|
vk_pipeline pipeline_rwkv_wkv6_f32;
|
|
334
403
|
vk_pipeline pipeline_rwkv_wkv7_f32;
|
|
335
404
|
vk_pipeline pipeline_opt_step_adamw_f32;
|
|
405
|
+
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
|
406
|
+
vk_pipeline pipeline_conv2d_dw_cwhn_f32;
|
|
336
407
|
|
|
337
408
|
// [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
|
|
409
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
410
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
411
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
412
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
413
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
414
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm2[GGML_TYPE_COUNT][2][2][2];
|
|
415
|
+
|
|
416
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D64_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
417
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D80_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
418
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D96_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
419
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D112_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
420
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D128_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
421
|
+
vk_pipeline pipeline_flash_attn_f32_f16_D256_cm1[GGML_TYPE_COUNT][2][2][2];
|
|
422
|
+
|
|
338
423
|
vk_pipeline pipeline_flash_attn_f32_f16_D64[GGML_TYPE_COUNT][2][2][2];
|
|
339
424
|
vk_pipeline pipeline_flash_attn_f32_f16_D80[GGML_TYPE_COUNT][2][2][2];
|
|
340
425
|
vk_pipeline pipeline_flash_attn_f32_f16_D96[GGML_TYPE_COUNT][2][2][2];
|
|
@@ -342,6 +427,8 @@ struct vk_device_struct {
|
|
|
342
427
|
vk_pipeline pipeline_flash_attn_f32_f16_D128[GGML_TYPE_COUNT][2][2][2];
|
|
343
428
|
vk_pipeline pipeline_flash_attn_f32_f16_D256[GGML_TYPE_COUNT][2][2][2];
|
|
344
429
|
|
|
430
|
+
vk_pipeline pipeline_flash_attn_split_k_reduce;
|
|
431
|
+
|
|
345
432
|
std::unordered_map<std::string, vk_pipeline_ref> pipelines;
|
|
346
433
|
std::unordered_map<std::string, uint64_t> pipeline_descriptor_set_requirements;
|
|
347
434
|
|
|
@@ -490,6 +577,10 @@ struct vk_flash_attn_push_constants {
|
|
|
490
577
|
uint32_t n_head_log2;
|
|
491
578
|
float m0;
|
|
492
579
|
float m1;
|
|
580
|
+
|
|
581
|
+
uint32_t gqa_ratio;
|
|
582
|
+
uint32_t split_kv;
|
|
583
|
+
uint32_t k_num;
|
|
493
584
|
};
|
|
494
585
|
|
|
495
586
|
struct vk_op_push_constants {
|
|
@@ -640,13 +731,22 @@ struct vk_op_rwkv_wkv7_push_constants {
|
|
|
640
731
|
uint32_t H;
|
|
641
732
|
};
|
|
642
733
|
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
734
|
+
struct vk_op_conv2d_dw_push_constants {
|
|
735
|
+
uint32_t ne;
|
|
736
|
+
uint32_t batches;
|
|
737
|
+
uint32_t channels;
|
|
738
|
+
uint32_t dst_w;
|
|
739
|
+
uint32_t dst_h;
|
|
740
|
+
uint32_t src_w;
|
|
741
|
+
uint32_t src_h;
|
|
742
|
+
uint32_t knl_w;
|
|
743
|
+
uint32_t knl_h;
|
|
744
|
+
int32_t stride_x;
|
|
745
|
+
int32_t stride_y;
|
|
746
|
+
int32_t pad_x;
|
|
747
|
+
int32_t pad_y;
|
|
748
|
+
int32_t dilation_x;
|
|
749
|
+
int32_t dilation_y;
|
|
650
750
|
};
|
|
651
751
|
|
|
652
752
|
struct vk_op_upscale_push_constants {
|
|
@@ -656,6 +756,15 @@ struct vk_op_upscale_push_constants {
|
|
|
656
756
|
float sf0; float sf1; float sf2; float sf3;
|
|
657
757
|
};
|
|
658
758
|
|
|
759
|
+
// Allow pre-recording command buffers
|
|
760
|
+
struct vk_staging_memcpy {
|
|
761
|
+
vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {}
|
|
762
|
+
|
|
763
|
+
void * dst;
|
|
764
|
+
const void * src;
|
|
765
|
+
size_t n;
|
|
766
|
+
};
|
|
767
|
+
|
|
659
768
|
struct vk_context_struct {
|
|
660
769
|
vk_submission * s;
|
|
661
770
|
std::vector<vk_sequence> seqs;
|
|
@@ -770,7 +879,8 @@ struct ggml_backend_vk_context {
|
|
|
770
879
|
ggml_vk_garbage_collector gc;
|
|
771
880
|
size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k;
|
|
772
881
|
vk_buffer prealloc_x, prealloc_y, prealloc_split_k;
|
|
773
|
-
vk::Fence fence;
|
|
882
|
+
vk::Fence fence, almost_ready_fence;
|
|
883
|
+
bool almost_ready_fence_pending {};
|
|
774
884
|
|
|
775
885
|
vk_buffer buffer_pool[MAX_VK_BUFFERS];
|
|
776
886
|
|
|
@@ -861,6 +971,39 @@ typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
861
971
|
|
|
862
972
|
static void ggml_backend_vk_free(ggml_backend_t backend);
|
|
863
973
|
|
|
974
|
+
// Wait for ctx->fence to be signaled.
|
|
975
|
+
static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) {
|
|
976
|
+
// Use waitForFences while most of the graph executes. Hopefully the CPU can sleep
|
|
977
|
+
// during this wait.
|
|
978
|
+
if (ctx->almost_ready_fence_pending) {
|
|
979
|
+
VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence");
|
|
980
|
+
ctx->device->device.resetFences({ ctx->almost_ready_fence });
|
|
981
|
+
ctx->almost_ready_fence_pending = false;
|
|
982
|
+
}
|
|
983
|
+
|
|
984
|
+
// Spin (w/pause) waiting for the graph to finish executing.
|
|
985
|
+
vk::Result result;
|
|
986
|
+
while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) {
|
|
987
|
+
if (result != vk::Result::eNotReady) {
|
|
988
|
+
fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__);
|
|
989
|
+
exit(1);
|
|
990
|
+
}
|
|
991
|
+
for (uint32_t i = 0; i < 100; ++i) {
|
|
992
|
+
YIELD();
|
|
993
|
+
YIELD();
|
|
994
|
+
YIELD();
|
|
995
|
+
YIELD();
|
|
996
|
+
YIELD();
|
|
997
|
+
YIELD();
|
|
998
|
+
YIELD();
|
|
999
|
+
YIELD();
|
|
1000
|
+
YIELD();
|
|
1001
|
+
YIELD();
|
|
1002
|
+
}
|
|
1003
|
+
}
|
|
1004
|
+
ctx->device->device.resetFences({ ctx->fence });
|
|
1005
|
+
}
|
|
1006
|
+
|
|
864
1007
|
// variables to track number of compiles in progress
|
|
865
1008
|
static uint32_t compile_count = 0;
|
|
866
1009
|
static std::mutex compile_count_mutex;
|
|
@@ -1455,15 +1598,56 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector<vk::Event>&& events
|
|
|
1455
1598
|
);
|
|
1456
1599
|
}
|
|
1457
1600
|
|
|
1601
|
+
enum FaCodePath {
|
|
1602
|
+
FA_SCALAR,
|
|
1603
|
+
FA_COOPMAT1,
|
|
1604
|
+
FA_COOPMAT2,
|
|
1605
|
+
};
|
|
1606
|
+
|
|
1458
1607
|
// number of rows/cols for flash attention shader
|
|
1459
1608
|
static constexpr uint32_t flash_attention_num_small_rows = 32;
|
|
1460
|
-
static
|
|
1609
|
+
static constexpr uint32_t scalar_flash_attention_num_small_rows = 1;
|
|
1610
|
+
static constexpr uint32_t scalar_flash_attention_num_large_rows = 8;
|
|
1611
|
+
|
|
1612
|
+
// The FA coopmat1 shader assumes 16x16x16 matrix multiply support.
|
|
1613
|
+
// 128 threads split into four subgroups, each subgroup does 1/4
|
|
1614
|
+
// of the Bc dimension.
|
|
1615
|
+
static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16;
|
|
1616
|
+
static constexpr uint32_t scalar_flash_attention_Bc = 64;
|
|
1617
|
+
static constexpr uint32_t scalar_flash_attention_workgroup_size = 128;
|
|
1618
|
+
|
|
1619
|
+
static uint32_t get_fa_num_small_rows(FaCodePath path) {
|
|
1620
|
+
if (path == FA_COOPMAT2) {
|
|
1621
|
+
return flash_attention_num_small_rows;
|
|
1622
|
+
} else {
|
|
1623
|
+
return scalar_flash_attention_num_small_rows;
|
|
1624
|
+
}
|
|
1625
|
+
}
|
|
1626
|
+
|
|
1627
|
+
static std::array<uint32_t, 2> fa_rows_cols(FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) {
|
|
1461
1628
|
GGML_UNUSED(clamp);
|
|
1462
1629
|
|
|
1630
|
+
if (path == FA_SCALAR) {
|
|
1631
|
+
if (small_rows) {
|
|
1632
|
+
return {scalar_flash_attention_num_small_rows, 64};
|
|
1633
|
+
} else {
|
|
1634
|
+
return {scalar_flash_attention_num_large_rows, 32};
|
|
1635
|
+
}
|
|
1636
|
+
}
|
|
1637
|
+
|
|
1638
|
+
if (path == FA_COOPMAT1) {
|
|
1639
|
+
if (small_rows) {
|
|
1640
|
+
return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc};
|
|
1641
|
+
} else {
|
|
1642
|
+
return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc};
|
|
1643
|
+
}
|
|
1644
|
+
}
|
|
1645
|
+
|
|
1463
1646
|
// small rows, large cols
|
|
1464
1647
|
if (small_rows) {
|
|
1465
|
-
return {
|
|
1648
|
+
return {get_fa_num_small_rows(FA_COOPMAT2), 32};
|
|
1466
1649
|
}
|
|
1650
|
+
|
|
1467
1651
|
// small cols to reduce register count
|
|
1468
1652
|
if (ggml_is_quantized(type) || D == 256) {
|
|
1469
1653
|
return {64, 32};
|
|
@@ -1508,7 +1692,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
|
|
|
1508
1692
|
const uint32_t warps = warptile[0] / warptile[10];
|
|
1509
1693
|
|
|
1510
1694
|
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
|
|
1511
|
-
const uint32_t mmid_row_ids = mul_mat_id ?
|
|
1695
|
+
const uint32_t mmid_row_ids = mul_mat_id ? 4096 * sizeof(uint32_t) : 0;
|
|
1512
1696
|
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
|
|
1513
1697
|
|
|
1514
1698
|
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size;
|
|
@@ -1598,6 +1782,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1598
1782
|
// mulmat
|
|
1599
1783
|
std::vector<uint32_t> l_warptile, m_warptile, s_warptile,
|
|
1600
1784
|
l_warptile_mmq, m_warptile_mmq, s_warptile_mmq,
|
|
1785
|
+
l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int,
|
|
1601
1786
|
l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k,
|
|
1602
1787
|
l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid;
|
|
1603
1788
|
std::array<uint32_t, 3> l_wg_denoms, m_wg_denoms, s_wg_denoms,
|
|
@@ -1662,6 +1847,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1662
1847
|
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
|
|
1663
1848
|
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
|
|
1664
1849
|
|
|
1850
|
+
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
|
1851
|
+
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
|
|
1852
|
+
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
|
|
1853
|
+
|
|
1854
|
+
// chip specific tuning
|
|
1855
|
+
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
|
|
1856
|
+
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
|
|
1857
|
+
}
|
|
1858
|
+
|
|
1665
1859
|
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
|
|
1666
1860
|
m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 };
|
|
1667
1861
|
s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 };
|
|
@@ -1707,6 +1901,12 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1707
1901
|
if (!device->pipeline_matmul_id_f32) {
|
|
1708
1902
|
device->pipeline_matmul_id_f32 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1709
1903
|
}
|
|
1904
|
+
if (!device->pipeline_matmul_bf16) {
|
|
1905
|
+
device->pipeline_matmul_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1906
|
+
}
|
|
1907
|
+
if (!device->pipeline_matmul_id_bf16) {
|
|
1908
|
+
device->pipeline_matmul_id_bf16 = std::make_shared<vk_matmul_pipeline_struct>();
|
|
1909
|
+
}
|
|
1710
1910
|
|
|
1711
1911
|
std::vector<std::future<void>> compiles;
|
|
1712
1912
|
auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const std::string &entrypoint,
|
|
@@ -1742,63 +1942,75 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1742
1942
|
parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size));
|
|
1743
1943
|
};
|
|
1744
1944
|
|
|
1745
|
-
|
|
1746
|
-
|
|
1747
|
-
|
|
1748
|
-
auto const &fa_wg_denoms = [&](uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
1749
|
-
return {fa_rows_cols(D, clamp, type, small_rows)[0], 1, 1};
|
|
1750
|
-
};
|
|
1945
|
+
auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::array<uint32_t, 3> {
|
|
1946
|
+
return {fa_rows_cols(path, D, clamp, type, small_rows)[0], 1, 1};
|
|
1947
|
+
};
|
|
1751
1948
|
|
|
1752
|
-
|
|
1753
|
-
|
|
1754
|
-
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
|
|
1759
|
-
|
|
1949
|
+
auto const &fa_spec_constants = [&](FaCodePath path, uint32_t D, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector<uint32_t> {
|
|
1950
|
+
// For large number of rows, 128 invocations seems to work best.
|
|
1951
|
+
// For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we
|
|
1952
|
+
// can't use 256 for D==80.
|
|
1953
|
+
// For scalar, use 128 (arbitrary)
|
|
1954
|
+
uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1)
|
|
1955
|
+
? scalar_flash_attention_workgroup_size
|
|
1956
|
+
: ((small_rows && (D % 32) == 0) ? 256 : 128);
|
|
1957
|
+
auto rows_cols = fa_rows_cols(path, D, clamp, type, small_rows);
|
|
1958
|
+
|
|
1959
|
+
// D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it.
|
|
1960
|
+
// D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader.
|
|
1961
|
+
const uint32_t D_lsb = D ^ (D & (D-1));
|
|
1962
|
+
uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4);
|
|
1963
|
+
|
|
1964
|
+
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
1965
|
+
GGML_ASSERT((GGML_KQ_MASK_PAD % rows_cols[0]) == 0);
|
|
1966
|
+
return {wg_size, rows_cols[0], rows_cols[1], (D), clamp, D_split};
|
|
1967
|
+
};
|
|
1760
1968
|
|
|
1761
|
-
#define CREATE_FA2(TYPE, NAMELC, D) \
|
|
1762
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1763
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1764
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1765
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1766
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1767
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1768
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1769
|
-
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC, flash_attn_f32_f16_ ## NAMELC ##
|
|
1770
|
-
|
|
1771
|
-
#define CREATE_FA(TYPE, NAMELC) \
|
|
1772
|
-
CREATE_FA2(TYPE, NAMELC, 64) \
|
|
1773
|
-
CREATE_FA2(TYPE, NAMELC, 80) \
|
|
1774
|
-
CREATE_FA2(TYPE, NAMELC, 96) \
|
|
1775
|
-
CREATE_FA2(TYPE, NAMELC, 112) \
|
|
1776
|
-
CREATE_FA2(TYPE, NAMELC, 128) \
|
|
1777
|
-
CREATE_FA2(TYPE, NAMELC, 256)
|
|
1778
|
-
|
|
1779
|
-
|
|
1780
|
-
|
|
1781
|
-
|
|
1782
|
-
|
|
1783
|
-
|
|
1784
|
-
CREATE_FA(
|
|
1785
|
-
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
|
|
1789
|
-
|
|
1790
|
-
|
|
1791
|
-
|
|
1792
|
-
|
|
1793
|
-
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1969
|
+
#define CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, D) \
|
|
1970
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][0], "flash_attn_f32_f16_D" #D "_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1971
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][0][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1972
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][0], "flash_attn_f32_f16_D" #D "_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,false), fa_spec_constants(FAPATH, D,1,TYPE,false), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1973
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][0][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,false), fa_spec_constants(FAPATH, D,0,TYPE,false), fa_rows_cols(FAPATH,D,0,TYPE,false)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1974
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][0], "flash_attn_f32_f16_D" #D "_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1975
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][0][1][1], "flash_attn_f32_f16_D" #D "_aligned_f16acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1976
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][0], "flash_attn_f32_f16_D" #D "_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,1,TYPE,true), fa_spec_constants(FAPATH, D,1,TYPE,true), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1977
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_f32_f16_D ## D ## SUFFIX[TYPE][1][1][1], "flash_attn_f32_f16_D" #D "_aligned_f32acc_smallrows" #NAMELC #SUFFIX, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 5, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, D,0,TYPE,true), fa_spec_constants(FAPATH, D,0,TYPE,true), fa_rows_cols(FAPATH,D,0,TYPE,true)[1], true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \
|
|
1978
|
+
|
|
1979
|
+
#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \
|
|
1980
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 64) \
|
|
1981
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 80) \
|
|
1982
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 96) \
|
|
1983
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 112) \
|
|
1984
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 128) \
|
|
1985
|
+
CREATE_FA2(TYPE, NAMELC, FAPATH, SUFFIX, 256)
|
|
1986
|
+
|
|
1987
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
|
|
1988
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
|
|
1989
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
|
|
1990
|
+
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
1991
|
+
if (device->coopmat1_fa_support) {
|
|
1992
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
|
|
1993
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
|
|
1994
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
|
|
1995
|
+
}
|
|
1996
|
+
#endif
|
|
1997
|
+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
1998
|
+
if (device->coopmat2) {
|
|
1999
|
+
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
|
|
2000
|
+
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
|
|
2001
|
+
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
|
|
2002
|
+
CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2)
|
|
2003
|
+
CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2)
|
|
2004
|
+
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2)
|
|
2005
|
+
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2)
|
|
2006
|
+
}
|
|
2007
|
+
#endif
|
|
2008
|
+
#undef CREATE_FA2
|
|
1800
2009
|
#undef CREATE_FA
|
|
1801
2010
|
|
|
2011
|
+
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
2012
|
+
if (device->coopmat2) {
|
|
2013
|
+
|
|
1802
2014
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1803
2015
|
#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
1804
2016
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
@@ -1814,6 +2026,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1814
2026
|
CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \
|
|
1815
2027
|
|
|
1816
2028
|
CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
2029
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2030
|
+
if (device->coopmat_bf16_support) {
|
|
2031
|
+
CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
|
|
2032
|
+
}
|
|
2033
|
+
#endif
|
|
1817
2034
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1818
2035
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1819
2036
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
@@ -1835,6 +2052,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1835
2052
|
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
|
|
1836
2053
|
|
|
1837
2054
|
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
2055
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2056
|
+
if (device->coopmat_bf16_support) {
|
|
2057
|
+
CREATE_MM(pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
|
|
2058
|
+
}
|
|
2059
|
+
#endif
|
|
1838
2060
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1839
2061
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
1840
2062
|
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
|
|
@@ -1863,17 +2085,17 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1863
2085
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
1864
2086
|
#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
1865
2087
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1866
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ##
|
|
2088
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \
|
|
1867
2089
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1868
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ##
|
|
2090
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \
|
|
1869
2091
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1870
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ##
|
|
2092
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \
|
|
1871
2093
|
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
1872
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ##
|
|
2094
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \
|
|
1873
2095
|
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
1874
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ##
|
|
2096
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \
|
|
1875
2097
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
1876
|
-
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ##
|
|
2098
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \
|
|
1877
2099
|
|
|
1878
2100
|
// Create 2 variants, {f16,f32} accumulator
|
|
1879
2101
|
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
@@ -1888,6 +2110,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1888
2110
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1889
2111
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
1890
2112
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2113
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2114
|
+
if (device->coopmat_bf16_support) {
|
|
2115
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, )
|
|
2116
|
+
}
|
|
2117
|
+
#endif
|
|
1891
2118
|
|
|
1892
2119
|
if (device->coopmat_acc_f16_support) {
|
|
1893
2120
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -1936,6 +2163,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
1936
2163
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1937
2164
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
1938
2165
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2166
|
+
#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
2167
|
+
if (device->coopmat_bf16_support) {
|
|
2168
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2169
|
+
}
|
|
2170
|
+
#endif
|
|
1939
2171
|
|
|
1940
2172
|
if (device->coopmat_acc_f16_support) {
|
|
1941
2173
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2000,6 +2232,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2000
2232
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2001
2233
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
2002
2234
|
|
|
2235
|
+
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2236
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2237
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
2238
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
2239
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
2240
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2241
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
2242
|
+
|
|
2003
2243
|
// Create 2 variants, {f16,f32} accumulator
|
|
2004
2244
|
#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2005
2245
|
CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
@@ -2010,6 +2250,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2010
2250
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2011
2251
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2012
2252
|
|
|
2253
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2254
|
+
|
|
2013
2255
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2014
2256
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2015
2257
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2031,10 +2273,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2031
2273
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2032
2274
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2033
2275
|
|
|
2276
|
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2277
|
+
if (device->integer_dot_product) {
|
|
2278
|
+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2279
|
+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2280
|
+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2281
|
+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f16acc, matmul_q5_1_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2282
|
+
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f16acc, matmul_q8_0_q8_1, _f16acc, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2283
|
+
}
|
|
2284
|
+
#endif
|
|
2285
|
+
|
|
2034
2286
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2035
2287
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2036
2288
|
CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2037
2289
|
|
|
2290
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
|
2291
|
+
|
|
2038
2292
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2039
2293
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2040
2294
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2056,6 +2310,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2056
2310
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2057
2311
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2058
2312
|
#undef CREATE_MM2
|
|
2313
|
+
#undef CREATE_MMQ
|
|
2059
2314
|
#undef CREATE_MM
|
|
2060
2315
|
} else {
|
|
2061
2316
|
// Create 6 variants, {s,m,l}x{unaligned,aligned}
|
|
@@ -2073,11 +2328,21 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2073
2328
|
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2074
2329
|
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \
|
|
2075
2330
|
|
|
2331
|
+
#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \
|
|
2332
|
+
if (device->mul_mat ## ID ## _l[TYPE]) \
|
|
2333
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \
|
|
2334
|
+
if (device->mul_mat ## ID ## _m[TYPE]) \
|
|
2335
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \
|
|
2336
|
+
if (device->mul_mat ## ID ## _s[TYPE]) \
|
|
2337
|
+
ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \
|
|
2338
|
+
|
|
2076
2339
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2077
2340
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2078
2341
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2079
2342
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2080
2343
|
|
|
2344
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2345
|
+
|
|
2081
2346
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2082
2347
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2083
2348
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
@@ -2099,10 +2364,22 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2099
2364
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2100
2365
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
|
|
2101
2366
|
|
|
2367
|
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2368
|
+
if (device->integer_dot_product) {
|
|
2369
|
+
CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2370
|
+
CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2371
|
+
CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2372
|
+
CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2373
|
+
CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, , mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, );
|
|
2374
|
+
}
|
|
2375
|
+
#endif
|
|
2376
|
+
|
|
2102
2377
|
CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2103
2378
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2104
2379
|
CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id);
|
|
2105
2380
|
|
|
2381
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
|
2382
|
+
|
|
2106
2383
|
CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2107
2384
|
CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2108
2385
|
CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
@@ -2123,8 +2400,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2123
2400
|
CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2124
2401
|
CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2125
2402
|
CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
|
|
2126
|
-
#undef CREATE_MM
|
|
2127
2403
|
}
|
|
2404
|
+
// reusing CREATE_MM from the fp32 path
|
|
2405
|
+
if ((device->coopmat2 || device->coopmat_support)
|
|
2406
|
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2407
|
+
&& !device->coopmat_bf16_support
|
|
2408
|
+
#endif
|
|
2409
|
+
) {
|
|
2410
|
+
// use scalar tile sizes
|
|
2411
|
+
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
|
|
2412
|
+
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 };
|
|
2413
|
+
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 };
|
|
2414
|
+
|
|
2415
|
+
l_wg_denoms = {128, 128, 1 };
|
|
2416
|
+
m_wg_denoms = { 64, 64, 1 };
|
|
2417
|
+
s_wg_denoms = { 32, 32, 1 };
|
|
2418
|
+
|
|
2419
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, );
|
|
2420
|
+
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id);
|
|
2421
|
+
}
|
|
2422
|
+
#undef CREATE_MM
|
|
2128
2423
|
|
|
2129
2424
|
// mul mat vec
|
|
2130
2425
|
|
|
@@ -2132,7 +2427,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2132
2427
|
uint32_t rm_stdq = 1;
|
|
2133
2428
|
uint32_t rm_kq = 2;
|
|
2134
2429
|
if (device->vendor_id == VK_VENDOR_ID_AMD) {
|
|
2135
|
-
if (device->
|
|
2430
|
+
if (device->architecture == AMD_GCN) {
|
|
2136
2431
|
rm_stdq = 2;
|
|
2137
2432
|
rm_kq = 4;
|
|
2138
2433
|
}
|
|
@@ -2143,6 +2438,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2143
2438
|
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
|
2144
2439
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2145
2440
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2441
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f32_f32_len, mul_mat_vec_bf16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2146
2442
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
2147
2443
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
2148
2444
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
@@ -2165,6 +2461,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2165
2461
|
|
|
2166
2462
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2167
2463
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2464
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32_"+std::to_string(i+1), mul_mat_vec_bf16_f16_f32_len, mul_mat_vec_bf16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
|
2168
2465
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
2169
2466
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
2170
2467
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
|
@@ -2188,6 +2485,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2188
2485
|
|
|
2189
2486
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
2190
2487
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
2488
|
+
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
|
2191
2489
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
2192
2490
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
2193
2491
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
|
@@ -2233,6 +2531,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2233
2531
|
// get_rows
|
|
2234
2532
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2235
2533
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2534
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2236
2535
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2237
2536
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2238
2537
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -2250,6 +2549,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2250
2549
|
|
|
2251
2550
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2252
2551
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2552
|
+
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
|
|
2253
2553
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2254
2554
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2255
2555
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
@@ -2266,6 +2566,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2266
2566
|
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
|
|
2267
2567
|
|
|
2268
2568
|
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
|
|
2569
|
+
ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 2, 3 * sizeof(uint32_t), {1, 1, 1}, {}, 1, true);
|
|
2570
|
+
ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1);
|
|
2269
2571
|
|
|
2270
2572
|
for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) {
|
|
2271
2573
|
if (device->subgroup_add && device->subgroup_require_full_support) {
|
|
@@ -2274,21 +2576,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2274
2576
|
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true);
|
|
2275
2577
|
}
|
|
2276
2578
|
}
|
|
2277
|
-
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3,
|
|
2579
|
+
ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 9 * sizeof(uint32_t), {1, 1, 1}, {}, 1);
|
|
2278
2580
|
|
|
2279
2581
|
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2280
2582
|
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2281
|
-
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(
|
|
2583
|
+
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
|
|
2282
2584
|
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2283
2585
|
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
|
|
2284
2586
|
|
|
2285
2587
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2286
2588
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2287
2589
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2590
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2591
|
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2288
2592
|
|
|
2289
2593
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2290
2594
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2291
2595
|
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2596
|
+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2597
|
+
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2598
|
+
|
|
2292
2599
|
if (device->float_controls_rte_fp16) {
|
|
2293
2600
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
|
|
2294
2601
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
|
|
@@ -2312,19 +2619,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2312
2619
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
|
|
2313
2620
|
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
|
|
2314
2621
|
|
|
2315
|
-
|
|
2316
|
-
|
|
2317
|
-
|
|
2318
|
-
|
|
2622
|
+
auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
|
|
2623
|
+
std::string s;
|
|
2624
|
+
s += std::string(src0_f16 ? "_f16" : "_f32");
|
|
2625
|
+
s += std::string(src1_f16 ? "_f16" : "_f32");
|
|
2626
|
+
s += std::string(dst_f16 ? "_f16" : "_f32");
|
|
2627
|
+
return s;
|
|
2628
|
+
};
|
|
2319
2629
|
|
|
2320
|
-
|
|
2630
|
+
#define CREATE_BINARY(name, namemod, spec) \
|
|
2631
|
+
for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
|
|
2632
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
|
|
2633
|
+
#name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
|
|
2634
|
+
"main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
|
|
2635
|
+
|
|
2636
|
+
CREATE_BINARY(add, , {0})
|
|
2637
|
+
CREATE_BINARY(add, _norepeat, {1})
|
|
2638
|
+
CREATE_BINARY(sub, , {0})
|
|
2639
|
+
CREATE_BINARY(sub, _norepeat, {1})
|
|
2640
|
+
CREATE_BINARY(mul, , {0})
|
|
2641
|
+
CREATE_BINARY(mul, _norepeat, {1})
|
|
2642
|
+
CREATE_BINARY(div, , {0})
|
|
2643
|
+
CREATE_BINARY(div, _norepeat, {1})
|
|
2644
|
+
#undef CREATE_BINARY
|
|
2321
2645
|
|
|
2322
|
-
ggml_vk_create_pipeline(device, device->
|
|
2323
|
-
ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
|
2324
|
-
ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
|
2325
|
-
ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
|
2326
|
-
ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
|
|
2327
|
-
ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
|
|
2646
|
+
ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2328
2647
|
|
|
2329
2648
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
2330
2649
|
ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
|
|
@@ -2345,14 +2664,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2345
2664
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2346
2665
|
ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
|
|
2347
2666
|
|
|
2348
|
-
|
|
2349
|
-
ggml_vk_create_pipeline(device, device->
|
|
2350
|
-
ggml_vk_create_pipeline(device, device->
|
|
2351
|
-
|
|
2352
|
-
|
|
2667
|
+
#define CREATE_UNARY(name) \
|
|
2668
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \
|
|
2669
|
+
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2670
|
+
|
|
2671
|
+
CREATE_UNARY(gelu)
|
|
2672
|
+
CREATE_UNARY(gelu_quick)
|
|
2673
|
+
CREATE_UNARY(silu)
|
|
2674
|
+
CREATE_UNARY(relu)
|
|
2675
|
+
CREATE_UNARY(tanh)
|
|
2676
|
+
CREATE_UNARY(sigmoid)
|
|
2677
|
+
#undef CREATE_UNARY
|
|
2678
|
+
|
|
2353
2679
|
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2354
|
-
ggml_vk_create_pipeline(device, device->
|
|
2355
|
-
ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2680
|
+
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2356
2681
|
|
|
2357
2682
|
ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
|
|
2358
2683
|
|
|
@@ -2404,6 +2729,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
|
|
2404
2729
|
|
|
2405
2730
|
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
|
2406
2731
|
|
|
2732
|
+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
2733
|
+
ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
|
|
2734
|
+
|
|
2407
2735
|
for (auto &c : compiles) {
|
|
2408
2736
|
c.wait();
|
|
2409
2737
|
}
|
|
@@ -2452,6 +2780,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2452
2780
|
bool pipeline_robustness = false;
|
|
2453
2781
|
bool coopmat2_support = false;
|
|
2454
2782
|
device->coopmat_support = false;
|
|
2783
|
+
device->integer_dot_product = false;
|
|
2784
|
+
bool bfloat16_support = false;
|
|
2455
2785
|
|
|
2456
2786
|
for (const auto& properties : ext_props) {
|
|
2457
2787
|
if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) {
|
|
@@ -2477,6 +2807,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2477
2807
|
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
|
2478
2808
|
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
|
2479
2809
|
coopmat2_support = true;
|
|
2810
|
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
2811
|
+
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
|
2812
|
+
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
|
2813
|
+
device->integer_dot_product = true;
|
|
2814
|
+
#endif
|
|
2815
|
+
} else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 &&
|
|
2816
|
+
!getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
|
2817
|
+
bfloat16_support = true;
|
|
2480
2818
|
}
|
|
2481
2819
|
}
|
|
2482
2820
|
|
|
@@ -2490,6 +2828,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2490
2828
|
vk::PhysicalDeviceVulkan11Properties vk11_props;
|
|
2491
2829
|
vk::PhysicalDeviceVulkan12Properties vk12_props;
|
|
2492
2830
|
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
|
|
2831
|
+
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
|
|
2493
2832
|
|
|
2494
2833
|
props2.pNext = &props3;
|
|
2495
2834
|
props3.pNext = &subgroup_props;
|
|
@@ -2524,9 +2863,15 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2524
2863
|
}
|
|
2525
2864
|
#endif
|
|
2526
2865
|
|
|
2866
|
+
if (device->integer_dot_product) {
|
|
2867
|
+
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
|
2868
|
+
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
|
2869
|
+
}
|
|
2870
|
+
|
|
2527
2871
|
device->physical_device.getProperties2(&props2);
|
|
2528
2872
|
device->properties = props2.properties;
|
|
2529
2873
|
device->vendor_id = device->properties.vendorID;
|
|
2874
|
+
device->driver_id = driver_props.driverID;
|
|
2530
2875
|
|
|
2531
2876
|
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
|
|
2532
2877
|
|
|
@@ -2562,6 +2907,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2562
2907
|
device->subgroup_add = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
2563
2908
|
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
|
|
2564
2909
|
|
|
2910
|
+
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
|
|
2911
|
+
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
|
|
2912
|
+
|
|
2565
2913
|
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
|
|
2566
2914
|
|
|
2567
2915
|
device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
|
@@ -2570,6 +2918,8 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2570
2918
|
device->coopmat_support = false;
|
|
2571
2919
|
}
|
|
2572
2920
|
|
|
2921
|
+
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
|
|
2922
|
+
|
|
2573
2923
|
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
|
|
2574
2924
|
|
|
2575
2925
|
// Try to find a non-graphics compute queue and transfer-focused queues
|
|
@@ -2654,6 +3004,17 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2654
3004
|
}
|
|
2655
3005
|
#endif
|
|
2656
3006
|
|
|
3007
|
+
#if defined(VK_KHR_shader_bfloat16)
|
|
3008
|
+
VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {};
|
|
3009
|
+
bfloat16_features.pNext = nullptr;
|
|
3010
|
+
bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR;
|
|
3011
|
+
if (bfloat16_support) {
|
|
3012
|
+
last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features;
|
|
3013
|
+
last_struct = (VkBaseOutStructure *)&bfloat16_features;
|
|
3014
|
+
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
|
3015
|
+
}
|
|
3016
|
+
#endif
|
|
3017
|
+
|
|
2657
3018
|
VkPhysicalDeviceMaintenance4Features maint4_features {};
|
|
2658
3019
|
maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
|
|
2659
3020
|
if (maintenance4_support) {
|
|
@@ -2662,6 +3023,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2662
3023
|
device_extensions.push_back("VK_KHR_maintenance4");
|
|
2663
3024
|
}
|
|
2664
3025
|
|
|
3026
|
+
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
|
|
3027
|
+
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
|
|
3028
|
+
if (device->integer_dot_product) {
|
|
3029
|
+
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
|
3030
|
+
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
|
3031
|
+
device_extensions.push_back("VK_KHR_shader_integer_dot_product");
|
|
3032
|
+
}
|
|
3033
|
+
|
|
2665
3034
|
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
|
|
2666
3035
|
|
|
2667
3036
|
device->fp16 = device->fp16 && vk12_features.shaderFloat16;
|
|
@@ -2684,6 +3053,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2684
3053
|
|
|
2685
3054
|
#if defined(VK_KHR_cooperative_matrix)
|
|
2686
3055
|
device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix;
|
|
3056
|
+
|
|
3057
|
+
// coopmat1 fa shader currently assumes 32 invocations per subgroup
|
|
3058
|
+
device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support &&
|
|
3059
|
+
device->subgroup_size_control && device->subgroup_min_size <= 32 &&
|
|
3060
|
+
device->subgroup_max_size >= 32;
|
|
2687
3061
|
#endif
|
|
2688
3062
|
|
|
2689
3063
|
if (coopmat2_support) {
|
|
@@ -2818,6 +3192,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2818
3192
|
// Only enable if shape is identical
|
|
2819
3193
|
device->coopmat_acc_f32_support = true;
|
|
2820
3194
|
}
|
|
3195
|
+
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
|
3196
|
+
device->coopmat_support_16x16x16_f32acc = true;
|
|
3197
|
+
}
|
|
2821
3198
|
} else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 &&
|
|
2822
3199
|
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) {
|
|
2823
3200
|
// coopmat sizes not set yet
|
|
@@ -2830,8 +3207,41 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2830
3207
|
// Only enable if shape is identical
|
|
2831
3208
|
device->coopmat_acc_f16_support = true;
|
|
2832
3209
|
}
|
|
3210
|
+
if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) {
|
|
3211
|
+
device->coopmat_support_16x16x16_f16acc = true;
|
|
3212
|
+
}
|
|
3213
|
+
}
|
|
3214
|
+
} else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 &&
|
|
3215
|
+
(vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 &&
|
|
3216
|
+
(vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 &&
|
|
3217
|
+
(vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 &&
|
|
3218
|
+
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup &&
|
|
3219
|
+
device->coopmat_int_m == 0
|
|
3220
|
+
) {
|
|
3221
|
+
device->coopmat_int_support = true;
|
|
3222
|
+
device->coopmat_int_m = prop.MSize;
|
|
3223
|
+
device->coopmat_int_n = prop.NSize;
|
|
3224
|
+
device->coopmat_int_k = prop.KSize;
|
|
3225
|
+
}
|
|
3226
|
+
#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT)
|
|
3227
|
+
if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
|
3228
|
+
prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR &&
|
|
3229
|
+
prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
|
3230
|
+
prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR &&
|
|
3231
|
+
(vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup
|
|
3232
|
+
) {
|
|
3233
|
+
// coopmat sizes not set yet
|
|
3234
|
+
if (device->coopmat_m == 0) {
|
|
3235
|
+
device->coopmat_bf16_support = true;
|
|
3236
|
+
device->coopmat_m = prop.MSize;
|
|
3237
|
+
device->coopmat_n = prop.NSize;
|
|
3238
|
+
device->coopmat_k = prop.KSize;
|
|
3239
|
+
} else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) {
|
|
3240
|
+
// Only enable if shape is identical
|
|
3241
|
+
device->coopmat_bf16_support = true;
|
|
2833
3242
|
}
|
|
2834
3243
|
}
|
|
3244
|
+
#endif
|
|
2835
3245
|
}
|
|
2836
3246
|
|
|
2837
3247
|
if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) {
|
|
@@ -2839,11 +3249,19 @@ static vk_device ggml_vk_get_device(size_t idx) {
|
|
|
2839
3249
|
GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n");
|
|
2840
3250
|
device->coopmat_support = false;
|
|
2841
3251
|
}
|
|
3252
|
+
if (getenv("GGML_VK_DISABLE_BFLOAT16")) {
|
|
3253
|
+
device->coopmat_bf16_support = false;
|
|
3254
|
+
}
|
|
2842
3255
|
}
|
|
2843
3256
|
|
|
2844
3257
|
if (device->coopmat_support) {
|
|
2845
3258
|
device_extensions.push_back("VK_KHR_cooperative_matrix");
|
|
2846
3259
|
}
|
|
3260
|
+
#if defined(VK_KHR_shader_bfloat16)
|
|
3261
|
+
if (device->coopmat_bf16_support) {
|
|
3262
|
+
device_extensions.push_back("VK_KHR_shader_bfloat16");
|
|
3263
|
+
}
|
|
3264
|
+
#endif
|
|
2847
3265
|
#endif
|
|
2848
3266
|
device->name = GGML_VK_NAME + std::to_string(idx);
|
|
2849
3267
|
|
|
@@ -2935,25 +3353,11 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
2935
3353
|
vk::PhysicalDevice physical_device = devices[dev_num];
|
|
2936
3354
|
std::vector<vk::ExtensionProperties> ext_props = physical_device.enumerateDeviceExtensionProperties();
|
|
2937
3355
|
|
|
2938
|
-
vk::PhysicalDeviceProperties2 props2;
|
|
2939
|
-
vk::PhysicalDeviceMaintenance3Properties props3;
|
|
2940
|
-
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
|
2941
|
-
vk::PhysicalDeviceDriverProperties driver_props;
|
|
2942
|
-
props2.pNext = &props3;
|
|
2943
|
-
props3.pNext = &subgroup_props;
|
|
2944
|
-
subgroup_props.pNext = &driver_props;
|
|
2945
|
-
physical_device.getProperties2(&props2);
|
|
2946
|
-
|
|
2947
|
-
vk_device_architecture arch = get_device_architecture(physical_device);
|
|
2948
|
-
uint32_t default_subgroup_size = get_subgroup_size("", arch);
|
|
2949
|
-
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
|
2950
|
-
|
|
2951
|
-
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
|
2952
|
-
|
|
2953
3356
|
bool fp16_storage = false;
|
|
2954
3357
|
bool fp16_compute = false;
|
|
2955
3358
|
bool coopmat_support = false;
|
|
2956
3359
|
bool coopmat2_support = false;
|
|
3360
|
+
bool integer_dot_product = false;
|
|
2957
3361
|
|
|
2958
3362
|
for (auto properties : ext_props) {
|
|
2959
3363
|
if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) {
|
|
@@ -2969,27 +3373,44 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
2969
3373
|
} else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 &&
|
|
2970
3374
|
!getenv("GGML_VK_DISABLE_COOPMAT2")) {
|
|
2971
3375
|
coopmat2_support = true;
|
|
3376
|
+
#endif
|
|
3377
|
+
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
|
|
3378
|
+
} else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 &&
|
|
3379
|
+
!getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) {
|
|
3380
|
+
integer_dot_product = true;
|
|
2972
3381
|
#endif
|
|
2973
3382
|
}
|
|
2974
3383
|
}
|
|
2975
3384
|
|
|
2976
3385
|
const vk_device_architecture device_architecture = get_device_architecture(physical_device);
|
|
2977
3386
|
|
|
2978
|
-
if (!ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture)) {
|
|
2979
|
-
coopmat_support = false;
|
|
2980
|
-
}
|
|
2981
|
-
|
|
2982
3387
|
const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16");
|
|
2983
3388
|
bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr;
|
|
2984
3389
|
|
|
2985
3390
|
bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
|
|
2986
3391
|
|
|
2987
|
-
vk::
|
|
3392
|
+
vk::PhysicalDeviceProperties2 props2;
|
|
3393
|
+
vk::PhysicalDeviceMaintenance3Properties props3;
|
|
3394
|
+
vk::PhysicalDeviceSubgroupProperties subgroup_props;
|
|
3395
|
+
vk::PhysicalDeviceDriverProperties driver_props;
|
|
3396
|
+
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
|
|
3397
|
+
props2.pNext = &props3;
|
|
3398
|
+
props3.pNext = &subgroup_props;
|
|
3399
|
+
subgroup_props.pNext = &driver_props;
|
|
3400
|
+
|
|
3401
|
+
// Pointer to the last chain element
|
|
3402
|
+
VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props;
|
|
3403
|
+
|
|
3404
|
+
if (integer_dot_product) {
|
|
3405
|
+
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
|
3406
|
+
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
|
|
3407
|
+
}
|
|
3408
|
+
|
|
3409
|
+
physical_device.getProperties2(&props2);
|
|
2988
3410
|
|
|
2989
3411
|
VkPhysicalDeviceFeatures2 device_features2;
|
|
2990
3412
|
device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2;
|
|
2991
3413
|
device_features2.pNext = nullptr;
|
|
2992
|
-
device_features2.features = (VkPhysicalDeviceFeatures)device_features;
|
|
2993
3414
|
|
|
2994
3415
|
VkPhysicalDeviceVulkan11Features vk11_features;
|
|
2995
3416
|
vk11_features.pNext = nullptr;
|
|
@@ -3002,7 +3423,7 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
3002
3423
|
vk11_features.pNext = &vk12_features;
|
|
3003
3424
|
|
|
3004
3425
|
// Pointer to the last chain element
|
|
3005
|
-
|
|
3426
|
+
last_struct = (VkBaseOutStructure *)&vk12_features;
|
|
3006
3427
|
|
|
3007
3428
|
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
3008
3429
|
VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features;
|
|
@@ -3014,20 +3435,39 @@ static void ggml_vk_print_gpu_info(size_t idx) {
|
|
|
3014
3435
|
last_struct->pNext = (VkBaseOutStructure *)&coopmat_features;
|
|
3015
3436
|
last_struct = (VkBaseOutStructure *)&coopmat_features;
|
|
3016
3437
|
}
|
|
3438
|
+
#endif
|
|
3439
|
+
|
|
3440
|
+
VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {};
|
|
3441
|
+
shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR;
|
|
3442
|
+
if (integer_dot_product) {
|
|
3443
|
+
last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
|
3444
|
+
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features;
|
|
3445
|
+
}
|
|
3017
3446
|
|
|
3018
3447
|
vkGetPhysicalDeviceFeatures2(physical_device, &device_features2);
|
|
3019
3448
|
|
|
3020
3449
|
fp16 = fp16 && vk12_features.shaderFloat16;
|
|
3021
3450
|
|
|
3022
|
-
|
|
3451
|
+
uint32_t default_subgroup_size = get_subgroup_size("", device_architecture);
|
|
3452
|
+
const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize;
|
|
3453
|
+
const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
|
|
3454
|
+
|
|
3455
|
+
integer_dot_product = integer_dot_product
|
|
3456
|
+
&& shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated
|
|
3457
|
+
&& shader_integer_dot_product_features.shaderIntegerDotProduct;
|
|
3458
|
+
|
|
3459
|
+
coopmat_support = coopmat_support
|
|
3460
|
+
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
|
|
3461
|
+
&& coopmat_features.cooperativeMatrix
|
|
3023
3462
|
#endif
|
|
3463
|
+
&& ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture);
|
|
3024
3464
|
|
|
3025
3465
|
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
|
|
3026
3466
|
|
|
3027
3467
|
std::string device_name = props2.properties.deviceName.data();
|
|
3028
|
-
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
|
|
3468
|
+
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n",
|
|
3029
3469
|
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
|
|
3030
|
-
props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
|
|
3470
|
+
props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str());
|
|
3031
3471
|
|
|
3032
3472
|
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
|
|
3033
3473
|
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
|
|
@@ -3229,6 +3669,7 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
|
|
|
3229
3669
|
ctx->prealloc_size_split_k = 0;
|
|
3230
3670
|
|
|
3231
3671
|
ctx->fence = ctx->device->device.createFence({});
|
|
3672
|
+
ctx->almost_ready_fence = ctx->device->device.createFence({});
|
|
3232
3673
|
|
|
3233
3674
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
3234
3675
|
const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS");
|
|
@@ -3277,6 +3718,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
3277
3718
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
|
|
3278
3719
|
return ctx->device->pipeline_matmul_f32_f16;
|
|
3279
3720
|
}
|
|
3721
|
+
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
|
3722
|
+
return ctx->device->pipeline_matmul_bf16;
|
|
3723
|
+
}
|
|
3280
3724
|
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
|
3281
3725
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
|
3282
3726
|
return ctx->device->pipeline_matmul_f16_f32.f16acc;
|
|
@@ -3293,6 +3737,17 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
|
|
3293
3737
|
}
|
|
3294
3738
|
}
|
|
3295
3739
|
|
|
3740
|
+
// MMQ
|
|
3741
|
+
if (src1_type == GGML_TYPE_Q8_1) {
|
|
3742
|
+
vk_matmul_pipeline pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc;
|
|
3743
|
+
|
|
3744
|
+
if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) {
|
|
3745
|
+
return nullptr;
|
|
3746
|
+
}
|
|
3747
|
+
|
|
3748
|
+
return pipelines;
|
|
3749
|
+
}
|
|
3750
|
+
|
|
3296
3751
|
if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) {
|
|
3297
3752
|
return nullptr;
|
|
3298
3753
|
}
|
|
@@ -3337,6 +3792,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
|
|
3337
3792
|
switch (a_type) {
|
|
3338
3793
|
case GGML_TYPE_F32:
|
|
3339
3794
|
case GGML_TYPE_F16:
|
|
3795
|
+
case GGML_TYPE_BF16:
|
|
3340
3796
|
case GGML_TYPE_Q4_0:
|
|
3341
3797
|
case GGML_TYPE_Q4_1:
|
|
3342
3798
|
case GGML_TYPE_Q5_0:
|
|
@@ -3369,6 +3825,9 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
|
|
|
3369
3825
|
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
|
|
3370
3826
|
return ctx->device->pipeline_matmul_id_f32;
|
|
3371
3827
|
}
|
|
3828
|
+
if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) {
|
|
3829
|
+
return ctx->device->pipeline_matmul_id_bf16;
|
|
3830
|
+
}
|
|
3372
3831
|
if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) {
|
|
3373
3832
|
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) {
|
|
3374
3833
|
return ctx->device->pipeline_matmul_id_f16_f32.f16acc;
|
|
@@ -3422,6 +3881,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
|
|
|
3422
3881
|
switch (a_type) {
|
|
3423
3882
|
case GGML_TYPE_F32:
|
|
3424
3883
|
case GGML_TYPE_F16:
|
|
3884
|
+
case GGML_TYPE_BF16:
|
|
3425
3885
|
case GGML_TYPE_Q4_0:
|
|
3426
3886
|
case GGML_TYPE_Q4_1:
|
|
3427
3887
|
case GGML_TYPE_Q5_0:
|
|
@@ -3585,8 +4045,6 @@ static vk_submission ggml_vk_begin_submission(vk_device& device, vk_queue& q, bo
|
|
|
3585
4045
|
return s;
|
|
3586
4046
|
}
|
|
3587
4047
|
|
|
3588
|
-
|
|
3589
|
-
|
|
3590
4048
|
static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list<vk::DescriptorBufferInfo> const& descriptor_buffer_infos, size_t push_constant_size, const void* push_constants, std::array<uint32_t, 3> elements) {
|
|
3591
4049
|
const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]);
|
|
3592
4050
|
const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]);
|
|
@@ -4010,14 +4468,20 @@ static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, int m, int
|
|
|
4010
4468
|
if (split_k == 3) {
|
|
4011
4469
|
split_k = 2;
|
|
4012
4470
|
}
|
|
4471
|
+
if (ctx->device->coopmat2) {
|
|
4472
|
+
// coopmat2 shader expects splits to be aligned to 256
|
|
4473
|
+
while (split_k > 1 && ((k / split_k) % 256) != 0) {
|
|
4474
|
+
split_k /= 2;
|
|
4475
|
+
}
|
|
4476
|
+
}
|
|
4013
4477
|
}
|
|
4014
4478
|
}
|
|
4015
4479
|
|
|
4016
4480
|
return split_k;
|
|
4017
4481
|
}
|
|
4018
4482
|
|
|
4019
|
-
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp,
|
|
4020
|
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
|
4483
|
+
static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) {
|
|
4484
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
|
4021
4485
|
|
|
4022
4486
|
if (ctx->device->coopmat2) {
|
|
4023
4487
|
// Use large shader when the N dimension is greater than the medium shader's tile size
|
|
@@ -4042,9 +4506,9 @@ static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx,
|
|
|
4042
4506
|
return aligned ? mmp->a_l : mmp->l;
|
|
4043
4507
|
}
|
|
4044
4508
|
|
|
4045
|
-
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) {
|
|
4046
|
-
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")");
|
|
4047
|
-
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type)->align;
|
|
4509
|
+
static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) {
|
|
4510
|
+
VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")");
|
|
4511
|
+
return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align;
|
|
4048
4512
|
}
|
|
4049
4513
|
|
|
4050
4514
|
static void ggml_vk_matmul(
|
|
@@ -4054,7 +4518,7 @@ static void ggml_vk_matmul(
|
|
|
4054
4518
|
uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d,
|
|
4055
4519
|
uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3,
|
|
4056
4520
|
uint32_t padded_n) {
|
|
4057
|
-
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ")");
|
|
4521
|
+
VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")");
|
|
4058
4522
|
ggml_vk_sync_buffers(subctx);
|
|
4059
4523
|
if (split_k == 1) {
|
|
4060
4524
|
const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n };
|
|
@@ -4072,7 +4536,7 @@ static void ggml_vk_matmul(
|
|
|
4072
4536
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2.size() * sizeof(uint32_t), pc2.data(), { m * n * batch, 1, 1 });
|
|
4073
4537
|
}
|
|
4074
4538
|
|
|
4075
|
-
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp,
|
|
4539
|
+
static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) {
|
|
4076
4540
|
VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")");
|
|
4077
4541
|
|
|
4078
4542
|
if (ctx->device->coopmat2) {
|
|
@@ -4153,6 +4617,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
4153
4617
|
return ctx->device->pipeline_cpy_f16_f16;
|
|
4154
4618
|
}
|
|
4155
4619
|
}
|
|
4620
|
+
if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
|
|
4621
|
+
if (contig) {
|
|
4622
|
+
return ctx->device->pipeline_contig_cpy_f16_f32;
|
|
4623
|
+
} else {
|
|
4624
|
+
return ctx->device->pipeline_cpy_f16_f32;
|
|
4625
|
+
}
|
|
4626
|
+
}
|
|
4627
|
+
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
|
|
4628
|
+
if (contig) {
|
|
4629
|
+
return ctx->device->pipeline_contig_cpy_f32_bf16;
|
|
4630
|
+
} else {
|
|
4631
|
+
return ctx->device->pipeline_cpy_f32_bf16;
|
|
4632
|
+
}
|
|
4633
|
+
}
|
|
4156
4634
|
if (src->type == GGML_TYPE_F32) {
|
|
4157
4635
|
switch (to) {
|
|
4158
4636
|
case GGML_TYPE_Q4_0:
|
|
@@ -4214,6 +4692,25 @@ static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context&
|
|
|
4214
4692
|
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(vk_op_unary_push_constants), &pc, elements);
|
|
4215
4693
|
}
|
|
4216
4694
|
|
|
4695
|
+
static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type) {
|
|
4696
|
+
switch(type) {
|
|
4697
|
+
case GGML_TYPE_Q8_1:
|
|
4698
|
+
return ctx->device->pipeline_quantize_q8_1;
|
|
4699
|
+
default:
|
|
4700
|
+
std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl;
|
|
4701
|
+
GGML_ABORT("fatal error");
|
|
4702
|
+
}
|
|
4703
|
+
}
|
|
4704
|
+
|
|
4705
|
+
static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne) {
|
|
4706
|
+
VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")");
|
|
4707
|
+
|
|
4708
|
+
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
|
4709
|
+
|
|
4710
|
+
ggml_vk_sync_buffers(subctx);
|
|
4711
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, sizeof(uint32_t), &ne, { ne, 1, 1 });
|
|
4712
|
+
}
|
|
4713
|
+
|
|
4217
4714
|
static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
4218
4715
|
VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
|
4219
4716
|
std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3];
|
|
@@ -4261,30 +4758,43 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4261
4758
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
|
4262
4759
|
!ggml_vk_dim01_contiguous(src0);
|
|
4263
4760
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
|
4761
|
+
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
|
4264
4762
|
!ggml_vk_dim01_contiguous(src1);
|
|
4265
4763
|
|
|
4764
|
+
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
|
4765
|
+
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
|
4766
|
+
|
|
4266
4767
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
|
4267
4768
|
|
|
4268
|
-
|
|
4769
|
+
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0;
|
|
4770
|
+
|
|
4771
|
+
// Check for mmq first
|
|
4772
|
+
vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr;
|
|
4773
|
+
|
|
4774
|
+
if (mmp == nullptr) {
|
|
4775
|
+
// Fall back to f16 dequant mul mat
|
|
4776
|
+
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
|
4777
|
+
quantize_y = false;
|
|
4778
|
+
}
|
|
4269
4779
|
|
|
4270
4780
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
|
4271
|
-
const bool qy_needs_dequant = (src1->type !=
|
|
4781
|
+
const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig);
|
|
4272
4782
|
|
|
4273
4783
|
if (qx_needs_dequant) {
|
|
4274
4784
|
// Fall back to dequant + f16 mulmat
|
|
4275
|
-
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx,
|
|
4785
|
+
mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
|
4276
4786
|
}
|
|
4277
4787
|
|
|
4278
4788
|
// Not implemented
|
|
4279
4789
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
4280
4790
|
|
|
4281
|
-
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ?
|
|
4282
|
-
const bool aligned = ne10 == kpad && ne01 > 8 && ne11 > 8;
|
|
4791
|
+
const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)));
|
|
4792
|
+
const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8;
|
|
4283
4793
|
|
|
4284
|
-
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ?
|
|
4794
|
+
vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type));
|
|
4285
4795
|
|
|
4286
4796
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
|
4287
|
-
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
|
4797
|
+
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11;
|
|
4288
4798
|
const int x_ne = ne01 * ne00;
|
|
4289
4799
|
const int y_ne = padded_n * ne10;
|
|
4290
4800
|
const int d_ne = ne11 * ne01;
|
|
@@ -4294,25 +4804,30 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4294
4804
|
const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type);
|
|
4295
4805
|
const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type);
|
|
4296
4806
|
const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne;
|
|
4297
|
-
const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
|
|
4807
|
+
const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
|
|
4298
4808
|
const uint64_t d_sz = sizeof(float) * d_ne;
|
|
4299
4809
|
|
|
4300
4810
|
vk_pipeline to_fp16_vk_0 = nullptr;
|
|
4301
4811
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
|
4812
|
+
vk_pipeline to_q8_1 = nullptr;
|
|
4302
4813
|
|
|
4303
4814
|
if (x_non_contig) {
|
|
4304
|
-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr,
|
|
4815
|
+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
|
4305
4816
|
} else {
|
|
4306
4817
|
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
|
4307
4818
|
}
|
|
4308
4819
|
if (y_non_contig) {
|
|
4309
|
-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr,
|
|
4820
|
+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
|
4310
4821
|
} else {
|
|
4311
4822
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
|
4312
4823
|
}
|
|
4313
4824
|
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
|
4314
4825
|
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
|
4315
4826
|
|
|
4827
|
+
if (quantize_y) {
|
|
4828
|
+
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
|
|
4829
|
+
}
|
|
4830
|
+
|
|
4316
4831
|
if (dryrun) {
|
|
4317
4832
|
const uint64_t x_sz_upd = x_sz * ne02 * ne03;
|
|
4318
4833
|
const uint64_t y_sz_upd = y_sz * ne12 * ne13;
|
|
@@ -4326,7 +4841,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4326
4841
|
if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) {
|
|
4327
4842
|
ctx->prealloc_size_x = x_sz_upd;
|
|
4328
4843
|
}
|
|
4329
|
-
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) {
|
|
4844
|
+
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) {
|
|
4330
4845
|
ctx->prealloc_size_y = y_sz_upd;
|
|
4331
4846
|
}
|
|
4332
4847
|
if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) {
|
|
@@ -4341,6 +4856,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4341
4856
|
if (qy_needs_dequant) {
|
|
4342
4857
|
ggml_pipeline_request_descriptor_sets(ctx->device, to_fp16_vk_1, 1);
|
|
4343
4858
|
}
|
|
4859
|
+
if (quantize_y) {
|
|
4860
|
+
ggml_pipeline_request_descriptor_sets(ctx->device, to_q8_1, 1);
|
|
4861
|
+
}
|
|
4344
4862
|
if (split_k > 1) {
|
|
4345
4863
|
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_matmul_split_k_reduce, 1);
|
|
4346
4864
|
}
|
|
@@ -4376,6 +4894,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4376
4894
|
if (qy_needs_dequant) {
|
|
4377
4895
|
d_Y = ctx->prealloc_y;
|
|
4378
4896
|
GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13);
|
|
4897
|
+
} else if (quantize_y) {
|
|
4898
|
+
d_Y = ctx->prealloc_y;
|
|
4899
|
+
GGML_ASSERT(d_Y->size >= y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1));
|
|
4379
4900
|
} else {
|
|
4380
4901
|
d_Y = d_Qy;
|
|
4381
4902
|
y_buf_offset = qy_buf_offset;
|
|
@@ -4392,6 +4913,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4392
4913
|
if (y_non_contig) {
|
|
4393
4914
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
|
4394
4915
|
}
|
|
4916
|
+
if (quantize_y) {
|
|
4917
|
+
ggml_vk_quantize_q8_1(ctx, subctx, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE }, y_ne * ne12 * ne13);
|
|
4918
|
+
}
|
|
4395
4919
|
|
|
4396
4920
|
uint32_t stride_batch_x = ne00*ne01;
|
|
4397
4921
|
uint32_t stride_batch_y = ne10*ne11;
|
|
@@ -4400,7 +4924,7 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
|
|
|
4400
4924
|
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
|
4401
4925
|
}
|
|
4402
4926
|
|
|
4403
|
-
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) {
|
|
4927
|
+
if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) {
|
|
4404
4928
|
stride_batch_y = src1->nb[0] / ggml_type_size(src1->type);
|
|
4405
4929
|
}
|
|
4406
4930
|
|
|
@@ -4710,6 +5234,8 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
4710
5234
|
const uint64_t nb01 = src0->nb[1];
|
|
4711
5235
|
const uint64_t nb02 = src0->nb[2];
|
|
4712
5236
|
|
|
5237
|
+
const uint64_t nb12 = src1->nb[2];
|
|
5238
|
+
|
|
4713
5239
|
// const uint64_t ne10 = src1->ne[0];
|
|
4714
5240
|
const uint64_t ne11 = src1->ne[1];
|
|
4715
5241
|
const uint64_t ne12 = src1->ne[2];
|
|
@@ -4735,6 +5261,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
4735
5261
|
|
|
4736
5262
|
const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t);
|
|
4737
5263
|
const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t);
|
|
5264
|
+
const uint32_t channel_stride_y = nb12 / sizeof(float);
|
|
4738
5265
|
|
|
4739
5266
|
const uint64_t qx_sz = ggml_nbytes(src0);
|
|
4740
5267
|
const uint64_t qy_sz = ggml_nbytes(src1);
|
|
@@ -4765,7 +5292,7 @@ static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_con
|
|
|
4765
5292
|
const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset;
|
|
4766
5293
|
|
|
4767
5294
|
// compute
|
|
4768
|
-
const std::array<uint32_t,
|
|
5295
|
+
const std::array<uint32_t, 9> pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) };
|
|
4769
5296
|
ggml_vk_sync_buffers(subctx);
|
|
4770
5297
|
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32,
|
|
4771
5298
|
{ vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, 7 * sizeof(uint32_t), &pc, { 1, (uint32_t)ne01, (uint32_t)ne12 });
|
|
@@ -4790,7 +5317,7 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
4790
5317
|
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
|
4791
5318
|
// when ne12 and ne13 are one.
|
|
4792
5319
|
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
|
4793
|
-
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
|
5320
|
+
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) {
|
|
4794
5321
|
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
|
4795
5322
|
} else {
|
|
4796
5323
|
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
|
@@ -4817,7 +5344,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
4817
5344
|
|
|
4818
5345
|
const uint64_t nei0 = ids->ne[0];
|
|
4819
5346
|
const uint64_t nei1 = ids->ne[1];
|
|
4820
|
-
GGML_ASSERT(nei0 * nei1 <=
|
|
5347
|
+
GGML_ASSERT(nei0 * nei1 <= 4096);
|
|
4821
5348
|
|
|
4822
5349
|
const uint32_t nbi1 = ids->nb[1];
|
|
4823
5350
|
const uint32_t nbi2 = ids->nb[2];
|
|
@@ -4858,27 +5385,31 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
4858
5385
|
const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
|
|
4859
5386
|
!ggml_vk_dim01_contiguous(src0);
|
|
4860
5387
|
const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
|
|
5388
|
+
(src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) ||
|
|
4861
5389
|
!ggml_vk_dim01_contiguous(src1);
|
|
4862
5390
|
|
|
5391
|
+
// If src0 is BF16, try to use a BF16 x BF16 multiply
|
|
5392
|
+
ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16;
|
|
5393
|
+
|
|
4863
5394
|
const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
|
|
4864
5395
|
|
|
4865
|
-
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ?
|
|
5396
|
+
vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]);
|
|
4866
5397
|
|
|
4867
5398
|
const bool qx_needs_dequant = mmp == nullptr || x_non_contig;
|
|
4868
|
-
const bool qy_needs_dequant = (src1->type !=
|
|
5399
|
+
const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig;
|
|
4869
5400
|
|
|
4870
5401
|
if (qx_needs_dequant) {
|
|
4871
5402
|
// Fall back to dequant + f16 mulmat
|
|
4872
|
-
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx,
|
|
5403
|
+
mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]);
|
|
4873
5404
|
}
|
|
4874
5405
|
|
|
4875
5406
|
// Not implemented
|
|
4876
5407
|
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
|
|
4877
5408
|
|
|
4878
|
-
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ?
|
|
5409
|
+
const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type));
|
|
4879
5410
|
const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8;
|
|
4880
5411
|
|
|
4881
|
-
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ?
|
|
5412
|
+
vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type);
|
|
4882
5413
|
|
|
4883
5414
|
// Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking
|
|
4884
5415
|
uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11;
|
|
@@ -4897,12 +5428,12 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
|
|
4897
5428
|
vk_pipeline to_fp16_vk_1 = nullptr;
|
|
4898
5429
|
|
|
4899
5430
|
if (x_non_contig) {
|
|
4900
|
-
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr,
|
|
5431
|
+
to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type);
|
|
4901
5432
|
} else {
|
|
4902
5433
|
to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type);
|
|
4903
5434
|
}
|
|
4904
5435
|
if (y_non_contig) {
|
|
4905
|
-
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr,
|
|
5436
|
+
to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type);
|
|
4906
5437
|
} else {
|
|
4907
5438
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
|
4908
5439
|
}
|
|
@@ -5212,6 +5743,36 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5212
5743
|
}
|
|
5213
5744
|
}
|
|
5214
5745
|
|
|
5746
|
+
static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t D, bool f32acc) {
|
|
5747
|
+
// Needs to be kept up to date on shader changes
|
|
5748
|
+
const uint32_t wg_size = scalar_flash_attention_workgroup_size;
|
|
5749
|
+
const uint32_t Br = scalar_flash_attention_num_large_rows;
|
|
5750
|
+
const uint32_t Bc = scalar_flash_attention_Bc;
|
|
5751
|
+
|
|
5752
|
+
const uint32_t acctype = f32acc ? 4 : 2;
|
|
5753
|
+
const uint32_t f16vec4 = 8;
|
|
5754
|
+
|
|
5755
|
+
const uint32_t tmpsh = wg_size * sizeof(float);
|
|
5756
|
+
const uint32_t tmpshv4 = wg_size * 4 * acctype;
|
|
5757
|
+
|
|
5758
|
+
const uint32_t Qf = Br * (D / 4 + 2) * f16vec4;
|
|
5759
|
+
|
|
5760
|
+
const uint32_t sfshstride = (D <= 128) ? (Br + 8) : Br;
|
|
5761
|
+
const uint32_t sfsh = Bc * sfshstride * acctype;
|
|
5762
|
+
|
|
5763
|
+
const uint32_t kshstride = D / 4 + 2;
|
|
5764
|
+
const uint32_t ksh = Bc * kshstride * f16vec4;
|
|
5765
|
+
|
|
5766
|
+
const uint32_t slope = Br * sizeof(float);
|
|
5767
|
+
|
|
5768
|
+
const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope;
|
|
5769
|
+
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
|
|
5770
|
+
|
|
5771
|
+
VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(D=" << D << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported);
|
|
5772
|
+
|
|
5773
|
+
return supported;
|
|
5774
|
+
}
|
|
5775
|
+
|
|
5215
5776
|
static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, ggml_tensor * dst, bool dryrun = false) {
|
|
5216
5777
|
VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3];
|
|
5217
5778
|
std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3];
|
|
@@ -5232,7 +5793,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5232
5793
|
const uint32_t nbm1 = mask ? mask->nb[1] : 0;
|
|
5233
5794
|
|
|
5234
5795
|
const uint32_t D = neq0;
|
|
5235
|
-
|
|
5796
|
+
uint32_t N = neq1;
|
|
5236
5797
|
const uint32_t KV = nek1;
|
|
5237
5798
|
|
|
5238
5799
|
GGML_ASSERT(ne0 == D);
|
|
@@ -5262,20 +5823,110 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5262
5823
|
assert(q->type == GGML_TYPE_F32);
|
|
5263
5824
|
assert(k->type == v->type);
|
|
5264
5825
|
|
|
5826
|
+
FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 :
|
|
5827
|
+
ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR;
|
|
5828
|
+
|
|
5829
|
+
if (path == FA_COOPMAT1) {
|
|
5830
|
+
const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) ||
|
|
5831
|
+
(dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc);
|
|
5832
|
+
|
|
5833
|
+
const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, D, dst->op_params[3] == GGML_PREC_F32);
|
|
5834
|
+
|
|
5835
|
+
if (!coopmat_shape_supported || !coopmat_shmem_supported) {
|
|
5836
|
+
path = FA_SCALAR;
|
|
5837
|
+
}
|
|
5838
|
+
}
|
|
5839
|
+
|
|
5840
|
+
uint32_t gqa_ratio = 1;
|
|
5841
|
+
uint32_t qk_ratio = neq2 / nek2;
|
|
5842
|
+
uint32_t workgroups_x = (uint32_t)neq1;
|
|
5843
|
+
uint32_t workgroups_y = (uint32_t)neq2;
|
|
5844
|
+
uint32_t workgroups_z = (uint32_t)neq3;
|
|
5845
|
+
|
|
5846
|
+
// For scalar/coopmat1 FA, we can use the "large" size to accommodate qga.
|
|
5847
|
+
// For coopmat2 FA, we always use the small size (which is still pretty large for gqa).
|
|
5848
|
+
uint32_t max_gqa;
|
|
5849
|
+
switch (path) {
|
|
5850
|
+
case FA_SCALAR:
|
|
5851
|
+
case FA_COOPMAT1:
|
|
5852
|
+
// We may switch from coopmat1 to scalar, so use the scalar limit for both
|
|
5853
|
+
max_gqa = scalar_flash_attention_num_large_rows;
|
|
5854
|
+
break;
|
|
5855
|
+
case FA_COOPMAT2:
|
|
5856
|
+
max_gqa = get_fa_num_small_rows(FA_COOPMAT2);
|
|
5857
|
+
break;
|
|
5858
|
+
default:
|
|
5859
|
+
GGML_ASSERT(0);
|
|
5860
|
+
}
|
|
5861
|
+
|
|
5862
|
+
if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa &&
|
|
5863
|
+
qk_ratio * nek2 == neq2 && nek2 == nev2 && neq3 == 1 && nek3 == 1 && nev3 == 1) {
|
|
5864
|
+
// grouped query attention - make the N dimension equal to gqa_ratio, reduce
|
|
5865
|
+
// workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1
|
|
5866
|
+
// and change addressing calculations to index Q's dimension 2.
|
|
5867
|
+
gqa_ratio = qk_ratio;
|
|
5868
|
+
N = gqa_ratio;
|
|
5869
|
+
workgroups_y /= N;
|
|
5870
|
+
}
|
|
5871
|
+
|
|
5265
5872
|
vk_pipeline *pipelines;
|
|
5266
|
-
|
|
5267
|
-
|
|
5268
|
-
|
|
5269
|
-
|
|
5270
|
-
|
|
5271
|
-
|
|
5272
|
-
|
|
5273
|
-
|
|
5274
|
-
|
|
5275
|
-
|
|
5873
|
+
bool small_rows = N <= get_fa_num_small_rows(path);
|
|
5874
|
+
|
|
5875
|
+
// coopmat1 does not actually support "small rows" (it needs 16 rows).
|
|
5876
|
+
// So use scalar instead.
|
|
5877
|
+
if (small_rows && path == FA_COOPMAT1) {
|
|
5878
|
+
path = FA_SCALAR;
|
|
5879
|
+
}
|
|
5880
|
+
|
|
5881
|
+
// scalar is faster than coopmat2 when N==1
|
|
5882
|
+
if (N == 1 && path == FA_COOPMAT2) {
|
|
5883
|
+
path = FA_SCALAR;
|
|
5884
|
+
}
|
|
5885
|
+
|
|
5886
|
+
bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32;
|
|
5887
|
+
|
|
5888
|
+
switch (path) {
|
|
5889
|
+
case FA_SCALAR:
|
|
5890
|
+
switch (D) {
|
|
5891
|
+
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64[k->type][f32acc][small_rows][0]; break;
|
|
5892
|
+
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80[k->type][f32acc][small_rows][0]; break;
|
|
5893
|
+
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96[k->type][f32acc][small_rows][0]; break;
|
|
5894
|
+
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112[k->type][f32acc][small_rows][0]; break;
|
|
5895
|
+
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128[k->type][f32acc][small_rows][0]; break;
|
|
5896
|
+
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256[k->type][f32acc][small_rows][0]; break;
|
|
5897
|
+
default:
|
|
5898
|
+
GGML_ASSERT(!"unsupported D value");
|
|
5899
|
+
return;
|
|
5900
|
+
}
|
|
5901
|
+
break;
|
|
5902
|
+
case FA_COOPMAT1:
|
|
5903
|
+
switch (D) {
|
|
5904
|
+
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5905
|
+
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5906
|
+
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5907
|
+
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5908
|
+
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5909
|
+
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm1[k->type][f32acc][small_rows][0]; break;
|
|
5910
|
+
default:
|
|
5911
|
+
GGML_ASSERT(!"unsupported D value");
|
|
5912
|
+
return;
|
|
5913
|
+
}
|
|
5914
|
+
break;
|
|
5915
|
+
case FA_COOPMAT2:
|
|
5916
|
+
switch (D) {
|
|
5917
|
+
case 64: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D64_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5918
|
+
case 80: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D80_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5919
|
+
case 96: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D96_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5920
|
+
case 112: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D112_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5921
|
+
case 128: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D128_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5922
|
+
case 256: pipelines = &ctx->device->pipeline_flash_attn_f32_f16_D256_cm2[k->type][f32acc][small_rows][0]; break;
|
|
5923
|
+
default:
|
|
5924
|
+
GGML_ASSERT(!"unsupported D value");
|
|
5925
|
+
return;
|
|
5926
|
+
}
|
|
5927
|
+
break;
|
|
5276
5928
|
default:
|
|
5277
|
-
|
|
5278
|
-
return;
|
|
5929
|
+
GGML_ASSERT(0);
|
|
5279
5930
|
}
|
|
5280
5931
|
assert(pipelines);
|
|
5281
5932
|
|
|
@@ -5287,12 +5938,47 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5287
5938
|
// the "aligned" shader variant will forcibly align strides, for performance
|
|
5288
5939
|
(q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0;
|
|
5289
5940
|
|
|
5941
|
+
// mask dim1 is padded to 64, we rely on this to avoid clamping mask loads
|
|
5942
|
+
GGML_ASSERT((nem1 % GGML_KQ_MASK_PAD) == 0);
|
|
5943
|
+
|
|
5290
5944
|
vk_pipeline pipeline = pipelines[aligned];
|
|
5291
5945
|
assert(pipeline);
|
|
5292
5946
|
|
|
5947
|
+
uint32_t split_kv = KV;
|
|
5948
|
+
uint32_t split_k = 1;
|
|
5949
|
+
|
|
5950
|
+
// Use a placeholder core count if one isn't available. split_k is a big help for perf.
|
|
5951
|
+
const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16;
|
|
5952
|
+
|
|
5953
|
+
// Try to use split_k when KV is large enough to be worth the overhead
|
|
5954
|
+
if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
|
|
5955
|
+
// Try to run two workgroups per SM.
|
|
5956
|
+
split_k = ctx->device->shader_core_count * 2 / workgroups_y;
|
|
5957
|
+
if (split_k > 1) {
|
|
5958
|
+
// Try to evenly split KV into split_k chunks, but it needs to be a multiple
|
|
5959
|
+
// of "align", so recompute split_k based on that.
|
|
5960
|
+
split_kv = ROUNDUP_POW2(KV / split_k, pipelines[1]->align);
|
|
5961
|
+
split_k = CEIL_DIV(KV, split_kv);
|
|
5962
|
+
workgroups_x = split_k;
|
|
5963
|
+
}
|
|
5964
|
+
}
|
|
5965
|
+
|
|
5966
|
+
// Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
|
|
5967
|
+
// and the per-row m and L values (ne1 rows).
|
|
5968
|
+
const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
|
|
5969
|
+
if (split_k_size > ctx->device->max_memory_allocation_size) {
|
|
5970
|
+
GGML_ABORT("Requested preallocation size is too large");
|
|
5971
|
+
}
|
|
5972
|
+
if (ctx->prealloc_size_split_k < split_k_size) {
|
|
5973
|
+
ctx->prealloc_size_split_k = split_k_size;
|
|
5974
|
+
}
|
|
5975
|
+
|
|
5293
5976
|
if (dryrun) {
|
|
5294
5977
|
// Request descriptor sets
|
|
5295
5978
|
ggml_pipeline_request_descriptor_sets(ctx->device, pipeline, 1);
|
|
5979
|
+
if (split_k > 1) {
|
|
5980
|
+
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_flash_attn_split_k_reduce, 1);
|
|
5981
|
+
}
|
|
5296
5982
|
return;
|
|
5297
5983
|
}
|
|
5298
5984
|
|
|
@@ -5313,8 +5999,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5313
5999
|
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
5314
6000
|
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
5315
6001
|
|
|
5316
|
-
ggml_vk_sync_buffers(subctx);
|
|
5317
|
-
|
|
5318
6002
|
vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr;
|
|
5319
6003
|
size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0;
|
|
5320
6004
|
|
|
@@ -5379,16 +6063,45 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
5379
6063
|
v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
|
|
5380
6064
|
nbm1,
|
|
5381
6065
|
scale, max_bias, logit_softcap,
|
|
5382
|
-
mask != nullptr, n_head_log2, m0, m1
|
|
5383
|
-
|
|
5384
|
-
|
|
5385
|
-
|
|
5386
|
-
|
|
5387
|
-
|
|
5388
|
-
|
|
5389
|
-
|
|
5390
|
-
|
|
5391
|
-
|
|
6066
|
+
mask != nullptr, n_head_log2, m0, m1,
|
|
6067
|
+
gqa_ratio, split_kv, split_k };
|
|
6068
|
+
|
|
6069
|
+
ggml_vk_sync_buffers(subctx);
|
|
6070
|
+
|
|
6071
|
+
if (split_k > 1) {
|
|
6072
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
6073
|
+
{
|
|
6074
|
+
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
|
6075
|
+
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
|
6076
|
+
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
|
6077
|
+
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
|
6078
|
+
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
|
6079
|
+
},
|
|
6080
|
+
// We only use split_k when group query attention is enabled, which means
|
|
6081
|
+
// there's no more than one tile of rows (i.e. workgroups_x would have been
|
|
6082
|
+
// one). We reuse workgroups_x to mean the number of splits, so we need to
|
|
6083
|
+
// cancel out the divide by wg_denoms[0].
|
|
6084
|
+
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
|
|
6085
|
+
|
|
6086
|
+
ggml_vk_sync_buffers(subctx);
|
|
6087
|
+
const std::array<uint32_t, 3> pc2 = { D, (uint32_t)ne1, split_k };
|
|
6088
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
|
|
6089
|
+
{
|
|
6090
|
+
vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
|
|
6091
|
+
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
|
6092
|
+
},
|
|
6093
|
+
pc2.size() * uint32_t{sizeof(uint32_t)}, pc2.data(), { (uint32_t)ne1, 1, 1 });
|
|
6094
|
+
} else {
|
|
6095
|
+
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
|
|
6096
|
+
{
|
|
6097
|
+
vk_subbuffer{d_Q, q_buf_offset, VK_WHOLE_SIZE},
|
|
6098
|
+
vk_subbuffer{d_K, k_buf_offset, VK_WHOLE_SIZE},
|
|
6099
|
+
vk_subbuffer{d_V, v_buf_offset, VK_WHOLE_SIZE},
|
|
6100
|
+
vk_subbuffer{d_M, m_buf_offset, VK_WHOLE_SIZE},
|
|
6101
|
+
vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
|
|
6102
|
+
},
|
|
6103
|
+
sizeof(vk_flash_attn_push_constants), &pc, { workgroups_x, workgroups_y, workgroups_z });
|
|
6104
|
+
}
|
|
5392
6105
|
}
|
|
5393
6106
|
|
|
5394
6107
|
static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) {
|
|
@@ -5408,26 +6121,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5408
6121
|
}
|
|
5409
6122
|
return nullptr;
|
|
5410
6123
|
case GGML_OP_ADD:
|
|
5411
|
-
|
|
5412
|
-
|
|
6124
|
+
case GGML_OP_SUB:
|
|
6125
|
+
case GGML_OP_MUL:
|
|
6126
|
+
case GGML_OP_DIV:
|
|
6127
|
+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
|
6128
|
+
(src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
|
|
6129
|
+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
|
|
6130
|
+
return nullptr;
|
|
5413
6131
|
}
|
|
5414
|
-
|
|
5415
|
-
|
|
6132
|
+
switch (op) {
|
|
6133
|
+
case GGML_OP_ADD:
|
|
6134
|
+
{
|
|
6135
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
|
|
6136
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
5416
6137
|
}
|
|
5417
|
-
|
|
5418
|
-
|
|
5419
|
-
|
|
5420
|
-
return
|
|
6138
|
+
case GGML_OP_SUB:
|
|
6139
|
+
{
|
|
6140
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
|
|
6141
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
5421
6142
|
}
|
|
5422
|
-
|
|
5423
|
-
|
|
5424
|
-
|
|
5425
|
-
return
|
|
6143
|
+
case GGML_OP_MUL:
|
|
6144
|
+
{
|
|
6145
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
|
|
6146
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
5426
6147
|
}
|
|
5427
|
-
|
|
5428
|
-
|
|
5429
|
-
|
|
5430
|
-
return
|
|
6148
|
+
case GGML_OP_DIV:
|
|
6149
|
+
{
|
|
6150
|
+
auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
|
|
6151
|
+
return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
|
|
6152
|
+
}
|
|
6153
|
+
default:
|
|
6154
|
+
break;
|
|
5431
6155
|
}
|
|
5432
6156
|
return nullptr;
|
|
5433
6157
|
case GGML_OP_CONCAT:
|
|
@@ -5442,7 +6166,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5442
6166
|
}
|
|
5443
6167
|
return nullptr;
|
|
5444
6168
|
case GGML_OP_UPSCALE:
|
|
5445
|
-
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6169
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && dst->op_params[0] == GGML_SCALE_MODE_NEAREST) {
|
|
5446
6170
|
return ctx->device->pipeline_upscale_f32;
|
|
5447
6171
|
}
|
|
5448
6172
|
return nullptr;
|
|
@@ -5521,37 +6245,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5521
6245
|
}
|
|
5522
6246
|
return nullptr;
|
|
5523
6247
|
case GGML_OP_UNARY:
|
|
6248
|
+
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
|
6249
|
+
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
|
|
6250
|
+
(src0->type != dst->type)) {
|
|
6251
|
+
return nullptr;
|
|
6252
|
+
}
|
|
6253
|
+
|
|
5524
6254
|
switch (ggml_get_unary_op(dst)) {
|
|
5525
6255
|
case GGML_UNARY_OP_SILU:
|
|
5526
|
-
|
|
5527
|
-
return ctx->device->pipeline_silu_f32;
|
|
5528
|
-
}
|
|
5529
|
-
break;
|
|
6256
|
+
return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
|
|
5530
6257
|
case GGML_UNARY_OP_GELU:
|
|
5531
|
-
|
|
5532
|
-
return ctx->device->pipeline_gelu_f32;
|
|
5533
|
-
}
|
|
5534
|
-
break;
|
|
6258
|
+
return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
|
|
5535
6259
|
case GGML_UNARY_OP_GELU_QUICK:
|
|
5536
|
-
|
|
5537
|
-
return ctx->device->pipeline_gelu_quick_f32;
|
|
5538
|
-
}
|
|
5539
|
-
break;
|
|
6260
|
+
return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
|
|
5540
6261
|
case GGML_UNARY_OP_RELU:
|
|
5541
|
-
|
|
5542
|
-
return ctx->device->pipeline_relu_f32;
|
|
5543
|
-
}
|
|
5544
|
-
break;
|
|
6262
|
+
return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
|
|
5545
6263
|
case GGML_UNARY_OP_TANH:
|
|
5546
|
-
|
|
5547
|
-
return ctx->device->pipeline_tanh_f32;
|
|
5548
|
-
}
|
|
5549
|
-
break;
|
|
6264
|
+
return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
|
|
5550
6265
|
case GGML_UNARY_OP_SIGMOID:
|
|
5551
|
-
|
|
5552
|
-
return ctx->device->pipeline_sigmoid_f32;
|
|
5553
|
-
}
|
|
5554
|
-
break;
|
|
6266
|
+
return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
|
|
5555
6267
|
default:
|
|
5556
6268
|
break;
|
|
5557
6269
|
}
|
|
@@ -5674,6 +6386,15 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
|
|
5674
6386
|
return ctx->device->pipeline_leaky_relu_f32;
|
|
5675
6387
|
}
|
|
5676
6388
|
return nullptr;
|
|
6389
|
+
case GGML_OP_CONV_2D_DW:
|
|
6390
|
+
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
|
6391
|
+
if (ggml_is_contiguous(src1)) {
|
|
6392
|
+
return ctx->device->pipeline_conv2d_dw_whcn_f32;
|
|
6393
|
+
} else if (ggml_is_contiguous_channels(src1)) {
|
|
6394
|
+
return ctx->device->pipeline_conv2d_dw_cwhn_f32;
|
|
6395
|
+
}
|
|
6396
|
+
}
|
|
6397
|
+
return nullptr;
|
|
5677
6398
|
default:
|
|
5678
6399
|
return nullptr;
|
|
5679
6400
|
}
|
|
@@ -5699,6 +6420,8 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
|
|
5699
6420
|
case GGML_OP_REPEAT:
|
|
5700
6421
|
case GGML_OP_REPEAT_BACK:
|
|
5701
6422
|
case GGML_OP_ROPE:
|
|
6423
|
+
case GGML_OP_RMS_NORM:
|
|
6424
|
+
case GGML_OP_CONV_2D_DW:
|
|
5702
6425
|
return true;
|
|
5703
6426
|
default:
|
|
5704
6427
|
return false;
|
|
@@ -5909,7 +6632,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5909
6632
|
|
|
5910
6633
|
switch (op) {
|
|
5911
6634
|
case GGML_OP_NORM:
|
|
5912
|
-
case GGML_OP_RMS_NORM:
|
|
5913
6635
|
case GGML_OP_RMS_NORM_BACK:
|
|
5914
6636
|
case GGML_OP_L2_NORM:
|
|
5915
6637
|
case GGML_OP_SOFT_MAX:
|
|
@@ -5926,6 +6648,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5926
6648
|
elements = { nr, 1, 1 };
|
|
5927
6649
|
}
|
|
5928
6650
|
} break;
|
|
6651
|
+
case GGML_OP_RMS_NORM:
|
|
6652
|
+
elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 };
|
|
6653
|
+
break;
|
|
6654
|
+
|
|
5929
6655
|
case GGML_OP_SUM:
|
|
5930
6656
|
// We use GGML_OP_SUM_ROWS with 1 row.
|
|
5931
6657
|
elements = { 1, 1, 1 };
|
|
@@ -5992,6 +6718,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
|
|
5992
6718
|
case GGML_OP_CONCAT:
|
|
5993
6719
|
case GGML_OP_UPSCALE:
|
|
5994
6720
|
case GGML_OP_UNARY:
|
|
6721
|
+
case GGML_OP_CONV_2D_DW:
|
|
5995
6722
|
{
|
|
5996
6723
|
const uint32_t ne = ggml_nelements(dst);
|
|
5997
6724
|
if (ne > 262144) {
|
|
@@ -6576,7 +7303,17 @@ static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx
|
|
|
6576
7303
|
|
|
6577
7304
|
static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6578
7305
|
float * op_params = (float *)dst->op_params;
|
|
6579
|
-
|
|
7306
|
+
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
|
7307
|
+
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
|
7308
|
+
|
|
7309
|
+
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
|
|
7310
|
+
(uint32_t)ggml_nelements(src0),
|
|
7311
|
+
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
|
7312
|
+
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
|
7313
|
+
0,
|
|
7314
|
+
op_params[0], 0.0f,
|
|
7315
|
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
|
7316
|
+
}, dryrun);
|
|
6580
7317
|
}
|
|
6581
7318
|
|
|
6582
7319
|
static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
@@ -6768,6 +7505,30 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
|
|
6768
7505
|
}, dryrun);
|
|
6769
7506
|
}
|
|
6770
7507
|
|
|
7508
|
+
static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
|
7509
|
+
vk_op_conv2d_dw_push_constants p{};
|
|
7510
|
+
p.ne = ggml_nelements(dst);
|
|
7511
|
+
p.channels = dst->ne[2];
|
|
7512
|
+
p.batches = dst->ne[3];
|
|
7513
|
+
p.dst_w = dst->ne[0];
|
|
7514
|
+
p.dst_h = dst->ne[1];
|
|
7515
|
+
p.src_w = src1->ne[0];
|
|
7516
|
+
p.src_h = src1->ne[1];
|
|
7517
|
+
p.knl_w = src0->ne[0];
|
|
7518
|
+
p.knl_h = src0->ne[1];
|
|
7519
|
+
p.stride_x = dst->op_params[0];
|
|
7520
|
+
p.stride_y = dst->op_params[1];
|
|
7521
|
+
p.pad_x = dst->op_params[2];
|
|
7522
|
+
p.pad_y = dst->op_params[3];
|
|
7523
|
+
p.dilation_x = dst->op_params[4];
|
|
7524
|
+
p.dilation_y = dst->op_params[5];
|
|
7525
|
+
|
|
7526
|
+
GGML_ASSERT(src0->ne[3] == p.channels);
|
|
7527
|
+
GGML_ASSERT(src1->ne[3] == p.batches);
|
|
7528
|
+
|
|
7529
|
+
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun);
|
|
7530
|
+
}
|
|
7531
|
+
|
|
6771
7532
|
static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
|
6772
7533
|
const float * op_params = (const float *)dst->op_params;
|
|
6773
7534
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun);
|
|
@@ -6929,6 +7690,10 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
|
|
|
6929
7690
|
}
|
|
6930
7691
|
}
|
|
6931
7692
|
|
|
7693
|
+
if (ctx->device->need_compiles) {
|
|
7694
|
+
ggml_vk_load_shaders(ctx->device);
|
|
7695
|
+
}
|
|
7696
|
+
|
|
6932
7697
|
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
|
6933
7698
|
|
|
6934
7699
|
vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
@@ -7177,6 +7942,10 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
|
|
7177
7942
|
|
|
7178
7943
|
ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
|
|
7179
7944
|
|
|
7945
|
+
if (ctx->device->need_compiles) {
|
|
7946
|
+
ggml_vk_load_shaders(ctx->device);
|
|
7947
|
+
}
|
|
7948
|
+
|
|
7180
7949
|
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
|
7181
7950
|
|
|
7182
7951
|
ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz);
|
|
@@ -7236,66 +8005,198 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
|
|
|
7236
8005
|
free(x_chk);
|
|
7237
8006
|
}
|
|
7238
8007
|
|
|
7239
|
-
|
|
8008
|
+
// This does not work without ggml q8_1 quantization support
|
|
8009
|
+
//
|
|
8010
|
+
// typedef uint16_t ggml_half;
|
|
8011
|
+
// typedef uint32_t ggml_half2;
|
|
8012
|
+
//
|
|
8013
|
+
// #define QK8_1 32
|
|
8014
|
+
// typedef struct {
|
|
8015
|
+
// union {
|
|
8016
|
+
// struct {
|
|
8017
|
+
// ggml_half d; // delta
|
|
8018
|
+
// ggml_half s; // d * sum(qs[i])
|
|
8019
|
+
// } GGML_COMMON_AGGR_S;
|
|
8020
|
+
// ggml_half2 ds;
|
|
8021
|
+
// } GGML_COMMON_AGGR_U;
|
|
8022
|
+
// int8_t qs[QK8_1]; // quants
|
|
8023
|
+
// } block_q8_1;
|
|
8024
|
+
//
|
|
8025
|
+
// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) {
|
|
8026
|
+
// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")");
|
|
8027
|
+
// GGML_ASSERT(quant == GGML_TYPE_Q8_1);
|
|
8028
|
+
//
|
|
8029
|
+
// const size_t x_sz = sizeof(float) * ne;
|
|
8030
|
+
// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
|
|
8031
|
+
// float * x = (float *) malloc(x_sz);
|
|
8032
|
+
// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
|
|
8033
|
+
// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
|
|
8034
|
+
// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
8035
|
+
// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
8036
|
+
//
|
|
8037
|
+
// for (size_t i = 0; i < ne; i++) {
|
|
8038
|
+
// x[i] = rand() / (float)RAND_MAX;
|
|
8039
|
+
// }
|
|
8040
|
+
//
|
|
8041
|
+
// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant);
|
|
8042
|
+
//
|
|
8043
|
+
// ggml_pipeline_request_descriptor_sets(ctx->device, p, 1);
|
|
8044
|
+
//
|
|
8045
|
+
// if (ctx->device->need_compiles) {
|
|
8046
|
+
// ggml_vk_load_shaders(ctx->device);
|
|
8047
|
+
// }
|
|
8048
|
+
//
|
|
8049
|
+
// ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
|
8050
|
+
//
|
|
8051
|
+
// ggml_vk_buffer_write(x_buf, 0, x, x_sz);
|
|
8052
|
+
//
|
|
8053
|
+
// vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
|
8054
|
+
// ggml_vk_ctx_begin(ctx->device, subctx);
|
|
8055
|
+
// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(x_buf), ggml_vk_subbuffer(qx_buf), ne);
|
|
8056
|
+
// ggml_vk_ctx_end(subctx);
|
|
8057
|
+
//
|
|
8058
|
+
// auto begin = std::chrono::high_resolution_clock::now();
|
|
8059
|
+
//
|
|
8060
|
+
// ggml_vk_submit(subctx, ctx->fence);
|
|
8061
|
+
// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences");
|
|
8062
|
+
// ctx->device->device.resetFences({ ctx->fence });
|
|
8063
|
+
//
|
|
8064
|
+
// auto end = std::chrono::high_resolution_clock::now();
|
|
8065
|
+
//
|
|
8066
|
+
// double ms_quant = std::chrono::duration_cast<std::chrono::microseconds>(end-begin).count() / 1000.0;
|
|
8067
|
+
// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz);
|
|
8068
|
+
//
|
|
8069
|
+
// ggml_vk_quantize_data(x, qx_res, ne, quant);
|
|
8070
|
+
//
|
|
8071
|
+
// int first_err = -1;
|
|
8072
|
+
//
|
|
8073
|
+
// for (size_t i = 0; i < ne / 32; i++) {
|
|
8074
|
+
// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d));
|
|
8075
|
+
//
|
|
8076
|
+
// if (first_err < 0 && error > 0.1) {
|
|
8077
|
+
// first_err = i;
|
|
8078
|
+
// }
|
|
8079
|
+
//
|
|
8080
|
+
// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s));
|
|
8081
|
+
//
|
|
8082
|
+
// if (first_err < 0 && error > 0.1) {
|
|
8083
|
+
// first_err = i;
|
|
8084
|
+
// }
|
|
8085
|
+
//
|
|
8086
|
+
// for (size_t j = 0; j < 32; j++) {
|
|
8087
|
+
// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]);
|
|
8088
|
+
//
|
|
8089
|
+
// if (first_err < 0 && error > 1) {
|
|
8090
|
+
// first_err = i;
|
|
8091
|
+
// }
|
|
8092
|
+
// }
|
|
8093
|
+
// }
|
|
8094
|
+
//
|
|
8095
|
+
// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl;
|
|
8096
|
+
//
|
|
8097
|
+
// if (first_err != -1) {
|
|
8098
|
+
// std::cerr << "first_error = " << first_err << std::endl;
|
|
8099
|
+
// std::cerr << "Actual result: " << std::endl << std::endl;
|
|
8100
|
+
// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
|
|
8101
|
+
// for (size_t j = 0; j < 32; j++) {
|
|
8102
|
+
// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " ";
|
|
8103
|
+
// }
|
|
8104
|
+
// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl;
|
|
8105
|
+
// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " ";
|
|
8106
|
+
// for (size_t j = 0; j < 32; j++) {
|
|
8107
|
+
// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " ";
|
|
8108
|
+
// }
|
|
8109
|
+
// std::cerr << std::endl;
|
|
8110
|
+
// }
|
|
8111
|
+
//
|
|
8112
|
+
// ggml_vk_destroy_buffer(x_buf);
|
|
8113
|
+
// ggml_vk_destroy_buffer(qx_buf);
|
|
8114
|
+
//
|
|
8115
|
+
// free(x);
|
|
8116
|
+
// free(qx);
|
|
8117
|
+
// free(qx_res);
|
|
8118
|
+
// }
|
|
8119
|
+
|
|
8120
|
+
static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) {
|
|
7240
8121
|
VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")");
|
|
7241
8122
|
const size_t x_ne = m * k * batch;
|
|
7242
8123
|
const size_t y_ne = k * n * batch;
|
|
7243
8124
|
const size_t d_ne = m * n * batch;
|
|
7244
8125
|
|
|
8126
|
+
vk_matmul_pipeline2 * pipelines;
|
|
8127
|
+
|
|
8128
|
+
if (mmq) {
|
|
8129
|
+
pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1;
|
|
8130
|
+
} else {
|
|
8131
|
+
pipelines = ctx->device->pipeline_dequant_mul_mat_mat;
|
|
8132
|
+
}
|
|
8133
|
+
|
|
8134
|
+
const bool fp16acc = ctx->device->fp16;
|
|
8135
|
+
|
|
7245
8136
|
vk_pipeline p;
|
|
7246
8137
|
std::string shname;
|
|
7247
8138
|
if (shader_size == 0) {
|
|
7248
|
-
p =
|
|
8139
|
+
p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s;
|
|
7249
8140
|
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S";
|
|
7250
8141
|
} else if (shader_size == 1) {
|
|
7251
|
-
p =
|
|
8142
|
+
p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m;
|
|
7252
8143
|
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M";
|
|
7253
8144
|
} else if (shader_size == 2) {
|
|
7254
|
-
p =
|
|
8145
|
+
p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l;
|
|
7255
8146
|
shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L";
|
|
7256
8147
|
} else {
|
|
7257
8148
|
GGML_ASSERT(0);
|
|
7258
8149
|
}
|
|
7259
8150
|
|
|
7260
|
-
const size_t kpad = ggml_vk_align_size(k, p->align);
|
|
8151
|
+
const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align);
|
|
7261
8152
|
|
|
7262
|
-
if (k != kpad) {
|
|
8153
|
+
if (mmq || k != kpad) {
|
|
7263
8154
|
if (shader_size == 0) {
|
|
7264
|
-
p =
|
|
8155
|
+
p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s;
|
|
7265
8156
|
shname = std::string(ggml_type_name(quant)) + "_S";
|
|
7266
8157
|
} else if (shader_size == 1) {
|
|
7267
|
-
p =
|
|
8158
|
+
p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m;
|
|
7268
8159
|
shname = std::string(ggml_type_name(quant)) + "_M";
|
|
7269
8160
|
} else if (shader_size == 2) {
|
|
7270
|
-
p =
|
|
8161
|
+
p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l;
|
|
7271
8162
|
shname = std::string(ggml_type_name(quant)) + "_L";
|
|
7272
8163
|
} else {
|
|
7273
8164
|
GGML_ASSERT(0);
|
|
7274
8165
|
}
|
|
7275
8166
|
}
|
|
7276
8167
|
|
|
8168
|
+
if (p == nullptr) {
|
|
8169
|
+
std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl;
|
|
8170
|
+
return;
|
|
8171
|
+
}
|
|
8172
|
+
|
|
7277
8173
|
const size_t x_sz = sizeof(float) * x_ne;
|
|
7278
8174
|
const size_t y_sz = sizeof(float) * y_ne;
|
|
7279
8175
|
const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant);
|
|
8176
|
+
const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz;
|
|
7280
8177
|
const size_t d_sz = sizeof(float) * d_ne;
|
|
7281
8178
|
float * x = (float *) malloc(x_sz);
|
|
7282
8179
|
float * y = (float *) malloc(y_sz);
|
|
7283
8180
|
void * qx = malloc(qx_sz);
|
|
7284
8181
|
vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
7285
8182
|
vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
8183
|
+
vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
7286
8184
|
vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
7287
8185
|
float * d = (float *) malloc(d_sz);
|
|
7288
8186
|
float * d_chk = (float *) malloc(d_sz);
|
|
7289
8187
|
|
|
7290
8188
|
for (size_t i = 0; i < x_ne; i++) {
|
|
7291
8189
|
x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
|
8190
|
+
// x[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
|
8191
|
+
// x[i] = i % k;
|
|
7292
8192
|
}
|
|
7293
8193
|
|
|
7294
8194
|
ggml_vk_quantize_data(x, qx, x_ne, quant);
|
|
7295
8195
|
|
|
7296
8196
|
for (size_t i = 0; i < y_ne; i++) {
|
|
7297
|
-
|
|
7298
|
-
y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
|
8197
|
+
y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f;
|
|
8198
|
+
// y[i] = (i % k == i / k) ? 1.0f : 0.0f;
|
|
8199
|
+
// y[i] = i % k;
|
|
7299
8200
|
}
|
|
7300
8201
|
|
|
7301
8202
|
ggml_pipeline_request_descriptor_sets(ctx->device, p, num_it);
|
|
@@ -7310,6 +8211,13 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
7310
8211
|
ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
|
|
7311
8212
|
}
|
|
7312
8213
|
}
|
|
8214
|
+
if (mmq) {
|
|
8215
|
+
ggml_pipeline_request_descriptor_sets(ctx->device, ctx->device->pipeline_quantize_q8_1, num_it);
|
|
8216
|
+
}
|
|
8217
|
+
|
|
8218
|
+
if (ctx->device->need_compiles) {
|
|
8219
|
+
ggml_vk_load_shaders(ctx->device);
|
|
8220
|
+
}
|
|
7313
8221
|
|
|
7314
8222
|
ggml_pipeline_allocate_descriptor_sets(ctx->device);
|
|
7315
8223
|
|
|
@@ -7318,13 +8226,25 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
7318
8226
|
|
|
7319
8227
|
vk_context subctx = ggml_vk_create_context(ctx, ctx->device->compute_queue);
|
|
7320
8228
|
ggml_vk_ctx_begin(ctx->device, subctx);
|
|
7321
|
-
|
|
7322
|
-
|
|
7323
|
-
ctx, subctx,
|
|
7324
|
-
|
|
7325
|
-
|
|
7326
|
-
|
|
7327
|
-
|
|
8229
|
+
if (mmq) {
|
|
8230
|
+
for (size_t i = 0; i < num_it; i++) {
|
|
8231
|
+
ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne);
|
|
8232
|
+
ggml_vk_matmul(
|
|
8233
|
+
ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
|
|
8234
|
+
m, n, k,
|
|
8235
|
+
k, k, m, k*m, k*n, m*n,
|
|
8236
|
+
split_k, batch, batch, batch, 1, 1, n
|
|
8237
|
+
);
|
|
8238
|
+
}
|
|
8239
|
+
} else {
|
|
8240
|
+
for (size_t i = 0; i < num_it; i++) {
|
|
8241
|
+
ggml_vk_matmul(
|
|
8242
|
+
ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k },
|
|
8243
|
+
m, n, k,
|
|
8244
|
+
k, k, m, k*m, k*n, m*n,
|
|
8245
|
+
split_k, batch, batch, batch, 1, 1, n
|
|
8246
|
+
);
|
|
8247
|
+
}
|
|
7328
8248
|
}
|
|
7329
8249
|
ggml_vk_ctx_end(subctx);
|
|
7330
8250
|
|
|
@@ -7382,7 +8302,11 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
7382
8302
|
|
|
7383
8303
|
double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0);
|
|
7384
8304
|
|
|
7385
|
-
std::cerr << "TEST
|
|
8305
|
+
std::cerr << "TEST dequant matmul " << shname;
|
|
8306
|
+
if (mmq) {
|
|
8307
|
+
std::cerr << " mmq";
|
|
8308
|
+
}
|
|
8309
|
+
std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl;
|
|
7386
8310
|
|
|
7387
8311
|
if (avg_err > 0.01 || std::isnan(avg_err)) {
|
|
7388
8312
|
std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl;
|
|
@@ -7392,6 +8316,12 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
7392
8316
|
std::cerr << "Expected result: " << std::endl << std::endl;
|
|
7393
8317
|
ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b);
|
|
7394
8318
|
|
|
8319
|
+
std::cerr << "src0: " << std::endl << std::endl;
|
|
8320
|
+
ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b);
|
|
8321
|
+
std::cerr << std::endl;
|
|
8322
|
+
std::cerr << "src1: " << std::endl << std::endl;
|
|
8323
|
+
ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b);
|
|
8324
|
+
|
|
7395
8325
|
if (split_k > 1) {
|
|
7396
8326
|
float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k);
|
|
7397
8327
|
ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k);
|
|
@@ -7414,6 +8344,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
|
|
|
7414
8344
|
|
|
7415
8345
|
ggml_vk_destroy_buffer(qx_buf);
|
|
7416
8346
|
ggml_vk_destroy_buffer(y_buf);
|
|
8347
|
+
ggml_vk_destroy_buffer(qy_buf);
|
|
7417
8348
|
ggml_vk_destroy_buffer(d_buf);
|
|
7418
8349
|
|
|
7419
8350
|
free(x);
|
|
@@ -7448,6 +8379,24 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
|
7448
8379
|
};
|
|
7449
8380
|
const size_t num_it = 100;
|
|
7450
8381
|
|
|
8382
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0);
|
|
8383
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0);
|
|
8384
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0);
|
|
8385
|
+
|
|
8386
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true);
|
|
8387
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true);
|
|
8388
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true);
|
|
8389
|
+
|
|
8390
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0);
|
|
8391
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0);
|
|
8392
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0);
|
|
8393
|
+
|
|
8394
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true);
|
|
8395
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true);
|
|
8396
|
+
ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true);
|
|
8397
|
+
|
|
8398
|
+
abort();
|
|
8399
|
+
|
|
7451
8400
|
for (size_t i = 0; i < vals.size(); i += 3) {
|
|
7452
8401
|
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0);
|
|
7453
8402
|
ggml_vk_test_matmul<ggml_fp16_t, float>(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1);
|
|
@@ -7522,11 +8471,11 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
|
|
|
7522
8471
|
}
|
|
7523
8472
|
}
|
|
7524
8473
|
|
|
7525
|
-
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence);
|
|
8474
|
+
static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready);
|
|
7526
8475
|
|
|
7527
8476
|
// Returns true if node has enqueued work into the queue, false otherwise
|
|
7528
8477
|
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
|
|
7529
|
-
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool submit){
|
|
8478
|
+
static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
|
|
7530
8479
|
if (ggml_is_empty(node) || !node->buffer) {
|
|
7531
8480
|
return false;
|
|
7532
8481
|
}
|
|
@@ -7600,6 +8549,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7600
8549
|
case GGML_OP_IM2COL:
|
|
7601
8550
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
7602
8551
|
case GGML_OP_POOL_2D:
|
|
8552
|
+
case GGML_OP_CONV_2D_DW:
|
|
7603
8553
|
case GGML_OP_RWKV_WKV6:
|
|
7604
8554
|
case GGML_OP_RWKV_WKV7:
|
|
7605
8555
|
case GGML_OP_LEAKY_RELU:
|
|
@@ -7663,6 +8613,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7663
8613
|
case GGML_OP_IM2COL:
|
|
7664
8614
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
7665
8615
|
case GGML_OP_POOL_2D:
|
|
8616
|
+
case GGML_OP_CONV_2D_DW:
|
|
7666
8617
|
case GGML_OP_LEAKY_RELU:
|
|
7667
8618
|
{
|
|
7668
8619
|
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
|
@@ -7836,6 +8787,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7836
8787
|
case GGML_OP_POOL_2D:
|
|
7837
8788
|
ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun);
|
|
7838
8789
|
|
|
8790
|
+
break;
|
|
8791
|
+
case GGML_OP_CONV_2D_DW:
|
|
8792
|
+
ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun);
|
|
8793
|
+
|
|
7839
8794
|
break;
|
|
7840
8795
|
case GGML_OP_LEAKY_RELU:
|
|
7841
8796
|
ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun);
|
|
@@ -7898,7 +8853,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7898
8853
|
|
|
7899
8854
|
ctx->compute_ctx.reset();
|
|
7900
8855
|
|
|
7901
|
-
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false);
|
|
8856
|
+
bool ok = ggml_vk_compute_forward(ctx, node_begin, node_idx_begin, false, almost_ready);
|
|
7902
8857
|
if (!ok) {
|
|
7903
8858
|
if (node->op == GGML_OP_UNARY) {
|
|
7904
8859
|
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
|
@@ -7912,7 +8867,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod
|
|
|
7912
8867
|
return true;
|
|
7913
8868
|
}
|
|
7914
8869
|
|
|
7915
|
-
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true){
|
|
8870
|
+
static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) {
|
|
7916
8871
|
ggml_backend_buffer * buf = nullptr;
|
|
7917
8872
|
|
|
7918
8873
|
switch (tensor->op) {
|
|
@@ -7957,6 +8912,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
7957
8912
|
case GGML_OP_IM2COL:
|
|
7958
8913
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
7959
8914
|
case GGML_OP_POOL_2D:
|
|
8915
|
+
case GGML_OP_CONV_2D_DW:
|
|
7960
8916
|
case GGML_OP_RWKV_WKV6:
|
|
7961
8917
|
case GGML_OP_RWKV_WKV7:
|
|
7962
8918
|
case GGML_OP_LEAKY_RELU:
|
|
@@ -8015,12 +8971,15 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
|
|
8015
8971
|
memcpy(cpy.dst, cpy.src, cpy.n);
|
|
8016
8972
|
}
|
|
8017
8973
|
|
|
8018
|
-
|
|
8974
|
+
if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) {
|
|
8975
|
+
ggml_vk_submit(subctx, ctx->almost_ready_fence);
|
|
8976
|
+
ctx->almost_ready_fence_pending = true;
|
|
8977
|
+
} else {
|
|
8978
|
+
ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{});
|
|
8979
|
+
}
|
|
8019
8980
|
|
|
8020
8981
|
if (use_fence) {
|
|
8021
|
-
|
|
8022
|
-
|
|
8023
|
-
ctx->device->device.resetFences({ ctx->fence });
|
|
8982
|
+
ggml_vk_wait_for_fence(ctx);
|
|
8024
8983
|
}
|
|
8025
8984
|
#ifdef GGML_VULKAN_CHECK_RESULTS
|
|
8026
8985
|
ggml_vk_check_results_1(tensor);
|
|
@@ -8106,6 +9065,7 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
|
|
|
8106
9065
|
ctx->gc.events.clear();
|
|
8107
9066
|
|
|
8108
9067
|
ctx->device->device.destroyFence(ctx->fence);
|
|
9068
|
+
ctx->device->device.destroyFence(ctx->almost_ready_fence);
|
|
8109
9069
|
}
|
|
8110
9070
|
|
|
8111
9071
|
static int ggml_vk_get_device_count() {
|
|
@@ -8452,8 +9412,7 @@ static void ggml_backend_vk_synchronize(ggml_backend_t backend) {
|
|
|
8452
9412
|
}
|
|
8453
9413
|
|
|
8454
9414
|
ggml_vk_submit(transfer_ctx, ctx->fence);
|
|
8455
|
-
|
|
8456
|
-
ctx->device->device.resetFences({ ctx->fence });
|
|
9415
|
+
ggml_vk_wait_for_fence(ctx);
|
|
8457
9416
|
|
|
8458
9417
|
for (auto& cpy : transfer_ctx->out_memcpys) {
|
|
8459
9418
|
memcpy(cpy.dst, cpy.src, cpy.n);
|
|
@@ -8472,7 +9431,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
8472
9431
|
|
|
8473
9432
|
uint64_t total_mat_mul_bytes = 0;
|
|
8474
9433
|
for (int i = 0; i < cgraph->n_nodes; i++) {
|
|
8475
|
-
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
|
|
9434
|
+
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
|
|
8476
9435
|
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
|
|
8477
9436
|
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
8478
9437
|
}
|
|
@@ -8514,11 +9473,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
8514
9473
|
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
|
|
8515
9474
|
}
|
|
8516
9475
|
|
|
9476
|
+
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
|
|
9477
|
+
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
|
|
8517
9478
|
bool submit = (submitted_nodes >= nodes_per_submit) ||
|
|
8518
9479
|
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
|
|
8519
|
-
(i == last_node)
|
|
9480
|
+
(i == last_node) ||
|
|
9481
|
+
(almost_ready && !ctx->almost_ready_fence_pending);
|
|
8520
9482
|
|
|
8521
|
-
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
|
|
9483
|
+
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
|
|
8522
9484
|
|
|
8523
9485
|
if (enqueued) {
|
|
8524
9486
|
++submitted_nodes;
|
|
@@ -8530,7 +9492,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
|
|
|
8530
9492
|
#endif
|
|
8531
9493
|
}
|
|
8532
9494
|
|
|
8533
|
-
if (submit) {
|
|
9495
|
+
if (submit && enqueued) {
|
|
8534
9496
|
first_node_in_batch = true;
|
|
8535
9497
|
submitted_nodes = 0;
|
|
8536
9498
|
mul_mat_bytes = 0;
|
|
@@ -8687,7 +9649,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8687
9649
|
case GGML_UNARY_OP_RELU:
|
|
8688
9650
|
case GGML_UNARY_OP_TANH:
|
|
8689
9651
|
case GGML_UNARY_OP_SIGMOID:
|
|
8690
|
-
return ggml_is_contiguous(op->src[0]) &&
|
|
9652
|
+
return ggml_is_contiguous(op->src[0]) &&
|
|
9653
|
+
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
9654
|
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
|
9655
|
+
(op->src[0]->type == op->type);
|
|
8691
9656
|
default:
|
|
8692
9657
|
return false;
|
|
8693
9658
|
}
|
|
@@ -8705,6 +9670,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8705
9670
|
switch (src0_type) {
|
|
8706
9671
|
case GGML_TYPE_F32:
|
|
8707
9672
|
case GGML_TYPE_F16:
|
|
9673
|
+
case GGML_TYPE_BF16:
|
|
8708
9674
|
case GGML_TYPE_Q4_0:
|
|
8709
9675
|
case GGML_TYPE_Q4_1:
|
|
8710
9676
|
case GGML_TYPE_Q5_0:
|
|
@@ -8740,19 +9706,23 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8740
9706
|
if (a->ne[3] != b->ne[3]) {
|
|
8741
9707
|
return false;
|
|
8742
9708
|
}
|
|
8743
|
-
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) ||
|
|
9709
|
+
if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) ||
|
|
8744
9710
|
!(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) {
|
|
8745
9711
|
return false;
|
|
8746
9712
|
}
|
|
9713
|
+
if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) {
|
|
9714
|
+
// We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader.
|
|
9715
|
+
// So don't support this combination for now.
|
|
9716
|
+
return false;
|
|
9717
|
+
}
|
|
8747
9718
|
|
|
8748
9719
|
return true;
|
|
8749
9720
|
} break;
|
|
8750
9721
|
case GGML_OP_FLASH_ATTN_EXT:
|
|
8751
9722
|
{
|
|
8752
9723
|
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
|
|
8753
|
-
|
|
8754
|
-
|
|
8755
|
-
}
|
|
9724
|
+
auto device = ggml_vk_get_device(ctx->device);
|
|
9725
|
+
bool coopmat2 = device->coopmat2;
|
|
8756
9726
|
switch (op->src[0]->ne[0]) {
|
|
8757
9727
|
case 64:
|
|
8758
9728
|
case 80:
|
|
@@ -8764,6 +9734,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8764
9734
|
default:
|
|
8765
9735
|
return false;
|
|
8766
9736
|
}
|
|
9737
|
+
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
|
|
9738
|
+
// different head sizes of K and V are not supported yet
|
|
9739
|
+
return false;
|
|
9740
|
+
}
|
|
8767
9741
|
if (op->src[0]->type != GGML_TYPE_F32) {
|
|
8768
9742
|
return false;
|
|
8769
9743
|
}
|
|
@@ -8781,10 +9755,12 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8781
9755
|
switch (op->src[1]->type) {
|
|
8782
9756
|
case GGML_TYPE_F16:
|
|
8783
9757
|
case GGML_TYPE_Q4_0:
|
|
9758
|
+
case GGML_TYPE_Q8_0:
|
|
9759
|
+
// supported in scalar and coopmat2 paths
|
|
9760
|
+
break;
|
|
8784
9761
|
case GGML_TYPE_Q4_1:
|
|
8785
9762
|
case GGML_TYPE_Q5_0:
|
|
8786
9763
|
case GGML_TYPE_Q5_1:
|
|
8787
|
-
case GGML_TYPE_Q8_0:
|
|
8788
9764
|
// K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently
|
|
8789
9765
|
//case GGML_TYPE_Q2_K:
|
|
8790
9766
|
//case GGML_TYPE_Q3_K:
|
|
@@ -8800,10 +9776,18 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8800
9776
|
//case GGML_TYPE_IQ3_S:
|
|
8801
9777
|
//case GGML_TYPE_IQ4_XS:
|
|
8802
9778
|
case GGML_TYPE_IQ4_NL:
|
|
9779
|
+
// currently supported only in coopmat2 path
|
|
9780
|
+
if (!coopmat2) {
|
|
9781
|
+
return false;
|
|
9782
|
+
}
|
|
8803
9783
|
break;
|
|
8804
9784
|
default:
|
|
8805
9785
|
return false;
|
|
8806
9786
|
}
|
|
9787
|
+
if (!coopmat2 && !device->subgroup_shuffle) {
|
|
9788
|
+
// scalar FA uses subgroupShuffle
|
|
9789
|
+
return false;
|
|
9790
|
+
}
|
|
8807
9791
|
return true;
|
|
8808
9792
|
}
|
|
8809
9793
|
case GGML_OP_GET_ROWS:
|
|
@@ -8811,6 +9795,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8811
9795
|
switch (op->src[0]->type) {
|
|
8812
9796
|
case GGML_TYPE_F32:
|
|
8813
9797
|
case GGML_TYPE_F16:
|
|
9798
|
+
case GGML_TYPE_BF16:
|
|
8814
9799
|
case GGML_TYPE_Q4_0:
|
|
8815
9800
|
case GGML_TYPE_Q4_1:
|
|
8816
9801
|
case GGML_TYPE_Q5_0:
|
|
@@ -8841,6 +9826,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8841
9826
|
switch (src1_type) {
|
|
8842
9827
|
case GGML_TYPE_F32:
|
|
8843
9828
|
case GGML_TYPE_F16:
|
|
9829
|
+
case GGML_TYPE_BF16:
|
|
8844
9830
|
case GGML_TYPE_Q4_0:
|
|
8845
9831
|
case GGML_TYPE_Q4_1:
|
|
8846
9832
|
case GGML_TYPE_Q5_0:
|
|
@@ -8854,6 +9840,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8854
9840
|
}
|
|
8855
9841
|
if (src1_type == GGML_TYPE_F32) {
|
|
8856
9842
|
switch (src0_type) {
|
|
9843
|
+
case GGML_TYPE_F16:
|
|
8857
9844
|
case GGML_TYPE_Q4_0:
|
|
8858
9845
|
case GGML_TYPE_Q4_1:
|
|
8859
9846
|
case GGML_TYPE_Q5_0:
|
|
@@ -8882,16 +9869,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8882
9869
|
case GGML_OP_VIEW:
|
|
8883
9870
|
case GGML_OP_PERMUTE:
|
|
8884
9871
|
case GGML_OP_TRANSPOSE:
|
|
9872
|
+
case GGML_OP_RMS_NORM:
|
|
8885
9873
|
return true;
|
|
8886
9874
|
case GGML_OP_NORM:
|
|
8887
9875
|
case GGML_OP_GROUP_NORM:
|
|
8888
|
-
case GGML_OP_RMS_NORM:
|
|
8889
9876
|
case GGML_OP_L2_NORM:
|
|
8890
9877
|
return ggml_is_contiguous(op->src[0]);
|
|
8891
9878
|
case GGML_OP_ADD:
|
|
8892
9879
|
case GGML_OP_SUB:
|
|
8893
9880
|
case GGML_OP_MUL:
|
|
8894
9881
|
case GGML_OP_DIV:
|
|
9882
|
+
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
|
9883
|
+
(op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
|
|
9884
|
+
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
|
|
8895
9885
|
case GGML_OP_SILU_BACK:
|
|
8896
9886
|
case GGML_OP_RMS_NORM_BACK:
|
|
8897
9887
|
case GGML_OP_SQR:
|
|
@@ -8899,9 +9889,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8899
9889
|
case GGML_OP_COS:
|
|
8900
9890
|
case GGML_OP_CLAMP:
|
|
8901
9891
|
return op->src[0]->type == GGML_TYPE_F32;
|
|
9892
|
+
case GGML_OP_UPSCALE:
|
|
9893
|
+
return op->op_params[0] == GGML_SCALE_MODE_NEAREST;
|
|
8902
9894
|
case GGML_OP_ACC:
|
|
8903
9895
|
case GGML_OP_CONCAT:
|
|
8904
|
-
case GGML_OP_UPSCALE:
|
|
8905
9896
|
case GGML_OP_SCALE:
|
|
8906
9897
|
case GGML_OP_PAD:
|
|
8907
9898
|
case GGML_OP_DIAG_MASK_INF:
|
|
@@ -8914,6 +9905,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
|
|
8914
9905
|
case GGML_OP_COUNT_EQUAL:
|
|
8915
9906
|
case GGML_OP_IM2COL:
|
|
8916
9907
|
case GGML_OP_TIMESTEP_EMBEDDING:
|
|
9908
|
+
case GGML_OP_CONV_2D_DW:
|
|
8917
9909
|
case GGML_OP_POOL_2D:
|
|
8918
9910
|
case GGML_OP_RWKV_WKV6:
|
|
8919
9911
|
case GGML_OP_RWKV_WKV7:
|
|
@@ -9254,7 +10246,7 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
9254
10246
|
}
|
|
9255
10247
|
|
|
9256
10248
|
if (tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
|
9257
|
-
const float *params = (const float *)tensor->op_params;
|
|
10249
|
+
const float * params = (const float *)tensor->op_params;
|
|
9258
10250
|
tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]);
|
|
9259
10251
|
} else if (tensor->op == GGML_OP_MUL_MAT) {
|
|
9260
10252
|
tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]);
|
|
@@ -9269,9 +10261,10 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
9269
10261
|
} else if (tensor->op == GGML_OP_CONCAT) {
|
|
9270
10262
|
tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params);
|
|
9271
10263
|
} else if (tensor->op == GGML_OP_UPSCALE) {
|
|
9272
|
-
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]);
|
|
10264
|
+
tensor_clone = ggml_upscale_ext(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->op_params[0], tensor->op_params[1], (ggml_scale_mode) tensor->op_params[0]);
|
|
9273
10265
|
} else if (tensor->op == GGML_OP_SCALE) {
|
|
9274
|
-
|
|
10266
|
+
const float * params = (const float *)tensor->op_params;
|
|
10267
|
+
tensor_clone = ggml_scale(ggml_ctx, src_clone[0], params[0]);
|
|
9275
10268
|
} else if (tensor->op == GGML_OP_SQR) {
|
|
9276
10269
|
tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]);
|
|
9277
10270
|
} else if (tensor->op == GGML_OP_SIN) {
|
|
@@ -9279,7 +10272,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
9279
10272
|
} else if (tensor->op == GGML_OP_COS) {
|
|
9280
10273
|
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
|
|
9281
10274
|
} else if (tensor->op == GGML_OP_CLAMP) {
|
|
9282
|
-
|
|
10275
|
+
const float * params = (const float *)tensor->op_params;
|
|
10276
|
+
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
|
|
9283
10277
|
} else if (tensor->op == GGML_OP_PAD) {
|
|
9284
10278
|
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]);
|
|
9285
10279
|
} else if (tensor->op == GGML_OP_REPEAT) {
|
|
@@ -9293,7 +10287,8 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
9293
10287
|
} else if (tensor->op == GGML_OP_NORM) {
|
|
9294
10288
|
tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
|
9295
10289
|
} else if (tensor->op == GGML_OP_GROUP_NORM) {
|
|
9296
|
-
|
|
10290
|
+
const float * float_params = (const float *)tensor->op_params;
|
|
10291
|
+
tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]);
|
|
9297
10292
|
} else if (tensor->op == GGML_OP_RMS_NORM) {
|
|
9298
10293
|
tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
|
|
9299
10294
|
} else if (tensor->op == GGML_OP_RMS_NORM_BACK) {
|
|
@@ -9306,14 +10301,15 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
|
|
9306
10301
|
tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps);
|
|
9307
10302
|
} else if (tensor->op == GGML_OP_SOFT_MAX) {
|
|
9308
10303
|
if (src1 != nullptr) {
|
|
9309
|
-
|
|
10304
|
+
const float * params = (const float *)tensor->op_params;
|
|
10305
|
+
tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]);
|
|
9310
10306
|
} else {
|
|
9311
10307
|
tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]);
|
|
9312
10308
|
}
|
|
9313
10309
|
} else if (tensor->op == GGML_OP_SOFT_MAX_BACK) {
|
|
9314
10310
|
tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]);
|
|
9315
10311
|
} else if (tensor->op == GGML_OP_DIAG_MASK_INF) {
|
|
9316
|
-
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0],
|
|
10312
|
+
tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]);
|
|
9317
10313
|
} else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) {
|
|
9318
10314
|
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
|
9319
10315
|
const int mode = ((int32_t *) tensor->op_params)[2];
|