@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -1,100 +1,100 @@
|
|
|
1
1
|
#include "utils.hpp"
|
|
2
2
|
|
|
3
|
+
#include "arg.h"
|
|
3
4
|
#include "common.h"
|
|
5
|
+
#include "log.h"
|
|
6
|
+
#include "sampling.h"
|
|
4
7
|
#include "json-schema-to-grammar.h"
|
|
5
8
|
#include "llama.h"
|
|
6
|
-
#include "grammar-parser.h"
|
|
7
9
|
|
|
8
|
-
#ifndef NDEBUG
|
|
9
|
-
// crash the server in debug mode, otherwise send an http 500 error
|
|
10
|
-
#define CPPHTTPLIB_NO_EXCEPTIONS 1
|
|
11
|
-
#endif
|
|
12
|
-
// increase max payload length to allow use of larger context size
|
|
13
|
-
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
|
14
|
-
#include "httplib.h"
|
|
15
10
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
|
16
11
|
#define JSON_ASSERT GGML_ASSERT
|
|
17
12
|
#include "json.hpp"
|
|
13
|
+
// mime type for sending response
|
|
14
|
+
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
|
18
15
|
|
|
19
16
|
// auto generated files (update with ./deps.sh)
|
|
20
|
-
#include "colorthemes.css.hpp"
|
|
21
|
-
#include "style.css.hpp"
|
|
22
|
-
#include "theme-beeninorder.css.hpp"
|
|
23
|
-
#include "theme-ketivah.css.hpp"
|
|
24
|
-
#include "theme-mangotango.css.hpp"
|
|
25
|
-
#include "theme-playground.css.hpp"
|
|
26
|
-
#include "theme-polarnight.css.hpp"
|
|
27
|
-
#include "theme-snowstorm.css.hpp"
|
|
28
17
|
#include "index.html.hpp"
|
|
29
|
-
#include "index-new.html.hpp"
|
|
30
|
-
#include "index.js.hpp"
|
|
31
18
|
#include "completion.js.hpp"
|
|
32
|
-
#include "
|
|
33
|
-
#include "
|
|
34
|
-
#include "
|
|
19
|
+
#include "loading.html.hpp"
|
|
20
|
+
#include "deps_daisyui.min.css.hpp"
|
|
21
|
+
#include "deps_markdown-it.js.hpp"
|
|
22
|
+
#include "deps_tailwindcss.js.hpp"
|
|
23
|
+
#include "deps_vue.esm-browser.js.hpp"
|
|
35
24
|
|
|
36
25
|
#include <atomic>
|
|
37
|
-
#include <chrono>
|
|
38
26
|
#include <condition_variable>
|
|
39
27
|
#include <cstddef>
|
|
40
|
-
#include <
|
|
28
|
+
#include <cinttypes>
|
|
29
|
+
#include <deque>
|
|
30
|
+
#include <memory>
|
|
41
31
|
#include <mutex>
|
|
42
|
-
#include <thread>
|
|
43
32
|
#include <signal.h>
|
|
44
|
-
#include <
|
|
33
|
+
#include <thread>
|
|
34
|
+
#include <unordered_map>
|
|
35
|
+
#include <unordered_set>
|
|
45
36
|
|
|
46
37
|
using json = nlohmann::ordered_json;
|
|
47
38
|
|
|
48
|
-
bool server_verbose = false;
|
|
49
|
-
bool server_log_json = true;
|
|
50
|
-
|
|
51
39
|
enum stop_type {
|
|
52
40
|
STOP_TYPE_FULL,
|
|
53
41
|
STOP_TYPE_PARTIAL,
|
|
54
42
|
};
|
|
55
43
|
|
|
44
|
+
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
|
56
45
|
enum slot_state {
|
|
57
46
|
SLOT_STATE_IDLE,
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
SLOT_COMMAND_NONE,
|
|
63
|
-
SLOT_COMMAND_LOAD_PROMPT,
|
|
64
|
-
SLOT_COMMAND_RELEASE,
|
|
47
|
+
SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
|
|
48
|
+
SLOT_STATE_PROCESSING_PROMPT,
|
|
49
|
+
SLOT_STATE_DONE_PROMPT,
|
|
50
|
+
SLOT_STATE_GENERATING,
|
|
65
51
|
};
|
|
66
52
|
|
|
67
53
|
enum server_state {
|
|
68
54
|
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
|
|
69
55
|
SERVER_STATE_READY, // Server is ready and model is loaded
|
|
70
|
-
SERVER_STATE_ERROR // An error occurred, load_model failed
|
|
71
56
|
};
|
|
72
57
|
|
|
73
58
|
enum server_task_type {
|
|
74
|
-
|
|
59
|
+
SERVER_TASK_TYPE_INFERENCE,
|
|
75
60
|
SERVER_TASK_TYPE_CANCEL,
|
|
76
61
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
|
77
62
|
SERVER_TASK_TYPE_METRICS,
|
|
78
63
|
SERVER_TASK_TYPE_SLOT_SAVE,
|
|
79
64
|
SERVER_TASK_TYPE_SLOT_RESTORE,
|
|
80
65
|
SERVER_TASK_TYPE_SLOT_ERASE,
|
|
66
|
+
SERVER_TASK_TYPE_SET_LORA,
|
|
67
|
+
};
|
|
68
|
+
|
|
69
|
+
enum server_task_inf_type {
|
|
70
|
+
SERVER_TASK_INF_TYPE_COMPLETION,
|
|
71
|
+
SERVER_TASK_INF_TYPE_EMBEDDING,
|
|
72
|
+
SERVER_TASK_INF_TYPE_RERANK,
|
|
73
|
+
SERVER_TASK_INF_TYPE_INFILL,
|
|
81
74
|
};
|
|
82
75
|
|
|
83
76
|
struct server_task {
|
|
84
77
|
int id = -1; // to be filled by server_queue
|
|
85
|
-
int
|
|
86
|
-
int id_target = -1;
|
|
78
|
+
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
|
87
79
|
|
|
80
|
+
llama_tokens prompt_tokens;
|
|
88
81
|
server_task_type type;
|
|
89
82
|
json data;
|
|
90
83
|
|
|
91
|
-
|
|
92
|
-
|
|
84
|
+
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
|
85
|
+
|
|
86
|
+
// utility function
|
|
87
|
+
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
|
88
|
+
std::unordered_set<int> ids(tasks.size());
|
|
89
|
+
for (size_t i = 0; i < tasks.size(); i++) {
|
|
90
|
+
ids.insert(tasks[i].id);
|
|
91
|
+
}
|
|
92
|
+
return ids;
|
|
93
|
+
}
|
|
93
94
|
};
|
|
94
95
|
|
|
95
96
|
struct server_task_result {
|
|
96
97
|
int id = -1;
|
|
97
|
-
int id_multi = -1;
|
|
98
98
|
|
|
99
99
|
json data;
|
|
100
100
|
|
|
@@ -102,36 +102,37 @@ struct server_task_result {
|
|
|
102
102
|
bool error;
|
|
103
103
|
};
|
|
104
104
|
|
|
105
|
-
struct
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
std::vector<server_task_result> results;
|
|
105
|
+
struct server_static_file {
|
|
106
|
+
const unsigned char * data;
|
|
107
|
+
unsigned int size;
|
|
108
|
+
const char * mime_type;
|
|
110
109
|
};
|
|
111
110
|
|
|
112
111
|
struct slot_params {
|
|
113
112
|
bool stream = true;
|
|
114
113
|
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
|
|
115
114
|
|
|
116
|
-
int32_t
|
|
117
|
-
int32_t
|
|
118
|
-
int32_t
|
|
115
|
+
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
|
116
|
+
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
|
117
|
+
int32_t n_predict = -1; // new tokens to predict
|
|
118
|
+
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
|
|
119
119
|
|
|
120
|
-
|
|
120
|
+
int64_t t_max_prompt_ms = -1; // TODO: implement
|
|
121
|
+
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
|
121
122
|
|
|
122
|
-
|
|
123
|
-
json input_suffix;
|
|
123
|
+
std::vector<std::string> antiprompt;
|
|
124
124
|
};
|
|
125
125
|
|
|
126
126
|
struct server_slot {
|
|
127
127
|
int id;
|
|
128
128
|
int id_task = -1;
|
|
129
|
-
|
|
129
|
+
|
|
130
|
+
// the index relative to completion multi-task request
|
|
131
|
+
size_t index = 0;
|
|
130
132
|
|
|
131
133
|
struct slot_params params;
|
|
132
134
|
|
|
133
135
|
slot_state state = SLOT_STATE_IDLE;
|
|
134
|
-
slot_command command = SLOT_COMMAND_NONE;
|
|
135
136
|
|
|
136
137
|
// used to determine the slot that has been used the longest
|
|
137
138
|
int64_t t_last_used = -1;
|
|
@@ -144,21 +145,23 @@ struct server_slot {
|
|
|
144
145
|
int32_t i_batch = -1;
|
|
145
146
|
int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
|
|
146
147
|
|
|
148
|
+
// n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
|
|
147
149
|
int32_t n_prompt_tokens = 0;
|
|
148
150
|
int32_t n_prompt_tokens_processed = 0;
|
|
149
151
|
|
|
150
|
-
|
|
152
|
+
// input prompt tokens
|
|
153
|
+
llama_tokens prompt_tokens;
|
|
151
154
|
|
|
152
|
-
|
|
153
|
-
std::vector<llama_token> prompt_tokens;
|
|
155
|
+
size_t last_nl_pos = 0;
|
|
154
156
|
|
|
155
157
|
std::string generated_text;
|
|
156
|
-
|
|
158
|
+
llama_tokens cache_tokens;
|
|
157
159
|
std::vector<completion_token_output> generated_token_probs;
|
|
158
160
|
|
|
159
|
-
|
|
160
|
-
|
|
161
|
+
server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
|
162
|
+
|
|
161
163
|
bool has_next_token = true;
|
|
164
|
+
bool has_new_line = false;
|
|
162
165
|
bool truncated = false;
|
|
163
166
|
bool stopped_eos = false;
|
|
164
167
|
bool stopped_word = false;
|
|
@@ -170,30 +173,32 @@ struct server_slot {
|
|
|
170
173
|
std::string stopping_word;
|
|
171
174
|
|
|
172
175
|
// sampling
|
|
173
|
-
llama_token sampled;
|
|
174
|
-
struct llama_sampling_params sparams;
|
|
175
|
-
llama_sampling_context * ctx_sampling = nullptr;
|
|
176
176
|
json json_schema;
|
|
177
177
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
int32_t ga_w = 512; // group-attention width
|
|
178
|
+
struct common_sampler_params sparams;
|
|
179
|
+
struct common_sampler * smpl = nullptr;
|
|
181
180
|
|
|
182
|
-
|
|
181
|
+
llama_token sampled;
|
|
183
182
|
|
|
184
183
|
// stats
|
|
185
|
-
size_t n_sent_text
|
|
184
|
+
size_t n_sent_text = 0; // number of sent text character
|
|
186
185
|
size_t n_sent_token_probs = 0;
|
|
187
186
|
|
|
188
187
|
int64_t t_start_process_prompt;
|
|
189
188
|
int64_t t_start_generation;
|
|
190
189
|
|
|
191
190
|
double t_prompt_processing; // ms
|
|
192
|
-
double t_token_generation;
|
|
191
|
+
double t_token_generation; // ms
|
|
192
|
+
|
|
193
|
+
std::function<void(int)> callback_on_release;
|
|
193
194
|
|
|
194
195
|
void reset() {
|
|
196
|
+
SLT_DBG(*this, "%s", "\n");
|
|
197
|
+
|
|
195
198
|
n_prompt_tokens = 0;
|
|
199
|
+
last_nl_pos = 0;
|
|
196
200
|
generated_text = "";
|
|
201
|
+
has_new_line = false;
|
|
197
202
|
truncated = false;
|
|
198
203
|
stopped_eos = false;
|
|
199
204
|
stopped_word = false;
|
|
@@ -202,14 +207,12 @@ struct server_slot {
|
|
|
202
207
|
n_past = 0;
|
|
203
208
|
n_sent_text = 0;
|
|
204
209
|
n_sent_token_probs = 0;
|
|
205
|
-
|
|
206
|
-
ga_i = 0;
|
|
207
|
-
n_past_se = 0;
|
|
210
|
+
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
|
208
211
|
|
|
209
212
|
generated_token_probs.clear();
|
|
210
213
|
}
|
|
211
214
|
|
|
212
|
-
bool has_budget(
|
|
215
|
+
bool has_budget(common_params &global_params) {
|
|
213
216
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
|
214
217
|
return true; // limitless
|
|
215
218
|
}
|
|
@@ -225,25 +228,26 @@ struct server_slot {
|
|
|
225
228
|
return n_remaining > 0; // no budget
|
|
226
229
|
}
|
|
227
230
|
|
|
228
|
-
bool available() const {
|
|
229
|
-
return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
|
|
230
|
-
}
|
|
231
|
-
|
|
232
231
|
bool is_processing() const {
|
|
233
|
-
return
|
|
232
|
+
return state != SLOT_STATE_IDLE;
|
|
234
233
|
}
|
|
235
234
|
|
|
236
|
-
void
|
|
237
|
-
if (
|
|
235
|
+
void add_token(const completion_token_output & token) {
|
|
236
|
+
if (!is_processing()) {
|
|
237
|
+
SLT_WRN(*this, "%s", "slot is not processing\n");
|
|
238
238
|
return;
|
|
239
239
|
}
|
|
240
240
|
generated_token_probs.push_back(token);
|
|
241
241
|
}
|
|
242
242
|
|
|
243
243
|
void release() {
|
|
244
|
-
if (
|
|
244
|
+
if (is_processing()) {
|
|
245
|
+
SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
|
|
246
|
+
|
|
247
|
+
t_last_used = ggml_time_us();
|
|
245
248
|
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
|
246
|
-
|
|
249
|
+
state = SLOT_STATE_IDLE;
|
|
250
|
+
callback_on_release(id);
|
|
247
251
|
}
|
|
248
252
|
}
|
|
249
253
|
|
|
@@ -290,49 +294,20 @@ struct server_slot {
|
|
|
290
294
|
}
|
|
291
295
|
|
|
292
296
|
void print_timings() const {
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
double
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
{"t_token", t_token},
|
|
308
|
-
{"n_tokens_second", n_tokens_second},
|
|
309
|
-
});
|
|
310
|
-
|
|
311
|
-
t_token = t_token_generation / n_decoded;
|
|
312
|
-
n_tokens_second = 1e3 / t_token_generation * n_decoded;
|
|
313
|
-
|
|
314
|
-
snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
|
315
|
-
t_token_generation, n_decoded,
|
|
316
|
-
t_token, n_tokens_second);
|
|
317
|
-
|
|
318
|
-
LOG_INFO(buffer, {
|
|
319
|
-
{"id_slot", id},
|
|
320
|
-
{"id_task", id_task},
|
|
321
|
-
{"t_token_generation", t_token_generation},
|
|
322
|
-
{"n_decoded", n_decoded},
|
|
323
|
-
{"t_token", t_token},
|
|
324
|
-
{"n_tokens_second", n_tokens_second},
|
|
325
|
-
});
|
|
326
|
-
|
|
327
|
-
snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
|
328
|
-
|
|
329
|
-
LOG_INFO(buffer, {
|
|
330
|
-
{"id_slot", id},
|
|
331
|
-
{"id_task", id_task},
|
|
332
|
-
{"t_prompt_processing", t_prompt_processing},
|
|
333
|
-
{"t_token_generation", t_token_generation},
|
|
334
|
-
{"t_total", t_prompt_processing + t_token_generation},
|
|
335
|
-
});
|
|
297
|
+
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
|
298
|
+
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
299
|
+
|
|
300
|
+
const double t_gen = t_token_generation / n_decoded;
|
|
301
|
+
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
|
|
302
|
+
|
|
303
|
+
SLT_INF(*this,
|
|
304
|
+
"\n"
|
|
305
|
+
"\rprompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
306
|
+
"\r eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
307
|
+
"\r total time = %10.2f ms / %5d tokens\n",
|
|
308
|
+
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
|
309
|
+
t_token_generation, n_decoded, t_gen, n_gen_second,
|
|
310
|
+
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
|
336
311
|
}
|
|
337
312
|
};
|
|
338
313
|
|
|
@@ -350,6 +325,9 @@ struct server_metrics {
|
|
|
350
325
|
uint64_t n_tokens_predicted = 0;
|
|
351
326
|
uint64_t t_tokens_generation = 0;
|
|
352
327
|
|
|
328
|
+
uint64_t n_decode_total = 0;
|
|
329
|
+
uint64_t n_busy_slots_total = 0;
|
|
330
|
+
|
|
353
331
|
void init() {
|
|
354
332
|
t_start = ggml_time_us();
|
|
355
333
|
}
|
|
@@ -368,6 +346,15 @@ struct server_metrics {
|
|
|
368
346
|
t_tokens_generation_total += slot.t_token_generation;
|
|
369
347
|
}
|
|
370
348
|
|
|
349
|
+
void on_decoded(const std::vector<server_slot> & slots) {
|
|
350
|
+
n_decode_total++;
|
|
351
|
+
for (const auto & slot : slots) {
|
|
352
|
+
if (slot.is_processing()) {
|
|
353
|
+
n_busy_slots_total++;
|
|
354
|
+
}
|
|
355
|
+
}
|
|
356
|
+
}
|
|
357
|
+
|
|
371
358
|
void reset_bucket() {
|
|
372
359
|
n_prompt_tokens_processed = 0;
|
|
373
360
|
t_prompt_processing = 0;
|
|
@@ -381,68 +368,83 @@ struct server_queue {
|
|
|
381
368
|
bool running;
|
|
382
369
|
|
|
383
370
|
// queues
|
|
384
|
-
std::
|
|
385
|
-
std::
|
|
386
|
-
|
|
387
|
-
std::vector<server_task_multi> queue_multitasks;
|
|
371
|
+
std::deque<server_task> queue_tasks;
|
|
372
|
+
std::deque<server_task> queue_tasks_deferred;
|
|
388
373
|
|
|
389
374
|
std::mutex mutex_tasks;
|
|
390
375
|
std::condition_variable condition_tasks;
|
|
391
376
|
|
|
392
377
|
// callback functions
|
|
393
|
-
std::function<void(server_task
|
|
394
|
-
std::function<void(
|
|
395
|
-
std::function<void(void)> callback_update_slots;
|
|
378
|
+
std::function<void(server_task)> callback_new_task;
|
|
379
|
+
std::function<void(void)> callback_update_slots;
|
|
396
380
|
|
|
397
381
|
// Add a new task to the end of the queue
|
|
398
|
-
int post(server_task task) {
|
|
382
|
+
int post(server_task task, bool front = false) {
|
|
399
383
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
400
384
|
if (task.id == -1) {
|
|
401
385
|
task.id = id++;
|
|
402
|
-
LOG_VERBOSE("new task id", {{"new_id", task.id}});
|
|
403
386
|
}
|
|
404
|
-
|
|
387
|
+
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
|
|
388
|
+
if (front) {
|
|
389
|
+
queue_tasks.push_front(std::move(task));
|
|
390
|
+
} else {
|
|
391
|
+
queue_tasks.push_back(std::move(task));
|
|
392
|
+
}
|
|
405
393
|
condition_tasks.notify_one();
|
|
406
394
|
return task.id;
|
|
407
395
|
}
|
|
408
396
|
|
|
397
|
+
// multi-task version of post()
|
|
398
|
+
int post(std::vector<server_task> & tasks, bool front = false) {
|
|
399
|
+
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
400
|
+
for (auto & task : tasks) {
|
|
401
|
+
if (task.id == -1) {
|
|
402
|
+
task.id = id++;
|
|
403
|
+
}
|
|
404
|
+
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
|
|
405
|
+
if (front) {
|
|
406
|
+
queue_tasks.push_front(std::move(task));
|
|
407
|
+
} else {
|
|
408
|
+
queue_tasks.push_back(std::move(task));
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
condition_tasks.notify_one();
|
|
412
|
+
return 0;
|
|
413
|
+
}
|
|
414
|
+
|
|
409
415
|
// Add a new task, but defer until one slot is available
|
|
410
416
|
void defer(server_task task) {
|
|
411
417
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
418
|
+
QUE_DBG("defer task, id = %d\n", task.id);
|
|
412
419
|
queue_tasks_deferred.push_back(std::move(task));
|
|
420
|
+
condition_tasks.notify_one();
|
|
413
421
|
}
|
|
414
422
|
|
|
415
|
-
// Get the next id for creating
|
|
423
|
+
// Get the next id for creating a new task
|
|
416
424
|
int get_new_id() {
|
|
417
425
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
418
426
|
int new_id = id++;
|
|
419
|
-
LOG_VERBOSE("new task id", {{"new_id", new_id}});
|
|
420
427
|
return new_id;
|
|
421
428
|
}
|
|
422
429
|
|
|
423
430
|
// Register function to process a new task
|
|
424
|
-
void on_new_task(std::function<void(server_task
|
|
431
|
+
void on_new_task(std::function<void(server_task)> callback) {
|
|
425
432
|
callback_new_task = std::move(callback);
|
|
426
433
|
}
|
|
427
434
|
|
|
428
|
-
// Register function to process a multitask when it is finished
|
|
429
|
-
void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
|
|
430
|
-
callback_finish_multitask = std::move(callback);
|
|
431
|
-
}
|
|
432
|
-
|
|
433
435
|
// Register the function to be called when all slots data is ready to be processed
|
|
434
436
|
void on_update_slots(std::function<void(void)> callback) {
|
|
435
437
|
callback_update_slots = std::move(callback);
|
|
436
438
|
}
|
|
437
439
|
|
|
438
|
-
// Call when the state of one slot is changed
|
|
439
|
-
void
|
|
440
|
-
// move deferred tasks back to main loop
|
|
440
|
+
// Call when the state of one slot is changed, it will move one task from deferred to main queue
|
|
441
|
+
void pop_deferred_task() {
|
|
441
442
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
442
|
-
|
|
443
|
-
queue_tasks.
|
|
443
|
+
if (!queue_tasks_deferred.empty()) {
|
|
444
|
+
queue_tasks.emplace_back(std::move(queue_tasks_deferred.front()));
|
|
445
|
+
queue_tasks_deferred.pop_front();
|
|
444
446
|
}
|
|
445
|
-
|
|
447
|
+
condition_tasks.notify_one();
|
|
446
448
|
}
|
|
447
449
|
|
|
448
450
|
// end the start_loop routine
|
|
@@ -463,7 +465,7 @@ struct server_queue {
|
|
|
463
465
|
running = true;
|
|
464
466
|
|
|
465
467
|
while (true) {
|
|
466
|
-
|
|
468
|
+
QUE_DBG("%s", "processing new tasks\n");
|
|
467
469
|
|
|
468
470
|
while (true) {
|
|
469
471
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
@@ -472,39 +474,24 @@ struct server_queue {
|
|
|
472
474
|
break;
|
|
473
475
|
}
|
|
474
476
|
server_task task = queue_tasks.front();
|
|
475
|
-
queue_tasks.
|
|
477
|
+
queue_tasks.pop_front();
|
|
476
478
|
lock.unlock();
|
|
477
|
-
LOG_VERBOSE("callback_new_task", {{"id_task", task.id}});
|
|
478
|
-
callback_new_task(task);
|
|
479
|
-
}
|
|
480
479
|
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
// check if we have any finished multitasks
|
|
484
|
-
auto queue_iterator = queue_multitasks.begin();
|
|
485
|
-
while (queue_iterator != queue_multitasks.end()) {
|
|
486
|
-
if (queue_iterator->subtasks_remaining.empty()) {
|
|
487
|
-
// all subtasks done == multitask is done
|
|
488
|
-
server_task_multi current_multitask = *queue_iterator;
|
|
489
|
-
callback_finish_multitask(current_multitask);
|
|
490
|
-
// remove this multitask
|
|
491
|
-
queue_iterator = queue_multitasks.erase(queue_iterator);
|
|
492
|
-
} else {
|
|
493
|
-
++queue_iterator;
|
|
494
|
-
}
|
|
480
|
+
QUE_DBG("processing task, id = %d\n", task.id);
|
|
481
|
+
callback_new_task(std::move(task));
|
|
495
482
|
}
|
|
496
483
|
|
|
497
484
|
// all tasks in the current loop is processed, slots data is now ready
|
|
498
|
-
|
|
485
|
+
QUE_DBG("%s", "update slots\n");
|
|
499
486
|
|
|
500
487
|
callback_update_slots();
|
|
501
488
|
|
|
502
|
-
|
|
489
|
+
QUE_DBG("%s", "waiting for new tasks\n");
|
|
503
490
|
{
|
|
504
491
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
505
492
|
if (queue_tasks.empty()) {
|
|
506
493
|
if (!running) {
|
|
507
|
-
|
|
494
|
+
QUE_DBG("%s", "terminate\n");
|
|
508
495
|
return;
|
|
509
496
|
}
|
|
510
497
|
condition_tasks.wait(lock, [&]{
|
|
@@ -514,38 +501,11 @@ struct server_queue {
|
|
|
514
501
|
}
|
|
515
502
|
}
|
|
516
503
|
}
|
|
517
|
-
|
|
518
|
-
//
|
|
519
|
-
// functions to manage multitasks
|
|
520
|
-
//
|
|
521
|
-
|
|
522
|
-
// add a multitask by specifying the id of all subtask (subtask is a server_task)
|
|
523
|
-
void add_multitask(int id_multi, std::vector<int> & sub_ids) {
|
|
524
|
-
std::lock_guard<std::mutex> lock(mutex_tasks);
|
|
525
|
-
server_task_multi multi;
|
|
526
|
-
multi.id = id_multi;
|
|
527
|
-
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
|
|
528
|
-
queue_multitasks.push_back(multi);
|
|
529
|
-
}
|
|
530
|
-
|
|
531
|
-
// updatethe remaining subtasks, while appending results to multitask
|
|
532
|
-
void update_multitask(int id_multi, int id_sub, server_task_result & result) {
|
|
533
|
-
std::lock_guard<std::mutex> lock(mutex_tasks);
|
|
534
|
-
for (auto & multitask : queue_multitasks) {
|
|
535
|
-
if (multitask.id == id_multi) {
|
|
536
|
-
multitask.subtasks_remaining.erase(id_sub);
|
|
537
|
-
multitask.results.push_back(result);
|
|
538
|
-
}
|
|
539
|
-
}
|
|
540
|
-
}
|
|
541
504
|
};
|
|
542
505
|
|
|
543
506
|
struct server_response {
|
|
544
|
-
typedef std::function<void(int, int, server_task_result &)> callback_multitask_t;
|
|
545
|
-
callback_multitask_t callback_update_multitask;
|
|
546
|
-
|
|
547
507
|
// for keeping track of all tasks waiting for the result
|
|
548
|
-
std::
|
|
508
|
+
std::unordered_set<int> waiting_task_ids;
|
|
549
509
|
|
|
550
510
|
// the main result queue
|
|
551
511
|
std::vector<server_task_result> queue_results;
|
|
@@ -555,22 +515,40 @@ struct server_response {
|
|
|
555
515
|
|
|
556
516
|
// add the id_task to the list of tasks waiting for response
|
|
557
517
|
void add_waiting_task_id(int id_task) {
|
|
558
|
-
|
|
518
|
+
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
|
|
559
519
|
|
|
560
520
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
561
521
|
waiting_task_ids.insert(id_task);
|
|
562
522
|
}
|
|
563
523
|
|
|
524
|
+
void add_waiting_tasks(const std::vector<server_task> & tasks) {
|
|
525
|
+
std::unique_lock<std::mutex> lock(mutex_results);
|
|
526
|
+
|
|
527
|
+
for (const auto & task : tasks) {
|
|
528
|
+
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
|
|
529
|
+
waiting_task_ids.insert(task.id);
|
|
530
|
+
}
|
|
531
|
+
}
|
|
532
|
+
|
|
564
533
|
// when the request is finished, we can remove task associated with it
|
|
565
534
|
void remove_waiting_task_id(int id_task) {
|
|
566
|
-
|
|
535
|
+
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
|
|
567
536
|
|
|
568
537
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
569
538
|
waiting_task_ids.erase(id_task);
|
|
570
539
|
}
|
|
571
540
|
|
|
572
|
-
|
|
573
|
-
|
|
541
|
+
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
|
|
542
|
+
std::unique_lock<std::mutex> lock(mutex_results);
|
|
543
|
+
|
|
544
|
+
for (const auto & id_task : id_tasks) {
|
|
545
|
+
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
|
|
546
|
+
waiting_task_ids.erase(id_task);
|
|
547
|
+
}
|
|
548
|
+
}
|
|
549
|
+
|
|
550
|
+
// This function blocks the thread until there is a response for one of the id_tasks
|
|
551
|
+
server_task_result recv(const std::unordered_set<int> & id_tasks) {
|
|
574
552
|
while (true) {
|
|
575
553
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
576
554
|
condition_results.wait(lock, [&]{
|
|
@@ -578,8 +556,7 @@ struct server_response {
|
|
|
578
556
|
});
|
|
579
557
|
|
|
580
558
|
for (int i = 0; i < (int) queue_results.size(); i++) {
|
|
581
|
-
if (queue_results[i].id
|
|
582
|
-
assert(queue_results[i].id_multi == -1);
|
|
559
|
+
if (id_tasks.find(queue_results[i].id) != id_tasks.end()) {
|
|
583
560
|
server_task_result res = queue_results[i];
|
|
584
561
|
queue_results.erase(queue_results.begin() + i);
|
|
585
562
|
return res;
|
|
@@ -590,28 +567,22 @@ struct server_response {
|
|
|
590
567
|
// should never reach here
|
|
591
568
|
}
|
|
592
569
|
|
|
593
|
-
//
|
|
594
|
-
|
|
595
|
-
|
|
570
|
+
// single-task version of recv()
|
|
571
|
+
server_task_result recv(int id_task) {
|
|
572
|
+
std::unordered_set<int> id_tasks = {id_task};
|
|
573
|
+
return recv(id_tasks);
|
|
596
574
|
}
|
|
597
575
|
|
|
598
576
|
// Send a new result to a waiting id_task
|
|
599
|
-
void send(server_task_result result) {
|
|
600
|
-
|
|
577
|
+
void send(server_task_result & result) {
|
|
578
|
+
SRV_DBG("sending result for task id = %d\n", result.id);
|
|
601
579
|
|
|
602
580
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
603
581
|
for (const auto & id_task : waiting_task_ids) {
|
|
604
|
-
// LOG_TEE("waiting task id %i \n", id_task);
|
|
605
|
-
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
|
|
606
|
-
if (result.id_multi == id_task) {
|
|
607
|
-
LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}});
|
|
608
|
-
callback_update_multitask(id_task, result.id, result);
|
|
609
|
-
continue;
|
|
610
|
-
}
|
|
611
|
-
|
|
612
582
|
if (result.id == id_task) {
|
|
613
|
-
|
|
614
|
-
|
|
583
|
+
SRV_DBG("task id = %d moved to result queue\n", result.id);
|
|
584
|
+
|
|
585
|
+
queue_results.push_back(std::move(result));
|
|
615
586
|
condition_results.notify_all();
|
|
616
587
|
return;
|
|
617
588
|
}
|
|
@@ -622,22 +593,18 @@ struct server_response {
|
|
|
622
593
|
struct server_context {
|
|
623
594
|
llama_model * model = nullptr;
|
|
624
595
|
llama_context * ctx = nullptr;
|
|
596
|
+
std::vector<common_lora_adapter_container> loras;
|
|
625
597
|
|
|
626
|
-
|
|
598
|
+
common_params params;
|
|
627
599
|
|
|
628
|
-
llama_batch batch;
|
|
600
|
+
llama_batch batch = {};
|
|
629
601
|
|
|
630
602
|
bool clean_kv_cache = true;
|
|
631
603
|
bool add_bos_token = true;
|
|
604
|
+
bool has_eos_token = false;
|
|
632
605
|
|
|
633
606
|
int32_t n_ctx; // total context for all clients / slots
|
|
634
607
|
|
|
635
|
-
// system prompt
|
|
636
|
-
bool system_need_update = false;
|
|
637
|
-
|
|
638
|
-
std::string system_prompt;
|
|
639
|
-
std::vector<llama_token> system_tokens;
|
|
640
|
-
|
|
641
608
|
// slots / clients
|
|
642
609
|
std::vector<server_slot> slots;
|
|
643
610
|
json default_generation_settings_for_props;
|
|
@@ -663,47 +630,53 @@ struct server_context {
|
|
|
663
630
|
|
|
664
631
|
// Clear any sampling context
|
|
665
632
|
for (server_slot & slot : slots) {
|
|
666
|
-
if (slot.
|
|
667
|
-
|
|
633
|
+
if (slot.smpl != nullptr) {
|
|
634
|
+
common_sampler_free(slot.smpl);
|
|
668
635
|
}
|
|
669
636
|
}
|
|
670
637
|
|
|
671
638
|
llama_batch_free(batch);
|
|
672
639
|
}
|
|
673
640
|
|
|
674
|
-
bool load_model(const
|
|
641
|
+
bool load_model(const common_params & params_) {
|
|
675
642
|
params = params_;
|
|
676
643
|
|
|
677
|
-
|
|
678
|
-
|
|
644
|
+
common_init_result llama_init = common_init_from_params(params);
|
|
645
|
+
|
|
646
|
+
model = llama_init.model;
|
|
647
|
+
ctx = llama_init.context;
|
|
648
|
+
loras = llama_init.lora_adapters;
|
|
679
649
|
|
|
680
|
-
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
|
681
|
-
params.n_parallel -= 1; // but be sneaky about it
|
|
682
650
|
if (model == nullptr) {
|
|
683
|
-
|
|
651
|
+
SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
|
|
684
652
|
return false;
|
|
685
653
|
}
|
|
686
654
|
|
|
687
655
|
n_ctx = llama_n_ctx(ctx);
|
|
688
656
|
|
|
689
|
-
add_bos_token =
|
|
690
|
-
|
|
657
|
+
add_bos_token = llama_add_bos_token(model);
|
|
658
|
+
has_eos_token = !llama_add_eos_token(model);
|
|
691
659
|
|
|
692
660
|
return true;
|
|
693
661
|
}
|
|
694
662
|
|
|
695
663
|
bool validate_model_chat_template() const {
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
699
|
-
|
|
700
|
-
|
|
664
|
+
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
|
665
|
+
std::string template_key = "tokenizer.chat_template";
|
|
666
|
+
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
|
667
|
+
if (res >= 0) {
|
|
668
|
+
llama_chat_message chat[] = {{"user", "test"}};
|
|
669
|
+
std::string tmpl = std::string(model_template.data(), model_template.size());
|
|
670
|
+
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
|
671
|
+
return chat_res > 0;
|
|
672
|
+
}
|
|
673
|
+
return false;
|
|
701
674
|
}
|
|
702
675
|
|
|
703
676
|
void init() {
|
|
704
677
|
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
|
705
678
|
|
|
706
|
-
|
|
679
|
+
SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
|
|
707
680
|
|
|
708
681
|
for (int i = 0; i < params.n_parallel; i++) {
|
|
709
682
|
server_slot slot;
|
|
@@ -712,33 +685,14 @@ struct server_context {
|
|
|
712
685
|
slot.n_ctx = n_ctx_slot;
|
|
713
686
|
slot.n_predict = params.n_predict;
|
|
714
687
|
|
|
715
|
-
|
|
716
|
-
{"id_slot", slot.id},
|
|
717
|
-
{"n_ctx_slot", slot.n_ctx}
|
|
718
|
-
});
|
|
719
|
-
|
|
720
|
-
const int ga_n = params.grp_attn_n;
|
|
721
|
-
const int ga_w = params.grp_attn_w;
|
|
722
|
-
|
|
723
|
-
if (ga_n != 1) {
|
|
724
|
-
GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
|
|
725
|
-
GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
|
|
726
|
-
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
|
727
|
-
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
|
728
|
-
|
|
729
|
-
LOG_INFO("slot self-extend", {
|
|
730
|
-
{"id_slot", slot.id},
|
|
731
|
-
{"ga_n", ga_n},
|
|
732
|
-
{"ga_w", ga_w}
|
|
733
|
-
});
|
|
734
|
-
}
|
|
735
|
-
|
|
736
|
-
slot.ga_i = 0;
|
|
737
|
-
slot.ga_n = ga_n;
|
|
738
|
-
slot.ga_w = ga_w;
|
|
688
|
+
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
|
739
689
|
|
|
740
690
|
slot.sparams = params.sparams;
|
|
741
691
|
|
|
692
|
+
slot.callback_on_release = [this](int) {
|
|
693
|
+
queue_tasks.pop_deferred_task();
|
|
694
|
+
};
|
|
695
|
+
|
|
742
696
|
slot.reset();
|
|
743
697
|
|
|
744
698
|
slots.push_back(slot);
|
|
@@ -747,59 +701,18 @@ struct server_context {
|
|
|
747
701
|
default_generation_settings_for_props = get_formated_generation(slots.front());
|
|
748
702
|
default_generation_settings_for_props["seed"] = -1;
|
|
749
703
|
|
|
750
|
-
// the update_slots() logic will always submit a maximum of n_batch tokens
|
|
704
|
+
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
|
751
705
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
752
706
|
{
|
|
753
707
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
754
708
|
|
|
755
709
|
// only a single seq_id per token is needed
|
|
756
|
-
batch = llama_batch_init(n_batch, 0, 1);
|
|
710
|
+
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
|
|
757
711
|
}
|
|
758
712
|
|
|
759
713
|
metrics.init();
|
|
760
714
|
}
|
|
761
715
|
|
|
762
|
-
std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
|
|
763
|
-
// TODO: currently, we tokenize using special tokens by default
|
|
764
|
-
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
|
|
765
|
-
// but it's better compared to completely ignoring ChatML and other chat templates
|
|
766
|
-
const bool TMP_FORCE_SPECIAL = true;
|
|
767
|
-
|
|
768
|
-
// If `add_bos` is true, we only add BOS, when json_prompt is a string,
|
|
769
|
-
// or the first element of the json_prompt array is a string.
|
|
770
|
-
std::vector<llama_token> prompt_tokens;
|
|
771
|
-
|
|
772
|
-
if (json_prompt.is_array()) {
|
|
773
|
-
bool first = true;
|
|
774
|
-
for (const auto & p : json_prompt) {
|
|
775
|
-
if (p.is_string()) {
|
|
776
|
-
auto s = p.template get<std::string>();
|
|
777
|
-
|
|
778
|
-
std::vector<llama_token> p;
|
|
779
|
-
if (first) {
|
|
780
|
-
p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
781
|
-
first = false;
|
|
782
|
-
} else {
|
|
783
|
-
p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
|
|
784
|
-
}
|
|
785
|
-
|
|
786
|
-
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
|
|
787
|
-
} else {
|
|
788
|
-
if (first) {
|
|
789
|
-
first = false;
|
|
790
|
-
}
|
|
791
|
-
|
|
792
|
-
prompt_tokens.push_back(p.template get<llama_token>());
|
|
793
|
-
}
|
|
794
|
-
}
|
|
795
|
-
} else {
|
|
796
|
-
auto s = json_prompt.template get<std::string>();
|
|
797
|
-
prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
|
|
798
|
-
}
|
|
799
|
-
|
|
800
|
-
return prompt_tokens;
|
|
801
|
-
}
|
|
802
|
-
|
|
803
716
|
server_slot * get_slot_by_id(int id) {
|
|
804
717
|
for (server_slot & slot : slots) {
|
|
805
718
|
if (slot.id == id) {
|
|
@@ -810,50 +723,41 @@ struct server_context {
|
|
|
810
723
|
return nullptr;
|
|
811
724
|
}
|
|
812
725
|
|
|
813
|
-
server_slot * get_available_slot(const
|
|
726
|
+
server_slot * get_available_slot(const server_task & task) {
|
|
814
727
|
server_slot * ret = nullptr;
|
|
815
728
|
|
|
816
729
|
// find the slot that has at least n% prompt similarity
|
|
817
|
-
if (ret == nullptr && slot_prompt_similarity != 0.0f
|
|
818
|
-
int
|
|
730
|
+
if (ret == nullptr && slot_prompt_similarity != 0.0f) {
|
|
731
|
+
int lcs_len = 0;
|
|
819
732
|
float similarity = 0;
|
|
820
733
|
|
|
821
734
|
for (server_slot & slot : slots) {
|
|
822
735
|
// skip the slot if it is not available
|
|
823
|
-
if (
|
|
736
|
+
if (slot.is_processing()) {
|
|
824
737
|
continue;
|
|
825
738
|
}
|
|
826
739
|
|
|
827
|
-
// skip the slot if it does not contains
|
|
828
|
-
if (
|
|
740
|
+
// skip the slot if it does not contains cached tokens
|
|
741
|
+
if (slot.cache_tokens.empty()) {
|
|
829
742
|
continue;
|
|
830
743
|
}
|
|
831
744
|
|
|
832
|
-
// current slot's prompt
|
|
833
|
-
|
|
834
|
-
|
|
835
|
-
// length of the current slot's prompt
|
|
836
|
-
int slot_prompt_len = slot_prompt.size();
|
|
745
|
+
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
|
|
746
|
+
int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
|
|
837
747
|
|
|
838
|
-
//
|
|
839
|
-
|
|
840
|
-
|
|
841
|
-
// fraction of the common substring length compared to the current slot's prompt length
|
|
842
|
-
similarity = static_cast<float>(lcp_len) / slot_prompt_len;
|
|
748
|
+
// fraction of the common subsequence length compared to the current slot's prompt length
|
|
749
|
+
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
|
|
843
750
|
|
|
844
751
|
// select the current slot if the criteria match
|
|
845
|
-
if (
|
|
846
|
-
|
|
752
|
+
if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
|
|
753
|
+
lcs_len = cur_lcs_len;
|
|
754
|
+
similarity = cur_similarity;
|
|
847
755
|
ret = &slot;
|
|
848
756
|
}
|
|
849
757
|
}
|
|
850
758
|
|
|
851
759
|
if (ret != nullptr) {
|
|
852
|
-
|
|
853
|
-
{"id_slot", ret->id},
|
|
854
|
-
{"max_lcp_len", max_lcp_len},
|
|
855
|
-
{"similarity", similarity},
|
|
856
|
-
});
|
|
760
|
+
SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity);
|
|
857
761
|
}
|
|
858
762
|
}
|
|
859
763
|
|
|
@@ -862,7 +766,7 @@ struct server_context {
|
|
|
862
766
|
int64_t t_last = ggml_time_us();
|
|
863
767
|
for (server_slot & slot : slots) {
|
|
864
768
|
// skip the slot if it is not available
|
|
865
|
-
if (
|
|
769
|
+
if (slot.is_processing()) {
|
|
866
770
|
continue;
|
|
867
771
|
}
|
|
868
772
|
|
|
@@ -874,10 +778,7 @@ struct server_context {
|
|
|
874
778
|
}
|
|
875
779
|
|
|
876
780
|
if (ret != nullptr) {
|
|
877
|
-
|
|
878
|
-
{"id_slot", ret->id},
|
|
879
|
-
{"t_last", t_last},
|
|
880
|
-
});
|
|
781
|
+
SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last);
|
|
881
782
|
}
|
|
882
783
|
}
|
|
883
784
|
|
|
@@ -887,8 +788,8 @@ struct server_context {
|
|
|
887
788
|
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
|
888
789
|
slot_params default_params;
|
|
889
790
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
|
890
|
-
|
|
891
|
-
auto & data = task.data;
|
|
791
|
+
auto default_sparams = params.sparams;
|
|
792
|
+
const auto & data = task.data;
|
|
892
793
|
|
|
893
794
|
if (data.count("__oaicompat") != 0) {
|
|
894
795
|
slot.oaicompat = true;
|
|
@@ -898,133 +799,86 @@ struct server_context {
|
|
|
898
799
|
slot.oaicompat_model = "";
|
|
899
800
|
}
|
|
900
801
|
|
|
901
|
-
slot.params.stream
|
|
902
|
-
slot.params.cache_prompt
|
|
903
|
-
slot.params.n_predict
|
|
904
|
-
slot.
|
|
905
|
-
slot.sparams.
|
|
906
|
-
slot.sparams.
|
|
907
|
-
slot.sparams.
|
|
908
|
-
slot.sparams.
|
|
909
|
-
slot.sparams.
|
|
910
|
-
slot.sparams.
|
|
911
|
-
slot.sparams.
|
|
912
|
-
slot.sparams.
|
|
913
|
-
slot.sparams.
|
|
914
|
-
slot.sparams.
|
|
915
|
-
slot.sparams.
|
|
916
|
-
slot.sparams.
|
|
917
|
-
slot.sparams.
|
|
918
|
-
slot.sparams.
|
|
919
|
-
slot.sparams.
|
|
920
|
-
slot.
|
|
921
|
-
slot.
|
|
922
|
-
slot.sparams.
|
|
923
|
-
slot.sparams.
|
|
924
|
-
slot.sparams.
|
|
802
|
+
slot.params.stream = json_value(data, "stream", false);
|
|
803
|
+
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
|
804
|
+
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
|
805
|
+
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
|
|
806
|
+
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
|
807
|
+
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
|
808
|
+
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
|
809
|
+
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
|
810
|
+
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
|
811
|
+
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
|
812
|
+
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
|
813
|
+
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
|
814
|
+
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
|
815
|
+
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
|
816
|
+
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
|
817
|
+
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
|
818
|
+
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
|
819
|
+
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
|
820
|
+
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
|
821
|
+
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
|
822
|
+
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
|
823
|
+
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
|
824
|
+
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
|
825
|
+
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
|
826
|
+
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
|
827
|
+
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
|
|
828
|
+
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
|
829
|
+
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
|
830
|
+
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
|
831
|
+
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
|
832
|
+
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
|
|
833
|
+
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
|
|
834
|
+
|
|
835
|
+
if (slot.sparams.dry_base < 1.0f)
|
|
836
|
+
{
|
|
837
|
+
slot.sparams.dry_base = default_sparams.dry_base;
|
|
838
|
+
}
|
|
839
|
+
|
|
840
|
+
// sequence breakers for DRY
|
|
841
|
+
{
|
|
842
|
+
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
|
843
|
+
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
|
844
|
+
|
|
845
|
+
if (data.contains("dry_sequence_breakers")) {
|
|
846
|
+
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
|
847
|
+
if (slot.sparams.dry_sequence_breakers.empty()) {
|
|
848
|
+
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
|
849
|
+
return false;
|
|
850
|
+
}
|
|
851
|
+
}
|
|
852
|
+
}
|
|
925
853
|
|
|
926
854
|
// process "json_schema" and "grammar"
|
|
927
855
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
|
928
856
|
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
|
929
857
|
return false;
|
|
930
|
-
}
|
|
858
|
+
}
|
|
859
|
+
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
931
860
|
try {
|
|
932
|
-
auto schema
|
|
933
|
-
slot.sparams.grammar
|
|
861
|
+
auto schema = json_value(data, "json_schema", json::object());
|
|
862
|
+
slot.sparams.grammar = json_schema_to_grammar(schema);
|
|
934
863
|
} catch (const std::exception & e) {
|
|
935
864
|
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
|
|
936
865
|
return false;
|
|
937
866
|
}
|
|
938
867
|
} else {
|
|
939
|
-
slot.sparams.grammar
|
|
940
|
-
}
|
|
941
|
-
|
|
942
|
-
if (slot.params.cache_prompt && slot.ga_n != 1) {
|
|
943
|
-
LOG_WARNING("cache_prompt is not supported with group-attention", {});
|
|
944
|
-
slot.params.cache_prompt = false;
|
|
868
|
+
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
|
945
869
|
}
|
|
946
870
|
|
|
947
871
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
|
948
872
|
// Might be better to reject the request with a 400 ?
|
|
949
|
-
LOG_WARNING("Max tokens to predict exceeds server configuration", {
|
|
950
|
-
{"params.n_predict", slot.params.n_predict},
|
|
951
|
-
{"slot.n_predict", slot.n_predict},
|
|
952
|
-
});
|
|
953
873
|
slot.params.n_predict = slot.n_predict;
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
// infill
|
|
957
|
-
slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
|
|
958
|
-
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
|
|
959
|
-
|
|
960
|
-
// get prompt
|
|
961
|
-
if (!task.infill) {
|
|
962
|
-
const auto & prompt = data.find("prompt");
|
|
963
|
-
if (prompt == data.end()) {
|
|
964
|
-
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
|
965
|
-
return false;
|
|
966
|
-
}
|
|
967
|
-
|
|
968
|
-
if ((prompt->is_string()) ||
|
|
969
|
-
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
|
970
|
-
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
|
971
|
-
slot.prompt = *prompt;
|
|
972
|
-
} else {
|
|
973
|
-
send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
974
|
-
return false;
|
|
975
|
-
}
|
|
976
|
-
}
|
|
977
|
-
|
|
978
|
-
// penalize user-provided tokens
|
|
979
|
-
{
|
|
980
|
-
slot.sparams.penalty_prompt_tokens.clear();
|
|
981
|
-
slot.sparams.use_penalty_prompt_tokens = false;
|
|
982
|
-
|
|
983
|
-
const auto & penalty_prompt = data.find("penalty_prompt");
|
|
984
|
-
|
|
985
|
-
if (penalty_prompt != data.end()) {
|
|
986
|
-
if (penalty_prompt->is_string()) {
|
|
987
|
-
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
|
988
|
-
slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
|
|
989
|
-
|
|
990
|
-
if (slot.params.n_predict > 0) {
|
|
991
|
-
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
|
|
992
|
-
}
|
|
993
|
-
slot.sparams.use_penalty_prompt_tokens = true;
|
|
994
|
-
|
|
995
|
-
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
996
|
-
{"id_slot", slot.id},
|
|
997
|
-
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
998
|
-
});
|
|
999
|
-
}
|
|
1000
|
-
else if (penalty_prompt->is_array()) {
|
|
1001
|
-
const auto n_tokens = penalty_prompt->size();
|
|
1002
|
-
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
|
|
1003
|
-
|
|
1004
|
-
const int n_vocab = llama_n_vocab(model);
|
|
1005
|
-
for (const auto & penalty_token : *penalty_prompt) {
|
|
1006
|
-
if (penalty_token.is_number_integer()) {
|
|
1007
|
-
const auto tok = penalty_token.get<llama_token>();
|
|
1008
|
-
if (tok >= 0 && tok < n_vocab) {
|
|
1009
|
-
slot.sparams.penalty_prompt_tokens.push_back(tok);
|
|
1010
|
-
}
|
|
1011
|
-
}
|
|
1012
|
-
}
|
|
1013
|
-
slot.sparams.use_penalty_prompt_tokens = true;
|
|
1014
|
-
|
|
1015
|
-
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
1016
|
-
{"id_slot", slot.id},
|
|
1017
|
-
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
1018
|
-
});
|
|
1019
|
-
}
|
|
1020
|
-
}
|
|
874
|
+
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
|
1021
875
|
}
|
|
1022
876
|
|
|
1023
877
|
{
|
|
1024
878
|
slot.sparams.logit_bias.clear();
|
|
1025
879
|
|
|
1026
|
-
if (json_value(data, "ignore_eos", false)) {
|
|
1027
|
-
slot.sparams.logit_bias
|
|
880
|
+
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
|
881
|
+
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
|
1028
882
|
}
|
|
1029
883
|
|
|
1030
884
|
const auto & logit_bias = data.find("logit_bias");
|
|
@@ -1045,12 +899,12 @@ struct server_context {
|
|
|
1045
899
|
if (el[0].is_number_integer()) {
|
|
1046
900
|
llama_token tok = el[0].get<llama_token>();
|
|
1047
901
|
if (tok >= 0 && tok < n_vocab) {
|
|
1048
|
-
slot.sparams.logit_bias
|
|
902
|
+
slot.sparams.logit_bias.push_back({tok, bias});
|
|
1049
903
|
}
|
|
1050
904
|
} else if (el[0].is_string()) {
|
|
1051
|
-
auto toks =
|
|
905
|
+
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
|
|
1052
906
|
for (auto tok : toks) {
|
|
1053
|
-
slot.sparams.logit_bias
|
|
907
|
+
slot.sparams.logit_bias.push_back({tok, bias});
|
|
1054
908
|
}
|
|
1055
909
|
}
|
|
1056
910
|
}
|
|
@@ -1072,128 +926,65 @@ struct server_context {
|
|
|
1072
926
|
}
|
|
1073
927
|
|
|
1074
928
|
{
|
|
1075
|
-
const auto &
|
|
1076
|
-
if (
|
|
1077
|
-
|
|
1078
|
-
|
|
1079
|
-
|
|
1080
|
-
|
|
929
|
+
const auto & samplers = data.find("samplers");
|
|
930
|
+
if (samplers != data.end()) {
|
|
931
|
+
if (samplers->is_array()) {
|
|
932
|
+
std::vector<std::string> sampler_names;
|
|
933
|
+
for (const auto & name : *samplers) {
|
|
934
|
+
if (name.is_string()) {
|
|
935
|
+
sampler_names.emplace_back(name);
|
|
936
|
+
}
|
|
1081
937
|
}
|
|
938
|
+
slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false);
|
|
939
|
+
} else if (samplers->is_string()){
|
|
940
|
+
std::string sampler_string;
|
|
941
|
+
for (const auto & name : *samplers) {
|
|
942
|
+
sampler_string += name;
|
|
943
|
+
}
|
|
944
|
+
slot.sparams.samplers = common_sampler_types_from_chars(sampler_string);
|
|
1082
945
|
}
|
|
1083
|
-
slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
|
|
1084
946
|
} else {
|
|
1085
|
-
slot.sparams.
|
|
947
|
+
slot.sparams.samplers = default_sparams.samplers;
|
|
1086
948
|
}
|
|
1087
949
|
}
|
|
1088
950
|
|
|
1089
951
|
{
|
|
1090
|
-
if (slot.
|
|
1091
|
-
|
|
952
|
+
if (slot.smpl != nullptr) {
|
|
953
|
+
common_sampler_free(slot.smpl);
|
|
1092
954
|
}
|
|
1093
|
-
|
|
1094
|
-
|
|
955
|
+
|
|
956
|
+
slot.smpl = common_sampler_init(model, slot.sparams);
|
|
957
|
+
if (slot.smpl == nullptr) {
|
|
1095
958
|
// for now, the only error that may happen here is invalid grammar
|
|
1096
959
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
|
1097
960
|
return false;
|
|
1098
961
|
}
|
|
1099
962
|
}
|
|
1100
963
|
|
|
1101
|
-
slot.
|
|
1102
|
-
slot.prompt_tokens.clear();
|
|
964
|
+
slot.state = SLOT_STATE_STARTED;
|
|
1103
965
|
|
|
1104
|
-
|
|
1105
|
-
{"id_slot", slot.id},
|
|
1106
|
-
{"id_task", slot.id_task},
|
|
1107
|
-
});
|
|
966
|
+
SLT_INF(slot, "%s", "processing task\n");
|
|
1108
967
|
|
|
1109
968
|
return true;
|
|
1110
969
|
}
|
|
1111
970
|
|
|
1112
971
|
void kv_cache_clear() {
|
|
1113
|
-
|
|
972
|
+
SRV_DBG("%s", "clearing KV cache\n");
|
|
1114
973
|
|
|
1115
974
|
// clear the entire KV cache
|
|
1116
975
|
llama_kv_cache_clear(ctx);
|
|
1117
976
|
clean_kv_cache = false;
|
|
1118
977
|
}
|
|
1119
978
|
|
|
1120
|
-
void system_prompt_update() {
|
|
1121
|
-
LOG_VERBOSE("system prompt update", {
|
|
1122
|
-
{"system_prompt", system_prompt},
|
|
1123
|
-
});
|
|
1124
|
-
|
|
1125
|
-
kv_cache_clear();
|
|
1126
|
-
system_tokens.clear();
|
|
1127
|
-
|
|
1128
|
-
if (!system_prompt.empty()) {
|
|
1129
|
-
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
|
1130
|
-
|
|
1131
|
-
llama_batch_clear(batch);
|
|
1132
|
-
|
|
1133
|
-
for (int i = 0; i < (int)system_tokens.size(); ++i) {
|
|
1134
|
-
llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
|
|
1135
|
-
}
|
|
1136
|
-
|
|
1137
|
-
const int32_t n_batch = llama_n_batch(ctx);
|
|
1138
|
-
|
|
1139
|
-
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
|
1140
|
-
const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
|
|
1141
|
-
llama_batch batch_view = {
|
|
1142
|
-
n_tokens,
|
|
1143
|
-
batch.token + i,
|
|
1144
|
-
nullptr,
|
|
1145
|
-
batch.pos + i,
|
|
1146
|
-
batch.n_seq_id + i,
|
|
1147
|
-
batch.seq_id + i,
|
|
1148
|
-
batch.logits + i,
|
|
1149
|
-
0, 0, 0, // unused
|
|
1150
|
-
};
|
|
1151
|
-
|
|
1152
|
-
if (llama_decode(ctx, batch_view) != 0) {
|
|
1153
|
-
LOG_ERROR("llama_decode() failed", {});
|
|
1154
|
-
return;
|
|
1155
|
-
}
|
|
1156
|
-
}
|
|
1157
|
-
|
|
1158
|
-
// assign the system KV cache to all parallel sequences
|
|
1159
|
-
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
|
1160
|
-
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
|
1161
|
-
}
|
|
1162
|
-
}
|
|
1163
|
-
|
|
1164
|
-
system_need_update = false;
|
|
1165
|
-
}
|
|
1166
|
-
|
|
1167
|
-
bool system_prompt_set(const std::string & sys_prompt) {
|
|
1168
|
-
system_prompt = sys_prompt;
|
|
1169
|
-
|
|
1170
|
-
LOG_VERBOSE("system prompt process", {
|
|
1171
|
-
{"system_prompt", system_prompt},
|
|
1172
|
-
});
|
|
1173
|
-
|
|
1174
|
-
// release all slots
|
|
1175
|
-
for (server_slot & slot : slots) {
|
|
1176
|
-
slot.release();
|
|
1177
|
-
}
|
|
1178
|
-
|
|
1179
|
-
system_need_update = true;
|
|
1180
|
-
return true;
|
|
1181
|
-
}
|
|
1182
|
-
|
|
1183
979
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
|
1184
980
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
1185
|
-
const std::string token_str =
|
|
981
|
+
const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
|
|
1186
982
|
slot.sampled = result.tok;
|
|
1187
983
|
|
|
1188
984
|
// search stop word and delete it
|
|
1189
985
|
slot.generated_text += token_str;
|
|
1190
986
|
slot.has_next_token = true;
|
|
1191
987
|
|
|
1192
|
-
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
|
1193
|
-
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
|
1194
|
-
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
|
1195
|
-
}
|
|
1196
|
-
|
|
1197
988
|
// check if there is incomplete UTF-8 character at the end
|
|
1198
989
|
bool incomplete = false;
|
|
1199
990
|
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
|
@@ -1220,29 +1011,28 @@ struct server_context {
|
|
|
1220
1011
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
1221
1012
|
|
|
1222
1013
|
const std::string str_test = slot.generated_text.substr(pos);
|
|
1223
|
-
bool
|
|
1014
|
+
bool send_text = true;
|
|
1224
1015
|
|
|
1225
1016
|
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
|
|
1226
1017
|
if (stop_pos != std::string::npos) {
|
|
1227
|
-
is_stop_full = true;
|
|
1228
1018
|
slot.generated_text.erase(
|
|
1229
1019
|
slot.generated_text.begin() + pos + stop_pos,
|
|
1230
1020
|
slot.generated_text.end());
|
|
1231
1021
|
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
1232
|
-
} else {
|
|
1233
|
-
is_stop_full = false;
|
|
1022
|
+
} else if (slot.has_next_token) {
|
|
1234
1023
|
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
|
|
1024
|
+
send_text = stop_pos == std::string::npos;
|
|
1235
1025
|
}
|
|
1236
1026
|
|
|
1237
1027
|
// check if there is any token to predict
|
|
1238
|
-
if (
|
|
1028
|
+
if (send_text) {
|
|
1239
1029
|
// no send the stop word in the response
|
|
1240
1030
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
|
1241
1031
|
slot.n_sent_text += result.text_to_send.size();
|
|
1242
1032
|
// add the token to slot queue and cache
|
|
1243
1033
|
}
|
|
1244
1034
|
|
|
1245
|
-
slot.
|
|
1035
|
+
slot.add_token(result);
|
|
1246
1036
|
if (slot.params.stream) {
|
|
1247
1037
|
send_partial_response(slot, result);
|
|
1248
1038
|
}
|
|
@@ -1257,124 +1047,155 @@ struct server_context {
|
|
|
1257
1047
|
slot.stopped_limit = true;
|
|
1258
1048
|
slot.has_next_token = false;
|
|
1259
1049
|
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1050
|
+
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
|
1051
|
+
}
|
|
1052
|
+
|
|
1053
|
+
if (slot.has_new_line) {
|
|
1054
|
+
// if we have already seen a new line, we stop after a certain time limit
|
|
1055
|
+
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
|
1056
|
+
slot.stopped_limit = true;
|
|
1057
|
+
slot.has_next_token = false;
|
|
1058
|
+
|
|
1059
|
+
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
|
1060
|
+
}
|
|
1061
|
+
|
|
1062
|
+
// require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
|
|
1063
|
+
if (slot.params.n_indent > 0) {
|
|
1064
|
+
// check the current indentation
|
|
1065
|
+
// TODO: improve by not doing it more than once for each new line
|
|
1066
|
+
if (slot.last_nl_pos > 0) {
|
|
1067
|
+
size_t pos = slot.last_nl_pos;
|
|
1068
|
+
|
|
1069
|
+
int n_indent = 0;
|
|
1070
|
+
while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
|
|
1071
|
+
n_indent++;
|
|
1072
|
+
pos++;
|
|
1073
|
+
}
|
|
1074
|
+
|
|
1075
|
+
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
|
|
1076
|
+
slot.stopped_limit = true;
|
|
1077
|
+
slot.has_next_token = false;
|
|
1078
|
+
|
|
1079
|
+
// cut the last line
|
|
1080
|
+
slot.generated_text.erase(pos, std::string::npos);
|
|
1081
|
+
|
|
1082
|
+
SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
|
|
1083
|
+
}
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
// find the next new line
|
|
1087
|
+
{
|
|
1088
|
+
const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
|
|
1089
|
+
|
|
1090
|
+
if (pos != std::string::npos) {
|
|
1091
|
+
slot.last_nl_pos = pos + 1;
|
|
1092
|
+
}
|
|
1093
|
+
}
|
|
1094
|
+
}
|
|
1095
|
+
}
|
|
1096
|
+
|
|
1097
|
+
// check if there is a new line in the generated text
|
|
1098
|
+
if (result.text_to_send.find('\n') != std::string::npos) {
|
|
1099
|
+
slot.has_new_line = true;
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
// if context shift is disabled, we stop when it reaches the context limit
|
|
1103
|
+
if (slot.n_past >= slot.n_ctx) {
|
|
1104
|
+
slot.truncated = true;
|
|
1105
|
+
slot.stopped_limit = true;
|
|
1106
|
+
slot.has_next_token = false;
|
|
1107
|
+
|
|
1108
|
+
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
|
|
1109
|
+
slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
|
|
1266
1110
|
}
|
|
1267
1111
|
|
|
1268
1112
|
if (llama_token_is_eog(model, result.tok)) {
|
|
1269
1113
|
slot.stopped_eos = true;
|
|
1270
1114
|
slot.has_next_token = false;
|
|
1271
1115
|
|
|
1272
|
-
|
|
1273
|
-
}
|
|
1274
|
-
|
|
1275
|
-
auto n_ctx_train = llama_n_ctx_train(model);
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
LOG_WARNING("n_predict is not set and self-context extend is disabled."
|
|
1279
|
-
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
|
|
1280
|
-
{ "id_slot", slot.id },
|
|
1281
|
-
{ "params.n_predict", slot.params.n_predict },
|
|
1282
|
-
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
|
|
1283
|
-
{ "slot.n_decoded", slot.n_decoded },
|
|
1284
|
-
{ "slot.n_predict", slot.n_predict },
|
|
1285
|
-
{ "n_slots", params.n_parallel },
|
|
1286
|
-
{ "slot.n_ctx", slot.n_ctx },
|
|
1287
|
-
{ "n_ctx", n_ctx },
|
|
1288
|
-
{ "n_ctx_train", n_ctx_train },
|
|
1289
|
-
{ "ga_n", slot.ga_n },
|
|
1290
|
-
});
|
|
1116
|
+
SLT_DBG(slot, "%s", "stopped by EOS\n");
|
|
1117
|
+
}
|
|
1118
|
+
|
|
1119
|
+
const auto n_ctx_train = llama_n_ctx_train(model);
|
|
1120
|
+
|
|
1121
|
+
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
|
1291
1122
|
slot.truncated = true;
|
|
1292
1123
|
slot.stopped_limit = true;
|
|
1293
1124
|
slot.has_next_token = false; // stop prediction
|
|
1125
|
+
|
|
1126
|
+
SLT_WRN(slot,
|
|
1127
|
+
"n_predict (%d) is set for infinite generation. "
|
|
1128
|
+
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
|
|
1129
|
+
slot.params.n_predict, n_ctx_train);
|
|
1294
1130
|
}
|
|
1295
1131
|
|
|
1296
|
-
|
|
1297
|
-
{"id_slot", slot.id},
|
|
1298
|
-
{"id_task", slot.id_task},
|
|
1299
|
-
{"token", result.tok},
|
|
1300
|
-
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
|
1301
|
-
{"has_next_token", slot.has_next_token},
|
|
1302
|
-
{"n_remain", slot.n_remaining},
|
|
1303
|
-
{"n_decoded", slot.n_decoded},
|
|
1304
|
-
{"stopped_eos", slot.stopped_eos},
|
|
1305
|
-
{"stopped_word", slot.stopped_word},
|
|
1306
|
-
{"stopped_limit", slot.stopped_limit},
|
|
1307
|
-
{"stopping_word", slot.stopping_word},
|
|
1308
|
-
});
|
|
1132
|
+
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
|
|
1309
1133
|
|
|
1310
1134
|
return slot.has_next_token; // continue
|
|
1311
1135
|
}
|
|
1312
1136
|
|
|
1313
1137
|
json get_formated_generation(const server_slot & slot) const {
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
|
|
1319
|
-
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
|
|
1320
|
-
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
|
|
1138
|
+
std::vector<std::string> samplers;
|
|
1139
|
+
samplers.reserve(slot.sparams.samplers.size());
|
|
1140
|
+
for (const auto & sampler : slot.sparams.samplers) {
|
|
1141
|
+
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
|
1321
1142
|
}
|
|
1322
1143
|
|
|
1323
1144
|
return json {
|
|
1324
1145
|
{"n_ctx", slot.n_ctx},
|
|
1325
|
-
{"n_predict", slot.n_predict},
|
|
1146
|
+
{"n_predict", slot.n_predict}, // Server configured n_predict
|
|
1326
1147
|
{"model", params.model_alias},
|
|
1327
1148
|
{"seed", slot.sparams.seed},
|
|
1149
|
+
{"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
|
|
1328
1150
|
{"temperature", slot.sparams.temp},
|
|
1329
1151
|
{"dynatemp_range", slot.sparams.dynatemp_range},
|
|
1330
1152
|
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
|
1331
1153
|
{"top_k", slot.sparams.top_k},
|
|
1332
1154
|
{"top_p", slot.sparams.top_p},
|
|
1333
1155
|
{"min_p", slot.sparams.min_p},
|
|
1334
|
-
{"
|
|
1335
|
-
{"
|
|
1156
|
+
{"xtc_probability", slot.sparams.xtc_probability},
|
|
1157
|
+
{"xtc_threshold", slot.sparams.xtc_threshold},
|
|
1158
|
+
{"typical_p", slot.sparams.typ_p},
|
|
1336
1159
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
|
1337
1160
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
|
1338
1161
|
{"presence_penalty", slot.sparams.penalty_present},
|
|
1339
1162
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
|
1340
|
-
{"
|
|
1341
|
-
{"
|
|
1163
|
+
{"dry_multiplier", slot.sparams.dry_multiplier},
|
|
1164
|
+
{"dry_base", slot.sparams.dry_base},
|
|
1165
|
+
{"dry_allowed_length", slot.sparams.dry_allowed_length},
|
|
1166
|
+
{"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
|
|
1167
|
+
{"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
|
|
1342
1168
|
{"mirostat", slot.sparams.mirostat},
|
|
1343
1169
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
|
1344
1170
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
|
1345
1171
|
{"penalize_nl", slot.sparams.penalize_nl},
|
|
1346
1172
|
{"stop", slot.params.antiprompt},
|
|
1347
|
-
{"
|
|
1173
|
+
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
|
1348
1174
|
{"n_keep", slot.params.n_keep},
|
|
1349
1175
|
{"n_discard", slot.params.n_discard},
|
|
1350
|
-
{"ignore_eos", ignore_eos},
|
|
1176
|
+
{"ignore_eos", slot.sparams.ignore_eos},
|
|
1351
1177
|
{"stream", slot.params.stream},
|
|
1352
|
-
|
|
1178
|
+
//{"logit_bias", slot.sparams.logit_bias},
|
|
1353
1179
|
{"n_probs", slot.sparams.n_probs},
|
|
1354
1180
|
{"min_keep", slot.sparams.min_keep},
|
|
1355
1181
|
{"grammar", slot.sparams.grammar},
|
|
1356
|
-
{"samplers",
|
|
1182
|
+
{"samplers", samplers},
|
|
1357
1183
|
};
|
|
1358
1184
|
}
|
|
1359
1185
|
|
|
1360
1186
|
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
1361
|
-
send_error(task.id,
|
|
1187
|
+
send_error(task.id, error, type);
|
|
1362
1188
|
}
|
|
1363
1189
|
|
|
1364
1190
|
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
1365
|
-
send_error(slot.id_task,
|
|
1191
|
+
send_error(slot.id_task, error, type);
|
|
1366
1192
|
}
|
|
1367
1193
|
|
|
1368
|
-
void send_error(const int id_task, const
|
|
1369
|
-
|
|
1370
|
-
{"id_multi", id_multi},
|
|
1371
|
-
{"id_task", id_task},
|
|
1372
|
-
{"error", error},
|
|
1373
|
-
});
|
|
1194
|
+
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
1195
|
+
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
|
|
1374
1196
|
|
|
1375
1197
|
server_task_result res;
|
|
1376
1198
|
res.id = id_task;
|
|
1377
|
-
res.id_multi = id_multi;
|
|
1378
1199
|
res.stop = false;
|
|
1379
1200
|
res.error = true;
|
|
1380
1201
|
res.data = format_error_response(error, type);
|
|
@@ -1385,18 +1206,18 @@ struct server_context {
|
|
|
1385
1206
|
void send_partial_response(server_slot & slot, completion_token_output tkn) {
|
|
1386
1207
|
server_task_result res;
|
|
1387
1208
|
res.id = slot.id_task;
|
|
1388
|
-
res.id_multi = slot.id_multi;
|
|
1389
1209
|
res.error = false;
|
|
1390
1210
|
res.stop = false;
|
|
1391
1211
|
res.data = json {
|
|
1392
1212
|
{"content", tkn.text_to_send},
|
|
1393
1213
|
{"stop", false},
|
|
1394
1214
|
{"id_slot", slot.id},
|
|
1395
|
-
{"multimodal", false}
|
|
1215
|
+
{"multimodal", false},
|
|
1216
|
+
{"index", slot.index},
|
|
1396
1217
|
};
|
|
1397
1218
|
|
|
1398
1219
|
if (slot.sparams.n_probs > 0) {
|
|
1399
|
-
const
|
|
1220
|
+
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
|
1400
1221
|
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
|
1401
1222
|
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
|
1402
1223
|
|
|
@@ -1422,7 +1243,6 @@ struct server_context {
|
|
|
1422
1243
|
void send_final_response(const server_slot & slot) {
|
|
1423
1244
|
server_task_result res;
|
|
1424
1245
|
res.id = slot.id_task;
|
|
1425
|
-
res.id_multi = slot.id_multi;
|
|
1426
1246
|
res.error = false;
|
|
1427
1247
|
res.stop = true;
|
|
1428
1248
|
res.data = json {
|
|
@@ -1433,20 +1253,22 @@ struct server_context {
|
|
|
1433
1253
|
{"tokens_predicted", slot.n_decoded},
|
|
1434
1254
|
{"tokens_evaluated", slot.n_prompt_tokens},
|
|
1435
1255
|
{"generation_settings", get_formated_generation(slot)},
|
|
1436
|
-
{"prompt", slot.
|
|
1256
|
+
{"prompt", common_detokenize(ctx, slot.prompt_tokens)},
|
|
1257
|
+
{"has_new_line", slot.has_new_line},
|
|
1437
1258
|
{"truncated", slot.truncated},
|
|
1438
1259
|
{"stopped_eos", slot.stopped_eos},
|
|
1439
1260
|
{"stopped_word", slot.stopped_word},
|
|
1440
1261
|
{"stopped_limit", slot.stopped_limit},
|
|
1441
1262
|
{"stopping_word", slot.stopping_word},
|
|
1442
1263
|
{"tokens_cached", slot.n_past},
|
|
1443
|
-
{"timings", slot.get_formated_timings()}
|
|
1264
|
+
{"timings", slot.get_formated_timings()},
|
|
1265
|
+
{"index", slot.index},
|
|
1444
1266
|
};
|
|
1445
1267
|
|
|
1446
1268
|
if (slot.sparams.n_probs > 0) {
|
|
1447
1269
|
std::vector<completion_token_output> probs;
|
|
1448
1270
|
if (!slot.params.stream && slot.stopped_word) {
|
|
1449
|
-
const
|
|
1271
|
+
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
|
1450
1272
|
|
|
1451
1273
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
|
1452
1274
|
probs = std::vector<completion_token_output>(
|
|
@@ -1471,17 +1293,16 @@ struct server_context {
|
|
|
1471
1293
|
|
|
1472
1294
|
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
|
1473
1295
|
server_task_result res;
|
|
1474
|
-
res.id
|
|
1475
|
-
res.
|
|
1476
|
-
res.
|
|
1477
|
-
res.stop = true;
|
|
1296
|
+
res.id = slot.id_task;
|
|
1297
|
+
res.error = false;
|
|
1298
|
+
res.stop = true;
|
|
1478
1299
|
|
|
1479
1300
|
const int n_embd = llama_n_embd(model);
|
|
1480
1301
|
|
|
1481
1302
|
std::vector<float> embd_res(n_embd, 0.0f);
|
|
1482
1303
|
|
|
1483
1304
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
1484
|
-
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id
|
|
1305
|
+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
1485
1306
|
continue;
|
|
1486
1307
|
}
|
|
1487
1308
|
|
|
@@ -1491,150 +1312,239 @@ struct server_context {
|
|
|
1491
1312
|
}
|
|
1492
1313
|
|
|
1493
1314
|
if (embd == NULL) {
|
|
1494
|
-
|
|
1495
|
-
{"token", batch.token [i]},
|
|
1496
|
-
{"seq_id", batch.seq_id[i][0]}
|
|
1497
|
-
});
|
|
1315
|
+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
1498
1316
|
|
|
1499
1317
|
res.data = json {
|
|
1500
1318
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
|
1319
|
+
{"index", slot.index},
|
|
1501
1320
|
};
|
|
1502
1321
|
|
|
1503
1322
|
continue;
|
|
1504
1323
|
}
|
|
1505
1324
|
|
|
1506
|
-
|
|
1325
|
+
common_embd_normalize(embd, embd_res.data(), n_embd);
|
|
1507
1326
|
|
|
1508
1327
|
res.data = json {
|
|
1509
1328
|
{"embedding", embd_res},
|
|
1329
|
+
{"index", slot.index},
|
|
1510
1330
|
};
|
|
1511
1331
|
}
|
|
1512
1332
|
|
|
1333
|
+
SLT_DBG(slot, "%s", "sending embeddings\n");
|
|
1334
|
+
|
|
1513
1335
|
queue_results.send(res);
|
|
1514
1336
|
}
|
|
1515
1337
|
|
|
1516
|
-
void
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
|
1527
|
-
// otherwise, it's a single-prompt task, we actually queue it
|
|
1528
|
-
// if there's numbers in the prompt array it will be treated as an array of tokens
|
|
1529
|
-
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
|
|
1530
|
-
bool numbers = false;
|
|
1531
|
-
for (const auto & e : task.data.at("prompt")) {
|
|
1532
|
-
if (e.is_number()) {
|
|
1533
|
-
numbers = true;
|
|
1534
|
-
break;
|
|
1535
|
-
}
|
|
1338
|
+
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
|
1339
|
+
server_task_result res;
|
|
1340
|
+
res.id = slot.id_task;
|
|
1341
|
+
res.error = false;
|
|
1342
|
+
res.stop = true;
|
|
1343
|
+
|
|
1344
|
+
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
1345
|
+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
1346
|
+
continue;
|
|
1536
1347
|
}
|
|
1537
1348
|
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
// if there are numbers, it needs to be treated like a single prompt,
|
|
1542
|
-
// queue_tasks handles a mix of strings and numbers just fine.
|
|
1543
|
-
if (numbers) {
|
|
1544
|
-
queue_tasks.post(task);
|
|
1545
|
-
} else {
|
|
1546
|
-
split_multiprompt_task(id_task, task);
|
|
1349
|
+
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
1350
|
+
if (embd == NULL) {
|
|
1351
|
+
embd = llama_get_embeddings_ith(ctx, i);
|
|
1547
1352
|
}
|
|
1548
|
-
} else {
|
|
1549
|
-
queue_tasks.post(task);
|
|
1550
|
-
}
|
|
1551
|
-
}
|
|
1552
1353
|
|
|
1553
|
-
|
|
1554
|
-
|
|
1555
|
-
task.type = SERVER_TASK_TYPE_CANCEL;
|
|
1556
|
-
task.id_target = id_task;
|
|
1354
|
+
if (embd == NULL) {
|
|
1355
|
+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
1557
1356
|
|
|
1558
|
-
|
|
1559
|
-
|
|
1357
|
+
res.data = json {
|
|
1358
|
+
{"index", slot.index},
|
|
1359
|
+
{"score", -1e6},
|
|
1360
|
+
};
|
|
1560
1361
|
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
if (prompt_count <= 1) {
|
|
1564
|
-
send_error(multiprompt_task, "error while handling multiple prompts");
|
|
1565
|
-
return;
|
|
1566
|
-
}
|
|
1362
|
+
continue;
|
|
1363
|
+
}
|
|
1567
1364
|
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1365
|
+
res.data = json {
|
|
1366
|
+
{"index", slot.index},
|
|
1367
|
+
{"score", embd[0]},
|
|
1368
|
+
};
|
|
1572
1369
|
}
|
|
1573
1370
|
|
|
1574
|
-
|
|
1575
|
-
queue_tasks.add_multitask(id_multi, subtask_ids);
|
|
1371
|
+
SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
|
|
1576
1372
|
|
|
1577
|
-
|
|
1578
|
-
for (int i = 0; i < prompt_count; i++) {
|
|
1579
|
-
json subtask_data = multiprompt_task.data;
|
|
1580
|
-
subtask_data["prompt"] = subtask_data.at("prompt")[i];
|
|
1581
|
-
|
|
1582
|
-
// subtasks inherit everything else (infill mode, embedding mode, etc.)
|
|
1583
|
-
request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding);
|
|
1584
|
-
}
|
|
1373
|
+
queue_results.send(res);
|
|
1585
1374
|
}
|
|
1586
1375
|
|
|
1587
|
-
|
|
1588
|
-
|
|
1589
|
-
|
|
1590
|
-
{
|
|
1591
|
-
const int id_slot = json_value(task.data, "id_slot", -1);
|
|
1376
|
+
//
|
|
1377
|
+
// Functions to create new task(s) and receive result(s)
|
|
1378
|
+
//
|
|
1592
1379
|
|
|
1593
|
-
|
|
1380
|
+
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
|
1381
|
+
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
|
|
1382
|
+
std::vector<server_task> tasks;
|
|
1383
|
+
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
|
1384
|
+
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
|
1385
|
+
server_task task;
|
|
1386
|
+
task.id = queue_tasks.get_new_id();
|
|
1387
|
+
task.inf_type = inf_type;
|
|
1388
|
+
task.type = SERVER_TASK_TYPE_INFERENCE;
|
|
1389
|
+
task.data = task_data;
|
|
1390
|
+
task.prompt_tokens = std::move(prompt_tokens);
|
|
1391
|
+
tasks.push_back(std::move(task));
|
|
1392
|
+
};
|
|
1594
1393
|
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
|
|
1598
|
-
|
|
1599
|
-
if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
|
|
1600
|
-
prompt = json_value(task.data, "prompt", std::string());
|
|
1601
|
-
}
|
|
1394
|
+
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
|
|
1395
|
+
if (!data.contains("prompt")) {
|
|
1396
|
+
throw std::runtime_error(error_msg);
|
|
1397
|
+
}
|
|
1602
1398
|
|
|
1603
|
-
|
|
1399
|
+
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
|
1400
|
+
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
|
|
1401
|
+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
|
1402
|
+
switch (inf_type) {
|
|
1403
|
+
case SERVER_TASK_INF_TYPE_RERANK:
|
|
1404
|
+
{
|
|
1405
|
+
// prompts[0] is the question
|
|
1406
|
+
// the rest are the answers/documents
|
|
1407
|
+
GGML_ASSERT(tokenized_prompts.size() > 1);
|
|
1408
|
+
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
|
|
1409
|
+
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
|
|
1410
|
+
data["index"] = i - 1;
|
|
1411
|
+
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
|
|
1412
|
+
create_task(data, tokens);
|
|
1604
1413
|
}
|
|
1414
|
+
} break;
|
|
1415
|
+
case SERVER_TASK_INF_TYPE_INFILL:
|
|
1416
|
+
{
|
|
1417
|
+
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
|
1418
|
+
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
|
1419
|
+
data["index"] = i;
|
|
1420
|
+
auto tokens = format_infill(
|
|
1421
|
+
ctx,
|
|
1422
|
+
data.at("input_prefix"),
|
|
1423
|
+
data.at("input_suffix"),
|
|
1424
|
+
data.at("input_extra"),
|
|
1425
|
+
params.n_batch,
|
|
1426
|
+
params.n_predict,
|
|
1427
|
+
slots[0].n_ctx, // TODO: there should be a better way
|
|
1428
|
+
params.spm_infill,
|
|
1429
|
+
tokenized_prompts[i]
|
|
1430
|
+
);
|
|
1431
|
+
create_task(data, tokens);
|
|
1432
|
+
}
|
|
1433
|
+
} break;
|
|
1434
|
+
default:
|
|
1435
|
+
{
|
|
1436
|
+
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
|
1437
|
+
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
|
1438
|
+
data["index"] = i;
|
|
1439
|
+
create_task(data, tokenized_prompts[i]);
|
|
1440
|
+
}
|
|
1441
|
+
}
|
|
1442
|
+
}
|
|
1443
|
+
|
|
1444
|
+
return tasks;
|
|
1445
|
+
}
|
|
1446
|
+
|
|
1447
|
+
void cancel_tasks(const std::unordered_set<int> & id_tasks) {
|
|
1448
|
+
std::vector<server_task> cancel_tasks;
|
|
1449
|
+
cancel_tasks.reserve(id_tasks.size());
|
|
1450
|
+
for (const auto & id_task : id_tasks) {
|
|
1451
|
+
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
|
1452
|
+
|
|
1453
|
+
server_task task;
|
|
1454
|
+
task.type = SERVER_TASK_TYPE_CANCEL;
|
|
1455
|
+
task.id_target = id_task;
|
|
1456
|
+
cancel_tasks.push_back(task);
|
|
1457
|
+
queue_results.remove_waiting_task_id(id_task);
|
|
1458
|
+
}
|
|
1459
|
+
// push to beginning of the queue, so it has highest priority
|
|
1460
|
+
queue_tasks.post(cancel_tasks, true);
|
|
1461
|
+
}
|
|
1462
|
+
|
|
1463
|
+
// receive the results from task(s) created by create_tasks_inference
|
|
1464
|
+
void receive_cmpl_results(
|
|
1465
|
+
const std::unordered_set<int> & id_tasks,
|
|
1466
|
+
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
|
1467
|
+
const std::function<void(json)> & error_handler) {
|
|
1468
|
+
// TODO: currently, there is no way to detect the client has cancelled the request
|
|
1469
|
+
std::vector<server_task_result> results(id_tasks.size());
|
|
1470
|
+
for (size_t i = 0; i < id_tasks.size(); i++) {
|
|
1471
|
+
server_task_result result = queue_results.recv(id_tasks);
|
|
1472
|
+
|
|
1473
|
+
if (result.error) {
|
|
1474
|
+
error_handler(result.data);
|
|
1475
|
+
cancel_tasks(id_tasks);
|
|
1476
|
+
return;
|
|
1477
|
+
}
|
|
1478
|
+
|
|
1479
|
+
const size_t idx = result.data["index"];
|
|
1480
|
+
GGML_ASSERT(idx < results.size() && "index out of range");
|
|
1481
|
+
|
|
1482
|
+
results[idx] = result;
|
|
1483
|
+
}
|
|
1484
|
+
result_handler(results);
|
|
1485
|
+
}
|
|
1486
|
+
|
|
1487
|
+
// receive the results from task(s) created by create_tasks_inference, in stream mode
|
|
1488
|
+
void receive_cmpl_results_stream(
|
|
1489
|
+
const std::unordered_set<int> & id_tasks, const
|
|
1490
|
+
std::function<bool(server_task_result&)> & result_handler, const
|
|
1491
|
+
std::function<void(json)> & error_handler) {
|
|
1492
|
+
size_t n_finished = 0;
|
|
1493
|
+
while (true) {
|
|
1494
|
+
server_task_result result = queue_results.recv(id_tasks);
|
|
1495
|
+
if (!result_handler(result)) {
|
|
1496
|
+
cancel_tasks(id_tasks);
|
|
1497
|
+
break;
|
|
1498
|
+
}
|
|
1499
|
+
|
|
1500
|
+
if (result.error) {
|
|
1501
|
+
error_handler(result.data);
|
|
1502
|
+
cancel_tasks(id_tasks);
|
|
1503
|
+
break;
|
|
1504
|
+
}
|
|
1505
|
+
|
|
1506
|
+
if (result.stop) {
|
|
1507
|
+
if (++n_finished == id_tasks.size()) {
|
|
1508
|
+
break;
|
|
1509
|
+
}
|
|
1510
|
+
}
|
|
1511
|
+
}
|
|
1512
|
+
}
|
|
1513
|
+
|
|
1514
|
+
//
|
|
1515
|
+
// Functions to process the task
|
|
1516
|
+
//
|
|
1517
|
+
|
|
1518
|
+
void process_single_task(server_task task) {
|
|
1519
|
+
switch (task.type) {
|
|
1520
|
+
case SERVER_TASK_TYPE_INFERENCE:
|
|
1521
|
+
{
|
|
1522
|
+
const int id_slot = json_value(task.data, "id_slot", -1);
|
|
1523
|
+
|
|
1524
|
+
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
|
|
1605
1525
|
|
|
1606
1526
|
if (slot == nullptr) {
|
|
1607
1527
|
// if no slot is available, we defer this task for processing later
|
|
1608
|
-
|
|
1528
|
+
SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
|
|
1609
1529
|
queue_tasks.defer(task);
|
|
1610
1530
|
break;
|
|
1611
1531
|
}
|
|
1612
|
-
if (
|
|
1532
|
+
if (slot->is_processing()) {
|
|
1613
1533
|
// if requested slot is unavailable, we defer this task for processing later
|
|
1614
|
-
|
|
1534
|
+
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
1615
1535
|
queue_tasks.defer(task);
|
|
1616
1536
|
break;
|
|
1617
1537
|
}
|
|
1618
1538
|
|
|
1619
|
-
if (task.data.contains("system_prompt")) {
|
|
1620
|
-
std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
|
|
1621
|
-
system_prompt_set(sys_prompt);
|
|
1622
|
-
|
|
1623
|
-
for (server_slot & slot : slots) {
|
|
1624
|
-
slot.n_past = 0;
|
|
1625
|
-
slot.n_past_se = 0;
|
|
1626
|
-
}
|
|
1627
|
-
}
|
|
1628
|
-
|
|
1629
1539
|
slot->reset();
|
|
1630
1540
|
|
|
1631
|
-
slot->id_task
|
|
1632
|
-
slot->
|
|
1633
|
-
slot->
|
|
1634
|
-
slot->
|
|
1541
|
+
slot->id_task = task.id;
|
|
1542
|
+
slot->inf_type = task.inf_type;
|
|
1543
|
+
slot->index = json_value(task.data, "index", 0);
|
|
1544
|
+
slot->prompt_tokens = std::move(task.prompt_tokens);
|
|
1635
1545
|
|
|
1636
1546
|
if (!launch_slot_with_task(*slot, task)) {
|
|
1637
|
-
|
|
1547
|
+
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
|
1638
1548
|
break;
|
|
1639
1549
|
}
|
|
1640
1550
|
} break;
|
|
@@ -1661,12 +1571,13 @@ struct server_context {
|
|
|
1661
1571
|
|
|
1662
1572
|
for (server_slot & slot : slots) {
|
|
1663
1573
|
json slot_data = get_formated_generation(slot);
|
|
1664
|
-
slot_data["id"]
|
|
1665
|
-
slot_data["id_task"]
|
|
1666
|
-
slot_data["
|
|
1667
|
-
slot_data["prompt"]
|
|
1668
|
-
slot_data["next_token"]
|
|
1574
|
+
slot_data["id"] = slot.id;
|
|
1575
|
+
slot_data["id_task"] = slot.id_task;
|
|
1576
|
+
slot_data["is_processing"] = slot.is_processing();
|
|
1577
|
+
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
|
1578
|
+
slot_data["next_token"] = {
|
|
1669
1579
|
{"has_next_token", slot.has_next_token},
|
|
1580
|
+
{"has_new_line", slot.has_new_line},
|
|
1670
1581
|
{"n_remain", slot.n_remaining},
|
|
1671
1582
|
{"n_decoded", slot.n_decoded},
|
|
1672
1583
|
{"stopped_eos", slot.stopped_eos},
|
|
@@ -1675,30 +1586,18 @@ struct server_context {
|
|
|
1675
1586
|
{"stopping_word", slot.stopping_word},
|
|
1676
1587
|
};
|
|
1677
1588
|
|
|
1678
|
-
if (
|
|
1679
|
-
n_idle_slots++;
|
|
1680
|
-
} else {
|
|
1589
|
+
if (slot.is_processing()) {
|
|
1681
1590
|
n_processing_slots++;
|
|
1591
|
+
} else {
|
|
1592
|
+
n_idle_slots++;
|
|
1682
1593
|
}
|
|
1683
1594
|
|
|
1684
1595
|
slots_data.push_back(slot_data);
|
|
1685
1596
|
}
|
|
1686
|
-
|
|
1687
|
-
{"id_task", task.id},
|
|
1688
|
-
{"n_idle_slots", n_idle_slots},
|
|
1689
|
-
{"n_processing_slots", n_processing_slots}
|
|
1690
|
-
});
|
|
1691
|
-
|
|
1692
|
-
LOG_VERBOSE("slot data", {
|
|
1693
|
-
{"id_task", task.id},
|
|
1694
|
-
{"n_idle_slots", n_idle_slots},
|
|
1695
|
-
{"n_processing_slots", n_processing_slots},
|
|
1696
|
-
{"slots", slots_data}
|
|
1697
|
-
});
|
|
1597
|
+
SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
|
|
1698
1598
|
|
|
1699
1599
|
server_task_result res;
|
|
1700
1600
|
res.id = task.id;
|
|
1701
|
-
res.id_multi = task.id_multi;
|
|
1702
1601
|
res.stop = true;
|
|
1703
1602
|
res.error = false;
|
|
1704
1603
|
res.data = {
|
|
@@ -1717,6 +1616,9 @@ struct server_context {
|
|
|
1717
1616
|
{ "n_tokens_predicted", metrics.n_tokens_predicted},
|
|
1718
1617
|
{ "t_tokens_generation", metrics.t_tokens_generation},
|
|
1719
1618
|
|
|
1619
|
+
{ "n_decode_total", metrics.n_decode_total},
|
|
1620
|
+
{ "n_busy_slots_total", metrics.n_busy_slots_total},
|
|
1621
|
+
|
|
1720
1622
|
{ "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
|
|
1721
1623
|
{ "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
|
|
1722
1624
|
|
|
@@ -1736,9 +1638,9 @@ struct server_context {
|
|
|
1736
1638
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
1737
1639
|
break;
|
|
1738
1640
|
}
|
|
1739
|
-
if (
|
|
1641
|
+
if (slot->is_processing()) {
|
|
1740
1642
|
// if requested slot is unavailable, we defer this task for processing later
|
|
1741
|
-
|
|
1643
|
+
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
1742
1644
|
queue_tasks.defer(task);
|
|
1743
1645
|
break;
|
|
1744
1646
|
}
|
|
@@ -1749,7 +1651,7 @@ struct server_context {
|
|
|
1749
1651
|
std::string filename = task.data.at("filename");
|
|
1750
1652
|
std::string filepath = task.data.at("filepath");
|
|
1751
1653
|
|
|
1752
|
-
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id
|
|
1654
|
+
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
|
1753
1655
|
|
|
1754
1656
|
const int64_t t_end = ggml_time_us();
|
|
1755
1657
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
|
@@ -1777,9 +1679,9 @@ struct server_context {
|
|
|
1777
1679
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
1778
1680
|
break;
|
|
1779
1681
|
}
|
|
1780
|
-
if (
|
|
1682
|
+
if (slot->is_processing()) {
|
|
1781
1683
|
// if requested slot is unavailable, we defer this task for processing later
|
|
1782
|
-
|
|
1684
|
+
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
1783
1685
|
queue_tasks.defer(task);
|
|
1784
1686
|
break;
|
|
1785
1687
|
}
|
|
@@ -1791,7 +1693,7 @@ struct server_context {
|
|
|
1791
1693
|
|
|
1792
1694
|
slot->cache_tokens.resize(slot->n_ctx);
|
|
1793
1695
|
size_t token_count = 0;
|
|
1794
|
-
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id
|
|
1696
|
+
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
|
1795
1697
|
if (nread == 0) {
|
|
1796
1698
|
slot->cache_tokens.resize(0);
|
|
1797
1699
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -1825,16 +1727,16 @@ struct server_context {
|
|
|
1825
1727
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
1826
1728
|
break;
|
|
1827
1729
|
}
|
|
1828
|
-
if (
|
|
1730
|
+
if (slot->is_processing()) {
|
|
1829
1731
|
// if requested slot is unavailable, we defer this task for processing later
|
|
1830
|
-
|
|
1732
|
+
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
1831
1733
|
queue_tasks.defer(task);
|
|
1832
1734
|
break;
|
|
1833
1735
|
}
|
|
1834
1736
|
|
|
1835
1737
|
// Erase token cache
|
|
1836
1738
|
const size_t n_erased = slot->cache_tokens.size();
|
|
1837
|
-
llama_kv_cache_seq_rm(ctx, slot->id
|
|
1739
|
+
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
|
1838
1740
|
slot->cache_tokens.clear();
|
|
1839
1741
|
|
|
1840
1742
|
server_task_result result;
|
|
@@ -1847,69 +1749,34 @@ struct server_context {
|
|
|
1847
1749
|
};
|
|
1848
1750
|
queue_results.send(result);
|
|
1849
1751
|
} break;
|
|
1752
|
+
case SERVER_TASK_TYPE_SET_LORA:
|
|
1753
|
+
{
|
|
1754
|
+
common_lora_adapters_apply(ctx, loras);
|
|
1755
|
+
server_task_result result;
|
|
1756
|
+
result.id = task.id;
|
|
1757
|
+
result.stop = true;
|
|
1758
|
+
result.error = false;
|
|
1759
|
+
result.data = json{{ "success", true }};
|
|
1760
|
+
queue_results.send(result);
|
|
1761
|
+
} break;
|
|
1850
1762
|
}
|
|
1851
1763
|
}
|
|
1852
1764
|
|
|
1853
|
-
void on_finish_multitask(const server_task_multi & multitask) {
|
|
1854
|
-
// all subtasks done == multitask is done
|
|
1855
|
-
server_task_result result;
|
|
1856
|
-
result.id = multitask.id;
|
|
1857
|
-
result.stop = true;
|
|
1858
|
-
result.error = false;
|
|
1859
|
-
|
|
1860
|
-
// collect json results into one json result
|
|
1861
|
-
std::vector<json> result_jsons;
|
|
1862
|
-
for (const auto & subres : multitask.results) {
|
|
1863
|
-
result_jsons.push_back(subres.data);
|
|
1864
|
-
result.error = result.error && subres.error;
|
|
1865
|
-
}
|
|
1866
|
-
result.data = json {
|
|
1867
|
-
{ "results", result_jsons }
|
|
1868
|
-
};
|
|
1869
|
-
|
|
1870
|
-
queue_results.send(result);
|
|
1871
|
-
}
|
|
1872
|
-
|
|
1873
1765
|
void update_slots() {
|
|
1874
|
-
if (system_need_update) {
|
|
1875
|
-
system_prompt_update();
|
|
1876
|
-
}
|
|
1877
|
-
|
|
1878
|
-
// release slots
|
|
1879
|
-
for (auto & slot : slots) {
|
|
1880
|
-
if (slot.command == SLOT_COMMAND_RELEASE) {
|
|
1881
|
-
slot.state = SLOT_STATE_IDLE;
|
|
1882
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
1883
|
-
slot.t_last_used = ggml_time_us();
|
|
1884
|
-
|
|
1885
|
-
LOG_INFO("slot released", {
|
|
1886
|
-
{"id_slot", slot.id},
|
|
1887
|
-
{"id_task", slot.id_task},
|
|
1888
|
-
{"n_ctx", n_ctx},
|
|
1889
|
-
{"n_past", slot.n_past},
|
|
1890
|
-
{"n_system_tokens", system_tokens.size()},
|
|
1891
|
-
{"n_cache_tokens", slot.cache_tokens.size()},
|
|
1892
|
-
{"truncated", slot.truncated}
|
|
1893
|
-
});
|
|
1894
|
-
|
|
1895
|
-
queue_tasks.notify_slot_changed();
|
|
1896
|
-
}
|
|
1897
|
-
}
|
|
1898
|
-
|
|
1899
1766
|
// check if all slots are idle
|
|
1900
1767
|
{
|
|
1901
1768
|
bool all_idle = true;
|
|
1902
1769
|
|
|
1903
1770
|
for (auto & slot : slots) {
|
|
1904
|
-
if (slot.
|
|
1771
|
+
if (slot.is_processing()) {
|
|
1905
1772
|
all_idle = false;
|
|
1906
1773
|
break;
|
|
1907
1774
|
}
|
|
1908
1775
|
}
|
|
1909
1776
|
|
|
1910
1777
|
if (all_idle) {
|
|
1911
|
-
|
|
1912
|
-
if (
|
|
1778
|
+
SRV_INF("%s", "all slots are idle\n");
|
|
1779
|
+
if (clean_kv_cache) {
|
|
1913
1780
|
kv_cache_clear();
|
|
1914
1781
|
}
|
|
1915
1782
|
|
|
@@ -1918,7 +1785,7 @@ struct server_context {
|
|
|
1918
1785
|
}
|
|
1919
1786
|
|
|
1920
1787
|
{
|
|
1921
|
-
|
|
1788
|
+
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
|
|
1922
1789
|
|
|
1923
1790
|
server_task task;
|
|
1924
1791
|
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
|
|
@@ -1930,59 +1797,51 @@ struct server_context {
|
|
|
1930
1797
|
// apply context-shift if needed
|
|
1931
1798
|
// TODO: simplify and improve
|
|
1932
1799
|
for (server_slot & slot : slots) {
|
|
1933
|
-
if (slot.
|
|
1934
|
-
if (
|
|
1935
|
-
//
|
|
1936
|
-
|
|
1937
|
-
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
{"id_slot", slot.id},
|
|
1942
|
-
{"id_task", slot.id_task},
|
|
1943
|
-
{"n_keep", n_keep},
|
|
1944
|
-
{"n_left", n_left},
|
|
1945
|
-
{"n_discard", n_discard},
|
|
1946
|
-
{"n_ctx", n_ctx},
|
|
1947
|
-
{"n_past", slot.n_past},
|
|
1948
|
-
{"n_system_tokens", system_tokens.size()},
|
|
1949
|
-
{"n_cache_tokens", slot.cache_tokens.size()}
|
|
1950
|
-
});
|
|
1800
|
+
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
|
|
1801
|
+
if (!params.ctx_shift) {
|
|
1802
|
+
// this check is redundant (for good)
|
|
1803
|
+
// we should never get here, because generation should already stopped in process_token()
|
|
1804
|
+
slot.release();
|
|
1805
|
+
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
|
1806
|
+
continue;
|
|
1807
|
+
}
|
|
1951
1808
|
|
|
1952
|
-
|
|
1953
|
-
|
|
1809
|
+
// Shift context
|
|
1810
|
+
const int n_keep = slot.params.n_keep + add_bos_token;
|
|
1811
|
+
const int n_left = slot.n_past - n_keep;
|
|
1812
|
+
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
|
1954
1813
|
|
|
1955
|
-
|
|
1956
|
-
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
|
1957
|
-
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
|
1958
|
-
}
|
|
1814
|
+
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
|
1959
1815
|
|
|
1960
|
-
|
|
1961
|
-
|
|
1816
|
+
llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
|
|
1817
|
+
llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
|
|
1962
1818
|
|
|
1963
|
-
|
|
1819
|
+
if (slot.params.cache_prompt) {
|
|
1820
|
+
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
|
1821
|
+
slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
|
|
1822
|
+
}
|
|
1964
1823
|
|
|
1965
|
-
slot.
|
|
1824
|
+
slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
|
|
1966
1825
|
}
|
|
1826
|
+
|
|
1827
|
+
slot.n_past -= n_discard;
|
|
1828
|
+
|
|
1829
|
+
slot.truncated = true;
|
|
1967
1830
|
}
|
|
1968
1831
|
}
|
|
1969
1832
|
|
|
1970
1833
|
// start populating the batch for this iteration
|
|
1971
|
-
|
|
1834
|
+
common_batch_clear(batch);
|
|
1972
1835
|
|
|
1973
1836
|
// frist, add sampled tokens from any ongoing sequences
|
|
1974
1837
|
for (auto & slot : slots) {
|
|
1975
|
-
if (slot.state
|
|
1838
|
+
if (slot.state != SLOT_STATE_GENERATING) {
|
|
1976
1839
|
continue;
|
|
1977
1840
|
}
|
|
1978
1841
|
|
|
1979
1842
|
slot.i_batch = batch.n_tokens;
|
|
1980
1843
|
|
|
1981
|
-
|
|
1982
|
-
|
|
1983
|
-
// TODO: we always have to take into account the "system_tokens"
|
|
1984
|
-
// this is not great and needs to be improved somehow
|
|
1985
|
-
llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
|
|
1844
|
+
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
|
1986
1845
|
|
|
1987
1846
|
slot.n_past += 1;
|
|
1988
1847
|
|
|
@@ -1990,15 +1849,8 @@ struct server_context {
|
|
|
1990
1849
|
slot.cache_tokens.push_back(slot.sampled);
|
|
1991
1850
|
}
|
|
1992
1851
|
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
{"id_task", slot.id_task},
|
|
1996
|
-
{"n_ctx", n_ctx},
|
|
1997
|
-
{"n_past", slot.n_past},
|
|
1998
|
-
{"n_system_tokens", system_tokens.size()},
|
|
1999
|
-
{"n_cache_tokens", slot.cache_tokens.size()},
|
|
2000
|
-
{"truncated", slot.truncated}
|
|
2001
|
-
});
|
|
1852
|
+
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
|
|
1853
|
+
slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
|
|
2002
1854
|
}
|
|
2003
1855
|
|
|
2004
1856
|
// process in chunks of params.n_batch
|
|
@@ -2008,111 +1860,86 @@ struct server_context {
|
|
|
2008
1860
|
// track if this is an embedding or non-embedding batch
|
|
2009
1861
|
// if we've added sampled tokens above, we are in non-embedding mode
|
|
2010
1862
|
// -1: none, 0: non-embedding, 1: embedding
|
|
1863
|
+
// TODO: make enum
|
|
2011
1864
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
2012
1865
|
|
|
2013
1866
|
// next, batch any pending prompts without exceeding n_batch
|
|
2014
1867
|
if (params.cont_batching || batch.n_tokens == 0) {
|
|
2015
1868
|
for (auto & slot : slots) {
|
|
2016
1869
|
// this slot still has a prompt to be processed
|
|
2017
|
-
if (slot.state ==
|
|
1870
|
+
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
|
2018
1871
|
auto & prompt_tokens = slot.prompt_tokens;
|
|
2019
1872
|
|
|
2020
|
-
//
|
|
2021
|
-
if (
|
|
2022
|
-
LOG_VERBOSE("tokenizing prompt", {
|
|
2023
|
-
{"id_slot", slot.id},
|
|
2024
|
-
{"id_task", slot.id_task}
|
|
2025
|
-
});
|
|
2026
|
-
|
|
1873
|
+
// TODO: maybe move branch to outside of this loop in the future
|
|
1874
|
+
if (slot.state == SLOT_STATE_STARTED) {
|
|
2027
1875
|
slot.t_start_process_prompt = ggml_time_us();
|
|
2028
1876
|
slot.t_start_generation = 0;
|
|
2029
1877
|
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
|
2034
|
-
params.input_suffix.erase(0, 1);
|
|
2035
|
-
suff_rm_leading_spc = false;
|
|
2036
|
-
}
|
|
2037
|
-
|
|
2038
|
-
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
|
|
2039
|
-
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
|
|
2040
|
-
|
|
2041
|
-
const int space_token = 29871; // TODO: this should not be hardcoded
|
|
2042
|
-
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
|
|
2043
|
-
suffix_tokens.erase(suffix_tokens.begin());
|
|
2044
|
-
}
|
|
2045
|
-
|
|
2046
|
-
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
|
|
2047
|
-
suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
|
|
1878
|
+
slot.n_past = 0;
|
|
1879
|
+
slot.n_prompt_tokens = prompt_tokens.size();
|
|
1880
|
+
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
|
2048
1881
|
|
|
2049
|
-
|
|
2050
|
-
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
|
|
2051
|
-
if (add_bos) {
|
|
2052
|
-
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
|
|
2053
|
-
}
|
|
2054
|
-
embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
|
|
1882
|
+
SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
|
2055
1883
|
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
1884
|
+
// print prompt tokens (for debugging)
|
|
1885
|
+
if (1) {
|
|
1886
|
+
// first 16 tokens (avoid flooding logs)
|
|
1887
|
+
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
|
|
1888
|
+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
|
2059
1889
|
}
|
|
2060
|
-
|
|
2061
|
-
prompt_tokens = embd_inp;
|
|
2062
1890
|
} else {
|
|
2063
|
-
|
|
1891
|
+
// all
|
|
1892
|
+
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
|
|
1893
|
+
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
|
1894
|
+
}
|
|
2064
1895
|
}
|
|
2065
1896
|
|
|
2066
|
-
slot.n_past = 0;
|
|
2067
|
-
slot.n_prompt_tokens = prompt_tokens.size();
|
|
2068
|
-
|
|
2069
|
-
LOG_VERBOSE("prompt tokenized", {
|
|
2070
|
-
{"id_slot", slot.id},
|
|
2071
|
-
{"id_task", slot.id_task},
|
|
2072
|
-
{"n_ctx", slot.n_ctx},
|
|
2073
|
-
{"n_keep", slot.params.n_keep},
|
|
2074
|
-
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
2075
|
-
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
|
2076
|
-
});
|
|
2077
|
-
|
|
2078
1897
|
// empty prompt passed -> release the slot and send empty response
|
|
2079
1898
|
if (prompt_tokens.empty()) {
|
|
2080
|
-
|
|
2081
|
-
{"id_slot", slot.id},
|
|
2082
|
-
{"id_task", slot.id_task}
|
|
2083
|
-
});
|
|
1899
|
+
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
|
2084
1900
|
|
|
2085
|
-
slot.state = SLOT_STATE_PROCESSING;
|
|
2086
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
2087
1901
|
slot.release();
|
|
2088
1902
|
slot.print_timings();
|
|
2089
1903
|
send_final_response(slot);
|
|
2090
1904
|
continue;
|
|
2091
1905
|
}
|
|
2092
1906
|
|
|
2093
|
-
if (slot.
|
|
2094
|
-
// this prompt is too large to process - discard it
|
|
1907
|
+
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
|
2095
1908
|
if (slot.n_prompt_tokens > n_ubatch) {
|
|
2096
|
-
slot.state = SLOT_STATE_PROCESSING;
|
|
2097
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
2098
1909
|
slot.release();
|
|
2099
1910
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
|
2100
1911
|
continue;
|
|
2101
1912
|
}
|
|
1913
|
+
|
|
1914
|
+
if (slot.n_prompt_tokens > slot.n_ctx) {
|
|
1915
|
+
slot.release();
|
|
1916
|
+
send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
|
|
1917
|
+
continue;
|
|
1918
|
+
}
|
|
2102
1919
|
} else {
|
|
1920
|
+
if (!params.ctx_shift) {
|
|
1921
|
+
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
|
1922
|
+
// TODO: there should be a separate parameter that control prompt truncation
|
|
1923
|
+
// context shift should be applied only during the generation phase
|
|
1924
|
+
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
|
1925
|
+
slot.release();
|
|
1926
|
+
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
|
1927
|
+
continue;
|
|
1928
|
+
}
|
|
1929
|
+
}
|
|
2103
1930
|
if (slot.params.n_keep < 0) {
|
|
2104
1931
|
slot.params.n_keep = slot.n_prompt_tokens;
|
|
2105
1932
|
}
|
|
2106
1933
|
slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
|
|
2107
1934
|
|
|
2108
|
-
// if input prompt is too big, truncate it
|
|
2109
|
-
if (slot.
|
|
1935
|
+
// if input prompt is too big, truncate it
|
|
1936
|
+
if (slot.n_prompt_tokens >= slot.n_ctx) {
|
|
2110
1937
|
const int n_left = slot.n_ctx - slot.params.n_keep;
|
|
2111
1938
|
|
|
2112
1939
|
const int n_block_size = n_left / 2;
|
|
2113
1940
|
const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
|
|
2114
1941
|
|
|
2115
|
-
|
|
1942
|
+
llama_tokens new_tokens(
|
|
2116
1943
|
prompt_tokens.begin(),
|
|
2117
1944
|
prompt_tokens.begin() + slot.params.n_keep);
|
|
2118
1945
|
|
|
@@ -2126,54 +1953,73 @@ struct server_context {
|
|
|
2126
1953
|
slot.truncated = true;
|
|
2127
1954
|
slot.n_prompt_tokens = prompt_tokens.size();
|
|
2128
1955
|
|
|
2129
|
-
|
|
2130
|
-
{"id_slot", slot.id},
|
|
2131
|
-
{"id_task", slot.id_task},
|
|
2132
|
-
{"n_ctx", slot.n_ctx},
|
|
2133
|
-
{"n_keep", slot.params.n_keep},
|
|
2134
|
-
{"n_left", n_left},
|
|
2135
|
-
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
2136
|
-
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
|
2137
|
-
});
|
|
1956
|
+
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
|
|
2138
1957
|
|
|
2139
1958
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
|
2140
1959
|
}
|
|
2141
1960
|
|
|
2142
|
-
|
|
1961
|
+
if (slot.params.cache_prompt) {
|
|
1962
|
+
// reuse any previously computed tokens that are common with the new prompt
|
|
1963
|
+
slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
|
|
2143
1964
|
|
|
2144
|
-
|
|
2145
|
-
|
|
2146
|
-
|
|
2147
|
-
|
|
2148
|
-
GGML_ASSERT(slot.ga_n == 1);
|
|
1965
|
+
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
|
1966
|
+
if (params.n_cache_reuse > 0) {
|
|
1967
|
+
size_t head_c = slot.n_past; // cache
|
|
1968
|
+
size_t head_p = slot.n_past; // current prompt
|
|
2149
1969
|
|
|
2150
|
-
|
|
2151
|
-
|
|
1970
|
+
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
|
|
1971
|
+
|
|
1972
|
+
while (head_c < slot.cache_tokens.size() &&
|
|
1973
|
+
head_p < prompt_tokens.size()) {
|
|
1974
|
+
|
|
1975
|
+
size_t n_match = 0;
|
|
1976
|
+
while (head_c + n_match < slot.cache_tokens.size() &&
|
|
1977
|
+
head_p + n_match < prompt_tokens.size() &&
|
|
1978
|
+
slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
|
|
1979
|
+
|
|
1980
|
+
n_match++;
|
|
1981
|
+
}
|
|
1982
|
+
|
|
1983
|
+
if (n_match >= (size_t) params.n_cache_reuse) {
|
|
1984
|
+
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
|
1985
|
+
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
|
1986
|
+
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
|
1987
|
+
//}
|
|
1988
|
+
|
|
1989
|
+
const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
|
|
1990
|
+
|
|
1991
|
+
llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
|
|
1992
|
+
llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
|
|
1993
|
+
|
|
1994
|
+
for (size_t i = 0; i < n_match; i++) {
|
|
1995
|
+
slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
|
|
1996
|
+
slot.n_past++;
|
|
1997
|
+
}
|
|
2152
1998
|
|
|
2153
|
-
|
|
2154
|
-
|
|
2155
|
-
|
|
1999
|
+
head_c += n_match;
|
|
2000
|
+
head_p += n_match;
|
|
2001
|
+
} else {
|
|
2002
|
+
head_c += 1;
|
|
2003
|
+
}
|
|
2004
|
+
}
|
|
2005
|
+
|
|
2006
|
+
SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
|
|
2156
2007
|
}
|
|
2157
2008
|
}
|
|
2158
2009
|
}
|
|
2159
2010
|
|
|
2160
2011
|
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
|
2161
2012
|
// we have to evaluate at least 1 token to generate logits.
|
|
2162
|
-
|
|
2163
|
-
{ "id_slot", slot.id },
|
|
2164
|
-
{ "id_task", slot.id_task }
|
|
2165
|
-
});
|
|
2013
|
+
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
|
2166
2014
|
|
|
2167
2015
|
slot.n_past--;
|
|
2168
|
-
if (slot.ga_i > 0) {
|
|
2169
|
-
slot.n_past_se--;
|
|
2170
|
-
}
|
|
2171
2016
|
}
|
|
2172
2017
|
|
|
2173
2018
|
slot.n_prompt_tokens_processed = 0;
|
|
2174
2019
|
}
|
|
2175
2020
|
|
|
2176
|
-
|
|
2021
|
+
// non-causal tasks require to fit the entire prompt in the physical batch
|
|
2022
|
+
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
|
2177
2023
|
// cannot fit the prompt in the current batch - will try next iter
|
|
2178
2024
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
|
2179
2025
|
continue;
|
|
@@ -2181,7 +2027,10 @@ struct server_context {
|
|
|
2181
2027
|
}
|
|
2182
2028
|
|
|
2183
2029
|
// check that we are in the right batch_type, if not defer the slot
|
|
2184
|
-
bool slot_type =
|
|
2030
|
+
const bool slot_type =
|
|
2031
|
+
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
|
2032
|
+
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
|
2033
|
+
|
|
2185
2034
|
if (batch_type == -1) {
|
|
2186
2035
|
batch_type = slot_type;
|
|
2187
2036
|
} else if (batch_type != slot_type) {
|
|
@@ -2189,88 +2038,53 @@ struct server_context {
|
|
|
2189
2038
|
}
|
|
2190
2039
|
|
|
2191
2040
|
// keep only the common part
|
|
2192
|
-
|
|
2193
|
-
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
|
2041
|
+
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
|
|
2194
2042
|
// could not partially delete (likely using a non-Transformer model)
|
|
2195
|
-
llama_kv_cache_seq_rm(ctx, slot.id
|
|
2196
|
-
|
|
2197
|
-
p0 = (int) system_tokens.size();
|
|
2198
|
-
if (p0 != 0) {
|
|
2199
|
-
// copy over the system prompt when there is one
|
|
2200
|
-
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
|
|
2201
|
-
}
|
|
2043
|
+
llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
|
|
2202
2044
|
|
|
2203
|
-
// there is no common part left
|
|
2045
|
+
// there is no common part left
|
|
2204
2046
|
slot.n_past = 0;
|
|
2205
|
-
slot.n_past_se = 0;
|
|
2206
|
-
slot.ga_i = 0;
|
|
2207
|
-
// TODO: is the system prompt ever in the sampling context?
|
|
2208
|
-
llama_sampling_reset(slot.ctx_sampling);
|
|
2209
2047
|
}
|
|
2210
2048
|
|
|
2049
|
+
SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
|
|
2050
|
+
|
|
2211
2051
|
// remove the non-common part from the cache
|
|
2212
2052
|
slot.cache_tokens.resize(slot.n_past);
|
|
2213
2053
|
|
|
2214
|
-
LOG_INFO("kv cache rm [p0, end)", {
|
|
2215
|
-
{ "id_slot", slot.id },
|
|
2216
|
-
{ "id_task", slot.id_task },
|
|
2217
|
-
{ "p0", p0 }
|
|
2218
|
-
});
|
|
2219
|
-
|
|
2220
|
-
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
|
2221
|
-
|
|
2222
|
-
int32_t ga_i = slot.ga_i;
|
|
2223
|
-
int32_t ga_n = slot.ga_n;
|
|
2224
|
-
int32_t ga_w = slot.ga_w;
|
|
2225
|
-
|
|
2226
2054
|
// add prompt tokens for processing in the current batch
|
|
2227
|
-
|
|
2228
|
-
|
|
2229
|
-
if (slot.ga_n != 1) {
|
|
2230
|
-
while (slot_npast >= ga_i + ga_w) {
|
|
2231
|
-
const int bd = (ga_w/ga_n)*(ga_n - 1);
|
|
2232
|
-
slot_npast -= bd;
|
|
2233
|
-
ga_i += ga_w/ga_n;
|
|
2234
|
-
}
|
|
2235
|
-
}
|
|
2236
|
-
|
|
2237
|
-
llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
|
|
2055
|
+
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
|
2056
|
+
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
|
|
2238
2057
|
|
|
2239
2058
|
if (slot.params.cache_prompt) {
|
|
2240
2059
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
|
2241
2060
|
}
|
|
2242
2061
|
|
|
2243
2062
|
slot.n_prompt_tokens_processed++;
|
|
2244
|
-
|
|
2063
|
+
slot.n_past++;
|
|
2245
2064
|
}
|
|
2246
2065
|
|
|
2247
|
-
|
|
2248
|
-
{"id_slot", slot.id},
|
|
2249
|
-
{"n_past", slot.n_past},
|
|
2250
|
-
{"n_ctx", n_ctx},
|
|
2251
|
-
{"n_tokens", batch.n_tokens},
|
|
2252
|
-
{"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
|
|
2253
|
-
});
|
|
2066
|
+
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
|
2254
2067
|
|
|
2255
|
-
// entire prompt has been processed
|
|
2068
|
+
// entire prompt has been processed
|
|
2256
2069
|
if (slot.n_past == slot.n_prompt_tokens) {
|
|
2257
|
-
slot.state
|
|
2258
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
2070
|
+
slot.state = SLOT_STATE_DONE_PROMPT;
|
|
2259
2071
|
|
|
2260
2072
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
2261
2073
|
|
|
2074
|
+
common_sampler_reset(slot.smpl);
|
|
2075
|
+
|
|
2076
|
+
// Process all prompt tokens through sampler system
|
|
2077
|
+
for (int i = 0; i < slot.n_prompt_tokens; ++i) {
|
|
2078
|
+
common_sampler_accept(slot.smpl, prompt_tokens[i], false);
|
|
2079
|
+
}
|
|
2080
|
+
|
|
2262
2081
|
// extract the logits only for the last token
|
|
2263
2082
|
batch.logits[batch.n_tokens - 1] = true;
|
|
2264
2083
|
|
|
2265
2084
|
slot.n_decoded = 0;
|
|
2266
2085
|
slot.i_batch = batch.n_tokens - 1;
|
|
2267
2086
|
|
|
2268
|
-
|
|
2269
|
-
{"id_slot", slot.id},
|
|
2270
|
-
{"n_past", slot.n_past},
|
|
2271
|
-
{"n_ctx", n_ctx},
|
|
2272
|
-
{"n_tokens", batch.n_tokens},
|
|
2273
|
-
});
|
|
2087
|
+
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
|
2274
2088
|
}
|
|
2275
2089
|
}
|
|
2276
2090
|
|
|
@@ -2281,13 +2095,11 @@ struct server_context {
|
|
|
2281
2095
|
}
|
|
2282
2096
|
|
|
2283
2097
|
if (batch.n_tokens == 0) {
|
|
2284
|
-
|
|
2098
|
+
SRV_WRN("%s", "no tokens to decode\n");
|
|
2285
2099
|
return;
|
|
2286
2100
|
}
|
|
2287
2101
|
|
|
2288
|
-
|
|
2289
|
-
{"n_tokens", batch.n_tokens},
|
|
2290
|
-
});
|
|
2102
|
+
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
|
2291
2103
|
|
|
2292
2104
|
// make sure we're in the right embedding mode
|
|
2293
2105
|
llama_set_embeddings(ctx, batch_type == 1);
|
|
@@ -2296,35 +2108,6 @@ struct server_context {
|
|
|
2296
2108
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
|
2297
2109
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
|
2298
2110
|
|
|
2299
|
-
for (auto & slot : slots) {
|
|
2300
|
-
if (slot.ga_n != 1) {
|
|
2301
|
-
// context extension via Self-Extend
|
|
2302
|
-
// TODO: simplify and/or abstract this
|
|
2303
|
-
while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
|
|
2304
|
-
const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
|
|
2305
|
-
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
|
2306
|
-
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
|
2307
|
-
|
|
2308
|
-
LOG_TEE("\n");
|
|
2309
|
-
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
|
2310
|
-
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
|
2311
|
-
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
|
2312
|
-
|
|
2313
|
-
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
|
2314
|
-
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
|
2315
|
-
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
|
2316
|
-
|
|
2317
|
-
slot.n_past_se -= bd;
|
|
2318
|
-
|
|
2319
|
-
slot.ga_i += slot.ga_w / slot.ga_n;
|
|
2320
|
-
|
|
2321
|
-
LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
|
2322
|
-
}
|
|
2323
|
-
|
|
2324
|
-
slot.n_past_se += n_tokens;
|
|
2325
|
-
}
|
|
2326
|
-
}
|
|
2327
|
-
|
|
2328
2111
|
llama_batch batch_view = {
|
|
2329
2112
|
n_tokens,
|
|
2330
2113
|
batch.token + i,
|
|
@@ -2333,22 +2116,16 @@ struct server_context {
|
|
|
2333
2116
|
batch.n_seq_id + i,
|
|
2334
2117
|
batch.seq_id + i,
|
|
2335
2118
|
batch.logits + i,
|
|
2336
|
-
0, 0, 0, // unused
|
|
2337
2119
|
};
|
|
2338
2120
|
|
|
2339
2121
|
const int ret = llama_decode(ctx, batch_view);
|
|
2122
|
+
metrics.on_decoded(slots);
|
|
2340
2123
|
|
|
2341
2124
|
if (ret != 0) {
|
|
2342
2125
|
if (n_batch == 1 || ret < 0) {
|
|
2343
2126
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
|
2344
|
-
|
|
2345
|
-
{"i", i},
|
|
2346
|
-
{"n_batch", ret},
|
|
2347
|
-
{"ret", ret},
|
|
2348
|
-
});
|
|
2127
|
+
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
|
2349
2128
|
for (auto & slot : slots) {
|
|
2350
|
-
slot.state = SLOT_STATE_PROCESSING;
|
|
2351
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
2352
2129
|
slot.release();
|
|
2353
2130
|
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
|
2354
2131
|
}
|
|
@@ -2359,32 +2136,42 @@ struct server_context {
|
|
|
2359
2136
|
n_batch /= 2;
|
|
2360
2137
|
i -= n_batch;
|
|
2361
2138
|
|
|
2362
|
-
|
|
2363
|
-
{"i", i},
|
|
2364
|
-
{"n_batch", n_batch},
|
|
2365
|
-
{"ret", ret},
|
|
2366
|
-
});
|
|
2139
|
+
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
|
2367
2140
|
|
|
2368
2141
|
continue; // continue loop of n_batch
|
|
2369
2142
|
}
|
|
2370
2143
|
|
|
2371
2144
|
for (auto & slot : slots) {
|
|
2372
|
-
if (slot.
|
|
2145
|
+
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
|
2373
2146
|
continue; // continue loop of slots
|
|
2374
2147
|
}
|
|
2375
2148
|
|
|
2376
|
-
|
|
2377
|
-
|
|
2378
|
-
|
|
2379
|
-
|
|
2380
|
-
|
|
2149
|
+
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
|
2150
|
+
if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
|
|
2151
|
+
// prompt evaluated for embedding
|
|
2152
|
+
send_embedding(slot, batch_view);
|
|
2153
|
+
slot.release();
|
|
2154
|
+
slot.i_batch = -1;
|
|
2155
|
+
continue; // continue loop of slots
|
|
2156
|
+
}
|
|
2157
|
+
|
|
2158
|
+
if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
|
|
2159
|
+
send_rerank(slot, batch_view);
|
|
2160
|
+
slot.release();
|
|
2161
|
+
slot.i_batch = -1;
|
|
2162
|
+
continue; // continue loop of slots
|
|
2163
|
+
}
|
|
2164
|
+
|
|
2165
|
+
// prompt evaluated for next-token prediction
|
|
2166
|
+
slot.state = SLOT_STATE_GENERATING;
|
|
2167
|
+
} else if (slot.state != SLOT_STATE_GENERATING) {
|
|
2381
2168
|
continue; // continue loop of slots
|
|
2382
2169
|
}
|
|
2383
2170
|
|
|
2384
2171
|
completion_token_output result;
|
|
2385
|
-
const llama_token id =
|
|
2172
|
+
const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
|
2386
2173
|
|
|
2387
|
-
|
|
2174
|
+
common_sampler_accept(slot.smpl, id, true);
|
|
2388
2175
|
|
|
2389
2176
|
slot.n_decoded += 1;
|
|
2390
2177
|
if (slot.n_decoded == 1) {
|
|
@@ -2393,37 +2180,19 @@ struct server_context {
|
|
|
2393
2180
|
metrics.on_prompt_eval(slot);
|
|
2394
2181
|
}
|
|
2395
2182
|
|
|
2396
|
-
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
|
2397
2183
|
result.tok = id;
|
|
2398
2184
|
|
|
2399
|
-
const
|
|
2400
|
-
if (n_probs > 0) {
|
|
2401
|
-
const size_t n_valid = slot.ctx_sampling->n_valid;
|
|
2185
|
+
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
|
2402
2186
|
|
|
2403
|
-
|
|
2404
|
-
|
|
2405
|
-
|
|
2406
|
-
|
|
2407
|
-
|
|
2408
|
-
if (slot.sparams.temp == 0.0f) {
|
|
2409
|
-
// With greedy sampling the probabilities have possibly not been calculated.
|
|
2410
|
-
for (size_t i = 0; i < n_probs; ++i) {
|
|
2411
|
-
result.probs.push_back({
|
|
2412
|
-
cur_p.data[i].id,
|
|
2413
|
-
i == 0 ? 1.0f : 0.0f
|
|
2414
|
-
});
|
|
2415
|
-
}
|
|
2416
|
-
} else {
|
|
2417
|
-
for (size_t i = 0; i < n_probs; ++i) {
|
|
2418
|
-
result.probs.push_back({
|
|
2419
|
-
cur_p.data[i].id,
|
|
2420
|
-
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
|
|
2421
|
-
});
|
|
2422
|
-
}
|
|
2423
|
-
}
|
|
2187
|
+
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
|
|
2188
|
+
result.probs.push_back({
|
|
2189
|
+
cur_p->data[i].id,
|
|
2190
|
+
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
|
2191
|
+
});
|
|
2424
2192
|
}
|
|
2425
2193
|
|
|
2426
2194
|
if (!process_token(result, slot)) {
|
|
2195
|
+
// release slot because of stop condition
|
|
2427
2196
|
slot.release();
|
|
2428
2197
|
slot.print_timings();
|
|
2429
2198
|
send_final_response(slot);
|
|
@@ -2434,7 +2203,7 @@ struct server_context {
|
|
|
2434
2203
|
}
|
|
2435
2204
|
}
|
|
2436
2205
|
|
|
2437
|
-
|
|
2206
|
+
SRV_DBG("%s", "run slots completed\n");
|
|
2438
2207
|
}
|
|
2439
2208
|
|
|
2440
2209
|
json model_meta() const {
|
|
@@ -2455,19 +2224,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
|
|
|
2455
2224
|
return;
|
|
2456
2225
|
}
|
|
2457
2226
|
|
|
2458
|
-
|
|
2459
|
-
{"remote_addr", req.remote_addr},
|
|
2460
|
-
{"remote_port", req.remote_port},
|
|
2461
|
-
{"status", res.status},
|
|
2462
|
-
{"method", req.method},
|
|
2463
|
-
{"path", req.path},
|
|
2464
|
-
{"params", req.params},
|
|
2465
|
-
});
|
|
2227
|
+
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
|
2466
2228
|
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
{"response", res.body},
|
|
2470
|
-
});
|
|
2229
|
+
LOG_DBG("request: %s\n", req.body.c_str());
|
|
2230
|
+
LOG_DBG("response: %s\n", res.body.c_str());
|
|
2471
2231
|
}
|
|
2472
2232
|
|
|
2473
2233
|
std::function<void(int)> shutdown_handler;
|
|
@@ -2485,28 +2245,22 @@ inline void signal_handler(int signal) {
|
|
|
2485
2245
|
}
|
|
2486
2246
|
|
|
2487
2247
|
int main(int argc, char ** argv) {
|
|
2488
|
-
#if SERVER_VERBOSE != 1
|
|
2489
|
-
log_disable();
|
|
2490
|
-
#endif
|
|
2491
2248
|
// own arguments required by this example
|
|
2492
|
-
|
|
2249
|
+
common_params params;
|
|
2493
2250
|
|
|
2494
|
-
if (!
|
|
2495
|
-
gpt_params_print_usage(argc, argv, params);
|
|
2251
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
|
|
2496
2252
|
return 1;
|
|
2497
2253
|
}
|
|
2498
2254
|
|
|
2499
|
-
|
|
2500
|
-
|
|
2501
|
-
|
|
2255
|
+
common_init();
|
|
2256
|
+
|
|
2257
|
+
// enabling this will output extra debug information in the HTTP responses from the server
|
|
2258
|
+
// see format_final_response_oaicompat()
|
|
2259
|
+
const bool verbose = params.verbosity > 9;
|
|
2502
2260
|
|
|
2503
2261
|
// struct that contains llama context and inference
|
|
2504
2262
|
server_context ctx_server;
|
|
2505
2263
|
|
|
2506
|
-
if (!params.system_prompt.empty()) {
|
|
2507
|
-
ctx_server.system_prompt_set(params.system_prompt);
|
|
2508
|
-
}
|
|
2509
|
-
|
|
2510
2264
|
if (params.model_alias == "unknown") {
|
|
2511
2265
|
params.model_alias = params.model;
|
|
2512
2266
|
}
|
|
@@ -2514,58 +2268,60 @@ int main(int argc, char ** argv) {
|
|
|
2514
2268
|
llama_backend_init();
|
|
2515
2269
|
llama_numa_init(params.numa);
|
|
2516
2270
|
|
|
2517
|
-
|
|
2518
|
-
|
|
2519
|
-
|
|
2520
|
-
|
|
2521
|
-
|
|
2522
|
-
|
|
2523
|
-
|
|
2524
|
-
{"
|
|
2525
|
-
{"
|
|
2526
|
-
{"
|
|
2527
|
-
|
|
2271
|
+
LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
|
|
2272
|
+
LOG_INF("\n");
|
|
2273
|
+
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
|
2274
|
+
LOG_INF("\n");
|
|
2275
|
+
|
|
2276
|
+
// static files
|
|
2277
|
+
std::map<std::string, server_static_file> static_files = {
|
|
2278
|
+
{ "/", { index_html, index_html_len, "text/html; charset=utf-8" }},
|
|
2279
|
+
{ "/completion.js", { completion_js, completion_js_len, "text/javascript; charset=utf-8" }},
|
|
2280
|
+
{ "/deps_daisyui.min.css", { deps_daisyui_min_css, deps_daisyui_min_css_len, "text/css; charset=utf-8" }},
|
|
2281
|
+
{ "/deps_markdown-it.js", { deps_markdown_it_js, deps_markdown_it_js_len, "text/javascript; charset=utf-8" }},
|
|
2282
|
+
{ "/deps_tailwindcss.js", { deps_tailwindcss_js, deps_tailwindcss_js_len, "text/javascript; charset=utf-8" }},
|
|
2283
|
+
{ "/deps_vue.esm-browser.js", { deps_vue_esm_browser_js, deps_vue_esm_browser_js_len, "text/javascript; charset=utf-8" }},
|
|
2284
|
+
};
|
|
2528
2285
|
|
|
2529
2286
|
std::unique_ptr<httplib::Server> svr;
|
|
2530
2287
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
2531
2288
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
|
2532
|
-
|
|
2289
|
+
LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
|
|
2533
2290
|
svr.reset(
|
|
2534
2291
|
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
|
|
2535
2292
|
);
|
|
2536
2293
|
} else {
|
|
2537
|
-
|
|
2294
|
+
LOG_INF("Running without SSL\n");
|
|
2538
2295
|
svr.reset(new httplib::Server());
|
|
2539
2296
|
}
|
|
2540
2297
|
#else
|
|
2298
|
+
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
|
2299
|
+
LOG_ERR("Server is built without SSL support\n");
|
|
2300
|
+
return 1;
|
|
2301
|
+
}
|
|
2541
2302
|
svr.reset(new httplib::Server());
|
|
2542
2303
|
#endif
|
|
2543
2304
|
|
|
2544
2305
|
std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
|
|
2545
2306
|
|
|
2546
2307
|
svr->set_default_headers({{"Server", "llama.cpp"}});
|
|
2547
|
-
|
|
2548
|
-
// CORS preflight
|
|
2549
|
-
svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
|
|
2550
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
2551
|
-
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
2552
|
-
res.set_header("Access-Control-Allow-Methods", "POST");
|
|
2553
|
-
res.set_header("Access-Control-Allow-Headers", "*");
|
|
2554
|
-
return res.set_content("", "application/json; charset=utf-8");
|
|
2555
|
-
});
|
|
2556
|
-
|
|
2557
2308
|
svr->set_logger(log_server_request);
|
|
2558
2309
|
|
|
2559
|
-
auto res_error = [](httplib::Response & res, json error_data) {
|
|
2310
|
+
auto res_error = [](httplib::Response & res, const json & error_data) {
|
|
2560
2311
|
json final_response {{"error", error_data}};
|
|
2561
|
-
res.set_content(final_response.dump(
|
|
2312
|
+
res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
|
2562
2313
|
res.status = json_value(error_data, "code", 500);
|
|
2563
2314
|
};
|
|
2564
2315
|
|
|
2316
|
+
auto res_ok = [](httplib::Response & res, const json & data) {
|
|
2317
|
+
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
|
2318
|
+
res.status = 200;
|
|
2319
|
+
};
|
|
2320
|
+
|
|
2565
2321
|
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
|
2566
2322
|
std::string message;
|
|
2567
2323
|
try {
|
|
2568
|
-
std::rethrow_exception(
|
|
2324
|
+
std::rethrow_exception(ep);
|
|
2569
2325
|
} catch (std::exception & e) {
|
|
2570
2326
|
message = e.what();
|
|
2571
2327
|
} catch (...) {
|
|
@@ -2573,7 +2329,7 @@ int main(int argc, char ** argv) {
|
|
|
2573
2329
|
}
|
|
2574
2330
|
|
|
2575
2331
|
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
|
2576
|
-
|
|
2332
|
+
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
|
|
2577
2333
|
res_error(res, formatted_error);
|
|
2578
2334
|
});
|
|
2579
2335
|
|
|
@@ -2588,11 +2344,6 @@ int main(int argc, char ** argv) {
|
|
|
2588
2344
|
svr->set_read_timeout (params.timeout_read);
|
|
2589
2345
|
svr->set_write_timeout(params.timeout_write);
|
|
2590
2346
|
|
|
2591
|
-
if (!svr->bind_to_port(params.hostname, params.port)) {
|
|
2592
|
-
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
|
|
2593
|
-
return 1;
|
|
2594
|
-
}
|
|
2595
|
-
|
|
2596
2347
|
std::unordered_map<std::string, std::string> log_data;
|
|
2597
2348
|
|
|
2598
2349
|
log_data["hostname"] = params.hostname;
|
|
@@ -2608,54 +2359,15 @@ int main(int argc, char ** argv) {
|
|
|
2608
2359
|
// Necessary similarity of prompt for slot selection
|
|
2609
2360
|
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
|
|
2610
2361
|
|
|
2611
|
-
// load the model
|
|
2612
|
-
if (!ctx_server.load_model(params)) {
|
|
2613
|
-
state.store(SERVER_STATE_ERROR);
|
|
2614
|
-
return 1;
|
|
2615
|
-
} else {
|
|
2616
|
-
ctx_server.init();
|
|
2617
|
-
state.store(SERVER_STATE_READY);
|
|
2618
|
-
}
|
|
2619
|
-
|
|
2620
|
-
LOG_INFO("model loaded", {});
|
|
2621
|
-
|
|
2622
|
-
const auto model_meta = ctx_server.model_meta();
|
|
2623
|
-
|
|
2624
|
-
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
|
2625
|
-
if (params.chat_template.empty()) {
|
|
2626
|
-
if (!ctx_server.validate_model_chat_template()) {
|
|
2627
|
-
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
|
2628
|
-
params.chat_template = "chatml";
|
|
2629
|
-
}
|
|
2630
|
-
}
|
|
2631
|
-
|
|
2632
|
-
// print sample chat example to make it clear which template is used
|
|
2633
|
-
{
|
|
2634
|
-
LOG_INFO("chat template", {
|
|
2635
|
-
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
|
2636
|
-
{"built_in", params.chat_template.empty()},
|
|
2637
|
-
});
|
|
2638
|
-
}
|
|
2639
|
-
|
|
2640
2362
|
//
|
|
2641
2363
|
// Middlewares
|
|
2642
2364
|
//
|
|
2643
2365
|
|
|
2644
|
-
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
2645
|
-
|
|
2646
|
-
|
|
2647
|
-
"/
|
|
2648
|
-
"/
|
|
2649
|
-
"/completions",
|
|
2650
|
-
"/v1/completions",
|
|
2651
|
-
"/chat/completions",
|
|
2652
|
-
"/v1/chat/completions",
|
|
2653
|
-
"/infill",
|
|
2654
|
-
"/tokenize",
|
|
2655
|
-
"/detokenize",
|
|
2656
|
-
"/embedding",
|
|
2657
|
-
"/embeddings",
|
|
2658
|
-
"/v1/embeddings",
|
|
2366
|
+
auto middleware_validate_api_key = [¶ms, &res_error, &static_files](const httplib::Request & req, httplib::Response & res) {
|
|
2367
|
+
static const std::unordered_set<std::string> public_endpoints = {
|
|
2368
|
+
"/health",
|
|
2369
|
+
"/models",
|
|
2370
|
+
"/v1/models",
|
|
2659
2371
|
};
|
|
2660
2372
|
|
|
2661
2373
|
// If API key is not set, skip validation
|
|
@@ -2663,8 +2375,8 @@ int main(int argc, char ** argv) {
|
|
|
2663
2375
|
return true;
|
|
2664
2376
|
}
|
|
2665
2377
|
|
|
2666
|
-
// If path is
|
|
2667
|
-
if (
|
|
2378
|
+
// If path is public or is static file, skip validation
|
|
2379
|
+
if (public_endpoints.find(req.path) != public_endpoints.end() || static_files.find(req.path) != static_files.end()) {
|
|
2668
2380
|
return true;
|
|
2669
2381
|
}
|
|
2670
2382
|
|
|
@@ -2680,17 +2392,42 @@ int main(int argc, char ** argv) {
|
|
|
2680
2392
|
}
|
|
2681
2393
|
|
|
2682
2394
|
// API key is invalid or not provided
|
|
2683
|
-
// TODO: make another middleware for CORS related logic
|
|
2684
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
2685
2395
|
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
|
2686
2396
|
|
|
2687
|
-
|
|
2397
|
+
LOG_WRN("Unauthorized: Invalid API Key\n");
|
|
2688
2398
|
|
|
2689
2399
|
return false;
|
|
2690
2400
|
};
|
|
2691
2401
|
|
|
2402
|
+
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
|
2403
|
+
server_state current_state = state.load();
|
|
2404
|
+
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
|
2405
|
+
auto tmp = string_split<std::string>(req.path, '.');
|
|
2406
|
+
if (req.path == "/" || tmp.back() == "html") {
|
|
2407
|
+
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
|
2408
|
+
res.status = 503;
|
|
2409
|
+
} else {
|
|
2410
|
+
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
|
2411
|
+
}
|
|
2412
|
+
return false;
|
|
2413
|
+
}
|
|
2414
|
+
return true;
|
|
2415
|
+
};
|
|
2416
|
+
|
|
2692
2417
|
// register server middlewares
|
|
2693
|
-
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
|
2418
|
+
svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
|
2419
|
+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
2420
|
+
// If this is OPTIONS request, skip validation because browsers don't include Authorization header
|
|
2421
|
+
if (req.method == "OPTIONS") {
|
|
2422
|
+
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
2423
|
+
res.set_header("Access-Control-Allow-Methods", "GET, POST");
|
|
2424
|
+
res.set_header("Access-Control-Allow-Headers", "*");
|
|
2425
|
+
res.set_content("", "text/html"); // blank response, no data
|
|
2426
|
+
return httplib::Server::HandlerResponse::Handled; // skip further processing
|
|
2427
|
+
}
|
|
2428
|
+
if (!middleware_server_state(req, res)) {
|
|
2429
|
+
return httplib::Server::HandlerResponse::Handled;
|
|
2430
|
+
}
|
|
2694
2431
|
if (!middleware_validate_api_key(req, res)) {
|
|
2695
2432
|
return httplib::Server::HandlerResponse::Handled;
|
|
2696
2433
|
}
|
|
@@ -2701,99 +2438,57 @@ int main(int argc, char ** argv) {
|
|
|
2701
2438
|
// Route handlers (or controllers)
|
|
2702
2439
|
//
|
|
2703
2440
|
|
|
2704
|
-
const auto handle_health = [&](const httplib::Request
|
|
2705
|
-
|
|
2706
|
-
|
|
2707
|
-
|
|
2708
|
-
{
|
|
2709
|
-
// request slots data using task queue
|
|
2710
|
-
server_task task;
|
|
2711
|
-
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2712
|
-
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2713
|
-
task.id_target = -1;
|
|
2714
|
-
|
|
2715
|
-
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2716
|
-
ctx_server.queue_tasks.post(task);
|
|
2717
|
-
|
|
2718
|
-
// get the result
|
|
2719
|
-
server_task_result result = ctx_server.queue_results.recv(task.id);
|
|
2720
|
-
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2721
|
-
|
|
2722
|
-
const int n_idle_slots = result.data.at("idle");
|
|
2723
|
-
const int n_processing_slots = result.data.at("processing");
|
|
2724
|
-
|
|
2725
|
-
json health = {
|
|
2726
|
-
{"status", "ok"},
|
|
2727
|
-
{"slots_idle", n_idle_slots},
|
|
2728
|
-
{"slots_processing", n_processing_slots}
|
|
2729
|
-
};
|
|
2730
|
-
|
|
2731
|
-
res.status = 200; // HTTP OK
|
|
2732
|
-
if (params.endpoint_slots && req.has_param("include_slots")) {
|
|
2733
|
-
health["slots"] = result.data.at("slots");
|
|
2734
|
-
}
|
|
2735
|
-
|
|
2736
|
-
if (n_idle_slots == 0) {
|
|
2737
|
-
health["status"] = "no slot available";
|
|
2738
|
-
if (req.has_param("fail_on_no_slot")) {
|
|
2739
|
-
res.status = 503; // HTTP Service Unavailable
|
|
2740
|
-
}
|
|
2741
|
-
}
|
|
2742
|
-
|
|
2743
|
-
res.set_content(health.dump(), "application/json");
|
|
2744
|
-
break;
|
|
2745
|
-
}
|
|
2746
|
-
case SERVER_STATE_LOADING_MODEL:
|
|
2747
|
-
{
|
|
2748
|
-
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
|
2749
|
-
} break;
|
|
2750
|
-
case SERVER_STATE_ERROR:
|
|
2751
|
-
{
|
|
2752
|
-
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
|
2753
|
-
} break;
|
|
2754
|
-
}
|
|
2441
|
+
const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
|
|
2442
|
+
// error and loading states are handled by middleware
|
|
2443
|
+
json health = {{"status", "ok"}};
|
|
2444
|
+
res_ok(res, health);
|
|
2755
2445
|
};
|
|
2756
2446
|
|
|
2757
|
-
const auto handle_slots = [&](const httplib::Request
|
|
2447
|
+
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
|
|
2758
2448
|
if (!params.endpoint_slots) {
|
|
2759
|
-
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
|
2449
|
+
res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2760
2450
|
return;
|
|
2761
2451
|
}
|
|
2762
2452
|
|
|
2763
2453
|
// request slots data using task queue
|
|
2764
2454
|
server_task task;
|
|
2765
2455
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2766
|
-
task.id_multi = -1;
|
|
2767
|
-
task.id_target = -1;
|
|
2768
2456
|
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2769
2457
|
|
|
2770
2458
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2771
|
-
ctx_server.queue_tasks.post(task);
|
|
2459
|
+
ctx_server.queue_tasks.post(task, true); // high-priority task
|
|
2772
2460
|
|
|
2773
2461
|
// get the result
|
|
2774
2462
|
server_task_result result = ctx_server.queue_results.recv(task.id);
|
|
2775
2463
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2776
2464
|
|
|
2777
|
-
|
|
2778
|
-
|
|
2465
|
+
// optionally return "fail_on_no_slot" error
|
|
2466
|
+
const int n_idle_slots = result.data.at("idle");
|
|
2467
|
+
if (req.has_param("fail_on_no_slot")) {
|
|
2468
|
+
if (n_idle_slots == 0) {
|
|
2469
|
+
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
|
2470
|
+
return;
|
|
2471
|
+
}
|
|
2472
|
+
}
|
|
2473
|
+
|
|
2474
|
+
res_ok(res, result.data.at("slots"));
|
|
2779
2475
|
};
|
|
2780
2476
|
|
|
2781
2477
|
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
|
2782
2478
|
if (!params.endpoint_metrics) {
|
|
2783
|
-
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
|
2479
|
+
res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2784
2480
|
return;
|
|
2785
2481
|
}
|
|
2786
2482
|
|
|
2787
2483
|
// request slots data using task queue
|
|
2788
2484
|
server_task task;
|
|
2789
2485
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2790
|
-
task.id_multi = -1;
|
|
2791
2486
|
task.id_target = -1;
|
|
2792
2487
|
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2793
2488
|
task.data.push_back({{"reset_bucket", true}});
|
|
2794
2489
|
|
|
2795
2490
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2796
|
-
ctx_server.queue_tasks.post(task);
|
|
2491
|
+
ctx_server.queue_tasks.post(task, true); // high-priority task
|
|
2797
2492
|
|
|
2798
2493
|
// get the result
|
|
2799
2494
|
server_task_result result = ctx_server.queue_results.recv(task.id);
|
|
@@ -2807,6 +2502,9 @@ int main(int argc, char ** argv) {
|
|
|
2807
2502
|
const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
|
|
2808
2503
|
const uint64_t t_tokens_generation = data.at("t_tokens_generation");
|
|
2809
2504
|
|
|
2505
|
+
const uint64_t n_decode_total = data.at("n_decode_total");
|
|
2506
|
+
const uint64_t n_busy_slots_total = data.at("n_busy_slots_total");
|
|
2507
|
+
|
|
2810
2508
|
const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
|
|
2811
2509
|
|
|
2812
2510
|
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
|
|
@@ -2827,6 +2525,14 @@ int main(int argc, char ** argv) {
|
|
|
2827
2525
|
{"name", "tokens_predicted_seconds_total"},
|
|
2828
2526
|
{"help", "Predict process time"},
|
|
2829
2527
|
{"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
|
|
2528
|
+
}, {
|
|
2529
|
+
{"name", "n_decode_total"},
|
|
2530
|
+
{"help", "Total number of llama_decode() calls"},
|
|
2531
|
+
{"value", n_decode_total}
|
|
2532
|
+
}, {
|
|
2533
|
+
{"name", "n_busy_slots_per_decode"},
|
|
2534
|
+
{"help", "Average number of busy slots per llama_decode() call"},
|
|
2535
|
+
{"value", (float) n_busy_slots_total / (float) n_decode_total}
|
|
2830
2536
|
}}},
|
|
2831
2537
|
{"gauge", {{
|
|
2832
2538
|
{"name", "prompt_tokens_seconds"},
|
|
@@ -2879,7 +2585,7 @@ int main(int argc, char ** argv) {
|
|
|
2879
2585
|
res.status = 200; // HTTP OK
|
|
2880
2586
|
};
|
|
2881
2587
|
|
|
2882
|
-
const auto handle_slots_save = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
2588
|
+
const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
2883
2589
|
json request_data = json::parse(req.body);
|
|
2884
2590
|
std::string filename = request_data.at("filename");
|
|
2885
2591
|
if (!fs_validate_filename(filename)) {
|
|
@@ -2893,7 +2599,7 @@ int main(int argc, char ** argv) {
|
|
|
2893
2599
|
task.data = {
|
|
2894
2600
|
{ "id_slot", id_slot },
|
|
2895
2601
|
{ "filename", filename },
|
|
2896
|
-
{ "filepath", filepath }
|
|
2602
|
+
{ "filepath", filepath },
|
|
2897
2603
|
};
|
|
2898
2604
|
|
|
2899
2605
|
const int id_task = ctx_server.queue_tasks.post(task);
|
|
@@ -2905,11 +2611,11 @@ int main(int argc, char ** argv) {
|
|
|
2905
2611
|
if (result.error) {
|
|
2906
2612
|
res_error(res, result.data);
|
|
2907
2613
|
} else {
|
|
2908
|
-
res
|
|
2614
|
+
res_ok(res, result.data);
|
|
2909
2615
|
}
|
|
2910
2616
|
};
|
|
2911
2617
|
|
|
2912
|
-
const auto handle_slots_restore = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
2618
|
+
const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
2913
2619
|
json request_data = json::parse(req.body);
|
|
2914
2620
|
std::string filename = request_data.at("filename");
|
|
2915
2621
|
if (!fs_validate_filename(filename)) {
|
|
@@ -2923,7 +2629,7 @@ int main(int argc, char ** argv) {
|
|
|
2923
2629
|
task.data = {
|
|
2924
2630
|
{ "id_slot", id_slot },
|
|
2925
2631
|
{ "filename", filename },
|
|
2926
|
-
{ "filepath", filepath }
|
|
2632
|
+
{ "filepath", filepath },
|
|
2927
2633
|
};
|
|
2928
2634
|
|
|
2929
2635
|
const int id_task = ctx_server.queue_tasks.post(task);
|
|
@@ -2935,11 +2641,11 @@ int main(int argc, char ** argv) {
|
|
|
2935
2641
|
if (result.error) {
|
|
2936
2642
|
res_error(res, result.data);
|
|
2937
2643
|
} else {
|
|
2938
|
-
res
|
|
2644
|
+
res_ok(res, result.data);
|
|
2939
2645
|
}
|
|
2940
2646
|
};
|
|
2941
2647
|
|
|
2942
|
-
const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
|
2648
|
+
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
|
2943
2649
|
server_task task;
|
|
2944
2650
|
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
|
|
2945
2651
|
task.data = {
|
|
@@ -2955,12 +2661,15 @@ int main(int argc, char ** argv) {
|
|
|
2955
2661
|
if (result.error) {
|
|
2956
2662
|
res_error(res, result.data);
|
|
2957
2663
|
} else {
|
|
2958
|
-
res
|
|
2664
|
+
res_ok(res, result.data);
|
|
2959
2665
|
}
|
|
2960
2666
|
};
|
|
2961
2667
|
|
|
2962
|
-
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
|
2963
|
-
|
|
2668
|
+
const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
|
2669
|
+
if (params.slot_save_path.empty()) {
|
|
2670
|
+
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2671
|
+
return;
|
|
2672
|
+
}
|
|
2964
2673
|
|
|
2965
2674
|
std::string id_slot_str = req.path_params.at("id_slot");
|
|
2966
2675
|
int id_slot;
|
|
@@ -2985,298 +2694,262 @@ int main(int argc, char ** argv) {
|
|
|
2985
2694
|
}
|
|
2986
2695
|
};
|
|
2987
2696
|
|
|
2988
|
-
const auto handle_props = [&ctx_server](const httplib::Request
|
|
2989
|
-
std::string template_key = "tokenizer.chat_template", curr_tmpl;
|
|
2990
|
-
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
|
2991
|
-
if (tlen > 0) {
|
|
2992
|
-
std::vector<char> curr_tmpl_buf(tlen + 1, 0);
|
|
2993
|
-
if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
|
|
2994
|
-
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
|
|
2995
|
-
}
|
|
2996
|
-
}
|
|
2997
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
2697
|
+
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
|
2998
2698
|
json data = {
|
|
2999
|
-
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
|
3000
2699
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
|
3001
2700
|
{ "total_slots", ctx_server.params.n_parallel },
|
|
3002
|
-
{ "chat_template",
|
|
2701
|
+
{ "chat_template", llama_get_chat_template(ctx_server.model) },
|
|
3003
2702
|
};
|
|
3004
2703
|
|
|
3005
|
-
res
|
|
2704
|
+
res_ok(res, data);
|
|
3006
2705
|
};
|
|
3007
2706
|
|
|
3008
|
-
const auto
|
|
3009
|
-
if (ctx_server.params.
|
|
3010
|
-
res_error(res, format_error_response("This server does not support
|
|
2707
|
+
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
2708
|
+
if (!ctx_server.params.endpoint_props) {
|
|
2709
|
+
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3011
2710
|
return;
|
|
3012
2711
|
}
|
|
3013
2712
|
|
|
3014
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
3015
|
-
|
|
3016
2713
|
json data = json::parse(req.body);
|
|
3017
2714
|
|
|
3018
|
-
|
|
2715
|
+
// update any props here
|
|
3019
2716
|
|
|
3020
|
-
|
|
3021
|
-
|
|
2717
|
+
res_ok(res, {{ "success", true }});
|
|
2718
|
+
};
|
|
3022
2719
|
|
|
3023
|
-
|
|
3024
|
-
|
|
3025
|
-
|
|
3026
|
-
|
|
3027
|
-
|
|
3028
|
-
res_error(res, result.data);
|
|
3029
|
-
}
|
|
2720
|
+
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
|
|
2721
|
+
if (ctx_server.params.embedding) {
|
|
2722
|
+
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2723
|
+
return;
|
|
2724
|
+
}
|
|
3030
2725
|
|
|
3031
|
-
|
|
3032
|
-
|
|
3033
|
-
|
|
3034
|
-
while (true) {
|
|
3035
|
-
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
3036
|
-
if (!result.error) {
|
|
3037
|
-
const std::string str =
|
|
3038
|
-
"data: " +
|
|
3039
|
-
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3040
|
-
"\n\n";
|
|
3041
|
-
|
|
3042
|
-
LOG_VERBOSE("data stream", {
|
|
3043
|
-
{ "to_send", str }
|
|
3044
|
-
});
|
|
3045
|
-
|
|
3046
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3047
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3048
|
-
return false;
|
|
3049
|
-
}
|
|
2726
|
+
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
|
|
2727
|
+
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
2728
|
+
ctx_server.queue_tasks.post(tasks);
|
|
3050
2729
|
|
|
3051
|
-
|
|
3052
|
-
|
|
3053
|
-
}
|
|
3054
|
-
} else {
|
|
3055
|
-
const std::string str =
|
|
3056
|
-
"error: " +
|
|
3057
|
-
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3058
|
-
"\n\n";
|
|
3059
|
-
|
|
3060
|
-
LOG_VERBOSE("data stream", {
|
|
3061
|
-
{ "to_send", str }
|
|
3062
|
-
});
|
|
3063
|
-
|
|
3064
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3065
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3066
|
-
return false;
|
|
3067
|
-
}
|
|
2730
|
+
bool stream = json_value(data, "stream", false);
|
|
2731
|
+
const auto task_ids = server_task::get_list_id(tasks);
|
|
3068
2732
|
|
|
3069
|
-
|
|
2733
|
+
if (!stream) {
|
|
2734
|
+
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
|
2735
|
+
if (results.size() == 1) {
|
|
2736
|
+
// single result
|
|
2737
|
+
res_ok(res, results[0].data);
|
|
2738
|
+
} else {
|
|
2739
|
+
// multiple results (multitask)
|
|
2740
|
+
json arr = json::array();
|
|
2741
|
+
for (const auto & res : results) {
|
|
2742
|
+
arr.push_back(res.data);
|
|
3070
2743
|
}
|
|
2744
|
+
res_ok(res, arr);
|
|
3071
2745
|
}
|
|
2746
|
+
}, [&](const json & error_data) {
|
|
2747
|
+
res_error(res, error_data);
|
|
2748
|
+
});
|
|
3072
2749
|
|
|
3073
|
-
|
|
2750
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
2751
|
+
} else {
|
|
2752
|
+
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
|
2753
|
+
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
|
2754
|
+
return server_sent_event(sink, "data", result.data);
|
|
2755
|
+
}, [&](const json & error_data) {
|
|
2756
|
+
server_sent_event(sink, "error", error_data);
|
|
2757
|
+
});
|
|
3074
2758
|
sink.done();
|
|
3075
|
-
|
|
3076
|
-
return true;
|
|
2759
|
+
return false;
|
|
3077
2760
|
};
|
|
3078
2761
|
|
|
3079
|
-
auto on_complete = [
|
|
3080
|
-
|
|
3081
|
-
ctx_server.request_cancel(id_task);
|
|
3082
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
2762
|
+
auto on_complete = [task_ids, &ctx_server] (bool) {
|
|
2763
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
3083
2764
|
};
|
|
3084
2765
|
|
|
3085
2766
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
3086
2767
|
}
|
|
3087
2768
|
};
|
|
3088
2769
|
|
|
3089
|
-
const auto
|
|
3090
|
-
|
|
3091
|
-
|
|
3092
|
-
json models = {
|
|
3093
|
-
{"object", "list"},
|
|
3094
|
-
{"data", {
|
|
3095
|
-
{
|
|
3096
|
-
{"id", params.model_alias},
|
|
3097
|
-
{"object", "model"},
|
|
3098
|
-
{"created", std::time(0)},
|
|
3099
|
-
{"owned_by", "llamacpp"},
|
|
3100
|
-
{"meta", model_meta}
|
|
3101
|
-
},
|
|
3102
|
-
}}
|
|
3103
|
-
};
|
|
3104
|
-
|
|
3105
|
-
res.set_content(models.dump(), "application/json; charset=utf-8");
|
|
2770
|
+
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
2771
|
+
json data = json::parse(req.body);
|
|
2772
|
+
return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
|
|
3106
2773
|
};
|
|
3107
2774
|
|
|
3108
|
-
const auto
|
|
3109
|
-
|
|
3110
|
-
|
|
2775
|
+
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
2776
|
+
// check model compatibility
|
|
2777
|
+
std::string err;
|
|
2778
|
+
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
|
2779
|
+
err += "prefix token is missing. ";
|
|
2780
|
+
}
|
|
2781
|
+
if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
|
2782
|
+
err += "suffix token is missing. ";
|
|
2783
|
+
}
|
|
2784
|
+
if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
|
2785
|
+
err += "middle token is missing. ";
|
|
2786
|
+
}
|
|
2787
|
+
if (!err.empty()) {
|
|
2788
|
+
res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
|
|
3111
2789
|
return;
|
|
3112
2790
|
}
|
|
3113
2791
|
|
|
3114
|
-
|
|
3115
|
-
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
|
3116
|
-
|
|
3117
|
-
const int id_task = ctx_server.queue_tasks.get_new_id();
|
|
3118
|
-
|
|
3119
|
-
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
3120
|
-
ctx_server.request_completion(id_task, -1, data, false, false);
|
|
2792
|
+
json data = json::parse(req.body);
|
|
3121
2793
|
|
|
3122
|
-
|
|
3123
|
-
if (!
|
|
3124
|
-
|
|
2794
|
+
// validate input
|
|
2795
|
+
if (!data.contains("input_prefix")) {
|
|
2796
|
+
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
|
2797
|
+
}
|
|
3125
2798
|
|
|
3126
|
-
|
|
3127
|
-
|
|
2799
|
+
if (!data.contains("input_suffix")) {
|
|
2800
|
+
res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
|
2801
|
+
}
|
|
3128
2802
|
|
|
3129
|
-
|
|
3130
|
-
|
|
3131
|
-
|
|
2803
|
+
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
|
2804
|
+
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
|
2805
|
+
return;
|
|
2806
|
+
}
|
|
2807
|
+
json input_extra = json_value(data, "input_extra", json::array());
|
|
2808
|
+
for (const auto & chunk : input_extra) {
|
|
2809
|
+
// { "text": string, "filename": string }
|
|
2810
|
+
if (!chunk.contains("text") || !chunk.at("text").is_string()) {
|
|
2811
|
+
res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
|
|
2812
|
+
return;
|
|
2813
|
+
}
|
|
2814
|
+
// filename is optional
|
|
2815
|
+
if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
|
|
2816
|
+
res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
2817
|
+
return;
|
|
3132
2818
|
}
|
|
3133
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3134
|
-
} else {
|
|
3135
|
-
const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
|
3136
|
-
while (true) {
|
|
3137
|
-
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
3138
|
-
if (!result.error) {
|
|
3139
|
-
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
|
3140
|
-
|
|
3141
|
-
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
|
|
3142
|
-
if (!it->empty()) {
|
|
3143
|
-
const std::string str =
|
|
3144
|
-
"data: " +
|
|
3145
|
-
it->dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3146
|
-
"\n\n";
|
|
3147
|
-
LOG_VERBOSE("data stream", {{"to_send", str}});
|
|
3148
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3149
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3150
|
-
return false;
|
|
3151
|
-
}
|
|
3152
|
-
}
|
|
3153
|
-
}
|
|
3154
|
-
if (result.stop) {
|
|
3155
|
-
break;
|
|
3156
|
-
}
|
|
3157
|
-
} else {
|
|
3158
|
-
const std::string str =
|
|
3159
|
-
"error: " +
|
|
3160
|
-
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3161
|
-
"\n\n";
|
|
3162
|
-
LOG_VERBOSE("data stream", {{"to_send", str}});
|
|
3163
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3164
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3165
|
-
return false;
|
|
3166
|
-
}
|
|
3167
|
-
break;
|
|
3168
|
-
}
|
|
3169
|
-
}
|
|
3170
|
-
sink.done();
|
|
3171
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3172
|
-
return true;
|
|
3173
|
-
};
|
|
3174
|
-
|
|
3175
|
-
auto on_complete = [id_task, &ctx_server](bool) {
|
|
3176
|
-
// cancel request
|
|
3177
|
-
ctx_server.request_cancel(id_task);
|
|
3178
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3179
|
-
};
|
|
3180
|
-
|
|
3181
|
-
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
3182
2819
|
}
|
|
2820
|
+
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
|
2821
|
+
|
|
2822
|
+
return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
|
|
3183
2823
|
};
|
|
3184
2824
|
|
|
3185
|
-
|
|
2825
|
+
// TODO: maybe merge this function with "handle_completions_generic"
|
|
2826
|
+
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
|
3186
2827
|
if (ctx_server.params.embedding) {
|
|
3187
|
-
res_error(res, format_error_response("This server does not support
|
|
2828
|
+
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3188
2829
|
return;
|
|
3189
2830
|
}
|
|
3190
2831
|
|
|
3191
|
-
|
|
3192
|
-
|
|
3193
|
-
json data = json::parse(req.body);
|
|
2832
|
+
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
|
3194
2833
|
|
|
3195
|
-
|
|
2834
|
+
std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
|
|
2835
|
+
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
2836
|
+
ctx_server.queue_tasks.post(tasks);
|
|
3196
2837
|
|
|
3197
|
-
|
|
3198
|
-
|
|
2838
|
+
bool stream = json_value(data, "stream", false);
|
|
2839
|
+
const auto task_ids = server_task::get_list_id(tasks);
|
|
2840
|
+
const auto completion_id = gen_chatcmplid();
|
|
3199
2841
|
|
|
3200
|
-
if (!
|
|
3201
|
-
server_task_result
|
|
3202
|
-
|
|
3203
|
-
|
|
3204
|
-
|
|
3205
|
-
|
|
3206
|
-
|
|
2842
|
+
if (!stream) {
|
|
2843
|
+
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
|
2844
|
+
// multitask is never support in chat completion, there is only one result
|
|
2845
|
+
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
|
2846
|
+
res_ok(res, result_oai);
|
|
2847
|
+
}, [&](const json & error_data) {
|
|
2848
|
+
res_error(res, error_data);
|
|
2849
|
+
});
|
|
3207
2850
|
|
|
3208
|
-
ctx_server.queue_results.
|
|
2851
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
3209
2852
|
} else {
|
|
3210
|
-
const auto chunked_content_provider = [
|
|
3211
|
-
|
|
3212
|
-
|
|
3213
|
-
|
|
3214
|
-
|
|
3215
|
-
|
|
3216
|
-
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3217
|
-
"\n\n";
|
|
3218
|
-
|
|
3219
|
-
LOG_VERBOSE("data stream", {
|
|
3220
|
-
{ "to_send", str }
|
|
3221
|
-
});
|
|
3222
|
-
|
|
3223
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3224
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3225
|
-
return false;
|
|
2853
|
+
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
|
2854
|
+
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
|
2855
|
+
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
|
2856
|
+
for (auto & event_data : result_array) {
|
|
2857
|
+
if (event_data.empty()) {
|
|
2858
|
+
continue; // skip the stop token
|
|
3226
2859
|
}
|
|
3227
|
-
|
|
3228
|
-
|
|
3229
|
-
break;
|
|
2860
|
+
if (!server_sent_event(sink, "data", event_data)) {
|
|
2861
|
+
return false; // connection is closed
|
|
3230
2862
|
}
|
|
3231
|
-
} else {
|
|
3232
|
-
break;
|
|
3233
2863
|
}
|
|
3234
|
-
|
|
3235
|
-
|
|
3236
|
-
|
|
2864
|
+
return true; // ok
|
|
2865
|
+
}, [&](const json & error_data) {
|
|
2866
|
+
server_sent_event(sink, "error", error_data);
|
|
2867
|
+
});
|
|
2868
|
+
static const std::string ev_done = "data: [DONE]\n\n";
|
|
2869
|
+
sink.write(ev_done.data(), ev_done.size());
|
|
3237
2870
|
sink.done();
|
|
3238
|
-
|
|
3239
2871
|
return true;
|
|
3240
2872
|
};
|
|
3241
2873
|
|
|
3242
|
-
auto on_complete = [
|
|
3243
|
-
ctx_server.
|
|
2874
|
+
auto on_complete = [task_ids, &ctx_server] (bool) {
|
|
2875
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
3244
2876
|
};
|
|
3245
2877
|
|
|
3246
2878
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
3247
2879
|
}
|
|
3248
2880
|
};
|
|
3249
2881
|
|
|
3250
|
-
const auto
|
|
3251
|
-
|
|
2882
|
+
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
|
2883
|
+
json models = {
|
|
2884
|
+
{"object", "list"},
|
|
2885
|
+
{"data", {
|
|
2886
|
+
{
|
|
2887
|
+
{"id", params.model_alias},
|
|
2888
|
+
{"object", "model"},
|
|
2889
|
+
{"created", std::time(0)},
|
|
2890
|
+
{"owned_by", "llamacpp"},
|
|
2891
|
+
{"meta", ctx_server.model_meta()}
|
|
2892
|
+
},
|
|
2893
|
+
}}
|
|
2894
|
+
};
|
|
2895
|
+
|
|
2896
|
+
res.set_content(models.dump(), MIMETYPE_JSON);
|
|
2897
|
+
};
|
|
2898
|
+
|
|
2899
|
+
const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3252
2900
|
const json body = json::parse(req.body);
|
|
3253
2901
|
|
|
3254
|
-
|
|
2902
|
+
json tokens_response = json::array();
|
|
3255
2903
|
if (body.count("content") != 0) {
|
|
3256
2904
|
const bool add_special = json_value(body, "add_special", false);
|
|
3257
|
-
|
|
2905
|
+
const bool with_pieces = json_value(body, "with_pieces", false);
|
|
2906
|
+
|
|
2907
|
+
llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
|
|
2908
|
+
|
|
2909
|
+
if (with_pieces) {
|
|
2910
|
+
for (const auto& token : tokens) {
|
|
2911
|
+
std::string piece = common_token_to_piece(ctx_server.ctx, token);
|
|
2912
|
+
json piece_json;
|
|
2913
|
+
|
|
2914
|
+
// Check if the piece is valid UTF-8
|
|
2915
|
+
if (is_valid_utf8(piece)) {
|
|
2916
|
+
piece_json = piece;
|
|
2917
|
+
} else {
|
|
2918
|
+
// If not valid UTF-8, store as array of byte values
|
|
2919
|
+
piece_json = json::array();
|
|
2920
|
+
for (unsigned char c : piece) {
|
|
2921
|
+
piece_json.push_back(static_cast<int>(c));
|
|
2922
|
+
}
|
|
2923
|
+
}
|
|
2924
|
+
|
|
2925
|
+
tokens_response.push_back({
|
|
2926
|
+
{"id", token},
|
|
2927
|
+
{"piece", piece_json}
|
|
2928
|
+
});
|
|
2929
|
+
}
|
|
2930
|
+
} else {
|
|
2931
|
+
tokens_response = tokens;
|
|
2932
|
+
}
|
|
3258
2933
|
}
|
|
3259
|
-
|
|
3260
|
-
|
|
2934
|
+
|
|
2935
|
+
const json data = format_tokenizer_response(tokens_response);
|
|
2936
|
+
res_ok(res, data);
|
|
3261
2937
|
};
|
|
3262
2938
|
|
|
3263
|
-
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
3264
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
2939
|
+
const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3265
2940
|
const json body = json::parse(req.body);
|
|
3266
2941
|
|
|
3267
2942
|
std::string content;
|
|
3268
2943
|
if (body.count("tokens") != 0) {
|
|
3269
|
-
const
|
|
2944
|
+
const llama_tokens tokens = body.at("tokens");
|
|
3270
2945
|
content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
|
|
3271
2946
|
}
|
|
3272
2947
|
|
|
3273
2948
|
const json data = format_detokenized_response(content);
|
|
3274
|
-
|
|
2949
|
+
res_ok(res, data);
|
|
3275
2950
|
};
|
|
3276
2951
|
|
|
3277
|
-
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
3278
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
3279
|
-
|
|
2952
|
+
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3280
2953
|
const json body = json::parse(req.body);
|
|
3281
2954
|
bool is_openai = false;
|
|
3282
2955
|
|
|
@@ -3294,42 +2967,157 @@ int main(int argc, char ** argv) {
|
|
|
3294
2967
|
}
|
|
3295
2968
|
|
|
3296
2969
|
// create and queue the task
|
|
3297
|
-
json responses;
|
|
2970
|
+
json responses = json::array();
|
|
2971
|
+
bool error = false;
|
|
3298
2972
|
{
|
|
3299
|
-
|
|
3300
|
-
ctx_server.queue_results.
|
|
3301
|
-
ctx_server.
|
|
2973
|
+
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
|
|
2974
|
+
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
2975
|
+
ctx_server.queue_tasks.post(tasks);
|
|
3302
2976
|
|
|
3303
2977
|
// get the result
|
|
3304
|
-
|
|
3305
|
-
|
|
3306
|
-
|
|
3307
|
-
|
|
3308
|
-
|
|
3309
|
-
responses = result.data.at("results");
|
|
3310
|
-
} else {
|
|
3311
|
-
// result for single task
|
|
3312
|
-
responses = std::vector<json>{result.data};
|
|
2978
|
+
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
2979
|
+
|
|
2980
|
+
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
|
2981
|
+
for (const auto & res : results) {
|
|
2982
|
+
responses.push_back(res.data);
|
|
3313
2983
|
}
|
|
3314
|
-
}
|
|
3315
|
-
|
|
3316
|
-
|
|
3317
|
-
|
|
3318
|
-
|
|
2984
|
+
}, [&](const json & error_data) {
|
|
2985
|
+
res_error(res, error_data);
|
|
2986
|
+
error = true;
|
|
2987
|
+
});
|
|
2988
|
+
|
|
2989
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
2990
|
+
}
|
|
2991
|
+
|
|
2992
|
+
if (error) {
|
|
2993
|
+
return;
|
|
3319
2994
|
}
|
|
3320
2995
|
|
|
3321
2996
|
// write JSON response
|
|
3322
2997
|
json root = is_openai
|
|
3323
2998
|
? format_embeddings_response_oaicompat(body, responses)
|
|
3324
2999
|
: responses[0];
|
|
3325
|
-
|
|
3000
|
+
res_ok(res, root);
|
|
3326
3001
|
};
|
|
3327
3002
|
|
|
3328
|
-
auto
|
|
3329
|
-
|
|
3330
|
-
res
|
|
3331
|
-
return
|
|
3332
|
-
}
|
|
3003
|
+
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3004
|
+
if (!ctx_server.params.reranking || ctx_server.params.embedding) {
|
|
3005
|
+
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3006
|
+
return;
|
|
3007
|
+
}
|
|
3008
|
+
|
|
3009
|
+
const json body = json::parse(req.body);
|
|
3010
|
+
|
|
3011
|
+
// TODO: implement
|
|
3012
|
+
//int top_n = 1;
|
|
3013
|
+
//if (body.count("top_n") != 1) {
|
|
3014
|
+
// top_n = body.at("top_n");
|
|
3015
|
+
//} else {
|
|
3016
|
+
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
3017
|
+
// return;
|
|
3018
|
+
//}
|
|
3019
|
+
|
|
3020
|
+
json query;
|
|
3021
|
+
if (body.count("query") == 1) {
|
|
3022
|
+
query = body.at("query");
|
|
3023
|
+
if (!query.is_string()) {
|
|
3024
|
+
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
3025
|
+
return;
|
|
3026
|
+
}
|
|
3027
|
+
} else {
|
|
3028
|
+
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
3029
|
+
return;
|
|
3030
|
+
}
|
|
3031
|
+
|
|
3032
|
+
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
|
|
3033
|
+
if (documents.empty()) {
|
|
3034
|
+
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
|
3035
|
+
return;
|
|
3036
|
+
}
|
|
3037
|
+
|
|
3038
|
+
// construct prompt object: array of ["query", "doc0", "doc1", ...]
|
|
3039
|
+
json prompt;
|
|
3040
|
+
prompt.push_back(query);
|
|
3041
|
+
for (const auto & doc : documents) {
|
|
3042
|
+
prompt.push_back(doc);
|
|
3043
|
+
}
|
|
3044
|
+
|
|
3045
|
+
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
|
|
3046
|
+
|
|
3047
|
+
// create and queue the task
|
|
3048
|
+
json responses = json::array();
|
|
3049
|
+
bool error = false;
|
|
3050
|
+
{
|
|
3051
|
+
std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
|
|
3052
|
+
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
3053
|
+
ctx_server.queue_tasks.post(tasks);
|
|
3054
|
+
|
|
3055
|
+
// get the result
|
|
3056
|
+
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
3057
|
+
|
|
3058
|
+
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
|
3059
|
+
for (const auto & res : results) {
|
|
3060
|
+
responses.push_back(res.data);
|
|
3061
|
+
}
|
|
3062
|
+
}, [&](const json & error_data) {
|
|
3063
|
+
res_error(res, error_data);
|
|
3064
|
+
error = true;
|
|
3065
|
+
});
|
|
3066
|
+
}
|
|
3067
|
+
|
|
3068
|
+
if (error) {
|
|
3069
|
+
return;
|
|
3070
|
+
}
|
|
3071
|
+
|
|
3072
|
+
// write JSON response
|
|
3073
|
+
json root = format_response_rerank(body, responses);
|
|
3074
|
+
res_ok(res, root);
|
|
3075
|
+
};
|
|
3076
|
+
|
|
3077
|
+
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
|
3078
|
+
json result = json::array();
|
|
3079
|
+
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
|
3080
|
+
auto & lora = ctx_server.loras[i];
|
|
3081
|
+
result.push_back({
|
|
3082
|
+
{"id", i},
|
|
3083
|
+
{"path", lora.path},
|
|
3084
|
+
{"scale", lora.scale},
|
|
3085
|
+
});
|
|
3086
|
+
}
|
|
3087
|
+
res_ok(res, result);
|
|
3088
|
+
res.status = 200; // HTTP OK
|
|
3089
|
+
};
|
|
3090
|
+
|
|
3091
|
+
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
|
3092
|
+
const std::vector<json> body = json::parse(req.body);
|
|
3093
|
+
int max_idx = ctx_server.loras.size();
|
|
3094
|
+
|
|
3095
|
+
// clear existing value
|
|
3096
|
+
for (auto & lora : ctx_server.loras) {
|
|
3097
|
+
lora.scale = 0.0f;
|
|
3098
|
+
}
|
|
3099
|
+
|
|
3100
|
+
// set value
|
|
3101
|
+
for (auto entry : body) {
|
|
3102
|
+
int id = entry.at("id");
|
|
3103
|
+
float scale = entry.at("scale");
|
|
3104
|
+
if (0 <= id && id < max_idx) {
|
|
3105
|
+
ctx_server.loras[id].scale = scale;
|
|
3106
|
+
} else {
|
|
3107
|
+
throw std::runtime_error("invalid adapter id");
|
|
3108
|
+
}
|
|
3109
|
+
}
|
|
3110
|
+
|
|
3111
|
+
server_task task;
|
|
3112
|
+
task.type = SERVER_TASK_TYPE_SET_LORA;
|
|
3113
|
+
const int id_task = ctx_server.queue_tasks.post(task);
|
|
3114
|
+
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
3115
|
+
|
|
3116
|
+
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
3117
|
+
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3118
|
+
|
|
3119
|
+
res_ok(res, result.data);
|
|
3120
|
+
res.status = 200; // HTTP OK
|
|
3333
3121
|
};
|
|
3334
3122
|
|
|
3335
3123
|
//
|
|
@@ -3339,34 +3127,29 @@ int main(int argc, char ** argv) {
|
|
|
3339
3127
|
// register static assets routes
|
|
3340
3128
|
if (!params.public_path.empty()) {
|
|
3341
3129
|
// Set the base directory for serving static files
|
|
3342
|
-
svr->
|
|
3343
|
-
|
|
3344
|
-
|
|
3345
|
-
|
|
3346
|
-
|
|
3347
|
-
|
|
3348
|
-
|
|
3349
|
-
|
|
3350
|
-
|
|
3351
|
-
|
|
3352
|
-
|
|
3353
|
-
|
|
3354
|
-
|
|
3355
|
-
|
|
3356
|
-
|
|
3357
|
-
svr->Get("/theme-playground.css", handle_static_file(theme_playground_css, theme_playground_css_len, "text/css; charset=utf-8"));
|
|
3358
|
-
svr->Get("/theme-polarnight.css", handle_static_file(theme_polarnight_css, theme_polarnight_css_len, "text/css; charset=utf-8"));
|
|
3359
|
-
svr->Get("/theme-snowstorm.css", handle_static_file(theme_snowstorm_css, theme_snowstorm_css_len, "text/css; charset=utf-8"));
|
|
3360
|
-
svr->Get("/index-new.html", handle_static_file(index_new_html, index_new_html_len, "text/html; charset=utf-8"));
|
|
3361
|
-
svr->Get("/system-prompts.js", handle_static_file(system_prompts_js, system_prompts_js_len, "text/javascript; charset=utf-8"));
|
|
3362
|
-
svr->Get("/prompt-formats.js", handle_static_file(prompt_formats_js, prompt_formats_js_len, "text/javascript; charset=utf-8"));
|
|
3130
|
+
bool is_found = svr->set_mount_point("/", params.public_path);
|
|
3131
|
+
if (!is_found) {
|
|
3132
|
+
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
|
3133
|
+
return 1;
|
|
3134
|
+
}
|
|
3135
|
+
} else {
|
|
3136
|
+
// using embedded static files
|
|
3137
|
+
for (const auto & it : static_files) {
|
|
3138
|
+
const server_static_file & static_file = it.second;
|
|
3139
|
+
svr->Get(it.first.c_str(), [&static_file](const httplib::Request &, httplib::Response & res) {
|
|
3140
|
+
res.set_content(reinterpret_cast<const char*>(static_file.data), static_file.size, static_file.mime_type);
|
|
3141
|
+
return false;
|
|
3142
|
+
});
|
|
3143
|
+
}
|
|
3144
|
+
}
|
|
3363
3145
|
|
|
3364
3146
|
// register API routes
|
|
3365
|
-
svr->Get ("/health", handle_health);
|
|
3366
|
-
svr->Get ("/slots", handle_slots);
|
|
3147
|
+
svr->Get ("/health", handle_health); // public endpoint (no API key check)
|
|
3367
3148
|
svr->Get ("/metrics", handle_metrics);
|
|
3368
3149
|
svr->Get ("/props", handle_props);
|
|
3369
|
-
svr->
|
|
3150
|
+
svr->Post("/props", handle_props_change);
|
|
3151
|
+
svr->Get ("/models", handle_models); // public endpoint (no API key check)
|
|
3152
|
+
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
|
|
3370
3153
|
svr->Post("/completion", handle_completions); // legacy
|
|
3371
3154
|
svr->Post("/completions", handle_completions);
|
|
3372
3155
|
svr->Post("/v1/completions", handle_completions);
|
|
@@ -3376,12 +3159,18 @@ int main(int argc, char ** argv) {
|
|
|
3376
3159
|
svr->Post("/embedding", handle_embeddings); // legacy
|
|
3377
3160
|
svr->Post("/embeddings", handle_embeddings);
|
|
3378
3161
|
svr->Post("/v1/embeddings", handle_embeddings);
|
|
3162
|
+
svr->Post("/rerank", handle_rerank);
|
|
3163
|
+
svr->Post("/reranking", handle_rerank);
|
|
3164
|
+
svr->Post("/v1/rerank", handle_rerank);
|
|
3165
|
+
svr->Post("/v1/reranking", handle_rerank);
|
|
3379
3166
|
svr->Post("/tokenize", handle_tokenize);
|
|
3380
3167
|
svr->Post("/detokenize", handle_detokenize);
|
|
3381
|
-
|
|
3382
|
-
|
|
3383
|
-
|
|
3384
|
-
|
|
3168
|
+
// LoRA adapters hotswap
|
|
3169
|
+
svr->Get ("/lora-adapters", handle_lora_adapters_list);
|
|
3170
|
+
svr->Post("/lora-adapters", handle_lora_adapters_apply);
|
|
3171
|
+
// Save & load slots
|
|
3172
|
+
svr->Get ("/slots", handle_slots);
|
|
3173
|
+
svr->Post("/slots/:id_slot", handle_slots_action);
|
|
3385
3174
|
|
|
3386
3175
|
//
|
|
3387
3176
|
// Start the server
|
|
@@ -3393,36 +3182,67 @@ int main(int argc, char ** argv) {
|
|
|
3393
3182
|
log_data["n_threads_http"] = std::to_string(params.n_threads_http);
|
|
3394
3183
|
svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); };
|
|
3395
3184
|
|
|
3396
|
-
|
|
3185
|
+
// clean up function, to be called before exit
|
|
3186
|
+
auto clean_up = [&svr]() {
|
|
3187
|
+
svr->stop();
|
|
3188
|
+
llama_backend_free();
|
|
3189
|
+
};
|
|
3190
|
+
|
|
3191
|
+
// bind HTTP listen port, run the HTTP server in a thread
|
|
3192
|
+
if (!svr->bind_to_port(params.hostname, params.port)) {
|
|
3193
|
+
//LOG_ERROR("couldn't bind HTTP server socket", {
|
|
3194
|
+
// {"hostname", params.hostname},
|
|
3195
|
+
// {"port", params.port},
|
|
3196
|
+
//});
|
|
3197
|
+
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
|
|
3198
|
+
clean_up();
|
|
3199
|
+
return 1;
|
|
3200
|
+
}
|
|
3201
|
+
std::thread t([&]() { svr->listen_after_bind(); });
|
|
3202
|
+
svr->wait_until_ready();
|
|
3203
|
+
|
|
3204
|
+
LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http);
|
|
3397
3205
|
|
|
3398
|
-
//
|
|
3399
|
-
|
|
3400
|
-
|
|
3401
|
-
|
|
3402
|
-
|
|
3206
|
+
// load the model
|
|
3207
|
+
LOG_INF("%s: loading model\n", __func__);
|
|
3208
|
+
|
|
3209
|
+
if (!ctx_server.load_model(params)) {
|
|
3210
|
+
clean_up();
|
|
3211
|
+
t.join();
|
|
3212
|
+
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
|
3213
|
+
return 1;
|
|
3214
|
+
}
|
|
3215
|
+
|
|
3216
|
+
ctx_server.init();
|
|
3217
|
+
state.store(SERVER_STATE_READY);
|
|
3218
|
+
|
|
3219
|
+
LOG_INF("%s: model loaded\n", __func__);
|
|
3220
|
+
|
|
3221
|
+
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
|
3222
|
+
if (params.chat_template.empty()) {
|
|
3223
|
+
if (!ctx_server.validate_model_chat_template()) {
|
|
3224
|
+
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
|
3225
|
+
params.chat_template = "chatml";
|
|
3403
3226
|
}
|
|
3227
|
+
}
|
|
3404
3228
|
|
|
3405
|
-
|
|
3406
|
-
|
|
3229
|
+
// print sample chat example to make it clear which template is used
|
|
3230
|
+
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
|
3407
3231
|
|
|
3408
3232
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
|
3409
|
-
|
|
3410
|
-
|
|
3411
|
-
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
|
3233
|
+
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
|
3234
|
+
|
|
3412
3235
|
ctx_server.queue_tasks.on_update_slots(std::bind(
|
|
3413
|
-
|
|
3414
|
-
ctx_server.queue_results.on_multitask_update(std::bind(
|
|
3415
|
-
&server_queue::update_multitask,
|
|
3416
|
-
&ctx_server.queue_tasks,
|
|
3417
|
-
std::placeholders::_1,
|
|
3418
|
-
std::placeholders::_2,
|
|
3419
|
-
std::placeholders::_3
|
|
3420
|
-
));
|
|
3236
|
+
&server_context::update_slots, &ctx_server));
|
|
3421
3237
|
|
|
3422
3238
|
shutdown_handler = [&](int) {
|
|
3423
3239
|
ctx_server.queue_tasks.terminate();
|
|
3424
3240
|
};
|
|
3425
3241
|
|
|
3242
|
+
LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
|
3243
|
+
|
|
3244
|
+
ctx_server.queue_tasks.start_loop();
|
|
3245
|
+
|
|
3426
3246
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
|
3427
3247
|
struct sigaction sigint_action;
|
|
3428
3248
|
sigint_action.sa_handler = signal_handler;
|
|
@@ -3437,12 +3257,8 @@ int main(int argc, char ** argv) {
|
|
|
3437
3257
|
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
|
3438
3258
|
#endif
|
|
3439
3259
|
|
|
3440
|
-
|
|
3441
|
-
|
|
3442
|
-
svr->stop();
|
|
3260
|
+
clean_up();
|
|
3443
3261
|
t.join();
|
|
3444
3262
|
|
|
3445
|
-
llama_backend_free();
|
|
3446
|
-
|
|
3447
3263
|
return 0;
|
|
3448
3264
|
}
|