@fugood/llama.node 0.3.2 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +7 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +17 -7
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +89 -27
- package/src/LlamaContext.h +2 -0
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +240 -168
- package/src/llama.cpp/.github/workflows/docker.yml +8 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +14 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -4
- package/src/llama.cpp/common/arg.cpp +986 -770
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +212 -351
- package/src/llama.cpp/common/common.h +204 -117
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +163 -121
- package/src/llama.cpp/common/sampling.h +41 -20
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +134 -161
- package/src/llama.cpp/examples/CMakeLists.txt +33 -14
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +19 -18
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +41 -87
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +263 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +83 -22
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
- package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +73 -114
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
- package/src/llama.cpp/examples/server/server.cpp +2073 -1339
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +354 -277
- package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/simple/simple.cpp +130 -94
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
- package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +159 -417
- package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
- package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +93 -52
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +4 -8
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +779 -194
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +55 -10
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +4317 -2979
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -38
- package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
- package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +62 -20
- package/src/llama.cpp/tests/test-sampling.cpp +163 -138
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -2,10 +2,11 @@
|
|
|
2
2
|
|
|
3
3
|
#include "arg.h"
|
|
4
4
|
#include "common.h"
|
|
5
|
-
#include "log.h"
|
|
6
|
-
#include "sampling.h"
|
|
7
5
|
#include "json-schema-to-grammar.h"
|
|
8
6
|
#include "llama.h"
|
|
7
|
+
#include "log.h"
|
|
8
|
+
#include "sampling.h"
|
|
9
|
+
#include "speculative.h"
|
|
9
10
|
|
|
10
11
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
|
11
12
|
#define JSON_ASSERT GGML_ASSERT
|
|
@@ -14,21 +15,7 @@
|
|
|
14
15
|
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
|
15
16
|
|
|
16
17
|
// auto generated files (update with ./deps.sh)
|
|
17
|
-
#include "
|
|
18
|
-
#include "style.css.hpp"
|
|
19
|
-
#include "theme-beeninorder.css.hpp"
|
|
20
|
-
#include "theme-ketivah.css.hpp"
|
|
21
|
-
#include "theme-mangotango.css.hpp"
|
|
22
|
-
#include "theme-playground.css.hpp"
|
|
23
|
-
#include "theme-polarnight.css.hpp"
|
|
24
|
-
#include "theme-snowstorm.css.hpp"
|
|
25
|
-
#include "index.html.hpp"
|
|
26
|
-
#include "index-new.html.hpp"
|
|
27
|
-
#include "index.js.hpp"
|
|
28
|
-
#include "completion.js.hpp"
|
|
29
|
-
#include "system-prompts.js.hpp"
|
|
30
|
-
#include "prompt-formats.js.hpp"
|
|
31
|
-
#include "json-schema-to-grammar.mjs.hpp"
|
|
18
|
+
#include "index.html.gz.hpp"
|
|
32
19
|
#include "loading.html.hpp"
|
|
33
20
|
|
|
34
21
|
#include <atomic>
|
|
@@ -43,31 +30,19 @@
|
|
|
43
30
|
#include <unordered_map>
|
|
44
31
|
#include <unordered_set>
|
|
45
32
|
|
|
46
|
-
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
47
|
-
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
48
|
-
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
49
|
-
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
50
|
-
|
|
51
|
-
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
52
|
-
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
53
|
-
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
54
|
-
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
55
|
-
|
|
56
|
-
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
57
|
-
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
58
|
-
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
59
|
-
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
60
|
-
|
|
61
33
|
using json = nlohmann::ordered_json;
|
|
62
34
|
|
|
63
35
|
enum stop_type {
|
|
64
|
-
|
|
65
|
-
|
|
36
|
+
STOP_TYPE_NONE,
|
|
37
|
+
STOP_TYPE_EOS,
|
|
38
|
+
STOP_TYPE_WORD,
|
|
39
|
+
STOP_TYPE_LIMIT,
|
|
66
40
|
};
|
|
67
41
|
|
|
68
42
|
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
|
69
43
|
enum slot_state {
|
|
70
44
|
SLOT_STATE_IDLE,
|
|
45
|
+
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
|
71
46
|
SLOT_STATE_PROCESSING_PROMPT,
|
|
72
47
|
SLOT_STATE_DONE_PROMPT,
|
|
73
48
|
SLOT_STATE_GENERATING,
|
|
@@ -80,6 +55,9 @@ enum server_state {
|
|
|
80
55
|
|
|
81
56
|
enum server_task_type {
|
|
82
57
|
SERVER_TASK_TYPE_COMPLETION,
|
|
58
|
+
SERVER_TASK_TYPE_EMBEDDING,
|
|
59
|
+
SERVER_TASK_TYPE_RERANK,
|
|
60
|
+
SERVER_TASK_TYPE_INFILL,
|
|
83
61
|
SERVER_TASK_TYPE_CANCEL,
|
|
84
62
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
|
85
63
|
SERVER_TASK_TYPE_METRICS,
|
|
@@ -89,21 +67,309 @@ enum server_task_type {
|
|
|
89
67
|
SERVER_TASK_TYPE_SET_LORA,
|
|
90
68
|
};
|
|
91
69
|
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
70
|
+
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
|
71
|
+
enum error_type {
|
|
72
|
+
ERROR_TYPE_INVALID_REQUEST,
|
|
73
|
+
ERROR_TYPE_AUTHENTICATION,
|
|
74
|
+
ERROR_TYPE_SERVER,
|
|
75
|
+
ERROR_TYPE_NOT_FOUND,
|
|
76
|
+
ERROR_TYPE_PERMISSION,
|
|
77
|
+
ERROR_TYPE_UNAVAILABLE, // custom error
|
|
78
|
+
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
struct slot_params {
|
|
82
|
+
bool stream = true;
|
|
83
|
+
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
|
84
|
+
bool return_tokens = false;
|
|
85
|
+
|
|
86
|
+
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
|
87
|
+
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
|
88
|
+
int32_t n_predict = -1; // new tokens to predict
|
|
89
|
+
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
|
|
90
|
+
|
|
91
|
+
int64_t t_max_prompt_ms = -1; // TODO: implement
|
|
92
|
+
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
|
93
|
+
|
|
94
|
+
std::vector<std::string> antiprompt;
|
|
95
|
+
bool timings_per_token = false;
|
|
96
|
+
bool post_sampling_probs = false;
|
|
97
|
+
bool ignore_eos = false;
|
|
98
|
+
|
|
99
|
+
struct common_params_sampling sampling;
|
|
100
|
+
struct common_params_speculative speculative;
|
|
101
|
+
|
|
102
|
+
// OAI-compat fields
|
|
103
|
+
bool verbose = false;
|
|
104
|
+
bool oaicompat = false;
|
|
105
|
+
bool oaicompat_chat = true;
|
|
106
|
+
std::string oaicompat_model;
|
|
107
|
+
std::string oaicompat_cmpl_id;
|
|
108
|
+
|
|
109
|
+
json to_json() const {
|
|
110
|
+
std::vector<std::string> samplers;
|
|
111
|
+
samplers.reserve(sampling.samplers.size());
|
|
112
|
+
for (const auto & sampler : sampling.samplers) {
|
|
113
|
+
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
return json {
|
|
117
|
+
{"n_predict", n_predict}, // Server configured n_predict
|
|
118
|
+
{"seed", sampling.seed},
|
|
119
|
+
{"temperature", sampling.temp},
|
|
120
|
+
{"dynatemp_range", sampling.dynatemp_range},
|
|
121
|
+
{"dynatemp_exponent", sampling.dynatemp_exponent},
|
|
122
|
+
{"top_k", sampling.top_k},
|
|
123
|
+
{"top_p", sampling.top_p},
|
|
124
|
+
{"min_p", sampling.min_p},
|
|
125
|
+
{"xtc_probability", sampling.xtc_probability},
|
|
126
|
+
{"xtc_threshold", sampling.xtc_threshold},
|
|
127
|
+
{"typical_p", sampling.typ_p},
|
|
128
|
+
{"repeat_last_n", sampling.penalty_last_n},
|
|
129
|
+
{"repeat_penalty", sampling.penalty_repeat},
|
|
130
|
+
{"presence_penalty", sampling.penalty_present},
|
|
131
|
+
{"frequency_penalty", sampling.penalty_freq},
|
|
132
|
+
{"dry_multiplier", sampling.dry_multiplier},
|
|
133
|
+
{"dry_base", sampling.dry_base},
|
|
134
|
+
{"dry_allowed_length", sampling.dry_allowed_length},
|
|
135
|
+
{"dry_penalty_last_n", sampling.dry_penalty_last_n},
|
|
136
|
+
{"dry_sequence_breakers", sampling.dry_sequence_breakers},
|
|
137
|
+
{"mirostat", sampling.mirostat},
|
|
138
|
+
{"mirostat_tau", sampling.mirostat_tau},
|
|
139
|
+
{"mirostat_eta", sampling.mirostat_eta},
|
|
140
|
+
{"stop", antiprompt},
|
|
141
|
+
{"max_tokens", n_predict}, // User configured n_predict
|
|
142
|
+
{"n_keep", n_keep},
|
|
143
|
+
{"n_discard", n_discard},
|
|
144
|
+
{"ignore_eos", sampling.ignore_eos},
|
|
145
|
+
{"stream", stream},
|
|
146
|
+
{"logit_bias", format_logit_bias(sampling.logit_bias)},
|
|
147
|
+
{"n_probs", sampling.n_probs},
|
|
148
|
+
{"min_keep", sampling.min_keep},
|
|
149
|
+
{"grammar", sampling.grammar},
|
|
150
|
+
{"samplers", samplers},
|
|
151
|
+
{"speculative.n_max", speculative.n_max},
|
|
152
|
+
{"speculative.n_min", speculative.n_min},
|
|
153
|
+
{"speculative.p_min", speculative.p_min},
|
|
154
|
+
{"timings_per_token", timings_per_token},
|
|
155
|
+
{"post_sampling_probs", post_sampling_probs},
|
|
156
|
+
};
|
|
157
|
+
}
|
|
97
158
|
};
|
|
98
159
|
|
|
99
160
|
struct server_task {
|
|
100
|
-
int id
|
|
101
|
-
int
|
|
161
|
+
int id = -1; // to be filled by server_queue
|
|
162
|
+
int index = -1; // used when there are multiple prompts (batch request)
|
|
102
163
|
|
|
103
164
|
server_task_type type;
|
|
104
|
-
json data;
|
|
105
165
|
|
|
106
|
-
|
|
166
|
+
// used by SERVER_TASK_TYPE_CANCEL
|
|
167
|
+
int id_target = -1;
|
|
168
|
+
|
|
169
|
+
// used by SERVER_TASK_TYPE_INFERENCE
|
|
170
|
+
slot_params params;
|
|
171
|
+
llama_tokens prompt_tokens;
|
|
172
|
+
int id_selected_slot = -1;
|
|
173
|
+
|
|
174
|
+
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
|
|
175
|
+
struct slot_action {
|
|
176
|
+
int slot_id;
|
|
177
|
+
std::string filename;
|
|
178
|
+
std::string filepath;
|
|
179
|
+
};
|
|
180
|
+
slot_action slot_action;
|
|
181
|
+
|
|
182
|
+
// used by SERVER_TASK_TYPE_METRICS
|
|
183
|
+
bool metrics_reset_bucket = false;
|
|
184
|
+
|
|
185
|
+
server_task(server_task_type type) : type(type) {}
|
|
186
|
+
|
|
187
|
+
static slot_params params_from_json_cmpl(
|
|
188
|
+
const llama_model * model,
|
|
189
|
+
const llama_context * ctx,
|
|
190
|
+
const common_params & params_base,
|
|
191
|
+
const json & data) {
|
|
192
|
+
slot_params params;
|
|
193
|
+
|
|
194
|
+
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
|
195
|
+
slot_params defaults;
|
|
196
|
+
defaults.sampling = params_base.sampling;
|
|
197
|
+
defaults.speculative = params_base.speculative;
|
|
198
|
+
|
|
199
|
+
// enabling this will output extra debug information in the HTTP responses from the server
|
|
200
|
+
params.verbose = params_base.verbosity > 9;
|
|
201
|
+
params.timings_per_token = json_value(data, "timings_per_token", false);
|
|
202
|
+
|
|
203
|
+
params.stream = json_value(data, "stream", false);
|
|
204
|
+
params.cache_prompt = json_value(data, "cache_prompt", true);
|
|
205
|
+
params.return_tokens = json_value(data, "return_tokens", false);
|
|
206
|
+
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
|
207
|
+
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
|
208
|
+
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
|
209
|
+
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
|
210
|
+
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
|
211
|
+
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
|
212
|
+
|
|
213
|
+
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
|
214
|
+
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
|
215
|
+
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
|
|
216
|
+
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
|
|
217
|
+
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
|
|
218
|
+
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
|
|
219
|
+
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
|
|
220
|
+
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
|
|
221
|
+
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
|
|
222
|
+
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
|
|
223
|
+
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
|
|
224
|
+
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
|
|
225
|
+
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
|
|
226
|
+
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
|
|
227
|
+
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
|
|
228
|
+
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
|
|
229
|
+
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
|
|
230
|
+
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
|
231
|
+
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
|
232
|
+
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
|
233
|
+
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
|
234
|
+
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
|
235
|
+
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
|
236
|
+
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
|
237
|
+
|
|
238
|
+
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
|
239
|
+
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
|
240
|
+
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
|
241
|
+
|
|
242
|
+
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
|
|
243
|
+
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
|
244
|
+
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
|
245
|
+
|
|
246
|
+
// TODO: add more sanity checks for the input parameters
|
|
247
|
+
|
|
248
|
+
if (params.sampling.penalty_last_n < -1) {
|
|
249
|
+
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
if (params.sampling.dry_penalty_last_n < -1) {
|
|
253
|
+
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
if (params.sampling.penalty_last_n == -1) {
|
|
257
|
+
// note: should be the slot's context and not the full context, but it's ok
|
|
258
|
+
params.sampling.penalty_last_n = llama_n_ctx(ctx);
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
if (params.sampling.dry_penalty_last_n == -1) {
|
|
262
|
+
params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if (params.sampling.dry_base < 1.0f) {
|
|
266
|
+
params.sampling.dry_base = defaults.sampling.dry_base;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
// sequence breakers for DRY
|
|
270
|
+
{
|
|
271
|
+
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
|
272
|
+
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
|
273
|
+
|
|
274
|
+
if (data.contains("dry_sequence_breakers")) {
|
|
275
|
+
params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
|
276
|
+
if (params.sampling.dry_sequence_breakers.empty()) {
|
|
277
|
+
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
// process "json_schema" and "grammar"
|
|
283
|
+
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
|
284
|
+
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
|
285
|
+
}
|
|
286
|
+
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
287
|
+
try {
|
|
288
|
+
auto schema = json_value(data, "json_schema", json::object());
|
|
289
|
+
params.sampling.grammar = json_schema_to_grammar(schema);
|
|
290
|
+
} catch (const std::exception & e) {
|
|
291
|
+
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
|
292
|
+
}
|
|
293
|
+
} else {
|
|
294
|
+
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
{
|
|
298
|
+
params.sampling.logit_bias.clear();
|
|
299
|
+
params.ignore_eos = json_value(data, "ignore_eos", false);
|
|
300
|
+
|
|
301
|
+
const auto & logit_bias = data.find("logit_bias");
|
|
302
|
+
if (logit_bias != data.end() && logit_bias->is_array()) {
|
|
303
|
+
const int n_vocab = llama_n_vocab(model);
|
|
304
|
+
for (const auto & el : *logit_bias) {
|
|
305
|
+
// TODO: we may want to throw errors here, in case "el" is incorrect
|
|
306
|
+
if (el.is_array() && el.size() == 2) {
|
|
307
|
+
float bias;
|
|
308
|
+
if (el[1].is_number()) {
|
|
309
|
+
bias = el[1].get<float>();
|
|
310
|
+
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
|
311
|
+
bias = -INFINITY;
|
|
312
|
+
} else {
|
|
313
|
+
continue;
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
if (el[0].is_number_integer()) {
|
|
317
|
+
llama_token tok = el[0].get<llama_token>();
|
|
318
|
+
if (tok >= 0 && tok < n_vocab) {
|
|
319
|
+
params.sampling.logit_bias.push_back({tok, bias});
|
|
320
|
+
}
|
|
321
|
+
} else if (el[0].is_string()) {
|
|
322
|
+
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
|
|
323
|
+
for (auto tok : toks) {
|
|
324
|
+
params.sampling.logit_bias.push_back({tok, bias});
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
{
|
|
333
|
+
params.antiprompt.clear();
|
|
334
|
+
|
|
335
|
+
const auto & stop = data.find("stop");
|
|
336
|
+
if (stop != data.end() && stop->is_array()) {
|
|
337
|
+
for (const auto & word : *stop) {
|
|
338
|
+
if (!word.empty()) {
|
|
339
|
+
params.antiprompt.push_back(word);
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
{
|
|
346
|
+
const auto & samplers = data.find("samplers");
|
|
347
|
+
if (samplers != data.end()) {
|
|
348
|
+
if (samplers->is_array()) {
|
|
349
|
+
std::vector<std::string> sampler_names;
|
|
350
|
+
for (const auto & name : *samplers) {
|
|
351
|
+
if (name.is_string()) {
|
|
352
|
+
sampler_names.emplace_back(name);
|
|
353
|
+
}
|
|
354
|
+
}
|
|
355
|
+
params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
|
|
356
|
+
} else if (samplers->is_string()){
|
|
357
|
+
std::string sampler_string;
|
|
358
|
+
for (const auto & name : *samplers) {
|
|
359
|
+
sampler_string += name;
|
|
360
|
+
}
|
|
361
|
+
params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
|
|
362
|
+
}
|
|
363
|
+
} else {
|
|
364
|
+
params.sampling.samplers = defaults.sampling.samplers;
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
|
|
369
|
+
params.oaicompat_model = json_value(data, "model", model_name);
|
|
370
|
+
|
|
371
|
+
return params;
|
|
372
|
+
}
|
|
107
373
|
|
|
108
374
|
// utility function
|
|
109
375
|
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
|
@@ -115,33 +381,628 @@ struct server_task {
|
|
|
115
381
|
}
|
|
116
382
|
};
|
|
117
383
|
|
|
384
|
+
struct result_timings {
|
|
385
|
+
int32_t prompt_n = -1;
|
|
386
|
+
double prompt_ms;
|
|
387
|
+
double prompt_per_token_ms;
|
|
388
|
+
double prompt_per_second;
|
|
389
|
+
|
|
390
|
+
int32_t predicted_n = -1;
|
|
391
|
+
double predicted_ms;
|
|
392
|
+
double predicted_per_token_ms;
|
|
393
|
+
double predicted_per_second;
|
|
394
|
+
|
|
395
|
+
json to_json() const {
|
|
396
|
+
return {
|
|
397
|
+
{"prompt_n", prompt_n},
|
|
398
|
+
{"prompt_ms", prompt_ms},
|
|
399
|
+
{"prompt_per_token_ms", prompt_per_token_ms},
|
|
400
|
+
{"prompt_per_second", prompt_per_second},
|
|
401
|
+
|
|
402
|
+
{"predicted_n", predicted_n},
|
|
403
|
+
{"predicted_ms", predicted_ms},
|
|
404
|
+
{"predicted_per_token_ms", predicted_per_token_ms},
|
|
405
|
+
{"predicted_per_second", predicted_per_second},
|
|
406
|
+
};
|
|
407
|
+
}
|
|
408
|
+
};
|
|
409
|
+
|
|
118
410
|
struct server_task_result {
|
|
119
|
-
int id
|
|
411
|
+
int id = -1;
|
|
412
|
+
int id_slot = -1;
|
|
413
|
+
virtual bool is_error() {
|
|
414
|
+
// only used by server_task_result_error
|
|
415
|
+
return false;
|
|
416
|
+
}
|
|
417
|
+
virtual bool is_stop() {
|
|
418
|
+
// only used by server_task_result_cmpl_*
|
|
419
|
+
return false;
|
|
420
|
+
}
|
|
421
|
+
virtual int get_index() {
|
|
422
|
+
return -1;
|
|
423
|
+
}
|
|
424
|
+
virtual json to_json() = 0;
|
|
425
|
+
virtual ~server_task_result() = default;
|
|
426
|
+
};
|
|
120
427
|
|
|
121
|
-
|
|
428
|
+
// using shared_ptr for polymorphism of server_task_result
|
|
429
|
+
using server_task_result_ptr = std::unique_ptr<server_task_result>;
|
|
122
430
|
|
|
123
|
-
|
|
124
|
-
|
|
431
|
+
inline std::string stop_type_to_str(stop_type type) {
|
|
432
|
+
switch (type) {
|
|
433
|
+
case STOP_TYPE_EOS: return "eos";
|
|
434
|
+
case STOP_TYPE_WORD: return "word";
|
|
435
|
+
case STOP_TYPE_LIMIT: return "limit";
|
|
436
|
+
default: return "none";
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
struct completion_token_output {
|
|
441
|
+
llama_token tok;
|
|
442
|
+
float prob;
|
|
443
|
+
std::string text_to_send;
|
|
444
|
+
struct prob_info {
|
|
445
|
+
llama_token tok;
|
|
446
|
+
std::string txt;
|
|
447
|
+
float prob;
|
|
448
|
+
};
|
|
449
|
+
std::vector<prob_info> probs;
|
|
450
|
+
|
|
451
|
+
json to_json(bool post_sampling_probs) const {
|
|
452
|
+
json probs_for_token = json::array();
|
|
453
|
+
for (const auto & p : probs) {
|
|
454
|
+
std::string txt(p.txt);
|
|
455
|
+
txt.resize(validate_utf8(txt));
|
|
456
|
+
probs_for_token.push_back(json {
|
|
457
|
+
{"id", p.tok},
|
|
458
|
+
{"token", txt},
|
|
459
|
+
{"bytes", str_to_bytes(p.txt)},
|
|
460
|
+
{
|
|
461
|
+
post_sampling_probs ? "prob" : "logprob",
|
|
462
|
+
post_sampling_probs ? p.prob : logarithm(p.prob)
|
|
463
|
+
},
|
|
464
|
+
});
|
|
465
|
+
}
|
|
466
|
+
return probs_for_token;
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
|
|
470
|
+
json out = json::array();
|
|
471
|
+
for (const auto & p : probs) {
|
|
472
|
+
std::string txt(p.text_to_send);
|
|
473
|
+
txt.resize(validate_utf8(txt));
|
|
474
|
+
out.push_back(json {
|
|
475
|
+
{"id", p.tok},
|
|
476
|
+
{"token", txt},
|
|
477
|
+
{"bytes", str_to_bytes(p.text_to_send)},
|
|
478
|
+
{
|
|
479
|
+
post_sampling_probs ? "prob" : "logprob",
|
|
480
|
+
post_sampling_probs ? p.prob : logarithm(p.prob)
|
|
481
|
+
},
|
|
482
|
+
{
|
|
483
|
+
post_sampling_probs ? "top_probs" : "top_logprobs",
|
|
484
|
+
p.to_json(post_sampling_probs)
|
|
485
|
+
},
|
|
486
|
+
});
|
|
487
|
+
}
|
|
488
|
+
return out;
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
static float logarithm(float x) {
|
|
492
|
+
// nlohmann::json converts -inf to null, so we need to prevent that
|
|
493
|
+
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
static std::vector<unsigned char> str_to_bytes(const std::string & str) {
|
|
497
|
+
std::vector<unsigned char> bytes;
|
|
498
|
+
for (unsigned char c : str) {
|
|
499
|
+
bytes.push_back(c);
|
|
500
|
+
}
|
|
501
|
+
return bytes;
|
|
502
|
+
}
|
|
125
503
|
};
|
|
126
504
|
|
|
127
|
-
struct
|
|
128
|
-
|
|
129
|
-
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
|
|
505
|
+
struct server_task_result_cmpl_final : server_task_result {
|
|
506
|
+
int index = 0;
|
|
130
507
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
int32_t n_predict = -1; // new tokens to predict
|
|
508
|
+
std::string content;
|
|
509
|
+
llama_tokens tokens;
|
|
134
510
|
|
|
135
|
-
|
|
511
|
+
bool stream;
|
|
512
|
+
result_timings timings;
|
|
513
|
+
std::string prompt;
|
|
514
|
+
|
|
515
|
+
bool truncated;
|
|
516
|
+
int32_t n_decoded;
|
|
517
|
+
int32_t n_prompt_tokens;
|
|
518
|
+
int32_t n_tokens_cached;
|
|
519
|
+
bool has_new_line;
|
|
520
|
+
std::string stopping_word;
|
|
521
|
+
stop_type stop = STOP_TYPE_NONE;
|
|
136
522
|
|
|
137
|
-
|
|
138
|
-
|
|
523
|
+
bool post_sampling_probs;
|
|
524
|
+
std::vector<completion_token_output> probs_output;
|
|
525
|
+
|
|
526
|
+
slot_params generation_params;
|
|
527
|
+
|
|
528
|
+
// OAI-compat fields
|
|
529
|
+
bool verbose = false;
|
|
530
|
+
bool oaicompat = false;
|
|
531
|
+
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
|
|
532
|
+
std::string oaicompat_model;
|
|
533
|
+
std::string oaicompat_cmpl_id;
|
|
534
|
+
|
|
535
|
+
virtual int get_index() override {
|
|
536
|
+
return index;
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
virtual bool is_stop() override {
|
|
540
|
+
return true; // in stream mode, final responses are considered stop
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
virtual json to_json() override {
|
|
544
|
+
return oaicompat
|
|
545
|
+
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
|
|
546
|
+
: to_json_non_oaicompat();
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
json to_json_non_oaicompat() {
|
|
550
|
+
json res = json {
|
|
551
|
+
{"index", index},
|
|
552
|
+
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
|
553
|
+
{"tokens", stream ? llama_tokens {} : tokens},
|
|
554
|
+
{"id_slot", id_slot},
|
|
555
|
+
{"stop", true},
|
|
556
|
+
{"model", oaicompat_model},
|
|
557
|
+
{"tokens_predicted", n_decoded},
|
|
558
|
+
{"tokens_evaluated", n_prompt_tokens},
|
|
559
|
+
{"generation_settings", generation_params.to_json()},
|
|
560
|
+
{"prompt", prompt},
|
|
561
|
+
{"has_new_line", has_new_line},
|
|
562
|
+
{"truncated", truncated},
|
|
563
|
+
{"stop_type", stop_type_to_str(stop)},
|
|
564
|
+
{"stopping_word", stopping_word},
|
|
565
|
+
{"tokens_cached", n_tokens_cached},
|
|
566
|
+
{"timings", timings.to_json()},
|
|
567
|
+
};
|
|
568
|
+
if (!stream && !probs_output.empty()) {
|
|
569
|
+
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
|
570
|
+
}
|
|
571
|
+
return res;
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
json to_json_oaicompat_chat() {
|
|
575
|
+
std::string finish_reason = "length";
|
|
576
|
+
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
577
|
+
finish_reason = "stop";
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
json choice = json{
|
|
581
|
+
{"finish_reason", finish_reason},
|
|
582
|
+
{"index", 0},
|
|
583
|
+
{"message", json {
|
|
584
|
+
{"content", content},
|
|
585
|
+
{"role", "assistant"}
|
|
586
|
+
}
|
|
587
|
+
}};
|
|
588
|
+
|
|
589
|
+
if (!stream && probs_output.size() > 0) {
|
|
590
|
+
choice["logprobs"] = json{
|
|
591
|
+
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
|
592
|
+
};
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
std::time_t t = std::time(0);
|
|
596
|
+
|
|
597
|
+
json res = json {
|
|
598
|
+
{"choices", json::array({choice})},
|
|
599
|
+
{"created", t},
|
|
600
|
+
{"model", oaicompat_model},
|
|
601
|
+
{"object", "chat.completion"},
|
|
602
|
+
{"usage", json {
|
|
603
|
+
{"completion_tokens", n_decoded},
|
|
604
|
+
{"prompt_tokens", n_prompt_tokens},
|
|
605
|
+
{"total_tokens", n_decoded + n_prompt_tokens}
|
|
606
|
+
}},
|
|
607
|
+
{"id", oaicompat_cmpl_id}
|
|
608
|
+
};
|
|
609
|
+
|
|
610
|
+
// extra fields for debugging purposes
|
|
611
|
+
if (verbose) {
|
|
612
|
+
res["__verbose"] = to_json_non_oaicompat();
|
|
613
|
+
}
|
|
614
|
+
if (timings.prompt_n >= 0) {
|
|
615
|
+
res.push_back({"timings", timings.to_json()});
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
return res;
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
json to_json_oaicompat_chat_stream() {
|
|
622
|
+
std::time_t t = std::time(0);
|
|
623
|
+
std::string finish_reason = "length";
|
|
624
|
+
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
625
|
+
finish_reason = "stop";
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
json choice = json{
|
|
629
|
+
{"finish_reason", finish_reason},
|
|
630
|
+
{"index", 0},
|
|
631
|
+
{"delta", json::object()}
|
|
632
|
+
};
|
|
633
|
+
|
|
634
|
+
json ret = json {
|
|
635
|
+
{"choices", json::array({choice})},
|
|
636
|
+
{"created", t},
|
|
637
|
+
{"id", oaicompat_cmpl_id},
|
|
638
|
+
{"model", oaicompat_model},
|
|
639
|
+
{"object", "chat.completion.chunk"},
|
|
640
|
+
{"usage", json {
|
|
641
|
+
{"completion_tokens", n_decoded},
|
|
642
|
+
{"prompt_tokens", n_prompt_tokens},
|
|
643
|
+
{"total_tokens", n_decoded + n_prompt_tokens},
|
|
644
|
+
}},
|
|
645
|
+
};
|
|
646
|
+
|
|
647
|
+
if (timings.prompt_n >= 0) {
|
|
648
|
+
ret.push_back({"timings", timings.to_json()});
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
return ret;
|
|
652
|
+
}
|
|
653
|
+
};
|
|
654
|
+
|
|
655
|
+
struct server_task_result_cmpl_partial : server_task_result {
|
|
656
|
+
int index = 0;
|
|
657
|
+
|
|
658
|
+
std::string content;
|
|
659
|
+
llama_tokens tokens;
|
|
660
|
+
|
|
661
|
+
int32_t n_decoded;
|
|
662
|
+
int32_t n_prompt_tokens;
|
|
663
|
+
|
|
664
|
+
bool post_sampling_probs;
|
|
665
|
+
completion_token_output prob_output;
|
|
666
|
+
result_timings timings;
|
|
667
|
+
|
|
668
|
+
// OAI-compat fields
|
|
669
|
+
bool verbose = false;
|
|
670
|
+
bool oaicompat = false;
|
|
671
|
+
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
|
|
672
|
+
std::string oaicompat_model;
|
|
673
|
+
std::string oaicompat_cmpl_id;
|
|
674
|
+
|
|
675
|
+
virtual int get_index() override {
|
|
676
|
+
return index;
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
virtual bool is_stop() override {
|
|
680
|
+
return false; // in stream mode, partial responses are not considered stop
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
virtual json to_json() override {
|
|
684
|
+
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
json to_json_non_oaicompat() {
|
|
688
|
+
// non-OAI-compat JSON
|
|
689
|
+
json res = json {
|
|
690
|
+
{"index", index},
|
|
691
|
+
{"content", content},
|
|
692
|
+
{"tokens", tokens},
|
|
693
|
+
{"stop", false},
|
|
694
|
+
{"id_slot", id_slot},
|
|
695
|
+
{"tokens_predicted", n_decoded},
|
|
696
|
+
{"tokens_evaluated", n_prompt_tokens},
|
|
697
|
+
};
|
|
698
|
+
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
|
699
|
+
if (timings.prompt_n > 0) {
|
|
700
|
+
res.push_back({"timings", timings.to_json()});
|
|
701
|
+
}
|
|
702
|
+
if (!prob_output.probs.empty()) {
|
|
703
|
+
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
|
|
704
|
+
}
|
|
705
|
+
return res;
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
json to_json_oaicompat() {
|
|
709
|
+
bool first = n_decoded == 0;
|
|
710
|
+
std::time_t t = std::time(0);
|
|
711
|
+
json choices;
|
|
712
|
+
|
|
713
|
+
if (first) {
|
|
714
|
+
if (content.empty()) {
|
|
715
|
+
choices = json::array({json{{"finish_reason", nullptr},
|
|
716
|
+
{"index", 0},
|
|
717
|
+
{"delta", json{{"role", "assistant"}}}}});
|
|
718
|
+
} else {
|
|
719
|
+
// We have to send this as two updates to conform to openai behavior
|
|
720
|
+
json initial_ret = json{{"choices", json::array({json{
|
|
721
|
+
{"finish_reason", nullptr},
|
|
722
|
+
{"index", 0},
|
|
723
|
+
{"delta", json{
|
|
724
|
+
{"role", "assistant"}
|
|
725
|
+
}}}})},
|
|
726
|
+
{"created", t},
|
|
727
|
+
{"id", oaicompat_cmpl_id},
|
|
728
|
+
{"model", oaicompat_model},
|
|
729
|
+
{"object", "chat.completion.chunk"}};
|
|
730
|
+
|
|
731
|
+
json second_ret = json{
|
|
732
|
+
{"choices", json::array({json{{"finish_reason", nullptr},
|
|
733
|
+
{"index", 0},
|
|
734
|
+
{"delta", json {
|
|
735
|
+
{"content", content}}}
|
|
736
|
+
}})},
|
|
737
|
+
{"created", t},
|
|
738
|
+
{"id", oaicompat_cmpl_id},
|
|
739
|
+
{"model", oaicompat_model},
|
|
740
|
+
{"object", "chat.completion.chunk"}};
|
|
741
|
+
|
|
742
|
+
return std::vector<json>({initial_ret, second_ret});
|
|
743
|
+
}
|
|
744
|
+
} else {
|
|
745
|
+
choices = json::array({json{
|
|
746
|
+
{"finish_reason", nullptr},
|
|
747
|
+
{"index", 0},
|
|
748
|
+
{"delta",
|
|
749
|
+
json {
|
|
750
|
+
{"content", content},
|
|
751
|
+
}},
|
|
752
|
+
}});
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
GGML_ASSERT(choices.size() >= 1);
|
|
756
|
+
|
|
757
|
+
if (prob_output.probs.size() > 0) {
|
|
758
|
+
choices[0]["logprobs"] = json{
|
|
759
|
+
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
|
760
|
+
};
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
json ret = json {
|
|
764
|
+
{"choices", choices},
|
|
765
|
+
{"created", t},
|
|
766
|
+
{"id", oaicompat_cmpl_id},
|
|
767
|
+
{"model", oaicompat_model},
|
|
768
|
+
{"object", "chat.completion.chunk"}
|
|
769
|
+
};
|
|
770
|
+
|
|
771
|
+
if (timings.prompt_n >= 0) {
|
|
772
|
+
ret.push_back({"timings", timings.to_json()});
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
return std::vector<json>({ret});
|
|
776
|
+
}
|
|
777
|
+
};
|
|
778
|
+
|
|
779
|
+
struct server_task_result_embd : server_task_result {
|
|
780
|
+
int index = 0;
|
|
781
|
+
std::vector<std::vector<float>> embedding;
|
|
782
|
+
|
|
783
|
+
int32_t n_tokens;
|
|
784
|
+
|
|
785
|
+
// OAI-compat fields
|
|
786
|
+
bool oaicompat = false;
|
|
787
|
+
|
|
788
|
+
virtual int get_index() override {
|
|
789
|
+
return index;
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
virtual json to_json() override {
|
|
793
|
+
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
json to_json_non_oaicompat() {
|
|
797
|
+
return json {
|
|
798
|
+
{"index", index},
|
|
799
|
+
{"embedding", embedding},
|
|
800
|
+
};
|
|
801
|
+
}
|
|
802
|
+
|
|
803
|
+
json to_json_oaicompat() {
|
|
804
|
+
return json {
|
|
805
|
+
{"index", index},
|
|
806
|
+
{"embedding", embedding[0]},
|
|
807
|
+
{"tokens_evaluated", n_tokens},
|
|
808
|
+
};
|
|
809
|
+
}
|
|
810
|
+
};
|
|
811
|
+
|
|
812
|
+
struct server_task_result_rerank : server_task_result {
|
|
813
|
+
int index = 0;
|
|
814
|
+
float score = -1e6;
|
|
815
|
+
|
|
816
|
+
int32_t n_tokens;
|
|
817
|
+
|
|
818
|
+
virtual int get_index() override {
|
|
819
|
+
return index;
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
virtual json to_json() override {
|
|
823
|
+
return json {
|
|
824
|
+
{"index", index},
|
|
825
|
+
{"score", score},
|
|
826
|
+
{"tokens_evaluated", n_tokens},
|
|
827
|
+
};
|
|
828
|
+
}
|
|
829
|
+
};
|
|
830
|
+
|
|
831
|
+
// this function maybe used outside of server_task_result_error
|
|
832
|
+
static json format_error_response(const std::string & message, const enum error_type type) {
|
|
833
|
+
std::string type_str;
|
|
834
|
+
int code = 500;
|
|
835
|
+
switch (type) {
|
|
836
|
+
case ERROR_TYPE_INVALID_REQUEST:
|
|
837
|
+
type_str = "invalid_request_error";
|
|
838
|
+
code = 400;
|
|
839
|
+
break;
|
|
840
|
+
case ERROR_TYPE_AUTHENTICATION:
|
|
841
|
+
type_str = "authentication_error";
|
|
842
|
+
code = 401;
|
|
843
|
+
break;
|
|
844
|
+
case ERROR_TYPE_NOT_FOUND:
|
|
845
|
+
type_str = "not_found_error";
|
|
846
|
+
code = 404;
|
|
847
|
+
break;
|
|
848
|
+
case ERROR_TYPE_SERVER:
|
|
849
|
+
type_str = "server_error";
|
|
850
|
+
code = 500;
|
|
851
|
+
break;
|
|
852
|
+
case ERROR_TYPE_PERMISSION:
|
|
853
|
+
type_str = "permission_error";
|
|
854
|
+
code = 403;
|
|
855
|
+
break;
|
|
856
|
+
case ERROR_TYPE_NOT_SUPPORTED:
|
|
857
|
+
type_str = "not_supported_error";
|
|
858
|
+
code = 501;
|
|
859
|
+
break;
|
|
860
|
+
case ERROR_TYPE_UNAVAILABLE:
|
|
861
|
+
type_str = "unavailable_error";
|
|
862
|
+
code = 503;
|
|
863
|
+
break;
|
|
864
|
+
}
|
|
865
|
+
return json {
|
|
866
|
+
{"code", code},
|
|
867
|
+
{"message", message},
|
|
868
|
+
{"type", type_str},
|
|
869
|
+
};
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
struct server_task_result_error : server_task_result {
|
|
873
|
+
int index = 0;
|
|
874
|
+
error_type err_type = ERROR_TYPE_SERVER;
|
|
875
|
+
std::string err_msg;
|
|
876
|
+
|
|
877
|
+
virtual bool is_error() override {
|
|
878
|
+
return true;
|
|
879
|
+
}
|
|
880
|
+
|
|
881
|
+
virtual json to_json() override {
|
|
882
|
+
return format_error_response(err_msg, err_type);
|
|
883
|
+
}
|
|
884
|
+
};
|
|
885
|
+
|
|
886
|
+
struct server_task_result_metrics : server_task_result {
|
|
887
|
+
int n_idle_slots;
|
|
888
|
+
int n_processing_slots;
|
|
889
|
+
int n_tasks_deferred;
|
|
890
|
+
int64_t t_start;
|
|
891
|
+
|
|
892
|
+
int32_t kv_cache_tokens_count;
|
|
893
|
+
int32_t kv_cache_used_cells;
|
|
894
|
+
|
|
895
|
+
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
|
|
896
|
+
uint64_t n_prompt_tokens_processed_total = 0;
|
|
897
|
+
uint64_t t_prompt_processing_total = 0;
|
|
898
|
+
uint64_t n_tokens_predicted_total = 0;
|
|
899
|
+
uint64_t t_tokens_generation_total = 0;
|
|
900
|
+
|
|
901
|
+
uint64_t n_prompt_tokens_processed = 0;
|
|
902
|
+
uint64_t t_prompt_processing = 0;
|
|
903
|
+
|
|
904
|
+
uint64_t n_tokens_predicted = 0;
|
|
905
|
+
uint64_t t_tokens_generation = 0;
|
|
906
|
+
|
|
907
|
+
uint64_t n_decode_total = 0;
|
|
908
|
+
uint64_t n_busy_slots_total = 0;
|
|
909
|
+
|
|
910
|
+
// while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
|
|
911
|
+
// therefore, we use json to temporarily store the slot.to_json() result
|
|
912
|
+
json slots_data = json::array();
|
|
913
|
+
|
|
914
|
+
virtual json to_json() override {
|
|
915
|
+
return json {
|
|
916
|
+
{ "idle", n_idle_slots },
|
|
917
|
+
{ "processing", n_processing_slots },
|
|
918
|
+
{ "deferred", n_tasks_deferred },
|
|
919
|
+
{ "t_start", t_start },
|
|
920
|
+
|
|
921
|
+
{ "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
|
|
922
|
+
{ "t_tokens_generation_total", t_tokens_generation_total },
|
|
923
|
+
{ "n_tokens_predicted_total", n_tokens_predicted_total },
|
|
924
|
+
{ "t_prompt_processing_total", t_prompt_processing_total },
|
|
925
|
+
|
|
926
|
+
{ "n_prompt_tokens_processed", n_prompt_tokens_processed },
|
|
927
|
+
{ "t_prompt_processing", t_prompt_processing },
|
|
928
|
+
{ "n_tokens_predicted", n_tokens_predicted },
|
|
929
|
+
{ "t_tokens_generation", t_tokens_generation },
|
|
930
|
+
|
|
931
|
+
{ "n_decode_total", n_decode_total },
|
|
932
|
+
{ "n_busy_slots_total", n_busy_slots_total },
|
|
933
|
+
|
|
934
|
+
{ "kv_cache_tokens_count", kv_cache_tokens_count },
|
|
935
|
+
{ "kv_cache_used_cells", kv_cache_used_cells },
|
|
936
|
+
|
|
937
|
+
{ "slots", slots_data },
|
|
938
|
+
};
|
|
939
|
+
}
|
|
940
|
+
};
|
|
941
|
+
|
|
942
|
+
struct server_task_result_slot_save_load : server_task_result {
|
|
943
|
+
std::string filename;
|
|
944
|
+
bool is_save; // true = save, false = load
|
|
945
|
+
|
|
946
|
+
size_t n_tokens;
|
|
947
|
+
size_t n_bytes;
|
|
948
|
+
double t_ms;
|
|
949
|
+
|
|
950
|
+
virtual json to_json() override {
|
|
951
|
+
if (is_save) {
|
|
952
|
+
return json {
|
|
953
|
+
{ "id_slot", id_slot },
|
|
954
|
+
{ "filename", filename },
|
|
955
|
+
{ "n_saved", n_tokens },
|
|
956
|
+
{ "n_written", n_bytes },
|
|
957
|
+
{ "timings", {
|
|
958
|
+
{ "save_ms", t_ms }
|
|
959
|
+
}},
|
|
960
|
+
};
|
|
961
|
+
} else {
|
|
962
|
+
return json {
|
|
963
|
+
{ "id_slot", id_slot },
|
|
964
|
+
{ "filename", filename },
|
|
965
|
+
{ "n_restored", n_tokens },
|
|
966
|
+
{ "n_read", n_bytes },
|
|
967
|
+
{ "timings", {
|
|
968
|
+
{ "restore_ms", t_ms }
|
|
969
|
+
}},
|
|
970
|
+
};
|
|
971
|
+
}
|
|
972
|
+
}
|
|
973
|
+
};
|
|
974
|
+
|
|
975
|
+
struct server_task_result_slot_erase : server_task_result {
|
|
976
|
+
size_t n_erased;
|
|
977
|
+
|
|
978
|
+
virtual json to_json() override {
|
|
979
|
+
return json {
|
|
980
|
+
{ "id_slot", id_slot },
|
|
981
|
+
{ "n_erased", n_erased },
|
|
982
|
+
};
|
|
983
|
+
}
|
|
984
|
+
};
|
|
985
|
+
|
|
986
|
+
struct server_task_result_apply_lora : server_task_result {
|
|
987
|
+
virtual json to_json() override {
|
|
988
|
+
return json {{ "success", true }};
|
|
989
|
+
}
|
|
139
990
|
};
|
|
140
991
|
|
|
141
992
|
struct server_slot {
|
|
142
993
|
int id;
|
|
143
994
|
int id_task = -1;
|
|
144
995
|
|
|
996
|
+
// only used for completion/embedding/infill/rerank
|
|
997
|
+
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
|
998
|
+
|
|
999
|
+
llama_batch batch_spec = {};
|
|
1000
|
+
|
|
1001
|
+
llama_context * ctx = nullptr;
|
|
1002
|
+
llama_context * ctx_dft = nullptr;
|
|
1003
|
+
|
|
1004
|
+
common_speculative * spec = nullptr;
|
|
1005
|
+
|
|
145
1006
|
// the index relative to completion multi-task request
|
|
146
1007
|
size_t index = 0;
|
|
147
1008
|
|
|
@@ -160,54 +1021,44 @@ struct server_slot {
|
|
|
160
1021
|
int32_t i_batch = -1;
|
|
161
1022
|
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
|
162
1023
|
|
|
1024
|
+
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
|
|
163
1025
|
int32_t n_prompt_tokens = 0;
|
|
164
1026
|
int32_t n_prompt_tokens_processed = 0;
|
|
165
1027
|
|
|
166
|
-
|
|
1028
|
+
// input prompt tokens
|
|
1029
|
+
llama_tokens prompt_tokens;
|
|
167
1030
|
|
|
168
|
-
|
|
169
|
-
std::vector<llama_token> prompt_tokens;
|
|
1031
|
+
size_t last_nl_pos = 0;
|
|
170
1032
|
|
|
171
|
-
std::string
|
|
172
|
-
|
|
173
|
-
std::vector<completion_token_output> generated_token_probs;
|
|
1033
|
+
std::string generated_text;
|
|
1034
|
+
llama_tokens generated_tokens;
|
|
174
1035
|
|
|
175
|
-
|
|
1036
|
+
llama_tokens cache_tokens;
|
|
1037
|
+
|
|
1038
|
+
std::vector<completion_token_output> generated_token_probs;
|
|
176
1039
|
|
|
177
1040
|
bool has_next_token = true;
|
|
1041
|
+
bool has_new_line = false;
|
|
178
1042
|
bool truncated = false;
|
|
179
|
-
|
|
180
|
-
bool stopped_word = false;
|
|
181
|
-
bool stopped_limit = false;
|
|
182
|
-
|
|
183
|
-
bool oaicompat = false;
|
|
1043
|
+
stop_type stop;
|
|
184
1044
|
|
|
185
|
-
std::string oaicompat_model;
|
|
186
1045
|
std::string stopping_word;
|
|
187
1046
|
|
|
188
1047
|
// sampling
|
|
189
1048
|
json json_schema;
|
|
190
1049
|
|
|
191
|
-
struct
|
|
192
|
-
struct gpt_sampler * smpl = nullptr;
|
|
1050
|
+
struct common_sampler * smpl = nullptr;
|
|
193
1051
|
|
|
194
1052
|
llama_token sampled;
|
|
195
1053
|
|
|
196
|
-
int32_t ga_i = 0; // group-attention state
|
|
197
|
-
int32_t ga_n = 1; // group-attention factor
|
|
198
|
-
int32_t ga_w = 512; // group-attention width
|
|
199
|
-
|
|
200
|
-
int32_t n_past_se = 0; // self-extend
|
|
201
|
-
|
|
202
1054
|
// stats
|
|
203
|
-
size_t n_sent_text
|
|
204
|
-
size_t n_sent_token_probs = 0;
|
|
1055
|
+
size_t n_sent_text = 0; // number of sent text character
|
|
205
1056
|
|
|
206
1057
|
int64_t t_start_process_prompt;
|
|
207
1058
|
int64_t t_start_generation;
|
|
208
1059
|
|
|
209
1060
|
double t_prompt_processing; // ms
|
|
210
|
-
double t_token_generation;
|
|
1061
|
+
double t_token_generation; // ms
|
|
211
1062
|
|
|
212
1063
|
std::function<void(int)> callback_on_release;
|
|
213
1064
|
|
|
@@ -215,23 +1066,25 @@ struct server_slot {
|
|
|
215
1066
|
SLT_DBG(*this, "%s", "\n");
|
|
216
1067
|
|
|
217
1068
|
n_prompt_tokens = 0;
|
|
1069
|
+
last_nl_pos = 0;
|
|
218
1070
|
generated_text = "";
|
|
1071
|
+
has_new_line = false;
|
|
219
1072
|
truncated = false;
|
|
220
|
-
|
|
221
|
-
stopped_word = false;
|
|
222
|
-
stopped_limit = false;
|
|
1073
|
+
stop = STOP_TYPE_NONE;
|
|
223
1074
|
stopping_word = "";
|
|
224
1075
|
n_past = 0;
|
|
225
1076
|
n_sent_text = 0;
|
|
226
|
-
|
|
227
|
-
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
|
228
|
-
ga_i = 0;
|
|
229
|
-
n_past_se = 0;
|
|
1077
|
+
task_type = SERVER_TASK_TYPE_COMPLETION;
|
|
230
1078
|
|
|
1079
|
+
generated_tokens.clear();
|
|
231
1080
|
generated_token_probs.clear();
|
|
232
1081
|
}
|
|
233
1082
|
|
|
234
|
-
bool
|
|
1083
|
+
bool is_non_causal() const {
|
|
1084
|
+
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
|
1085
|
+
}
|
|
1086
|
+
|
|
1087
|
+
bool has_budget(const common_params & global_params) {
|
|
235
1088
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
|
236
1089
|
return true; // limitless
|
|
237
1090
|
}
|
|
@@ -251,6 +1104,10 @@ struct server_slot {
|
|
|
251
1104
|
return state != SLOT_STATE_IDLE;
|
|
252
1105
|
}
|
|
253
1106
|
|
|
1107
|
+
bool can_speculate() const {
|
|
1108
|
+
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
|
|
1109
|
+
}
|
|
1110
|
+
|
|
254
1111
|
void add_token(const completion_token_output & token) {
|
|
255
1112
|
if (!is_processing()) {
|
|
256
1113
|
SLT_WRN(*this, "%s", "slot is not processing\n");
|
|
@@ -263,44 +1120,47 @@ struct server_slot {
|
|
|
263
1120
|
if (is_processing()) {
|
|
264
1121
|
SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
|
|
265
1122
|
|
|
1123
|
+
t_last_used = ggml_time_us();
|
|
266
1124
|
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
|
267
1125
|
state = SLOT_STATE_IDLE;
|
|
268
1126
|
callback_on_release(id);
|
|
269
1127
|
}
|
|
270
1128
|
}
|
|
271
1129
|
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
1130
|
+
result_timings get_timings() const {
|
|
1131
|
+
result_timings timings;
|
|
1132
|
+
timings.prompt_n = n_prompt_tokens_processed;
|
|
1133
|
+
timings.prompt_ms = t_prompt_processing;
|
|
1134
|
+
timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
|
|
1135
|
+
timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
1136
|
+
|
|
1137
|
+
timings.predicted_n = n_decoded;
|
|
1138
|
+
timings.predicted_ms = t_token_generation;
|
|
1139
|
+
timings.predicted_per_token_ms = t_token_generation / n_decoded;
|
|
1140
|
+
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
|
|
1141
|
+
|
|
1142
|
+
return timings;
|
|
284
1143
|
}
|
|
285
1144
|
|
|
286
|
-
size_t find_stopping_strings(const std::string & text, const size_t last_token_size,
|
|
1145
|
+
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
|
287
1146
|
size_t stop_pos = std::string::npos;
|
|
288
1147
|
|
|
289
1148
|
for (const std::string & word : params.antiprompt) {
|
|
290
1149
|
size_t pos;
|
|
291
1150
|
|
|
292
|
-
if (
|
|
1151
|
+
if (is_full_stop) {
|
|
293
1152
|
const size_t tmp = word.size() + last_token_size;
|
|
294
1153
|
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
|
295
1154
|
|
|
296
1155
|
pos = text.find(word, from_pos);
|
|
297
1156
|
} else {
|
|
1157
|
+
// otherwise, partial stop
|
|
298
1158
|
pos = find_partial_stop_string(word, text);
|
|
299
1159
|
}
|
|
300
1160
|
|
|
301
1161
|
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
|
302
|
-
if (
|
|
303
|
-
|
|
1162
|
+
if (is_full_stop) {
|
|
1163
|
+
stop = STOP_TYPE_WORD;
|
|
304
1164
|
stopping_word = word;
|
|
305
1165
|
has_next_token = false;
|
|
306
1166
|
}
|
|
@@ -320,13 +1180,35 @@ struct server_slot {
|
|
|
320
1180
|
|
|
321
1181
|
SLT_INF(*this,
|
|
322
1182
|
"\n"
|
|
323
|
-
"
|
|
324
|
-
"
|
|
325
|
-
"
|
|
1183
|
+
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
1184
|
+
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
1185
|
+
" total time = %10.2f ms / %5d tokens\n",
|
|
326
1186
|
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
|
327
1187
|
t_token_generation, n_decoded, t_gen, n_gen_second,
|
|
328
1188
|
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
|
329
1189
|
}
|
|
1190
|
+
|
|
1191
|
+
json to_json() const {
|
|
1192
|
+
return json {
|
|
1193
|
+
{"id", id},
|
|
1194
|
+
{"id_task", id_task},
|
|
1195
|
+
{"n_ctx", n_ctx},
|
|
1196
|
+
{"speculative", can_speculate()},
|
|
1197
|
+
{"is_processing", is_processing()},
|
|
1198
|
+
{"non_causal", is_non_causal()},
|
|
1199
|
+
{"params", params.to_json()},
|
|
1200
|
+
{"prompt", common_detokenize(ctx, prompt_tokens)},
|
|
1201
|
+
{"next_token",
|
|
1202
|
+
{
|
|
1203
|
+
{"has_next_token", has_next_token},
|
|
1204
|
+
{"has_new_line", has_new_line},
|
|
1205
|
+
{"n_remain", n_remaining},
|
|
1206
|
+
{"n_decoded", n_decoded},
|
|
1207
|
+
{"stopping_word", stopping_word},
|
|
1208
|
+
}
|
|
1209
|
+
},
|
|
1210
|
+
};
|
|
1211
|
+
}
|
|
330
1212
|
};
|
|
331
1213
|
|
|
332
1214
|
struct server_metrics {
|
|
@@ -393,15 +1275,13 @@ struct server_queue {
|
|
|
393
1275
|
std::condition_variable condition_tasks;
|
|
394
1276
|
|
|
395
1277
|
// callback functions
|
|
396
|
-
std::function<void(server_task
|
|
397
|
-
std::function<void(void)>
|
|
1278
|
+
std::function<void(server_task)> callback_new_task;
|
|
1279
|
+
std::function<void(void)> callback_update_slots;
|
|
398
1280
|
|
|
399
1281
|
// Add a new task to the end of the queue
|
|
400
1282
|
int post(server_task task, bool front = false) {
|
|
401
1283
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
402
|
-
|
|
403
|
-
task.id = id++;
|
|
404
|
-
}
|
|
1284
|
+
GGML_ASSERT(task.id != -1);
|
|
405
1285
|
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
|
|
406
1286
|
if (front) {
|
|
407
1287
|
queue_tasks.push_front(std::move(task));
|
|
@@ -446,7 +1326,7 @@ struct server_queue {
|
|
|
446
1326
|
}
|
|
447
1327
|
|
|
448
1328
|
// Register function to process a new task
|
|
449
|
-
void on_new_task(std::function<void(server_task
|
|
1329
|
+
void on_new_task(std::function<void(server_task)> callback) {
|
|
450
1330
|
callback_new_task = std::move(callback);
|
|
451
1331
|
}
|
|
452
1332
|
|
|
@@ -496,7 +1376,7 @@ struct server_queue {
|
|
|
496
1376
|
lock.unlock();
|
|
497
1377
|
|
|
498
1378
|
QUE_DBG("processing task, id = %d\n", task.id);
|
|
499
|
-
callback_new_task(task);
|
|
1379
|
+
callback_new_task(std::move(task));
|
|
500
1380
|
}
|
|
501
1381
|
|
|
502
1382
|
// all tasks in the current loop is processed, slots data is now ready
|
|
@@ -525,8 +1405,8 @@ struct server_response {
|
|
|
525
1405
|
// for keeping track of all tasks waiting for the result
|
|
526
1406
|
std::unordered_set<int> waiting_task_ids;
|
|
527
1407
|
|
|
528
|
-
// the main result queue
|
|
529
|
-
std::vector<
|
|
1408
|
+
// the main result queue (using ptr for polymorphism)
|
|
1409
|
+
std::vector<server_task_result_ptr> queue_results;
|
|
530
1410
|
|
|
531
1411
|
std::mutex mutex_results;
|
|
532
1412
|
std::condition_variable condition_results;
|
|
@@ -566,7 +1446,7 @@ struct server_response {
|
|
|
566
1446
|
}
|
|
567
1447
|
|
|
568
1448
|
// This function blocks the thread until there is a response for one of the id_tasks
|
|
569
|
-
|
|
1449
|
+
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
|
|
570
1450
|
while (true) {
|
|
571
1451
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
572
1452
|
condition_results.wait(lock, [&]{
|
|
@@ -574,8 +1454,8 @@ struct server_response {
|
|
|
574
1454
|
});
|
|
575
1455
|
|
|
576
1456
|
for (int i = 0; i < (int) queue_results.size(); i++) {
|
|
577
|
-
if (id_tasks.find(queue_results[i]
|
|
578
|
-
|
|
1457
|
+
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
|
1458
|
+
server_task_result_ptr res = std::move(queue_results[i]);
|
|
579
1459
|
queue_results.erase(queue_results.begin() + i);
|
|
580
1460
|
return res;
|
|
581
1461
|
}
|
|
@@ -586,21 +1466,21 @@ struct server_response {
|
|
|
586
1466
|
}
|
|
587
1467
|
|
|
588
1468
|
// single-task version of recv()
|
|
589
|
-
|
|
1469
|
+
server_task_result_ptr recv(int id_task) {
|
|
590
1470
|
std::unordered_set<int> id_tasks = {id_task};
|
|
591
1471
|
return recv(id_tasks);
|
|
592
1472
|
}
|
|
593
1473
|
|
|
594
1474
|
// Send a new result to a waiting id_task
|
|
595
|
-
void send(
|
|
596
|
-
SRV_DBG("sending result for task id = %d\n", result
|
|
1475
|
+
void send(server_task_result_ptr && result) {
|
|
1476
|
+
SRV_DBG("sending result for task id = %d\n", result->id);
|
|
597
1477
|
|
|
598
1478
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
599
1479
|
for (const auto & id_task : waiting_task_ids) {
|
|
600
|
-
if (result
|
|
601
|
-
SRV_DBG("task id = %d
|
|
1480
|
+
if (result->id == id_task) {
|
|
1481
|
+
SRV_DBG("task id = %d pushed to result queue\n", result->id);
|
|
602
1482
|
|
|
603
|
-
queue_results.
|
|
1483
|
+
queue_results.emplace_back(std::move(result));
|
|
604
1484
|
condition_results.notify_all();
|
|
605
1485
|
return;
|
|
606
1486
|
}
|
|
@@ -609,11 +1489,14 @@ struct server_response {
|
|
|
609
1489
|
};
|
|
610
1490
|
|
|
611
1491
|
struct server_context {
|
|
1492
|
+
common_params params_base;
|
|
1493
|
+
|
|
612
1494
|
llama_model * model = nullptr;
|
|
613
1495
|
llama_context * ctx = nullptr;
|
|
614
|
-
std::vector<
|
|
1496
|
+
std::vector<common_lora_adapter_container> loras;
|
|
615
1497
|
|
|
616
|
-
|
|
1498
|
+
llama_model * model_dft = nullptr;
|
|
1499
|
+
llama_context_params cparams_dft;
|
|
617
1500
|
|
|
618
1501
|
llama_batch batch = {};
|
|
619
1502
|
|
|
@@ -623,12 +1506,6 @@ struct server_context {
|
|
|
623
1506
|
|
|
624
1507
|
int32_t n_ctx; // total context for all clients / slots
|
|
625
1508
|
|
|
626
|
-
// system prompt
|
|
627
|
-
bool system_need_update = false;
|
|
628
|
-
|
|
629
|
-
std::string system_prompt;
|
|
630
|
-
std::vector<llama_token> system_tokens;
|
|
631
|
-
|
|
632
1509
|
// slots / clients
|
|
633
1510
|
std::vector<server_slot> slots;
|
|
634
1511
|
json default_generation_settings_for_props;
|
|
@@ -652,82 +1529,139 @@ struct server_context {
|
|
|
652
1529
|
model = nullptr;
|
|
653
1530
|
}
|
|
654
1531
|
|
|
1532
|
+
if (model_dft) {
|
|
1533
|
+
llama_free_model(model_dft);
|
|
1534
|
+
model_dft = nullptr;
|
|
1535
|
+
}
|
|
1536
|
+
|
|
655
1537
|
// Clear any sampling context
|
|
656
1538
|
for (server_slot & slot : slots) {
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
1539
|
+
common_sampler_free(slot.smpl);
|
|
1540
|
+
slot.smpl = nullptr;
|
|
1541
|
+
|
|
1542
|
+
llama_free(slot.ctx_dft);
|
|
1543
|
+
slot.ctx_dft = nullptr;
|
|
1544
|
+
|
|
1545
|
+
common_speculative_free(slot.spec);
|
|
1546
|
+
slot.spec = nullptr;
|
|
1547
|
+
|
|
1548
|
+
llama_batch_free(slot.batch_spec);
|
|
660
1549
|
}
|
|
661
1550
|
|
|
662
1551
|
llama_batch_free(batch);
|
|
663
1552
|
}
|
|
664
1553
|
|
|
665
|
-
bool load_model(const
|
|
666
|
-
|
|
1554
|
+
bool load_model(const common_params & params) {
|
|
1555
|
+
SRV_INF("loading model '%s'\n", params.model.c_str());
|
|
667
1556
|
|
|
668
|
-
|
|
669
|
-
params.n_parallel += 1;
|
|
1557
|
+
params_base = params;
|
|
670
1558
|
|
|
671
|
-
|
|
1559
|
+
common_init_result llama_init = common_init_from_params(params_base);
|
|
672
1560
|
|
|
673
1561
|
model = llama_init.model;
|
|
674
1562
|
ctx = llama_init.context;
|
|
675
1563
|
loras = llama_init.lora_adapters;
|
|
676
1564
|
|
|
677
|
-
params.n_parallel -= 1; // but be sneaky about it
|
|
678
|
-
|
|
679
1565
|
if (model == nullptr) {
|
|
680
|
-
SRV_ERR("failed to load model, '%s'\n",
|
|
1566
|
+
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
|
681
1567
|
return false;
|
|
682
1568
|
}
|
|
683
1569
|
|
|
684
1570
|
n_ctx = llama_n_ctx(ctx);
|
|
685
1571
|
|
|
686
1572
|
add_bos_token = llama_add_bos_token(model);
|
|
687
|
-
has_eos_token =
|
|
1573
|
+
has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
|
|
1574
|
+
|
|
1575
|
+
if (!params_base.speculative.model.empty()) {
|
|
1576
|
+
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
|
|
1577
|
+
|
|
1578
|
+
auto params_dft = params_base;
|
|
1579
|
+
|
|
1580
|
+
params_dft.devices = params_base.speculative.devices;
|
|
1581
|
+
params_dft.model = params_base.speculative.model;
|
|
1582
|
+
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
|
1583
|
+
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
|
1584
|
+
params_dft.n_parallel = 1;
|
|
1585
|
+
|
|
1586
|
+
common_init_result llama_init_dft = common_init_from_params(params_dft);
|
|
1587
|
+
|
|
1588
|
+
model_dft = llama_init_dft.model;
|
|
1589
|
+
|
|
1590
|
+
if (model_dft == nullptr) {
|
|
1591
|
+
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
|
|
1592
|
+
return false;
|
|
1593
|
+
}
|
|
1594
|
+
|
|
1595
|
+
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
|
|
1596
|
+
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
|
|
1597
|
+
|
|
1598
|
+
llama_free (llama_init_dft.context);
|
|
1599
|
+
llama_free_model(llama_init_dft.model);
|
|
1600
|
+
|
|
1601
|
+
return false;
|
|
1602
|
+
}
|
|
1603
|
+
|
|
1604
|
+
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
|
|
1605
|
+
|
|
1606
|
+
cparams_dft = common_context_params_to_llama(params_dft);
|
|
1607
|
+
cparams_dft.n_batch = n_ctx_dft;
|
|
1608
|
+
|
|
1609
|
+
// force F16 KV cache for the draft model for extra performance
|
|
1610
|
+
cparams_dft.type_k = GGML_TYPE_F16;
|
|
1611
|
+
cparams_dft.type_v = GGML_TYPE_F16;
|
|
1612
|
+
|
|
1613
|
+
// the context is not needed - we will create one for each slot
|
|
1614
|
+
llama_free(llama_init_dft.context);
|
|
1615
|
+
}
|
|
688
1616
|
|
|
689
1617
|
return true;
|
|
690
1618
|
}
|
|
691
1619
|
|
|
692
1620
|
bool validate_model_chat_template() const {
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
1621
|
+
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
|
1622
|
+
std::string template_key = "tokenizer.chat_template";
|
|
1623
|
+
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
|
1624
|
+
if (res >= 0) {
|
|
1625
|
+
llama_chat_message chat[] = {{"user", "test"}};
|
|
1626
|
+
std::string tmpl = std::string(model_template.data(), model_template.size());
|
|
1627
|
+
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
|
1628
|
+
return chat_res > 0;
|
|
1629
|
+
}
|
|
1630
|
+
return false;
|
|
698
1631
|
}
|
|
699
1632
|
|
|
700
1633
|
void init() {
|
|
701
|
-
const int32_t n_ctx_slot = n_ctx /
|
|
1634
|
+
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
|
702
1635
|
|
|
703
|
-
SRV_INF("initializing slots, n_slots = %d\n",
|
|
1636
|
+
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
|
704
1637
|
|
|
705
|
-
for (int i = 0; i <
|
|
1638
|
+
for (int i = 0; i < params_base.n_parallel; i++) {
|
|
706
1639
|
server_slot slot;
|
|
707
1640
|
|
|
708
1641
|
slot.id = i;
|
|
1642
|
+
slot.ctx = ctx;
|
|
709
1643
|
slot.n_ctx = n_ctx_slot;
|
|
710
|
-
slot.n_predict =
|
|
1644
|
+
slot.n_predict = params_base.n_predict;
|
|
711
1645
|
|
|
712
|
-
|
|
713
|
-
|
|
714
|
-
const int ga_n = params.grp_attn_n;
|
|
715
|
-
const int ga_w = params.grp_attn_w;
|
|
1646
|
+
if (model_dft) {
|
|
1647
|
+
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
|
716
1648
|
|
|
717
|
-
|
|
718
|
-
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
1649
|
+
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
|
|
1650
|
+
if (slot.ctx_dft == nullptr) {
|
|
1651
|
+
SRV_ERR("%s", "failed to create draft context\n");
|
|
1652
|
+
return;
|
|
1653
|
+
}
|
|
722
1654
|
|
|
723
|
-
|
|
1655
|
+
slot.spec = common_speculative_init(slot.ctx_dft);
|
|
1656
|
+
if (slot.spec == nullptr) {
|
|
1657
|
+
SRV_ERR("%s", "failed to create speculator\n");
|
|
1658
|
+
return;
|
|
1659
|
+
}
|
|
724
1660
|
}
|
|
725
1661
|
|
|
726
|
-
slot
|
|
727
|
-
slot.ga_n = ga_n;
|
|
728
|
-
slot.ga_w = ga_w;
|
|
1662
|
+
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
|
729
1663
|
|
|
730
|
-
slot.
|
|
1664
|
+
slot.params.sampling = params_base.sampling;
|
|
731
1665
|
|
|
732
1666
|
slot.callback_on_release = [this](int) {
|
|
733
1667
|
queue_tasks.pop_deferred_task();
|
|
@@ -738,60 +1672,18 @@ struct server_context {
|
|
|
738
1672
|
slots.push_back(slot);
|
|
739
1673
|
}
|
|
740
1674
|
|
|
741
|
-
default_generation_settings_for_props =
|
|
742
|
-
default_generation_settings_for_props["seed"] = -1;
|
|
1675
|
+
default_generation_settings_for_props = slots[0].to_json();
|
|
743
1676
|
|
|
744
1677
|
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
|
745
1678
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
746
|
-
{
|
|
747
|
-
const int32_t n_batch = llama_n_batch(ctx);
|
|
748
|
-
|
|
749
|
-
// only a single seq_id per token is needed
|
|
750
|
-
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
|
|
751
|
-
}
|
|
752
|
-
|
|
753
|
-
metrics.init();
|
|
754
|
-
}
|
|
755
|
-
|
|
756
|
-
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
|
|
757
|
-
// TODO: currently, we tokenize using special tokens by default
|
|
758
|
-
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
|
|
759
|
-
// but it's better compared to completely ignoring ChatML and other chat templates
|
|
760
|
-
const bool TMP_FORCE_SPECIAL = true;
|
|
761
|
-
|
|
762
|
-
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
763
|
-
// or the first element of the json_prompt array is a string.
|
|
764
|
-
std::vector<llama_token> prompt_tokens;
|
|
765
|
-
|
|
766
|
-
if (json_prompt.is_array()) {
|
|
767
|
-
bool first = true;
|
|
768
|
-
for (const auto & p : json_prompt) {
|
|
769
|
-
if (p.is_string()) {
|
|
770
|
-
auto s = p.template get<std::string>();
|
|
771
|
-
|
|
772
|
-
std::vector<llama_token> p;
|
|
773
|
-
if (first) {
|
|
774
|
-
p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
775
|
-
first = false;
|
|
776
|
-
} else {
|
|
777
|
-
p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
|
|
778
|
-
}
|
|
779
|
-
|
|
780
|
-
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
781
|
-
} else {
|
|
782
|
-
if (first) {
|
|
783
|
-
first = false;
|
|
784
|
-
}
|
|
1679
|
+
{
|
|
1680
|
+
const int32_t n_batch = llama_n_batch(ctx);
|
|
785
1681
|
|
|
786
|
-
|
|
787
|
-
|
|
788
|
-
}
|
|
789
|
-
} else {
|
|
790
|
-
auto s = json_prompt.template get<std::string>();
|
|
791
|
-
prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
1682
|
+
// only a single seq_id per token is needed
|
|
1683
|
+
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
|
792
1684
|
}
|
|
793
1685
|
|
|
794
|
-
|
|
1686
|
+
metrics.init();
|
|
795
1687
|
}
|
|
796
1688
|
|
|
797
1689
|
server_slot * get_slot_by_id(int id) {
|
|
@@ -804,12 +1696,12 @@ struct server_context {
|
|
|
804
1696
|
return nullptr;
|
|
805
1697
|
}
|
|
806
1698
|
|
|
807
|
-
server_slot * get_available_slot(const
|
|
1699
|
+
server_slot * get_available_slot(const server_task & task) {
|
|
808
1700
|
server_slot * ret = nullptr;
|
|
809
1701
|
|
|
810
1702
|
// find the slot that has at least n% prompt similarity
|
|
811
|
-
if (ret == nullptr && slot_prompt_similarity != 0.0f
|
|
812
|
-
int
|
|
1703
|
+
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
|
1704
|
+
int lcs_len = 0;
|
|
813
1705
|
float similarity = 0;
|
|
814
1706
|
|
|
815
1707
|
for (server_slot & slot : slots) {
|
|
@@ -818,32 +1710,27 @@ struct server_context {
|
|
|
818
1710
|
continue;
|
|
819
1711
|
}
|
|
820
1712
|
|
|
821
|
-
// skip the slot if it does not contains
|
|
822
|
-
if (
|
|
1713
|
+
// skip the slot if it does not contains cached tokens
|
|
1714
|
+
if (slot.cache_tokens.empty()) {
|
|
823
1715
|
continue;
|
|
824
1716
|
}
|
|
825
1717
|
|
|
826
|
-
// current slot's prompt
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
// length of the current slot's prompt
|
|
830
|
-
int slot_prompt_len = slot_prompt.size();
|
|
1718
|
+
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
|
|
1719
|
+
int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
|
|
831
1720
|
|
|
832
|
-
//
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
// fraction of the common substring length compared to the current slot's prompt length
|
|
836
|
-
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
|
1721
|
+
// fraction of the common subsequence length compared to the current slot's prompt length
|
|
1722
|
+
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
|
|
837
1723
|
|
|
838
1724
|
// select the current slot if the criteria match
|
|
839
|
-
if (
|
|
840
|
-
|
|
1725
|
+
if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
|
|
1726
|
+
lcs_len = cur_lcs_len;
|
|
1727
|
+
similarity = cur_similarity;
|
|
841
1728
|
ret = &slot;
|
|
842
1729
|
}
|
|
843
1730
|
}
|
|
844
1731
|
|
|
845
1732
|
if (ret != nullptr) {
|
|
846
|
-
SLT_DBG(*ret, "selected slot by
|
|
1733
|
+
SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity);
|
|
847
1734
|
}
|
|
848
1735
|
}
|
|
849
1736
|
|
|
@@ -872,65 +1759,14 @@ struct server_context {
|
|
|
872
1759
|
}
|
|
873
1760
|
|
|
874
1761
|
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
|
875
|
-
|
|
876
|
-
|
|
877
|
-
|
|
878
|
-
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
slot.oaicompat = true;
|
|
882
|
-
slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
|
883
|
-
} else {
|
|
884
|
-
slot.oaicompat = false;
|
|
885
|
-
slot.oaicompat_model = "";
|
|
886
|
-
}
|
|
887
|
-
|
|
888
|
-
slot.params.stream = json_value(data, "stream", false);
|
|
889
|
-
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
|
890
|
-
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
|
891
|
-
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
|
892
|
-
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
|
893
|
-
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
|
894
|
-
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
|
895
|
-
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
|
896
|
-
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
|
897
|
-
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
|
898
|
-
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
|
899
|
-
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
|
900
|
-
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
|
901
|
-
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
|
902
|
-
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
|
903
|
-
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
|
904
|
-
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
|
905
|
-
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
|
906
|
-
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
|
907
|
-
slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
|
|
908
|
-
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
|
909
|
-
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
|
910
|
-
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
|
911
|
-
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
|
912
|
-
|
|
913
|
-
// process "json_schema" and "grammar"
|
|
914
|
-
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
|
915
|
-
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
|
916
|
-
return false;
|
|
917
|
-
}
|
|
918
|
-
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
919
|
-
try {
|
|
920
|
-
auto schema = json_value(data, "json_schema", json::object());
|
|
921
|
-
slot.sparams.grammar = json_schema_to_grammar(schema);
|
|
922
|
-
} catch (const std::exception & e) {
|
|
923
|
-
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
|
|
924
|
-
return false;
|
|
925
|
-
}
|
|
926
|
-
} else {
|
|
927
|
-
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
|
928
|
-
}
|
|
1762
|
+
slot.reset();
|
|
1763
|
+
slot.id_task = task.id;
|
|
1764
|
+
slot.index = task.index;
|
|
1765
|
+
slot.task_type = task.type;
|
|
1766
|
+
slot.params = std::move(task.params);
|
|
1767
|
+
slot.prompt_tokens = std::move(task.prompt_tokens);
|
|
929
1768
|
|
|
930
|
-
|
|
931
|
-
slot.params.cache_prompt = false;
|
|
932
|
-
SLT_WRN(slot, "%s", "group-attention is not supported with prompt caching. disabling cache\n");
|
|
933
|
-
}
|
|
1769
|
+
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
|
934
1770
|
|
|
935
1771
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
|
936
1772
|
// Might be better to reject the request with a 400 ?
|
|
@@ -938,111 +1774,16 @@ struct server_context {
|
|
|
938
1774
|
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
|
939
1775
|
}
|
|
940
1776
|
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
|
|
944
|
-
|
|
945
|
-
// get prompt
|
|
946
|
-
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
|
|
947
|
-
const auto & prompt = data.find("prompt");
|
|
948
|
-
if (prompt == data.end()) {
|
|
949
|
-
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
|
950
|
-
return false;
|
|
951
|
-
}
|
|
952
|
-
|
|
953
|
-
if ((prompt->is_string()) ||
|
|
954
|
-
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
|
955
|
-
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
|
956
|
-
slot.prompt = *prompt;
|
|
957
|
-
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
|
958
|
-
slot.prompt = prompt->at(0);
|
|
959
|
-
} else if (prompt->is_array() && prompt->size() > 1) {
|
|
960
|
-
// array of strings
|
|
961
|
-
for (const auto & el : *prompt) {
|
|
962
|
-
if (!el.is_string()) {
|
|
963
|
-
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
964
|
-
return false;
|
|
965
|
-
}
|
|
966
|
-
}
|
|
967
|
-
slot.prompt = *prompt;
|
|
968
|
-
} else {
|
|
969
|
-
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
970
|
-
return false;
|
|
971
|
-
}
|
|
972
|
-
}
|
|
973
|
-
|
|
974
|
-
{
|
|
975
|
-
slot.sparams.logit_bias.clear();
|
|
976
|
-
|
|
977
|
-
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
|
978
|
-
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
|
979
|
-
}
|
|
980
|
-
|
|
981
|
-
const auto & logit_bias = data.find("logit_bias");
|
|
982
|
-
if (logit_bias != data.end() && logit_bias->is_array()) {
|
|
983
|
-
const int n_vocab = llama_n_vocab(model);
|
|
984
|
-
for (const auto & el : *logit_bias) {
|
|
985
|
-
// TODO: we may want to throw errors here, in case "el" is incorrect
|
|
986
|
-
if (el.is_array() && el.size() == 2) {
|
|
987
|
-
float bias;
|
|
988
|
-
if (el[1].is_number()) {
|
|
989
|
-
bias = el[1].get<float>();
|
|
990
|
-
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
|
991
|
-
bias = -INFINITY;
|
|
992
|
-
} else {
|
|
993
|
-
continue;
|
|
994
|
-
}
|
|
995
|
-
|
|
996
|
-
if (el[0].is_number_integer()) {
|
|
997
|
-
llama_token tok = el[0].get<llama_token>();
|
|
998
|
-
if (tok >= 0 && tok < n_vocab) {
|
|
999
|
-
slot.sparams.logit_bias.push_back({tok, bias});
|
|
1000
|
-
}
|
|
1001
|
-
} else if (el[0].is_string()) {
|
|
1002
|
-
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
|
1003
|
-
for (auto tok : toks) {
|
|
1004
|
-
slot.sparams.logit_bias.push_back({tok, bias});
|
|
1005
|
-
}
|
|
1006
|
-
}
|
|
1007
|
-
}
|
|
1008
|
-
}
|
|
1009
|
-
}
|
|
1010
|
-
}
|
|
1011
|
-
|
|
1012
|
-
{
|
|
1013
|
-
slot.params.antiprompt.clear();
|
|
1014
|
-
|
|
1015
|
-
const auto & stop = data.find("stop");
|
|
1016
|
-
if (stop != data.end() && stop->is_array()) {
|
|
1017
|
-
for (const auto & word : *stop) {
|
|
1018
|
-
if (!word.empty()) {
|
|
1019
|
-
slot.params.antiprompt.push_back(word);
|
|
1020
|
-
}
|
|
1021
|
-
}
|
|
1022
|
-
}
|
|
1023
|
-
}
|
|
1024
|
-
|
|
1025
|
-
{
|
|
1026
|
-
const auto & samplers = data.find("samplers");
|
|
1027
|
-
if (samplers != data.end() && samplers->is_array()) {
|
|
1028
|
-
std::vector<std::string> sampler_names;
|
|
1029
|
-
for (const auto & name : *samplers) {
|
|
1030
|
-
if (name.is_string()) {
|
|
1031
|
-
sampler_names.emplace_back(name);
|
|
1032
|
-
}
|
|
1033
|
-
}
|
|
1034
|
-
slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
|
|
1035
|
-
} else {
|
|
1036
|
-
slot.sparams.samplers = default_sparams.samplers;
|
|
1037
|
-
}
|
|
1777
|
+
if (slot.params.ignore_eos && has_eos_token) {
|
|
1778
|
+
slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
|
1038
1779
|
}
|
|
1039
1780
|
|
|
1040
1781
|
{
|
|
1041
1782
|
if (slot.smpl != nullptr) {
|
|
1042
|
-
|
|
1783
|
+
common_sampler_free(slot.smpl);
|
|
1043
1784
|
}
|
|
1044
1785
|
|
|
1045
|
-
slot.smpl =
|
|
1786
|
+
slot.smpl = common_sampler_init(model, slot.params.sampling);
|
|
1046
1787
|
if (slot.smpl == nullptr) {
|
|
1047
1788
|
// for now, the only error that may happen here is invalid grammar
|
|
1048
1789
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -1050,8 +1791,13 @@ struct server_context {
|
|
|
1050
1791
|
}
|
|
1051
1792
|
}
|
|
1052
1793
|
|
|
1053
|
-
slot.
|
|
1054
|
-
|
|
1794
|
+
if (slot.ctx_dft) {
|
|
1795
|
+
llama_batch_free(slot.batch_spec);
|
|
1796
|
+
|
|
1797
|
+
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
|
1798
|
+
}
|
|
1799
|
+
|
|
1800
|
+
slot.state = SLOT_STATE_STARTED;
|
|
1055
1801
|
|
|
1056
1802
|
SLT_INF(slot, "%s", "processing task\n");
|
|
1057
1803
|
|
|
@@ -1066,107 +1812,40 @@ struct server_context {
|
|
|
1066
1812
|
clean_kv_cache = false;
|
|
1067
1813
|
}
|
|
1068
1814
|
|
|
1069
|
-
void system_prompt_update() {
|
|
1070
|
-
SRV_DBG("updating system prompt: '%s'\n", system_prompt.c_str());
|
|
1071
|
-
|
|
1072
|
-
kv_cache_clear();
|
|
1073
|
-
system_tokens.clear();
|
|
1074
|
-
|
|
1075
|
-
if (!system_prompt.empty()) {
|
|
1076
|
-
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
|
1077
|
-
|
|
1078
|
-
const int32_t n_batch = llama_n_batch(ctx);
|
|
1079
|
-
const int32_t n_tokens_prompt = system_tokens.size();
|
|
1080
|
-
|
|
1081
|
-
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
|
|
1082
|
-
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
|
|
1083
|
-
|
|
1084
|
-
llama_batch_clear(batch);
|
|
1085
|
-
|
|
1086
|
-
for (int32_t j = 0; j < n_tokens; ++j) {
|
|
1087
|
-
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
|
1088
|
-
}
|
|
1089
|
-
|
|
1090
|
-
if (llama_decode(ctx, batch) != 0) {
|
|
1091
|
-
SRV_ERR("%s", "llama_decode() failed\n");
|
|
1092
|
-
return;
|
|
1093
|
-
}
|
|
1094
|
-
}
|
|
1095
|
-
|
|
1096
|
-
// assign the system KV cache to all parallel sequences
|
|
1097
|
-
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
|
1098
|
-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
|
1099
|
-
}
|
|
1100
|
-
}
|
|
1101
|
-
|
|
1102
|
-
system_need_update = false;
|
|
1103
|
-
}
|
|
1104
|
-
|
|
1105
|
-
bool system_prompt_set(const std::string & sys_prompt) {
|
|
1106
|
-
SRV_DBG("system prompt set: '%s'\n", system_prompt.c_str());
|
|
1107
|
-
|
|
1108
|
-
system_prompt = sys_prompt;
|
|
1109
|
-
|
|
1110
|
-
// release all slots
|
|
1111
|
-
for (server_slot & slot : slots) {
|
|
1112
|
-
slot.release();
|
|
1113
|
-
}
|
|
1114
|
-
|
|
1115
|
-
system_need_update = true;
|
|
1116
|
-
return true;
|
|
1117
|
-
}
|
|
1118
|
-
|
|
1119
1815
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
|
1120
1816
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
1121
|
-
const std::string token_str =
|
|
1817
|
+
const std::string token_str = result.text_to_send;
|
|
1122
1818
|
slot.sampled = result.tok;
|
|
1123
1819
|
|
|
1124
|
-
// search stop word and delete it
|
|
1125
1820
|
slot.generated_text += token_str;
|
|
1821
|
+
if (slot.params.return_tokens) {
|
|
1822
|
+
slot.generated_tokens.push_back(result.tok);
|
|
1823
|
+
}
|
|
1126
1824
|
slot.has_next_token = true;
|
|
1127
1825
|
|
|
1128
1826
|
// check if there is incomplete UTF-8 character at the end
|
|
1129
|
-
bool incomplete =
|
|
1130
|
-
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
|
1131
|
-
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
|
|
1132
|
-
if ((c & 0xC0) == 0x80) {
|
|
1133
|
-
// continuation byte: 10xxxxxx
|
|
1134
|
-
continue;
|
|
1135
|
-
}
|
|
1136
|
-
if ((c & 0xE0) == 0xC0) {
|
|
1137
|
-
// 2-byte character: 110xxxxx ...
|
|
1138
|
-
incomplete = i < 2;
|
|
1139
|
-
} else if ((c & 0xF0) == 0xE0) {
|
|
1140
|
-
// 3-byte character: 1110xxxx ...
|
|
1141
|
-
incomplete = i < 3;
|
|
1142
|
-
} else if ((c & 0xF8) == 0xF0) {
|
|
1143
|
-
// 4-byte character: 11110xxx ...
|
|
1144
|
-
incomplete = i < 4;
|
|
1145
|
-
}
|
|
1146
|
-
// else 1-byte character or invalid byte
|
|
1147
|
-
break;
|
|
1148
|
-
}
|
|
1827
|
+
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
|
1149
1828
|
|
|
1829
|
+
// search stop word and delete it
|
|
1150
1830
|
if (!incomplete) {
|
|
1151
1831
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
1152
1832
|
|
|
1153
1833
|
const std::string str_test = slot.generated_text.substr(pos);
|
|
1154
|
-
bool
|
|
1834
|
+
bool send_text = true;
|
|
1155
1835
|
|
|
1156
|
-
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(),
|
|
1836
|
+
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
|
|
1157
1837
|
if (stop_pos != std::string::npos) {
|
|
1158
|
-
is_stop_full = true;
|
|
1159
1838
|
slot.generated_text.erase(
|
|
1160
1839
|
slot.generated_text.begin() + pos + stop_pos,
|
|
1161
1840
|
slot.generated_text.end());
|
|
1162
1841
|
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
1163
|
-
} else {
|
|
1164
|
-
|
|
1165
|
-
|
|
1842
|
+
} else if (slot.has_next_token) {
|
|
1843
|
+
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
|
|
1844
|
+
send_text = stop_pos == std::string::npos;
|
|
1166
1845
|
}
|
|
1167
1846
|
|
|
1168
1847
|
// check if there is any token to predict
|
|
1169
|
-
if (
|
|
1848
|
+
if (send_text) {
|
|
1170
1849
|
// no send the stop word in the response
|
|
1171
1850
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
|
1172
1851
|
slot.n_sent_text += result.text_to_send.size();
|
|
@@ -1184,24 +1863,74 @@ struct server_context {
|
|
|
1184
1863
|
}
|
|
1185
1864
|
|
|
1186
1865
|
// check the limits
|
|
1187
|
-
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(
|
|
1188
|
-
slot.
|
|
1866
|
+
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
|
|
1867
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1189
1868
|
slot.has_next_token = false;
|
|
1190
1869
|
|
|
1191
1870
|
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
|
1192
1871
|
}
|
|
1193
1872
|
|
|
1873
|
+
if (slot.has_new_line) {
|
|
1874
|
+
// if we have already seen a new line, we stop after a certain time limit
|
|
1875
|
+
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
|
1876
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1877
|
+
slot.has_next_token = false;
|
|
1878
|
+
|
|
1879
|
+
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
|
1880
|
+
}
|
|
1881
|
+
|
|
1882
|
+
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
|
|
1883
|
+
if (slot.params.n_indent > 0) {
|
|
1884
|
+
// check the current indentation
|
|
1885
|
+
// TODO: improve by not doing it more than once for each new line
|
|
1886
|
+
if (slot.last_nl_pos > 0) {
|
|
1887
|
+
size_t pos = slot.last_nl_pos;
|
|
1888
|
+
|
|
1889
|
+
int n_indent = 0;
|
|
1890
|
+
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
|
|
1891
|
+
n_indent++;
|
|
1892
|
+
pos++;
|
|
1893
|
+
}
|
|
1894
|
+
|
|
1895
|
+
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
|
|
1896
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1897
|
+
slot.has_next_token = false;
|
|
1898
|
+
|
|
1899
|
+
// cut the last line
|
|
1900
|
+
slot.generated_text.erase(pos, std::string::npos);
|
|
1901
|
+
|
|
1902
|
+
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
|
|
1903
|
+
}
|
|
1904
|
+
}
|
|
1905
|
+
|
|
1906
|
+
// find the next new line
|
|
1907
|
+
{
|
|
1908
|
+
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
|
|
1909
|
+
|
|
1910
|
+
if (pos != std::string::npos) {
|
|
1911
|
+
slot.last_nl_pos = pos + 1;
|
|
1912
|
+
}
|
|
1913
|
+
}
|
|
1914
|
+
}
|
|
1915
|
+
}
|
|
1916
|
+
|
|
1917
|
+
// check if there is a new line in the generated text
|
|
1918
|
+
if (result.text_to_send.find('\n') != std::string::npos) {
|
|
1919
|
+
slot.has_new_line = true;
|
|
1920
|
+
}
|
|
1921
|
+
|
|
1194
1922
|
// if context shift is disabled, we stop when it reaches the context limit
|
|
1195
|
-
if (slot.
|
|
1923
|
+
if (slot.n_past >= slot.n_ctx) {
|
|
1196
1924
|
slot.truncated = true;
|
|
1197
|
-
slot.
|
|
1925
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1198
1926
|
slot.has_next_token = false;
|
|
1199
1927
|
|
|
1200
|
-
SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n",
|
|
1928
|
+
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
|
|
1929
|
+
slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
|
|
1201
1930
|
}
|
|
1202
1931
|
|
|
1203
1932
|
if (llama_token_is_eog(model, result.tok)) {
|
|
1204
|
-
slot.
|
|
1933
|
+
slot.stop = STOP_TYPE_EOS;
|
|
1205
1934
|
slot.has_next_token = false;
|
|
1206
1935
|
|
|
1207
1936
|
SLT_DBG(slot, "%s", "stopped by EOS\n");
|
|
@@ -1209,63 +1938,69 @@ struct server_context {
|
|
|
1209
1938
|
|
|
1210
1939
|
const auto n_ctx_train = llama_n_ctx_train(model);
|
|
1211
1940
|
|
|
1212
|
-
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.
|
|
1941
|
+
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
|
1213
1942
|
slot.truncated = true;
|
|
1214
|
-
slot.
|
|
1943
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1215
1944
|
slot.has_next_token = false; // stop prediction
|
|
1216
1945
|
|
|
1217
1946
|
SLT_WRN(slot,
|
|
1218
|
-
"n_predict (%d) is
|
|
1947
|
+
"n_predict (%d) is set for infinite generation. "
|
|
1219
1948
|
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
|
|
1220
1949
|
slot.params.n_predict, n_ctx_train);
|
|
1221
1950
|
}
|
|
1222
1951
|
|
|
1223
|
-
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
|
|
1952
|
+
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
|
|
1224
1953
|
|
|
1225
1954
|
return slot.has_next_token; // continue
|
|
1226
1955
|
}
|
|
1227
1956
|
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
1231
|
-
|
|
1232
|
-
|
|
1233
|
-
|
|
1957
|
+
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
|
|
1958
|
+
size_t n_probs = slot.params.sampling.n_probs;
|
|
1959
|
+
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
1960
|
+
if (post_sampling) {
|
|
1961
|
+
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
|
1962
|
+
const size_t max_probs = cur_p->size;
|
|
1963
|
+
|
|
1964
|
+
// set probability for sampled token
|
|
1965
|
+
for (size_t i = 0; i < max_probs; i++) {
|
|
1966
|
+
if (cur_p->data[i].id == result.tok) {
|
|
1967
|
+
result.prob = cur_p->data[i].p;
|
|
1968
|
+
break;
|
|
1969
|
+
}
|
|
1970
|
+
}
|
|
1234
1971
|
|
|
1235
|
-
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1239
|
-
|
|
1240
|
-
|
|
1241
|
-
|
|
1242
|
-
|
|
1243
|
-
|
|
1244
|
-
|
|
1245
|
-
|
|
1246
|
-
|
|
1247
|
-
|
|
1248
|
-
|
|
1249
|
-
{
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
{
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
{"samplers", samplers},
|
|
1268
|
-
};
|
|
1972
|
+
// set probability for top n_probs tokens
|
|
1973
|
+
result.probs.reserve(max_probs);
|
|
1974
|
+
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
|
1975
|
+
result.probs.push_back({
|
|
1976
|
+
cur_p->data[i].id,
|
|
1977
|
+
common_detokenize(ctx, {cur_p->data[i].id}, special),
|
|
1978
|
+
cur_p->data[i].p
|
|
1979
|
+
});
|
|
1980
|
+
}
|
|
1981
|
+
} else {
|
|
1982
|
+
// TODO: optimize this with min-p optimization
|
|
1983
|
+
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
|
1984
|
+
|
|
1985
|
+
// set probability for sampled token
|
|
1986
|
+
for (size_t i = 0; i < n_vocab; i++) {
|
|
1987
|
+
// set probability for sampled token
|
|
1988
|
+
if (cur[i].id == result.tok) {
|
|
1989
|
+
result.prob = cur[i].p;
|
|
1990
|
+
break;
|
|
1991
|
+
}
|
|
1992
|
+
}
|
|
1993
|
+
|
|
1994
|
+
// set probability for top n_probs tokens
|
|
1995
|
+
result.probs.reserve(n_probs);
|
|
1996
|
+
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
|
1997
|
+
result.probs.push_back({
|
|
1998
|
+
cur[i].id,
|
|
1999
|
+
common_detokenize(ctx, {cur[i].id}, special),
|
|
2000
|
+
cur[i].p
|
|
2001
|
+
});
|
|
2002
|
+
}
|
|
2003
|
+
}
|
|
1269
2004
|
}
|
|
1270
2005
|
|
|
1271
2006
|
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
@@ -1279,114 +2014,106 @@ struct server_context {
|
|
|
1279
2014
|
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
1280
2015
|
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
|
|
1281
2016
|
|
|
1282
|
-
|
|
1283
|
-
res
|
|
1284
|
-
res
|
|
1285
|
-
res
|
|
1286
|
-
res.data = format_error_response(error, type);
|
|
1287
|
-
|
|
1288
|
-
queue_results.send(res);
|
|
1289
|
-
}
|
|
1290
|
-
|
|
1291
|
-
void send_partial_response(server_slot & slot, completion_token_output tkn) {
|
|
1292
|
-
server_task_result res;
|
|
1293
|
-
res.id = slot.id_task;
|
|
1294
|
-
res.error = false;
|
|
1295
|
-
res.stop = false;
|
|
1296
|
-
res.data = json {
|
|
1297
|
-
{"content", tkn.text_to_send},
|
|
1298
|
-
{"stop", false},
|
|
1299
|
-
{"id_slot", slot.id},
|
|
1300
|
-
{"multimodal", false},
|
|
1301
|
-
{"index", slot.index},
|
|
1302
|
-
};
|
|
2017
|
+
auto res = std::make_unique<server_task_result_error>();
|
|
2018
|
+
res->id = id_task;
|
|
2019
|
+
res->err_type = type;
|
|
2020
|
+
res->err_msg = error;
|
|
1303
2021
|
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
|
1307
|
-
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
|
2022
|
+
queue_results.send(std::move(res));
|
|
2023
|
+
}
|
|
1308
2024
|
|
|
1309
|
-
|
|
1310
|
-
|
|
1311
|
-
|
|
1312
|
-
|
|
1313
|
-
|
|
1314
|
-
|
|
1315
|
-
|
|
2025
|
+
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
|
2026
|
+
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
|
2027
|
+
|
|
2028
|
+
res->id = slot.id_task;
|
|
2029
|
+
res->index = slot.index;
|
|
2030
|
+
res->content = tkn.text_to_send;
|
|
2031
|
+
res->tokens = { tkn.tok };
|
|
2032
|
+
|
|
2033
|
+
res->n_decoded = slot.n_decoded;
|
|
2034
|
+
res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
2035
|
+
res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
1316
2036
|
|
|
1317
|
-
|
|
2037
|
+
res->verbose = slot.params.verbose;
|
|
2038
|
+
res->oaicompat = slot.params.oaicompat;
|
|
2039
|
+
res->oaicompat_chat = slot.params.oaicompat_chat;
|
|
2040
|
+
res->oaicompat_model = slot.params.oaicompat_model;
|
|
2041
|
+
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
2042
|
+
|
|
2043
|
+
// populate res.probs_output
|
|
2044
|
+
if (slot.params.sampling.n_probs > 0) {
|
|
2045
|
+
res->prob_output = tkn; // copy the token probs
|
|
1318
2046
|
}
|
|
1319
2047
|
|
|
1320
|
-
if
|
|
1321
|
-
|
|
1322
|
-
res
|
|
2048
|
+
// populate timings if this is final response or timings_per_token is enabled
|
|
2049
|
+
if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
|
|
2050
|
+
res->timings = slot.get_timings();
|
|
1323
2051
|
}
|
|
1324
2052
|
|
|
1325
|
-
queue_results.send(res);
|
|
2053
|
+
queue_results.send(std::move(res));
|
|
1326
2054
|
}
|
|
1327
2055
|
|
|
1328
|
-
void send_final_response(
|
|
1329
|
-
|
|
1330
|
-
res
|
|
1331
|
-
res
|
|
1332
|
-
|
|
1333
|
-
res
|
|
1334
|
-
|
|
1335
|
-
|
|
1336
|
-
|
|
1337
|
-
|
|
1338
|
-
|
|
1339
|
-
|
|
1340
|
-
|
|
1341
|
-
|
|
1342
|
-
|
|
1343
|
-
|
|
1344
|
-
|
|
1345
|
-
|
|
1346
|
-
|
|
1347
|
-
|
|
1348
|
-
|
|
1349
|
-
|
|
1350
|
-
|
|
1351
|
-
|
|
1352
|
-
|
|
1353
|
-
|
|
1354
|
-
|
|
1355
|
-
|
|
2056
|
+
void send_final_response(server_slot & slot) {
|
|
2057
|
+
auto res = std::make_unique<server_task_result_cmpl_final>();
|
|
2058
|
+
res->id = slot.id_task;
|
|
2059
|
+
res->id_slot = slot.id;
|
|
2060
|
+
|
|
2061
|
+
res->index = slot.index;
|
|
2062
|
+
res->content = slot.generated_text;
|
|
2063
|
+
res->tokens = slot.generated_tokens;
|
|
2064
|
+
res->timings = slot.get_timings();
|
|
2065
|
+
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
|
2066
|
+
|
|
2067
|
+
res->truncated = slot.truncated;
|
|
2068
|
+
res->n_decoded = slot.n_decoded;
|
|
2069
|
+
res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
2070
|
+
res->n_tokens_cached = slot.n_past;
|
|
2071
|
+
res->has_new_line = slot.has_new_line;
|
|
2072
|
+
res->stopping_word = slot.stopping_word;
|
|
2073
|
+
res->stop = slot.stop;
|
|
2074
|
+
res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
2075
|
+
|
|
2076
|
+
res->verbose = slot.params.verbose;
|
|
2077
|
+
res->stream = slot.params.stream;
|
|
2078
|
+
res->oaicompat = slot.params.oaicompat;
|
|
2079
|
+
res->oaicompat_chat = slot.params.oaicompat_chat;
|
|
2080
|
+
res->oaicompat_model = slot.params.oaicompat_model;
|
|
2081
|
+
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
2082
|
+
|
|
2083
|
+
// populate res.probs_output
|
|
2084
|
+
if (slot.params.sampling.n_probs > 0) {
|
|
2085
|
+
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
|
2086
|
+
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
|
1356
2087
|
|
|
1357
2088
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
|
1358
|
-
|
|
2089
|
+
res->probs_output = std::vector<completion_token_output>(
|
|
1359
2090
|
slot.generated_token_probs.begin(),
|
|
1360
2091
|
slot.generated_token_probs.end() - safe_offset);
|
|
1361
2092
|
} else {
|
|
1362
|
-
|
|
2093
|
+
res->probs_output = std::vector<completion_token_output>(
|
|
1363
2094
|
slot.generated_token_probs.begin(),
|
|
1364
2095
|
slot.generated_token_probs.end());
|
|
1365
2096
|
}
|
|
1366
|
-
|
|
1367
|
-
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
|
|
1368
2097
|
}
|
|
1369
2098
|
|
|
1370
|
-
|
|
1371
|
-
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
|
1372
|
-
res.data["model"] = slot.oaicompat_model;
|
|
1373
|
-
}
|
|
2099
|
+
res->generation_params = slot.params; // copy the parameters
|
|
1374
2100
|
|
|
1375
|
-
queue_results.send(res);
|
|
2101
|
+
queue_results.send(std::move(res));
|
|
1376
2102
|
}
|
|
1377
2103
|
|
|
1378
2104
|
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
|
1379
|
-
|
|
1380
|
-
res
|
|
1381
|
-
res
|
|
1382
|
-
res
|
|
2105
|
+
auto res = std::make_unique<server_task_result_embd>();
|
|
2106
|
+
res->id = slot.id_task;
|
|
2107
|
+
res->index = slot.index;
|
|
2108
|
+
res->n_tokens = slot.n_prompt_tokens;
|
|
2109
|
+
res->oaicompat = slot.params.oaicompat;
|
|
1383
2110
|
|
|
1384
2111
|
const int n_embd = llama_n_embd(model);
|
|
1385
2112
|
|
|
1386
2113
|
std::vector<float> embd_res(n_embd, 0.0f);
|
|
1387
2114
|
|
|
1388
2115
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
1389
|
-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id
|
|
2116
|
+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
1390
2117
|
continue;
|
|
1391
2118
|
}
|
|
1392
2119
|
|
|
@@ -1398,35 +2125,33 @@ struct server_context {
|
|
|
1398
2125
|
if (embd == NULL) {
|
|
1399
2126
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
1400
2127
|
|
|
1401
|
-
res.
|
|
1402
|
-
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
|
1403
|
-
{"index", slot.index},
|
|
1404
|
-
};
|
|
1405
|
-
|
|
2128
|
+
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
|
1406
2129
|
continue;
|
|
1407
2130
|
}
|
|
1408
2131
|
|
|
1409
|
-
|
|
1410
|
-
|
|
1411
|
-
|
|
1412
|
-
|
|
1413
|
-
|
|
1414
|
-
}
|
|
2132
|
+
// normalize only when there is pooling
|
|
2133
|
+
// TODO: configurable
|
|
2134
|
+
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
|
2135
|
+
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
|
2136
|
+
res->embedding.push_back(embd_res);
|
|
2137
|
+
} else {
|
|
2138
|
+
res->embedding.push_back({ embd, embd + n_embd });
|
|
2139
|
+
}
|
|
1415
2140
|
}
|
|
1416
2141
|
|
|
1417
2142
|
SLT_DBG(slot, "%s", "sending embeddings\n");
|
|
1418
2143
|
|
|
1419
|
-
queue_results.send(res);
|
|
2144
|
+
queue_results.send(std::move(res));
|
|
1420
2145
|
}
|
|
1421
2146
|
|
|
1422
2147
|
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
|
1423
|
-
|
|
1424
|
-
res
|
|
1425
|
-
res
|
|
1426
|
-
res
|
|
2148
|
+
auto res = std::make_unique<server_task_result_rerank>();
|
|
2149
|
+
res->id = slot.id_task;
|
|
2150
|
+
res->index = slot.index;
|
|
2151
|
+
res->n_tokens = slot.n_prompt_tokens;
|
|
1427
2152
|
|
|
1428
2153
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
1429
|
-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id
|
|
2154
|
+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
1430
2155
|
continue;
|
|
1431
2156
|
}
|
|
1432
2157
|
|
|
@@ -1438,100 +2163,29 @@ struct server_context {
|
|
|
1438
2163
|
if (embd == NULL) {
|
|
1439
2164
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
1440
2165
|
|
|
1441
|
-
res
|
|
1442
|
-
{"index", slot.index},
|
|
1443
|
-
{"score", -1e6},
|
|
1444
|
-
};
|
|
1445
|
-
|
|
2166
|
+
res->score = -1e6;
|
|
1446
2167
|
continue;
|
|
1447
2168
|
}
|
|
1448
2169
|
|
|
1449
|
-
res
|
|
1450
|
-
{"index", slot.index},
|
|
1451
|
-
{"score", embd[0]},
|
|
1452
|
-
};
|
|
2170
|
+
res->score = embd[0];
|
|
1453
2171
|
}
|
|
1454
2172
|
|
|
1455
|
-
SLT_DBG(slot, "sending rerank result, res =
|
|
2173
|
+
SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
|
|
1456
2174
|
|
|
1457
|
-
queue_results.send(res);
|
|
2175
|
+
queue_results.send(std::move(res));
|
|
1458
2176
|
}
|
|
1459
2177
|
|
|
1460
2178
|
//
|
|
1461
2179
|
// Functions to create new task(s) and receive result(s)
|
|
1462
2180
|
//
|
|
1463
2181
|
|
|
1464
|
-
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
|
1465
|
-
std::vector<server_task> tasks;
|
|
1466
|
-
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
|
1467
|
-
server_task task;
|
|
1468
|
-
task.id = queue_tasks.get_new_id();
|
|
1469
|
-
task.cmpl_type = cmpl_type;
|
|
1470
|
-
task.type = SERVER_TASK_TYPE_COMPLETION;
|
|
1471
|
-
if (replace_prompt) {
|
|
1472
|
-
task.data = task_data;
|
|
1473
|
-
task.data["prompt"] = std::move(prompt);
|
|
1474
|
-
} else {
|
|
1475
|
-
task.data = std::move(task_data);
|
|
1476
|
-
}
|
|
1477
|
-
tasks.push_back(std::move(task));
|
|
1478
|
-
};
|
|
1479
|
-
|
|
1480
|
-
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
|
|
1481
|
-
if (!data.contains("prompt")) {
|
|
1482
|
-
throw std::runtime_error(error_msg);
|
|
1483
|
-
}
|
|
1484
|
-
|
|
1485
|
-
json prompt = data.at("prompt");
|
|
1486
|
-
|
|
1487
|
-
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
|
1488
|
-
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
|
1489
|
-
data["index"] = 0;
|
|
1490
|
-
create_task(data, false, nullptr);
|
|
1491
|
-
}
|
|
1492
|
-
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
|
1493
|
-
else if (prompt.is_array()) {
|
|
1494
|
-
std::vector<json> prompts = prompt;
|
|
1495
|
-
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
1496
|
-
// prompts[0] is the question
|
|
1497
|
-
// the rest are the answers/documents
|
|
1498
|
-
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
|
1499
|
-
for (size_t i = 1; i < prompts.size(); i++) {
|
|
1500
|
-
json qd;
|
|
1501
|
-
qd.push_back(prompts[0]);
|
|
1502
|
-
qd.push_back(prompts[i]);
|
|
1503
|
-
data["index"] = i - 1;
|
|
1504
|
-
create_task(data, true, qd);
|
|
1505
|
-
}
|
|
1506
|
-
} else {
|
|
1507
|
-
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
|
1508
|
-
for (size_t i = 0; i < prompts.size(); i++) {
|
|
1509
|
-
const auto & e = prompts[i];
|
|
1510
|
-
if (e.is_string() || json_is_array_of_numbers(e)) {
|
|
1511
|
-
data["index"] = i;
|
|
1512
|
-
create_task(data, true, e);
|
|
1513
|
-
} else {
|
|
1514
|
-
throw std::runtime_error(error_msg);
|
|
1515
|
-
}
|
|
1516
|
-
}
|
|
1517
|
-
}
|
|
1518
|
-
}
|
|
1519
|
-
// invalid case
|
|
1520
|
-
else {
|
|
1521
|
-
throw std::runtime_error(error_msg);
|
|
1522
|
-
}
|
|
1523
|
-
|
|
1524
|
-
return tasks;
|
|
1525
|
-
}
|
|
1526
|
-
|
|
1527
2182
|
void cancel_tasks(const std::unordered_set<int> & id_tasks) {
|
|
1528
2183
|
std::vector<server_task> cancel_tasks;
|
|
1529
2184
|
cancel_tasks.reserve(id_tasks.size());
|
|
1530
2185
|
for (const auto & id_task : id_tasks) {
|
|
1531
2186
|
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
|
1532
2187
|
|
|
1533
|
-
server_task task;
|
|
1534
|
-
task.type = SERVER_TASK_TYPE_CANCEL;
|
|
2188
|
+
server_task task(SERVER_TASK_TYPE_CANCEL);
|
|
1535
2189
|
task.id_target = id_task;
|
|
1536
2190
|
cancel_tasks.push_back(task);
|
|
1537
2191
|
queue_results.remove_waiting_task_id(id_task);
|
|
@@ -1540,50 +2194,58 @@ struct server_context {
|
|
|
1540
2194
|
queue_tasks.post(cancel_tasks, true);
|
|
1541
2195
|
}
|
|
1542
2196
|
|
|
1543
|
-
// receive the results from task(s)
|
|
1544
|
-
void
|
|
2197
|
+
// receive the results from task(s)
|
|
2198
|
+
void receive_multi_results(
|
|
1545
2199
|
const std::unordered_set<int> & id_tasks,
|
|
1546
|
-
const std::function<void(std::vector<
|
|
2200
|
+
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
|
1547
2201
|
const std::function<void(json)> & error_handler) {
|
|
1548
|
-
|
|
1549
|
-
std::vector<server_task_result> results(id_tasks.size());
|
|
2202
|
+
std::vector<server_task_result_ptr> results(id_tasks.size());
|
|
1550
2203
|
for (size_t i = 0; i < id_tasks.size(); i++) {
|
|
1551
|
-
|
|
2204
|
+
server_task_result_ptr result = queue_results.recv(id_tasks);
|
|
1552
2205
|
|
|
1553
|
-
if (result
|
|
1554
|
-
error_handler(result
|
|
2206
|
+
if (result->is_error()) {
|
|
2207
|
+
error_handler(result->to_json());
|
|
1555
2208
|
cancel_tasks(id_tasks);
|
|
1556
2209
|
return;
|
|
1557
2210
|
}
|
|
1558
2211
|
|
|
1559
|
-
|
|
2212
|
+
GGML_ASSERT(
|
|
2213
|
+
dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
2214
|
+
|| dynamic_cast<server_task_result_embd*>(result.get()) != nullptr
|
|
2215
|
+
|| dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr
|
|
2216
|
+
);
|
|
2217
|
+
const size_t idx = result->get_index();
|
|
1560
2218
|
GGML_ASSERT(idx < results.size() && "index out of range");
|
|
1561
|
-
|
|
1562
|
-
results[idx] = result;
|
|
2219
|
+
results[idx] = std::move(result);
|
|
1563
2220
|
}
|
|
1564
2221
|
result_handler(results);
|
|
1565
2222
|
}
|
|
1566
2223
|
|
|
1567
|
-
// receive the results from task(s)
|
|
2224
|
+
// receive the results from task(s), in stream mode
|
|
1568
2225
|
void receive_cmpl_results_stream(
|
|
1569
|
-
const std::unordered_set<int> & id_tasks,
|
|
1570
|
-
std::function<bool(
|
|
1571
|
-
std::function<void(json)> & error_handler) {
|
|
2226
|
+
const std::unordered_set<int> & id_tasks,
|
|
2227
|
+
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
|
2228
|
+
const std::function<void(json)> & error_handler) {
|
|
1572
2229
|
size_t n_finished = 0;
|
|
1573
2230
|
while (true) {
|
|
1574
|
-
|
|
1575
|
-
|
|
2231
|
+
server_task_result_ptr result = queue_results.recv(id_tasks);
|
|
2232
|
+
|
|
2233
|
+
if (result->is_error()) {
|
|
2234
|
+
error_handler(result->to_json());
|
|
1576
2235
|
cancel_tasks(id_tasks);
|
|
1577
|
-
|
|
2236
|
+
return;
|
|
1578
2237
|
}
|
|
1579
2238
|
|
|
1580
|
-
|
|
1581
|
-
|
|
2239
|
+
GGML_ASSERT(
|
|
2240
|
+
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
|
2241
|
+
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
2242
|
+
);
|
|
2243
|
+
if (!result_handler(result)) {
|
|
1582
2244
|
cancel_tasks(id_tasks);
|
|
1583
2245
|
break;
|
|
1584
2246
|
}
|
|
1585
2247
|
|
|
1586
|
-
if (result
|
|
2248
|
+
if (result->is_stop()) {
|
|
1587
2249
|
if (++n_finished == id_tasks.size()) {
|
|
1588
2250
|
break;
|
|
1589
2251
|
}
|
|
@@ -1595,24 +2257,16 @@ struct server_context {
|
|
|
1595
2257
|
// Functions to process the task
|
|
1596
2258
|
//
|
|
1597
2259
|
|
|
1598
|
-
void process_single_task(
|
|
2260
|
+
void process_single_task(server_task task) {
|
|
1599
2261
|
switch (task.type) {
|
|
1600
2262
|
case SERVER_TASK_TYPE_COMPLETION:
|
|
2263
|
+
case SERVER_TASK_TYPE_INFILL:
|
|
2264
|
+
case SERVER_TASK_TYPE_EMBEDDING:
|
|
2265
|
+
case SERVER_TASK_TYPE_RERANK:
|
|
1601
2266
|
{
|
|
1602
|
-
const int id_slot =
|
|
1603
|
-
|
|
1604
|
-
server_slot * slot;
|
|
1605
|
-
|
|
1606
|
-
if (id_slot != -1) {
|
|
1607
|
-
slot = get_slot_by_id(id_slot);
|
|
1608
|
-
} else {
|
|
1609
|
-
std::string prompt;
|
|
1610
|
-
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
|
|
1611
|
-
prompt = json_value(task.data, "prompt", std::string());
|
|
1612
|
-
}
|
|
2267
|
+
const int id_slot = task.id_selected_slot;
|
|
1613
2268
|
|
|
1614
|
-
|
|
1615
|
-
}
|
|
2269
|
+
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
|
|
1616
2270
|
|
|
1617
2271
|
if (slot == nullptr) {
|
|
1618
2272
|
// if no slot is available, we defer this task for processing later
|
|
@@ -1627,22 +2281,6 @@ struct server_context {
|
|
|
1627
2281
|
break;
|
|
1628
2282
|
}
|
|
1629
2283
|
|
|
1630
|
-
if (task.data.contains("system_prompt")) {
|
|
1631
|
-
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
|
|
1632
|
-
system_prompt_set(sys_prompt);
|
|
1633
|
-
|
|
1634
|
-
for (server_slot & slot : slots) {
|
|
1635
|
-
slot.n_past = 0;
|
|
1636
|
-
slot.n_past_se = 0;
|
|
1637
|
-
}
|
|
1638
|
-
}
|
|
1639
|
-
|
|
1640
|
-
slot->reset();
|
|
1641
|
-
|
|
1642
|
-
slot->id_task = task.id;
|
|
1643
|
-
slot->cmpl_type = task.cmpl_type;
|
|
1644
|
-
slot->index = json_value(task.data, "index", 0);
|
|
1645
|
-
|
|
1646
2284
|
if (!launch_slot_with_task(*slot, task)) {
|
|
1647
2285
|
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
|
1648
2286
|
break;
|
|
@@ -1670,68 +2308,50 @@ struct server_context {
|
|
|
1670
2308
|
int n_processing_slots = 0;
|
|
1671
2309
|
|
|
1672
2310
|
for (server_slot & slot : slots) {
|
|
1673
|
-
json slot_data =
|
|
1674
|
-
|
|
1675
|
-
|
|
1676
|
-
slot_data["state"] = slot.state;
|
|
1677
|
-
slot_data["prompt"] = slot.prompt;
|
|
1678
|
-
slot_data["next_token"] = {
|
|
1679
|
-
{"has_next_token", slot.has_next_token},
|
|
1680
|
-
{"n_remain", slot.n_remaining},
|
|
1681
|
-
{"n_decoded", slot.n_decoded},
|
|
1682
|
-
{"stopped_eos", slot.stopped_eos},
|
|
1683
|
-
{"stopped_word", slot.stopped_word},
|
|
1684
|
-
{"stopped_limit", slot.stopped_limit},
|
|
1685
|
-
{"stopping_word", slot.stopping_word},
|
|
1686
|
-
};
|
|
1687
|
-
|
|
1688
|
-
if (slot_data["state"] == SLOT_STATE_IDLE) {
|
|
1689
|
-
n_idle_slots++;
|
|
1690
|
-
} else {
|
|
2311
|
+
json slot_data = slot.to_json();
|
|
2312
|
+
|
|
2313
|
+
if (slot.is_processing()) {
|
|
1691
2314
|
n_processing_slots++;
|
|
2315
|
+
} else {
|
|
2316
|
+
n_idle_slots++;
|
|
1692
2317
|
}
|
|
1693
2318
|
|
|
1694
2319
|
slots_data.push_back(slot_data);
|
|
1695
2320
|
}
|
|
1696
2321
|
SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
|
|
1697
2322
|
|
|
1698
|
-
|
|
1699
|
-
res
|
|
1700
|
-
res
|
|
1701
|
-
res
|
|
1702
|
-
res
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
{ "deferred", queue_tasks.queue_tasks_deferred.size() },
|
|
1706
|
-
{ "t_start", metrics.t_start},
|
|
1707
|
-
|
|
1708
|
-
{ "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
|
|
1709
|
-
{ "t_tokens_generation_total", metrics.t_tokens_generation_total},
|
|
1710
|
-
{ "n_tokens_predicted_total", metrics.n_tokens_predicted_total},
|
|
1711
|
-
{ "t_prompt_processing_total", metrics.t_prompt_processing_total},
|
|
2323
|
+
auto res = std::make_unique<server_task_result_metrics>();
|
|
2324
|
+
res->id = task.id;
|
|
2325
|
+
res->slots_data = std::move(slots_data);
|
|
2326
|
+
res->n_idle_slots = n_idle_slots;
|
|
2327
|
+
res->n_processing_slots = n_processing_slots;
|
|
2328
|
+
res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
|
|
2329
|
+
res->t_start = metrics.t_start;
|
|
1712
2330
|
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
{ "n_tokens_predicted", metrics.n_tokens_predicted},
|
|
1716
|
-
{ "t_tokens_generation", metrics.t_tokens_generation},
|
|
2331
|
+
res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
|
|
2332
|
+
res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
|
|
1717
2333
|
|
|
1718
|
-
|
|
1719
|
-
|
|
2334
|
+
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
|
|
2335
|
+
res->t_prompt_processing_total = metrics.t_prompt_processing_total;
|
|
2336
|
+
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
|
|
2337
|
+
res->t_tokens_generation_total = metrics.t_tokens_generation_total;
|
|
1720
2338
|
|
|
1721
|
-
|
|
1722
|
-
|
|
2339
|
+
res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
|
|
2340
|
+
res->t_prompt_processing = metrics.t_prompt_processing;
|
|
2341
|
+
res->n_tokens_predicted = metrics.n_tokens_predicted;
|
|
2342
|
+
res->t_tokens_generation = metrics.t_tokens_generation;
|
|
1723
2343
|
|
|
1724
|
-
|
|
1725
|
-
|
|
2344
|
+
res->n_decode_total = metrics.n_decode_total;
|
|
2345
|
+
res->n_busy_slots_total = metrics.n_busy_slots_total;
|
|
1726
2346
|
|
|
1727
|
-
if (
|
|
2347
|
+
if (task.metrics_reset_bucket) {
|
|
1728
2348
|
metrics.reset_bucket();
|
|
1729
2349
|
}
|
|
1730
|
-
queue_results.send(res);
|
|
2350
|
+
queue_results.send(std::move(res));
|
|
1731
2351
|
} break;
|
|
1732
2352
|
case SERVER_TASK_TYPE_SLOT_SAVE:
|
|
1733
2353
|
{
|
|
1734
|
-
int id_slot = task.
|
|
2354
|
+
int id_slot = task.slot_action.slot_id;
|
|
1735
2355
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
1736
2356
|
if (slot == nullptr) {
|
|
1737
2357
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -1747,32 +2367,27 @@ struct server_context {
|
|
|
1747
2367
|
const size_t token_count = slot->cache_tokens.size();
|
|
1748
2368
|
const int64_t t_start = ggml_time_us();
|
|
1749
2369
|
|
|
1750
|
-
std::string filename = task.
|
|
1751
|
-
std::string filepath = task.
|
|
2370
|
+
std::string filename = task.slot_action.filename;
|
|
2371
|
+
std::string filepath = task.slot_action.filepath;
|
|
1752
2372
|
|
|
1753
|
-
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id
|
|
2373
|
+
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
|
1754
2374
|
|
|
1755
2375
|
const int64_t t_end = ggml_time_us();
|
|
1756
2376
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
|
1757
2377
|
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
|
|
1767
|
-
{ "timings", {
|
|
1768
|
-
{ "save_ms", t_save_ms }
|
|
1769
|
-
} }
|
|
1770
|
-
};
|
|
1771
|
-
queue_results.send(result);
|
|
2378
|
+
auto res = std::make_unique<server_task_result_slot_save_load>();
|
|
2379
|
+
res->id = task.id;
|
|
2380
|
+
res->id_slot = id_slot;
|
|
2381
|
+
res->filename = filename;
|
|
2382
|
+
res->is_save = true;
|
|
2383
|
+
res->n_tokens = token_count;
|
|
2384
|
+
res->n_bytes = nwrite;
|
|
2385
|
+
res->t_ms = t_save_ms;
|
|
2386
|
+
queue_results.send(std::move(res));
|
|
1772
2387
|
} break;
|
|
1773
2388
|
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
|
1774
2389
|
{
|
|
1775
|
-
int id_slot = task.
|
|
2390
|
+
int id_slot = task.slot_action.slot_id;
|
|
1776
2391
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
1777
2392
|
if (slot == nullptr) {
|
|
1778
2393
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -1787,12 +2402,12 @@ struct server_context {
|
|
|
1787
2402
|
|
|
1788
2403
|
const int64_t t_start = ggml_time_us();
|
|
1789
2404
|
|
|
1790
|
-
std::string filename = task.
|
|
1791
|
-
std::string filepath = task.
|
|
2405
|
+
std::string filename = task.slot_action.filename;
|
|
2406
|
+
std::string filepath = task.slot_action.filepath;
|
|
1792
2407
|
|
|
1793
2408
|
slot->cache_tokens.resize(slot->n_ctx);
|
|
1794
2409
|
size_t token_count = 0;
|
|
1795
|
-
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id
|
|
2410
|
+
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
|
1796
2411
|
if (nread == 0) {
|
|
1797
2412
|
slot->cache_tokens.resize(0);
|
|
1798
2413
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -1803,24 +2418,19 @@ struct server_context {
|
|
|
1803
2418
|
const int64_t t_end = ggml_time_us();
|
|
1804
2419
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
|
1805
2420
|
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1813
|
-
|
|
1814
|
-
|
|
1815
|
-
{ "timings", {
|
|
1816
|
-
{ "restore_ms", t_restore_ms }
|
|
1817
|
-
} }
|
|
1818
|
-
};
|
|
1819
|
-
queue_results.send(result);
|
|
2421
|
+
auto res = std::make_unique<server_task_result_slot_save_load>();
|
|
2422
|
+
res->id = task.id;
|
|
2423
|
+
res->id_slot = id_slot;
|
|
2424
|
+
res->filename = filename;
|
|
2425
|
+
res->is_save = false;
|
|
2426
|
+
res->n_tokens = token_count;
|
|
2427
|
+
res->n_bytes = nread;
|
|
2428
|
+
res->t_ms = t_restore_ms;
|
|
2429
|
+
queue_results.send(std::move(res));
|
|
1820
2430
|
} break;
|
|
1821
2431
|
case SERVER_TASK_TYPE_SLOT_ERASE:
|
|
1822
2432
|
{
|
|
1823
|
-
int id_slot = task.
|
|
2433
|
+
int id_slot = task.slot_action.slot_id;
|
|
1824
2434
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
1825
2435
|
if (slot == nullptr) {
|
|
1826
2436
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -1835,37 +2445,26 @@ struct server_context {
|
|
|
1835
2445
|
|
|
1836
2446
|
// Erase token cache
|
|
1837
2447
|
const size_t n_erased = slot->cache_tokens.size();
|
|
1838
|
-
llama_kv_cache_seq_rm(ctx, slot->id
|
|
2448
|
+
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
|
1839
2449
|
slot->cache_tokens.clear();
|
|
1840
2450
|
|
|
1841
|
-
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
{ "id_slot", id_slot },
|
|
1847
|
-
{ "n_erased", n_erased }
|
|
1848
|
-
};
|
|
1849
|
-
queue_results.send(result);
|
|
2451
|
+
auto res = std::make_unique<server_task_result_slot_erase>();
|
|
2452
|
+
res->id = task.id;
|
|
2453
|
+
res->id_slot = id_slot;
|
|
2454
|
+
res->n_erased = n_erased;
|
|
2455
|
+
queue_results.send(std::move(res));
|
|
1850
2456
|
} break;
|
|
1851
2457
|
case SERVER_TASK_TYPE_SET_LORA:
|
|
1852
2458
|
{
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
1856
|
-
|
|
1857
|
-
result.error = false;
|
|
1858
|
-
result.data = json{{ "success", true }};
|
|
1859
|
-
queue_results.send(result);
|
|
2459
|
+
common_lora_adapters_apply(ctx, loras);
|
|
2460
|
+
auto res = std::make_unique<server_task_result_apply_lora>();
|
|
2461
|
+
res->id = task.id;
|
|
2462
|
+
queue_results.send(std::move(res));
|
|
1860
2463
|
} break;
|
|
1861
2464
|
}
|
|
1862
2465
|
}
|
|
1863
2466
|
|
|
1864
2467
|
void update_slots() {
|
|
1865
|
-
if (system_need_update) {
|
|
1866
|
-
system_prompt_update();
|
|
1867
|
-
}
|
|
1868
|
-
|
|
1869
2468
|
// check if all slots are idle
|
|
1870
2469
|
{
|
|
1871
2470
|
bool all_idle = true;
|
|
@@ -1879,7 +2478,7 @@ struct server_context {
|
|
|
1879
2478
|
|
|
1880
2479
|
if (all_idle) {
|
|
1881
2480
|
SRV_INF("%s", "all slots are idle\n");
|
|
1882
|
-
if (
|
|
2481
|
+
if (clean_kv_cache) {
|
|
1883
2482
|
kv_cache_clear();
|
|
1884
2483
|
}
|
|
1885
2484
|
|
|
@@ -1890,53 +2489,49 @@ struct server_context {
|
|
|
1890
2489
|
{
|
|
1891
2490
|
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
|
|
1892
2491
|
|
|
1893
|
-
server_task task;
|
|
1894
|
-
task.
|
|
1895
|
-
task.id_target = -1;
|
|
1896
|
-
|
|
2492
|
+
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
|
|
2493
|
+
task.id = queue_tasks.get_new_id();
|
|
1897
2494
|
queue_tasks.post(task);
|
|
1898
2495
|
}
|
|
1899
2496
|
|
|
1900
2497
|
// apply context-shift if needed
|
|
1901
2498
|
// TODO: simplify and improve
|
|
1902
2499
|
for (server_slot & slot : slots) {
|
|
1903
|
-
if (slot.
|
|
1904
|
-
if (
|
|
1905
|
-
|
|
1906
|
-
|
|
1907
|
-
|
|
1908
|
-
|
|
1909
|
-
|
|
1910
|
-
|
|
1911
|
-
}
|
|
1912
|
-
|
|
1913
|
-
// Shift context
|
|
1914
|
-
const int n_keep = slot.params.n_keep + add_bos_token;
|
|
1915
|
-
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
|
1916
|
-
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
|
2500
|
+
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
|
|
2501
|
+
if (!params_base.ctx_shift) {
|
|
2502
|
+
// this check is redundant (for good)
|
|
2503
|
+
// we should never get here, because generation should already stopped in process_token()
|
|
2504
|
+
slot.release();
|
|
2505
|
+
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
|
2506
|
+
continue;
|
|
2507
|
+
}
|
|
1917
2508
|
|
|
1918
|
-
|
|
2509
|
+
// Shift context
|
|
2510
|
+
const int n_keep = slot.params.n_keep + add_bos_token;
|
|
2511
|
+
const int n_left = slot.n_past - n_keep;
|
|
2512
|
+
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
|
1919
2513
|
|
|
1920
|
-
|
|
1921
|
-
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
|
2514
|
+
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
|
1922
2515
|
|
|
1923
|
-
|
|
1924
|
-
|
|
1925
|
-
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
|
1926
|
-
}
|
|
2516
|
+
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
|
2517
|
+
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
|
1927
2518
|
|
|
1928
|
-
|
|
2519
|
+
if (slot.params.cache_prompt) {
|
|
2520
|
+
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
|
2521
|
+
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
|
1929
2522
|
}
|
|
1930
2523
|
|
|
1931
|
-
slot.
|
|
1932
|
-
|
|
1933
|
-
slot.truncated = true;
|
|
2524
|
+
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
|
1934
2525
|
}
|
|
2526
|
+
|
|
2527
|
+
slot.n_past -= n_discard;
|
|
2528
|
+
|
|
2529
|
+
slot.truncated = true;
|
|
1935
2530
|
}
|
|
1936
2531
|
}
|
|
1937
2532
|
|
|
1938
2533
|
// start populating the batch for this iteration
|
|
1939
|
-
|
|
2534
|
+
common_batch_clear(batch);
|
|
1940
2535
|
|
|
1941
2536
|
// frist, add sampled tokens from any ongoing sequences
|
|
1942
2537
|
for (auto & slot : slots) {
|
|
@@ -1946,11 +2541,7 @@ struct server_context {
|
|
|
1946
2541
|
|
|
1947
2542
|
slot.i_batch = batch.n_tokens;
|
|
1948
2543
|
|
|
1949
|
-
|
|
1950
|
-
|
|
1951
|
-
// TODO: we always have to take into account the "system_tokens"
|
|
1952
|
-
// this is not great and needs to be improved somehow
|
|
1953
|
-
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
|
|
2544
|
+
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
|
1954
2545
|
|
|
1955
2546
|
slot.n_past += 1;
|
|
1956
2547
|
|
|
@@ -1958,8 +2549,8 @@ struct server_context {
|
|
|
1958
2549
|
slot.cache_tokens.push_back(slot.sampled);
|
|
1959
2550
|
}
|
|
1960
2551
|
|
|
1961
|
-
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d,
|
|
1962
|
-
slot.n_ctx, slot.n_past, (int)
|
|
2552
|
+
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
|
|
2553
|
+
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
|
|
1963
2554
|
}
|
|
1964
2555
|
|
|
1965
2556
|
// process in chunks of params.n_batch
|
|
@@ -1973,82 +2564,35 @@ struct server_context {
|
|
|
1973
2564
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
1974
2565
|
|
|
1975
2566
|
// next, batch any pending prompts without exceeding n_batch
|
|
1976
|
-
if (
|
|
2567
|
+
if (params_base.cont_batching || batch.n_tokens == 0) {
|
|
1977
2568
|
for (auto & slot : slots) {
|
|
1978
2569
|
// this slot still has a prompt to be processed
|
|
1979
|
-
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
|
2570
|
+
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
|
1980
2571
|
auto & prompt_tokens = slot.prompt_tokens;
|
|
1981
2572
|
|
|
1982
|
-
//
|
|
1983
|
-
if (
|
|
1984
|
-
|
|
1985
|
-
|
|
1986
|
-
|
|
1987
|
-
slot.t_start_generation = 0;
|
|
1988
|
-
|
|
1989
|
-
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
|
|
1990
|
-
const bool add_bos = llama_add_bos_token(model);
|
|
1991
|
-
bool suff_rm_leading_spc = true;
|
|
1992
|
-
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
|
1993
|
-
params.input_suffix.erase(0, 1);
|
|
1994
|
-
suff_rm_leading_spc = false;
|
|
1995
|
-
}
|
|
1996
|
-
|
|
1997
|
-
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
|
|
1998
|
-
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
|
|
1999
|
-
|
|
2000
|
-
const int space_token = 29871; // TODO: this should not be hardcoded
|
|
2001
|
-
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
|
|
2002
|
-
suffix_tokens.erase(suffix_tokens.begin());
|
|
2003
|
-
}
|
|
2004
|
-
|
|
2005
|
-
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
|
2006
|
-
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
|
2007
|
-
|
|
2008
|
-
auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
|
|
2009
|
-
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
|
|
2010
|
-
if (add_bos) {
|
|
2011
|
-
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
|
2012
|
-
}
|
|
2013
|
-
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
|
2014
|
-
|
|
2015
|
-
const llama_token middle_token = llama_token_middle(model);
|
|
2016
|
-
if (middle_token >= 0) {
|
|
2017
|
-
embd_inp.push_back(middle_token);
|
|
2018
|
-
}
|
|
2019
|
-
|
|
2020
|
-
prompt_tokens = embd_inp;
|
|
2021
|
-
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
2022
|
-
// require slot.prompt to be array of 2 strings
|
|
2023
|
-
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
|
2024
|
-
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
|
2025
|
-
slot.release();
|
|
2026
|
-
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
|
2027
|
-
continue;
|
|
2028
|
-
}
|
|
2029
|
-
|
|
2030
|
-
// prompt: [BOS]query[EOS][SEP]doc[EOS]
|
|
2031
|
-
prompt_tokens.clear();
|
|
2032
|
-
prompt_tokens.push_back(llama_token_bos(model));
|
|
2033
|
-
{
|
|
2034
|
-
const auto part = tokenize(slot.prompt[0], false);
|
|
2035
|
-
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
2036
|
-
}
|
|
2037
|
-
prompt_tokens.push_back(llama_token_eos(model));
|
|
2038
|
-
prompt_tokens.push_back(llama_token_sep(model));
|
|
2039
|
-
{
|
|
2040
|
-
const auto part = tokenize(slot.prompt[1], false);
|
|
2041
|
-
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
2042
|
-
}
|
|
2043
|
-
prompt_tokens.push_back(llama_token_eos(model));
|
|
2044
|
-
} else {
|
|
2045
|
-
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
|
2046
|
-
}
|
|
2047
|
-
|
|
2573
|
+
// TODO: maybe move branch to outside of this loop in the future
|
|
2574
|
+
if (slot.state == SLOT_STATE_STARTED) {
|
|
2575
|
+
slot.t_start_process_prompt = ggml_time_us();
|
|
2576
|
+
slot.t_start_generation = 0;
|
|
2577
|
+
|
|
2048
2578
|
slot.n_past = 0;
|
|
2049
2579
|
slot.n_prompt_tokens = prompt_tokens.size();
|
|
2580
|
+
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
|
2581
|
+
|
|
2582
|
+
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
|
2050
2583
|
|
|
2051
|
-
|
|
2584
|
+
// print prompt tokens (for debugging)
|
|
2585
|
+
if (1) {
|
|
2586
|
+
// first 16 tokens (avoid flooding logs)
|
|
2587
|
+
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
|
|
2588
|
+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
|
2589
|
+
}
|
|
2590
|
+
} else {
|
|
2591
|
+
// all
|
|
2592
|
+
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
|
2593
|
+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
|
2594
|
+
}
|
|
2595
|
+
}
|
|
2052
2596
|
|
|
2053
2597
|
// empty prompt passed -> release the slot and send empty response
|
|
2054
2598
|
if (prompt_tokens.empty()) {
|
|
@@ -2060,17 +2604,24 @@ struct server_context {
|
|
|
2060
2604
|
continue;
|
|
2061
2605
|
}
|
|
2062
2606
|
|
|
2063
|
-
if (slot.
|
|
2064
|
-
// this prompt is too large to process - discard it
|
|
2607
|
+
if (slot.is_non_causal()) {
|
|
2065
2608
|
if (slot.n_prompt_tokens > n_ubatch) {
|
|
2066
2609
|
slot.release();
|
|
2067
2610
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
|
2068
2611
|
continue;
|
|
2069
2612
|
}
|
|
2613
|
+
|
|
2614
|
+
if (slot.n_prompt_tokens > slot.n_ctx) {
|
|
2615
|
+
slot.release();
|
|
2616
|
+
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
|
|
2617
|
+
continue;
|
|
2618
|
+
}
|
|
2070
2619
|
} else {
|
|
2071
|
-
if (!
|
|
2620
|
+
if (!params_base.ctx_shift) {
|
|
2072
2621
|
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
|
2073
|
-
|
|
2622
|
+
// TODO: there should be a separate parameter that control prompt truncation
|
|
2623
|
+
// context shift should be applied only during the generation phase
|
|
2624
|
+
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
|
2074
2625
|
slot.release();
|
|
2075
2626
|
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
|
2076
2627
|
continue;
|
|
@@ -2081,14 +2632,14 @@ struct server_context {
|
|
|
2081
2632
|
}
|
|
2082
2633
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
|
2083
2634
|
|
|
2084
|
-
// if input prompt is too big, truncate it
|
|
2085
|
-
if (slot.
|
|
2635
|
+
// if input prompt is too big, truncate it
|
|
2636
|
+
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
|
2086
2637
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
|
2087
2638
|
|
|
2088
2639
|
const int n_block_size = n_left / 2;
|
|
2089
2640
|
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
|
2090
2641
|
|
|
2091
|
-
|
|
2642
|
+
llama_tokens new_tokens(
|
|
2092
2643
|
prompt_tokens.begin(),
|
|
2093
2644
|
prompt_tokens.begin() + slot.params.n_keep);
|
|
2094
2645
|
|
|
@@ -2107,20 +2658,52 @@ struct server_context {
|
|
|
2107
2658
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
|
2108
2659
|
}
|
|
2109
2660
|
|
|
2110
|
-
|
|
2661
|
+
if (slot.params.cache_prompt) {
|
|
2662
|
+
// reuse any previously computed tokens that are common with the new prompt
|
|
2663
|
+
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
|
|
2111
2664
|
|
|
2112
|
-
|
|
2113
|
-
|
|
2114
|
-
|
|
2115
|
-
|
|
2116
|
-
GGML_ASSERT(slot.ga_n == 1);
|
|
2665
|
+
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
|
2666
|
+
if (params_base.n_cache_reuse > 0) {
|
|
2667
|
+
size_t head_c = slot.n_past; // cache
|
|
2668
|
+
size_t head_p = slot.n_past; // current prompt
|
|
2117
2669
|
|
|
2118
|
-
|
|
2119
|
-
|
|
2670
|
+
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
|
|
2671
|
+
|
|
2672
|
+
while (head_c < slot.cache_tokens.size() &&
|
|
2673
|
+
head_p < prompt_tokens.size()) {
|
|
2674
|
+
|
|
2675
|
+
size_t n_match = 0;
|
|
2676
|
+
while (head_c + n_match < slot.cache_tokens.size() &&
|
|
2677
|
+
head_p + n_match < prompt_tokens.size() &&
|
|
2678
|
+
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
|
|
2120
2679
|
|
|
2121
|
-
|
|
2122
|
-
|
|
2123
|
-
|
|
2680
|
+
n_match++;
|
|
2681
|
+
}
|
|
2682
|
+
|
|
2683
|
+
if (n_match >= (size_t) params_base.n_cache_reuse) {
|
|
2684
|
+
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
|
2685
|
+
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
|
2686
|
+
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
|
2687
|
+
//}
|
|
2688
|
+
|
|
2689
|
+
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
|
2690
|
+
|
|
2691
|
+
llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
|
|
2692
|
+
llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
|
|
2693
|
+
|
|
2694
|
+
for (size_t i = 0; i < n_match; i++) {
|
|
2695
|
+
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
|
2696
|
+
slot.n_past++;
|
|
2697
|
+
}
|
|
2698
|
+
|
|
2699
|
+
head_c += n_match;
|
|
2700
|
+
head_p += n_match;
|
|
2701
|
+
} else {
|
|
2702
|
+
head_c += 1;
|
|
2703
|
+
}
|
|
2704
|
+
}
|
|
2705
|
+
|
|
2706
|
+
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
|
2124
2707
|
}
|
|
2125
2708
|
}
|
|
2126
2709
|
}
|
|
@@ -2130,16 +2713,13 @@ struct server_context {
|
|
|
2130
2713
|
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
|
2131
2714
|
|
|
2132
2715
|
slot.n_past--;
|
|
2133
|
-
if (slot.ga_i > 0) {
|
|
2134
|
-
slot.n_past_se--;
|
|
2135
|
-
}
|
|
2136
2716
|
}
|
|
2137
2717
|
|
|
2138
2718
|
slot.n_prompt_tokens_processed = 0;
|
|
2139
2719
|
}
|
|
2140
2720
|
|
|
2141
2721
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
|
2142
|
-
if (slot.
|
|
2722
|
+
if (slot.is_non_causal()) {
|
|
2143
2723
|
// cannot fit the prompt in the current batch - will try next iter
|
|
2144
2724
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
|
2145
2725
|
continue;
|
|
@@ -2147,10 +2727,7 @@ struct server_context {
|
|
|
2147
2727
|
}
|
|
2148
2728
|
|
|
2149
2729
|
// check that we are in the right batch_type, if not defer the slot
|
|
2150
|
-
|
|
2151
|
-
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
|
2152
|
-
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
|
2153
|
-
|
|
2730
|
+
int slot_type = slot.is_non_causal();
|
|
2154
2731
|
if (batch_type == -1) {
|
|
2155
2732
|
batch_type = slot_type;
|
|
2156
2733
|
} else if (batch_type != slot_type) {
|
|
@@ -2158,55 +2735,32 @@ struct server_context {
|
|
|
2158
2735
|
}
|
|
2159
2736
|
|
|
2160
2737
|
// keep only the common part
|
|
2161
|
-
|
|
2162
|
-
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
|
2738
|
+
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
|
|
2163
2739
|
// could not partially delete (likely using a non-Transformer model)
|
|
2164
|
-
llama_kv_cache_seq_rm(ctx, slot.id
|
|
2740
|
+
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
|
|
2165
2741
|
|
|
2166
|
-
|
|
2167
|
-
if (p0 != 0) {
|
|
2168
|
-
// copy over the system prompt when there is one
|
|
2169
|
-
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
|
|
2170
|
-
}
|
|
2171
|
-
|
|
2172
|
-
// there is no common part left (except for the system prompt)
|
|
2742
|
+
// there is no common part left
|
|
2173
2743
|
slot.n_past = 0;
|
|
2174
|
-
slot.n_past_se = 0;
|
|
2175
|
-
slot.ga_i = 0;
|
|
2176
|
-
// TODO: is the system prompt ever in the sampling context?
|
|
2177
|
-
gpt_sampler_reset(slot.smpl);
|
|
2178
2744
|
}
|
|
2179
2745
|
|
|
2746
|
+
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
|
2747
|
+
|
|
2180
2748
|
// remove the non-common part from the cache
|
|
2181
2749
|
slot.cache_tokens.resize(slot.n_past);
|
|
2182
2750
|
|
|
2183
|
-
SLT_INF(slot, "kv cache rm [%d, end)\n", p0);
|
|
2184
|
-
|
|
2185
|
-
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
|
2186
|
-
|
|
2187
|
-
int32_t ga_i = slot.ga_i;
|
|
2188
|
-
int32_t ga_n = slot.ga_n;
|
|
2189
|
-
int32_t ga_w = slot.ga_w;
|
|
2190
|
-
|
|
2191
2751
|
// add prompt tokens for processing in the current batch
|
|
2192
|
-
|
|
2193
|
-
|
|
2194
|
-
|
|
2195
|
-
while (slot_npast >= ga_i + ga_w) {
|
|
2196
|
-
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
|
2197
|
-
slot_npast -= bd;
|
|
2198
|
-
ga_i += ga_w/ga_n;
|
|
2199
|
-
}
|
|
2200
|
-
}
|
|
2752
|
+
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
|
2753
|
+
// without pooling, we want to output the embeddings for all the tokens in the batch
|
|
2754
|
+
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
|
2201
2755
|
|
|
2202
|
-
|
|
2756
|
+
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
|
2203
2757
|
|
|
2204
2758
|
if (slot.params.cache_prompt) {
|
|
2205
2759
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
|
2206
2760
|
}
|
|
2207
2761
|
|
|
2208
2762
|
slot.n_prompt_tokens_processed++;
|
|
2209
|
-
|
|
2763
|
+
slot.n_past++;
|
|
2210
2764
|
}
|
|
2211
2765
|
|
|
2212
2766
|
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
|
@@ -2217,6 +2771,13 @@ struct server_context {
|
|
|
2217
2771
|
|
|
2218
2772
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
2219
2773
|
|
|
2774
|
+
common_sampler_reset(slot.smpl);
|
|
2775
|
+
|
|
2776
|
+
// Process all prompt tokens through sampler system
|
|
2777
|
+
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
|
2778
|
+
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
|
|
2779
|
+
}
|
|
2780
|
+
|
|
2220
2781
|
// extract the logits only for the last token
|
|
2221
2782
|
batch.logits[batch.n_tokens - 1] = true;
|
|
2222
2783
|
|
|
@@ -2247,34 +2808,6 @@ struct server_context {
|
|
|
2247
2808
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
|
2248
2809
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
|
2249
2810
|
|
|
2250
|
-
for (auto & slot : slots) {
|
|
2251
|
-
if (slot.ga_n != 1) {
|
|
2252
|
-
// context extension via Self-Extend
|
|
2253
|
-
// TODO: simplify and/or abstract this
|
|
2254
|
-
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
|
|
2255
|
-
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
|
2256
|
-
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
|
2257
|
-
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
|
2258
|
-
|
|
2259
|
-
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
|
2260
|
-
SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
|
2261
|
-
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
|
2262
|
-
|
|
2263
|
-
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
|
2264
|
-
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
|
2265
|
-
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
|
2266
|
-
|
|
2267
|
-
slot.n_past_se -= bd;
|
|
2268
|
-
|
|
2269
|
-
slot.ga_i += slot.ga_w / slot.ga_n;
|
|
2270
|
-
|
|
2271
|
-
SLT_DBG(slot, "\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
|
2272
|
-
}
|
|
2273
|
-
|
|
2274
|
-
slot.n_past_se += n_tokens;
|
|
2275
|
-
}
|
|
2276
|
-
}
|
|
2277
|
-
|
|
2278
2811
|
llama_batch batch_view = {
|
|
2279
2812
|
n_tokens,
|
|
2280
2813
|
batch.token + i,
|
|
@@ -2283,7 +2816,6 @@ struct server_context {
|
|
|
2283
2816
|
batch.n_seq_id + i,
|
|
2284
2817
|
batch.seq_id + i,
|
|
2285
2818
|
batch.logits + i,
|
|
2286
|
-
0, 0, 0, // unused
|
|
2287
2819
|
};
|
|
2288
2820
|
|
|
2289
2821
|
const int ret = llama_decode(ctx, batch_view);
|
|
@@ -2315,7 +2847,7 @@ struct server_context {
|
|
|
2315
2847
|
}
|
|
2316
2848
|
|
|
2317
2849
|
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
|
2318
|
-
if (slot.
|
|
2850
|
+
if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
|
|
2319
2851
|
// prompt evaluated for embedding
|
|
2320
2852
|
send_embedding(slot, batch_view);
|
|
2321
2853
|
slot.release();
|
|
@@ -2323,7 +2855,7 @@ struct server_context {
|
|
|
2323
2855
|
continue; // continue loop of slots
|
|
2324
2856
|
}
|
|
2325
2857
|
|
|
2326
|
-
if (slot.
|
|
2858
|
+
if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
|
|
2327
2859
|
send_rerank(slot, batch_view);
|
|
2328
2860
|
slot.release();
|
|
2329
2861
|
slot.i_batch = -1;
|
|
@@ -2336,27 +2868,33 @@ struct server_context {
|
|
|
2336
2868
|
continue; // continue loop of slots
|
|
2337
2869
|
}
|
|
2338
2870
|
|
|
2339
|
-
|
|
2340
|
-
|
|
2871
|
+
const int tok_idx = slot.i_batch - i;
|
|
2872
|
+
|
|
2873
|
+
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
|
2341
2874
|
|
|
2342
|
-
|
|
2875
|
+
slot.i_batch = -1;
|
|
2876
|
+
|
|
2877
|
+
common_sampler_accept(slot.smpl, id, true);
|
|
2343
2878
|
|
|
2344
2879
|
slot.n_decoded += 1;
|
|
2880
|
+
|
|
2881
|
+
const int64_t t_current = ggml_time_us();
|
|
2882
|
+
|
|
2345
2883
|
if (slot.n_decoded == 1) {
|
|
2346
|
-
slot.t_start_generation =
|
|
2884
|
+
slot.t_start_generation = t_current;
|
|
2347
2885
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
|
2348
2886
|
metrics.on_prompt_eval(slot);
|
|
2349
2887
|
}
|
|
2350
2888
|
|
|
2351
|
-
|
|
2889
|
+
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
|
2352
2890
|
|
|
2353
|
-
|
|
2891
|
+
completion_token_output result;
|
|
2892
|
+
result.tok = id;
|
|
2893
|
+
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
|
2894
|
+
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
|
2354
2895
|
|
|
2355
|
-
|
|
2356
|
-
result.
|
|
2357
|
-
cur_p->data[i].id,
|
|
2358
|
-
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
|
2359
|
-
});
|
|
2896
|
+
if (slot.params.sampling.n_probs > 0) {
|
|
2897
|
+
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
|
2360
2898
|
}
|
|
2361
2899
|
|
|
2362
2900
|
if (!process_token(result, slot)) {
|
|
@@ -2365,9 +2903,98 @@ struct server_context {
|
|
|
2365
2903
|
slot.print_timings();
|
|
2366
2904
|
send_final_response(slot);
|
|
2367
2905
|
metrics.on_prediction(slot);
|
|
2906
|
+
continue;
|
|
2907
|
+
}
|
|
2908
|
+
}
|
|
2909
|
+
|
|
2910
|
+
// do speculative decoding
|
|
2911
|
+
for (auto & slot : slots) {
|
|
2912
|
+
if (!slot.is_processing() || !slot.can_speculate()) {
|
|
2913
|
+
continue;
|
|
2368
2914
|
}
|
|
2369
2915
|
|
|
2370
|
-
slot.
|
|
2916
|
+
if (slot.state != SLOT_STATE_GENERATING) {
|
|
2917
|
+
continue;
|
|
2918
|
+
}
|
|
2919
|
+
|
|
2920
|
+
// determine the max draft that fits the current slot state
|
|
2921
|
+
int n_draft_max = slot.params.speculative.n_max;
|
|
2922
|
+
|
|
2923
|
+
// note: n_past is not yet increased for the `id` token sampled above
|
|
2924
|
+
// also, need to leave space for 1 extra token to allow context shifts
|
|
2925
|
+
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
|
|
2926
|
+
|
|
2927
|
+
if (slot.n_remaining > 0) {
|
|
2928
|
+
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
|
|
2929
|
+
}
|
|
2930
|
+
|
|
2931
|
+
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
|
|
2932
|
+
|
|
2933
|
+
if (n_draft_max < slot.params.speculative.n_min) {
|
|
2934
|
+
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
|
|
2935
|
+
|
|
2936
|
+
continue;
|
|
2937
|
+
}
|
|
2938
|
+
|
|
2939
|
+
llama_token id = slot.sampled;
|
|
2940
|
+
|
|
2941
|
+
struct common_speculative_params params_spec;
|
|
2942
|
+
params_spec.n_draft = n_draft_max;
|
|
2943
|
+
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
|
2944
|
+
params_spec.p_min = slot.params.speculative.p_min;
|
|
2945
|
+
|
|
2946
|
+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
|
|
2947
|
+
|
|
2948
|
+
// ignore small drafts
|
|
2949
|
+
if (slot.params.speculative.n_min > (int) draft.size()) {
|
|
2950
|
+
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
|
|
2951
|
+
|
|
2952
|
+
continue;
|
|
2953
|
+
}
|
|
2954
|
+
|
|
2955
|
+
// construct the speculation batch
|
|
2956
|
+
common_batch_clear(slot.batch_spec);
|
|
2957
|
+
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
|
|
2958
|
+
|
|
2959
|
+
for (size_t i = 0; i < draft.size(); ++i) {
|
|
2960
|
+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
|
2961
|
+
}
|
|
2962
|
+
|
|
2963
|
+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
|
2964
|
+
|
|
2965
|
+
llama_decode(ctx, slot.batch_spec);
|
|
2966
|
+
|
|
2967
|
+
// the accepted tokens from the speculation
|
|
2968
|
+
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
|
2969
|
+
|
|
2970
|
+
slot.n_past += ids.size();
|
|
2971
|
+
slot.n_decoded += ids.size();
|
|
2972
|
+
|
|
2973
|
+
slot.cache_tokens.push_back(id);
|
|
2974
|
+
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
|
|
2975
|
+
|
|
2976
|
+
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
|
2977
|
+
|
|
2978
|
+
for (size_t i = 0; i < ids.size(); ++i) {
|
|
2979
|
+
completion_token_output result;
|
|
2980
|
+
|
|
2981
|
+
result.tok = ids[i];
|
|
2982
|
+
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
|
2983
|
+
result.prob = 1.0f; // set later
|
|
2984
|
+
|
|
2985
|
+
// TODO: set result.probs
|
|
2986
|
+
|
|
2987
|
+
if (!process_token(result, slot)) {
|
|
2988
|
+
// release slot because of stop condition
|
|
2989
|
+
slot.release();
|
|
2990
|
+
slot.print_timings();
|
|
2991
|
+
send_final_response(slot);
|
|
2992
|
+
metrics.on_prediction(slot);
|
|
2993
|
+
break;
|
|
2994
|
+
}
|
|
2995
|
+
}
|
|
2996
|
+
|
|
2997
|
+
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
|
|
2371
2998
|
}
|
|
2372
2999
|
}
|
|
2373
3000
|
|
|
@@ -2414,35 +3041,23 @@ inline void signal_handler(int signal) {
|
|
|
2414
3041
|
|
|
2415
3042
|
int main(int argc, char ** argv) {
|
|
2416
3043
|
// own arguments required by this example
|
|
2417
|
-
|
|
3044
|
+
common_params params;
|
|
2418
3045
|
|
|
2419
|
-
if (!
|
|
3046
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
|
|
2420
3047
|
return 1;
|
|
2421
3048
|
}
|
|
2422
3049
|
|
|
2423
|
-
|
|
2424
|
-
|
|
2425
|
-
// enabling this will output extra debug information in the HTTP responses from the server
|
|
2426
|
-
// see format_final_response_oaicompat()
|
|
2427
|
-
const bool verbose = params.verbosity > 9;
|
|
3050
|
+
common_init();
|
|
2428
3051
|
|
|
2429
3052
|
// struct that contains llama context and inference
|
|
2430
3053
|
server_context ctx_server;
|
|
2431
3054
|
|
|
2432
|
-
if (!params.system_prompt.empty()) {
|
|
2433
|
-
ctx_server.system_prompt_set(params.system_prompt);
|
|
2434
|
-
}
|
|
2435
|
-
|
|
2436
|
-
if (params.model_alias == "unknown") {
|
|
2437
|
-
params.model_alias = params.model;
|
|
2438
|
-
}
|
|
2439
|
-
|
|
2440
3055
|
llama_backend_init();
|
|
2441
3056
|
llama_numa_init(params.numa);
|
|
2442
3057
|
|
|
2443
3058
|
LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
|
|
2444
3059
|
LOG_INF("\n");
|
|
2445
|
-
LOG_INF("%s\n",
|
|
3060
|
+
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
|
2446
3061
|
LOG_INF("\n");
|
|
2447
3062
|
|
|
2448
3063
|
std::unique_ptr<httplib::Server> svr;
|
|
@@ -2467,34 +3082,24 @@ int main(int argc, char ** argv) {
|
|
|
2467
3082
|
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
|
|
2468
3083
|
|
|
2469
3084
|
svr->set_default_headers({{"Server", "llama.cpp"}});
|
|
2470
|
-
|
|
2471
|
-
// CORS preflight
|
|
2472
|
-
svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
|
|
2473
|
-
// Access-Control-Allow-Origin is already set by middleware
|
|
2474
|
-
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
2475
|
-
res.set_header("Access-Control-Allow-Methods", "POST");
|
|
2476
|
-
res.set_header("Access-Control-Allow-Headers", "*");
|
|
2477
|
-
return res.set_content("", "text/html"); // blank response, no data
|
|
2478
|
-
});
|
|
2479
|
-
|
|
2480
3085
|
svr->set_logger(log_server_request);
|
|
2481
3086
|
|
|
2482
3087
|
auto res_error = [](httplib::Response & res, const json & error_data) {
|
|
2483
3088
|
json final_response {{"error", error_data}};
|
|
2484
|
-
res.set_content(final_response
|
|
3089
|
+
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
|
|
2485
3090
|
res.status = json_value(error_data, "code", 500);
|
|
2486
3091
|
};
|
|
2487
3092
|
|
|
2488
3093
|
auto res_ok = [](httplib::Response & res, const json & data) {
|
|
2489
|
-
res.set_content(data
|
|
3094
|
+
res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
|
|
2490
3095
|
res.status = 200;
|
|
2491
3096
|
};
|
|
2492
3097
|
|
|
2493
|
-
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
|
3098
|
+
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
|
|
2494
3099
|
std::string message;
|
|
2495
3100
|
try {
|
|
2496
3101
|
std::rethrow_exception(ep);
|
|
2497
|
-
} catch (std::exception & e) {
|
|
3102
|
+
} catch (const std::exception & e) {
|
|
2498
3103
|
message = e.what();
|
|
2499
3104
|
} catch (...) {
|
|
2500
3105
|
message = "Unknown Exception";
|
|
@@ -2536,20 +3141,10 @@ int main(int argc, char ** argv) {
|
|
|
2536
3141
|
//
|
|
2537
3142
|
|
|
2538
3143
|
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
2539
|
-
|
|
2540
|
-
|
|
2541
|
-
"/
|
|
2542
|
-
"/
|
|
2543
|
-
"/completions",
|
|
2544
|
-
"/v1/completions",
|
|
2545
|
-
"/chat/completions",
|
|
2546
|
-
"/v1/chat/completions",
|
|
2547
|
-
"/infill",
|
|
2548
|
-
"/tokenize",
|
|
2549
|
-
"/detokenize",
|
|
2550
|
-
"/embedding",
|
|
2551
|
-
"/embeddings",
|
|
2552
|
-
"/v1/embeddings",
|
|
3144
|
+
static const std::unordered_set<std::string> public_endpoints = {
|
|
3145
|
+
"/health",
|
|
3146
|
+
"/models",
|
|
3147
|
+
"/v1/models",
|
|
2553
3148
|
};
|
|
2554
3149
|
|
|
2555
3150
|
// If API key is not set, skip validation
|
|
@@ -2557,8 +3152,8 @@ int main(int argc, char ** argv) {
|
|
|
2557
3152
|
return true;
|
|
2558
3153
|
}
|
|
2559
3154
|
|
|
2560
|
-
// If path is
|
|
2561
|
-
if (
|
|
3155
|
+
// If path is public or is static file, skip validation
|
|
3156
|
+
if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
|
|
2562
3157
|
return true;
|
|
2563
3158
|
}
|
|
2564
3159
|
|
|
@@ -2584,7 +3179,7 @@ int main(int argc, char ** argv) {
|
|
|
2584
3179
|
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
|
2585
3180
|
server_state current_state = state.load();
|
|
2586
3181
|
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
|
2587
|
-
auto tmp = string_split(req.path, '.');
|
|
3182
|
+
auto tmp = string_split<std::string>(req.path, '.');
|
|
2588
3183
|
if (req.path == "/" || tmp.back() == "html") {
|
|
2589
3184
|
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
|
2590
3185
|
res.status = 503;
|
|
@@ -2599,6 +3194,14 @@ int main(int argc, char ** argv) {
|
|
|
2599
3194
|
// register server middlewares
|
|
2600
3195
|
svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
|
2601
3196
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
3197
|
+
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
|
|
3198
|
+
if (req.method == "OPTIONS") {
|
|
3199
|
+
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
3200
|
+
res.set_header("Access-Control-Allow-Methods", "GET, POST");
|
|
3201
|
+
res.set_header("Access-Control-Allow-Headers", "*");
|
|
3202
|
+
res.set_content("", "text/html"); // blank response, no data
|
|
3203
|
+
return httplib::Server::HandlerResponse::Handled; // skip further processing
|
|
3204
|
+
}
|
|
2602
3205
|
if (!middleware_server_state(req, res)) {
|
|
2603
3206
|
return httplib::Server::HandlerResponse::Handled;
|
|
2604
3207
|
}
|
|
@@ -2620,32 +3223,38 @@ int main(int argc, char ** argv) {
|
|
|
2620
3223
|
|
|
2621
3224
|
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
|
|
2622
3225
|
if (!params.endpoint_slots) {
|
|
2623
|
-
res_error(res, format_error_response("This server does not support slots endpoint. Start it
|
|
3226
|
+
res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2624
3227
|
return;
|
|
2625
3228
|
}
|
|
2626
3229
|
|
|
2627
3230
|
// request slots data using task queue
|
|
2628
|
-
server_task task;
|
|
3231
|
+
server_task task(SERVER_TASK_TYPE_METRICS);
|
|
2629
3232
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2630
|
-
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2631
|
-
|
|
2632
3233
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2633
3234
|
ctx_server.queue_tasks.post(task, true); // high-priority task
|
|
2634
3235
|
|
|
2635
3236
|
// get the result
|
|
2636
|
-
|
|
3237
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
2637
3238
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2638
3239
|
|
|
3240
|
+
if (result->is_error()) {
|
|
3241
|
+
res_error(res, result->to_json());
|
|
3242
|
+
return;
|
|
3243
|
+
}
|
|
3244
|
+
|
|
3245
|
+
// TODO: get rid of this dynamic_cast
|
|
3246
|
+
auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
|
|
3247
|
+
GGML_ASSERT(res_metrics != nullptr);
|
|
3248
|
+
|
|
2639
3249
|
// optionally return "fail_on_no_slot" error
|
|
2640
|
-
const int n_idle_slots = result.data.at("idle");
|
|
2641
3250
|
if (req.has_param("fail_on_no_slot")) {
|
|
2642
|
-
if (n_idle_slots == 0) {
|
|
3251
|
+
if (res_metrics->n_idle_slots == 0) {
|
|
2643
3252
|
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
|
2644
3253
|
return;
|
|
2645
3254
|
}
|
|
2646
3255
|
}
|
|
2647
3256
|
|
|
2648
|
-
res_ok(res,
|
|
3257
|
+
res_ok(res, res_metrics->slots_data);
|
|
2649
3258
|
};
|
|
2650
3259
|
|
|
2651
3260
|
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
|
@@ -2655,83 +3264,77 @@ int main(int argc, char ** argv) {
|
|
|
2655
3264
|
}
|
|
2656
3265
|
|
|
2657
3266
|
// request slots data using task queue
|
|
2658
|
-
server_task task;
|
|
3267
|
+
server_task task(SERVER_TASK_TYPE_METRICS);
|
|
2659
3268
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2660
|
-
task.
|
|
2661
|
-
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2662
|
-
task.data.push_back({{"reset_bucket", true}});
|
|
3269
|
+
task.metrics_reset_bucket = true;
|
|
2663
3270
|
|
|
2664
3271
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2665
3272
|
ctx_server.queue_tasks.post(task, true); // high-priority task
|
|
2666
3273
|
|
|
2667
3274
|
// get the result
|
|
2668
|
-
|
|
3275
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
2669
3276
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2670
3277
|
|
|
2671
|
-
|
|
2672
|
-
|
|
2673
|
-
|
|
2674
|
-
|
|
2675
|
-
|
|
2676
|
-
const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
|
|
2677
|
-
const uint64_t t_tokens_generation = data.at("t_tokens_generation");
|
|
2678
|
-
|
|
2679
|
-
const uint64_t n_decode_total = data.at("n_decode_total");
|
|
2680
|
-
const uint64_t n_busy_slots_total = data.at("n_busy_slots_total");
|
|
3278
|
+
if (result->is_error()) {
|
|
3279
|
+
res_error(res, result->to_json());
|
|
3280
|
+
return;
|
|
3281
|
+
}
|
|
2681
3282
|
|
|
2682
|
-
|
|
3283
|
+
// TODO: get rid of this dynamic_cast
|
|
3284
|
+
auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
|
|
3285
|
+
GGML_ASSERT(res_metrics != nullptr);
|
|
2683
3286
|
|
|
2684
3287
|
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
|
|
2685
3288
|
json all_metrics_def = json {
|
|
2686
3289
|
{"counter", {{
|
|
2687
3290
|
{"name", "prompt_tokens_total"},
|
|
2688
3291
|
{"help", "Number of prompt tokens processed."},
|
|
2689
|
-
{"value", (uint64_t)
|
|
3292
|
+
{"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total}
|
|
2690
3293
|
}, {
|
|
2691
3294
|
{"name", "prompt_seconds_total"},
|
|
2692
3295
|
{"help", "Prompt process time"},
|
|
2693
|
-
{"value", (uint64_t)
|
|
3296
|
+
{"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3}
|
|
2694
3297
|
}, {
|
|
2695
3298
|
{"name", "tokens_predicted_total"},
|
|
2696
3299
|
{"help", "Number of generation tokens processed."},
|
|
2697
|
-
{"value", (uint64_t)
|
|
3300
|
+
{"value", (uint64_t) res_metrics->n_tokens_predicted_total}
|
|
2698
3301
|
}, {
|
|
2699
3302
|
{"name", "tokens_predicted_seconds_total"},
|
|
2700
3303
|
{"help", "Predict process time"},
|
|
2701
|
-
{"value", (uint64_t)
|
|
3304
|
+
{"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3}
|
|
2702
3305
|
}, {
|
|
2703
3306
|
{"name", "n_decode_total"},
|
|
2704
3307
|
{"help", "Total number of llama_decode() calls"},
|
|
2705
|
-
{"value", n_decode_total}
|
|
3308
|
+
{"value", res_metrics->n_decode_total}
|
|
2706
3309
|
}, {
|
|
2707
3310
|
{"name", "n_busy_slots_per_decode"},
|
|
2708
3311
|
{"help", "Average number of busy slots per llama_decode() call"},
|
|
2709
|
-
{"value", (float) n_busy_slots_total / (float) n_decode_total}
|
|
3312
|
+
{"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total}
|
|
2710
3313
|
}}},
|
|
2711
3314
|
{"gauge", {{
|
|
2712
3315
|
{"name", "prompt_tokens_seconds"},
|
|
2713
3316
|
{"help", "Average prompt throughput in tokens/s."},
|
|
2714
|
-
{"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
|
|
3317
|
+
{"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.}
|
|
2715
3318
|
},{
|
|
2716
3319
|
{"name", "predicted_tokens_seconds"},
|
|
2717
3320
|
{"help", "Average generation throughput in tokens/s."},
|
|
2718
|
-
{"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
|
|
3321
|
+
{"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.}
|
|
2719
3322
|
},{
|
|
2720
3323
|
{"name", "kv_cache_usage_ratio"},
|
|
2721
3324
|
{"help", "KV-cache usage. 1 means 100 percent usage."},
|
|
2722
|
-
{"value", 1. * kv_cache_used_cells / params.n_ctx}
|
|
3325
|
+
{"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx}
|
|
2723
3326
|
},{
|
|
2724
3327
|
{"name", "kv_cache_tokens"},
|
|
2725
3328
|
{"help", "KV-cache tokens."},
|
|
2726
|
-
{"value", (uint64_t)
|
|
3329
|
+
{"value", (uint64_t) res_metrics->kv_cache_tokens_count}
|
|
2727
3330
|
},{
|
|
2728
3331
|
{"name", "requests_processing"},
|
|
2729
3332
|
{"help", "Number of request processing."},
|
|
2730
|
-
{"value", (uint64_t)
|
|
3333
|
+
{"value", (uint64_t) res_metrics->n_processing_slots}
|
|
2731
3334
|
},{
|
|
2732
3335
|
{"name", "requests_deferred"},
|
|
2733
3336
|
{"help", "Number of request deferred."},
|
|
2734
|
-
{"value", (uint64_t)
|
|
3337
|
+
{"value", (uint64_t) res_metrics->n_tasks_deferred}
|
|
2735
3338
|
}}}
|
|
2736
3339
|
};
|
|
2737
3340
|
|
|
@@ -2752,8 +3355,7 @@ int main(int argc, char ** argv) {
|
|
|
2752
3355
|
}
|
|
2753
3356
|
}
|
|
2754
3357
|
|
|
2755
|
-
|
|
2756
|
-
res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
|
|
3358
|
+
res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start));
|
|
2757
3359
|
|
|
2758
3360
|
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
|
|
2759
3361
|
res.status = 200; // HTTP OK
|
|
@@ -2768,25 +3370,24 @@ int main(int argc, char ** argv) {
|
|
|
2768
3370
|
}
|
|
2769
3371
|
std::string filepath = params.slot_save_path + filename;
|
|
2770
3372
|
|
|
2771
|
-
server_task task;
|
|
2772
|
-
task.
|
|
2773
|
-
task.
|
|
2774
|
-
|
|
2775
|
-
|
|
2776
|
-
{ "filepath", filepath },
|
|
2777
|
-
};
|
|
3373
|
+
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
|
|
3374
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3375
|
+
task.slot_action.slot_id = id_slot;
|
|
3376
|
+
task.slot_action.filename = filename;
|
|
3377
|
+
task.slot_action.filepath = filepath;
|
|
2778
3378
|
|
|
2779
|
-
|
|
2780
|
-
ctx_server.
|
|
3379
|
+
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
3380
|
+
ctx_server.queue_tasks.post(task);
|
|
2781
3381
|
|
|
2782
|
-
|
|
2783
|
-
ctx_server.queue_results.remove_waiting_task_id(
|
|
3382
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
3383
|
+
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2784
3384
|
|
|
2785
|
-
if (result
|
|
2786
|
-
res_error(res, result
|
|
2787
|
-
|
|
2788
|
-
res_ok(res, result.data);
|
|
3385
|
+
if (result->is_error()) {
|
|
3386
|
+
res_error(res, result->to_json());
|
|
3387
|
+
return;
|
|
2789
3388
|
}
|
|
3389
|
+
|
|
3390
|
+
res_ok(res, result->to_json());
|
|
2790
3391
|
};
|
|
2791
3392
|
|
|
2792
3393
|
const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
@@ -2798,45 +3399,45 @@ int main(int argc, char ** argv) {
|
|
|
2798
3399
|
}
|
|
2799
3400
|
std::string filepath = params.slot_save_path + filename;
|
|
2800
3401
|
|
|
2801
|
-
server_task task;
|
|
2802
|
-
task.
|
|
2803
|
-
task.
|
|
2804
|
-
|
|
2805
|
-
|
|
2806
|
-
{ "filepath", filepath },
|
|
2807
|
-
};
|
|
3402
|
+
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
|
|
3403
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3404
|
+
task.slot_action.slot_id = id_slot;
|
|
3405
|
+
task.slot_action.filename = filename;
|
|
3406
|
+
task.slot_action.filepath = filepath;
|
|
2808
3407
|
|
|
2809
|
-
|
|
2810
|
-
ctx_server.
|
|
3408
|
+
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
3409
|
+
ctx_server.queue_tasks.post(task);
|
|
2811
3410
|
|
|
2812
|
-
|
|
2813
|
-
ctx_server.queue_results.remove_waiting_task_id(
|
|
3411
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
3412
|
+
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2814
3413
|
|
|
2815
|
-
if (result
|
|
2816
|
-
res_error(res, result
|
|
2817
|
-
|
|
2818
|
-
res_ok(res, result.data);
|
|
3414
|
+
if (result->is_error()) {
|
|
3415
|
+
res_error(res, result->to_json());
|
|
3416
|
+
return;
|
|
2819
3417
|
}
|
|
3418
|
+
|
|
3419
|
+
GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
|
|
3420
|
+
res_ok(res, result->to_json());
|
|
2820
3421
|
};
|
|
2821
3422
|
|
|
2822
3423
|
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
|
2823
|
-
server_task task;
|
|
2824
|
-
task.
|
|
2825
|
-
task.
|
|
2826
|
-
{ "id_slot", id_slot },
|
|
2827
|
-
};
|
|
3424
|
+
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
|
|
3425
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3426
|
+
task.slot_action.slot_id = id_slot;
|
|
2828
3427
|
|
|
2829
|
-
|
|
2830
|
-
ctx_server.
|
|
3428
|
+
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
3429
|
+
ctx_server.queue_tasks.post(task);
|
|
2831
3430
|
|
|
2832
|
-
|
|
2833
|
-
ctx_server.queue_results.remove_waiting_task_id(
|
|
3431
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
3432
|
+
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2834
3433
|
|
|
2835
|
-
if (result
|
|
2836
|
-
res_error(res, result
|
|
2837
|
-
|
|
2838
|
-
res_ok(res, result.data);
|
|
3434
|
+
if (result->is_error()) {
|
|
3435
|
+
res_error(res, result->to_json());
|
|
3436
|
+
return;
|
|
2839
3437
|
}
|
|
3438
|
+
|
|
3439
|
+
GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
|
|
3440
|
+
res_ok(res, result->to_json());
|
|
2840
3441
|
};
|
|
2841
3442
|
|
|
2842
3443
|
const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
|
@@ -2869,31 +3470,74 @@ int main(int argc, char ** argv) {
|
|
|
2869
3470
|
};
|
|
2870
3471
|
|
|
2871
3472
|
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
|
2872
|
-
|
|
2873
|
-
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
|
2874
|
-
if (tlen > 0) {
|
|
2875
|
-
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
|
2876
|
-
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
|
2877
|
-
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
|
|
2878
|
-
}
|
|
2879
|
-
}
|
|
3473
|
+
// this endpoint is publicly available, please only return what is safe to be exposed
|
|
2880
3474
|
json data = {
|
|
2881
|
-
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
|
2882
3475
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
|
2883
|
-
{ "total_slots", ctx_server.
|
|
2884
|
-
{ "
|
|
3476
|
+
{ "total_slots", ctx_server.params_base.n_parallel },
|
|
3477
|
+
{ "model_path", ctx_server.params_base.model },
|
|
3478
|
+
{ "chat_template", llama_get_chat_template(ctx_server.model) },
|
|
2885
3479
|
};
|
|
2886
3480
|
|
|
2887
3481
|
res_ok(res, data);
|
|
2888
3482
|
};
|
|
2889
3483
|
|
|
2890
|
-
const auto
|
|
2891
|
-
if (ctx_server.
|
|
2892
|
-
res_error(res, format_error_response("This server does not support
|
|
3484
|
+
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3485
|
+
if (!ctx_server.params_base.endpoint_props) {
|
|
3486
|
+
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3487
|
+
return;
|
|
3488
|
+
}
|
|
3489
|
+
|
|
3490
|
+
json data = json::parse(req.body);
|
|
3491
|
+
|
|
3492
|
+
// update any props here
|
|
3493
|
+
|
|
3494
|
+
res_ok(res, {{ "success", true }});
|
|
3495
|
+
};
|
|
3496
|
+
|
|
3497
|
+
// handle completion-like requests (completion, chat, infill)
|
|
3498
|
+
// we can optionally provide a custom format for partial results and final results
|
|
3499
|
+
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
|
|
3500
|
+
server_task_type type,
|
|
3501
|
+
json & data,
|
|
3502
|
+
httplib::Response & res,
|
|
3503
|
+
bool oaicompat = false,
|
|
3504
|
+
bool oaicompat_chat = false) {
|
|
3505
|
+
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
|
3506
|
+
|
|
3507
|
+
if (ctx_server.params_base.embedding) {
|
|
3508
|
+
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3509
|
+
return;
|
|
3510
|
+
}
|
|
3511
|
+
|
|
3512
|
+
auto completion_id = gen_chatcmplid();
|
|
3513
|
+
std::vector<server_task> tasks;
|
|
3514
|
+
|
|
3515
|
+
try {
|
|
3516
|
+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
|
|
3517
|
+
tasks.reserve(tokenized_prompts.size());
|
|
3518
|
+
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
|
3519
|
+
server_task task = server_task(type);
|
|
3520
|
+
|
|
3521
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3522
|
+
task.index = i;
|
|
3523
|
+
|
|
3524
|
+
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
|
3525
|
+
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
|
|
3526
|
+
task.id_selected_slot = json_value(data, "id_slot", -1);
|
|
3527
|
+
|
|
3528
|
+
// OAI-compat
|
|
3529
|
+
task.params.oaicompat = oaicompat;
|
|
3530
|
+
task.params.oaicompat_chat = oaicompat_chat;
|
|
3531
|
+
task.params.oaicompat_cmpl_id = completion_id;
|
|
3532
|
+
// oaicompat_model is already populated by params_from_json_cmpl
|
|
3533
|
+
|
|
3534
|
+
tasks.push_back(task);
|
|
3535
|
+
}
|
|
3536
|
+
} catch (const std::exception & e) {
|
|
3537
|
+
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
|
2893
3538
|
return;
|
|
2894
3539
|
}
|
|
2895
3540
|
|
|
2896
|
-
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
|
2897
3541
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
2898
3542
|
ctx_server.queue_tasks.post(tasks);
|
|
2899
3543
|
|
|
@@ -2901,15 +3545,15 @@ int main(int argc, char ** argv) {
|
|
|
2901
3545
|
const auto task_ids = server_task::get_list_id(tasks);
|
|
2902
3546
|
|
|
2903
3547
|
if (!stream) {
|
|
2904
|
-
ctx_server.
|
|
3548
|
+
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
|
2905
3549
|
if (results.size() == 1) {
|
|
2906
3550
|
// single result
|
|
2907
|
-
res_ok(res, results[0]
|
|
3551
|
+
res_ok(res, results[0]->to_json());
|
|
2908
3552
|
} else {
|
|
2909
3553
|
// multiple results (multitask)
|
|
2910
3554
|
json arr = json::array();
|
|
2911
|
-
for (
|
|
2912
|
-
arr.push_back(res
|
|
3555
|
+
for (auto & res : results) {
|
|
3556
|
+
arr.push_back(res->to_json());
|
|
2913
3557
|
}
|
|
2914
3558
|
res_ok(res, arr);
|
|
2915
3559
|
}
|
|
@@ -2919,12 +3563,26 @@ int main(int argc, char ** argv) {
|
|
|
2919
3563
|
|
|
2920
3564
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
2921
3565
|
} else {
|
|
2922
|
-
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
|
2923
|
-
ctx_server.receive_cmpl_results_stream(task_ids, [&](
|
|
2924
|
-
|
|
3566
|
+
const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
|
|
3567
|
+
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
|
|
3568
|
+
json res_json = result->to_json();
|
|
3569
|
+
if (res_json.is_array()) {
|
|
3570
|
+
for (const auto & res : res_json) {
|
|
3571
|
+
if (!server_sent_event(sink, "data", res)) {
|
|
3572
|
+
return false;
|
|
3573
|
+
}
|
|
3574
|
+
}
|
|
3575
|
+
return true;
|
|
3576
|
+
} else {
|
|
3577
|
+
return server_sent_event(sink, "data", res_json);
|
|
3578
|
+
}
|
|
2925
3579
|
}, [&](const json & error_data) {
|
|
2926
3580
|
server_sent_event(sink, "error", error_data);
|
|
2927
3581
|
});
|
|
3582
|
+
if (oaicompat) {
|
|
3583
|
+
static const std::string ev_done = "data: [DONE]\n\n";
|
|
3584
|
+
sink.write(ev_done.data(), ev_done.size());
|
|
3585
|
+
}
|
|
2928
3586
|
sink.done();
|
|
2929
3587
|
return false;
|
|
2930
3588
|
};
|
|
@@ -2939,72 +3597,102 @@ int main(int argc, char ** argv) {
|
|
|
2939
3597
|
|
|
2940
3598
|
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
2941
3599
|
json data = json::parse(req.body);
|
|
2942
|
-
return handle_completions_generic(
|
|
2943
|
-
|
|
2944
|
-
|
|
2945
|
-
|
|
2946
|
-
|
|
2947
|
-
|
|
3600
|
+
return handle_completions_generic(
|
|
3601
|
+
SERVER_TASK_TYPE_COMPLETION,
|
|
3602
|
+
data,
|
|
3603
|
+
res,
|
|
3604
|
+
/* oaicompat */ false,
|
|
3605
|
+
/* oaicompat_chat */ false);
|
|
2948
3606
|
};
|
|
2949
3607
|
|
|
2950
|
-
|
|
2951
|
-
|
|
2952
|
-
|
|
2953
|
-
|
|
3608
|
+
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
3609
|
+
// check model compatibility
|
|
3610
|
+
std::string err;
|
|
3611
|
+
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
|
3612
|
+
err += "prefix token is missing. ";
|
|
3613
|
+
}
|
|
3614
|
+
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
|
3615
|
+
err += "suffix token is missing. ";
|
|
3616
|
+
}
|
|
3617
|
+
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
|
3618
|
+
err += "middle token is missing. ";
|
|
3619
|
+
}
|
|
3620
|
+
if (!err.empty()) {
|
|
3621
|
+
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
|
2954
3622
|
return;
|
|
2955
3623
|
}
|
|
2956
3624
|
|
|
2957
|
-
json data =
|
|
3625
|
+
json data = json::parse(req.body);
|
|
2958
3626
|
|
|
2959
|
-
|
|
2960
|
-
|
|
2961
|
-
|
|
3627
|
+
// validate input
|
|
3628
|
+
if (data.contains("prompt") && !data.at("prompt").is_string()) {
|
|
3629
|
+
// prompt is optional
|
|
3630
|
+
res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
3631
|
+
}
|
|
2962
3632
|
|
|
2963
|
-
|
|
2964
|
-
|
|
2965
|
-
|
|
3633
|
+
if (!data.contains("input_prefix")) {
|
|
3634
|
+
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
|
3635
|
+
}
|
|
2966
3636
|
|
|
2967
|
-
if (!
|
|
2968
|
-
|
|
2969
|
-
|
|
2970
|
-
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
|
2971
|
-
res_ok(res, result_oai);
|
|
2972
|
-
}, [&](const json & error_data) {
|
|
2973
|
-
res_error(res, error_data);
|
|
2974
|
-
});
|
|
3637
|
+
if (!data.contains("input_suffix")) {
|
|
3638
|
+
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
|
3639
|
+
}
|
|
2975
3640
|
|
|
2976
|
-
|
|
2977
|
-
|
|
2978
|
-
|
|
2979
|
-
|
|
2980
|
-
|
|
2981
|
-
for (auto & event_data : result_array) {
|
|
2982
|
-
if (event_data.empty()) {
|
|
2983
|
-
continue; // skip the stop token
|
|
2984
|
-
}
|
|
2985
|
-
if (!server_sent_event(sink, "data", event_data)) {
|
|
2986
|
-
return false; // connection is closed
|
|
2987
|
-
}
|
|
2988
|
-
}
|
|
2989
|
-
return true; // ok
|
|
2990
|
-
}, [&](const json & error_data) {
|
|
2991
|
-
server_sent_event(sink, "error", error_data);
|
|
2992
|
-
});
|
|
2993
|
-
static const std::string ev_done = "data: [DONE]\n\n";
|
|
2994
|
-
sink.write(ev_done.data(), ev_done.size());
|
|
2995
|
-
sink.done();
|
|
2996
|
-
return true;
|
|
2997
|
-
};
|
|
3641
|
+
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
|
3642
|
+
// input_extra is optional
|
|
3643
|
+
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
|
3644
|
+
return;
|
|
3645
|
+
}
|
|
2998
3646
|
|
|
2999
|
-
|
|
3000
|
-
|
|
3001
|
-
}
|
|
3647
|
+
json input_extra = json_value(data, "input_extra", json::array());
|
|
3648
|
+
for (const auto & chunk : input_extra) {
|
|
3649
|
+
// { "text": string, "filename": string }
|
|
3650
|
+
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
|
3651
|
+
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
|
3652
|
+
return;
|
|
3653
|
+
}
|
|
3654
|
+
// filename is optional
|
|
3655
|
+
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
|
3656
|
+
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
3657
|
+
return;
|
|
3658
|
+
}
|
|
3659
|
+
}
|
|
3660
|
+
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
|
3661
|
+
|
|
3662
|
+
std::string prompt = json_value(data, "prompt", std::string());
|
|
3663
|
+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
|
3664
|
+
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
|
3665
|
+
data["prompt"] = format_infill(
|
|
3666
|
+
ctx_server.ctx,
|
|
3667
|
+
data.at("input_prefix"),
|
|
3668
|
+
data.at("input_suffix"),
|
|
3669
|
+
data.at("input_extra"),
|
|
3670
|
+
ctx_server.params_base.n_batch,
|
|
3671
|
+
ctx_server.params_base.n_predict,
|
|
3672
|
+
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
|
|
3673
|
+
ctx_server.params_base.spm_infill,
|
|
3674
|
+
tokenized_prompts[0]
|
|
3675
|
+
);
|
|
3002
3676
|
|
|
3003
|
-
|
|
3677
|
+
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
|
|
3678
|
+
};
|
|
3679
|
+
|
|
3680
|
+
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
3681
|
+
if (ctx_server.params_base.embedding) {
|
|
3682
|
+
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3683
|
+
return;
|
|
3004
3684
|
}
|
|
3685
|
+
|
|
3686
|
+
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
|
3687
|
+
return handle_completions_generic(
|
|
3688
|
+
SERVER_TASK_TYPE_COMPLETION,
|
|
3689
|
+
data,
|
|
3690
|
+
res,
|
|
3691
|
+
/* oaicompat */ true,
|
|
3692
|
+
/* oaicompat_chat */ true);
|
|
3005
3693
|
};
|
|
3006
3694
|
|
|
3007
|
-
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
|
3695
|
+
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
|
3008
3696
|
json models = {
|
|
3009
3697
|
{"object", "list"},
|
|
3010
3698
|
{"data", {
|
|
@@ -3018,7 +3706,7 @@ int main(int argc, char ** argv) {
|
|
|
3018
3706
|
}}
|
|
3019
3707
|
};
|
|
3020
3708
|
|
|
3021
|
-
res
|
|
3709
|
+
res_ok(res, models);
|
|
3022
3710
|
};
|
|
3023
3711
|
|
|
3024
3712
|
const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
@@ -3028,11 +3716,12 @@ int main(int argc, char ** argv) {
|
|
|
3028
3716
|
if (body.count("content") != 0) {
|
|
3029
3717
|
const bool add_special = json_value(body, "add_special", false);
|
|
3030
3718
|
const bool with_pieces = json_value(body, "with_pieces", false);
|
|
3031
|
-
|
|
3719
|
+
|
|
3720
|
+
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
|
|
3032
3721
|
|
|
3033
3722
|
if (with_pieces) {
|
|
3034
3723
|
for (const auto& token : tokens) {
|
|
3035
|
-
std::string piece =
|
|
3724
|
+
std::string piece = common_token_to_piece(ctx_server.ctx, token);
|
|
3036
3725
|
json piece_json;
|
|
3037
3726
|
|
|
3038
3727
|
// Check if the piece is valid UTF-8
|
|
@@ -3065,7 +3754,7 @@ int main(int argc, char ** argv) {
|
|
|
3065
3754
|
|
|
3066
3755
|
std::string content;
|
|
3067
3756
|
if (body.count("tokens") != 0) {
|
|
3068
|
-
const
|
|
3757
|
+
const llama_tokens tokens = body.at("tokens");
|
|
3069
3758
|
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
|
3070
3759
|
}
|
|
3071
3760
|
|
|
@@ -3073,42 +3762,63 @@ int main(int argc, char ** argv) {
|
|
|
3073
3762
|
res_ok(res, data);
|
|
3074
3763
|
};
|
|
3075
3764
|
|
|
3076
|
-
const auto
|
|
3077
|
-
|
|
3078
|
-
|
|
3079
|
-
|
|
3765
|
+
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
|
|
3766
|
+
const json body = json::parse(req.body);
|
|
3767
|
+
|
|
3768
|
+
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
|
3769
|
+
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
|
3080
3770
|
return;
|
|
3081
3771
|
}
|
|
3082
|
-
const json body = json::parse(req.body);
|
|
3083
|
-
bool is_openai = false;
|
|
3084
3772
|
|
|
3085
|
-
//
|
|
3773
|
+
// for the shape of input/content, see tokenize_input_prompts()
|
|
3086
3774
|
json prompt;
|
|
3087
3775
|
if (body.count("input") != 0) {
|
|
3088
|
-
is_openai = true;
|
|
3089
3776
|
prompt = body.at("input");
|
|
3090
|
-
} else if (body.
|
|
3091
|
-
|
|
3092
|
-
prompt =
|
|
3777
|
+
} else if (body.contains("content")) {
|
|
3778
|
+
oaicompat = false;
|
|
3779
|
+
prompt = body.at("content");
|
|
3093
3780
|
} else {
|
|
3094
3781
|
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
3095
3782
|
return;
|
|
3096
3783
|
}
|
|
3097
3784
|
|
|
3785
|
+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
|
3786
|
+
for (const auto & tokens : tokenized_prompts) {
|
|
3787
|
+
// this check is necessary for models that do not add BOS token to the input
|
|
3788
|
+
if (tokens.empty()) {
|
|
3789
|
+
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
|
3790
|
+
return;
|
|
3791
|
+
}
|
|
3792
|
+
}
|
|
3793
|
+
|
|
3098
3794
|
// create and queue the task
|
|
3099
3795
|
json responses = json::array();
|
|
3100
3796
|
bool error = false;
|
|
3101
3797
|
{
|
|
3102
|
-
std::vector<server_task> tasks
|
|
3798
|
+
std::vector<server_task> tasks;
|
|
3799
|
+
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
|
3800
|
+
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
|
3801
|
+
|
|
3802
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3803
|
+
task.index = i;
|
|
3804
|
+
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
|
3805
|
+
|
|
3806
|
+
// OAI-compat
|
|
3807
|
+
task.params.oaicompat = oaicompat;
|
|
3808
|
+
|
|
3809
|
+
tasks.push_back(task);
|
|
3810
|
+
}
|
|
3811
|
+
|
|
3103
3812
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
3104
3813
|
ctx_server.queue_tasks.post(tasks);
|
|
3105
3814
|
|
|
3106
3815
|
// get the result
|
|
3107
3816
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
3108
3817
|
|
|
3109
|
-
ctx_server.
|
|
3110
|
-
for (
|
|
3111
|
-
|
|
3818
|
+
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
|
3819
|
+
for (auto & res : results) {
|
|
3820
|
+
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
|
3821
|
+
responses.push_back(res->to_json());
|
|
3112
3822
|
}
|
|
3113
3823
|
}, [&](const json & error_data) {
|
|
3114
3824
|
res_error(res, error_data);
|
|
@@ -3123,17 +3833,24 @@ int main(int argc, char ** argv) {
|
|
|
3123
3833
|
}
|
|
3124
3834
|
|
|
3125
3835
|
// write JSON response
|
|
3126
|
-
json root =
|
|
3127
|
-
? format_embeddings_response_oaicompat(body, responses)
|
|
3128
|
-
: responses[0];
|
|
3836
|
+
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
|
|
3129
3837
|
res_ok(res, root);
|
|
3130
3838
|
};
|
|
3131
3839
|
|
|
3840
|
+
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
|
3841
|
+
handle_embeddings_impl(req, res, false);
|
|
3842
|
+
};
|
|
3843
|
+
|
|
3844
|
+
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
|
3845
|
+
handle_embeddings_impl(req, res, true);
|
|
3846
|
+
};
|
|
3847
|
+
|
|
3132
3848
|
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3133
|
-
if (!ctx_server.
|
|
3134
|
-
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3849
|
+
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
|
3850
|
+
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3135
3851
|
return;
|
|
3136
3852
|
}
|
|
3853
|
+
|
|
3137
3854
|
const json body = json::parse(req.body);
|
|
3138
3855
|
|
|
3139
3856
|
// TODO: implement
|
|
@@ -3163,29 +3880,33 @@ int main(int argc, char ** argv) {
|
|
|
3163
3880
|
return;
|
|
3164
3881
|
}
|
|
3165
3882
|
|
|
3166
|
-
|
|
3167
|
-
json prompt;
|
|
3168
|
-
prompt.push_back(query);
|
|
3169
|
-
for (const auto & doc : documents) {
|
|
3170
|
-
prompt.push_back(doc);
|
|
3171
|
-
}
|
|
3172
|
-
|
|
3173
|
-
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
|
|
3883
|
+
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0];
|
|
3174
3884
|
|
|
3175
3885
|
// create and queue the task
|
|
3176
3886
|
json responses = json::array();
|
|
3177
3887
|
bool error = false;
|
|
3178
3888
|
{
|
|
3179
|
-
std::vector<server_task> tasks
|
|
3889
|
+
std::vector<server_task> tasks;
|
|
3890
|
+
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
|
|
3891
|
+
tasks.reserve(tokenized_docs.size());
|
|
3892
|
+
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
|
3893
|
+
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
|
3894
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3895
|
+
task.index = i;
|
|
3896
|
+
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
|
|
3897
|
+
tasks.push_back(task);
|
|
3898
|
+
}
|
|
3899
|
+
|
|
3180
3900
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
3181
3901
|
ctx_server.queue_tasks.post(tasks);
|
|
3182
3902
|
|
|
3183
3903
|
// get the result
|
|
3184
3904
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
3185
3905
|
|
|
3186
|
-
ctx_server.
|
|
3187
|
-
for (
|
|
3188
|
-
|
|
3906
|
+
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
|
3907
|
+
for (auto & res : results) {
|
|
3908
|
+
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
|
3909
|
+
responses.push_back(res->to_json());
|
|
3189
3910
|
}
|
|
3190
3911
|
}, [&](const json & error_data) {
|
|
3191
3912
|
res_error(res, error_data);
|
|
@@ -3236,59 +3957,59 @@ int main(int argc, char ** argv) {
|
|
|
3236
3957
|
}
|
|
3237
3958
|
}
|
|
3238
3959
|
|
|
3239
|
-
server_task task;
|
|
3240
|
-
task.
|
|
3241
|
-
|
|
3242
|
-
ctx_server.
|
|
3960
|
+
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
|
3961
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3962
|
+
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
3963
|
+
ctx_server.queue_tasks.post(task);
|
|
3243
3964
|
|
|
3244
|
-
|
|
3245
|
-
ctx_server.queue_results.remove_waiting_task_id(
|
|
3965
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
3966
|
+
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
3246
3967
|
|
|
3247
|
-
|
|
3248
|
-
|
|
3249
|
-
|
|
3968
|
+
if (result->is_error()) {
|
|
3969
|
+
res_error(res, result->to_json());
|
|
3970
|
+
return;
|
|
3971
|
+
}
|
|
3250
3972
|
|
|
3251
|
-
|
|
3252
|
-
|
|
3253
|
-
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
|
|
3254
|
-
return false;
|
|
3255
|
-
};
|
|
3973
|
+
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
|
|
3974
|
+
res_ok(res, result->to_json());
|
|
3256
3975
|
};
|
|
3257
3976
|
|
|
3258
3977
|
//
|
|
3259
3978
|
// Router
|
|
3260
3979
|
//
|
|
3261
3980
|
|
|
3262
|
-
|
|
3263
|
-
|
|
3264
|
-
|
|
3265
|
-
|
|
3266
|
-
|
|
3267
|
-
|
|
3268
|
-
|
|
3269
|
-
|
|
3270
|
-
|
|
3271
|
-
|
|
3272
|
-
|
|
3273
|
-
|
|
3274
|
-
|
|
3275
|
-
|
|
3276
|
-
|
|
3277
|
-
|
|
3278
|
-
|
|
3279
|
-
|
|
3280
|
-
|
|
3281
|
-
|
|
3282
|
-
|
|
3283
|
-
|
|
3284
|
-
|
|
3285
|
-
|
|
3981
|
+
if (!params.webui) {
|
|
3982
|
+
LOG_INF("Web UI is disabled\n");
|
|
3983
|
+
} else {
|
|
3984
|
+
// register static assets routes
|
|
3985
|
+
if (!params.public_path.empty()) {
|
|
3986
|
+
// Set the base directory for serving static files
|
|
3987
|
+
bool is_found = svr->set_mount_point("/", params.public_path);
|
|
3988
|
+
if (!is_found) {
|
|
3989
|
+
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
|
3990
|
+
return 1;
|
|
3991
|
+
}
|
|
3992
|
+
} else {
|
|
3993
|
+
// using embedded static index.html
|
|
3994
|
+
svr->Get("/", [](const httplib::Request & req, httplib::Response & res) {
|
|
3995
|
+
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
|
3996
|
+
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
|
3997
|
+
} else {
|
|
3998
|
+
res.set_header("Content-Encoding", "gzip");
|
|
3999
|
+
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
|
|
4000
|
+
}
|
|
4001
|
+
return false;
|
|
4002
|
+
});
|
|
4003
|
+
}
|
|
4004
|
+
}
|
|
3286
4005
|
|
|
3287
4006
|
// register API routes
|
|
3288
|
-
svr->Get ("/health", handle_health);
|
|
4007
|
+
svr->Get ("/health", handle_health); // public endpoint (no API key check)
|
|
3289
4008
|
svr->Get ("/metrics", handle_metrics);
|
|
3290
4009
|
svr->Get ("/props", handle_props);
|
|
3291
|
-
svr->
|
|
4010
|
+
svr->Post("/props", handle_props_change);
|
|
4011
|
+
svr->Get ("/models", handle_models); // public endpoint (no API key check)
|
|
4012
|
+
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
|
|
3292
4013
|
svr->Post("/completion", handle_completions); // legacy
|
|
3293
4014
|
svr->Post("/completions", handle_completions);
|
|
3294
4015
|
svr->Post("/v1/completions", handle_completions);
|
|
@@ -3297,7 +4018,7 @@ int main(int argc, char ** argv) {
|
|
|
3297
4018
|
svr->Post("/infill", handle_infill);
|
|
3298
4019
|
svr->Post("/embedding", handle_embeddings); // legacy
|
|
3299
4020
|
svr->Post("/embeddings", handle_embeddings);
|
|
3300
|
-
svr->Post("/v1/embeddings",
|
|
4021
|
+
svr->Post("/v1/embeddings", handle_embeddings_oai);
|
|
3301
4022
|
svr->Post("/rerank", handle_rerank);
|
|
3302
4023
|
svr->Post("/reranking", handle_rerank);
|
|
3303
4024
|
svr->Post("/v1/rerank", handle_rerank);
|
|
@@ -3327,8 +4048,18 @@ int main(int argc, char ** argv) {
|
|
|
3327
4048
|
llama_backend_free();
|
|
3328
4049
|
};
|
|
3329
4050
|
|
|
3330
|
-
// bind HTTP listen port
|
|
3331
|
-
|
|
4051
|
+
// bind HTTP listen port
|
|
4052
|
+
bool was_bound = false;
|
|
4053
|
+
if (params.port == 0) {
|
|
4054
|
+
int bound_port = svr->bind_to_any_port(params.hostname);
|
|
4055
|
+
if ((was_bound = (bound_port >= 0))) {
|
|
4056
|
+
params.port = bound_port;
|
|
4057
|
+
}
|
|
4058
|
+
} else {
|
|
4059
|
+
was_bound = svr->bind_to_port(params.hostname, params.port);
|
|
4060
|
+
}
|
|
4061
|
+
|
|
4062
|
+
if (!was_bound) {
|
|
3332
4063
|
//LOG_ERROR("couldn't bind HTTP server socket", {
|
|
3333
4064
|
// {"hostname", params.hostname},
|
|
3334
4065
|
// {"port", params.port},
|
|
@@ -3337,6 +4068,8 @@ int main(int argc, char ** argv) {
|
|
|
3337
4068
|
clean_up();
|
|
3338
4069
|
return 1;
|
|
3339
4070
|
}
|
|
4071
|
+
|
|
4072
|
+
// run the HTTP server in a thread
|
|
3340
4073
|
std::thread t([&]() { svr->listen_after_bind(); });
|
|
3341
4074
|
svr->wait_until_ready();
|
|
3342
4075
|
|
|
@@ -3366,10 +4099,11 @@ int main(int argc, char ** argv) {
|
|
|
3366
4099
|
}
|
|
3367
4100
|
|
|
3368
4101
|
// print sample chat example to make it clear which template is used
|
|
3369
|
-
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(),
|
|
4102
|
+
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
|
3370
4103
|
|
|
3371
4104
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
|
3372
4105
|
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
|
4106
|
+
|
|
3373
4107
|
ctx_server.queue_tasks.on_update_slots(std::bind(
|
|
3374
4108
|
&server_context::update_slots, &ctx_server));
|
|
3375
4109
|
|
|
@@ -3377,7 +4111,7 @@ int main(int argc, char ** argv) {
|
|
|
3377
4111
|
ctx_server.queue_tasks.terminate();
|
|
3378
4112
|
};
|
|
3379
4113
|
|
|
3380
|
-
LOG_INF("%s: server is listening on
|
|
4114
|
+
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
|
3381
4115
|
|
|
3382
4116
|
ctx_server.queue_tasks.start_loop();
|
|
3383
4117
|
|