@fugood/llama.node 0.3.16 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +6 -1
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-cuda/arm64/llama-node.node +0 -0
- package/bin/linux-cuda/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +44 -2
- package/lib/index.js +132 -1
- package/lib/index.ts +203 -3
- package/package.json +2 -1
- package/src/EmbeddingWorker.cpp +1 -1
- package/src/LlamaCompletionWorker.cpp +374 -19
- package/src/LlamaCompletionWorker.h +31 -10
- package/src/LlamaContext.cpp +216 -7
- package/src/LlamaContext.h +12 -0
- package/src/common.hpp +15 -0
- package/src/llama.cpp/.github/workflows/build-linux-cross.yml +233 -0
- package/src/llama.cpp/.github/workflows/build.yml +89 -767
- package/src/llama.cpp/.github/workflows/docker.yml +9 -6
- package/src/llama.cpp/.github/workflows/release.yml +716 -0
- package/src/llama.cpp/.github/workflows/server.yml +19 -23
- package/src/llama.cpp/CMakeLists.txt +11 -1
- package/src/llama.cpp/cmake/build-info.cmake +8 -2
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +0 -6
- package/src/llama.cpp/common/CMakeLists.txt +35 -4
- package/src/llama.cpp/common/arg.cpp +844 -121
- package/src/llama.cpp/common/arg.h +9 -0
- package/src/llama.cpp/common/chat.cpp +129 -107
- package/src/llama.cpp/common/chat.h +2 -0
- package/src/llama.cpp/common/common.cpp +64 -518
- package/src/llama.cpp/common/common.h +35 -45
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +3 -0
- package/src/llama.cpp/common/llguidance.cpp +31 -47
- package/src/llama.cpp/common/minja/chat-template.hpp +23 -11
- package/src/llama.cpp/common/minja/minja.hpp +186 -127
- package/src/llama.cpp/common/regex-partial.cpp +204 -0
- package/src/llama.cpp/common/regex-partial.h +56 -0
- package/src/llama.cpp/common/sampling.cpp +60 -50
- package/src/llama.cpp/docs/build.md +122 -7
- package/src/llama.cpp/examples/CMakeLists.txt +2 -32
- package/src/llama.cpp/examples/batched/batched.cpp +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +9 -12
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +89 -15
- package/src/llama.cpp/examples/passkey/passkey.cpp +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +1 -1
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +1 -1
- package/src/llama.cpp/examples/sycl/build.sh +2 -2
- package/src/llama.cpp/examples/sycl/win-build-sycl.bat +2 -2
- package/src/llama.cpp/examples/training/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/training/finetune.cpp +96 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +35 -2
- package/src/llama.cpp/ggml/cmake/GitVars.cmake +22 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +4 -4
- package/src/llama.cpp/ggml/include/ggml-cpp.h +1 -1
- package/src/llama.cpp/ggml/include/ggml-cpu.h +5 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +47 -28
- package/src/llama.cpp/ggml/include/ggml-rpc.h +6 -1
- package/src/llama.cpp/ggml/include/ggml.h +76 -106
- package/src/llama.cpp/ggml/src/CMakeLists.txt +11 -8
- package/src/llama.cpp/ggml/src/ggml-alloc.c +4 -1
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +9 -5
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +0 -2
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +8 -4
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +5 -5
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +692 -1534
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.h +613 -122
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +135 -1
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +507 -137
- package/src/llama.cpp/ggml/src/ggml-common.h +12 -6
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +66 -33
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.cpp +158 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/binary-ops.h +16 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/common.h +72 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.cpp +896 -194
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +2 -21
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +1060 -410
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +1008 -13533
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +31 -16
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.cpp +90 -12
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kernels.h +47 -13
- package/src/llama.cpp/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +266 -72
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1034 -88
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.cpp +8796 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ops.h +110 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/simd-mappings.h +892 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.cpp +186 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/unary-ops.h +28 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.cpp +252 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/vec.h +802 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +23 -4
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +7 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +0 -4
- package/src/llama.cpp/ggml/src/ggml-impl.h +52 -18
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +106 -14
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +67 -119
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +1023 -262
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +368 -190
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -6
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +307 -40
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +125 -45
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +10 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.cpp +239 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/binbcast.hpp +39 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +0 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +9 -307
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +72 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +14 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +59 -21
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +7 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +79 -90
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +944 -438
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +22 -23
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +37 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.cpp +24 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/getrows.hpp +1 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +507 -411
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +84 -74
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +185 -89
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +37 -49
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.hpp +7 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +4 -14
- package/src/llama.cpp/ggml/src/ggml-sycl/quants.hpp +83 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +204 -118
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.hpp +1 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +128 -53
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +83 -49
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1278 -282
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +32 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +133 -30
- package/src/llama.cpp/ggml/src/ggml.c +170 -265
- package/src/llama.cpp/ggml/src/gguf.cpp +34 -33
- package/src/llama.cpp/include/llama.h +82 -22
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-llama4.gguf.out +46 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-pixtral.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-all.txt +5 -3
- package/src/llama.cpp/requirements/requirements-gguf_editor_gui.txt +3 -0
- package/src/llama.cpp/scripts/xxd.cmake +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +4 -2
- package/src/llama.cpp/src/llama-adapter.cpp +43 -1
- package/src/llama.cpp/src/llama-arch.cpp +163 -17
- package/src/llama.cpp/src/llama-arch.h +16 -0
- package/src/llama.cpp/src/llama-batch.cpp +5 -1
- package/src/llama.cpp/src/llama-batch.h +2 -1
- package/src/llama.cpp/src/llama-chat.cpp +91 -16
- package/src/llama.cpp/src/llama-chat.h +7 -2
- package/src/llama.cpp/src/llama-context.cpp +479 -575
- package/src/llama.cpp/src/llama-context.h +44 -33
- package/src/llama.cpp/src/llama-cparams.h +1 -0
- package/src/llama.cpp/src/llama-graph.cpp +209 -157
- package/src/llama.cpp/src/llama-graph.h +38 -14
- package/src/llama.cpp/src/llama-hparams.h +13 -0
- package/src/llama.cpp/src/llama-kv-cache.cpp +1604 -543
- package/src/llama.cpp/src/llama-kv-cache.h +283 -171
- package/src/llama.cpp/src/llama-memory.h +12 -2
- package/src/llama.cpp/src/llama-mmap.cpp +1 -1
- package/src/llama.cpp/src/llama-model-loader.cpp +34 -20
- package/src/llama.cpp/src/llama-model-loader.h +5 -3
- package/src/llama.cpp/src/llama-model-saver.cpp +281 -0
- package/src/llama.cpp/src/llama-model-saver.h +37 -0
- package/src/llama.cpp/src/llama-model.cpp +1803 -330
- package/src/llama.cpp/src/llama-model.h +21 -2
- package/src/llama.cpp/src/llama-quant.cpp +33 -10
- package/src/llama.cpp/src/llama-sampling.cpp +25 -7
- package/src/llama.cpp/src/llama-vocab.cpp +86 -10
- package/src/llama.cpp/src/llama-vocab.h +6 -0
- package/src/llama.cpp/src/llama.cpp +15 -1
- package/src/llama.cpp/tests/CMakeLists.txt +52 -31
- package/src/llama.cpp/tests/test-arg-parser.cpp +51 -4
- package/src/llama.cpp/tests/test-backend-ops.cpp +189 -90
- package/src/llama.cpp/tests/test-chat-template.cpp +26 -6
- package/src/llama.cpp/tests/test-chat.cpp +15 -3
- package/src/llama.cpp/{examples/gbnf-validator/gbnf-validator.cpp → tests/test-gbnf-validator.cpp} +2 -2
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -2
- package/src/llama.cpp/tests/test-grammar-llguidance.cpp +63 -2
- package/src/llama.cpp/tests/test-grammar-parser.cpp +3 -1
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -1
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -1
- package/src/llama.cpp/tests/test-mtmd-c-api.c +63 -0
- package/src/llama.cpp/tests/test-opt.cpp +33 -21
- package/src/llama.cpp/{examples/quantize-stats/quantize-stats.cpp → tests/test-quantize-stats.cpp} +3 -1
- package/src/llama.cpp/tests/test-regex-partial.cpp +288 -0
- package/src/llama.cpp/tests/test-sampling.cpp +1 -1
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +2 -1
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +2 -1
- package/src/llama.cpp/tools/CMakeLists.txt +39 -0
- package/src/llama.cpp/{examples → tools}/batched-bench/batched-bench.cpp +3 -3
- package/src/llama.cpp/{examples → tools}/export-lora/export-lora.cpp +1 -1
- package/src/llama.cpp/{examples → tools}/gguf-split/gguf-split.cpp +15 -16
- package/src/llama.cpp/{examples → tools}/imatrix/imatrix.cpp +11 -9
- package/src/llama.cpp/{examples → tools}/llama-bench/llama-bench.cpp +623 -274
- package/src/llama.cpp/{examples → tools}/main/main.cpp +22 -14
- package/src/llama.cpp/tools/mtmd/CMakeLists.txt +47 -0
- package/src/llama.cpp/tools/mtmd/clip-impl.h +365 -0
- package/src/llama.cpp/tools/mtmd/clip.cpp +3646 -0
- package/src/llama.cpp/tools/mtmd/clip.h +99 -0
- package/src/llama.cpp/tools/mtmd/deprecation-warning.cpp +22 -0
- package/src/llama.cpp/tools/mtmd/mtmd-cli.cpp +370 -0
- package/src/llama.cpp/tools/mtmd/mtmd-helper.cpp +310 -0
- package/src/llama.cpp/tools/mtmd/mtmd.cpp +678 -0
- package/src/llama.cpp/tools/mtmd/mtmd.h +331 -0
- package/src/llama.cpp/{examples → tools}/perplexity/perplexity.cpp +21 -5
- package/src/llama.cpp/{examples → tools}/quantize/quantize.cpp +53 -3
- package/src/llama.cpp/tools/rpc/CMakeLists.txt +4 -0
- package/src/llama.cpp/tools/rpc/rpc-server.cpp +322 -0
- package/src/llama.cpp/tools/run/CMakeLists.txt +16 -0
- package/src/llama.cpp/{examples → tools}/run/run.cpp +30 -30
- package/src/llama.cpp/{examples → tools}/server/CMakeLists.txt +2 -1
- package/src/llama.cpp/{examples → tools}/server/httplib.h +313 -247
- package/src/llama.cpp/{examples → tools}/server/server.cpp +529 -215
- package/src/llama.cpp/{examples → tools}/server/utils.hpp +427 -6
- package/src/llama.cpp/{examples → tools}/tts/tts.cpp +6 -9
- package/src/llama.cpp/cmake/arm64-windows-msvc.cmake +0 -6
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/infill/infill.cpp +0 -590
- package/src/llama.cpp/examples/llava/CMakeLists.txt +0 -66
- package/src/llama.cpp/examples/llava/android/build_64.sh +0 -8
- package/src/llama.cpp/examples/llava/clip-quantize-cli.cpp +0 -59
- package/src/llama.cpp/examples/llava/clip.cpp +0 -3206
- package/src/llama.cpp/examples/llava/clip.h +0 -118
- package/src/llama.cpp/examples/llava/gemma3-cli.cpp +0 -341
- package/src/llama.cpp/examples/llava/llava-cli.cpp +0 -332
- package/src/llama.cpp/examples/llava/llava.cpp +0 -574
- package/src/llama.cpp/examples/llava/llava.h +0 -49
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +0 -354
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +0 -584
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/rpc/CMakeLists.txt +0 -2
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +0 -171
- package/src/llama.cpp/examples/run/CMakeLists.txt +0 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +0 -30
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +0 -234
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +0 -197
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +0 -190
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +0 -204
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q8_0.cpp +0 -191
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +0 -218
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +0 -216
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +0 -295
- /package/src/llama.cpp/{examples → tools}/batched-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/completions.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/cvector-generator.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/mean.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/negative.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/pca.hpp +0 -0
- /package/src/llama.cpp/{examples → tools}/cvector-generator/positive.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/export-lora/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/gguf-split/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/imatrix/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/llama-bench/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/main/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples/llava → tools/mtmd}/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/perplexity/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/quantize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/run/linenoise.cpp/linenoise.h +0 -0
- /package/src/llama.cpp/{examples → tools}/server/bench/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/server/tests/requirements.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/CMakeLists.txt +0 -0
- /package/src/llama.cpp/{examples → tools}/tokenize/tokenize.cpp +0 -0
- /package/src/llama.cpp/{examples → tools}/tts/CMakeLists.txt +0 -0
|
@@ -6,7 +6,6 @@
|
|
|
6
6
|
#include "llama-model.h"
|
|
7
7
|
#include "llama-kv-cache.h"
|
|
8
8
|
|
|
9
|
-
#include <cassert>
|
|
10
9
|
#include <cstring>
|
|
11
10
|
#include <stdexcept>
|
|
12
11
|
#include <cinttypes>
|
|
@@ -94,6 +93,7 @@ llama_context::llama_context(
|
|
|
94
93
|
}
|
|
95
94
|
|
|
96
95
|
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
|
96
|
+
cparams.op_offload = params.op_offload;
|
|
97
97
|
|
|
98
98
|
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
|
99
99
|
|
|
@@ -113,12 +113,10 @@ llama_context::llama_context(
|
|
|
113
113
|
}
|
|
114
114
|
|
|
115
115
|
if (n_ctx_per_seq > hparams.n_ctx_train) {
|
|
116
|
-
LLAMA_LOG_WARN("%s:
|
|
116
|
+
LLAMA_LOG_WARN("%s: n_ctx_per_seq (%u) > n_ctx_train (%u) -- possible training context overflow\n",
|
|
117
117
|
__func__, n_ctx_per_seq, hparams.n_ctx_train);
|
|
118
118
|
}
|
|
119
119
|
|
|
120
|
-
logits_all = params.logits_all;
|
|
121
|
-
|
|
122
120
|
if (!hparams.vocab_only) {
|
|
123
121
|
// GPU backends
|
|
124
122
|
for (auto * dev : model.devices) {
|
|
@@ -176,44 +174,13 @@ llama_context::llama_context(
|
|
|
176
174
|
}
|
|
177
175
|
|
|
178
176
|
// init the memory module
|
|
179
|
-
// TODO: for now, always create a unified KV cache
|
|
180
177
|
if (!hparams.vocab_only) {
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams));
|
|
186
|
-
|
|
187
|
-
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
|
188
|
-
|
|
189
|
-
uint32_t kv_size = cparams.n_ctx;
|
|
190
|
-
ggml_type type_k = params.type_k;
|
|
191
|
-
ggml_type type_v = params.type_v;
|
|
192
|
-
|
|
193
|
-
if (llama_model_is_recurrent(&model)) {
|
|
194
|
-
// Mamba needs at least as many KV cells as there are sequences kept at any time
|
|
195
|
-
kv_size = std::max((uint32_t) 1, params.n_seq_max);
|
|
196
|
-
// it's probably best to keep as much precision as possible for the states
|
|
197
|
-
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
|
|
198
|
-
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states
|
|
199
|
-
}
|
|
200
|
-
|
|
201
|
-
GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0);
|
|
202
|
-
GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0);
|
|
203
|
-
|
|
204
|
-
if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) {
|
|
205
|
-
throw std::runtime_error("failed to initialize self-attention cache");
|
|
206
|
-
}
|
|
207
|
-
|
|
208
|
-
{
|
|
209
|
-
const size_t memory_size_k = kv_self->size_k_bytes();
|
|
210
|
-
const size_t memory_size_v = kv_self->size_v_bytes();
|
|
178
|
+
llama_memory_params params_mem = {
|
|
179
|
+
/*.type_k =*/ params.type_k,
|
|
180
|
+
/*.type_v =*/ params.type_v,
|
|
181
|
+
};
|
|
211
182
|
|
|
212
|
-
|
|
213
|
-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
|
|
214
|
-
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
|
|
215
|
-
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
|
|
216
|
-
}
|
|
183
|
+
memory.reset(model.create_memory(params_mem, cparams));
|
|
217
184
|
}
|
|
218
185
|
|
|
219
186
|
// init backends
|
|
@@ -255,7 +222,8 @@ llama_context::llama_context(
|
|
|
255
222
|
model.n_devices() > 1 &&
|
|
256
223
|
model.params.n_gpu_layers > (int) model.hparams.n_layer &&
|
|
257
224
|
model.params.split_mode == LLAMA_SPLIT_MODE_LAYER &&
|
|
258
|
-
cparams.offload_kqv
|
|
225
|
+
cparams.offload_kqv &&
|
|
226
|
+
!model.has_tensor_overrides();
|
|
259
227
|
|
|
260
228
|
// pipeline parallelism requires support for async compute and events in all devices
|
|
261
229
|
if (pipeline_parallel) {
|
|
@@ -276,7 +244,7 @@ llama_context::llama_context(
|
|
|
276
244
|
}
|
|
277
245
|
}
|
|
278
246
|
|
|
279
|
-
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel));
|
|
247
|
+
sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload));
|
|
280
248
|
|
|
281
249
|
if (pipeline_parallel) {
|
|
282
250
|
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
|
@@ -284,7 +252,7 @@ llama_context::llama_context(
|
|
|
284
252
|
}
|
|
285
253
|
|
|
286
254
|
// reserve worst-case graph
|
|
287
|
-
if (!hparams.vocab_only) {
|
|
255
|
+
if (!hparams.vocab_only && memory) {
|
|
288
256
|
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
|
|
289
257
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
290
258
|
|
|
@@ -294,10 +262,7 @@ llama_context::llama_context(
|
|
|
294
262
|
// TODO: something cleaner
|
|
295
263
|
const auto n_outputs_save = n_outputs;
|
|
296
264
|
|
|
297
|
-
|
|
298
|
-
n_outputs = n_tokens;
|
|
299
|
-
|
|
300
|
-
LLAMA_LOG_DEBUG("%s: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
265
|
+
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
|
301
266
|
|
|
302
267
|
int n_splits_pp = -1;
|
|
303
268
|
int n_nodes_pp = -1;
|
|
@@ -306,15 +271,24 @@ llama_context::llama_context(
|
|
|
306
271
|
int n_nodes_tg = -1;
|
|
307
272
|
|
|
308
273
|
// simulate full KV cache
|
|
309
|
-
kv_self
|
|
274
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
275
|
+
|
|
276
|
+
kv_self->set_full();
|
|
310
277
|
|
|
311
278
|
cross.v_embd.clear();
|
|
312
279
|
|
|
313
280
|
// reserve pp graph first so that buffers are only allocated once
|
|
314
281
|
{
|
|
315
282
|
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
283
|
+
|
|
284
|
+
// max number of outputs
|
|
285
|
+
n_outputs = ubatch_pp.n_tokens;
|
|
286
|
+
|
|
287
|
+
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
|
288
|
+
|
|
316
289
|
auto * gf = graph_init();
|
|
317
290
|
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
291
|
+
|
|
318
292
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
319
293
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
320
294
|
}
|
|
@@ -326,11 +300,18 @@ llama_context::llama_context(
|
|
|
326
300
|
// reserve with tg graph to get the number of splits and nodes
|
|
327
301
|
{
|
|
328
302
|
llama_ubatch ubatch_tg = { true, 1, 1, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
303
|
+
|
|
304
|
+
n_outputs = ubatch_tg.n_tokens;
|
|
305
|
+
|
|
306
|
+
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_tg.n_tokens, ubatch_tg.n_seqs);
|
|
307
|
+
|
|
329
308
|
auto * gf = graph_init();
|
|
330
309
|
graph_build(ctx_compute.get(), gf, ubatch_tg, LLM_GRAPH_TYPE_DEFAULT);
|
|
310
|
+
|
|
331
311
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
332
312
|
throw std::runtime_error("failed to allocate compute tg buffers");
|
|
333
313
|
}
|
|
314
|
+
|
|
334
315
|
n_splits_tg = ggml_backend_sched_get_n_splits(sched.get());
|
|
335
316
|
n_nodes_tg = ggml_graph_n_nodes(gf);
|
|
336
317
|
}
|
|
@@ -338,8 +319,14 @@ llama_context::llama_context(
|
|
|
338
319
|
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
|
339
320
|
{
|
|
340
321
|
llama_ubatch ubatch_pp = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
322
|
+
|
|
323
|
+
n_outputs = ubatch_pp.n_tokens;
|
|
324
|
+
|
|
325
|
+
LLAMA_LOG_DEBUG("%s: reserving graph for n_tokens = %d, n_seqs = %d\n", __func__, ubatch_pp.n_tokens, ubatch_pp.n_seqs);
|
|
326
|
+
|
|
341
327
|
auto * gf = graph_init();
|
|
342
328
|
graph_build(ctx_compute.get(), gf, ubatch_pp, LLM_GRAPH_TYPE_DEFAULT);
|
|
329
|
+
|
|
343
330
|
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
|
344
331
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
345
332
|
}
|
|
@@ -372,7 +359,9 @@ llama_context::llama_context(
|
|
|
372
359
|
}
|
|
373
360
|
}
|
|
374
361
|
|
|
375
|
-
llama_context::~llama_context()
|
|
362
|
+
llama_context::~llama_context() {
|
|
363
|
+
ggml_opt_free(opt_ctx);
|
|
364
|
+
}
|
|
376
365
|
|
|
377
366
|
void llama_context::synchronize() {
|
|
378
367
|
ggml_backend_sched_synchronize(sched.get());
|
|
@@ -408,6 +397,18 @@ const llama_model & llama_context::get_model() const {
|
|
|
408
397
|
return model;
|
|
409
398
|
}
|
|
410
399
|
|
|
400
|
+
const llama_cparams & llama_context::get_cparams() const {
|
|
401
|
+
return cparams;
|
|
402
|
+
}
|
|
403
|
+
|
|
404
|
+
ggml_backend_sched_t llama_context::get_sched() const {
|
|
405
|
+
return sched.get();
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
ggml_context * llama_context::get_ctx_compute() const {
|
|
409
|
+
return ctx_compute.get();
|
|
410
|
+
}
|
|
411
|
+
|
|
411
412
|
uint32_t llama_context::n_ctx() const {
|
|
412
413
|
return cparams.n_ctx;
|
|
413
414
|
}
|
|
@@ -437,345 +438,21 @@ uint32_t llama_context::n_threads_batch() const {
|
|
|
437
438
|
}
|
|
438
439
|
|
|
439
440
|
llama_kv_cache * llama_context::get_kv_self() {
|
|
440
|
-
|
|
441
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
442
|
+
return kv_self;
|
|
441
443
|
}
|
|
442
444
|
|
|
443
445
|
const llama_kv_cache * llama_context::get_kv_self() const {
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
ggml_tensor * llama_context::build_rope_shift(
|
|
448
|
-
ggml_context * ctx0,
|
|
449
|
-
ggml_tensor * cur,
|
|
450
|
-
ggml_tensor * shift,
|
|
451
|
-
ggml_tensor * factors,
|
|
452
|
-
float freq_base,
|
|
453
|
-
float freq_scale,
|
|
454
|
-
ggml_backend_buffer * bbuf) const {
|
|
455
|
-
const auto & n_ctx_orig = cparams.n_ctx_orig_yarn;
|
|
456
|
-
|
|
457
|
-
const auto & yarn_ext_factor = cparams.yarn_ext_factor;
|
|
458
|
-
const auto & yarn_attn_factor = cparams.yarn_attn_factor;
|
|
459
|
-
const auto & yarn_beta_fast = cparams.yarn_beta_fast;
|
|
460
|
-
const auto & yarn_beta_slow = cparams.yarn_beta_slow;
|
|
461
|
-
|
|
462
|
-
const auto & hparams = model.hparams;
|
|
463
|
-
|
|
464
|
-
const auto & n_rot = hparams.n_rot;
|
|
465
|
-
const auto & rope_type = hparams.rope_type;
|
|
466
|
-
|
|
467
|
-
ggml_tensor * tmp;
|
|
468
|
-
|
|
469
|
-
if (ggml_is_quantized(cur->type)) {
|
|
470
|
-
// dequantize to f32 -> RoPE -> quantize back
|
|
471
|
-
tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32);
|
|
472
|
-
|
|
473
|
-
if (bbuf) {
|
|
474
|
-
for (const auto & backend : backends) {
|
|
475
|
-
// Figure out which backend KV cache belongs to
|
|
476
|
-
if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) {
|
|
477
|
-
ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get());
|
|
478
|
-
break;
|
|
479
|
-
}
|
|
480
|
-
}
|
|
481
|
-
}
|
|
482
|
-
|
|
483
|
-
tmp = ggml_rope_ext_inplace(ctx0, tmp,
|
|
484
|
-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
485
|
-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
|
486
|
-
|
|
487
|
-
tmp = ggml_cpy(ctx0, tmp, cur);
|
|
488
|
-
} else {
|
|
489
|
-
// we rotate only the first n_rot dimensions
|
|
490
|
-
tmp = ggml_rope_ext_inplace(ctx0, cur,
|
|
491
|
-
shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
|
492
|
-
yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow);
|
|
493
|
-
}
|
|
494
|
-
|
|
495
|
-
return tmp;
|
|
496
|
-
}
|
|
497
|
-
|
|
498
|
-
class llm_graph_input_k_shift : public llm_graph_input_i {
|
|
499
|
-
public:
|
|
500
|
-
llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {}
|
|
501
|
-
virtual ~llm_graph_input_k_shift() = default;
|
|
502
|
-
|
|
503
|
-
void set_input(const llama_ubatch * ubatch) override;
|
|
504
|
-
|
|
505
|
-
ggml_tensor * k_shift; // I32 [kv_size]
|
|
506
|
-
|
|
507
|
-
const llama_kv_cache_unified * kv_self;
|
|
508
|
-
};
|
|
509
|
-
|
|
510
|
-
void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) {
|
|
511
|
-
GGML_UNUSED(ubatch);
|
|
512
|
-
|
|
513
|
-
if (k_shift) {
|
|
514
|
-
assert(ggml_backend_buffer_is_host(k_shift->buffer));
|
|
515
|
-
|
|
516
|
-
int32_t * data = (int32_t *) k_shift->data;
|
|
517
|
-
|
|
518
|
-
for (uint32_t i = 0; i < kv_self->size; ++i) {
|
|
519
|
-
data[i] = kv_self->cells[i].delta;
|
|
520
|
-
}
|
|
521
|
-
}
|
|
522
|
-
}
|
|
523
|
-
|
|
524
|
-
llm_graph_result_ptr llama_context::build_kv_self_shift(
|
|
525
|
-
ggml_context * ctx0,
|
|
526
|
-
ggml_cgraph * gf) const {
|
|
527
|
-
auto res = std::make_unique<llm_graph_result>();
|
|
528
|
-
|
|
529
|
-
const auto & hparams = model.hparams;
|
|
530
|
-
|
|
531
|
-
const auto & n_layer = hparams.n_layer;
|
|
532
|
-
|
|
533
|
-
const auto & n_embd_head_k = hparams.n_embd_head_k;
|
|
534
|
-
//const auto & n_embd_head_v = hparams.n_embd_head_v;
|
|
535
|
-
|
|
536
|
-
//GGML_ASSERT(kv_self->size == n_ctx);
|
|
537
|
-
|
|
538
|
-
auto inp = std::make_unique<llm_graph_input_k_shift>(kv_self.get());
|
|
539
|
-
|
|
540
|
-
inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx);
|
|
541
|
-
ggml_set_input(inp->k_shift);
|
|
542
|
-
|
|
543
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
544
|
-
const int64_t n_head_kv = hparams.n_head_kv(il);
|
|
545
|
-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
546
|
-
|
|
547
|
-
const bool is_swa = hparams.is_swa(il);
|
|
548
|
-
|
|
549
|
-
// note: the swa rope params could become part of the cparams in the future
|
|
550
|
-
// if we decide to make them configurable, like the non-sliding ones
|
|
551
|
-
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
|
|
552
|
-
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
|
|
553
|
-
|
|
554
|
-
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
|
|
555
|
-
|
|
556
|
-
ggml_tensor * k =
|
|
557
|
-
ggml_view_3d(ctx0, kv_self->k_l[il],
|
|
558
|
-
n_embd_head_k, n_head_kv, kv_self->size,
|
|
559
|
-
ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k),
|
|
560
|
-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
561
|
-
0);
|
|
562
|
-
|
|
563
|
-
ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, kv_self->k_l[il]->buffer);
|
|
564
|
-
|
|
565
|
-
ggml_build_forward_expand(gf, cur);
|
|
566
|
-
}
|
|
567
|
-
|
|
568
|
-
res->add_input(std::move(inp));
|
|
569
|
-
|
|
570
|
-
return res;
|
|
571
|
-
}
|
|
572
|
-
|
|
573
|
-
llm_graph_result_ptr llama_context::build_kv_self_defrag(
|
|
574
|
-
ggml_context * ctx0,
|
|
575
|
-
ggml_cgraph * gf) const {
|
|
576
|
-
auto res = std::make_unique<llm_graph_result>();
|
|
577
|
-
|
|
578
|
-
const auto & hparams = model.hparams;
|
|
579
|
-
|
|
580
|
-
const auto & ids = kv_self->defrag_info.ids;
|
|
581
|
-
|
|
582
|
-
#if 0
|
|
583
|
-
// CPU defrag
|
|
584
|
-
//
|
|
585
|
-
// TODO: optimizations are possible:
|
|
586
|
-
// - multiple threads
|
|
587
|
-
// - avoid copying to the host memory when already there
|
|
588
|
-
//
|
|
589
|
-
// likely not worth the effort, as we have ggml_graph based defrag
|
|
590
|
-
//
|
|
591
|
-
|
|
592
|
-
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
|
593
|
-
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
|
594
|
-
|
|
595
|
-
const uint32_t kv_size = size;
|
|
596
|
-
|
|
597
|
-
std::vector<uint8_t> buf_k;
|
|
598
|
-
std::vector<uint8_t> buf_v;
|
|
599
|
-
|
|
600
|
-
for (uint32_t il = 0; il < n_layer; ++il) {
|
|
601
|
-
const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa);
|
|
602
|
-
const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size);
|
|
603
|
-
|
|
604
|
-
const size_t v_size_el = ggml_type_size(v_l[il]->type);
|
|
605
|
-
const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size);
|
|
606
|
-
|
|
607
|
-
buf_k.resize(k_size);
|
|
608
|
-
buf_v.resize(v_size);
|
|
609
|
-
|
|
610
|
-
ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size());
|
|
611
|
-
ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size());
|
|
612
|
-
|
|
613
|
-
// batch move [i, i+nm) to [id, id+nm)
|
|
614
|
-
// note: cells can move only to a lower index
|
|
615
|
-
for (uint32_t i = 0; i < n_kv; ++i) {
|
|
616
|
-
const uint32_t id = ids[i];
|
|
617
|
-
|
|
618
|
-
if (i == id || id == n_kv) {
|
|
619
|
-
continue;
|
|
620
|
-
}
|
|
621
|
-
|
|
622
|
-
uint32_t nm = 1;
|
|
623
|
-
|
|
624
|
-
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
|
625
|
-
nm++;
|
|
626
|
-
}
|
|
627
|
-
|
|
628
|
-
// move keys
|
|
629
|
-
{
|
|
630
|
-
const int64_t os = i*k_size_row;
|
|
631
|
-
const int64_t od = id*k_size_row;
|
|
632
|
-
|
|
633
|
-
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
|
634
|
-
}
|
|
635
|
-
|
|
636
|
-
// move values (note: they are transposed)
|
|
637
|
-
{
|
|
638
|
-
const int64_t os = i;
|
|
639
|
-
const int64_t od = id;
|
|
640
|
-
|
|
641
|
-
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
642
|
-
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
|
643
|
-
}
|
|
644
|
-
}
|
|
645
|
-
|
|
646
|
-
i += nm - 1;
|
|
647
|
-
}
|
|
648
|
-
|
|
649
|
-
ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size());
|
|
650
|
-
ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size());
|
|
651
|
-
}
|
|
652
|
-
#else
|
|
653
|
-
for (uint32_t i = 0; i < ids.size(); ++i) {
|
|
654
|
-
const uint32_t id = ids[i];
|
|
655
|
-
|
|
656
|
-
if (i == id || id == ids.size()) {
|
|
657
|
-
continue;
|
|
658
|
-
}
|
|
659
|
-
|
|
660
|
-
uint32_t nm = 1;
|
|
661
|
-
|
|
662
|
-
while (i + nm < ids.size() && ids[i + nm] == id + nm) {
|
|
663
|
-
nm++;
|
|
664
|
-
}
|
|
665
|
-
|
|
666
|
-
for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT
|
|
667
|
-
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
|
668
|
-
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
|
669
|
-
|
|
670
|
-
ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il],
|
|
671
|
-
n_embd_k_gqa, nm,
|
|
672
|
-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
673
|
-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i));
|
|
674
|
-
|
|
675
|
-
ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il],
|
|
676
|
-
n_embd_k_gqa, nm,
|
|
677
|
-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa),
|
|
678
|
-
ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id));
|
|
679
|
-
|
|
680
|
-
ggml_tensor * view_v_src;
|
|
681
|
-
ggml_tensor * view_v_dst;
|
|
682
|
-
|
|
683
|
-
if (cparams.flash_attn) {
|
|
684
|
-
// NOTE: the V cache is not transposed when using flash attention
|
|
685
|
-
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
686
|
-
n_embd_v_gqa, nm,
|
|
687
|
-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
|
688
|
-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i));
|
|
689
|
-
|
|
690
|
-
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
691
|
-
n_embd_v_gqa, nm,
|
|
692
|
-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa),
|
|
693
|
-
ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id));
|
|
694
|
-
} else {
|
|
695
|
-
view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
696
|
-
nm, n_embd_v_gqa,
|
|
697
|
-
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
|
698
|
-
ggml_row_size(kv_self->v_l[il]->type, i));
|
|
699
|
-
|
|
700
|
-
view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il],
|
|
701
|
-
nm, n_embd_v_gqa,
|
|
702
|
-
ggml_row_size(kv_self->v_l[il]->type, kv_self->size),
|
|
703
|
-
ggml_row_size(kv_self->v_l[il]->type, id));
|
|
704
|
-
}
|
|
705
|
-
|
|
706
|
-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst));
|
|
707
|
-
ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst));
|
|
708
|
-
}
|
|
709
|
-
|
|
710
|
-
i += nm - 1;
|
|
711
|
-
}
|
|
712
|
-
|
|
713
|
-
//LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes);
|
|
714
|
-
#endif
|
|
715
|
-
|
|
716
|
-
return res;
|
|
446
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
447
|
+
return kv_self;
|
|
717
448
|
}
|
|
718
449
|
|
|
719
450
|
void llama_context::kv_self_update() {
|
|
720
|
-
auto & kv = kv_self;
|
|
721
|
-
|
|
722
451
|
bool need_reserve = false;
|
|
723
452
|
|
|
724
|
-
|
|
725
|
-
if (!kv->get_can_shift()) {
|
|
726
|
-
GGML_ABORT("The current context does not support K-shift");
|
|
727
|
-
}
|
|
728
|
-
|
|
729
|
-
LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__);
|
|
730
|
-
|
|
731
|
-
// apply K-shift if needed
|
|
732
|
-
if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) {
|
|
733
|
-
ggml_backend_sched_reset(sched.get());
|
|
734
|
-
|
|
735
|
-
auto * gf = graph_init();
|
|
736
|
-
|
|
737
|
-
auto res = build_kv_self_shift(ctx_compute.get(), gf);
|
|
738
|
-
|
|
739
|
-
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
740
|
-
|
|
741
|
-
res->set_inputs(nullptr);
|
|
742
|
-
|
|
743
|
-
graph_compute(gf, false);
|
|
744
|
-
|
|
745
|
-
need_reserve = true;
|
|
746
|
-
}
|
|
747
|
-
|
|
748
|
-
{
|
|
749
|
-
kv->has_shift = false;
|
|
750
|
-
|
|
751
|
-
for (uint32_t i = 0; i < kv->size; ++i) {
|
|
752
|
-
kv->cells[i].delta = 0;
|
|
753
|
-
}
|
|
754
|
-
}
|
|
755
|
-
}
|
|
756
|
-
|
|
757
|
-
// defragment the KV cache if needed
|
|
758
|
-
if (kv->do_defrag) {
|
|
759
|
-
LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__);
|
|
760
|
-
|
|
761
|
-
if (kv->defrag_prepare(graph_max_nodes())) {
|
|
762
|
-
ggml_backend_sched_reset(sched.get());
|
|
763
|
-
|
|
764
|
-
auto * gf = graph_init();
|
|
765
|
-
|
|
766
|
-
auto res = build_kv_self_defrag(ctx_compute.get(), gf);
|
|
767
|
-
|
|
768
|
-
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
769
|
-
|
|
770
|
-
res->set_inputs(nullptr);
|
|
453
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
771
454
|
|
|
772
|
-
|
|
773
|
-
|
|
774
|
-
need_reserve = true;
|
|
775
|
-
}
|
|
776
|
-
|
|
777
|
-
kv->do_defrag = false;
|
|
778
|
-
}
|
|
455
|
+
need_reserve = kv_self->update(*this);
|
|
779
456
|
|
|
780
457
|
// reserve a worst case graph if needed
|
|
781
458
|
if (need_reserve) {
|
|
@@ -786,7 +463,7 @@ void llama_context::kv_self_update() {
|
|
|
786
463
|
uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
|
787
464
|
|
|
788
465
|
// simulate full KV cache
|
|
789
|
-
kv_self->
|
|
466
|
+
kv_self->set_full();
|
|
790
467
|
|
|
791
468
|
llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph
|
|
792
469
|
llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr};
|
|
@@ -807,9 +484,6 @@ enum llama_pooling_type llama_context::pooling_type() const {
|
|
|
807
484
|
}
|
|
808
485
|
|
|
809
486
|
float * llama_context::get_logits() {
|
|
810
|
-
// reorder logits for backward compatibility
|
|
811
|
-
output_reorder();
|
|
812
|
-
|
|
813
487
|
return logits;
|
|
814
488
|
}
|
|
815
489
|
|
|
@@ -852,9 +526,6 @@ float * llama_context::get_logits_ith(int32_t i) {
|
|
|
852
526
|
}
|
|
853
527
|
|
|
854
528
|
float * llama_context::get_embeddings() {
|
|
855
|
-
// reorder embeddings for backward compatibility
|
|
856
|
-
output_reorder();
|
|
857
|
-
|
|
858
529
|
return embd;
|
|
859
530
|
}
|
|
860
531
|
|
|
@@ -1006,8 +677,8 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
1006
677
|
}
|
|
1007
678
|
|
|
1008
679
|
// temporary allocate memory for the input batch if needed
|
|
1009
|
-
//
|
|
1010
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 :
|
|
680
|
+
// note: during encode, we always pass the full sequence starting from pos = 0
|
|
681
|
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0);
|
|
1011
682
|
|
|
1012
683
|
const llama_batch & batch = batch_allocr.batch;
|
|
1013
684
|
const int32_t n_tokens = batch.n_tokens;
|
|
@@ -1032,11 +703,13 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
1032
703
|
t_compute_start_us = ggml_time_us();
|
|
1033
704
|
}
|
|
1034
705
|
|
|
706
|
+
embd_seq.clear();
|
|
707
|
+
|
|
1035
708
|
n_queued_tokens += n_tokens;
|
|
1036
709
|
|
|
1037
710
|
const int64_t n_embd = hparams.n_embd;
|
|
1038
711
|
|
|
1039
|
-
sbatch
|
|
712
|
+
llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true);
|
|
1040
713
|
|
|
1041
714
|
const llama_ubatch ubatch = sbatch.split_simple(n_tokens);
|
|
1042
715
|
|
|
@@ -1093,12 +766,12 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
1093
766
|
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
|
1094
767
|
GGML_ASSERT(backend_embd != nullptr);
|
|
1095
768
|
|
|
1096
|
-
GGML_ASSERT(embd != nullptr);
|
|
1097
|
-
|
|
1098
769
|
switch (cparams.pooling_type) {
|
|
1099
770
|
case LLAMA_POOLING_TYPE_NONE:
|
|
1100
771
|
{
|
|
1101
772
|
// extract token embeddings
|
|
773
|
+
GGML_ASSERT(embd != nullptr);
|
|
774
|
+
|
|
1102
775
|
GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
|
|
1103
776
|
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
|
|
1104
777
|
} break;
|
|
@@ -1123,11 +796,18 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
1123
796
|
} break;
|
|
1124
797
|
case LLAMA_POOLING_TYPE_RANK:
|
|
1125
798
|
{
|
|
1126
|
-
//
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
799
|
+
// extract the rerank score - a single float per sequence
|
|
800
|
+
auto & embd_seq_out = embd_seq;
|
|
801
|
+
|
|
802
|
+
for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
|
|
803
|
+
const llama_seq_id seq_id = ubatch.seq_id[s][0];
|
|
804
|
+
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
|
|
805
|
+
continue;
|
|
806
|
+
}
|
|
807
|
+
embd_seq_out[seq_id].resize(1);
|
|
808
|
+
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
|
|
809
|
+
}
|
|
810
|
+
} break;
|
|
1131
811
|
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
|
1132
812
|
{
|
|
1133
813
|
GGML_ABORT("unknown pooling type");
|
|
@@ -1165,14 +845,21 @@ int llama_context::encode(llama_batch & inp_batch) {
|
|
|
1165
845
|
}
|
|
1166
846
|
|
|
1167
847
|
int llama_context::decode(llama_batch & inp_batch) {
|
|
848
|
+
if (!memory) {
|
|
849
|
+
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
|
|
850
|
+
return encode(inp_batch);
|
|
851
|
+
}
|
|
852
|
+
|
|
1168
853
|
if (inp_batch.n_tokens == 0) {
|
|
1169
854
|
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
|
|
1170
855
|
return -1;
|
|
1171
856
|
}
|
|
1172
857
|
|
|
858
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
859
|
+
|
|
1173
860
|
// temporary allocate memory for the input batch if needed
|
|
1174
|
-
// TODO: this is incorrect for multiple sequences because
|
|
1175
|
-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->
|
|
861
|
+
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
|
|
862
|
+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
|
|
1176
863
|
|
|
1177
864
|
const llama_batch & batch = batch_allocr.batch;
|
|
1178
865
|
|
|
@@ -1184,33 +871,7 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1184
871
|
const int64_t n_tokens_all = batch.n_tokens;
|
|
1185
872
|
const int64_t n_embd = hparams.n_embd;
|
|
1186
873
|
|
|
1187
|
-
|
|
1188
|
-
class batch_guard {
|
|
1189
|
-
public:
|
|
1190
|
-
batch_guard(llama_kv_cache_unified & kv_self) : kv_slot_restorer(kv_self) {
|
|
1191
|
-
}
|
|
1192
|
-
|
|
1193
|
-
~batch_guard() {
|
|
1194
|
-
if (!is_done) {
|
|
1195
|
-
kv_slot_restorer.restore();
|
|
1196
|
-
}
|
|
1197
|
-
}
|
|
1198
|
-
|
|
1199
|
-
void done() {
|
|
1200
|
-
is_done = true;
|
|
1201
|
-
}
|
|
1202
|
-
|
|
1203
|
-
void save(const llama_kv_cache_slot_info & slot_info) {
|
|
1204
|
-
kv_slot_restorer.save(slot_info);
|
|
1205
|
-
}
|
|
1206
|
-
|
|
1207
|
-
private:
|
|
1208
|
-
bool is_done = false;
|
|
1209
|
-
|
|
1210
|
-
llama_kv_slot_restorer kv_slot_restorer;
|
|
1211
|
-
};
|
|
1212
|
-
|
|
1213
|
-
batch_guard bg(*kv_self);
|
|
874
|
+
llama_kv_cache_guard kv_guard(kv_self);
|
|
1214
875
|
|
|
1215
876
|
GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
|
|
1216
877
|
|
|
@@ -1244,18 +905,14 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1244
905
|
for (uint32_t i = 0; i < n_tokens_all; ++i) {
|
|
1245
906
|
n_outputs_all += batch.logits[i] != 0;
|
|
1246
907
|
}
|
|
1247
|
-
} else if (
|
|
908
|
+
} else if (embd_pooled) {
|
|
1248
909
|
n_outputs_all = n_tokens_all;
|
|
1249
910
|
} else {
|
|
1250
911
|
// keep last output only
|
|
1251
912
|
n_outputs_all = 1;
|
|
1252
913
|
}
|
|
1253
914
|
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
sbatch.from_batch(batch, n_embd,
|
|
1257
|
-
/* simple_split */ !kv_self->recurrent,
|
|
1258
|
-
/* logits_all */ logits_all);
|
|
915
|
+
llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all);
|
|
1259
916
|
|
|
1260
917
|
// reserve output buffer
|
|
1261
918
|
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
@@ -1263,25 +920,13 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1263
920
|
return -2;
|
|
1264
921
|
};
|
|
1265
922
|
|
|
923
|
+
// handle any pending defrags/shifts
|
|
924
|
+
kv_self_update();
|
|
925
|
+
|
|
1266
926
|
int64_t n_outputs_prev = 0;
|
|
1267
927
|
|
|
1268
928
|
while (sbatch.n_tokens > 0) {
|
|
1269
|
-
llama_ubatch ubatch =
|
|
1270
|
-
|
|
1271
|
-
const auto & n_ubatch = cparams.n_ubatch;
|
|
1272
|
-
|
|
1273
|
-
if (kv_self->recurrent) {
|
|
1274
|
-
if (embd_pooled) {
|
|
1275
|
-
// Pooled embeddings cannot be split across ubatches (yet)
|
|
1276
|
-
ubatch = sbatch.split_seq(cparams.n_ubatch);
|
|
1277
|
-
} else {
|
|
1278
|
-
// recurrent model architectures are easier to implement
|
|
1279
|
-
// with equal-length sequences
|
|
1280
|
-
ubatch = sbatch.split_equal(cparams.n_ubatch);
|
|
1281
|
-
}
|
|
1282
|
-
} else {
|
|
1283
|
-
ubatch = sbatch.split_simple(n_ubatch);
|
|
1284
|
-
}
|
|
929
|
+
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
|
1285
930
|
|
|
1286
931
|
// count the outputs in this u_batch
|
|
1287
932
|
{
|
|
@@ -1300,35 +945,13 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1300
945
|
n_outputs = n_outputs_new;
|
|
1301
946
|
}
|
|
1302
947
|
|
|
1303
|
-
//
|
|
1304
|
-
if (
|
|
1305
|
-
|
|
948
|
+
// find KV slot
|
|
949
|
+
if (!kv_self->find_slot(ubatch)) {
|
|
950
|
+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
|
1306
951
|
|
|
1307
|
-
|
|
1308
|
-
// better to start searching from the beginning of the cache, hoping to fill it
|
|
1309
|
-
if (kv_self->head > kv_self->used + 2*ubatch.n_tokens) {
|
|
1310
|
-
kv_self->head = 0;
|
|
1311
|
-
}
|
|
1312
|
-
|
|
1313
|
-
const auto slot_info = kv_self->find_slot(ubatch);
|
|
1314
|
-
if (!slot_info) {
|
|
1315
|
-
LLAMA_LOG_ERROR("%s: failed to prepare ubatch\n", __func__);
|
|
1316
|
-
return -3;
|
|
1317
|
-
}
|
|
1318
|
-
|
|
1319
|
-
bg.save(slot_info);
|
|
1320
|
-
|
|
1321
|
-
if (!kv_self->recurrent) {
|
|
1322
|
-
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
|
1323
|
-
// after enough generations, the benefit from this heuristic disappears
|
|
1324
|
-
// if we start defragmenting the cache, the benefit from this will be more important
|
|
1325
|
-
const uint32_t pad = kv_self->get_padding(cparams);
|
|
1326
|
-
kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad)));
|
|
1327
|
-
}
|
|
952
|
+
return 1;
|
|
1328
953
|
}
|
|
1329
954
|
|
|
1330
|
-
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head);
|
|
1331
|
-
|
|
1332
955
|
ggml_backend_sched_reset(sched.get());
|
|
1333
956
|
ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data);
|
|
1334
957
|
|
|
@@ -1354,16 +977,6 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1354
977
|
}
|
|
1355
978
|
}
|
|
1356
979
|
|
|
1357
|
-
// update the kv ring buffer
|
|
1358
|
-
{
|
|
1359
|
-
kv_self->head += ubatch.n_tokens;
|
|
1360
|
-
|
|
1361
|
-
// Ensure kv cache head points to a valid index.
|
|
1362
|
-
if (kv_self->head >= kv_self->size) {
|
|
1363
|
-
kv_self->head = 0;
|
|
1364
|
-
}
|
|
1365
|
-
}
|
|
1366
|
-
|
|
1367
980
|
// plot the computation graph in dot format (for debugging purposes)
|
|
1368
981
|
//if (n_past%100 == 0) {
|
|
1369
982
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
|
@@ -1450,45 +1063,70 @@ int llama_context::decode(llama_batch & inp_batch) {
|
|
|
1450
1063
|
}
|
|
1451
1064
|
|
|
1452
1065
|
// finalize the batch processing
|
|
1453
|
-
|
|
1066
|
+
kv_guard.commit();
|
|
1067
|
+
|
|
1068
|
+
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
|
1069
|
+
n_outputs = n_outputs_all;
|
|
1454
1070
|
|
|
1455
1071
|
// set output mappings
|
|
1456
1072
|
{
|
|
1457
1073
|
bool sorted_output = true;
|
|
1458
1074
|
|
|
1459
|
-
|
|
1075
|
+
auto & out_ids = sbatch.out_ids;
|
|
1076
|
+
|
|
1077
|
+
GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all);
|
|
1460
1078
|
|
|
1461
1079
|
for (int64_t i = 0; i < n_outputs_all; ++i) {
|
|
1462
|
-
int64_t out_id =
|
|
1080
|
+
int64_t out_id = out_ids[i];
|
|
1463
1081
|
output_ids[out_id] = i;
|
|
1464
1082
|
if (out_id != i) {
|
|
1465
1083
|
sorted_output = false;
|
|
1466
1084
|
}
|
|
1467
1085
|
}
|
|
1468
1086
|
|
|
1469
|
-
|
|
1470
|
-
|
|
1087
|
+
// make the outputs have the same order they had in the user-provided batch
|
|
1088
|
+
// note: this is mostly relevant for recurrent models atm
|
|
1089
|
+
if (!sorted_output) {
|
|
1090
|
+
const uint32_t n_vocab = model.vocab.n_tokens();
|
|
1091
|
+
const uint32_t n_embd = model.hparams.n_embd;
|
|
1092
|
+
|
|
1093
|
+
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
1094
|
+
|
|
1095
|
+
// TODO: is there something more efficient which also minimizes swaps?
|
|
1096
|
+
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
|
1097
|
+
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
|
1098
|
+
int32_t j_min = i;
|
|
1099
|
+
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
|
1100
|
+
if (out_ids[j] < out_ids[j_min]) {
|
|
1101
|
+
j_min = j;
|
|
1102
|
+
}
|
|
1103
|
+
}
|
|
1104
|
+
if (j_min == i) { continue; }
|
|
1105
|
+
std::swap(out_ids[i], out_ids[j_min]);
|
|
1106
|
+
if (logits_size > 0) {
|
|
1107
|
+
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
1108
|
+
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
|
|
1109
|
+
}
|
|
1110
|
+
}
|
|
1111
|
+
if (embd_size > 0) {
|
|
1112
|
+
for (uint32_t k = 0; k < n_embd; k++) {
|
|
1113
|
+
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
|
1114
|
+
}
|
|
1115
|
+
}
|
|
1116
|
+
}
|
|
1117
|
+
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1118
|
+
for (int32_t i = 0; i < n_outputs; ++i) {
|
|
1119
|
+
output_ids[out_ids[i]] = i;
|
|
1120
|
+
}
|
|
1471
1121
|
}
|
|
1472
1122
|
}
|
|
1473
1123
|
|
|
1474
|
-
// set to total number of outputs in the batch, for use in llama_get_logits_ith
|
|
1475
|
-
n_outputs = n_outputs_all;
|
|
1476
|
-
|
|
1477
1124
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
|
1478
1125
|
//synchronize();
|
|
1479
1126
|
|
|
1480
1127
|
// decide if we need to defrag the kv cache
|
|
1481
|
-
if (cparams.
|
|
1482
|
-
|
|
1483
|
-
// - count the padding towards the number of used tokens
|
|
1484
|
-
const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f;
|
|
1485
|
-
|
|
1486
|
-
// queue defragmentation for next llama_kv_cache_update
|
|
1487
|
-
if (fragmentation > cparams.defrag_thold) {
|
|
1488
|
-
LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation);
|
|
1489
|
-
|
|
1490
|
-
kv_self->defrag();
|
|
1491
|
-
}
|
|
1128
|
+
if (cparams.defrag_thold > 0.0f) {
|
|
1129
|
+
kv_self->defrag_sched(cparams.defrag_thold);
|
|
1492
1130
|
}
|
|
1493
1131
|
|
|
1494
1132
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
|
@@ -1568,52 +1206,12 @@ int32_t llama_context::output_reserve(int32_t n_outputs) {
|
|
|
1568
1206
|
// set all ids as invalid (negative)
|
|
1569
1207
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1570
1208
|
|
|
1571
|
-
ggml_backend_buffer_clear(buf_output.get(), 0);
|
|
1572
|
-
|
|
1573
1209
|
this->n_outputs = 0;
|
|
1574
1210
|
this->n_outputs_max = n_outputs_max;
|
|
1575
1211
|
|
|
1576
1212
|
return n_outputs_max;
|
|
1577
1213
|
}
|
|
1578
1214
|
|
|
1579
|
-
void llama_context::output_reorder() {
|
|
1580
|
-
auto & out_ids = sbatch.out_ids;
|
|
1581
|
-
if (!out_ids.empty()) {
|
|
1582
|
-
const uint32_t n_vocab = model.vocab.n_tokens();
|
|
1583
|
-
const uint32_t n_embd = model.hparams.n_embd;
|
|
1584
|
-
|
|
1585
|
-
GGML_ASSERT((size_t) n_outputs == out_ids.size());
|
|
1586
|
-
|
|
1587
|
-
// TODO: is there something more efficient which also minimizes swaps?
|
|
1588
|
-
// selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort)
|
|
1589
|
-
for (int32_t i = 0; i < n_outputs - 1; ++i) {
|
|
1590
|
-
int32_t j_min = i;
|
|
1591
|
-
for (int32_t j = i + 1; j < n_outputs; ++j) {
|
|
1592
|
-
if (out_ids[j] < out_ids[j_min]) {
|
|
1593
|
-
j_min = j;
|
|
1594
|
-
}
|
|
1595
|
-
}
|
|
1596
|
-
if (j_min == i) { continue; }
|
|
1597
|
-
std::swap(out_ids[i], out_ids[j_min]);
|
|
1598
|
-
if (logits_size > 0) {
|
|
1599
|
-
for (uint32_t k = 0; k < n_vocab; k++) {
|
|
1600
|
-
std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]);
|
|
1601
|
-
}
|
|
1602
|
-
}
|
|
1603
|
-
if (embd_size > 0) {
|
|
1604
|
-
for (uint32_t k = 0; k < n_embd; k++) {
|
|
1605
|
-
std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]);
|
|
1606
|
-
}
|
|
1607
|
-
}
|
|
1608
|
-
}
|
|
1609
|
-
std::fill(output_ids.begin(), output_ids.end(), -1);
|
|
1610
|
-
for (int32_t i = 0; i < n_outputs; ++i) {
|
|
1611
|
-
output_ids[out_ids[i]] = i;
|
|
1612
|
-
}
|
|
1613
|
-
out_ids.clear();
|
|
1614
|
-
}
|
|
1615
|
-
}
|
|
1616
|
-
|
|
1617
1215
|
//
|
|
1618
1216
|
// graph
|
|
1619
1217
|
//
|
|
@@ -1650,7 +1248,7 @@ llm_graph_result_ptr llama_context::graph_build(
|
|
|
1650
1248
|
/*.backend_cpu =*/ backend_cpu,
|
|
1651
1249
|
/*.cvec =*/ &cvec,
|
|
1652
1250
|
/*.loras =*/ &loras,
|
|
1653
|
-
/*.memory =*/
|
|
1251
|
+
/*.memory =*/ memory.get(),
|
|
1654
1252
|
/*.cross =*/ &cross,
|
|
1655
1253
|
/*.n_outputs =*/ n_outputs,
|
|
1656
1254
|
/*.cb =*/ graph_get_cb(),
|
|
@@ -2054,8 +1652,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
2054
1652
|
{
|
|
2055
1653
|
LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__);
|
|
2056
1654
|
|
|
2057
|
-
output_reorder();
|
|
2058
|
-
|
|
2059
1655
|
const auto n_outputs = this->n_outputs;
|
|
2060
1656
|
const auto & output_ids = this->output_ids;
|
|
2061
1657
|
|
|
@@ -2108,8 +1704,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) {
|
|
|
2108
1704
|
}
|
|
2109
1705
|
}
|
|
2110
1706
|
|
|
2111
|
-
|
|
2112
|
-
|
|
1707
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
1708
|
+
|
|
1709
|
+
if (kv_self != nullptr) {
|
|
1710
|
+
LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__);
|
|
1711
|
+
kv_self->state_write(io);
|
|
1712
|
+
}
|
|
2113
1713
|
|
|
2114
1714
|
return io.n_bytes();
|
|
2115
1715
|
}
|
|
@@ -2192,8 +1792,13 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
2192
1792
|
}
|
|
2193
1793
|
}
|
|
2194
1794
|
|
|
2195
|
-
|
|
2196
|
-
|
|
1795
|
+
if (memory) {
|
|
1796
|
+
LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__);
|
|
1797
|
+
|
|
1798
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
1799
|
+
|
|
1800
|
+
kv_self->state_read(io);
|
|
1801
|
+
}
|
|
2197
1802
|
|
|
2198
1803
|
return io.n_bytes();
|
|
2199
1804
|
}
|
|
@@ -2201,7 +1806,11 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|
|
2201
1806
|
size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) {
|
|
2202
1807
|
GGML_UNUSED(seq_id);
|
|
2203
1808
|
|
|
2204
|
-
|
|
1809
|
+
if (memory) {
|
|
1810
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
1811
|
+
|
|
1812
|
+
kv_self->state_write(io, seq_id);
|
|
1813
|
+
}
|
|
2205
1814
|
|
|
2206
1815
|
return io.n_bytes();
|
|
2207
1816
|
}
|
|
@@ -2209,7 +1818,11 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s
|
|
|
2209
1818
|
size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) {
|
|
2210
1819
|
GGML_UNUSED(seq_id);
|
|
2211
1820
|
|
|
2212
|
-
|
|
1821
|
+
if (memory) {
|
|
1822
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
1823
|
+
|
|
1824
|
+
kv_self->state_read(io, seq_id);
|
|
1825
|
+
}
|
|
2213
1826
|
|
|
2214
1827
|
return io.n_bytes();
|
|
2215
1828
|
}
|
|
@@ -2237,6 +1850,215 @@ void llama_context::perf_reset() {
|
|
|
2237
1850
|
t_p_eval_us = n_p_eval = 0;
|
|
2238
1851
|
}
|
|
2239
1852
|
|
|
1853
|
+
//
|
|
1854
|
+
// training
|
|
1855
|
+
//
|
|
1856
|
+
|
|
1857
|
+
static void llama_set_param(struct ggml_tensor * tensor, llama_opt_param_filter param_filter, void * userdata) {
|
|
1858
|
+
if (!tensor || tensor->type != GGML_TYPE_F32) {
|
|
1859
|
+
return;
|
|
1860
|
+
}
|
|
1861
|
+
if (!param_filter(tensor, userdata)) {
|
|
1862
|
+
return;
|
|
1863
|
+
}
|
|
1864
|
+
if (strcmp(tensor->name, "token_embd.weight") == 0) {
|
|
1865
|
+
return; // FIXME
|
|
1866
|
+
}
|
|
1867
|
+
if (strcmp(tensor->name, "rope_freqs.weight") == 0) {
|
|
1868
|
+
return; // FIXME
|
|
1869
|
+
}
|
|
1870
|
+
ggml_set_param(tensor);
|
|
1871
|
+
}
|
|
1872
|
+
|
|
1873
|
+
void llama_context::opt_init(struct llama_model * model, struct llama_opt_params lopt_params) {
|
|
1874
|
+
GGML_ASSERT(!opt_ctx);
|
|
1875
|
+
model->hparams.n_ctx_train = lopt_params.n_ctx_train > 0 ? lopt_params.n_ctx_train : n_ctx();
|
|
1876
|
+
const uint32_t n_batch = std::min(this->n_batch(), model->hparams.n_ctx_train);
|
|
1877
|
+
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
|
1878
|
+
GGML_ASSERT(model->hparams.n_ctx_train % n_batch == 0);
|
|
1879
|
+
GGML_ASSERT(n_batch % n_ubatch == 0);
|
|
1880
|
+
|
|
1881
|
+
ggml_opt_params opt_params = ggml_opt_default_params(sched.get(), GGML_OPT_LOSS_TYPE_CROSS_ENTROPY);
|
|
1882
|
+
opt_params.opt_period = n_batch / n_ubatch;
|
|
1883
|
+
opt_params.get_opt_pars = lopt_params.get_opt_pars;
|
|
1884
|
+
opt_params.get_opt_pars_ud = lopt_params.get_opt_pars_ud;
|
|
1885
|
+
|
|
1886
|
+
opt_ctx = ggml_opt_init(opt_params);
|
|
1887
|
+
|
|
1888
|
+
llama_opt_param_filter param_filter = lopt_params.param_filter;
|
|
1889
|
+
void * param_filter_ud = lopt_params.param_filter_ud;
|
|
1890
|
+
|
|
1891
|
+
//llama_set_param(model->tok_embd, param_filter, param_filter_ud); // FIXME
|
|
1892
|
+
llama_set_param(model->type_embd, param_filter, param_filter_ud);
|
|
1893
|
+
llama_set_param(model->pos_embd, param_filter, param_filter_ud);
|
|
1894
|
+
llama_set_param(model->tok_norm, param_filter, param_filter_ud);
|
|
1895
|
+
llama_set_param(model->tok_norm_b, param_filter, param_filter_ud);
|
|
1896
|
+
llama_set_param(model->output_norm, param_filter, param_filter_ud);
|
|
1897
|
+
llama_set_param(model->output_norm_b, param_filter, param_filter_ud);
|
|
1898
|
+
llama_set_param(model->output, param_filter, param_filter_ud);
|
|
1899
|
+
llama_set_param(model->output_b, param_filter, param_filter_ud);
|
|
1900
|
+
llama_set_param(model->output_norm_enc, param_filter, param_filter_ud);
|
|
1901
|
+
llama_set_param(model->cls, param_filter, param_filter_ud);
|
|
1902
|
+
llama_set_param(model->cls_b, param_filter, param_filter_ud);
|
|
1903
|
+
llama_set_param(model->cls_out, param_filter, param_filter_ud);
|
|
1904
|
+
llama_set_param(model->cls_out_b, param_filter, param_filter_ud);
|
|
1905
|
+
|
|
1906
|
+
for (struct llama_layer & layer : model->layers) {
|
|
1907
|
+
for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) {
|
|
1908
|
+
llama_set_param(reinterpret_cast<struct ggml_tensor **>(&layer)[i], param_filter, param_filter_ud);
|
|
1909
|
+
}
|
|
1910
|
+
}
|
|
1911
|
+
}
|
|
1912
|
+
|
|
1913
|
+
void llama_context::opt_epoch_iter(
|
|
1914
|
+
ggml_opt_dataset_t dataset,
|
|
1915
|
+
ggml_opt_result_t result,
|
|
1916
|
+
const std::vector<llama_token> & tokens,
|
|
1917
|
+
const std::vector<llama_token> & labels_sparse,
|
|
1918
|
+
llama_batch & batch,
|
|
1919
|
+
ggml_opt_epoch_callback callback,
|
|
1920
|
+
bool train,
|
|
1921
|
+
int64_t idata_in_loop,
|
|
1922
|
+
int64_t ndata_in_loop,
|
|
1923
|
+
int64_t t_loop_start) {
|
|
1924
|
+
GGML_ASSERT(opt_ctx);
|
|
1925
|
+
const uint32_t n_ctx = llama_model_n_ctx_train(&model);
|
|
1926
|
+
const uint32_t n_batch = std::min(this->n_batch(), n_ctx);
|
|
1927
|
+
const uint32_t n_ubatch = std::min(this->n_ubatch(), n_batch);
|
|
1928
|
+
|
|
1929
|
+
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
|
|
1930
|
+
|
|
1931
|
+
kv_self->clear();
|
|
1932
|
+
llama_kv_cache_guard kv_guard(kv_self);
|
|
1933
|
+
|
|
1934
|
+
for (uint32_t pos_ctx = 0; pos_ctx < n_ctx; pos_ctx += n_batch) {
|
|
1935
|
+
batch.n_tokens = n_batch;
|
|
1936
|
+
for (uint32_t pos_batch = 0; pos_batch < n_batch; ++pos_batch) {
|
|
1937
|
+
batch.token [pos_batch] = tokens[pos_ctx + pos_batch];
|
|
1938
|
+
batch.pos [pos_batch] = pos_ctx + pos_batch;
|
|
1939
|
+
batch.n_seq_id[pos_batch] = 1;
|
|
1940
|
+
batch.seq_id [pos_batch][0] = 0;
|
|
1941
|
+
batch.logits [pos_batch] = true;
|
|
1942
|
+
}
|
|
1943
|
+
|
|
1944
|
+
const auto n_tokens_all = batch.n_tokens;
|
|
1945
|
+
|
|
1946
|
+
n_queued_tokens += n_tokens_all;
|
|
1947
|
+
|
|
1948
|
+
// this indicates we are doing pooled embedding, so we ignore batch.logits and output all tokens
|
|
1949
|
+
const bool embd_pooled = cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE;
|
|
1950
|
+
|
|
1951
|
+
embd_seq.clear();
|
|
1952
|
+
|
|
1953
|
+
int64_t n_outputs_all = n_tokens_all;
|
|
1954
|
+
|
|
1955
|
+
llama_sbatch sbatch = kv_self->sbatch_init(batch, /*logits_all =*/ true);
|
|
1956
|
+
|
|
1957
|
+
// reserve output buffer
|
|
1958
|
+
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
|
1959
|
+
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %" PRId64 " outputs\n", __func__, n_outputs_all);
|
|
1960
|
+
GGML_ABORT("TODO: handle this error");
|
|
1961
|
+
};
|
|
1962
|
+
|
|
1963
|
+
for (uint32_t pos_batch = 0; pos_batch < n_batch; pos_batch += n_ubatch) {
|
|
1964
|
+
llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled);
|
|
1965
|
+
|
|
1966
|
+
n_outputs = ubatch.n_tokens;
|
|
1967
|
+
|
|
1968
|
+
// TODO: not sure if this is needed
|
|
1969
|
+
if (!kv_self->find_slot(ubatch)) {
|
|
1970
|
+
LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens);
|
|
1971
|
+
|
|
1972
|
+
GGML_ABORT("TODO: handle this error");
|
|
1973
|
+
}
|
|
1974
|
+
|
|
1975
|
+
auto * gf = graph_init();
|
|
1976
|
+
auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT);
|
|
1977
|
+
|
|
1978
|
+
struct ggml_context * ctx_compute_opt;
|
|
1979
|
+
{
|
|
1980
|
+
const size_t size_gf = ggml_graph_size(gf);
|
|
1981
|
+
const size_t size_meta = 4*size_gf*ggml_tensor_overhead() + 2*ggml_graph_overhead_custom(size_gf, /*grads = */ true);
|
|
1982
|
+
struct ggml_init_params params = {
|
|
1983
|
+
/*.mem_size =*/ size_meta,
|
|
1984
|
+
/*.mem_buffer =*/ nullptr,
|
|
1985
|
+
/*.no_alloc =*/ true,
|
|
1986
|
+
};
|
|
1987
|
+
ctx_compute_opt = ggml_init(params);
|
|
1988
|
+
}
|
|
1989
|
+
ggml_opt_prepare_alloc(opt_ctx, ctx_compute_opt, gf, res->get_tokens(), res->get_logits());
|
|
1990
|
+
ggml_opt_alloc(opt_ctx, train);
|
|
1991
|
+
res->set_inputs(&ubatch);
|
|
1992
|
+
{
|
|
1993
|
+
struct ggml_tensor * labels = ggml_opt_labels(opt_ctx);
|
|
1994
|
+
GGML_ASSERT(labels->ne[1] == n_ubatch);
|
|
1995
|
+
ggml_set_zero(labels);
|
|
1996
|
+
const float onef = 1.0f;
|
|
1997
|
+
for (uint32_t pos_ubatch = 0; pos_ubatch < n_ubatch; ++pos_ubatch) {
|
|
1998
|
+
const uint32_t ilabel = pos_ctx + pos_batch + pos_ubatch;
|
|
1999
|
+
GGML_ASSERT(labels_sparse[ilabel] < labels->ne[0]);
|
|
2000
|
+
ggml_backend_tensor_set(labels, &onef, (pos_ubatch*labels->ne[0] + labels_sparse[ilabel])*sizeof(float), sizeof(float));
|
|
2001
|
+
}
|
|
2002
|
+
}
|
|
2003
|
+
ggml_opt_eval(opt_ctx, result);
|
|
2004
|
+
if (callback) {
|
|
2005
|
+
callback(train, opt_ctx, dataset, result, idata_in_loop + (pos_ctx + pos_batch)/n_ubatch + 1, ndata_in_loop, t_loop_start);
|
|
2006
|
+
}
|
|
2007
|
+
ggml_free(ctx_compute_opt);
|
|
2008
|
+
}
|
|
2009
|
+
}
|
|
2010
|
+
|
|
2011
|
+
kv_guard.commit();
|
|
2012
|
+
}
|
|
2013
|
+
|
|
2014
|
+
void llama_context::opt_epoch(
|
|
2015
|
+
ggml_opt_dataset_t dataset,
|
|
2016
|
+
ggml_opt_result_t result_train,
|
|
2017
|
+
ggml_opt_result_t result_eval,
|
|
2018
|
+
int64_t idata_split,
|
|
2019
|
+
ggml_opt_epoch_callback callback_train,
|
|
2020
|
+
ggml_opt_epoch_callback callback_eval) {
|
|
2021
|
+
const uint32_t n_ctx = this->n_ctx();
|
|
2022
|
+
const uint32_t n_batch = std::min(cparams.n_batch, n_ctx);
|
|
2023
|
+
const uint32_t n_ubatch = std::min(cparams.n_ubatch, n_batch);
|
|
2024
|
+
const int64_t ndata = ggml_opt_dataset_ndata(dataset);
|
|
2025
|
+
|
|
2026
|
+
GGML_ASSERT(idata_split >= 0);
|
|
2027
|
+
GGML_ASSERT(idata_split <= ndata);
|
|
2028
|
+
|
|
2029
|
+
const uint32_t ubatch_per_ctx = n_ctx / n_ubatch;
|
|
2030
|
+
|
|
2031
|
+
struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
|
|
2032
|
+
std::vector<llama_token> tokens(n_ctx);
|
|
2033
|
+
std::vector<llama_token> labels_sparse(n_ctx);
|
|
2034
|
+
|
|
2035
|
+
int64_t idata = 0;
|
|
2036
|
+
|
|
2037
|
+
int64_t t_loop_start = ggml_time_us();
|
|
2038
|
+
int64_t ndata_in_loop = idata_split*ubatch_per_ctx;
|
|
2039
|
+
for (; idata < idata_split; ++idata) {
|
|
2040
|
+
constexpr bool train = true;
|
|
2041
|
+
const int64_t idata_in_loop = idata*ubatch_per_ctx;
|
|
2042
|
+
|
|
2043
|
+
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
|
2044
|
+
opt_epoch_iter(dataset, result_train, tokens, labels_sparse, batch,
|
|
2045
|
+
callback_train, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
|
2046
|
+
}
|
|
2047
|
+
|
|
2048
|
+
t_loop_start = ggml_time_us();
|
|
2049
|
+
ndata_in_loop = (ndata - idata_split)*ubatch_per_ctx;
|
|
2050
|
+
for (; idata < ndata; ++idata) {
|
|
2051
|
+
constexpr bool train = false;
|
|
2052
|
+
const int64_t idata_in_loop = (idata - idata_split)*ubatch_per_ctx;
|
|
2053
|
+
|
|
2054
|
+
ggml_opt_dataset_get_batch_host(dataset, tokens.data(), n_ctx*sizeof(llama_token), labels_sparse.data(), idata);
|
|
2055
|
+
opt_epoch_iter(dataset, result_eval, tokens, labels_sparse, batch,
|
|
2056
|
+
callback_eval, train, idata_in_loop, ndata_in_loop, t_loop_start);
|
|
2057
|
+
}
|
|
2058
|
+
|
|
2059
|
+
llama_batch_free(batch);
|
|
2060
|
+
}
|
|
2061
|
+
|
|
2240
2062
|
//
|
|
2241
2063
|
// interface implementation
|
|
2242
2064
|
//
|
|
@@ -2264,13 +2086,13 @@ llama_context_params llama_context_default_params() {
|
|
|
2264
2086
|
/*.cb_eval_user_data =*/ nullptr,
|
|
2265
2087
|
/*.type_k =*/ GGML_TYPE_F16,
|
|
2266
2088
|
/*.type_v =*/ GGML_TYPE_F16,
|
|
2267
|
-
/*.
|
|
2089
|
+
/*.abort_callback =*/ nullptr,
|
|
2090
|
+
/*.abort_callback_data =*/ nullptr,
|
|
2268
2091
|
/*.embeddings =*/ false,
|
|
2269
2092
|
/*.offload_kqv =*/ true,
|
|
2270
2093
|
/*.flash_attn =*/ false,
|
|
2271
2094
|
/*.no_perf =*/ true,
|
|
2272
|
-
/*.
|
|
2273
|
-
/*.abort_callback_data =*/ nullptr,
|
|
2095
|
+
/*.op_offload =*/ true,
|
|
2274
2096
|
};
|
|
2275
2097
|
|
|
2276
2098
|
return result;
|
|
@@ -2299,11 +2121,6 @@ llama_context * llama_init_from_model(
|
|
|
2299
2121
|
params.flash_attn = false;
|
|
2300
2122
|
}
|
|
2301
2123
|
|
|
2302
|
-
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
|
2303
|
-
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
|
2304
|
-
params.flash_attn = false;
|
|
2305
|
-
}
|
|
2306
|
-
|
|
2307
2124
|
if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
|
|
2308
2125
|
LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
|
|
2309
2126
|
return nullptr;
|
|
@@ -2504,7 +2321,12 @@ int32_t llama_get_kv_cache_token_count(const llama_context * ctx) {
|
|
|
2504
2321
|
}
|
|
2505
2322
|
|
|
2506
2323
|
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
|
|
2507
|
-
|
|
2324
|
+
const auto * kv = ctx->get_kv_self();
|
|
2325
|
+
if (!kv) {
|
|
2326
|
+
return 0;
|
|
2327
|
+
}
|
|
2328
|
+
|
|
2329
|
+
return kv->get_n_tokens();
|
|
2508
2330
|
}
|
|
2509
2331
|
|
|
2510
2332
|
// deprecated
|
|
@@ -2513,7 +2335,12 @@ int32_t llama_get_kv_cache_used_cells(const llama_context * ctx) {
|
|
|
2513
2335
|
}
|
|
2514
2336
|
|
|
2515
2337
|
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
|
|
2516
|
-
|
|
2338
|
+
const auto * kv = ctx->get_kv_self();
|
|
2339
|
+
if (!kv) {
|
|
2340
|
+
return 0;
|
|
2341
|
+
}
|
|
2342
|
+
|
|
2343
|
+
return kv->get_used_cells();
|
|
2517
2344
|
}
|
|
2518
2345
|
|
|
2519
2346
|
// deprecated
|
|
@@ -2522,7 +2349,12 @@ void llama_kv_cache_clear(llama_context * ctx) {
|
|
|
2522
2349
|
}
|
|
2523
2350
|
|
|
2524
2351
|
void llama_kv_self_clear(llama_context * ctx) {
|
|
2525
|
-
|
|
2352
|
+
auto * kv = ctx->get_kv_self();
|
|
2353
|
+
if (!kv) {
|
|
2354
|
+
return;
|
|
2355
|
+
}
|
|
2356
|
+
|
|
2357
|
+
kv->clear();
|
|
2526
2358
|
}
|
|
2527
2359
|
|
|
2528
2360
|
// deprecated
|
|
@@ -2539,7 +2371,12 @@ bool llama_kv_self_seq_rm(
|
|
|
2539
2371
|
llama_seq_id seq_id,
|
|
2540
2372
|
llama_pos p0,
|
|
2541
2373
|
llama_pos p1) {
|
|
2542
|
-
|
|
2374
|
+
auto * kv = ctx->get_kv_self();
|
|
2375
|
+
if (!kv) {
|
|
2376
|
+
return true;
|
|
2377
|
+
}
|
|
2378
|
+
|
|
2379
|
+
return kv->seq_rm(seq_id, p0, p1);
|
|
2543
2380
|
}
|
|
2544
2381
|
|
|
2545
2382
|
// deprecated
|
|
@@ -2549,7 +2386,7 @@ void llama_kv_cache_seq_cp(
|
|
|
2549
2386
|
llama_seq_id seq_id_dst,
|
|
2550
2387
|
llama_pos p0,
|
|
2551
2388
|
llama_pos p1) {
|
|
2552
|
-
|
|
2389
|
+
llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1);
|
|
2553
2390
|
}
|
|
2554
2391
|
|
|
2555
2392
|
void llama_kv_self_seq_cp(
|
|
@@ -2558,18 +2395,28 @@ void llama_kv_self_seq_cp(
|
|
|
2558
2395
|
llama_seq_id seq_id_dst,
|
|
2559
2396
|
llama_pos p0,
|
|
2560
2397
|
llama_pos p1) {
|
|
2561
|
-
|
|
2398
|
+
auto * kv = ctx->get_kv_self();
|
|
2399
|
+
if (!kv) {
|
|
2400
|
+
return;
|
|
2401
|
+
}
|
|
2402
|
+
|
|
2403
|
+
kv->seq_cp(seq_id_src, seq_id_dst, p0, p1);
|
|
2562
2404
|
}
|
|
2563
2405
|
|
|
2564
2406
|
// deprecated
|
|
2565
2407
|
void llama_kv_cache_seq_keep(
|
|
2566
2408
|
llama_context * ctx,
|
|
2567
2409
|
llama_seq_id seq_id) {
|
|
2568
|
-
|
|
2410
|
+
llama_kv_self_seq_keep(ctx, seq_id);
|
|
2569
2411
|
}
|
|
2570
2412
|
|
|
2571
2413
|
void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) {
|
|
2572
|
-
|
|
2414
|
+
auto * kv = ctx->get_kv_self();
|
|
2415
|
+
if (!kv) {
|
|
2416
|
+
return;
|
|
2417
|
+
}
|
|
2418
|
+
|
|
2419
|
+
kv->seq_keep(seq_id);
|
|
2573
2420
|
}
|
|
2574
2421
|
|
|
2575
2422
|
// deprecated
|
|
@@ -2579,7 +2426,7 @@ void llama_kv_cache_seq_add(
|
|
|
2579
2426
|
llama_pos p0,
|
|
2580
2427
|
llama_pos p1,
|
|
2581
2428
|
llama_pos delta) {
|
|
2582
|
-
|
|
2429
|
+
llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta);
|
|
2583
2430
|
}
|
|
2584
2431
|
|
|
2585
2432
|
void llama_kv_self_seq_add(
|
|
@@ -2588,7 +2435,12 @@ void llama_kv_self_seq_add(
|
|
|
2588
2435
|
llama_pos p0,
|
|
2589
2436
|
llama_pos p1,
|
|
2590
2437
|
llama_pos delta) {
|
|
2591
|
-
|
|
2438
|
+
auto * kv = ctx->get_kv_self();
|
|
2439
|
+
if (!kv) {
|
|
2440
|
+
return;
|
|
2441
|
+
}
|
|
2442
|
+
|
|
2443
|
+
kv->seq_add(seq_id, p0, p1, delta);
|
|
2592
2444
|
}
|
|
2593
2445
|
|
|
2594
2446
|
// deprecated
|
|
@@ -2598,7 +2450,7 @@ void llama_kv_cache_seq_div(
|
|
|
2598
2450
|
llama_pos p0,
|
|
2599
2451
|
llama_pos p1,
|
|
2600
2452
|
int d) {
|
|
2601
|
-
|
|
2453
|
+
llama_kv_self_seq_div(ctx, seq_id, p0, p1, d);
|
|
2602
2454
|
}
|
|
2603
2455
|
|
|
2604
2456
|
void llama_kv_self_seq_div(
|
|
@@ -2607,7 +2459,12 @@ void llama_kv_self_seq_div(
|
|
|
2607
2459
|
llama_pos p0,
|
|
2608
2460
|
llama_pos p1,
|
|
2609
2461
|
int d) {
|
|
2610
|
-
|
|
2462
|
+
auto * kv = ctx->get_kv_self();
|
|
2463
|
+
if (!kv) {
|
|
2464
|
+
return;
|
|
2465
|
+
}
|
|
2466
|
+
|
|
2467
|
+
kv->seq_div(seq_id, p0, p1, d);
|
|
2611
2468
|
}
|
|
2612
2469
|
|
|
2613
2470
|
// deprecated
|
|
@@ -2616,16 +2473,27 @@ llama_pos llama_kv_cache_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
|
2616
2473
|
}
|
|
2617
2474
|
|
|
2618
2475
|
llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) {
|
|
2619
|
-
|
|
2476
|
+
const auto * kv = ctx->get_kv_self();
|
|
2477
|
+
if (!kv) {
|
|
2478
|
+
return 0;
|
|
2479
|
+
}
|
|
2480
|
+
|
|
2481
|
+
return kv->seq_pos_max(seq_id);
|
|
2620
2482
|
}
|
|
2621
2483
|
|
|
2622
2484
|
// deprecated
|
|
2623
2485
|
void llama_kv_cache_defrag(llama_context * ctx) {
|
|
2624
|
-
|
|
2486
|
+
llama_kv_self_defrag(ctx);
|
|
2625
2487
|
}
|
|
2626
2488
|
|
|
2627
2489
|
void llama_kv_self_defrag(llama_context * ctx) {
|
|
2628
|
-
|
|
2490
|
+
auto * kv = ctx->get_kv_self();
|
|
2491
|
+
if (!kv) {
|
|
2492
|
+
return;
|
|
2493
|
+
}
|
|
2494
|
+
|
|
2495
|
+
// force defrag
|
|
2496
|
+
kv->defrag_sched(-1.0f);
|
|
2629
2497
|
}
|
|
2630
2498
|
|
|
2631
2499
|
// deprecated
|
|
@@ -2634,7 +2502,12 @@ bool llama_kv_cache_can_shift(const llama_context * ctx) {
|
|
|
2634
2502
|
}
|
|
2635
2503
|
|
|
2636
2504
|
bool llama_kv_self_can_shift(const llama_context * ctx) {
|
|
2637
|
-
|
|
2505
|
+
const auto * kv = ctx->get_kv_self();
|
|
2506
|
+
if (!kv) {
|
|
2507
|
+
return false;
|
|
2508
|
+
}
|
|
2509
|
+
|
|
2510
|
+
return kv->get_can_shift();
|
|
2638
2511
|
}
|
|
2639
2512
|
|
|
2640
2513
|
// deprecated
|
|
@@ -2804,3 +2677,34 @@ void llama_perf_context_print(const llama_context * ctx) {
|
|
|
2804
2677
|
void llama_perf_context_reset(llama_context * ctx) {
|
|
2805
2678
|
ctx->perf_reset();
|
|
2806
2679
|
}
|
|
2680
|
+
|
|
2681
|
+
//
|
|
2682
|
+
// training
|
|
2683
|
+
//
|
|
2684
|
+
|
|
2685
|
+
bool llama_opt_param_filter_all(const struct ggml_tensor * tensor, void * userdata) {
|
|
2686
|
+
GGML_UNUSED(tensor);
|
|
2687
|
+
GGML_UNUSED(userdata);
|
|
2688
|
+
return true;
|
|
2689
|
+
}
|
|
2690
|
+
|
|
2691
|
+
void llama_opt_init(struct llama_context * ctx, struct llama_model * model, struct llama_opt_params lopt_params) {
|
|
2692
|
+
ctx->opt_init(model, lopt_params);
|
|
2693
|
+
}
|
|
2694
|
+
|
|
2695
|
+
void llama_opt_epoch(
|
|
2696
|
+
struct llama_context * ctx,
|
|
2697
|
+
ggml_opt_dataset_t dataset,
|
|
2698
|
+
ggml_opt_result_t result_train,
|
|
2699
|
+
ggml_opt_result_t result_eval,
|
|
2700
|
+
int64_t idata_split,
|
|
2701
|
+
ggml_opt_epoch_callback callback_train,
|
|
2702
|
+
ggml_opt_epoch_callback callback_eval) {
|
|
2703
|
+
ctx->opt_epoch(
|
|
2704
|
+
dataset,
|
|
2705
|
+
result_train,
|
|
2706
|
+
result_eval,
|
|
2707
|
+
idata_split,
|
|
2708
|
+
callback_train,
|
|
2709
|
+
callback_eval);
|
|
2710
|
+
}
|