@fugood/llama.node 0.3.2 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +7 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +17 -7
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +89 -27
- package/src/LlamaContext.h +2 -0
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +240 -168
- package/src/llama.cpp/.github/workflows/docker.yml +8 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +14 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -4
- package/src/llama.cpp/common/arg.cpp +986 -770
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +212 -351
- package/src/llama.cpp/common/common.h +204 -117
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +163 -121
- package/src/llama.cpp/common/sampling.h +41 -20
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +134 -161
- package/src/llama.cpp/examples/CMakeLists.txt +33 -14
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +19 -18
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +41 -87
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +263 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +83 -22
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
- package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +73 -114
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
- package/src/llama.cpp/examples/server/server.cpp +2073 -1339
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +354 -277
- package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/simple/simple.cpp +130 -94
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
- package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +159 -417
- package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +93 -52
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +4 -8
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +779 -194
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +55 -10
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +4317 -2979
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -38
- package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
- package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +62 -20
- package/src/llama.cpp/tests/test-sampling.cpp +163 -138
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -8,7 +8,6 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
|
|
|
8
8
|
|
|
9
9
|
const int nthreads = item_ct1.get_local_range(2);
|
|
10
10
|
const int nwarps = nthreads / WARP_SIZE;
|
|
11
|
-
assert(nwarps % WARP_SIZE == 0);
|
|
12
11
|
sycl::float2 mean_var = sycl::float2(0.f, 0.f);
|
|
13
12
|
|
|
14
13
|
for (int col = tid; col < ncols; col += block_size) {
|
|
@@ -32,7 +31,7 @@ static void norm_f32(const float* x, float* dst, const int ncols, const float ep
|
|
|
32
31
|
*/
|
|
33
32
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
34
33
|
mean_var = 0.f;
|
|
35
|
-
|
|
34
|
+
size_t nreduce = nwarps / WARP_SIZE;
|
|
36
35
|
for (size_t i = 0; i < nreduce; i += 1)
|
|
37
36
|
{
|
|
38
37
|
mean_var += s_sum[lane_id + i * WARP_SIZE];
|
|
@@ -55,9 +54,8 @@ static void group_norm_f32(const float* x, float* dst, const int group_size, con
|
|
|
55
54
|
int end = start + group_size;
|
|
56
55
|
const int nthreads = item_ct1.get_local_range(2);
|
|
57
56
|
const int nwarps = nthreads / WARP_SIZE;
|
|
58
|
-
assert(nwarps % WARP_SIZE == 0);
|
|
59
57
|
start += item_ct1.get_local_id(2);
|
|
60
|
-
|
|
58
|
+
size_t nreduce = nwarps / WARP_SIZE;
|
|
61
59
|
|
|
62
60
|
if (end >= ne_elements) {
|
|
63
61
|
end = ne_elements;
|
|
@@ -144,7 +142,6 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
|
|
144
142
|
const int tid = item_ct1.get_local_id(2);
|
|
145
143
|
const int nthreads = item_ct1.get_local_range(2);
|
|
146
144
|
const int nwarps = nthreads / WARP_SIZE;
|
|
147
|
-
assert(nwarps % WARP_SIZE == 0);
|
|
148
145
|
float tmp = 0.0f; // partial sum for thread in warp
|
|
149
146
|
|
|
150
147
|
for (int col = tid; col < ncols; col += block_size) {
|
|
@@ -166,7 +163,7 @@ static void rms_norm_f32(const float* x, float* dst, const int ncols, const floa
|
|
|
166
163
|
converged control flow. You may need to adjust the code.
|
|
167
164
|
*/
|
|
168
165
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
169
|
-
|
|
166
|
+
size_t nreduce = nwarps / WARP_SIZE;
|
|
170
167
|
tmp = 0.f;
|
|
171
168
|
for (size_t i = 0; i < nreduce; i += 1)
|
|
172
169
|
{
|
|
@@ -202,6 +199,7 @@ static void norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
202
199
|
}
|
|
203
200
|
else {
|
|
204
201
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
202
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
|
205
203
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
|
206
204
|
/*
|
|
207
205
|
DPCT1049:17: The work-group size passed to the SYCL kernel may exceed
|
|
@@ -244,6 +242,7 @@ static void group_norm_f32_sycl(const float* x, float* dst,
|
|
|
244
242
|
}
|
|
245
243
|
else {
|
|
246
244
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
245
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
|
247
246
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
|
248
247
|
/*
|
|
249
248
|
DPCT1049:18: The work-group size passed to the SYCL kernel may exceed
|
|
@@ -290,6 +289,7 @@ static void rms_norm_f32_sycl(const float* x, float* dst, const int ncols,
|
|
|
290
289
|
}
|
|
291
290
|
else {
|
|
292
291
|
const int work_group_size = ggml_sycl_info().max_work_group_sizes[device];
|
|
292
|
+
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
|
|
293
293
|
const sycl::range<3> block_dims(1, 1, work_group_size);
|
|
294
294
|
/*
|
|
295
295
|
DPCT1049:19: The work-group size passed to the SYCL kernel may exceed
|
|
@@ -352,6 +352,7 @@ void ggml_sycl_op_group_norm(ggml_backend_sycl_context& ctx, const ggml_tensor*
|
|
|
352
352
|
(void)src1;
|
|
353
353
|
(void)dst;
|
|
354
354
|
(void)src1_dd;
|
|
355
|
+
GGML_UNUSED(ctx);
|
|
355
356
|
}
|
|
356
357
|
|
|
357
358
|
void ggml_sycl_op_rms_norm(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
#include <sycl/sycl.hpp>
|
|
2
|
+
#include <oneapi/mkl.hpp>
|
|
3
|
+
#include "outprod.hpp"
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
7
|
+
const ggml_tensor* src1, ggml_tensor* dst) {
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
|
11
|
+
GGML_ASSERT(src1->type == GGML_TYPE_F32);
|
|
12
|
+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
|
|
13
|
+
GGML_ASSERT(ggml_is_contiguous(src0));
|
|
14
|
+
GGML_ASSERT(ggml_is_contiguous(dst));
|
|
15
|
+
|
|
16
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
17
|
+
|
|
18
|
+
// Get SYCL queue
|
|
19
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
20
|
+
|
|
21
|
+
// Dimension checks
|
|
22
|
+
GGML_ASSERT(ne01 == ne11); // Inner dimensions must match
|
|
23
|
+
GGML_ASSERT(ne0 == ne00); // Output rows match src0 rows
|
|
24
|
+
GGML_ASSERT(ne1 == ne10); // Output cols match src1 cols
|
|
25
|
+
|
|
26
|
+
// Get data pointers
|
|
27
|
+
const float* src0_d = (const float*)src0->data;
|
|
28
|
+
const float* src1_d = (const float*)src1->data;
|
|
29
|
+
float* dst_d = (float*)dst->data;
|
|
30
|
+
|
|
31
|
+
// GEMM parameters
|
|
32
|
+
const float alpha = 1.0f;
|
|
33
|
+
const float beta = 0.0f;
|
|
34
|
+
|
|
35
|
+
// Handle transposition of src1
|
|
36
|
+
const bool src1_T = ggml_is_transposed(src1);
|
|
37
|
+
const oneapi::mkl::transpose src1_op =
|
|
38
|
+
src1_T ? oneapi::mkl::transpose::nontrans : oneapi::mkl::transpose::trans;
|
|
39
|
+
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
|
|
40
|
+
|
|
41
|
+
try {
|
|
42
|
+
// Perform matrix multiplication using oneMKL GEMM
|
|
43
|
+
#ifdef GGML_SYCL_NVIDIA
|
|
44
|
+
oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector<oneapi::mkl::backend::cublas>{ *stream },
|
|
45
|
+
oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha, src0_d,
|
|
46
|
+
ne00, src1_d, ldb, beta, dst_d, ne0);
|
|
47
|
+
#else
|
|
48
|
+
oneapi::mkl::blas::column_major::gemm(*stream, oneapi::mkl::transpose::nontrans, src1_op, ne0, ne1, ne01, alpha,
|
|
49
|
+
src0_d, ne00, src1_d, ldb, beta, dst_d, ne0);
|
|
50
|
+
#endif
|
|
51
|
+
}
|
|
52
|
+
catch (sycl::exception const& exc) {
|
|
53
|
+
std::cerr << exc.what() << std::endl;
|
|
54
|
+
GGML_ASSERT(false);
|
|
55
|
+
}
|
|
56
|
+
}
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
#ifndef GGML_SYCL_OUTPROD_HPP
|
|
2
|
+
#define GGML_SYCL_OUTPROD_HPP
|
|
3
|
+
|
|
4
|
+
#include "common.hpp"
|
|
5
|
+
|
|
6
|
+
void ggml_sycl_op_out_prod(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
7
|
+
const ggml_tensor* src1, ggml_tensor* dst);
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
#endif // GGML_SYCL_OUTPROD_HPP
|
|
11
|
+
|
|
@@ -25,6 +25,11 @@
|
|
|
25
25
|
#define SYCL_RELU_BLOCK_SIZE 256
|
|
26
26
|
#define SYCL_HARDSIGMOID_BLOCK_SIZE 256
|
|
27
27
|
#define SYCL_HARDSWISH_BLOCK_SIZE 256
|
|
28
|
+
#define SYCL_EXP_BLOCK_SIZE 256
|
|
29
|
+
#define SYCL_NEG_BLOCK_SIZE 256
|
|
30
|
+
#define SYCL_SIGMOID_BLOCK_SIZE 256
|
|
31
|
+
#define SYCL_SQRT_BLOCK_SIZE 256
|
|
32
|
+
#define SYCL_SIN_BLOCK_SIZE 256
|
|
28
33
|
#define SYCL_SQR_BLOCK_SIZE 256
|
|
29
34
|
#define SYCL_CPY_BLOCK_SIZE 32
|
|
30
35
|
#define SYCL_SCALE_BLOCK_SIZE 256
|
|
@@ -41,6 +46,7 @@
|
|
|
41
46
|
#define SYCL_ACC_BLOCK_SIZE 256
|
|
42
47
|
#define SYCL_IM2COL_BLOCK_SIZE 256
|
|
43
48
|
#define SYCL_POOL2D_BLOCK_SIZE 256
|
|
49
|
+
#define SYCL_ARGMAX_BLOCK_SIZE 256
|
|
44
50
|
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
|
|
45
51
|
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
|
|
46
52
|
|
|
@@ -16,7 +16,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|
|
16
16
|
const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
|
|
17
17
|
const int nthreads = block_size;
|
|
18
18
|
const int nwarps = nthreads / WARP_SIZE;
|
|
19
|
-
|
|
19
|
+
size_t nreduce = nwarps / WARP_SIZE;
|
|
20
20
|
float slope = 1.0f;
|
|
21
21
|
|
|
22
22
|
// ALiBi
|
|
@@ -53,8 +53,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|
|
53
53
|
if (block_size > WARP_SIZE) {
|
|
54
54
|
if (warp_id == 0) {
|
|
55
55
|
buf[lane_id] = -INFINITY;
|
|
56
|
-
for (size_t i = 1; i < nreduce; i += 1)
|
|
56
|
+
for (size_t i = 1; i < nreduce; i += 1) {
|
|
57
57
|
buf[lane_id + i * WARP_SIZE] = -INFINITY;
|
|
58
|
+
}
|
|
58
59
|
}
|
|
59
60
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
60
61
|
|
|
@@ -63,8 +64,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|
|
63
64
|
}
|
|
64
65
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
65
66
|
max_val = buf[lane_id];
|
|
66
|
-
for (size_t i = 1; i < nreduce; i += 1)
|
|
67
|
-
{
|
|
67
|
+
for (size_t i = 1; i < nreduce; i += 1) {
|
|
68
68
|
max_val = std::max(max_val, buf[lane_id + i * WARP_SIZE]);
|
|
69
69
|
}
|
|
70
70
|
max_val = warp_reduce_max(max_val, item_ct1);
|
|
@@ -89,8 +89,9 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|
|
89
89
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
90
90
|
if (warp_id == 0) {
|
|
91
91
|
buf[lane_id] = 0.f;
|
|
92
|
-
for (size_t i = 1; i < nreduce; i += 1)
|
|
92
|
+
for (size_t i = 1; i < nreduce; i += 1) {
|
|
93
93
|
buf[lane_id + i * WARP_SIZE] = 0.f;
|
|
94
|
+
}
|
|
94
95
|
}
|
|
95
96
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
96
97
|
|
|
@@ -100,8 +101,7 @@ static void soft_max_f32(const float * x, const float * mask, float * dst, const
|
|
|
100
101
|
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
101
102
|
|
|
102
103
|
tmp = buf[lane_id];
|
|
103
|
-
for (size_t i = 1; i < nreduce; i += 1)
|
|
104
|
-
{
|
|
104
|
+
for (size_t i = 1; i < nreduce; i += 1) {
|
|
105
105
|
tmp += buf[lane_id + i * WARP_SIZE];
|
|
106
106
|
}
|
|
107
107
|
tmp = warp_reduce_sum(tmp, item_ct1);
|
|
@@ -968,8 +968,8 @@ vec_dot_iq3_xxs_q8_1(const void *__restrict__ vbq,
|
|
|
968
968
|
grid1[0] ^ signs[0], signs[0], std::minus<>());
|
|
969
969
|
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
|
970
970
|
grid2[0] ^ signs[1], signs[1], std::minus<>());
|
|
971
|
-
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
|
|
972
|
-
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
|
|
971
|
+
sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
|
|
972
|
+
sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
|
|
973
973
|
q8 += 8;
|
|
974
974
|
aux32 >>= 7;
|
|
975
975
|
}
|
|
@@ -1009,8 +1009,8 @@ vec_dot_iq3_s_q8_1(const void *__restrict__ vbq,
|
|
|
1009
1009
|
grid1[0] ^ signs0, signs0, std::minus<>());
|
|
1010
1010
|
const int grid_h = dpct::vectorized_binary<sycl::uchar4>(
|
|
1011
1011
|
grid2[0] ^ signs1, signs1, std::minus<>());
|
|
1012
|
-
sumi = dpct::dp4a(grid_l, *((int *)q8 + 0), sumi);
|
|
1013
|
-
sumi = dpct::dp4a(grid_h, *((int *)q8 + 1), sumi);
|
|
1012
|
+
sumi = dpct::dp4a(grid_l, *((const int *)q8 + 0), sumi);
|
|
1013
|
+
sumi = dpct::dp4a(grid_h, *((const int *)q8 + 1), sumi);
|
|
1014
1014
|
q8 += 8;
|
|
1015
1015
|
}
|
|
1016
1016
|
const float d =
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
#include <sycl/sycl.hpp>
|
|
2
|
+
#include "wkv6.hpp"
|
|
3
|
+
|
|
4
|
+
constexpr int WKV_BLOCK_SIZE = 64; // Matching CUDA_WKV_BLOCK_SIZE
|
|
5
|
+
|
|
6
|
+
// Helper function for the main kernel
|
|
7
|
+
static void rwkv_wkv_f32_kernel(
|
|
8
|
+
const int B, const int T, const int C, const int H,
|
|
9
|
+
const float* k, const float* v, const float* r,
|
|
10
|
+
const float* tf, const float* td, const float* s,
|
|
11
|
+
float* dst, const sycl::nd_item<3>& item_ct1, float* shared_mem) {
|
|
12
|
+
|
|
13
|
+
const int tid = item_ct1.get_local_id(2);
|
|
14
|
+
const int bid = item_ct1.get_group(2);
|
|
15
|
+
|
|
16
|
+
const int head_size = WKV_BLOCK_SIZE;
|
|
17
|
+
const int batch_i = bid / H;
|
|
18
|
+
const int head_i = bid % H;
|
|
19
|
+
const int state_size = C * head_size;
|
|
20
|
+
const int n_seq_tokens = T / B;
|
|
21
|
+
|
|
22
|
+
// Set up shared memory pointers
|
|
23
|
+
float* _k = shared_mem;
|
|
24
|
+
float* _r = _k + head_size;
|
|
25
|
+
float* _tf = _r + head_size;
|
|
26
|
+
float* _td = _tf + head_size;
|
|
27
|
+
|
|
28
|
+
// Local state array
|
|
29
|
+
float state[WKV_BLOCK_SIZE];
|
|
30
|
+
|
|
31
|
+
// Load initial state
|
|
32
|
+
#pragma unroll
|
|
33
|
+
for (int i = 0; i < head_size; i++) {
|
|
34
|
+
state[i] = s[batch_i * state_size + head_i * head_size * head_size + i * head_size + tid];
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
// Sync threads before shared memory operations
|
|
38
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
39
|
+
|
|
40
|
+
// Load time-mixing parameters
|
|
41
|
+
_tf[tid] = tf[head_i * head_size + tid];
|
|
42
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
43
|
+
|
|
44
|
+
// Main sequence processing loop
|
|
45
|
+
for (int t = batch_i * n_seq_tokens * C + head_i * head_size + tid;
|
|
46
|
+
t < (batch_i + 1) * n_seq_tokens * C + head_i * head_size + tid;
|
|
47
|
+
t += C) {
|
|
48
|
+
|
|
49
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
50
|
+
|
|
51
|
+
// Load current timestep data to shared memory
|
|
52
|
+
_k[tid] = k[t];
|
|
53
|
+
_r[tid] = r[t];
|
|
54
|
+
_td[tid] = td[t];
|
|
55
|
+
|
|
56
|
+
item_ct1.barrier(sycl::access::fence_space::local_space);
|
|
57
|
+
|
|
58
|
+
const float _v = v[t];
|
|
59
|
+
float y = 0;
|
|
60
|
+
|
|
61
|
+
// Process in chunks of 4 for better vectorization
|
|
62
|
+
sycl::float4 k4, r4, tf4, td4, s4;
|
|
63
|
+
#pragma unroll
|
|
64
|
+
for (int j = 0; j < head_size; j += 4) {
|
|
65
|
+
// Load data in vec4 chunks
|
|
66
|
+
k4 = sycl::float4(_k[j], _k[j+1], _k[j+2], _k[j+3]);
|
|
67
|
+
r4 = sycl::float4(_r[j], _r[j+1], _r[j+2], _r[j+3]);
|
|
68
|
+
tf4 = sycl::float4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]);
|
|
69
|
+
td4 = sycl::float4(_td[j], _td[j+1], _td[j+2], _td[j+3]);
|
|
70
|
+
s4 = sycl::float4(state[j], state[j+1], state[j+2], state[j+3]);
|
|
71
|
+
|
|
72
|
+
// Compute key-value product
|
|
73
|
+
sycl::float4 kv4 = k4 * _v;
|
|
74
|
+
|
|
75
|
+
// Accumulate weighted sum
|
|
76
|
+
y += sycl::dot(r4, tf4 * kv4 + s4);
|
|
77
|
+
|
|
78
|
+
// Update state
|
|
79
|
+
s4 = s4 * td4 + kv4;
|
|
80
|
+
|
|
81
|
+
// Store updated state
|
|
82
|
+
state[j] = s4.x();
|
|
83
|
+
state[j+1] = s4.y();
|
|
84
|
+
state[j+2] = s4.z();
|
|
85
|
+
state[j+3] = s4.w();
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
dst[t] = y;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
// Save final state
|
|
92
|
+
#pragma unroll
|
|
93
|
+
for (int i = 0; i < head_size; i++) {
|
|
94
|
+
dst[T * C + batch_i * state_size + head_i * head_size * head_size + i * head_size + tid] = state[i];
|
|
95
|
+
}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
void ggml_sycl_op_rwkv_wkv6(ggml_backend_sycl_context& ctx, const ggml_tensor* src0,
|
|
99
|
+
const ggml_tensor* src1, ggml_tensor* dst) {
|
|
100
|
+
|
|
101
|
+
const float* k_d = (const float*)dst->src[0]->data;
|
|
102
|
+
const float* v_d = (const float*)dst->src[1]->data;
|
|
103
|
+
const float* r_d = (const float*)dst->src[2]->data;
|
|
104
|
+
const float* tf_d = (const float*)dst->src[3]->data;
|
|
105
|
+
const float* td_d = (const float*)dst->src[4]->data;
|
|
106
|
+
const float* s_d = (const float*)dst->src[5]->data;
|
|
107
|
+
float* dst_d = (float*)dst->data;
|
|
108
|
+
|
|
109
|
+
const int64_t B = dst->src[5]->ne[1];
|
|
110
|
+
const int64_t T = dst->src[0]->ne[3];
|
|
111
|
+
const int64_t C = dst->ne[0];
|
|
112
|
+
const int64_t H = dst->src[0]->ne[2];
|
|
113
|
+
|
|
114
|
+
GGML_ASSERT(dst->src[5]->type == GGML_TYPE_F32);
|
|
115
|
+
GGML_ASSERT(C % H == 0);
|
|
116
|
+
GGML_ASSERT(C / H == WKV_BLOCK_SIZE); // The current sycl kernel is designed for RWKV6, HEAD_SIZE == 64
|
|
117
|
+
|
|
118
|
+
dpct::queue_ptr stream = ctx.stream();
|
|
119
|
+
|
|
120
|
+
// Calculate execution configuration
|
|
121
|
+
const size_t shared_mem_size = WKV_BLOCK_SIZE * 4 * sizeof(float); // For k, r, tf, td
|
|
122
|
+
sycl::range<3> block_dims(1, 1, C / H);
|
|
123
|
+
sycl::range<3> grid_dims(1, 1, B * H);
|
|
124
|
+
|
|
125
|
+
// Submit kernel
|
|
126
|
+
stream->submit([&](sycl::handler& cgh) {
|
|
127
|
+
sycl::local_accessor<float, 1> shared_mem_acc(shared_mem_size, cgh);
|
|
128
|
+
|
|
129
|
+
cgh.parallel_for(
|
|
130
|
+
sycl::nd_range<3>(grid_dims * block_dims, block_dims),
|
|
131
|
+
[=](sycl::nd_item<3> item_ct1) {
|
|
132
|
+
rwkv_wkv_f32_kernel(
|
|
133
|
+
B, T, C, H, k_d, v_d, r_d, tf_d, td_d, s_d, dst_d,
|
|
134
|
+
item_ct1, shared_mem_acc.get_pointer()
|
|
135
|
+
);
|
|
136
|
+
});
|
|
137
|
+
});
|
|
138
|
+
|
|
139
|
+
GGML_UNUSED(src0);
|
|
140
|
+
GGML_UNUSED(src1);
|
|
141
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
#include "ggml-threading.h"
|
|
2
|
+
#include <mutex>
|
|
3
|
+
|
|
4
|
+
std::mutex ggml_critical_section_mutex;
|
|
5
|
+
|
|
6
|
+
void ggml_critical_section_start() {
|
|
7
|
+
ggml_critical_section_mutex.lock();
|
|
8
|
+
}
|
|
9
|
+
|
|
10
|
+
void ggml_critical_section_end(void) {
|
|
11
|
+
ggml_critical_section_mutex.unlock();
|
|
12
|
+
}
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
find_package(Vulkan COMPONENTS glslc REQUIRED)
|
|
2
|
+
|
|
3
|
+
if (Vulkan_FOUND)
|
|
4
|
+
message(STATUS "Vulkan found")
|
|
5
|
+
|
|
6
|
+
ggml_add_backend_library(ggml-vulkan
|
|
7
|
+
ggml-vulkan.cpp
|
|
8
|
+
../../include/ggml-vulkan.h
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
# Compile a test shader to determine whether GL_NV_cooperative_matrix2 is supported.
|
|
12
|
+
# If it's not, there will be an error to stderr.
|
|
13
|
+
# If it's supported, set a define to indicate that we should compile those shaders
|
|
14
|
+
execute_process(COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/test_coopmat2_support.comp"
|
|
15
|
+
OUTPUT_VARIABLE glslc_output
|
|
16
|
+
ERROR_VARIABLE glslc_error)
|
|
17
|
+
|
|
18
|
+
if (${glslc_error} MATCHES ".*extension not supported: GL_NV_cooperative_matrix2.*")
|
|
19
|
+
message(STATUS "GL_NV_cooperative_matrix2 not supported by glslc")
|
|
20
|
+
else()
|
|
21
|
+
message(STATUS "GL_NV_cooperative_matrix2 supported by glslc")
|
|
22
|
+
add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
|
|
23
|
+
endif()
|
|
24
|
+
|
|
25
|
+
target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan)
|
|
26
|
+
target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR})
|
|
27
|
+
|
|
28
|
+
# Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build
|
|
29
|
+
# Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector
|
|
30
|
+
if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
|
|
31
|
+
add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0)
|
|
32
|
+
endif()
|
|
33
|
+
|
|
34
|
+
if (GGML_VULKAN_CHECK_RESULTS)
|
|
35
|
+
add_compile_definitions(GGML_VULKAN_CHECK_RESULTS)
|
|
36
|
+
endif()
|
|
37
|
+
|
|
38
|
+
if (GGML_VULKAN_DEBUG)
|
|
39
|
+
add_compile_definitions(GGML_VULKAN_DEBUG)
|
|
40
|
+
endif()
|
|
41
|
+
|
|
42
|
+
if (GGML_VULKAN_MEMORY_DEBUG)
|
|
43
|
+
add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG)
|
|
44
|
+
endif()
|
|
45
|
+
|
|
46
|
+
if (GGML_VULKAN_SHADER_DEBUG_INFO)
|
|
47
|
+
add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO)
|
|
48
|
+
endif()
|
|
49
|
+
|
|
50
|
+
if (GGML_VULKAN_PERF)
|
|
51
|
+
add_compile_definitions(GGML_VULKAN_PERF)
|
|
52
|
+
endif()
|
|
53
|
+
|
|
54
|
+
if (GGML_VULKAN_VALIDATE)
|
|
55
|
+
add_compile_definitions(GGML_VULKAN_VALIDATE)
|
|
56
|
+
endif()
|
|
57
|
+
|
|
58
|
+
if (GGML_VULKAN_RUN_TESTS)
|
|
59
|
+
add_compile_definitions(GGML_VULKAN_RUN_TESTS)
|
|
60
|
+
endif()
|
|
61
|
+
|
|
62
|
+
add_subdirectory(vulkan-shaders)
|
|
63
|
+
|
|
64
|
+
set (_ggml_vk_genshaders_cmd vulkan-shaders-gen)
|
|
65
|
+
set (_ggml_vk_header ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp)
|
|
66
|
+
set (_ggml_vk_source ${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.cpp)
|
|
67
|
+
set (_ggml_vk_input_dir ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders)
|
|
68
|
+
set (_ggml_vk_output_dir ${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv)
|
|
69
|
+
|
|
70
|
+
file(GLOB _ggml_vk_shader_deps "${_ggml_vk_input_dir}/*.comp")
|
|
71
|
+
|
|
72
|
+
add_custom_command(
|
|
73
|
+
OUTPUT ${_ggml_vk_header}
|
|
74
|
+
${_ggml_vk_source}
|
|
75
|
+
|
|
76
|
+
COMMAND ${_ggml_vk_genshaders_cmd}
|
|
77
|
+
--glslc ${Vulkan_GLSLC_EXECUTABLE}
|
|
78
|
+
--input-dir ${_ggml_vk_input_dir}
|
|
79
|
+
--output-dir ${_ggml_vk_output_dir}
|
|
80
|
+
--target-hpp ${_ggml_vk_header}
|
|
81
|
+
--target-cpp ${_ggml_vk_source}
|
|
82
|
+
--no-clean
|
|
83
|
+
|
|
84
|
+
DEPENDS ${_ggml_vk_shader_deps} ${_ggml_vk_genshaders_cmd}
|
|
85
|
+
COMMENT "Generate vulkan shaders"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
target_sources(ggml-vulkan PRIVATE ${_ggml_vk_source} ${_ggml_vk_header})
|
|
89
|
+
|
|
90
|
+
else()
|
|
91
|
+
message(WARNING "Vulkan not found")
|
|
92
|
+
endif()
|