@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -55,7 +55,37 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) {
|
|
|
55
55
|
if (ubatch->pos && pos) {
|
|
56
56
|
const int64_t n_tokens = ubatch->n_tokens;
|
|
57
57
|
|
|
58
|
-
|
|
58
|
+
if (ubatch->token && n_pos_per_embd == 4) {
|
|
59
|
+
// in case we're using M-RoPE with text tokens, convert the 1D positions to 4D
|
|
60
|
+
// the 3 first dims are the same, and 4th dim is all 0
|
|
61
|
+
std::vector<llama_pos> pos_data(n_tokens*n_pos_per_embd);
|
|
62
|
+
// copy the first dimension
|
|
63
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
64
|
+
pos_data[ i] = ubatch->pos[i];
|
|
65
|
+
pos_data[ n_tokens + i] = ubatch->pos[i];
|
|
66
|
+
pos_data[2 * n_tokens + i] = ubatch->pos[i];
|
|
67
|
+
pos_data[3 * n_tokens + i] = 0; // 4th dim is 0
|
|
68
|
+
}
|
|
69
|
+
ggml_backend_tensor_set(pos, pos_data.data(), 0, pos_data.size()*ggml_element_size(pos));
|
|
70
|
+
} else {
|
|
71
|
+
ggml_backend_tensor_set(pos, ubatch->pos, 0, n_tokens*n_pos_per_embd*ggml_element_size(pos));
|
|
72
|
+
}
|
|
73
|
+
}
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
|
77
|
+
if (ubatch->pos && attn_scale) {
|
|
78
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
79
|
+
|
|
80
|
+
std::vector<float> attn_scale_data(n_tokens, 0.0f);
|
|
81
|
+
for (int i = 0; i < n_tokens; ++i) {
|
|
82
|
+
const float pos = ubatch->pos[i];
|
|
83
|
+
attn_scale_data[i] = std::log(
|
|
84
|
+
std::floor((pos + 1.0f) / n_attn_temp_floor_scale) + 1.0
|
|
85
|
+
) * f_attn_temp_scale + 1.0;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
ggml_backend_tensor_set(attn_scale, attn_scale_data.data(), 0, n_tokens*ggml_element_size(attn_scale));
|
|
59
89
|
}
|
|
60
90
|
}
|
|
61
91
|
|
|
@@ -254,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) {
|
|
|
254
284
|
|
|
255
285
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
|
256
286
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
//////////////////////////////////////////////
|
|
260
|
-
// TODO: this should not mutate the KV cache !
|
|
261
|
-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
|
262
|
-
|
|
263
|
-
// prevent out-of-bound sources
|
|
264
|
-
if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) {
|
|
265
|
-
kv_cell.src = cell_id;
|
|
266
|
-
}
|
|
267
|
-
|
|
268
|
-
data[i] = kv_cell.src;
|
|
269
|
-
|
|
270
|
-
// TODO: do not mutate the KV cache
|
|
271
|
-
// ensure copy only happens once
|
|
272
|
-
if (kv_cell.src != (int32_t) cell_id) {
|
|
273
|
-
kv_cell.src = cell_id;
|
|
274
|
-
}
|
|
287
|
+
data[i] = kv_self->s_copy(i);
|
|
275
288
|
}
|
|
276
289
|
}
|
|
277
290
|
}
|
|
@@ -287,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) {
|
|
|
287
300
|
|
|
288
301
|
// clear unused states
|
|
289
302
|
for (int i = 0; i < n_kv; ++i) {
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
//////////////////////////////////////////////
|
|
293
|
-
// TODO: this should not mutate the KV cache !
|
|
294
|
-
llama_kv_cell & kv_cell = const_cast<class llama_kv_cache_unified *>(kv_self)->cells[i];
|
|
295
|
-
|
|
296
|
-
data[i] = (float) (kv_cell.src >= 0);
|
|
297
|
-
|
|
298
|
-
// only clear once
|
|
299
|
-
if (kv_cell.src < 0) {
|
|
300
|
-
kv_cell.src = cell_id;
|
|
301
|
-
}
|
|
303
|
+
data[i] = kv_self->s_mask(i);
|
|
302
304
|
}
|
|
303
305
|
}
|
|
304
306
|
}
|
|
@@ -402,120 +404,94 @@ void llm_graph_input_attn_no_cache::set_input(const llama_ubatch * ubatch) {
|
|
|
402
404
|
|
|
403
405
|
void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) {
|
|
404
406
|
if (self_kq_mask || self_kq_mask_swa) {
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
410
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
|
411
|
-
|
|
412
|
-
float * data = nullptr;
|
|
413
|
-
float * data_swa = nullptr;
|
|
414
|
-
|
|
415
|
-
if (self_kq_mask) {
|
|
416
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
|
417
|
-
data = (float *) self_kq_mask->data;
|
|
418
|
-
}
|
|
407
|
+
const int64_t n_kv = kv_self->n;
|
|
408
|
+
const int64_t n_tokens = ubatch->n_tokens;
|
|
409
|
+
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
410
|
+
const int64_t n_seqs = ubatch->n_seqs;
|
|
419
411
|
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
data_swa = (float *) self_kq_mask_swa->data;
|
|
423
|
-
}
|
|
412
|
+
float * data = nullptr;
|
|
413
|
+
float * data_swa = nullptr;
|
|
424
414
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
for (int s = 0; s < n_seqs; ++s) {
|
|
430
|
-
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
|
415
|
+
if (self_kq_mask) {
|
|
416
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
|
417
|
+
data = (float *) self_kq_mask->data;
|
|
418
|
+
}
|
|
431
419
|
|
|
432
|
-
|
|
433
|
-
|
|
420
|
+
if (self_kq_mask_swa) {
|
|
421
|
+
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask_swa->buffer));
|
|
422
|
+
data_swa = (float *) self_kq_mask_swa->data;
|
|
423
|
+
}
|
|
434
424
|
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
425
|
+
// Use only the previous KV cells of the correct sequence for each token of the ubatch.
|
|
426
|
+
// It's assumed that if a token in the batch has multiple sequences, they are equivalent.
|
|
427
|
+
// Example with a cache of 10 tokens, 2 tokens populated in cache and 3 tokens in batch:
|
|
428
|
+
// Causal mask:
|
|
429
|
+
// xxx-------
|
|
430
|
+
// xxxx------
|
|
431
|
+
// xxxxx-----
|
|
432
|
+
// Non-causal mask:
|
|
433
|
+
// xxxxx-----
|
|
434
|
+
// xxxxx-----
|
|
435
|
+
// xxxxx-----
|
|
436
|
+
// To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615
|
|
437
|
+
for (int h = 0; h < 1; ++h) {
|
|
438
|
+
for (int s = 0; s < n_seqs; ++s) {
|
|
439
|
+
const llama_seq_id seq_id = ubatch->seq_id[s][0];
|
|
440
|
+
|
|
441
|
+
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
442
|
+
const llama_pos pos = ubatch->pos[s*n_seq_tokens + j];
|
|
443
|
+
for (int i = 0; i < n_kv; ++i) {
|
|
444
|
+
float f;
|
|
445
|
+
// mask the token if:
|
|
446
|
+
if (!kv_self->cells[i].has_seq_id(seq_id) // not the correct sequence
|
|
447
|
+
|| (cparams.causal_attn && kv_self->cells[i].pos > pos) // for causal, mask future tokens
|
|
448
|
+
) {
|
|
449
|
+
f = -INFINITY;
|
|
450
|
+
} else {
|
|
451
|
+
if (hparams.use_alibi) {
|
|
452
|
+
f = -std::abs(kv_self->cells[i].pos - pos);
|
|
439
453
|
} else {
|
|
440
|
-
|
|
441
|
-
f = -std::abs(kv_self->cells[i].pos - pos);
|
|
442
|
-
} else {
|
|
443
|
-
f = 0.0f;
|
|
444
|
-
}
|
|
454
|
+
f = 0.0f;
|
|
445
455
|
}
|
|
456
|
+
}
|
|
446
457
|
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
458
|
+
if (data) {
|
|
459
|
+
data[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
|
460
|
+
}
|
|
450
461
|
|
|
451
|
-
|
|
452
|
-
|
|
462
|
+
// may need to cut off old tokens for sliding window
|
|
463
|
+
// TODO @ngxson : we are currently re-using the swa logic to store the chunked mask, we should rename SWA to something more generic like "aux mask"
|
|
464
|
+
if (data_swa) {
|
|
465
|
+
if (hparams.n_attn_chunk) {
|
|
466
|
+
llama_pos pos_chunk_start = (pos / hparams.n_attn_chunk) * hparams.n_attn_chunk;
|
|
467
|
+
if (kv_self->cells[i].pos < pos_chunk_start || pos < pos_chunk_start) {
|
|
468
|
+
f = -INFINITY;
|
|
469
|
+
}
|
|
470
|
+
} else {
|
|
453
471
|
if (pos - kv_self->cells[i].pos >= (int32_t)hparams.n_swa) {
|
|
454
472
|
f = -INFINITY;
|
|
455
473
|
}
|
|
456
|
-
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
|
457
474
|
}
|
|
475
|
+
data_swa[h*(n_kv*n_tokens) + s*(n_kv*n_seq_tokens) + j*n_kv + i] = f;
|
|
458
476
|
}
|
|
459
477
|
}
|
|
460
478
|
}
|
|
479
|
+
}
|
|
461
480
|
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
}
|
|
468
|
-
}
|
|
469
|
-
|
|
470
|
-
if (data_swa) {
|
|
471
|
-
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
|
472
|
-
for (int j = 0; j < n_kv; ++j) {
|
|
473
|
-
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
474
|
-
}
|
|
481
|
+
// mask padded tokens
|
|
482
|
+
if (data) {
|
|
483
|
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
|
484
|
+
for (int j = 0; j < n_kv; ++j) {
|
|
485
|
+
data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
475
486
|
}
|
|
476
487
|
}
|
|
477
488
|
}
|
|
478
|
-
} else {
|
|
479
|
-
const int64_t n_tokens = ubatch->n_tokens;
|
|
480
|
-
const int64_t n_seq_tokens = ubatch->n_seq_tokens;
|
|
481
|
-
const int64_t n_seqs = ubatch->n_seqs;
|
|
482
|
-
// when using kv cache, the mask needs to match the kv cache size
|
|
483
|
-
const int64_t n_stride = n_tokens;
|
|
484
|
-
|
|
485
|
-
GGML_ASSERT(ggml_backend_buffer_is_host(self_kq_mask->buffer));
|
|
486
|
-
|
|
487
|
-
float * data = (float *) self_kq_mask->data;
|
|
488
|
-
|
|
489
|
-
for (int h = 0; h < 1; ++h) {
|
|
490
|
-
for (int s1 = 0; s1 < n_seqs; ++s1) {
|
|
491
|
-
const llama_seq_id seq_id = ubatch->seq_id[s1][0];
|
|
492
|
-
|
|
493
|
-
for (int j = 0; j < n_seq_tokens; ++j) {
|
|
494
|
-
const int32_t tj = s1*n_seq_tokens + j;
|
|
495
|
-
|
|
496
|
-
for (int s0 = 0; s0 < n_seqs; ++s0) {
|
|
497
|
-
for (int i = 0; i < n_seq_tokens; ++i) {
|
|
498
|
-
const int32_t ti = s0*n_seq_tokens + i;
|
|
499
|
-
float f = -INFINITY;
|
|
500
|
-
|
|
501
|
-
for (int s = 0; s < ubatch->n_seq_id[s0]; ++s) {
|
|
502
|
-
if (ubatch->seq_id[s0][s] == seq_id) {
|
|
503
|
-
if (hparams.use_alibi) {
|
|
504
|
-
f = -std::abs(ubatch->pos[ti] - ubatch->pos[tj]);
|
|
505
|
-
} else {
|
|
506
|
-
f = 0.0f;
|
|
507
|
-
}
|
|
508
|
-
break;
|
|
509
|
-
}
|
|
510
|
-
}
|
|
511
|
-
|
|
512
|
-
data[h*(n_tokens*n_tokens) + tj*n_stride + ti] = f;
|
|
513
|
-
}
|
|
514
|
-
}
|
|
515
489
|
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
490
|
+
// mask padded tokens
|
|
491
|
+
if (data_swa) {
|
|
492
|
+
for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
|
|
493
|
+
for (int j = 0; j < n_kv; ++j) {
|
|
494
|
+
data_swa[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
|
|
519
495
|
}
|
|
520
496
|
}
|
|
521
497
|
}
|
|
@@ -602,7 +578,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) :
|
|
|
602
578
|
res (std::make_unique<llm_graph_result>()) {
|
|
603
579
|
}
|
|
604
580
|
|
|
605
|
-
int64_t llm_graph_context::
|
|
581
|
+
int64_t llm_graph_context::n_pos_per_embd() const {
|
|
606
582
|
return arch == LLM_ARCH_QWEN2VL ? 4 : 1;
|
|
607
583
|
}
|
|
608
584
|
|
|
@@ -806,13 +782,17 @@ ggml_tensor * llm_graph_context::build_ffn(
|
|
|
806
782
|
} break;
|
|
807
783
|
}
|
|
808
784
|
|
|
809
|
-
if (type_gate == LLM_FFN_PAR) {
|
|
785
|
+
if (gate && type_gate == LLM_FFN_PAR) {
|
|
810
786
|
cur = ggml_mul(ctx0, cur, tmp);
|
|
811
787
|
cb(cur, "ffn_gate_par", il);
|
|
812
788
|
}
|
|
813
789
|
|
|
814
790
|
if (down) {
|
|
815
791
|
cur = build_lora_mm(down, cur);
|
|
792
|
+
if (arch == LLM_ARCH_GLM4) {
|
|
793
|
+
// GLM4 seems to have numerical issues with half-precision accumulators
|
|
794
|
+
ggml_mul_mat_set_prec(cur, GGML_PREC_F32);
|
|
795
|
+
}
|
|
816
796
|
}
|
|
817
797
|
|
|
818
798
|
if (down_b) {
|
|
@@ -846,8 +826,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
846
826
|
float w_scale,
|
|
847
827
|
llama_expert_gating_func_type gating_op,
|
|
848
828
|
int il) const {
|
|
849
|
-
int64_t n_embd
|
|
850
|
-
int64_t n_tokens = cur->ne[1];
|
|
829
|
+
const int64_t n_embd = cur->ne[0];
|
|
830
|
+
const int64_t n_tokens = cur->ne[1];
|
|
831
|
+
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
|
851
832
|
|
|
852
833
|
ggml_tensor * logits = build_lora_mm(gate_inp, cur); // [n_expert, n_tokens]
|
|
853
834
|
cb(logits, "ffn_moe_logits", il);
|
|
@@ -875,6 +856,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
875
856
|
cb(selection_probs, "ffn_moe_probs_biased", il);
|
|
876
857
|
}
|
|
877
858
|
|
|
859
|
+
// llama4 doesn't have exp_probs_b, and sigmoid is only used after top_k
|
|
860
|
+
// see: https://github.com/meta-llama/llama-models/blob/699a02993512fb36936b1b0741e13c06790bcf98/models/llama4/moe.py#L183-L198
|
|
861
|
+
if (arch == LLM_ARCH_LLAMA4) {
|
|
862
|
+
selection_probs = logits;
|
|
863
|
+
}
|
|
864
|
+
|
|
878
865
|
// select experts
|
|
879
866
|
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
|
880
867
|
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
|
@@ -901,34 +888,53 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
901
888
|
}
|
|
902
889
|
|
|
903
890
|
cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
|
|
891
|
+
|
|
892
|
+
if (weight_before_ffn) {
|
|
893
|
+
// TODO: this is a workaround as we don't yet have a repeat op that takes custom dim (ggml_repeat_4d)
|
|
894
|
+
ggml_tensor * repeated = ggml_new_tensor_3d(ctx0, cur->type, n_embd, n_expert_used, n_tokens);
|
|
895
|
+
repeated = ggml_repeat(ctx0, cur, repeated); // [n_embd, n_expert_used, n_tokens]
|
|
896
|
+
cur = ggml_mul(ctx0, repeated, weights);
|
|
897
|
+
cb(cur, "ffn_moe_weighted", il);
|
|
898
|
+
}
|
|
899
|
+
|
|
904
900
|
ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
905
901
|
cb(up, "ffn_moe_up", il);
|
|
906
902
|
|
|
907
|
-
ggml_tensor *
|
|
908
|
-
|
|
903
|
+
ggml_tensor * experts = nullptr;
|
|
904
|
+
if (gate_exps) {
|
|
905
|
+
cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens]
|
|
906
|
+
cb(cur, "ffn_moe_gate", il);
|
|
907
|
+
} else {
|
|
908
|
+
cur = up;
|
|
909
|
+
}
|
|
909
910
|
|
|
910
911
|
switch (type_op) {
|
|
911
912
|
case LLM_FFN_SILU:
|
|
912
913
|
{
|
|
913
|
-
|
|
914
|
-
cb(
|
|
914
|
+
cur = ggml_silu(ctx0, cur);
|
|
915
|
+
cb(cur, "ffn_moe_silu", il);
|
|
915
916
|
} break;
|
|
916
917
|
case LLM_FFN_GELU:
|
|
917
918
|
{
|
|
918
|
-
|
|
919
|
-
cb(
|
|
919
|
+
cur = ggml_gelu(ctx0, cur);
|
|
920
|
+
cb(cur, "ffn_moe_gelu", il);
|
|
920
921
|
} break;
|
|
921
922
|
default:
|
|
922
923
|
GGML_ABORT("fatal error");
|
|
923
924
|
}
|
|
924
925
|
|
|
925
|
-
|
|
926
|
-
|
|
926
|
+
if (gate_exps) {
|
|
927
|
+
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
|
|
928
|
+
cb(cur, "ffn_moe_gate_par", il);
|
|
929
|
+
}
|
|
927
930
|
|
|
928
|
-
|
|
931
|
+
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
|
929
932
|
cb(experts, "ffn_moe_down", il);
|
|
930
933
|
|
|
931
|
-
|
|
934
|
+
if (!weight_before_ffn) {
|
|
935
|
+
experts = ggml_mul(ctx0, experts, weights);
|
|
936
|
+
cb(cur, "ffn_moe_weighted", il);
|
|
937
|
+
}
|
|
932
938
|
|
|
933
939
|
// aggregate experts
|
|
934
940
|
ggml_tensor * moe_out = nullptr;
|
|
@@ -948,6 +954,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|
|
948
954
|
moe_out = ggml_cont(ctx0, moe_out);
|
|
949
955
|
}
|
|
950
956
|
|
|
957
|
+
cb(moe_out, "ffn_moe_out", il);
|
|
958
|
+
|
|
951
959
|
return moe_out;
|
|
952
960
|
}
|
|
953
961
|
|
|
@@ -963,6 +971,7 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
|
963
971
|
inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);
|
|
964
972
|
//cb(inp->tokens, "inp_tokens", -1);
|
|
965
973
|
ggml_set_input(inp->tokens);
|
|
974
|
+
res->t_tokens = inp->tokens;
|
|
966
975
|
|
|
967
976
|
cur = ggml_get_rows(ctx0, tok_embd, inp->tokens);
|
|
968
977
|
|
|
@@ -1003,11 +1012,25 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const {
|
|
|
1003
1012
|
}
|
|
1004
1013
|
|
|
1005
1014
|
ggml_tensor * llm_graph_context::build_inp_pos() const {
|
|
1006
|
-
auto inp = std::make_unique<llm_graph_input_pos>(
|
|
1015
|
+
auto inp = std::make_unique<llm_graph_input_pos>(n_pos_per_embd());
|
|
1007
1016
|
|
|
1008
1017
|
auto & cur = inp->pos;
|
|
1009
1018
|
|
|
1010
|
-
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*
|
|
1019
|
+
cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_embd());
|
|
1020
|
+
ggml_set_input(cur);
|
|
1021
|
+
|
|
1022
|
+
res->add_input(std::move(inp));
|
|
1023
|
+
|
|
1024
|
+
return cur;
|
|
1025
|
+
}
|
|
1026
|
+
|
|
1027
|
+
ggml_tensor * llm_graph_context::build_inp_attn_scale() const {
|
|
1028
|
+
auto inp = std::make_unique<llm_graph_input_attn_temp>(hparams.n_attn_temp_floor_scale, hparams.f_attn_temp_scale);
|
|
1029
|
+
|
|
1030
|
+
auto & cur = inp->attn_scale;
|
|
1031
|
+
|
|
1032
|
+
// this need to be 1x1xN for broadcasting
|
|
1033
|
+
cur = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, 1, 1, n_tokens);
|
|
1011
1034
|
ggml_set_input(cur);
|
|
1012
1035
|
|
|
1013
1036
|
res->add_input(std::move(inp));
|
|
@@ -1055,7 +1078,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const {
|
|
|
1055
1078
|
}
|
|
1056
1079
|
|
|
1057
1080
|
ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
1058
|
-
const
|
|
1081
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1059
1082
|
|
|
1060
1083
|
auto inp = std::make_unique<llm_graph_input_s_copy>(kv_self);
|
|
1061
1084
|
|
|
@@ -1072,7 +1095,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const {
|
|
|
1072
1095
|
}
|
|
1073
1096
|
|
|
1074
1097
|
ggml_tensor * llm_graph_context::build_inp_s_mask() const {
|
|
1075
|
-
const
|
|
1098
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1076
1099
|
|
|
1077
1100
|
auto inp = std::make_unique<llm_graph_input_s_mask>(kv_self);
|
|
1078
1101
|
|
|
@@ -1164,6 +1187,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1164
1187
|
ggml_tensor * v,
|
|
1165
1188
|
ggml_tensor * kq_b,
|
|
1166
1189
|
ggml_tensor * kq_mask,
|
|
1190
|
+
ggml_tensor * v_mla,
|
|
1167
1191
|
bool v_trans,
|
|
1168
1192
|
float kq_scale) const {
|
|
1169
1193
|
//const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
@@ -1175,8 +1199,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1175
1199
|
//const auto & n_embd_head_k = hparams.n_embd_head_k;
|
|
1176
1200
|
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
|
1177
1201
|
|
|
1178
|
-
const auto n_embd_head_v = v_trans ? v->ne[1] : v->ne[0];
|
|
1179
|
-
|
|
1180
1202
|
const auto n_tokens = q->ne[1];
|
|
1181
1203
|
const auto n_head = q->ne[2];
|
|
1182
1204
|
const auto n_kv = k->ne[1];
|
|
@@ -1191,12 +1213,37 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1191
1213
|
v = ggml_transpose(ctx0, v);
|
|
1192
1214
|
}
|
|
1193
1215
|
|
|
1216
|
+
// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
|
|
1217
|
+
if (k->type == GGML_TYPE_F32) {
|
|
1218
|
+
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
|
|
1219
|
+
}
|
|
1220
|
+
|
|
1221
|
+
if (v->type == GGML_TYPE_F32) {
|
|
1222
|
+
v = ggml_cast(ctx0, v, GGML_TYPE_F16);
|
|
1223
|
+
}
|
|
1224
|
+
|
|
1194
1225
|
cur = ggml_flash_attn_ext(ctx0, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
|
1195
1226
|
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
|
1196
1227
|
|
|
1197
1228
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
|
1198
1229
|
|
|
1199
|
-
|
|
1230
|
+
if (v_mla) {
|
|
1231
|
+
#if 0
|
|
1232
|
+
// v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens.
|
|
1233
|
+
// However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient.
|
|
1234
|
+
cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens);
|
|
1235
|
+
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
|
1236
|
+
#else
|
|
1237
|
+
// It's preferable to do the calculation as a matrix-matrix multiplication with n_tokens in dimension 1.
|
|
1238
|
+
// The permutations are noops and only change how the tensor data is interpreted.
|
|
1239
|
+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
1240
|
+
cur = ggml_mul_mat(ctx0, v_mla, cur);
|
|
1241
|
+
cur = ggml_permute(ctx0, cur, 0, 2, 1, 3);
|
|
1242
|
+
cur = ggml_cont(ctx0, cur); // Needed because ggml_reshape_2d expects contiguous inputs.
|
|
1243
|
+
#endif
|
|
1244
|
+
}
|
|
1245
|
+
|
|
1246
|
+
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
|
1200
1247
|
} else {
|
|
1201
1248
|
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
|
1202
1249
|
|
|
@@ -1234,9 +1281,14 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|
|
1234
1281
|
|
|
1235
1282
|
ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq);
|
|
1236
1283
|
|
|
1237
|
-
|
|
1284
|
+
// for MLA with the absorption optimization, we need to "decompress" from MQA back to MHA
|
|
1285
|
+
if (v_mla) {
|
|
1286
|
+
kqv = ggml_mul_mat(ctx0, v_mla, kqv);
|
|
1287
|
+
}
|
|
1288
|
+
|
|
1289
|
+
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
|
1238
1290
|
|
|
1239
|
-
cur = ggml_cont_2d(ctx0,
|
|
1291
|
+
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
|
1240
1292
|
|
|
1241
1293
|
if (!cparams.offload_kqv) {
|
|
1242
1294
|
// all nodes between the KV store and the attention output are run on the CPU
|
|
@@ -1271,6 +1323,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1271
1323
|
ggml_tensor * k_cur,
|
|
1272
1324
|
ggml_tensor * v_cur,
|
|
1273
1325
|
ggml_tensor * kq_b,
|
|
1326
|
+
ggml_tensor * v_mla,
|
|
1274
1327
|
float kq_scale,
|
|
1275
1328
|
int il) const {
|
|
1276
1329
|
GGML_UNUSED(n_tokens);
|
|
@@ -1292,7 +1345,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1292
1345
|
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
|
1293
1346
|
//cb(k, "v", il);
|
|
1294
1347
|
|
|
1295
|
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
|
|
1348
|
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
|
1296
1349
|
|
|
1297
1350
|
cb(cur, "kqv_out", il);
|
|
1298
1351
|
|
|
@@ -1346,6 +1399,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1346
1399
|
ggml_tensor * k_cur,
|
|
1347
1400
|
ggml_tensor * v_cur,
|
|
1348
1401
|
ggml_tensor * kq_b,
|
|
1402
|
+
ggml_tensor * v_mla,
|
|
1349
1403
|
float kq_scale,
|
|
1350
1404
|
int il) const {
|
|
1351
1405
|
// these nodes are added to the graph together so that they are not reordered
|
|
@@ -1366,8 +1420,6 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1366
1420
|
|
|
1367
1421
|
// store to KV cache
|
|
1368
1422
|
{
|
|
1369
|
-
GGML_ASSERT(!kv_self->recurrent);
|
|
1370
|
-
|
|
1371
1423
|
const auto kv_head = kv_self->head;
|
|
1372
1424
|
|
|
1373
1425
|
GGML_ASSERT(kv_self->size == n_ctx);
|
|
@@ -1431,7 +1483,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1431
1483
|
ggml_element_size(kv_self->v_l[il])*n_ctx*n_embd_head_v,
|
|
1432
1484
|
0);
|
|
1433
1485
|
|
|
1434
|
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_trans, kq_scale);
|
|
1486
|
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, v_trans, kq_scale);
|
|
1435
1487
|
cb(cur, "kqv_out", il);
|
|
1436
1488
|
|
|
1437
1489
|
if (wo) {
|
|
@@ -1471,6 +1523,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1471
1523
|
ggml_tensor * k_cur,
|
|
1472
1524
|
ggml_tensor * v_cur,
|
|
1473
1525
|
ggml_tensor * kq_b,
|
|
1526
|
+
ggml_tensor * v_mla,
|
|
1474
1527
|
float kq_scale,
|
|
1475
1528
|
int il) const {
|
|
1476
1529
|
// these nodes are added to the graph together so that they are not reordered
|
|
@@ -1490,7 +1543,7 @@ ggml_tensor * llm_graph_context::build_attn(
|
|
|
1490
1543
|
ggml_tensor * v = ggml_permute(ctx0, v_cur, 0, 2, 1, 3);
|
|
1491
1544
|
//cb(k, "v", il);
|
|
1492
1545
|
|
|
1493
|
-
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, false, kq_scale);
|
|
1546
|
+
ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, false, kq_scale);
|
|
1494
1547
|
|
|
1495
1548
|
cb(cur, "kqv_out", il);
|
|
1496
1549
|
|
|
@@ -1516,7 +1569,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state(
|
|
|
1516
1569
|
ggml_tensor * state_mask,
|
|
1517
1570
|
int32_t n_state,
|
|
1518
1571
|
int32_t n_seqs) const {
|
|
1519
|
-
const
|
|
1572
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1520
1573
|
|
|
1521
1574
|
const auto n_kv = kv_self->n;
|
|
1522
1575
|
const auto kv_head = kv_self->head;
|
|
@@ -1548,7 +1601,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load(
|
|
|
1548
1601
|
ggml_tensor * state_mask,
|
|
1549
1602
|
const llama_ubatch & ubatch,
|
|
1550
1603
|
int il) const {
|
|
1551
|
-
const
|
|
1604
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1552
1605
|
|
|
1553
1606
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1554
1607
|
|
|
@@ -1569,7 +1622,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store(
|
|
|
1569
1622
|
ggml_tensor * token_shift,
|
|
1570
1623
|
const llama_ubatch & ubatch,
|
|
1571
1624
|
int il) const {
|
|
1572
|
-
const
|
|
1625
|
+
const llama_kv_cache_recurrent * kv_self = static_cast<const llama_kv_cache_recurrent *>(memory);
|
|
1573
1626
|
|
|
1574
1627
|
const auto token_shift_count = hparams.token_shift_count;
|
|
1575
1628
|
const auto n_embd = hparams.n_embd;
|
|
@@ -1659,4 +1712,3 @@ void llm_graph_context::build_pooling(
|
|
|
1659
1712
|
|
|
1660
1713
|
ggml_build_forward_expand(gf, cur);
|
|
1661
1714
|
}
|
|
1662
|
-
|