@fugood/llama.node 0.3.2 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +7 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +17 -7
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +89 -27
- package/src/LlamaContext.h +2 -0
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +240 -168
- package/src/llama.cpp/.github/workflows/docker.yml +8 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +14 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -4
- package/src/llama.cpp/common/arg.cpp +986 -770
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +212 -351
- package/src/llama.cpp/common/common.h +204 -117
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +163 -121
- package/src/llama.cpp/common/sampling.h +41 -20
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +134 -161
- package/src/llama.cpp/examples/CMakeLists.txt +33 -14
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +19 -18
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +41 -87
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +263 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +83 -22
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
- package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +73 -114
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
- package/src/llama.cpp/examples/server/server.cpp +2073 -1339
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +354 -277
- package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/simple/simple.cpp +130 -94
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
- package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +159 -417
- package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +93 -52
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +4 -8
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +779 -194
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +55 -10
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +4317 -2979
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -38
- package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
- package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +62 -20
- package/src/llama.cpp/tests/test-sampling.cpp +163 -138
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -20,6 +20,7 @@
|
|
|
20
20
|
#include "shaderop_mul_mat_q8_0.h"
|
|
21
21
|
#include "shaderop_mul_mat_q4_0.h"
|
|
22
22
|
#include "shaderop_mul_mat_q4_1.h"
|
|
23
|
+
#include "shaderop_mul_mat_q4_k.h"
|
|
23
24
|
#include "shaderop_mul_mat_q6_k.h"
|
|
24
25
|
#include "shaderop_mul_mat_mat_f32.h"
|
|
25
26
|
#include "shaderop_getrows_f32.h"
|
|
@@ -27,8 +28,10 @@
|
|
|
27
28
|
#include "shaderop_getrows_q4_0.h"
|
|
28
29
|
#include "shaderop_getrows_q4_1.h"
|
|
29
30
|
#include "shaderop_getrows_q6_k.h"
|
|
30
|
-
#include "
|
|
31
|
-
#include "
|
|
31
|
+
#include "shaderop_rope_norm_f16.h"
|
|
32
|
+
#include "shaderop_rope_norm_f32.h"
|
|
33
|
+
#include "shaderop_rope_neox_f16.h"
|
|
34
|
+
#include "shaderop_rope_neox_f32.h"
|
|
32
35
|
#include "shaderop_cpy_f16_f16.h"
|
|
33
36
|
#include "shaderop_cpy_f16_f32.h"
|
|
34
37
|
#include "shaderop_cpy_f32_f16.h"
|
|
@@ -42,6 +45,7 @@
|
|
|
42
45
|
#include <cstring>
|
|
43
46
|
#include <iostream>
|
|
44
47
|
#include <memory>
|
|
48
|
+
#include <mutex>
|
|
45
49
|
#include <stdexcept>
|
|
46
50
|
#include <string>
|
|
47
51
|
#include <unordered_map>
|
|
@@ -273,18 +277,9 @@ static std::vector<ggml_vk_device> ggml_vk_available_devices_internal(size_t mem
|
|
|
273
277
|
return results;
|
|
274
278
|
}
|
|
275
279
|
|
|
276
|
-
|
|
277
|
-
ggml_vk_device
|
|
278
|
-
|
|
279
|
-
*count = devices.size();
|
|
280
|
-
if (devices.empty()) {
|
|
281
|
-
return nullptr;
|
|
282
|
-
}
|
|
283
|
-
|
|
284
|
-
size_t nbytes = sizeof (ggml_vk_device) * (devices.size());
|
|
285
|
-
auto * arr = static_cast<ggml_vk_device *>(malloc(nbytes));
|
|
286
|
-
memcpy(arr, devices.data(), nbytes);
|
|
287
|
-
return arr;
|
|
280
|
+
static std::vector<ggml_vk_device>& ggml_vk_available_devices() {
|
|
281
|
+
static std::vector<ggml_vk_device> devices = ggml_vk_available_devices_internal(0);
|
|
282
|
+
return devices;
|
|
288
283
|
}
|
|
289
284
|
|
|
290
285
|
static void ggml_vk_filterByVendor(std::vector<ggml_vk_device>& devices, const std::string& targetVendor) {
|
|
@@ -341,7 +336,7 @@ ggml_vk_device ggml_vk_current_device() {
|
|
|
341
336
|
if (!komputeManager()->hasDevice())
|
|
342
337
|
return ggml_vk_device();
|
|
343
338
|
|
|
344
|
-
auto devices =
|
|
339
|
+
auto devices = ggml_vk_available_devices();
|
|
345
340
|
ggml_vk_filterByName(devices, komputeManager()->physicalDevice()->getProperties().deviceName.data());
|
|
346
341
|
GGML_ASSERT(!devices.empty());
|
|
347
342
|
return devices.front();
|
|
@@ -352,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
|
|
|
352
347
|
std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
|
|
353
348
|
vk::DescriptorPoolSize(
|
|
354
349
|
vk::DescriptorType::eStorageBuffer,
|
|
355
|
-
|
|
350
|
+
4 * size // Descriptor count is number of possible tensors to pass into an algorithm
|
|
356
351
|
)
|
|
357
352
|
};
|
|
358
353
|
|
|
@@ -795,7 +790,8 @@ static void ggml_vk_soft_max(
|
|
|
795
790
|
const std::shared_ptr<kp::Tensor>& out,
|
|
796
791
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
797
792
|
int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
|
|
798
|
-
float scale
|
|
793
|
+
float scale, float max_bias, float m0, float m1,
|
|
794
|
+
uint32_t n_head_log2
|
|
799
795
|
) {
|
|
800
796
|
const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
|
|
801
797
|
kp::shader_data::op_softmax_comp_spv_len);
|
|
@@ -803,12 +799,14 @@ static void ggml_vk_soft_max(
|
|
|
803
799
|
struct PushConstants {
|
|
804
800
|
uint32_t inAOff, inBOff, outOff;
|
|
805
801
|
int32_t ne00, ne01, ne02;
|
|
806
|
-
float scale;
|
|
802
|
+
float scale, max_bias, m0, m1;
|
|
803
|
+
uint32_t n_head_log2;
|
|
807
804
|
int32_t mask;
|
|
808
805
|
} pushConsts {
|
|
809
806
|
safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
810
807
|
ne00, ne01, ne02,
|
|
811
|
-
scale,
|
|
808
|
+
scale, max_bias, m0, m1,
|
|
809
|
+
n_head_log2,
|
|
812
810
|
bool(inB)
|
|
813
811
|
};
|
|
814
812
|
|
|
@@ -918,9 +916,9 @@ static void ggml_vk_mul_mat_f16(
|
|
|
918
916
|
const std::shared_ptr<kp::Tensor>& out,
|
|
919
917
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
920
918
|
int32_t ne00, int32_t ne01, int32_t ne02,
|
|
921
|
-
uint32_t nb00, uint32_t nb01, uint32_t nb02,
|
|
919
|
+
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
922
920
|
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
|
|
923
|
-
uint32_t nb10, uint32_t nb11, uint32_t nb12,
|
|
921
|
+
uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13,
|
|
924
922
|
int32_t ne0, int32_t ne1,
|
|
925
923
|
uint32_t r2, uint32_t r3
|
|
926
924
|
) {
|
|
@@ -930,17 +928,17 @@ static void ggml_vk_mul_mat_f16(
|
|
|
930
928
|
struct PushConstants {
|
|
931
929
|
uint32_t inAOff, inBOff, outOff;
|
|
932
930
|
int32_t ne00, ne01, ne02;
|
|
933
|
-
uint32_t nb00, nb01, nb02;
|
|
931
|
+
uint32_t nb00, nb01, nb02, nb03;
|
|
934
932
|
int32_t ne10, ne11, ne12;
|
|
935
|
-
uint32_t nb10, nb11, nb12;
|
|
933
|
+
uint32_t nb10, nb11, nb12, nb13;
|
|
936
934
|
int32_t ne0, ne1;
|
|
937
935
|
uint32_t r2, r3;
|
|
938
936
|
} pushConsts {
|
|
939
937
|
safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
940
938
|
ne00, ne01, ne02,
|
|
941
|
-
nb00, nb01, nb02,
|
|
939
|
+
nb00, nb01, nb02, nb03,
|
|
942
940
|
ne10, ne11, ne12,
|
|
943
|
-
nb10, nb11, nb12,
|
|
941
|
+
nb10, nb11, nb12, nb13,
|
|
944
942
|
ne0, ne1,
|
|
945
943
|
r2, r3
|
|
946
944
|
};
|
|
@@ -1020,6 +1018,8 @@ static void ggml_vk_mul_mat_impl(
|
|
|
1020
1018
|
int32_t ne00, int32_t ne01, int32_t ne02,
|
|
1021
1019
|
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
|
|
1022
1020
|
int32_t ne0, int32_t ne1,
|
|
1021
|
+
uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
1022
|
+
uint32_t nb11, uint32_t nb12, uint32_t nb13,
|
|
1023
1023
|
uint32_t r2, uint32_t r3
|
|
1024
1024
|
) {
|
|
1025
1025
|
struct PushConstants {
|
|
@@ -1027,19 +1027,23 @@ static void ggml_vk_mul_mat_impl(
|
|
|
1027
1027
|
int32_t ne00, ne01, ne02;
|
|
1028
1028
|
int32_t ne10, ne12;
|
|
1029
1029
|
int32_t ne0, ne1;
|
|
1030
|
+
uint32_t nb01, nb02, nb03;
|
|
1031
|
+
uint32_t nb11, nb12, nb13;
|
|
1030
1032
|
uint32_t r2, r3;
|
|
1031
1033
|
} pushConsts {
|
|
1032
1034
|
safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
1033
1035
|
ne00, ne01, ne02,
|
|
1034
1036
|
ne10, ne12,
|
|
1035
1037
|
ne0, ne1,
|
|
1038
|
+
nb01, nb02, nb03,
|
|
1039
|
+
nb11, nb12, nb13,
|
|
1036
1040
|
r2, r3
|
|
1037
1041
|
};
|
|
1038
1042
|
|
|
1039
1043
|
auto name = std::string(__func__) + "_" + suffix;
|
|
1040
1044
|
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
|
1041
1045
|
if (!komputeManager()->hasAlgorithm(name)) {
|
|
1042
|
-
const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
|
|
1046
|
+
const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8;
|
|
1043
1047
|
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
|
|
1044
1048
|
} else {
|
|
1045
1049
|
s_algo = komputeManager()->getAlgorithm(name);
|
|
@@ -1075,34 +1079,84 @@ static void ggml_vk_mul_mat_q8_0(Args&&... args) {
|
|
|
1075
1079
|
ggml_vk_mul_mat_impl(spirv, "q8_0", 1/*We access blocks unaligned*/, std::forward<Args>(args)...);
|
|
1076
1080
|
}
|
|
1077
1081
|
|
|
1082
|
+
static void ggml_vk_mul_mat_q4_k(
|
|
1083
|
+
kp::Sequence& seq,
|
|
1084
|
+
const std::shared_ptr<kp::Tensor>& inA,
|
|
1085
|
+
const std::shared_ptr<kp::Tensor>& inB,
|
|
1086
|
+
const std::shared_ptr<kp::Tensor>& out,
|
|
1087
|
+
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
1088
|
+
int32_t ne00, int32_t ne01, int32_t ne02,
|
|
1089
|
+
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
|
|
1090
|
+
int32_t ne0, int32_t ne1,
|
|
1091
|
+
uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
1092
|
+
uint32_t nb11, uint32_t nb12, uint32_t nb13,
|
|
1093
|
+
uint32_t r2, uint32_t r3
|
|
1094
|
+
) {
|
|
1095
|
+
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
|
|
1096
|
+
kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
|
|
1097
|
+
|
|
1098
|
+
struct PushConstants {
|
|
1099
|
+
uint32_t inAOff, inBOff, outOff;
|
|
1100
|
+
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
|
|
1101
|
+
uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
|
|
1102
|
+
uint32_t r2, r3;
|
|
1103
|
+
} pushConsts {
|
|
1104
|
+
inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
1105
|
+
ne00, ne10, ne0, ne1, ne01, ne02, ne12,
|
|
1106
|
+
nb01, nb02, nb03, nb11, nb12, nb13,
|
|
1107
|
+
r2, r3
|
|
1108
|
+
};
|
|
1109
|
+
|
|
1110
|
+
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
|
1111
|
+
if (!komputeManager()->hasAlgorithm(__func__)) {
|
|
1112
|
+
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)}, {}, {pushConsts});
|
|
1113
|
+
} else {
|
|
1114
|
+
s_algo = komputeManager()->getAlgorithm(__func__);
|
|
1115
|
+
s_algo->setTensors({inA, inB, out});
|
|
1116
|
+
s_algo->setWorkgroup({unsigned((ne01 + 3)/4), unsigned(ne11), unsigned(ne12) * unsigned(ne13)});
|
|
1117
|
+
s_algo->setPushConstants<PushConstants>({pushConsts});
|
|
1118
|
+
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
|
1119
|
+
}
|
|
1120
|
+
seq.record<kp::OpAlgoDispatch>(s_algo);
|
|
1121
|
+
}
|
|
1122
|
+
|
|
1078
1123
|
static void ggml_vk_mul_mat_q6_k(
|
|
1079
1124
|
kp::Sequence& seq,
|
|
1080
1125
|
const std::shared_ptr<kp::Tensor>& inA,
|
|
1081
1126
|
const std::shared_ptr<kp::Tensor>& inB,
|
|
1082
1127
|
const std::shared_ptr<kp::Tensor>& out,
|
|
1083
1128
|
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
1084
|
-
int32_t ne00, int32_t
|
|
1085
|
-
int32_t
|
|
1129
|
+
int32_t ne00, int32_t ne01, int32_t ne02,
|
|
1130
|
+
int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
|
|
1131
|
+
int32_t ne0, int32_t ne1,
|
|
1132
|
+
uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
1133
|
+
uint32_t nb11, uint32_t nb12, uint32_t nb13,
|
|
1134
|
+
uint32_t r2, uint32_t r3
|
|
1086
1135
|
) {
|
|
1087
1136
|
const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
|
|
1088
1137
|
kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
|
|
1089
1138
|
|
|
1090
1139
|
struct PushConstants {
|
|
1091
1140
|
uint32_t inAOff, inBOff, outOff;
|
|
1092
|
-
int32_t ne00, ne10, ne0, ne1, ne01,
|
|
1141
|
+
int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
|
|
1142
|
+
uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
|
|
1143
|
+
uint32_t r2, r3;
|
|
1093
1144
|
} pushConsts {
|
|
1094
1145
|
inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
|
|
1095
|
-
ne00, ne10, ne0, ne1, ne01, ne12
|
|
1146
|
+
ne00, ne10, ne0, ne1, ne01, ne02, ne12,
|
|
1147
|
+
nb01, nb02, nb03, nb11, nb12, nb13,
|
|
1148
|
+
r2, r3
|
|
1096
1149
|
};
|
|
1097
1150
|
|
|
1098
1151
|
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
|
1099
1152
|
if (!komputeManager()->hasAlgorithm(__func__)) {
|
|
1100
|
-
const uint32_t local_x =
|
|
1101
|
-
|
|
1153
|
+
const uint32_t local_x = 2;
|
|
1154
|
+
const uint32_t local_y = ggml_vk_current_device().subgroupSize;
|
|
1155
|
+
s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts});
|
|
1102
1156
|
} else {
|
|
1103
1157
|
s_algo = komputeManager()->getAlgorithm(__func__);
|
|
1104
1158
|
s_algo->setTensors({inA, inB, out});
|
|
1105
|
-
s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
|
|
1159
|
+
s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)});
|
|
1106
1160
|
s_algo->setPushConstants<PushConstants>({pushConsts});
|
|
1107
1161
|
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
|
1108
1162
|
}
|
|
@@ -1190,10 +1244,11 @@ static void ggml_vk_rope(
|
|
|
1190
1244
|
kp::Sequence& seq,
|
|
1191
1245
|
const std::shared_ptr<kp::Tensor>& inA,
|
|
1192
1246
|
const std::shared_ptr<kp::Tensor>& inB,
|
|
1247
|
+
const std::shared_ptr<kp::Tensor>& inC,
|
|
1193
1248
|
const std::shared_ptr<kp::Tensor>& out,
|
|
1194
|
-
uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
|
|
1249
|
+
uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
|
|
1195
1250
|
ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
|
|
1196
|
-
float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
|
1251
|
+
float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
|
|
1197
1252
|
int32_t ne01, int32_t ne02, int32_t ne03,
|
|
1198
1253
|
uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
|
|
1199
1254
|
int32_t ne0,
|
|
@@ -1201,11 +1256,17 @@ static void ggml_vk_rope(
|
|
|
1201
1256
|
) {
|
|
1202
1257
|
GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
|
|
1203
1258
|
|
|
1204
|
-
static const auto
|
|
1205
|
-
kp::shader_data::
|
|
1259
|
+
static const auto spirv_norm_f16 = getSpirvShader(
|
|
1260
|
+
kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
|
|
1261
|
+
);
|
|
1262
|
+
static const auto spirv_norm_f32 = getSpirvShader(
|
|
1263
|
+
kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
|
|
1206
1264
|
);
|
|
1207
|
-
static const auto
|
|
1208
|
-
kp::shader_data::
|
|
1265
|
+
static const auto spirv_neox_f16 = getSpirvShader(
|
|
1266
|
+
kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
|
|
1267
|
+
);
|
|
1268
|
+
static const auto spirv_neox_f32 = getSpirvShader(
|
|
1269
|
+
kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
|
|
1209
1270
|
);
|
|
1210
1271
|
|
|
1211
1272
|
int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
|
|
@@ -1220,32 +1281,40 @@ static void ggml_vk_rope(
|
|
|
1220
1281
|
GGML_ASSERT(nb0 % type_size == 0);
|
|
1221
1282
|
|
|
1222
1283
|
struct PushConstants {
|
|
1223
|
-
uint32_t inAOff, inBOff, outOff;
|
|
1284
|
+
uint32_t inAOff, inBOff, inCOff, outOff;
|
|
1224
1285
|
int32_t n_dims, mode, n_ctx_orig;
|
|
1225
|
-
float freq_base, freq_scale
|
|
1286
|
+
float freq_base, freq_scale;
|
|
1287
|
+
bool has_freq_factors;
|
|
1288
|
+
float ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1226
1289
|
uint32_t nb00, nb01, nb02, nb03;
|
|
1227
1290
|
int32_t ne0;
|
|
1228
1291
|
uint32_t nb0, nb1, nb2, nb3;
|
|
1229
1292
|
} pushConsts {
|
|
1230
|
-
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
|
|
1293
|
+
safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
|
|
1231
1294
|
n_dims, mode, n_ctx_orig,
|
|
1232
|
-
freq_base, freq_scale,
|
|
1295
|
+
freq_base, freq_scale,
|
|
1296
|
+
has_freq_factors,
|
|
1297
|
+
ext_factor, attn_factor, beta_fast, beta_slow,
|
|
1233
1298
|
nb00, nb01, nb02, nb03,
|
|
1234
1299
|
ne0,
|
|
1235
1300
|
nb0, nb1, nb2, nb3
|
|
1236
1301
|
};
|
|
1237
1302
|
|
|
1238
|
-
auto
|
|
1303
|
+
auto & inC_ = inC ? inC : inA;
|
|
1304
|
+
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
|
1305
|
+
const bool is_f16 = src0t == GGML_TYPE_F16;
|
|
1306
|
+
|
|
1307
|
+
auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
|
|
1239
1308
|
std::shared_ptr<kp::Algorithm> s_algo = nullptr;
|
|
1240
1309
|
if (!komputeManager()->hasAlgorithm(name)) {
|
|
1310
|
+
auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
|
|
1241
1311
|
s_algo = komputeManager()->algorithm<float, PushConstants>(
|
|
1242
|
-
name, s_kompute_context->pool.get(), {inA, inB, out},
|
|
1243
|
-
src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
|
|
1312
|
+
name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
|
|
1244
1313
|
{unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
|
|
1245
1314
|
);
|
|
1246
1315
|
} else {
|
|
1247
1316
|
s_algo = komputeManager()->getAlgorithm(name);
|
|
1248
|
-
s_algo->setTensors({inA, inB, out});
|
|
1317
|
+
s_algo->setTensors({inA, inB, inC_, out});
|
|
1249
1318
|
s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
|
|
1250
1319
|
s_algo->setPushConstants<PushConstants>({pushConsts});
|
|
1251
1320
|
s_algo->updateDescriptors(s_kompute_context->pool.get());
|
|
@@ -1323,22 +1392,16 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
|
|
|
1323
1392
|
ggml_vk_cpy(spirv, 2, 4, std::forward<Args>(args)...);
|
|
1324
1393
|
}
|
|
1325
1394
|
|
|
1326
|
-
static bool
|
|
1327
|
-
|
|
1328
|
-
case GGML_TYPE_F16:
|
|
1329
|
-
case GGML_TYPE_F32:
|
|
1330
|
-
case GGML_TYPE_Q4_0:
|
|
1331
|
-
case GGML_TYPE_Q4_1:
|
|
1332
|
-
break;
|
|
1333
|
-
default:
|
|
1334
|
-
return false;
|
|
1335
|
-
}
|
|
1336
|
-
|
|
1395
|
+
static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
|
|
1396
|
+
int64_t n = ggml_nelements(op);
|
|
1337
1397
|
switch (op->op) {
|
|
1338
1398
|
case GGML_OP_UNARY:
|
|
1399
|
+
if (n % 4 != 0) return false;
|
|
1339
1400
|
switch (ggml_get_unary_op(op)) {
|
|
1340
|
-
case GGML_UNARY_OP_RELU:
|
|
1341
1401
|
case GGML_UNARY_OP_GELU:
|
|
1402
|
+
if (n % 8 != 0) return false;
|
|
1403
|
+
// fall through
|
|
1404
|
+
case GGML_UNARY_OP_RELU:
|
|
1342
1405
|
case GGML_UNARY_OP_SILU:
|
|
1343
1406
|
return ggml_is_contiguous(op->src[0]);
|
|
1344
1407
|
default:
|
|
@@ -1356,8 +1419,18 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
|
|
1356
1419
|
case GGML_OP_SOFT_MAX:
|
|
1357
1420
|
case GGML_OP_RMS_NORM:
|
|
1358
1421
|
case GGML_OP_NORM:
|
|
1359
|
-
case GGML_OP_ROPE:
|
|
1360
1422
|
return true;
|
|
1423
|
+
case GGML_OP_ROPE:
|
|
1424
|
+
{
|
|
1425
|
+
const int mode = ((const int32_t *) op->op_params)[2];
|
|
1426
|
+
if (mode & GGML_ROPE_TYPE_MROPE) {
|
|
1427
|
+
return false;
|
|
1428
|
+
}
|
|
1429
|
+
if (mode & GGML_ROPE_TYPE_VISION) {
|
|
1430
|
+
return false;
|
|
1431
|
+
}
|
|
1432
|
+
return true;
|
|
1433
|
+
}
|
|
1361
1434
|
case GGML_OP_DUP:
|
|
1362
1435
|
case GGML_OP_CPY:
|
|
1363
1436
|
case GGML_OP_CONT:
|
|
@@ -1396,12 +1469,13 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
|
|
1396
1469
|
|
|
1397
1470
|
switch (op->src[0]->type) {
|
|
1398
1471
|
case GGML_TYPE_F32:
|
|
1399
|
-
case GGML_TYPE_Q6_K:
|
|
1400
1472
|
return op->ne[3] == 1;
|
|
1473
|
+
case GGML_TYPE_Q6_K:
|
|
1401
1474
|
case GGML_TYPE_F16:
|
|
1402
1475
|
case GGML_TYPE_Q8_0:
|
|
1403
1476
|
case GGML_TYPE_Q4_0:
|
|
1404
1477
|
case GGML_TYPE_Q4_1:
|
|
1478
|
+
case GGML_TYPE_Q4_K:
|
|
1405
1479
|
return true;
|
|
1406
1480
|
default:
|
|
1407
1481
|
;
|
|
@@ -1410,6 +1484,8 @@ static bool ggml_vk_supports_op(const struct ggml_tensor * op) {
|
|
|
1410
1484
|
;
|
|
1411
1485
|
}
|
|
1412
1486
|
return false;
|
|
1487
|
+
|
|
1488
|
+
GGML_UNUSED(dev);
|
|
1413
1489
|
}
|
|
1414
1490
|
|
|
1415
1491
|
static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml_cgraph * gf) {
|
|
@@ -1458,11 +1534,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1458
1534
|
|
|
1459
1535
|
any_commands_recorded = true;
|
|
1460
1536
|
|
|
1461
|
-
if (!ggml_vk_supports_op(dst)) {
|
|
1462
|
-
fprintf(stderr, "%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
|
|
1463
|
-
GGML_ABORT("unsupported op");
|
|
1464
|
-
}
|
|
1465
|
-
|
|
1466
1537
|
const int32_t ne00 = src0 ? src0->ne[0] : 0;
|
|
1467
1538
|
const int32_t ne01 = src0 ? src0->ne[1] : 0;
|
|
1468
1539
|
const int32_t ne02 = src0 ? src0->ne[2] : 0;
|
|
@@ -1500,9 +1571,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1500
1571
|
const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
|
|
1501
1572
|
uint32_t off_src0 = 0;
|
|
1502
1573
|
uint32_t off_src1 = 0;
|
|
1574
|
+
uint32_t off_src2 = 0;
|
|
1503
1575
|
uint32_t off_dst = 0;
|
|
1504
1576
|
const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
|
|
1505
1577
|
const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
|
|
1578
|
+
const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
|
|
1506
1579
|
const std::shared_ptr<kp::Tensor>& id_dst = dst ? ggml_vk_get_tensor(dst, &off_dst) : nullTensor;
|
|
1507
1580
|
|
|
1508
1581
|
switch (dst->op) {
|
|
@@ -1578,11 +1651,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1578
1651
|
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5021")
|
|
1579
1652
|
GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
|
|
1580
1653
|
|
|
1581
|
-
|
|
1582
|
-
|
|
1583
|
-
|
|
1654
|
+
const int64_t nrows_x = ggml_nrows(src0);
|
|
1655
|
+
const int64_t nrows_y = src0->ne[1];
|
|
1656
|
+
|
|
1657
|
+
const uint32_t n_head = nrows_x/nrows_y;
|
|
1658
|
+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
|
1584
1659
|
|
|
1585
|
-
|
|
1660
|
+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
|
1661
|
+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
|
1662
|
+
|
|
1663
|
+
ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
|
|
1586
1664
|
} break;
|
|
1587
1665
|
case GGML_OP_DIAG_MASK_INF:
|
|
1588
1666
|
{
|
|
@@ -1634,32 +1712,44 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1634
1712
|
case GGML_TYPE_F16:
|
|
1635
1713
|
ggml_vk_mul_mat_f16(
|
|
1636
1714
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1637
|
-
ne00, ne01, ne02, nb00, nb01, nb02,
|
|
1715
|
+
ne00, ne01, ne02, nb00, nb01, nb02, nb03,
|
|
1716
|
+
ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
|
1638
1717
|
ne0, ne1, r2, r3
|
|
1639
1718
|
);
|
|
1640
1719
|
break;
|
|
1641
1720
|
case GGML_TYPE_Q8_0:
|
|
1642
1721
|
ggml_vk_mul_mat_q8_0(
|
|
1643
1722
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1644
|
-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1723
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1724
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1645
1725
|
);
|
|
1646
1726
|
break;
|
|
1647
1727
|
case GGML_TYPE_Q4_0:
|
|
1648
1728
|
ggml_vk_mul_mat_q4_0(
|
|
1649
1729
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1650
|
-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1730
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1731
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1651
1732
|
);
|
|
1652
1733
|
break;
|
|
1653
1734
|
case GGML_TYPE_Q4_1:
|
|
1654
1735
|
ggml_vk_mul_mat_q4_1(
|
|
1655
1736
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1656
|
-
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1737
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1738
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1739
|
+
);
|
|
1740
|
+
break;
|
|
1741
|
+
case GGML_TYPE_Q4_K:
|
|
1742
|
+
ggml_vk_mul_mat_q4_k(
|
|
1743
|
+
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1744
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1745
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1657
1746
|
);
|
|
1658
1747
|
break;
|
|
1659
1748
|
case GGML_TYPE_Q6_K:
|
|
1660
1749
|
ggml_vk_mul_mat_q6_k(
|
|
1661
1750
|
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
|
|
1662
|
-
ne00, ne10,
|
|
1751
|
+
ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
|
|
1752
|
+
nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
|
|
1663
1753
|
);
|
|
1664
1754
|
break;
|
|
1665
1755
|
default: {
|
|
@@ -1688,13 +1778,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1688
1778
|
} break;
|
|
1689
1779
|
case GGML_OP_ROPE:
|
|
1690
1780
|
{
|
|
1691
|
-
#pragma message("TODO: implement phi3 frequency factors support")
|
|
1692
|
-
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7225")
|
|
1693
|
-
GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
|
|
1694
|
-
|
|
1695
|
-
#pragma message("TODO: update rope NORM mode to match NEOX mode")
|
|
1696
|
-
#pragma message(" https://github.com/ggerganov/llama.cpp/pull/7634")
|
|
1697
|
-
|
|
1698
1781
|
GGML_ASSERT(ne10 == ne02);
|
|
1699
1782
|
GGML_ASSERT(src0t == dstt);
|
|
1700
1783
|
// const int n_past = ((int32_t *) dst->op_params)[0];
|
|
@@ -1703,6 +1786,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1703
1786
|
// skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
|
|
1704
1787
|
const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
|
|
1705
1788
|
|
|
1789
|
+
const bool has_freq_factors = dst->src[2] != nullptr;
|
|
1790
|
+
|
|
1706
1791
|
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
|
|
1707
1792
|
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
|
|
1708
1793
|
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
|
|
@@ -1711,8 +1796,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
|
|
|
1711
1796
|
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
|
|
1712
1797
|
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
|
|
1713
1798
|
ggml_vk_rope(
|
|
1714
|
-
seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
|
|
1715
|
-
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
|
|
1799
|
+
seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
|
|
1800
|
+
freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
|
|
1716
1801
|
ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
|
|
1717
1802
|
);
|
|
1718
1803
|
} break;
|
|
@@ -1820,11 +1905,6 @@ static void ggml_backend_kompute_device_unref(ggml_backend_buffer_type_t buft) {
|
|
|
1820
1905
|
}
|
|
1821
1906
|
}
|
|
1822
1907
|
|
|
1823
|
-
static const char * ggml_backend_kompute_buffer_get_name(ggml_backend_buffer_t buffer) {
|
|
1824
|
-
auto * ctx = static_cast<ggml_backend_kompute_buffer_type_context *>(buffer->buft->context);
|
|
1825
|
-
return ctx->name.c_str();
|
|
1826
|
-
}
|
|
1827
|
-
|
|
1828
1908
|
static void ggml_backend_kompute_buffer_free_buffer(ggml_backend_buffer_t buffer) {
|
|
1829
1909
|
auto * memory = (ggml_vk_memory *)buffer->context;
|
|
1830
1910
|
if (ggml_vk_has_device()) {
|
|
@@ -1868,7 +1948,6 @@ static void ggml_backend_kompute_buffer_clear(ggml_backend_buffer_t buffer, uint
|
|
|
1868
1948
|
}
|
|
1869
1949
|
|
|
1870
1950
|
static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = {
|
|
1871
|
-
/* .get_name = */ ggml_backend_kompute_buffer_get_name,
|
|
1872
1951
|
/* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer,
|
|
1873
1952
|
/* .get_base = */ ggml_backend_kompute_buffer_get_base,
|
|
1874
1953
|
/* .init_tensor = */ NULL,
|
|
@@ -1913,25 +1992,31 @@ static ggml_backend_buffer_type_i ggml_backend_kompute_buffer_type_interface = {
|
|
|
1913
1992
|
};
|
|
1914
1993
|
|
|
1915
1994
|
ggml_backend_buffer_type_t ggml_backend_kompute_buffer_type(int device) {
|
|
1916
|
-
static std::
|
|
1917
|
-
|
|
1918
|
-
|
|
1919
|
-
|
|
1920
|
-
|
|
1921
|
-
|
|
1922
|
-
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
|
|
1926
|
-
|
|
1995
|
+
static std::mutex mutex;
|
|
1996
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
1997
|
+
|
|
1998
|
+
auto devices = ggml_vk_available_devices();
|
|
1999
|
+
int32_t device_count = (int32_t) devices.size();
|
|
2000
|
+
GGML_ASSERT(device < device_count);
|
|
2001
|
+
GGML_ASSERT(devices.size() <= GGML_KOMPUTE_MAX_DEVICES);
|
|
2002
|
+
|
|
2003
|
+
static ggml_backend_buffer_type
|
|
2004
|
+
ggml_backend_kompute_buffer_types[GGML_KOMPUTE_MAX_DEVICES];
|
|
2005
|
+
|
|
2006
|
+
static bool ggml_backend_kompute_buffer_type_initialized = false;
|
|
2007
|
+
|
|
2008
|
+
if (!ggml_backend_kompute_buffer_type_initialized) {
|
|
2009
|
+
for (int32_t i = 0; i < device_count; i++) {
|
|
2010
|
+
ggml_backend_kompute_buffer_types[i] = {
|
|
2011
|
+
/* .iface = */ ggml_backend_kompute_buffer_type_interface,
|
|
2012
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), i),
|
|
2013
|
+
/* .context = */ new ggml_backend_kompute_buffer_type_context{ i, devices[i].bufferAlignment, devices[i].maxAlloc },
|
|
2014
|
+
};
|
|
1927
2015
|
}
|
|
1928
|
-
|
|
1929
|
-
}
|
|
2016
|
+
ggml_backend_kompute_buffer_type_initialized = true;
|
|
2017
|
+
}
|
|
1930
2018
|
|
|
1931
|
-
|
|
1932
|
-
return device == static_cast<ggml_backend_kompute_buffer_type_context *>(t.context)->device;
|
|
1933
|
-
});
|
|
1934
|
-
return it < bufts.end() ? &*it : nullptr;
|
|
2019
|
+
return &ggml_backend_kompute_buffer_types[device];
|
|
1935
2020
|
}
|
|
1936
2021
|
|
|
1937
2022
|
// backend
|
|
@@ -1953,31 +2038,15 @@ static void ggml_backend_kompute_free(ggml_backend_t backend) {
|
|
|
1953
2038
|
delete backend;
|
|
1954
2039
|
}
|
|
1955
2040
|
|
|
1956
|
-
static ggml_backend_buffer_type_t ggml_backend_kompute_get_default_buffer_type(ggml_backend_t backend) {
|
|
1957
|
-
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
|
|
1958
|
-
return ggml_backend_kompute_buffer_type(ctx->device);
|
|
1959
|
-
}
|
|
1960
|
-
|
|
1961
2041
|
static ggml_status ggml_backend_kompute_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
|
1962
2042
|
auto * ctx = static_cast<ggml_kompute_context *>(backend->context);
|
|
1963
2043
|
ggml_vk_graph_compute(ctx, cgraph);
|
|
1964
2044
|
return GGML_STATUS_SUCCESS;
|
|
1965
2045
|
}
|
|
1966
2046
|
|
|
1967
|
-
static bool ggml_backend_kompute_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
|
|
1968
|
-
GGML_UNUSED(backend);
|
|
1969
|
-
return ggml_vk_supports_op(op);
|
|
1970
|
-
}
|
|
1971
|
-
|
|
1972
|
-
static bool ggml_backend_kompute_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
|
|
1973
|
-
GGML_UNUSED(backend);
|
|
1974
|
-
return buft->iface.get_name == ggml_backend_kompute_buffer_type_get_name;
|
|
1975
|
-
}
|
|
1976
|
-
|
|
1977
2047
|
static struct ggml_backend_i kompute_backend_i = {
|
|
1978
2048
|
/* .get_name = */ ggml_backend_kompute_name,
|
|
1979
2049
|
/* .free = */ ggml_backend_kompute_free,
|
|
1980
|
-
/* .get_default_buffer_type = */ ggml_backend_kompute_get_default_buffer_type,
|
|
1981
2050
|
/* .set_tensor_async = */ NULL,
|
|
1982
2051
|
/* .get_tensor_async = */ NULL,
|
|
1983
2052
|
/* .cpy_tensor_async = */ NULL,
|
|
@@ -1987,9 +2056,6 @@ static struct ggml_backend_i kompute_backend_i = {
|
|
|
1987
2056
|
/* .graph_plan_update = */ NULL,
|
|
1988
2057
|
/* .graph_plan_compute = */ NULL,
|
|
1989
2058
|
/* .graph_compute = */ ggml_backend_kompute_graph_compute,
|
|
1990
|
-
/* .supports_op = */ ggml_backend_kompute_supports_op,
|
|
1991
|
-
/* .supports_buft = */ ggml_backend_kompute_supports_buft,
|
|
1992
|
-
/* .offload_op = */ NULL,
|
|
1993
2059
|
/* .event_record = */ NULL,
|
|
1994
2060
|
/* .event_wait = */ NULL,
|
|
1995
2061
|
};
|
|
@@ -2006,7 +2072,7 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
|
|
|
2006
2072
|
ggml_backend_t kompute_backend = new ggml_backend {
|
|
2007
2073
|
/* .guid = */ ggml_backend_kompute_guid(),
|
|
2008
2074
|
/* .interface = */ kompute_backend_i,
|
|
2009
|
-
/* .device = */
|
|
2075
|
+
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_kompute_reg(), device),
|
|
2010
2076
|
/* .context = */ s_kompute_context,
|
|
2011
2077
|
};
|
|
2012
2078
|
|
|
@@ -2016,3 +2082,170 @@ ggml_backend_t ggml_backend_kompute_init(int device) {
|
|
|
2016
2082
|
bool ggml_backend_is_kompute(ggml_backend_t backend) {
|
|
2017
2083
|
return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_kompute_guid());
|
|
2018
2084
|
}
|
|
2085
|
+
|
|
2086
|
+
static size_t ggml_backend_kompute_get_device_count() {
|
|
2087
|
+
auto devices = ggml_vk_available_devices();
|
|
2088
|
+
return devices.size();
|
|
2089
|
+
}
|
|
2090
|
+
|
|
2091
|
+
static void ggml_backend_kompute_get_device_description(int device, char * description, size_t description_size) {
|
|
2092
|
+
auto devices = ggml_vk_available_devices();
|
|
2093
|
+
GGML_ASSERT((size_t) device < devices.size());
|
|
2094
|
+
snprintf(description, description_size, "%s", devices[device].name);
|
|
2095
|
+
}
|
|
2096
|
+
|
|
2097
|
+
static void ggml_backend_kompute_get_device_memory(int device, size_t * free, size_t * total) {
|
|
2098
|
+
auto devices = ggml_vk_available_devices();
|
|
2099
|
+
GGML_ASSERT((size_t) device < devices.size());
|
|
2100
|
+
*total = devices[device].heapSize;
|
|
2101
|
+
*free = devices[device].heapSize;
|
|
2102
|
+
}
|
|
2103
|
+
|
|
2104
|
+
//////////////////////////
|
|
2105
|
+
|
|
2106
|
+
struct ggml_backend_kompute_device_context {
|
|
2107
|
+
int device;
|
|
2108
|
+
std::string name;
|
|
2109
|
+
std::string description;
|
|
2110
|
+
};
|
|
2111
|
+
|
|
2112
|
+
static const char * ggml_backend_kompute_device_get_name(ggml_backend_dev_t dev) {
|
|
2113
|
+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
|
2114
|
+
return ctx->name.c_str();
|
|
2115
|
+
}
|
|
2116
|
+
|
|
2117
|
+
static const char * ggml_backend_kompute_device_get_description(ggml_backend_dev_t dev) {
|
|
2118
|
+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
|
2119
|
+
return ctx->description.c_str();
|
|
2120
|
+
}
|
|
2121
|
+
|
|
2122
|
+
static void ggml_backend_kompute_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
|
|
2123
|
+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
|
2124
|
+
ggml_backend_kompute_get_device_memory(ctx->device, free, total);
|
|
2125
|
+
}
|
|
2126
|
+
|
|
2127
|
+
static ggml_backend_buffer_type_t ggml_backend_kompute_device_get_buffer_type(ggml_backend_dev_t dev) {
|
|
2128
|
+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
|
2129
|
+
return ggml_backend_kompute_buffer_type(ctx->device);
|
|
2130
|
+
}
|
|
2131
|
+
|
|
2132
|
+
static bool ggml_backend_kompute_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
|
|
2133
|
+
if (buft->iface.get_name != ggml_backend_kompute_buffer_type_get_name) {
|
|
2134
|
+
return false;
|
|
2135
|
+
}
|
|
2136
|
+
|
|
2137
|
+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
|
2138
|
+
ggml_backend_kompute_buffer_type_context * buft_ctx = (ggml_backend_kompute_buffer_type_context *)buft->context;
|
|
2139
|
+
|
|
2140
|
+
return buft_ctx->device == ctx->device;
|
|
2141
|
+
}
|
|
2142
|
+
|
|
2143
|
+
static enum ggml_backend_dev_type ggml_backend_kompute_device_get_type(ggml_backend_dev_t dev) {
|
|
2144
|
+
GGML_UNUSED(dev);
|
|
2145
|
+
return GGML_BACKEND_DEVICE_TYPE_GPU;
|
|
2146
|
+
}
|
|
2147
|
+
|
|
2148
|
+
static void ggml_backend_kompute_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
|
|
2149
|
+
props->name = ggml_backend_kompute_device_get_name(dev);
|
|
2150
|
+
props->description = ggml_backend_kompute_device_get_description(dev);
|
|
2151
|
+
props->type = ggml_backend_kompute_device_get_type(dev);
|
|
2152
|
+
ggml_backend_kompute_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
|
2153
|
+
props->caps = {
|
|
2154
|
+
/* async = */ false,
|
|
2155
|
+
/* host_buffer = */ false,
|
|
2156
|
+
/* .buffer_from_host_ptr = */ false,
|
|
2157
|
+
/* events = */ false,
|
|
2158
|
+
};
|
|
2159
|
+
}
|
|
2160
|
+
|
|
2161
|
+
static ggml_backend_t ggml_backend_kompute_device_init(ggml_backend_dev_t dev, const char * params) {
|
|
2162
|
+
GGML_UNUSED(params);
|
|
2163
|
+
ggml_backend_kompute_device_context * ctx = (ggml_backend_kompute_device_context *)dev->context;
|
|
2164
|
+
return ggml_backend_kompute_init(ctx->device);
|
|
2165
|
+
}
|
|
2166
|
+
|
|
2167
|
+
static bool ggml_backend_kompute_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
|
|
2168
|
+
const int min_batch_size = 32;
|
|
2169
|
+
|
|
2170
|
+
return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
|
|
2171
|
+
(op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
|
|
2172
|
+
|
|
2173
|
+
GGML_UNUSED(dev);
|
|
2174
|
+
}
|
|
2175
|
+
|
|
2176
|
+
static const struct ggml_backend_device_i ggml_backend_kompute_device_i = {
|
|
2177
|
+
/* .get_name = */ ggml_backend_kompute_device_get_name,
|
|
2178
|
+
/* .get_description = */ ggml_backend_kompute_device_get_description,
|
|
2179
|
+
/* .get_memory = */ ggml_backend_kompute_device_get_memory,
|
|
2180
|
+
/* .get_type = */ ggml_backend_kompute_device_get_type,
|
|
2181
|
+
/* .get_props = */ ggml_backend_kompute_device_get_props,
|
|
2182
|
+
/* .init_backend = */ ggml_backend_kompute_device_init,
|
|
2183
|
+
/* .get_buffer_type = */ ggml_backend_kompute_device_get_buffer_type,
|
|
2184
|
+
/* .get_host_buffer_type = */ NULL,
|
|
2185
|
+
/* .buffer_from_host_ptr = */ NULL,
|
|
2186
|
+
/* .supports_op = */ ggml_backend_kompute_device_supports_op,
|
|
2187
|
+
/* .supports_buft = */ ggml_backend_kompute_device_supports_buft,
|
|
2188
|
+
/* .offload_op = */ ggml_backend_kompute_device_offload_op,
|
|
2189
|
+
/* .event_new = */ NULL,
|
|
2190
|
+
/* .event_free = */ NULL,
|
|
2191
|
+
/* .event_synchronize = */ NULL,
|
|
2192
|
+
};
|
|
2193
|
+
|
|
2194
|
+
static const char * ggml_backend_kompute_reg_get_name(ggml_backend_reg_t reg) {
|
|
2195
|
+
GGML_UNUSED(reg);
|
|
2196
|
+
return "Kompute";
|
|
2197
|
+
}
|
|
2198
|
+
|
|
2199
|
+
static size_t ggml_backend_kompute_reg_get_device_count(ggml_backend_reg_t reg) {
|
|
2200
|
+
GGML_UNUSED(reg);
|
|
2201
|
+
return ggml_backend_kompute_get_device_count();
|
|
2202
|
+
}
|
|
2203
|
+
|
|
2204
|
+
static ggml_backend_dev_t ggml_backend_kompute_reg_get_device(ggml_backend_reg_t reg, size_t device) {
|
|
2205
|
+
static std::vector<ggml_backend_dev_t> devices;
|
|
2206
|
+
|
|
2207
|
+
static bool initialized = false;
|
|
2208
|
+
|
|
2209
|
+
{
|
|
2210
|
+
static std::mutex mutex;
|
|
2211
|
+
std::lock_guard<std::mutex> lock(mutex);
|
|
2212
|
+
if (!initialized) {
|
|
2213
|
+
for (size_t i = 0; i < ggml_backend_kompute_get_device_count(); i++) {
|
|
2214
|
+
ggml_backend_kompute_device_context * ctx = new ggml_backend_kompute_device_context;
|
|
2215
|
+
char desc[256];
|
|
2216
|
+
ggml_backend_kompute_get_device_description(i, desc, sizeof(desc));
|
|
2217
|
+
ctx->device = i;
|
|
2218
|
+
ctx->name = "Kompute" + std::to_string(i);
|
|
2219
|
+
ctx->description = desc;
|
|
2220
|
+
devices.push_back(new ggml_backend_device {
|
|
2221
|
+
/* .iface = */ ggml_backend_kompute_device_i,
|
|
2222
|
+
/* .reg = */ reg,
|
|
2223
|
+
/* .context = */ ctx,
|
|
2224
|
+
});
|
|
2225
|
+
}
|
|
2226
|
+
initialized = true;
|
|
2227
|
+
}
|
|
2228
|
+
}
|
|
2229
|
+
|
|
2230
|
+
GGML_ASSERT(device < devices.size());
|
|
2231
|
+
return devices[device];
|
|
2232
|
+
}
|
|
2233
|
+
|
|
2234
|
+
static const struct ggml_backend_reg_i ggml_backend_kompute_reg_i = {
|
|
2235
|
+
/* .get_name = */ ggml_backend_kompute_reg_get_name,
|
|
2236
|
+
/* .get_device_count = */ ggml_backend_kompute_reg_get_device_count,
|
|
2237
|
+
/* .get_device = */ ggml_backend_kompute_reg_get_device,
|
|
2238
|
+
/* .get_proc_address = */ NULL,
|
|
2239
|
+
};
|
|
2240
|
+
|
|
2241
|
+
ggml_backend_reg_t ggml_backend_kompute_reg() {
|
|
2242
|
+
static ggml_backend_reg reg = {
|
|
2243
|
+
/* .api_version = */ GGML_BACKEND_API_VERSION,
|
|
2244
|
+
/* .iface = */ ggml_backend_kompute_reg_i,
|
|
2245
|
+
/* .context = */ nullptr,
|
|
2246
|
+
};
|
|
2247
|
+
|
|
2248
|
+
return ®
|
|
2249
|
+
}
|
|
2250
|
+
|
|
2251
|
+
GGML_BACKEND_DL_IMPL(ggml_backend_kompute_reg)
|