@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -3,7 +3,9 @@
|
|
|
3
3
|
//
|
|
4
4
|
#include <arm_neon.h>
|
|
5
5
|
#include <assert.h>
|
|
6
|
+
#include <atomic>
|
|
6
7
|
#include <cfloat>
|
|
8
|
+
#include <stdexcept>
|
|
7
9
|
#include <stdint.h>
|
|
8
10
|
#include <string.h>
|
|
9
11
|
#if defined(__linux__)
|
|
@@ -34,8 +36,9 @@
|
|
|
34
36
|
#include "ggml-common.h"
|
|
35
37
|
|
|
36
38
|
struct ggml_kleidiai_context {
|
|
39
|
+
cpu_feature features;
|
|
37
40
|
ggml_kleidiai_kernels * kernels;
|
|
38
|
-
} static ctx = { NULL };
|
|
41
|
+
} static ctx = { CPU_FEATURE_NONE, NULL };
|
|
39
42
|
|
|
40
43
|
static void init_kleidiai_context(void) {
|
|
41
44
|
|
|
@@ -47,18 +50,18 @@ static void init_kleidiai_context(void) {
|
|
|
47
50
|
const char *env_var = getenv("GGML_KLEIDIAI_SME");
|
|
48
51
|
int sme_enabled = 0;
|
|
49
52
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
+
ctx.features = (ggml_cpu_has_dotprod() ? CPU_FEATURE_DOTPROD : CPU_FEATURE_NONE) |
|
|
54
|
+
(ggml_cpu_has_matmul_int8() ? CPU_FEATURE_I8MM : CPU_FEATURE_NONE) |
|
|
55
|
+
(ggml_cpu_has_sve() ? CPU_FEATURE_SVE : CPU_FEATURE_NONE);
|
|
53
56
|
|
|
54
57
|
if (env_var) {
|
|
55
58
|
sme_enabled = atoi(env_var);
|
|
56
59
|
}
|
|
57
60
|
|
|
58
61
|
if (sme_enabled != 0) {
|
|
59
|
-
features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
|
62
|
+
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
|
60
63
|
}
|
|
61
|
-
ctx.kernels =
|
|
64
|
+
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
|
62
65
|
}
|
|
63
66
|
ggml_critical_section_end();
|
|
64
67
|
}
|
|
@@ -68,96 +71,275 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) {
|
|
|
68
71
|
return tensor->ne[dim];
|
|
69
72
|
}
|
|
70
73
|
|
|
74
|
+
template<typename Ret, typename Variant, typename... Args>
|
|
75
|
+
static Ret variant_call(const Variant & var, Args&&... args) {
|
|
76
|
+
return std::visit([&](auto&& func) -> Ret {
|
|
77
|
+
if constexpr (std::is_invocable_r_v<Ret, decltype(func), Args...>) {
|
|
78
|
+
return func(std::forward<Args>(args)...);
|
|
79
|
+
} else {
|
|
80
|
+
throw std::runtime_error("Invalid function type in variant_call");
|
|
81
|
+
}
|
|
82
|
+
}, var);
|
|
83
|
+
}
|
|
84
|
+
|
|
71
85
|
namespace ggml::cpu::kleidiai {
|
|
86
|
+
|
|
87
|
+
static size_t round_down(size_t x, size_t y) {
|
|
88
|
+
return y == 0 ? x : x - (x % y);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint16_t * src, size_t rhs_stride) {
|
|
92
|
+
size_t src_stride = rhs_stride / sizeof(uint16_t);
|
|
93
|
+
size_t dst_stride = n;
|
|
94
|
+
|
|
95
|
+
for (size_t k_idx = 0; k_idx < k; ++k_idx) {
|
|
96
|
+
for (size_t n_idx = 0; n_idx < n; ++n_idx) {
|
|
97
|
+
uint16_t v = *(src + k_idx + n_idx * src_stride);
|
|
98
|
+
*(dst + n_idx + k_idx * dst_stride) = kai_cast_f32_f16(v);
|
|
99
|
+
}
|
|
100
|
+
}
|
|
101
|
+
}
|
|
102
|
+
|
|
72
103
|
class tensor_traits : public ggml::cpu::tensor_traits {
|
|
73
104
|
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
|
74
|
-
|
|
75
|
-
|
|
105
|
+
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
|
106
|
+
GGML_ASSERT(kernels);
|
|
107
|
+
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
|
76
108
|
|
|
77
109
|
size_t k = op->src[0]->ne[0];
|
|
110
|
+
size_t n = op->src[0]->ne[1];
|
|
78
111
|
size_t m = op->src[1]->ne[1];
|
|
79
112
|
|
|
80
113
|
size_t mr = kernel->get_mr();
|
|
81
114
|
size_t kr = kernel->get_kr();
|
|
82
115
|
size_t sr = kernel->get_sr();
|
|
83
116
|
|
|
84
|
-
|
|
117
|
+
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
|
118
|
+
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr);
|
|
119
|
+
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
|
120
|
+
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) +
|
|
121
|
+
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
|
|
122
|
+
k * n * sizeof(float) + n * sizeof(float);
|
|
123
|
+
} else {
|
|
124
|
+
GGML_ASSERT(false);
|
|
125
|
+
}
|
|
85
126
|
|
|
86
127
|
return true;
|
|
87
128
|
}
|
|
88
129
|
|
|
130
|
+
|
|
89
131
|
bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override {
|
|
90
132
|
if (dst->op == GGML_OP_MUL_MAT) {
|
|
91
|
-
|
|
92
|
-
|
|
133
|
+
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
|
134
|
+
return compute_forward_q4_0(params, dst);
|
|
135
|
+
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
|
136
|
+
return compute_forward_kv_cache(params, dst);
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
return false;
|
|
140
|
+
}
|
|
93
141
|
|
|
94
|
-
|
|
142
|
+
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
143
|
+
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
|
|
95
144
|
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
lhs_packing_info * lhs_info = &ctx.kernels->lhs_info;
|
|
145
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
146
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
99
147
|
|
|
100
|
-
|
|
148
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
101
149
|
|
|
102
|
-
|
|
103
|
-
|
|
150
|
+
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
151
|
+
GGML_ASSERT(kernels);
|
|
104
152
|
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
const size_t n = ne01;
|
|
153
|
+
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
|
154
|
+
GGML_ASSERT(kernel);
|
|
108
155
|
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
const size_t n_start = ith * num_n_per_thread;
|
|
156
|
+
const int nth = params->nth;
|
|
157
|
+
const int ith = params->ith;
|
|
112
158
|
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
159
|
+
const int64_t lhs_batch_size0 = ne12;
|
|
160
|
+
const int64_t rhs_batch_size0 = ne02;
|
|
161
|
+
const int64_t batch_size = rhs_batch_size0;
|
|
162
|
+
|
|
163
|
+
const int64_t r = lhs_batch_size0 / rhs_batch_size0;
|
|
164
|
+
|
|
165
|
+
const int64_t m = ne11 * r;
|
|
166
|
+
const int64_t n = ne01;
|
|
167
|
+
const int64_t k = ne00;
|
|
168
|
+
|
|
169
|
+
const size_t lhs_stride = src1->nb[1];
|
|
170
|
+
const size_t rhs_stride = src0->nb[1];
|
|
171
|
+
const size_t dst_stride = dst->nb[1];
|
|
172
|
+
|
|
173
|
+
const int64_t mr = static_cast<int64_t>(kernel->get_mr());
|
|
174
|
+
const int64_t nr = static_cast<int64_t>(kernel->get_nr());
|
|
175
|
+
const int64_t kr = static_cast<int64_t>(kernel->get_kr());
|
|
176
|
+
const int64_t sr = static_cast<int64_t>(kernel->get_sr());
|
|
177
|
+
|
|
178
|
+
const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr);
|
|
179
|
+
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
|
|
180
|
+
const size_t kxn_size = k * n * sizeof(float);
|
|
181
|
+
const size_t bias_size = n * sizeof(float);
|
|
182
|
+
|
|
183
|
+
const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size;
|
|
184
|
+
GGML_ASSERT(wsize_required <= params->wsize);
|
|
185
|
+
|
|
186
|
+
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
|
|
187
|
+
uint8_t * rhs_packed = lhs_packed + lhs_packed_size;
|
|
188
|
+
uint8_t * rhs_kxn = rhs_packed + rhs_packed_size;
|
|
189
|
+
uint8_t * bias = rhs_kxn + kxn_size;
|
|
190
|
+
|
|
191
|
+
for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
|
|
192
|
+
const uint8_t * lhs_batch = static_cast<const uint8_t *>(src1->data) + batch_idx * m * lhs_stride;
|
|
193
|
+
const uint8_t * rhs_batch = static_cast<const uint8_t *>(src0->data) + batch_idx * n * rhs_stride;
|
|
194
|
+
uint8_t * dst_batch = static_cast<uint8_t *>(dst->data) + batch_idx * m * dst_stride;
|
|
117
195
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
196
|
+
// LHS packing
|
|
197
|
+
{
|
|
198
|
+
const int64_t m_roundup_mr = kai_roundup(m, mr);
|
|
199
|
+
const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth);
|
|
200
|
+
|
|
201
|
+
if (ith < num_threads) {
|
|
202
|
+
const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr);
|
|
203
|
+
const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0;
|
|
204
|
+
|
|
205
|
+
const int64_t m_start = ith * num_m_per_thread0;
|
|
206
|
+
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
|
|
207
|
+
|
|
208
|
+
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
|
|
209
|
+
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr);
|
|
210
|
+
|
|
211
|
+
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
|
|
212
|
+
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
|
|
213
|
+
|
|
214
|
+
variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
|
|
215
|
+
}
|
|
133
216
|
}
|
|
134
217
|
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
218
|
+
// RHS packing
|
|
219
|
+
if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) {
|
|
220
|
+
// First thread to reach this point handles RHS packing
|
|
221
|
+
memset(bias, 0, n * sizeof(float));
|
|
222
|
+
transpose_f32kxn_f16nxk(n, k, reinterpret_cast<float *>(rhs_kxn),
|
|
223
|
+
reinterpret_cast<const uint16_t *>(rhs_batch), rhs_stride);
|
|
141
224
|
|
|
142
|
-
|
|
225
|
+
variant_call<void>(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float),
|
|
226
|
+
rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr);
|
|
143
227
|
}
|
|
144
228
|
|
|
145
229
|
ggml_barrier(params->threadpool);
|
|
146
230
|
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
231
|
+
first_to_arrive.clear(std::memory_order_release);
|
|
232
|
+
|
|
233
|
+
// Perform the matmul
|
|
234
|
+
{
|
|
235
|
+
const int64_t m_to_process = m;
|
|
236
|
+
const int64_t m_start = 0;
|
|
237
|
+
|
|
238
|
+
const int64_t n_step = static_cast<int64_t>(kernel->get_n_step());
|
|
239
|
+
const int64_t num_threads = KAI_MIN(n / n_step, nth);
|
|
240
|
+
|
|
241
|
+
if (ith < num_threads) {
|
|
242
|
+
const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step);
|
|
243
|
+
const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0;
|
|
244
|
+
|
|
245
|
+
const int64_t n_start = ith * num_n_per_thread0;
|
|
246
|
+
const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0;
|
|
247
|
+
|
|
248
|
+
const size_t lhs_packed_offset = variant_call<size_t>(kernel->get_lhs_offset, m_start, k);
|
|
249
|
+
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k);
|
|
250
|
+
const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride);
|
|
251
|
+
|
|
252
|
+
const void * lhs_ptr = lhs_packed + lhs_packed_offset;
|
|
253
|
+
const void * rhs_ptr = rhs_packed + rhs_packed_offset;
|
|
254
|
+
float * dst_ptr = reinterpret_cast<float *>(dst_batch + dst_offset);
|
|
255
|
+
|
|
256
|
+
variant_call<void>(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
|
|
257
|
+
}
|
|
258
|
+
}
|
|
259
|
+
|
|
260
|
+
if (batch_idx != batch_size - 1) {
|
|
261
|
+
// This barrier is necessary when the batch size is larger than 1. While processing a batch,
|
|
262
|
+
// the work data buffer (params->wdata) is used as temporary storage which means that only
|
|
263
|
+
// a single batch can be processed at any given time. No barrier is needed for the last
|
|
264
|
+
// batch since GGML inserts a barrier between the execution of every operator.
|
|
265
|
+
ggml_barrier(params->threadpool);
|
|
266
|
+
}
|
|
159
267
|
}
|
|
160
|
-
|
|
268
|
+
|
|
269
|
+
return true;
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
|
273
|
+
const ggml_tensor * src0 = dst->src[0];
|
|
274
|
+
const ggml_tensor * src1 = dst->src[1];
|
|
275
|
+
|
|
276
|
+
GGML_TENSOR_BINARY_OP_LOCALS
|
|
277
|
+
|
|
278
|
+
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
|
279
|
+
GGML_ASSERT(kernels);
|
|
280
|
+
|
|
281
|
+
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
|
282
|
+
lhs_packing_info * lhs_info = &kernels->lhs_info;
|
|
283
|
+
|
|
284
|
+
GGML_ASSERT(kernel);
|
|
285
|
+
|
|
286
|
+
const int ith = params->ith;
|
|
287
|
+
const int nth = params->nth;
|
|
288
|
+
|
|
289
|
+
const size_t k = ne00;
|
|
290
|
+
const size_t m = ne11;
|
|
291
|
+
const size_t n = ne01;
|
|
292
|
+
|
|
293
|
+
size_t mr = kernel->get_mr();
|
|
294
|
+
size_t kr = kernel->get_kr();
|
|
295
|
+
size_t sr = kernel->get_sr();
|
|
296
|
+
|
|
297
|
+
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
|
298
|
+
uint8_t * lhs_packed = (uint8_t*)params->wdata;
|
|
299
|
+
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
|
300
|
+
|
|
301
|
+
const size_t n_step = kernel->get_n_step();
|
|
302
|
+
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
|
303
|
+
const size_t n_start = ith * num_n_per_thread;
|
|
304
|
+
|
|
305
|
+
size_t n_to_process = num_n_per_thread;
|
|
306
|
+
if ((n_start + n_to_process) > n) {
|
|
307
|
+
n_to_process = n - n_start;
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
// Calculate number of columns to be processed per thread
|
|
311
|
+
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
|
312
|
+
const size_t m_start = ith * num_m_per_thread;
|
|
313
|
+
size_t m_to_process = num_m_per_thread;
|
|
314
|
+
if ((m_start + m_to_process) > m) {
|
|
315
|
+
m_to_process = m - m_start;
|
|
316
|
+
}
|
|
317
|
+
|
|
318
|
+
if (m_start < m) {
|
|
319
|
+
// Transform LHS
|
|
320
|
+
const size_t src_stride = src1->nb[1];
|
|
321
|
+
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
|
322
|
+
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr);
|
|
323
|
+
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
|
324
|
+
|
|
325
|
+
variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
ggml_barrier(params->threadpool);
|
|
329
|
+
|
|
330
|
+
// Perform the operation
|
|
331
|
+
const size_t dst_stride = dst->nb[1];
|
|
332
|
+
const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr);
|
|
333
|
+
const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0);
|
|
334
|
+
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
|
335
|
+
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
|
336
|
+
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
|
|
337
|
+
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
|
338
|
+
|
|
339
|
+
variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
|
340
|
+
sizeof(float), -FLT_MAX, FLT_MAX);
|
|
341
|
+
|
|
342
|
+
return true;
|
|
161
343
|
}
|
|
162
344
|
|
|
163
345
|
public:
|
|
@@ -170,13 +352,13 @@ public:
|
|
|
170
352
|
size_t sr = ctx.kernels->gemm.get_sr();
|
|
171
353
|
|
|
172
354
|
#ifndef NDEBUG
|
|
173
|
-
const size_t repacked_size = ctx.kernels->rhs_info.packed_size
|
|
355
|
+
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
|
174
356
|
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
|
175
357
|
#endif
|
|
176
358
|
struct kai_rhs_pack_qs4cxs1s0_param params;
|
|
177
359
|
params.lhs_zero_point = 1;
|
|
178
360
|
params.rhs_zero_point = 8;
|
|
179
|
-
ctx.kernels->rhs_info.pack_func
|
|
361
|
+
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
|
180
362
|
|
|
181
363
|
return 0;
|
|
182
364
|
|
|
@@ -190,7 +372,7 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
|
|
190
372
|
}
|
|
191
373
|
} // namespace ggml::cpu::kleidiai
|
|
192
374
|
|
|
193
|
-
|
|
375
|
+
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
|
194
376
|
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
|
195
377
|
|
|
196
378
|
GGML_UNUSED(buffer);
|
|
@@ -239,12 +421,11 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
|
|
239
421
|
namespace ggml::cpu::kleidiai {
|
|
240
422
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
241
423
|
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
|
242
|
-
if (
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
) {
|
|
424
|
+
if (op->op == GGML_OP_MUL_MAT &&
|
|
425
|
+
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
|
426
|
+
op->src[0]->buffer &&
|
|
427
|
+
(ggml_n_dims(op->src[0]) == 2) &&
|
|
428
|
+
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
|
248
429
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
|
249
430
|
return false;
|
|
250
431
|
}
|
|
@@ -261,6 +442,19 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|
|
261
442
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
|
262
443
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
|
263
444
|
}
|
|
445
|
+
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
|
|
446
|
+
op->src[0]->op == GGML_OP_VIEW &&
|
|
447
|
+
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
|
|
448
|
+
op->src[1]->ne[1] > 1) {
|
|
449
|
+
if ((op->src[0]->nb[0] != 2) ||
|
|
450
|
+
(op->src[1]->nb[0] != 4) ||
|
|
451
|
+
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
|
452
|
+
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
|
453
|
+
return nullptr;
|
|
454
|
+
}
|
|
455
|
+
|
|
456
|
+
return ggml::cpu::kleidiai::get_tensor_traits(NULL, NULL);
|
|
457
|
+
}
|
|
264
458
|
}
|
|
265
459
|
return nullptr;
|
|
266
460
|
}
|