@fugood/llama.node 0.3.2 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +7 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +17 -7
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +89 -27
- package/src/LlamaContext.h +2 -0
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +240 -168
- package/src/llama.cpp/.github/workflows/docker.yml +8 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +14 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -4
- package/src/llama.cpp/common/arg.cpp +986 -770
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +212 -351
- package/src/llama.cpp/common/common.h +204 -117
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +163 -121
- package/src/llama.cpp/common/sampling.h +41 -20
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +134 -161
- package/src/llama.cpp/examples/CMakeLists.txt +33 -14
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +19 -18
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +41 -87
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +263 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +83 -22
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
- package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +73 -114
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
- package/src/llama.cpp/examples/server/server.cpp +2073 -1339
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +354 -277
- package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/simple/simple.cpp +130 -94
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
- package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +159 -417
- package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +93 -52
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +4 -8
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +779 -194
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +55 -10
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +4317 -2979
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -38
- package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
- package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +62 -20
- package/src/llama.cpp/tests/test-sampling.cpp +163 -138
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -12,7 +12,7 @@
|
|
|
12
12
|
#include <string>
|
|
13
13
|
#include <vector>
|
|
14
14
|
|
|
15
|
-
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE
|
|
15
|
+
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
|
|
16
16
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
17
17
|
|
|
18
18
|
struct seq_draft {
|
|
@@ -26,22 +26,27 @@ struct seq_draft {
|
|
|
26
26
|
std::vector<llama_token> tokens;
|
|
27
27
|
std::vector<std::vector<llama_token_data>> dists;
|
|
28
28
|
|
|
29
|
-
struct
|
|
29
|
+
struct common_sampler * smpl = nullptr;
|
|
30
30
|
};
|
|
31
31
|
|
|
32
32
|
int main(int argc, char ** argv) {
|
|
33
|
-
|
|
33
|
+
common_params params;
|
|
34
34
|
|
|
35
35
|
// needed to get candidate probs even for temp <= 0.0
|
|
36
|
-
params.
|
|
36
|
+
params.sampling.n_probs = 128;
|
|
37
37
|
|
|
38
|
-
if (!
|
|
38
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
|
39
39
|
return 1;
|
|
40
40
|
}
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
if (params.n_predict < -1) {
|
|
43
|
+
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
|
|
44
|
+
return 1;
|
|
45
|
+
}
|
|
43
46
|
|
|
44
|
-
|
|
47
|
+
common_init();
|
|
48
|
+
|
|
49
|
+
if (params.speculative.model.empty()) {
|
|
45
50
|
LOG_ERR("%s: --model-draft is required\n", __func__);
|
|
46
51
|
return 1;
|
|
47
52
|
}
|
|
@@ -50,9 +55,9 @@ int main(int argc, char ** argv) {
|
|
|
50
55
|
const int n_seq_dft = params.n_parallel;
|
|
51
56
|
|
|
52
57
|
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
|
53
|
-
const float
|
|
58
|
+
const float p_draft_split = params.speculative.p_split;
|
|
54
59
|
|
|
55
|
-
std::default_random_engine rng(params.
|
|
60
|
+
std::default_random_engine rng(params.sampling.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sampling.seed);
|
|
56
61
|
std::uniform_real_distribution<> u_dist;
|
|
57
62
|
|
|
58
63
|
// init llama.cpp
|
|
@@ -66,19 +71,20 @@ int main(int argc, char ** argv) {
|
|
|
66
71
|
llama_context * ctx_dft = NULL;
|
|
67
72
|
|
|
68
73
|
// load the target model
|
|
69
|
-
|
|
74
|
+
common_init_result llama_init_tgt = common_init_from_params(params);
|
|
70
75
|
model_tgt = llama_init_tgt.model;
|
|
71
76
|
ctx_tgt = llama_init_tgt.context;
|
|
72
77
|
|
|
73
78
|
// load the draft model
|
|
74
|
-
params.
|
|
75
|
-
params.
|
|
76
|
-
|
|
77
|
-
|
|
79
|
+
params.devices = params.speculative.devices;
|
|
80
|
+
params.model = params.speculative.model;
|
|
81
|
+
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
|
82
|
+
if (params.speculative.cpuparams.n_threads > 0) {
|
|
83
|
+
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
|
78
84
|
}
|
|
79
85
|
|
|
80
|
-
params.cpuparams_batch.n_threads = params.
|
|
81
|
-
|
|
86
|
+
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
|
87
|
+
common_init_result llama_init_dft = common_init_from_params(params);
|
|
82
88
|
model_dft = llama_init_dft.model;
|
|
83
89
|
ctx_dft = llama_init_dft.context;
|
|
84
90
|
|
|
@@ -124,8 +130,8 @@ int main(int argc, char ** argv) {
|
|
|
124
130
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
125
131
|
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
|
|
126
132
|
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
|
|
127
|
-
|
|
128
|
-
|
|
133
|
+
common_token_to_piece(ctx_tgt, i).c_str(),
|
|
134
|
+
common_token_to_piece(ctx_dft, i).c_str());
|
|
129
135
|
return 1;
|
|
130
136
|
}
|
|
131
137
|
}
|
|
@@ -134,7 +140,7 @@ int main(int argc, char ** argv) {
|
|
|
134
140
|
|
|
135
141
|
// Tokenize the prompt
|
|
136
142
|
std::vector<llama_token> inp;
|
|
137
|
-
inp =
|
|
143
|
+
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
|
|
138
144
|
|
|
139
145
|
const int max_context_size = llama_n_ctx(ctx_tgt);
|
|
140
146
|
const int max_tokens_list_size = max_context_size - 4;
|
|
@@ -147,7 +153,7 @@ int main(int argc, char ** argv) {
|
|
|
147
153
|
LOG("\n\n");
|
|
148
154
|
|
|
149
155
|
for (auto id : inp) {
|
|
150
|
-
LOG("%s",
|
|
156
|
+
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
|
|
151
157
|
}
|
|
152
158
|
|
|
153
159
|
const int n_input = inp.size();
|
|
@@ -155,9 +161,9 @@ int main(int argc, char ** argv) {
|
|
|
155
161
|
const auto t_enc_start = ggml_time_us();
|
|
156
162
|
|
|
157
163
|
// eval the prompt with both models
|
|
158
|
-
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1
|
|
159
|
-
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1
|
|
160
|
-
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input
|
|
164
|
+
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
|
|
165
|
+
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
|
|
166
|
+
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
|
|
161
167
|
|
|
162
168
|
const auto t_enc_end = ggml_time_us();
|
|
163
169
|
|
|
@@ -165,7 +171,7 @@ int main(int argc, char ** argv) {
|
|
|
165
171
|
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
|
|
166
172
|
|
|
167
173
|
// how many tokens to draft each time
|
|
168
|
-
int n_draft = params.
|
|
174
|
+
int n_draft = params.speculative.n_max;
|
|
169
175
|
|
|
170
176
|
int n_predict = 0;
|
|
171
177
|
int n_drafted = 0;
|
|
@@ -178,20 +184,18 @@ int main(int argc, char ** argv) {
|
|
|
178
184
|
bool has_eos = false;
|
|
179
185
|
|
|
180
186
|
// target model sampling context (reuse the llama_context's sampling instance)
|
|
181
|
-
struct
|
|
182
|
-
|
|
183
|
-
struct llama_sampler * softmax = llama_sampler_init_softmax();
|
|
187
|
+
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
|
|
184
188
|
|
|
185
189
|
// draft sequence data
|
|
186
190
|
std::vector<seq_draft> drafts(n_seq_dft);
|
|
187
191
|
|
|
188
192
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
189
|
-
// allocate
|
|
190
|
-
drafts[s].smpl =
|
|
193
|
+
// allocate llama_sampler for each draft sequence
|
|
194
|
+
drafts[s].smpl = common_sampler_init(model_dft, params.sampling);
|
|
191
195
|
}
|
|
192
196
|
|
|
193
|
-
llama_batch batch_dft = llama_batch_init(
|
|
194
|
-
llama_batch batch_tgt = llama_batch_init(
|
|
197
|
+
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
|
198
|
+
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
|
|
195
199
|
|
|
196
200
|
const auto t_dec_start = ggml_time_us();
|
|
197
201
|
|
|
@@ -227,11 +231,11 @@ int main(int argc, char ** argv) {
|
|
|
227
231
|
// for stochastic sampling, attempt to match the token with the drafted tokens
|
|
228
232
|
{
|
|
229
233
|
bool accept = false;
|
|
230
|
-
if (params.
|
|
234
|
+
if (params.sampling.temp > 0) {
|
|
231
235
|
// stochastic verification
|
|
232
|
-
|
|
236
|
+
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
|
|
233
237
|
|
|
234
|
-
auto & dist_tgt = *
|
|
238
|
+
auto & dist_tgt = *common_sampler_get_candidates(smpl);
|
|
235
239
|
|
|
236
240
|
float p_tgt = 0.0f;
|
|
237
241
|
float p_dft = 0.0f;
|
|
@@ -264,11 +268,12 @@ int main(int argc, char ** argv) {
|
|
|
264
268
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
265
269
|
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
|
266
270
|
p_tgt = dist_tgt.data[i].p;
|
|
271
|
+
break;
|
|
267
272
|
}
|
|
273
|
+
}
|
|
274
|
+
for (size_t i = 0; i < dist_dft.size; i++) {
|
|
268
275
|
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
|
|
269
276
|
p_dft = dist_dft.data[i].p;
|
|
270
|
-
}
|
|
271
|
-
if (p_tgt && p_dft) {
|
|
272
277
|
break;
|
|
273
278
|
}
|
|
274
279
|
}
|
|
@@ -277,13 +282,13 @@ int main(int argc, char ** argv) {
|
|
|
277
282
|
s_keep = s;
|
|
278
283
|
accept = true;
|
|
279
284
|
token_id = drafts[s].tokens[i_dft];
|
|
280
|
-
token_str =
|
|
281
|
-
|
|
285
|
+
token_str = common_token_to_piece(ctx_tgt, token_id);
|
|
286
|
+
common_sampler_accept(smpl, token_id, true);
|
|
282
287
|
|
|
283
288
|
LOG_DBG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
|
284
289
|
break;
|
|
285
290
|
} else {
|
|
286
|
-
LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft],
|
|
291
|
+
LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], common_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
|
|
287
292
|
drafts[s].active = false;
|
|
288
293
|
|
|
289
294
|
// calculate residual probability
|
|
@@ -349,19 +354,19 @@ int main(int argc, char ** argv) {
|
|
|
349
354
|
const int idx = dist(rng);
|
|
350
355
|
|
|
351
356
|
token_id = dist_tgt.data[idx].id;
|
|
352
|
-
|
|
353
|
-
token_str =
|
|
357
|
+
common_sampler_accept(smpl, token_id, true);
|
|
358
|
+
token_str = common_token_to_piece(ctx_tgt, token_id);
|
|
354
359
|
}
|
|
355
360
|
} else {
|
|
356
361
|
// greedy verification
|
|
357
362
|
|
|
358
363
|
// sample from the target model
|
|
359
364
|
LOG_DBG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
360
|
-
token_id =
|
|
365
|
+
token_id = common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
361
366
|
|
|
362
|
-
|
|
367
|
+
common_sampler_accept(smpl, token_id, true);
|
|
363
368
|
|
|
364
|
-
token_str =
|
|
369
|
+
token_str = common_token_to_piece(ctx_tgt, token_id);
|
|
365
370
|
|
|
366
371
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
367
372
|
if (!drafts[s].active) {
|
|
@@ -431,8 +436,8 @@ int main(int argc, char ** argv) {
|
|
|
431
436
|
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
|
432
437
|
drafts[0].i_batch_tgt.push_back(0);
|
|
433
438
|
|
|
434
|
-
|
|
435
|
-
|
|
439
|
+
common_batch_clear(batch_dft);
|
|
440
|
+
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
|
436
441
|
|
|
437
442
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
|
438
443
|
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
|
@@ -441,14 +446,14 @@ int main(int argc, char ** argv) {
|
|
|
441
446
|
++n_past_dft;
|
|
442
447
|
}
|
|
443
448
|
|
|
444
|
-
if (n_predict > params.n_predict || has_eos) {
|
|
449
|
+
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
|
445
450
|
break;
|
|
446
451
|
}
|
|
447
452
|
|
|
448
453
|
if (drafts[0].smpl) {
|
|
449
|
-
|
|
454
|
+
common_sampler_free(drafts[0].smpl);
|
|
450
455
|
}
|
|
451
|
-
drafts[0].smpl =
|
|
456
|
+
drafts[0].smpl = common_sampler_clone(smpl);
|
|
452
457
|
|
|
453
458
|
int n_seq_cur = 1;
|
|
454
459
|
int n_past_cur = n_past_dft;
|
|
@@ -461,8 +466,8 @@ int main(int argc, char ** argv) {
|
|
|
461
466
|
drafts[0].drafting = true;
|
|
462
467
|
drafts[0].i_batch_dft = 0;
|
|
463
468
|
|
|
464
|
-
|
|
465
|
-
|
|
469
|
+
common_batch_clear(batch_tgt);
|
|
470
|
+
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
|
466
471
|
|
|
467
472
|
// sample n_draft tokens from the draft model using tree-based sampling
|
|
468
473
|
for (int i = 0; i < n_draft; ++i) {
|
|
@@ -477,20 +482,20 @@ int main(int argc, char ** argv) {
|
|
|
477
482
|
continue;
|
|
478
483
|
}
|
|
479
484
|
|
|
480
|
-
|
|
485
|
+
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
|
|
481
486
|
|
|
482
|
-
const auto * cur_p =
|
|
487
|
+
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
|
|
483
488
|
|
|
484
489
|
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
|
|
485
490
|
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
|
486
|
-
k, s, i, cur_p->data[k].id, cur_p->data[k].p,
|
|
491
|
+
k, s, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
|
487
492
|
}
|
|
488
493
|
|
|
489
494
|
std::vector<int> sa(1, s);
|
|
490
495
|
|
|
491
496
|
// attempt to split the branch if the probability is high enough
|
|
492
497
|
for (int f = 1; f < 8; ++f) {
|
|
493
|
-
if (n_seq_cur < n_seq_dft && cur_p->data[f].p >
|
|
498
|
+
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_draft_split) {
|
|
494
499
|
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
|
495
500
|
|
|
496
501
|
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
|
@@ -518,9 +523,9 @@ int main(int argc, char ** argv) {
|
|
|
518
523
|
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
|
519
524
|
|
|
520
525
|
if (drafts[n_seq_cur].smpl) {
|
|
521
|
-
|
|
526
|
+
common_sampler_free(drafts[n_seq_cur].smpl);
|
|
522
527
|
}
|
|
523
|
-
drafts[n_seq_cur].smpl =
|
|
528
|
+
drafts[n_seq_cur].smpl = common_sampler_clone(drafts[s].smpl);
|
|
524
529
|
|
|
525
530
|
sa.push_back(n_seq_cur);
|
|
526
531
|
|
|
@@ -536,7 +541,7 @@ int main(int argc, char ** argv) {
|
|
|
536
541
|
|
|
537
542
|
const int s = sa[is];
|
|
538
543
|
|
|
539
|
-
|
|
544
|
+
common_sampler_accept(drafts[s].smpl, id, true);
|
|
540
545
|
|
|
541
546
|
drafts[s].tokens.push_back(id);
|
|
542
547
|
// save cur_p.data into drafts[s].dists
|
|
@@ -545,12 +550,12 @@ int main(int argc, char ** argv) {
|
|
|
545
550
|
// add unique drafted tokens to the target batch
|
|
546
551
|
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
|
547
552
|
|
|
548
|
-
|
|
553
|
+
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
|
549
554
|
|
|
550
555
|
// add the token to the batch for batched decoding with the draft model
|
|
551
556
|
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
|
552
557
|
|
|
553
|
-
|
|
558
|
+
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
|
554
559
|
|
|
555
560
|
if (batch_tgt.n_tokens > n_draft) {
|
|
556
561
|
drafts[s].drafting = false;
|
|
@@ -617,14 +622,13 @@ int main(int argc, char ** argv) {
|
|
|
617
622
|
|
|
618
623
|
LOG_INF("\n");
|
|
619
624
|
LOG_INF("target:\n\n");
|
|
620
|
-
|
|
625
|
+
common_perf_print(ctx_tgt, smpl);
|
|
621
626
|
|
|
622
|
-
|
|
627
|
+
common_sampler_free(smpl);
|
|
623
628
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
624
|
-
|
|
629
|
+
common_sampler_free(drafts[s].smpl);
|
|
625
630
|
}
|
|
626
631
|
|
|
627
|
-
llama_sampler_free(softmax);
|
|
628
632
|
llama_batch_free(batch_dft);
|
|
629
633
|
|
|
630
634
|
llama_free(ctx_tgt);
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
#include "arg.h"
|
|
2
|
+
#include "common.h"
|
|
3
|
+
#include "sampling.h"
|
|
4
|
+
#include "speculative.h"
|
|
5
|
+
#include "log.h"
|
|
6
|
+
#include "llama.h"
|
|
7
|
+
|
|
8
|
+
#include <cstdio>
|
|
9
|
+
#include <cstring>
|
|
10
|
+
#include <string>
|
|
11
|
+
#include <vector>
|
|
12
|
+
|
|
13
|
+
int main(int argc, char ** argv) {
|
|
14
|
+
common_params params;
|
|
15
|
+
|
|
16
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
|
17
|
+
return 1;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
if (params.n_predict < -1) {
|
|
21
|
+
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
|
|
22
|
+
return 1;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
common_init();
|
|
26
|
+
|
|
27
|
+
if (params.speculative.model.empty()) {
|
|
28
|
+
LOG_ERR("%s: --model-draft is required\n", __func__);
|
|
29
|
+
return 1;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
// init llama.cpp
|
|
33
|
+
llama_backend_init();
|
|
34
|
+
llama_numa_init(params.numa);
|
|
35
|
+
|
|
36
|
+
llama_model * model_tgt = NULL;
|
|
37
|
+
llama_model * model_dft = NULL;
|
|
38
|
+
|
|
39
|
+
llama_context * ctx_tgt = NULL;
|
|
40
|
+
llama_context * ctx_dft = NULL;
|
|
41
|
+
|
|
42
|
+
// load the target model
|
|
43
|
+
common_init_result llama_init_tgt = common_init_from_params(params);
|
|
44
|
+
|
|
45
|
+
model_tgt = llama_init_tgt.model;
|
|
46
|
+
ctx_tgt = llama_init_tgt.context;
|
|
47
|
+
|
|
48
|
+
// load the draft model
|
|
49
|
+
params.devices = params.speculative.devices;
|
|
50
|
+
params.model = params.speculative.model;
|
|
51
|
+
params.n_ctx = params.speculative.n_ctx;
|
|
52
|
+
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
|
|
53
|
+
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
|
54
|
+
|
|
55
|
+
if (params.speculative.cpuparams.n_threads > 0) {
|
|
56
|
+
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
|
60
|
+
common_init_result llama_init_dft = common_init_from_params(params);
|
|
61
|
+
|
|
62
|
+
model_dft = llama_init_dft.model;
|
|
63
|
+
ctx_dft = llama_init_dft.context;
|
|
64
|
+
|
|
65
|
+
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
|
|
66
|
+
return 1;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// Tokenize the prompt
|
|
70
|
+
std::vector<llama_token> inp;
|
|
71
|
+
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
|
|
72
|
+
|
|
73
|
+
if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) {
|
|
74
|
+
LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
|
|
75
|
+
|
|
76
|
+
return 1;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) {
|
|
80
|
+
LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
|
|
81
|
+
|
|
82
|
+
return 1;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
LOG("\n\n");
|
|
86
|
+
|
|
87
|
+
for (auto id : inp) {
|
|
88
|
+
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
// how many tokens to draft each time
|
|
92
|
+
int n_draft = params.speculative.n_max;
|
|
93
|
+
int n_draft_min = params.speculative.n_min;
|
|
94
|
+
|
|
95
|
+
float p_min = params.speculative.p_min;
|
|
96
|
+
|
|
97
|
+
int n_predict = 0;
|
|
98
|
+
int n_drafted = 0;
|
|
99
|
+
int n_accept = 0;
|
|
100
|
+
|
|
101
|
+
// used to determine end of generation
|
|
102
|
+
bool has_eos = false;
|
|
103
|
+
|
|
104
|
+
// ================================================
|
|
105
|
+
// everything until here is standard initialization
|
|
106
|
+
// the relevant stuff for speculative decoding starts here
|
|
107
|
+
|
|
108
|
+
const auto t_enc_start = ggml_time_us();
|
|
109
|
+
|
|
110
|
+
// target model sampling context
|
|
111
|
+
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
|
|
112
|
+
|
|
113
|
+
// eval the prompt
|
|
114
|
+
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
|
|
115
|
+
|
|
116
|
+
// note: keep the last token separate!
|
|
117
|
+
llama_token id_last = inp.back();
|
|
118
|
+
|
|
119
|
+
// all tokens currently in the target context
|
|
120
|
+
llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
|
|
121
|
+
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
|
|
122
|
+
|
|
123
|
+
int n_past = inp.size() - 1;
|
|
124
|
+
|
|
125
|
+
// init the speculator
|
|
126
|
+
struct common_speculative_params params_spec;
|
|
127
|
+
params_spec.n_draft = n_draft;
|
|
128
|
+
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
|
|
129
|
+
params_spec.p_min = p_min;
|
|
130
|
+
|
|
131
|
+
struct common_speculative * spec = common_speculative_init(ctx_dft);
|
|
132
|
+
|
|
133
|
+
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
|
134
|
+
|
|
135
|
+
const auto t_enc_end = ggml_time_us();
|
|
136
|
+
|
|
137
|
+
const auto t_dec_start = ggml_time_us();
|
|
138
|
+
|
|
139
|
+
while (true) {
|
|
140
|
+
// optionally, generate draft tokens that can be appended to the target batch
|
|
141
|
+
//
|
|
142
|
+
// this is the most important part of the speculation. the more probable tokens that are provided here
|
|
143
|
+
// the better the performance will be. in theory, this computation can be performed asynchronously and even
|
|
144
|
+
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
|
|
145
|
+
// from a cache or lookup tables.
|
|
146
|
+
//
|
|
147
|
+
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
|
|
148
|
+
|
|
149
|
+
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
|
|
150
|
+
|
|
151
|
+
// always have a token to evaluate from before - id_last
|
|
152
|
+
common_batch_clear(batch_tgt);
|
|
153
|
+
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
|
|
154
|
+
|
|
155
|
+
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
|
156
|
+
{
|
|
157
|
+
// do not waste time on small drafts
|
|
158
|
+
if (draft.size() < (size_t) n_draft_min) {
|
|
159
|
+
draft.clear();
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
for (size_t i = 0; i < draft.size(); ++i) {
|
|
163
|
+
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
|
|
167
|
+
|
|
168
|
+
llama_decode(ctx_tgt, batch_tgt);
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// sample from the full target batch and return the accepted tokens based on the target sampler
|
|
172
|
+
//
|
|
173
|
+
// for each token to be accepted, the sampler would have to sample that same token
|
|
174
|
+
// in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
|
|
175
|
+
// available logits from the batch and sample the next token until we run out of logits or the sampler
|
|
176
|
+
// disagrees with the draft
|
|
177
|
+
//
|
|
178
|
+
const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
|
|
179
|
+
|
|
180
|
+
//LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
|
|
181
|
+
|
|
182
|
+
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
|
|
183
|
+
|
|
184
|
+
n_past += ids.size() - 1;
|
|
185
|
+
n_drafted += draft.size(); // note: we ignore the discarded small drafts
|
|
186
|
+
n_accept += ids.size() - 1;
|
|
187
|
+
n_predict += ids.size();
|
|
188
|
+
|
|
189
|
+
// process the accepted tokens and update contexts
|
|
190
|
+
//
|
|
191
|
+
// this is the standard token post-processing that we normally do
|
|
192
|
+
// in this case, we do it for a group of accepted tokens at once
|
|
193
|
+
//
|
|
194
|
+
for (size_t i = 0; i < ids.size(); ++i) {
|
|
195
|
+
prompt_tgt.push_back(id_last);
|
|
196
|
+
|
|
197
|
+
id_last = ids[i];
|
|
198
|
+
|
|
199
|
+
if (llama_token_is_eog(model_tgt, id_last)) {
|
|
200
|
+
has_eos = true;
|
|
201
|
+
break;
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
|
|
205
|
+
|
|
206
|
+
if (params.use_color && i + 1 < ids.size()) {
|
|
207
|
+
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
|
|
208
|
+
} else {
|
|
209
|
+
LOG("%s", token_str.c_str());
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
|
|
214
|
+
|
|
215
|
+
{
|
|
216
|
+
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
|
217
|
+
|
|
218
|
+
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
|
222
|
+
break;
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
auto t_dec_end = ggml_time_us();
|
|
227
|
+
|
|
228
|
+
const int n_input = inp.size();
|
|
229
|
+
|
|
230
|
+
LOG("\n\n");
|
|
231
|
+
|
|
232
|
+
LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
|
233
|
+
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
|
234
|
+
|
|
235
|
+
LOG_INF("\n");
|
|
236
|
+
LOG_INF("n_draft = %d\n", n_draft);
|
|
237
|
+
LOG_INF("n_predict = %d\n", n_predict);
|
|
238
|
+
LOG_INF("n_drafted = %d\n", n_drafted);
|
|
239
|
+
LOG_INF("n_accept = %d\n", n_accept);
|
|
240
|
+
LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
|
241
|
+
|
|
242
|
+
LOG_INF("\n");
|
|
243
|
+
LOG_INF("draft:\n\n");
|
|
244
|
+
|
|
245
|
+
llama_perf_context_print(ctx_dft);
|
|
246
|
+
|
|
247
|
+
LOG_INF("\n");
|
|
248
|
+
LOG_INF("target:\n\n");
|
|
249
|
+
common_perf_print(ctx_tgt, smpl);
|
|
250
|
+
|
|
251
|
+
common_sampler_free(smpl);
|
|
252
|
+
common_speculative_free(spec);
|
|
253
|
+
|
|
254
|
+
llama_free(ctx_tgt);
|
|
255
|
+
llama_free_model(model_tgt);
|
|
256
|
+
|
|
257
|
+
llama_free(ctx_dft);
|
|
258
|
+
llama_free_model(model_dft);
|
|
259
|
+
|
|
260
|
+
llama_backend_free();
|
|
261
|
+
|
|
262
|
+
LOG("\n\n");
|
|
263
|
+
|
|
264
|
+
return 0;
|
|
265
|
+
}
|
|
@@ -2,4 +2,4 @@ set(TARGET llama-tokenize)
|
|
|
2
2
|
add_executable(${TARGET} tokenize.cpp)
|
|
3
3
|
install(TARGETS ${TARGET} RUNTIME)
|
|
4
4
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
|
5
|
-
target_compile_features(${TARGET} PRIVATE
|
|
5
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
|
@@ -365,7 +365,7 @@ int main(int raw_argc, char ** raw_argv) {
|
|
|
365
365
|
const bool parse_special = !no_parse_special;
|
|
366
366
|
|
|
367
367
|
std::vector<llama_token> tokens;
|
|
368
|
-
tokens =
|
|
368
|
+
tokens = common_tokenize(model, prompt, add_bos, parse_special);
|
|
369
369
|
|
|
370
370
|
if (printing_ids) {
|
|
371
371
|
printf("[");
|
|
@@ -380,7 +380,7 @@ int main(int raw_argc, char ** raw_argv) {
|
|
|
380
380
|
} else {
|
|
381
381
|
bool invalid_utf8 = false;
|
|
382
382
|
printf("%6d -> '", tokens[i]);
|
|
383
|
-
write_utf8_cstr_to_stdout(
|
|
383
|
+
write_utf8_cstr_to_stdout(common_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8);
|
|
384
384
|
if (invalid_utf8) {
|
|
385
385
|
printf("' (utf-8 decode failure)\n");
|
|
386
386
|
} else {
|
|
@@ -394,7 +394,7 @@ int main(int raw_argc, char ** raw_argv) {
|
|
|
394
394
|
}
|
|
395
395
|
|
|
396
396
|
if (show_token_count) {
|
|
397
|
-
printf("Total number of tokens: %
|
|
397
|
+
printf("Total number of tokens: %zu\n", tokens.size());
|
|
398
398
|
}
|
|
399
399
|
// silence valgrind
|
|
400
400
|
llama_free(ctx);
|