@fugood/llama.node 0.3.3 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +5 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +15 -5
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +1 -1
- package/src/LlamaContext.cpp +81 -18
- package/src/LlamaContext.h +2 -0
- package/src/llama.cpp/.github/workflows/build.yml +197 -159
- package/src/llama.cpp/.github/workflows/docker.yml +5 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +11 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -2
- package/src/llama.cpp/common/arg.cpp +426 -245
- package/src/llama.cpp/common/common.cpp +143 -80
- package/src/llama.cpp/common/common.h +81 -24
- package/src/llama.cpp/common/sampling.cpp +53 -19
- package/src/llama.cpp/common/sampling.h +22 -1
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +101 -148
- package/src/llama.cpp/examples/CMakeLists.txt +32 -13
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +5 -4
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +262 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/llava.cpp +46 -19
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
- package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +9 -5
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
- package/src/llama.cpp/examples/server/server.cpp +1758 -886
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +94 -304
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +4 -0
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
- package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +106 -24
- package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
- package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
- package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
- package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
- package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
- package/src/llama.cpp/ggml/src/ggml.c +367 -207
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +26 -19
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/CMakeLists.txt +2 -7
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +35 -90
- package/src/llama.cpp/src/llama-vocab.cpp +6 -1
- package/src/llama.cpp/src/llama.cpp +1748 -640
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -37
- package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
- package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
- package/src/llama.cpp/tests/test-rope.cpp +61 -20
- package/src/llama.cpp/tests/test-sampling.cpp +2 -2
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
|
@@ -2,10 +2,11 @@
|
|
|
2
2
|
|
|
3
3
|
#include "arg.h"
|
|
4
4
|
#include "common.h"
|
|
5
|
-
#include "log.h"
|
|
6
|
-
#include "sampling.h"
|
|
7
5
|
#include "json-schema-to-grammar.h"
|
|
8
6
|
#include "llama.h"
|
|
7
|
+
#include "log.h"
|
|
8
|
+
#include "sampling.h"
|
|
9
|
+
#include "speculative.h"
|
|
9
10
|
|
|
10
11
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
|
11
12
|
#define JSON_ASSERT GGML_ASSERT
|
|
@@ -14,13 +15,8 @@
|
|
|
14
15
|
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
|
15
16
|
|
|
16
17
|
// auto generated files (update with ./deps.sh)
|
|
17
|
-
#include "index.html.hpp"
|
|
18
|
-
#include "completion.js.hpp"
|
|
18
|
+
#include "index.html.gz.hpp"
|
|
19
19
|
#include "loading.html.hpp"
|
|
20
|
-
#include "deps_daisyui.min.css.hpp"
|
|
21
|
-
#include "deps_markdown-it.js.hpp"
|
|
22
|
-
#include "deps_tailwindcss.js.hpp"
|
|
23
|
-
#include "deps_vue.esm-browser.js.hpp"
|
|
24
20
|
|
|
25
21
|
#include <atomic>
|
|
26
22
|
#include <condition_variable>
|
|
@@ -37,8 +33,10 @@
|
|
|
37
33
|
using json = nlohmann::ordered_json;
|
|
38
34
|
|
|
39
35
|
enum stop_type {
|
|
40
|
-
|
|
41
|
-
|
|
36
|
+
STOP_TYPE_NONE,
|
|
37
|
+
STOP_TYPE_EOS,
|
|
38
|
+
STOP_TYPE_WORD,
|
|
39
|
+
STOP_TYPE_LIMIT,
|
|
42
40
|
};
|
|
43
41
|
|
|
44
42
|
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
|
@@ -56,7 +54,10 @@ enum server_state {
|
|
|
56
54
|
};
|
|
57
55
|
|
|
58
56
|
enum server_task_type {
|
|
59
|
-
|
|
57
|
+
SERVER_TASK_TYPE_COMPLETION,
|
|
58
|
+
SERVER_TASK_TYPE_EMBEDDING,
|
|
59
|
+
SERVER_TASK_TYPE_RERANK,
|
|
60
|
+
SERVER_TASK_TYPE_INFILL,
|
|
60
61
|
SERVER_TASK_TYPE_CANCEL,
|
|
61
62
|
SERVER_TASK_TYPE_NEXT_RESPONSE,
|
|
62
63
|
SERVER_TASK_TYPE_METRICS,
|
|
@@ -66,22 +67,309 @@ enum server_task_type {
|
|
|
66
67
|
SERVER_TASK_TYPE_SET_LORA,
|
|
67
68
|
};
|
|
68
69
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
70
|
+
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
|
71
|
+
enum error_type {
|
|
72
|
+
ERROR_TYPE_INVALID_REQUEST,
|
|
73
|
+
ERROR_TYPE_AUTHENTICATION,
|
|
74
|
+
ERROR_TYPE_SERVER,
|
|
75
|
+
ERROR_TYPE_NOT_FOUND,
|
|
76
|
+
ERROR_TYPE_PERMISSION,
|
|
77
|
+
ERROR_TYPE_UNAVAILABLE, // custom error
|
|
78
|
+
ERROR_TYPE_NOT_SUPPORTED, // custom error
|
|
79
|
+
};
|
|
80
|
+
|
|
81
|
+
struct slot_params {
|
|
82
|
+
bool stream = true;
|
|
83
|
+
bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
|
|
84
|
+
bool return_tokens = false;
|
|
85
|
+
|
|
86
|
+
int32_t n_keep = 0; // number of tokens to keep from initial prompt
|
|
87
|
+
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
|
|
88
|
+
int32_t n_predict = -1; // new tokens to predict
|
|
89
|
+
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
|
|
90
|
+
|
|
91
|
+
int64_t t_max_prompt_ms = -1; // TODO: implement
|
|
92
|
+
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
|
93
|
+
|
|
94
|
+
std::vector<std::string> antiprompt;
|
|
95
|
+
bool timings_per_token = false;
|
|
96
|
+
bool post_sampling_probs = false;
|
|
97
|
+
bool ignore_eos = false;
|
|
98
|
+
|
|
99
|
+
struct common_params_sampling sampling;
|
|
100
|
+
struct common_params_speculative speculative;
|
|
101
|
+
|
|
102
|
+
// OAI-compat fields
|
|
103
|
+
bool verbose = false;
|
|
104
|
+
bool oaicompat = false;
|
|
105
|
+
bool oaicompat_chat = true;
|
|
106
|
+
std::string oaicompat_model;
|
|
107
|
+
std::string oaicompat_cmpl_id;
|
|
108
|
+
|
|
109
|
+
json to_json() const {
|
|
110
|
+
std::vector<std::string> samplers;
|
|
111
|
+
samplers.reserve(sampling.samplers.size());
|
|
112
|
+
for (const auto & sampler : sampling.samplers) {
|
|
113
|
+
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
return json {
|
|
117
|
+
{"n_predict", n_predict}, // Server configured n_predict
|
|
118
|
+
{"seed", sampling.seed},
|
|
119
|
+
{"temperature", sampling.temp},
|
|
120
|
+
{"dynatemp_range", sampling.dynatemp_range},
|
|
121
|
+
{"dynatemp_exponent", sampling.dynatemp_exponent},
|
|
122
|
+
{"top_k", sampling.top_k},
|
|
123
|
+
{"top_p", sampling.top_p},
|
|
124
|
+
{"min_p", sampling.min_p},
|
|
125
|
+
{"xtc_probability", sampling.xtc_probability},
|
|
126
|
+
{"xtc_threshold", sampling.xtc_threshold},
|
|
127
|
+
{"typical_p", sampling.typ_p},
|
|
128
|
+
{"repeat_last_n", sampling.penalty_last_n},
|
|
129
|
+
{"repeat_penalty", sampling.penalty_repeat},
|
|
130
|
+
{"presence_penalty", sampling.penalty_present},
|
|
131
|
+
{"frequency_penalty", sampling.penalty_freq},
|
|
132
|
+
{"dry_multiplier", sampling.dry_multiplier},
|
|
133
|
+
{"dry_base", sampling.dry_base},
|
|
134
|
+
{"dry_allowed_length", sampling.dry_allowed_length},
|
|
135
|
+
{"dry_penalty_last_n", sampling.dry_penalty_last_n},
|
|
136
|
+
{"dry_sequence_breakers", sampling.dry_sequence_breakers},
|
|
137
|
+
{"mirostat", sampling.mirostat},
|
|
138
|
+
{"mirostat_tau", sampling.mirostat_tau},
|
|
139
|
+
{"mirostat_eta", sampling.mirostat_eta},
|
|
140
|
+
{"stop", antiprompt},
|
|
141
|
+
{"max_tokens", n_predict}, // User configured n_predict
|
|
142
|
+
{"n_keep", n_keep},
|
|
143
|
+
{"n_discard", n_discard},
|
|
144
|
+
{"ignore_eos", sampling.ignore_eos},
|
|
145
|
+
{"stream", stream},
|
|
146
|
+
{"logit_bias", format_logit_bias(sampling.logit_bias)},
|
|
147
|
+
{"n_probs", sampling.n_probs},
|
|
148
|
+
{"min_keep", sampling.min_keep},
|
|
149
|
+
{"grammar", sampling.grammar},
|
|
150
|
+
{"samplers", samplers},
|
|
151
|
+
{"speculative.n_max", speculative.n_max},
|
|
152
|
+
{"speculative.n_min", speculative.n_min},
|
|
153
|
+
{"speculative.p_min", speculative.p_min},
|
|
154
|
+
{"timings_per_token", timings_per_token},
|
|
155
|
+
{"post_sampling_probs", post_sampling_probs},
|
|
156
|
+
};
|
|
157
|
+
}
|
|
74
158
|
};
|
|
75
159
|
|
|
76
160
|
struct server_task {
|
|
77
|
-
int id
|
|
78
|
-
int
|
|
161
|
+
int id = -1; // to be filled by server_queue
|
|
162
|
+
int index = -1; // used when there are multiple prompts (batch request)
|
|
79
163
|
|
|
80
|
-
llama_tokens prompt_tokens;
|
|
81
164
|
server_task_type type;
|
|
82
|
-
json data;
|
|
83
165
|
|
|
84
|
-
|
|
166
|
+
// used by SERVER_TASK_TYPE_CANCEL
|
|
167
|
+
int id_target = -1;
|
|
168
|
+
|
|
169
|
+
// used by SERVER_TASK_TYPE_INFERENCE
|
|
170
|
+
slot_params params;
|
|
171
|
+
llama_tokens prompt_tokens;
|
|
172
|
+
int id_selected_slot = -1;
|
|
173
|
+
|
|
174
|
+
// used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
|
|
175
|
+
struct slot_action {
|
|
176
|
+
int slot_id;
|
|
177
|
+
std::string filename;
|
|
178
|
+
std::string filepath;
|
|
179
|
+
};
|
|
180
|
+
slot_action slot_action;
|
|
181
|
+
|
|
182
|
+
// used by SERVER_TASK_TYPE_METRICS
|
|
183
|
+
bool metrics_reset_bucket = false;
|
|
184
|
+
|
|
185
|
+
server_task(server_task_type type) : type(type) {}
|
|
186
|
+
|
|
187
|
+
static slot_params params_from_json_cmpl(
|
|
188
|
+
const llama_model * model,
|
|
189
|
+
const llama_context * ctx,
|
|
190
|
+
const common_params & params_base,
|
|
191
|
+
const json & data) {
|
|
192
|
+
slot_params params;
|
|
193
|
+
|
|
194
|
+
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
|
195
|
+
slot_params defaults;
|
|
196
|
+
defaults.sampling = params_base.sampling;
|
|
197
|
+
defaults.speculative = params_base.speculative;
|
|
198
|
+
|
|
199
|
+
// enabling this will output extra debug information in the HTTP responses from the server
|
|
200
|
+
params.verbose = params_base.verbosity > 9;
|
|
201
|
+
params.timings_per_token = json_value(data, "timings_per_token", false);
|
|
202
|
+
|
|
203
|
+
params.stream = json_value(data, "stream", false);
|
|
204
|
+
params.cache_prompt = json_value(data, "cache_prompt", true);
|
|
205
|
+
params.return_tokens = json_value(data, "return_tokens", false);
|
|
206
|
+
params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
|
|
207
|
+
params.n_indent = json_value(data, "n_indent", defaults.n_indent);
|
|
208
|
+
params.n_keep = json_value(data, "n_keep", defaults.n_keep);
|
|
209
|
+
params.n_discard = json_value(data, "n_discard", defaults.n_discard);
|
|
210
|
+
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
|
|
211
|
+
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
|
|
212
|
+
|
|
213
|
+
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
|
|
214
|
+
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
|
|
215
|
+
params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
|
|
216
|
+
params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
|
|
217
|
+
params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
|
|
218
|
+
params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
|
|
219
|
+
params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
|
|
220
|
+
params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
|
|
221
|
+
params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
|
|
222
|
+
params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
|
|
223
|
+
params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
|
|
224
|
+
params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
|
|
225
|
+
params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
|
|
226
|
+
params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
|
|
227
|
+
params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
|
|
228
|
+
params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
|
|
229
|
+
params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
|
|
230
|
+
params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
|
|
231
|
+
params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
|
|
232
|
+
params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
|
|
233
|
+
params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
|
|
234
|
+
params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
|
|
235
|
+
params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
|
|
236
|
+
params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
|
|
237
|
+
|
|
238
|
+
params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
|
|
239
|
+
params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
|
|
240
|
+
params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
|
|
241
|
+
|
|
242
|
+
params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
|
|
243
|
+
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
|
244
|
+
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
|
245
|
+
|
|
246
|
+
// TODO: add more sanity checks for the input parameters
|
|
247
|
+
|
|
248
|
+
if (params.sampling.penalty_last_n < -1) {
|
|
249
|
+
throw std::runtime_error("Error: repeat_last_n must be >= -1");
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
if (params.sampling.dry_penalty_last_n < -1) {
|
|
253
|
+
throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
|
|
254
|
+
}
|
|
255
|
+
|
|
256
|
+
if (params.sampling.penalty_last_n == -1) {
|
|
257
|
+
// note: should be the slot's context and not the full context, but it's ok
|
|
258
|
+
params.sampling.penalty_last_n = llama_n_ctx(ctx);
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
if (params.sampling.dry_penalty_last_n == -1) {
|
|
262
|
+
params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
|
|
263
|
+
}
|
|
264
|
+
|
|
265
|
+
if (params.sampling.dry_base < 1.0f) {
|
|
266
|
+
params.sampling.dry_base = defaults.sampling.dry_base;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
// sequence breakers for DRY
|
|
270
|
+
{
|
|
271
|
+
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
|
272
|
+
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
|
273
|
+
|
|
274
|
+
if (data.contains("dry_sequence_breakers")) {
|
|
275
|
+
params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
|
276
|
+
if (params.sampling.dry_sequence_breakers.empty()) {
|
|
277
|
+
throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
|
|
278
|
+
}
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
// process "json_schema" and "grammar"
|
|
283
|
+
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
|
284
|
+
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
|
|
285
|
+
}
|
|
286
|
+
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
287
|
+
try {
|
|
288
|
+
auto schema = json_value(data, "json_schema", json::object());
|
|
289
|
+
params.sampling.grammar = json_schema_to_grammar(schema);
|
|
290
|
+
} catch (const std::exception & e) {
|
|
291
|
+
throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
|
|
292
|
+
}
|
|
293
|
+
} else {
|
|
294
|
+
params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
{
|
|
298
|
+
params.sampling.logit_bias.clear();
|
|
299
|
+
params.ignore_eos = json_value(data, "ignore_eos", false);
|
|
300
|
+
|
|
301
|
+
const auto & logit_bias = data.find("logit_bias");
|
|
302
|
+
if (logit_bias != data.end() && logit_bias->is_array()) {
|
|
303
|
+
const int n_vocab = llama_n_vocab(model);
|
|
304
|
+
for (const auto & el : *logit_bias) {
|
|
305
|
+
// TODO: we may want to throw errors here, in case "el" is incorrect
|
|
306
|
+
if (el.is_array() && el.size() == 2) {
|
|
307
|
+
float bias;
|
|
308
|
+
if (el[1].is_number()) {
|
|
309
|
+
bias = el[1].get<float>();
|
|
310
|
+
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
|
311
|
+
bias = -INFINITY;
|
|
312
|
+
} else {
|
|
313
|
+
continue;
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
if (el[0].is_number_integer()) {
|
|
317
|
+
llama_token tok = el[0].get<llama_token>();
|
|
318
|
+
if (tok >= 0 && tok < n_vocab) {
|
|
319
|
+
params.sampling.logit_bias.push_back({tok, bias});
|
|
320
|
+
}
|
|
321
|
+
} else if (el[0].is_string()) {
|
|
322
|
+
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
|
|
323
|
+
for (auto tok : toks) {
|
|
324
|
+
params.sampling.logit_bias.push_back({tok, bias});
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
}
|
|
329
|
+
}
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
{
|
|
333
|
+
params.antiprompt.clear();
|
|
334
|
+
|
|
335
|
+
const auto & stop = data.find("stop");
|
|
336
|
+
if (stop != data.end() && stop->is_array()) {
|
|
337
|
+
for (const auto & word : *stop) {
|
|
338
|
+
if (!word.empty()) {
|
|
339
|
+
params.antiprompt.push_back(word);
|
|
340
|
+
}
|
|
341
|
+
}
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
|
|
345
|
+
{
|
|
346
|
+
const auto & samplers = data.find("samplers");
|
|
347
|
+
if (samplers != data.end()) {
|
|
348
|
+
if (samplers->is_array()) {
|
|
349
|
+
std::vector<std::string> sampler_names;
|
|
350
|
+
for (const auto & name : *samplers) {
|
|
351
|
+
if (name.is_string()) {
|
|
352
|
+
sampler_names.emplace_back(name);
|
|
353
|
+
}
|
|
354
|
+
}
|
|
355
|
+
params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
|
|
356
|
+
} else if (samplers->is_string()){
|
|
357
|
+
std::string sampler_string;
|
|
358
|
+
for (const auto & name : *samplers) {
|
|
359
|
+
sampler_string += name;
|
|
360
|
+
}
|
|
361
|
+
params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
|
|
362
|
+
}
|
|
363
|
+
} else {
|
|
364
|
+
params.sampling.samplers = defaults.sampling.samplers;
|
|
365
|
+
}
|
|
366
|
+
}
|
|
367
|
+
|
|
368
|
+
std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
|
|
369
|
+
params.oaicompat_model = json_value(data, "model", model_name);
|
|
370
|
+
|
|
371
|
+
return params;
|
|
372
|
+
}
|
|
85
373
|
|
|
86
374
|
// utility function
|
|
87
375
|
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
|
@@ -93,40 +381,628 @@ struct server_task {
|
|
|
93
381
|
}
|
|
94
382
|
};
|
|
95
383
|
|
|
384
|
+
struct result_timings {
|
|
385
|
+
int32_t prompt_n = -1;
|
|
386
|
+
double prompt_ms;
|
|
387
|
+
double prompt_per_token_ms;
|
|
388
|
+
double prompt_per_second;
|
|
389
|
+
|
|
390
|
+
int32_t predicted_n = -1;
|
|
391
|
+
double predicted_ms;
|
|
392
|
+
double predicted_per_token_ms;
|
|
393
|
+
double predicted_per_second;
|
|
394
|
+
|
|
395
|
+
json to_json() const {
|
|
396
|
+
return {
|
|
397
|
+
{"prompt_n", prompt_n},
|
|
398
|
+
{"prompt_ms", prompt_ms},
|
|
399
|
+
{"prompt_per_token_ms", prompt_per_token_ms},
|
|
400
|
+
{"prompt_per_second", prompt_per_second},
|
|
401
|
+
|
|
402
|
+
{"predicted_n", predicted_n},
|
|
403
|
+
{"predicted_ms", predicted_ms},
|
|
404
|
+
{"predicted_per_token_ms", predicted_per_token_ms},
|
|
405
|
+
{"predicted_per_second", predicted_per_second},
|
|
406
|
+
};
|
|
407
|
+
}
|
|
408
|
+
};
|
|
409
|
+
|
|
96
410
|
struct server_task_result {
|
|
97
|
-
int id
|
|
411
|
+
int id = -1;
|
|
412
|
+
int id_slot = -1;
|
|
413
|
+
virtual bool is_error() {
|
|
414
|
+
// only used by server_task_result_error
|
|
415
|
+
return false;
|
|
416
|
+
}
|
|
417
|
+
virtual bool is_stop() {
|
|
418
|
+
// only used by server_task_result_cmpl_*
|
|
419
|
+
return false;
|
|
420
|
+
}
|
|
421
|
+
virtual int get_index() {
|
|
422
|
+
return -1;
|
|
423
|
+
}
|
|
424
|
+
virtual json to_json() = 0;
|
|
425
|
+
virtual ~server_task_result() = default;
|
|
426
|
+
};
|
|
98
427
|
|
|
99
|
-
|
|
428
|
+
// using shared_ptr for polymorphism of server_task_result
|
|
429
|
+
using server_task_result_ptr = std::unique_ptr<server_task_result>;
|
|
100
430
|
|
|
101
|
-
|
|
102
|
-
|
|
431
|
+
inline std::string stop_type_to_str(stop_type type) {
|
|
432
|
+
switch (type) {
|
|
433
|
+
case STOP_TYPE_EOS: return "eos";
|
|
434
|
+
case STOP_TYPE_WORD: return "word";
|
|
435
|
+
case STOP_TYPE_LIMIT: return "limit";
|
|
436
|
+
default: return "none";
|
|
437
|
+
}
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
struct completion_token_output {
|
|
441
|
+
llama_token tok;
|
|
442
|
+
float prob;
|
|
443
|
+
std::string text_to_send;
|
|
444
|
+
struct prob_info {
|
|
445
|
+
llama_token tok;
|
|
446
|
+
std::string txt;
|
|
447
|
+
float prob;
|
|
448
|
+
};
|
|
449
|
+
std::vector<prob_info> probs;
|
|
450
|
+
|
|
451
|
+
json to_json(bool post_sampling_probs) const {
|
|
452
|
+
json probs_for_token = json::array();
|
|
453
|
+
for (const auto & p : probs) {
|
|
454
|
+
std::string txt(p.txt);
|
|
455
|
+
txt.resize(validate_utf8(txt));
|
|
456
|
+
probs_for_token.push_back(json {
|
|
457
|
+
{"id", p.tok},
|
|
458
|
+
{"token", txt},
|
|
459
|
+
{"bytes", str_to_bytes(p.txt)},
|
|
460
|
+
{
|
|
461
|
+
post_sampling_probs ? "prob" : "logprob",
|
|
462
|
+
post_sampling_probs ? p.prob : logarithm(p.prob)
|
|
463
|
+
},
|
|
464
|
+
});
|
|
465
|
+
}
|
|
466
|
+
return probs_for_token;
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
|
|
470
|
+
json out = json::array();
|
|
471
|
+
for (const auto & p : probs) {
|
|
472
|
+
std::string txt(p.text_to_send);
|
|
473
|
+
txt.resize(validate_utf8(txt));
|
|
474
|
+
out.push_back(json {
|
|
475
|
+
{"id", p.tok},
|
|
476
|
+
{"token", txt},
|
|
477
|
+
{"bytes", str_to_bytes(p.text_to_send)},
|
|
478
|
+
{
|
|
479
|
+
post_sampling_probs ? "prob" : "logprob",
|
|
480
|
+
post_sampling_probs ? p.prob : logarithm(p.prob)
|
|
481
|
+
},
|
|
482
|
+
{
|
|
483
|
+
post_sampling_probs ? "top_probs" : "top_logprobs",
|
|
484
|
+
p.to_json(post_sampling_probs)
|
|
485
|
+
},
|
|
486
|
+
});
|
|
487
|
+
}
|
|
488
|
+
return out;
|
|
489
|
+
}
|
|
490
|
+
|
|
491
|
+
static float logarithm(float x) {
|
|
492
|
+
// nlohmann::json converts -inf to null, so we need to prevent that
|
|
493
|
+
return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
|
|
494
|
+
}
|
|
495
|
+
|
|
496
|
+
static std::vector<unsigned char> str_to_bytes(const std::string & str) {
|
|
497
|
+
std::vector<unsigned char> bytes;
|
|
498
|
+
for (unsigned char c : str) {
|
|
499
|
+
bytes.push_back(c);
|
|
500
|
+
}
|
|
501
|
+
return bytes;
|
|
502
|
+
}
|
|
103
503
|
};
|
|
104
504
|
|
|
105
|
-
struct
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
505
|
+
struct server_task_result_cmpl_final : server_task_result {
|
|
506
|
+
int index = 0;
|
|
507
|
+
|
|
508
|
+
std::string content;
|
|
509
|
+
llama_tokens tokens;
|
|
510
|
+
|
|
511
|
+
bool stream;
|
|
512
|
+
result_timings timings;
|
|
513
|
+
std::string prompt;
|
|
514
|
+
|
|
515
|
+
bool truncated;
|
|
516
|
+
int32_t n_decoded;
|
|
517
|
+
int32_t n_prompt_tokens;
|
|
518
|
+
int32_t n_tokens_cached;
|
|
519
|
+
bool has_new_line;
|
|
520
|
+
std::string stopping_word;
|
|
521
|
+
stop_type stop = STOP_TYPE_NONE;
|
|
522
|
+
|
|
523
|
+
bool post_sampling_probs;
|
|
524
|
+
std::vector<completion_token_output> probs_output;
|
|
525
|
+
|
|
526
|
+
slot_params generation_params;
|
|
527
|
+
|
|
528
|
+
// OAI-compat fields
|
|
529
|
+
bool verbose = false;
|
|
530
|
+
bool oaicompat = false;
|
|
531
|
+
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
|
|
532
|
+
std::string oaicompat_model;
|
|
533
|
+
std::string oaicompat_cmpl_id;
|
|
534
|
+
|
|
535
|
+
virtual int get_index() override {
|
|
536
|
+
return index;
|
|
537
|
+
}
|
|
538
|
+
|
|
539
|
+
virtual bool is_stop() override {
|
|
540
|
+
return true; // in stream mode, final responses are considered stop
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
virtual json to_json() override {
|
|
544
|
+
return oaicompat
|
|
545
|
+
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
|
|
546
|
+
: to_json_non_oaicompat();
|
|
547
|
+
}
|
|
548
|
+
|
|
549
|
+
json to_json_non_oaicompat() {
|
|
550
|
+
json res = json {
|
|
551
|
+
{"index", index},
|
|
552
|
+
{"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
|
553
|
+
{"tokens", stream ? llama_tokens {} : tokens},
|
|
554
|
+
{"id_slot", id_slot},
|
|
555
|
+
{"stop", true},
|
|
556
|
+
{"model", oaicompat_model},
|
|
557
|
+
{"tokens_predicted", n_decoded},
|
|
558
|
+
{"tokens_evaluated", n_prompt_tokens},
|
|
559
|
+
{"generation_settings", generation_params.to_json()},
|
|
560
|
+
{"prompt", prompt},
|
|
561
|
+
{"has_new_line", has_new_line},
|
|
562
|
+
{"truncated", truncated},
|
|
563
|
+
{"stop_type", stop_type_to_str(stop)},
|
|
564
|
+
{"stopping_word", stopping_word},
|
|
565
|
+
{"tokens_cached", n_tokens_cached},
|
|
566
|
+
{"timings", timings.to_json()},
|
|
567
|
+
};
|
|
568
|
+
if (!stream && !probs_output.empty()) {
|
|
569
|
+
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
|
|
570
|
+
}
|
|
571
|
+
return res;
|
|
572
|
+
}
|
|
573
|
+
|
|
574
|
+
json to_json_oaicompat_chat() {
|
|
575
|
+
std::string finish_reason = "length";
|
|
576
|
+
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
577
|
+
finish_reason = "stop";
|
|
578
|
+
}
|
|
579
|
+
|
|
580
|
+
json choice = json{
|
|
581
|
+
{"finish_reason", finish_reason},
|
|
582
|
+
{"index", 0},
|
|
583
|
+
{"message", json {
|
|
584
|
+
{"content", content},
|
|
585
|
+
{"role", "assistant"}
|
|
586
|
+
}
|
|
587
|
+
}};
|
|
588
|
+
|
|
589
|
+
if (!stream && probs_output.size() > 0) {
|
|
590
|
+
choice["logprobs"] = json{
|
|
591
|
+
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
|
592
|
+
};
|
|
593
|
+
}
|
|
594
|
+
|
|
595
|
+
std::time_t t = std::time(0);
|
|
596
|
+
|
|
597
|
+
json res = json {
|
|
598
|
+
{"choices", json::array({choice})},
|
|
599
|
+
{"created", t},
|
|
600
|
+
{"model", oaicompat_model},
|
|
601
|
+
{"object", "chat.completion"},
|
|
602
|
+
{"usage", json {
|
|
603
|
+
{"completion_tokens", n_decoded},
|
|
604
|
+
{"prompt_tokens", n_prompt_tokens},
|
|
605
|
+
{"total_tokens", n_decoded + n_prompt_tokens}
|
|
606
|
+
}},
|
|
607
|
+
{"id", oaicompat_cmpl_id}
|
|
608
|
+
};
|
|
609
|
+
|
|
610
|
+
// extra fields for debugging purposes
|
|
611
|
+
if (verbose) {
|
|
612
|
+
res["__verbose"] = to_json_non_oaicompat();
|
|
613
|
+
}
|
|
614
|
+
if (timings.prompt_n >= 0) {
|
|
615
|
+
res.push_back({"timings", timings.to_json()});
|
|
616
|
+
}
|
|
617
|
+
|
|
618
|
+
return res;
|
|
619
|
+
}
|
|
620
|
+
|
|
621
|
+
json to_json_oaicompat_chat_stream() {
|
|
622
|
+
std::time_t t = std::time(0);
|
|
623
|
+
std::string finish_reason = "length";
|
|
624
|
+
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
|
625
|
+
finish_reason = "stop";
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
json choice = json{
|
|
629
|
+
{"finish_reason", finish_reason},
|
|
630
|
+
{"index", 0},
|
|
631
|
+
{"delta", json::object()}
|
|
632
|
+
};
|
|
633
|
+
|
|
634
|
+
json ret = json {
|
|
635
|
+
{"choices", json::array({choice})},
|
|
636
|
+
{"created", t},
|
|
637
|
+
{"id", oaicompat_cmpl_id},
|
|
638
|
+
{"model", oaicompat_model},
|
|
639
|
+
{"object", "chat.completion.chunk"},
|
|
640
|
+
{"usage", json {
|
|
641
|
+
{"completion_tokens", n_decoded},
|
|
642
|
+
{"prompt_tokens", n_prompt_tokens},
|
|
643
|
+
{"total_tokens", n_decoded + n_prompt_tokens},
|
|
644
|
+
}},
|
|
645
|
+
};
|
|
646
|
+
|
|
647
|
+
if (timings.prompt_n >= 0) {
|
|
648
|
+
ret.push_back({"timings", timings.to_json()});
|
|
649
|
+
}
|
|
650
|
+
|
|
651
|
+
return ret;
|
|
652
|
+
}
|
|
109
653
|
};
|
|
110
654
|
|
|
111
|
-
struct
|
|
112
|
-
|
|
113
|
-
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
|
|
655
|
+
struct server_task_result_cmpl_partial : server_task_result {
|
|
656
|
+
int index = 0;
|
|
114
657
|
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
int32_t n_predict = -1; // new tokens to predict
|
|
118
|
-
int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
|
|
658
|
+
std::string content;
|
|
659
|
+
llama_tokens tokens;
|
|
119
660
|
|
|
120
|
-
|
|
121
|
-
|
|
661
|
+
int32_t n_decoded;
|
|
662
|
+
int32_t n_prompt_tokens;
|
|
122
663
|
|
|
123
|
-
|
|
664
|
+
bool post_sampling_probs;
|
|
665
|
+
completion_token_output prob_output;
|
|
666
|
+
result_timings timings;
|
|
667
|
+
|
|
668
|
+
// OAI-compat fields
|
|
669
|
+
bool verbose = false;
|
|
670
|
+
bool oaicompat = false;
|
|
671
|
+
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
|
|
672
|
+
std::string oaicompat_model;
|
|
673
|
+
std::string oaicompat_cmpl_id;
|
|
674
|
+
|
|
675
|
+
virtual int get_index() override {
|
|
676
|
+
return index;
|
|
677
|
+
}
|
|
678
|
+
|
|
679
|
+
virtual bool is_stop() override {
|
|
680
|
+
return false; // in stream mode, partial responses are not considered stop
|
|
681
|
+
}
|
|
682
|
+
|
|
683
|
+
virtual json to_json() override {
|
|
684
|
+
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
|
685
|
+
}
|
|
686
|
+
|
|
687
|
+
json to_json_non_oaicompat() {
|
|
688
|
+
// non-OAI-compat JSON
|
|
689
|
+
json res = json {
|
|
690
|
+
{"index", index},
|
|
691
|
+
{"content", content},
|
|
692
|
+
{"tokens", tokens},
|
|
693
|
+
{"stop", false},
|
|
694
|
+
{"id_slot", id_slot},
|
|
695
|
+
{"tokens_predicted", n_decoded},
|
|
696
|
+
{"tokens_evaluated", n_prompt_tokens},
|
|
697
|
+
};
|
|
698
|
+
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
|
699
|
+
if (timings.prompt_n > 0) {
|
|
700
|
+
res.push_back({"timings", timings.to_json()});
|
|
701
|
+
}
|
|
702
|
+
if (!prob_output.probs.empty()) {
|
|
703
|
+
res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
|
|
704
|
+
}
|
|
705
|
+
return res;
|
|
706
|
+
}
|
|
707
|
+
|
|
708
|
+
json to_json_oaicompat() {
|
|
709
|
+
bool first = n_decoded == 0;
|
|
710
|
+
std::time_t t = std::time(0);
|
|
711
|
+
json choices;
|
|
712
|
+
|
|
713
|
+
if (first) {
|
|
714
|
+
if (content.empty()) {
|
|
715
|
+
choices = json::array({json{{"finish_reason", nullptr},
|
|
716
|
+
{"index", 0},
|
|
717
|
+
{"delta", json{{"role", "assistant"}}}}});
|
|
718
|
+
} else {
|
|
719
|
+
// We have to send this as two updates to conform to openai behavior
|
|
720
|
+
json initial_ret = json{{"choices", json::array({json{
|
|
721
|
+
{"finish_reason", nullptr},
|
|
722
|
+
{"index", 0},
|
|
723
|
+
{"delta", json{
|
|
724
|
+
{"role", "assistant"}
|
|
725
|
+
}}}})},
|
|
726
|
+
{"created", t},
|
|
727
|
+
{"id", oaicompat_cmpl_id},
|
|
728
|
+
{"model", oaicompat_model},
|
|
729
|
+
{"object", "chat.completion.chunk"}};
|
|
730
|
+
|
|
731
|
+
json second_ret = json{
|
|
732
|
+
{"choices", json::array({json{{"finish_reason", nullptr},
|
|
733
|
+
{"index", 0},
|
|
734
|
+
{"delta", json {
|
|
735
|
+
{"content", content}}}
|
|
736
|
+
}})},
|
|
737
|
+
{"created", t},
|
|
738
|
+
{"id", oaicompat_cmpl_id},
|
|
739
|
+
{"model", oaicompat_model},
|
|
740
|
+
{"object", "chat.completion.chunk"}};
|
|
741
|
+
|
|
742
|
+
return std::vector<json>({initial_ret, second_ret});
|
|
743
|
+
}
|
|
744
|
+
} else {
|
|
745
|
+
choices = json::array({json{
|
|
746
|
+
{"finish_reason", nullptr},
|
|
747
|
+
{"index", 0},
|
|
748
|
+
{"delta",
|
|
749
|
+
json {
|
|
750
|
+
{"content", content},
|
|
751
|
+
}},
|
|
752
|
+
}});
|
|
753
|
+
}
|
|
754
|
+
|
|
755
|
+
GGML_ASSERT(choices.size() >= 1);
|
|
756
|
+
|
|
757
|
+
if (prob_output.probs.size() > 0) {
|
|
758
|
+
choices[0]["logprobs"] = json{
|
|
759
|
+
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
|
760
|
+
};
|
|
761
|
+
}
|
|
762
|
+
|
|
763
|
+
json ret = json {
|
|
764
|
+
{"choices", choices},
|
|
765
|
+
{"created", t},
|
|
766
|
+
{"id", oaicompat_cmpl_id},
|
|
767
|
+
{"model", oaicompat_model},
|
|
768
|
+
{"object", "chat.completion.chunk"}
|
|
769
|
+
};
|
|
770
|
+
|
|
771
|
+
if (timings.prompt_n >= 0) {
|
|
772
|
+
ret.push_back({"timings", timings.to_json()});
|
|
773
|
+
}
|
|
774
|
+
|
|
775
|
+
return std::vector<json>({ret});
|
|
776
|
+
}
|
|
777
|
+
};
|
|
778
|
+
|
|
779
|
+
struct server_task_result_embd : server_task_result {
|
|
780
|
+
int index = 0;
|
|
781
|
+
std::vector<std::vector<float>> embedding;
|
|
782
|
+
|
|
783
|
+
int32_t n_tokens;
|
|
784
|
+
|
|
785
|
+
// OAI-compat fields
|
|
786
|
+
bool oaicompat = false;
|
|
787
|
+
|
|
788
|
+
virtual int get_index() override {
|
|
789
|
+
return index;
|
|
790
|
+
}
|
|
791
|
+
|
|
792
|
+
virtual json to_json() override {
|
|
793
|
+
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
json to_json_non_oaicompat() {
|
|
797
|
+
return json {
|
|
798
|
+
{"index", index},
|
|
799
|
+
{"embedding", embedding},
|
|
800
|
+
};
|
|
801
|
+
}
|
|
802
|
+
|
|
803
|
+
json to_json_oaicompat() {
|
|
804
|
+
return json {
|
|
805
|
+
{"index", index},
|
|
806
|
+
{"embedding", embedding[0]},
|
|
807
|
+
{"tokens_evaluated", n_tokens},
|
|
808
|
+
};
|
|
809
|
+
}
|
|
810
|
+
};
|
|
811
|
+
|
|
812
|
+
struct server_task_result_rerank : server_task_result {
|
|
813
|
+
int index = 0;
|
|
814
|
+
float score = -1e6;
|
|
815
|
+
|
|
816
|
+
int32_t n_tokens;
|
|
817
|
+
|
|
818
|
+
virtual int get_index() override {
|
|
819
|
+
return index;
|
|
820
|
+
}
|
|
821
|
+
|
|
822
|
+
virtual json to_json() override {
|
|
823
|
+
return json {
|
|
824
|
+
{"index", index},
|
|
825
|
+
{"score", score},
|
|
826
|
+
{"tokens_evaluated", n_tokens},
|
|
827
|
+
};
|
|
828
|
+
}
|
|
829
|
+
};
|
|
830
|
+
|
|
831
|
+
// this function maybe used outside of server_task_result_error
|
|
832
|
+
static json format_error_response(const std::string & message, const enum error_type type) {
|
|
833
|
+
std::string type_str;
|
|
834
|
+
int code = 500;
|
|
835
|
+
switch (type) {
|
|
836
|
+
case ERROR_TYPE_INVALID_REQUEST:
|
|
837
|
+
type_str = "invalid_request_error";
|
|
838
|
+
code = 400;
|
|
839
|
+
break;
|
|
840
|
+
case ERROR_TYPE_AUTHENTICATION:
|
|
841
|
+
type_str = "authentication_error";
|
|
842
|
+
code = 401;
|
|
843
|
+
break;
|
|
844
|
+
case ERROR_TYPE_NOT_FOUND:
|
|
845
|
+
type_str = "not_found_error";
|
|
846
|
+
code = 404;
|
|
847
|
+
break;
|
|
848
|
+
case ERROR_TYPE_SERVER:
|
|
849
|
+
type_str = "server_error";
|
|
850
|
+
code = 500;
|
|
851
|
+
break;
|
|
852
|
+
case ERROR_TYPE_PERMISSION:
|
|
853
|
+
type_str = "permission_error";
|
|
854
|
+
code = 403;
|
|
855
|
+
break;
|
|
856
|
+
case ERROR_TYPE_NOT_SUPPORTED:
|
|
857
|
+
type_str = "not_supported_error";
|
|
858
|
+
code = 501;
|
|
859
|
+
break;
|
|
860
|
+
case ERROR_TYPE_UNAVAILABLE:
|
|
861
|
+
type_str = "unavailable_error";
|
|
862
|
+
code = 503;
|
|
863
|
+
break;
|
|
864
|
+
}
|
|
865
|
+
return json {
|
|
866
|
+
{"code", code},
|
|
867
|
+
{"message", message},
|
|
868
|
+
{"type", type_str},
|
|
869
|
+
};
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
struct server_task_result_error : server_task_result {
|
|
873
|
+
int index = 0;
|
|
874
|
+
error_type err_type = ERROR_TYPE_SERVER;
|
|
875
|
+
std::string err_msg;
|
|
876
|
+
|
|
877
|
+
virtual bool is_error() override {
|
|
878
|
+
return true;
|
|
879
|
+
}
|
|
880
|
+
|
|
881
|
+
virtual json to_json() override {
|
|
882
|
+
return format_error_response(err_msg, err_type);
|
|
883
|
+
}
|
|
884
|
+
};
|
|
885
|
+
|
|
886
|
+
struct server_task_result_metrics : server_task_result {
|
|
887
|
+
int n_idle_slots;
|
|
888
|
+
int n_processing_slots;
|
|
889
|
+
int n_tasks_deferred;
|
|
890
|
+
int64_t t_start;
|
|
891
|
+
|
|
892
|
+
int32_t kv_cache_tokens_count;
|
|
893
|
+
int32_t kv_cache_used_cells;
|
|
894
|
+
|
|
895
|
+
// TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
|
|
896
|
+
uint64_t n_prompt_tokens_processed_total = 0;
|
|
897
|
+
uint64_t t_prompt_processing_total = 0;
|
|
898
|
+
uint64_t n_tokens_predicted_total = 0;
|
|
899
|
+
uint64_t t_tokens_generation_total = 0;
|
|
900
|
+
|
|
901
|
+
uint64_t n_prompt_tokens_processed = 0;
|
|
902
|
+
uint64_t t_prompt_processing = 0;
|
|
903
|
+
|
|
904
|
+
uint64_t n_tokens_predicted = 0;
|
|
905
|
+
uint64_t t_tokens_generation = 0;
|
|
906
|
+
|
|
907
|
+
uint64_t n_decode_total = 0;
|
|
908
|
+
uint64_t n_busy_slots_total = 0;
|
|
909
|
+
|
|
910
|
+
// while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
|
|
911
|
+
// therefore, we use json to temporarily store the slot.to_json() result
|
|
912
|
+
json slots_data = json::array();
|
|
913
|
+
|
|
914
|
+
virtual json to_json() override {
|
|
915
|
+
return json {
|
|
916
|
+
{ "idle", n_idle_slots },
|
|
917
|
+
{ "processing", n_processing_slots },
|
|
918
|
+
{ "deferred", n_tasks_deferred },
|
|
919
|
+
{ "t_start", t_start },
|
|
920
|
+
|
|
921
|
+
{ "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
|
|
922
|
+
{ "t_tokens_generation_total", t_tokens_generation_total },
|
|
923
|
+
{ "n_tokens_predicted_total", n_tokens_predicted_total },
|
|
924
|
+
{ "t_prompt_processing_total", t_prompt_processing_total },
|
|
925
|
+
|
|
926
|
+
{ "n_prompt_tokens_processed", n_prompt_tokens_processed },
|
|
927
|
+
{ "t_prompt_processing", t_prompt_processing },
|
|
928
|
+
{ "n_tokens_predicted", n_tokens_predicted },
|
|
929
|
+
{ "t_tokens_generation", t_tokens_generation },
|
|
930
|
+
|
|
931
|
+
{ "n_decode_total", n_decode_total },
|
|
932
|
+
{ "n_busy_slots_total", n_busy_slots_total },
|
|
933
|
+
|
|
934
|
+
{ "kv_cache_tokens_count", kv_cache_tokens_count },
|
|
935
|
+
{ "kv_cache_used_cells", kv_cache_used_cells },
|
|
936
|
+
|
|
937
|
+
{ "slots", slots_data },
|
|
938
|
+
};
|
|
939
|
+
}
|
|
940
|
+
};
|
|
941
|
+
|
|
942
|
+
struct server_task_result_slot_save_load : server_task_result {
|
|
943
|
+
std::string filename;
|
|
944
|
+
bool is_save; // true = save, false = load
|
|
945
|
+
|
|
946
|
+
size_t n_tokens;
|
|
947
|
+
size_t n_bytes;
|
|
948
|
+
double t_ms;
|
|
949
|
+
|
|
950
|
+
virtual json to_json() override {
|
|
951
|
+
if (is_save) {
|
|
952
|
+
return json {
|
|
953
|
+
{ "id_slot", id_slot },
|
|
954
|
+
{ "filename", filename },
|
|
955
|
+
{ "n_saved", n_tokens },
|
|
956
|
+
{ "n_written", n_bytes },
|
|
957
|
+
{ "timings", {
|
|
958
|
+
{ "save_ms", t_ms }
|
|
959
|
+
}},
|
|
960
|
+
};
|
|
961
|
+
} else {
|
|
962
|
+
return json {
|
|
963
|
+
{ "id_slot", id_slot },
|
|
964
|
+
{ "filename", filename },
|
|
965
|
+
{ "n_restored", n_tokens },
|
|
966
|
+
{ "n_read", n_bytes },
|
|
967
|
+
{ "timings", {
|
|
968
|
+
{ "restore_ms", t_ms }
|
|
969
|
+
}},
|
|
970
|
+
};
|
|
971
|
+
}
|
|
972
|
+
}
|
|
973
|
+
};
|
|
974
|
+
|
|
975
|
+
struct server_task_result_slot_erase : server_task_result {
|
|
976
|
+
size_t n_erased;
|
|
977
|
+
|
|
978
|
+
virtual json to_json() override {
|
|
979
|
+
return json {
|
|
980
|
+
{ "id_slot", id_slot },
|
|
981
|
+
{ "n_erased", n_erased },
|
|
982
|
+
};
|
|
983
|
+
}
|
|
984
|
+
};
|
|
985
|
+
|
|
986
|
+
struct server_task_result_apply_lora : server_task_result {
|
|
987
|
+
virtual json to_json() override {
|
|
988
|
+
return json {{ "success", true }};
|
|
989
|
+
}
|
|
124
990
|
};
|
|
125
991
|
|
|
126
992
|
struct server_slot {
|
|
127
993
|
int id;
|
|
128
994
|
int id_task = -1;
|
|
129
995
|
|
|
996
|
+
// only used for completion/embedding/infill/rerank
|
|
997
|
+
server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
|
|
998
|
+
|
|
999
|
+
llama_batch batch_spec = {};
|
|
1000
|
+
|
|
1001
|
+
llama_context * ctx = nullptr;
|
|
1002
|
+
llama_context * ctx_dft = nullptr;
|
|
1003
|
+
|
|
1004
|
+
common_speculative * spec = nullptr;
|
|
1005
|
+
|
|
130
1006
|
// the index relative to completion multi-task request
|
|
131
1007
|
size_t index = 0;
|
|
132
1008
|
|
|
@@ -154,35 +1030,29 @@ struct server_slot {
|
|
|
154
1030
|
|
|
155
1031
|
size_t last_nl_pos = 0;
|
|
156
1032
|
|
|
157
|
-
std::string
|
|
1033
|
+
std::string generated_text;
|
|
1034
|
+
llama_tokens generated_tokens;
|
|
1035
|
+
|
|
158
1036
|
llama_tokens cache_tokens;
|
|
159
|
-
std::vector<completion_token_output> generated_token_probs;
|
|
160
1037
|
|
|
161
|
-
|
|
1038
|
+
std::vector<completion_token_output> generated_token_probs;
|
|
162
1039
|
|
|
163
1040
|
bool has_next_token = true;
|
|
164
1041
|
bool has_new_line = false;
|
|
165
1042
|
bool truncated = false;
|
|
166
|
-
|
|
167
|
-
bool stopped_word = false;
|
|
168
|
-
bool stopped_limit = false;
|
|
1043
|
+
stop_type stop;
|
|
169
1044
|
|
|
170
|
-
bool oaicompat = false;
|
|
171
|
-
|
|
172
|
-
std::string oaicompat_model;
|
|
173
1045
|
std::string stopping_word;
|
|
174
1046
|
|
|
175
1047
|
// sampling
|
|
176
1048
|
json json_schema;
|
|
177
1049
|
|
|
178
|
-
struct common_sampler_params sparams;
|
|
179
1050
|
struct common_sampler * smpl = nullptr;
|
|
180
1051
|
|
|
181
1052
|
llama_token sampled;
|
|
182
1053
|
|
|
183
1054
|
// stats
|
|
184
1055
|
size_t n_sent_text = 0; // number of sent text character
|
|
185
|
-
size_t n_sent_token_probs = 0;
|
|
186
1056
|
|
|
187
1057
|
int64_t t_start_process_prompt;
|
|
188
1058
|
int64_t t_start_generation;
|
|
@@ -200,19 +1070,21 @@ struct server_slot {
|
|
|
200
1070
|
generated_text = "";
|
|
201
1071
|
has_new_line = false;
|
|
202
1072
|
truncated = false;
|
|
203
|
-
|
|
204
|
-
stopped_word = false;
|
|
205
|
-
stopped_limit = false;
|
|
1073
|
+
stop = STOP_TYPE_NONE;
|
|
206
1074
|
stopping_word = "";
|
|
207
1075
|
n_past = 0;
|
|
208
1076
|
n_sent_text = 0;
|
|
209
|
-
|
|
210
|
-
inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
|
|
1077
|
+
task_type = SERVER_TASK_TYPE_COMPLETION;
|
|
211
1078
|
|
|
1079
|
+
generated_tokens.clear();
|
|
212
1080
|
generated_token_probs.clear();
|
|
213
1081
|
}
|
|
214
1082
|
|
|
215
|
-
bool
|
|
1083
|
+
bool is_non_causal() const {
|
|
1084
|
+
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
|
1085
|
+
}
|
|
1086
|
+
|
|
1087
|
+
bool has_budget(const common_params & global_params) {
|
|
216
1088
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
|
217
1089
|
return true; // limitless
|
|
218
1090
|
}
|
|
@@ -232,6 +1104,10 @@ struct server_slot {
|
|
|
232
1104
|
return state != SLOT_STATE_IDLE;
|
|
233
1105
|
}
|
|
234
1106
|
|
|
1107
|
+
bool can_speculate() const {
|
|
1108
|
+
return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
|
|
1109
|
+
}
|
|
1110
|
+
|
|
235
1111
|
void add_token(const completion_token_output & token) {
|
|
236
1112
|
if (!is_processing()) {
|
|
237
1113
|
SLT_WRN(*this, "%s", "slot is not processing\n");
|
|
@@ -251,38 +1127,40 @@ struct server_slot {
|
|
|
251
1127
|
}
|
|
252
1128
|
}
|
|
253
1129
|
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
1130
|
+
result_timings get_timings() const {
|
|
1131
|
+
result_timings timings;
|
|
1132
|
+
timings.prompt_n = n_prompt_tokens_processed;
|
|
1133
|
+
timings.prompt_ms = t_prompt_processing;
|
|
1134
|
+
timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
|
|
1135
|
+
timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
1136
|
+
|
|
1137
|
+
timings.predicted_n = n_decoded;
|
|
1138
|
+
timings.predicted_ms = t_token_generation;
|
|
1139
|
+
timings.predicted_per_token_ms = t_token_generation / n_decoded;
|
|
1140
|
+
timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
|
|
1141
|
+
|
|
1142
|
+
return timings;
|
|
266
1143
|
}
|
|
267
1144
|
|
|
268
|
-
size_t find_stopping_strings(const std::string & text, const size_t last_token_size,
|
|
1145
|
+
size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
|
|
269
1146
|
size_t stop_pos = std::string::npos;
|
|
270
1147
|
|
|
271
1148
|
for (const std::string & word : params.antiprompt) {
|
|
272
1149
|
size_t pos;
|
|
273
1150
|
|
|
274
|
-
if (
|
|
1151
|
+
if (is_full_stop) {
|
|
275
1152
|
const size_t tmp = word.size() + last_token_size;
|
|
276
1153
|
const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
|
|
277
1154
|
|
|
278
1155
|
pos = text.find(word, from_pos);
|
|
279
1156
|
} else {
|
|
1157
|
+
// otherwise, partial stop
|
|
280
1158
|
pos = find_partial_stop_string(word, text);
|
|
281
1159
|
}
|
|
282
1160
|
|
|
283
1161
|
if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
|
|
284
|
-
if (
|
|
285
|
-
|
|
1162
|
+
if (is_full_stop) {
|
|
1163
|
+
stop = STOP_TYPE_WORD;
|
|
286
1164
|
stopping_word = word;
|
|
287
1165
|
has_next_token = false;
|
|
288
1166
|
}
|
|
@@ -302,13 +1180,35 @@ struct server_slot {
|
|
|
302
1180
|
|
|
303
1181
|
SLT_INF(*this,
|
|
304
1182
|
"\n"
|
|
305
|
-
"
|
|
306
|
-
"
|
|
307
|
-
"
|
|
1183
|
+
"prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
1184
|
+
" eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
1185
|
+
" total time = %10.2f ms / %5d tokens\n",
|
|
308
1186
|
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
|
309
1187
|
t_token_generation, n_decoded, t_gen, n_gen_second,
|
|
310
1188
|
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
|
311
1189
|
}
|
|
1190
|
+
|
|
1191
|
+
json to_json() const {
|
|
1192
|
+
return json {
|
|
1193
|
+
{"id", id},
|
|
1194
|
+
{"id_task", id_task},
|
|
1195
|
+
{"n_ctx", n_ctx},
|
|
1196
|
+
{"speculative", can_speculate()},
|
|
1197
|
+
{"is_processing", is_processing()},
|
|
1198
|
+
{"non_causal", is_non_causal()},
|
|
1199
|
+
{"params", params.to_json()},
|
|
1200
|
+
{"prompt", common_detokenize(ctx, prompt_tokens)},
|
|
1201
|
+
{"next_token",
|
|
1202
|
+
{
|
|
1203
|
+
{"has_next_token", has_next_token},
|
|
1204
|
+
{"has_new_line", has_new_line},
|
|
1205
|
+
{"n_remain", n_remaining},
|
|
1206
|
+
{"n_decoded", n_decoded},
|
|
1207
|
+
{"stopping_word", stopping_word},
|
|
1208
|
+
}
|
|
1209
|
+
},
|
|
1210
|
+
};
|
|
1211
|
+
}
|
|
312
1212
|
};
|
|
313
1213
|
|
|
314
1214
|
struct server_metrics {
|
|
@@ -381,9 +1281,7 @@ struct server_queue {
|
|
|
381
1281
|
// Add a new task to the end of the queue
|
|
382
1282
|
int post(server_task task, bool front = false) {
|
|
383
1283
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
384
|
-
|
|
385
|
-
task.id = id++;
|
|
386
|
-
}
|
|
1284
|
+
GGML_ASSERT(task.id != -1);
|
|
387
1285
|
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
|
|
388
1286
|
if (front) {
|
|
389
1287
|
queue_tasks.push_front(std::move(task));
|
|
@@ -507,8 +1405,8 @@ struct server_response {
|
|
|
507
1405
|
// for keeping track of all tasks waiting for the result
|
|
508
1406
|
std::unordered_set<int> waiting_task_ids;
|
|
509
1407
|
|
|
510
|
-
// the main result queue
|
|
511
|
-
std::vector<
|
|
1408
|
+
// the main result queue (using ptr for polymorphism)
|
|
1409
|
+
std::vector<server_task_result_ptr> queue_results;
|
|
512
1410
|
|
|
513
1411
|
std::mutex mutex_results;
|
|
514
1412
|
std::condition_variable condition_results;
|
|
@@ -548,7 +1446,7 @@ struct server_response {
|
|
|
548
1446
|
}
|
|
549
1447
|
|
|
550
1448
|
// This function blocks the thread until there is a response for one of the id_tasks
|
|
551
|
-
|
|
1449
|
+
server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
|
|
552
1450
|
while (true) {
|
|
553
1451
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
554
1452
|
condition_results.wait(lock, [&]{
|
|
@@ -556,8 +1454,8 @@ struct server_response {
|
|
|
556
1454
|
});
|
|
557
1455
|
|
|
558
1456
|
for (int i = 0; i < (int) queue_results.size(); i++) {
|
|
559
|
-
if (id_tasks.find(queue_results[i]
|
|
560
|
-
|
|
1457
|
+
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
|
1458
|
+
server_task_result_ptr res = std::move(queue_results[i]);
|
|
561
1459
|
queue_results.erase(queue_results.begin() + i);
|
|
562
1460
|
return res;
|
|
563
1461
|
}
|
|
@@ -568,21 +1466,21 @@ struct server_response {
|
|
|
568
1466
|
}
|
|
569
1467
|
|
|
570
1468
|
// single-task version of recv()
|
|
571
|
-
|
|
1469
|
+
server_task_result_ptr recv(int id_task) {
|
|
572
1470
|
std::unordered_set<int> id_tasks = {id_task};
|
|
573
1471
|
return recv(id_tasks);
|
|
574
1472
|
}
|
|
575
1473
|
|
|
576
1474
|
// Send a new result to a waiting id_task
|
|
577
|
-
void send(
|
|
578
|
-
SRV_DBG("sending result for task id = %d\n", result
|
|
1475
|
+
void send(server_task_result_ptr && result) {
|
|
1476
|
+
SRV_DBG("sending result for task id = %d\n", result->id);
|
|
579
1477
|
|
|
580
1478
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
581
1479
|
for (const auto & id_task : waiting_task_ids) {
|
|
582
|
-
if (result
|
|
583
|
-
SRV_DBG("task id = %d
|
|
1480
|
+
if (result->id == id_task) {
|
|
1481
|
+
SRV_DBG("task id = %d pushed to result queue\n", result->id);
|
|
584
1482
|
|
|
585
|
-
queue_results.
|
|
1483
|
+
queue_results.emplace_back(std::move(result));
|
|
586
1484
|
condition_results.notify_all();
|
|
587
1485
|
return;
|
|
588
1486
|
}
|
|
@@ -591,11 +1489,14 @@ struct server_response {
|
|
|
591
1489
|
};
|
|
592
1490
|
|
|
593
1491
|
struct server_context {
|
|
1492
|
+
common_params params_base;
|
|
1493
|
+
|
|
594
1494
|
llama_model * model = nullptr;
|
|
595
1495
|
llama_context * ctx = nullptr;
|
|
596
1496
|
std::vector<common_lora_adapter_container> loras;
|
|
597
1497
|
|
|
598
|
-
|
|
1498
|
+
llama_model * model_dft = nullptr;
|
|
1499
|
+
llama_context_params cparams_dft;
|
|
599
1500
|
|
|
600
1501
|
llama_batch batch = {};
|
|
601
1502
|
|
|
@@ -628,34 +1529,90 @@ struct server_context {
|
|
|
628
1529
|
model = nullptr;
|
|
629
1530
|
}
|
|
630
1531
|
|
|
1532
|
+
if (model_dft) {
|
|
1533
|
+
llama_free_model(model_dft);
|
|
1534
|
+
model_dft = nullptr;
|
|
1535
|
+
}
|
|
1536
|
+
|
|
631
1537
|
// Clear any sampling context
|
|
632
1538
|
for (server_slot & slot : slots) {
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
1539
|
+
common_sampler_free(slot.smpl);
|
|
1540
|
+
slot.smpl = nullptr;
|
|
1541
|
+
|
|
1542
|
+
llama_free(slot.ctx_dft);
|
|
1543
|
+
slot.ctx_dft = nullptr;
|
|
1544
|
+
|
|
1545
|
+
common_speculative_free(slot.spec);
|
|
1546
|
+
slot.spec = nullptr;
|
|
1547
|
+
|
|
1548
|
+
llama_batch_free(slot.batch_spec);
|
|
636
1549
|
}
|
|
637
1550
|
|
|
638
1551
|
llama_batch_free(batch);
|
|
639
1552
|
}
|
|
640
1553
|
|
|
641
|
-
bool load_model(const common_params &
|
|
642
|
-
|
|
1554
|
+
bool load_model(const common_params & params) {
|
|
1555
|
+
SRV_INF("loading model '%s'\n", params.model.c_str());
|
|
643
1556
|
|
|
644
|
-
|
|
1557
|
+
params_base = params;
|
|
1558
|
+
|
|
1559
|
+
common_init_result llama_init = common_init_from_params(params_base);
|
|
645
1560
|
|
|
646
1561
|
model = llama_init.model;
|
|
647
1562
|
ctx = llama_init.context;
|
|
648
1563
|
loras = llama_init.lora_adapters;
|
|
649
1564
|
|
|
650
1565
|
if (model == nullptr) {
|
|
651
|
-
SRV_ERR("failed to load model, '%s'\n",
|
|
1566
|
+
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
|
652
1567
|
return false;
|
|
653
1568
|
}
|
|
654
1569
|
|
|
655
1570
|
n_ctx = llama_n_ctx(ctx);
|
|
656
1571
|
|
|
657
1572
|
add_bos_token = llama_add_bos_token(model);
|
|
658
|
-
has_eos_token =
|
|
1573
|
+
has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
|
|
1574
|
+
|
|
1575
|
+
if (!params_base.speculative.model.empty()) {
|
|
1576
|
+
SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
|
|
1577
|
+
|
|
1578
|
+
auto params_dft = params_base;
|
|
1579
|
+
|
|
1580
|
+
params_dft.devices = params_base.speculative.devices;
|
|
1581
|
+
params_dft.model = params_base.speculative.model;
|
|
1582
|
+
params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
|
|
1583
|
+
params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
|
|
1584
|
+
params_dft.n_parallel = 1;
|
|
1585
|
+
|
|
1586
|
+
common_init_result llama_init_dft = common_init_from_params(params_dft);
|
|
1587
|
+
|
|
1588
|
+
model_dft = llama_init_dft.model;
|
|
1589
|
+
|
|
1590
|
+
if (model_dft == nullptr) {
|
|
1591
|
+
SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
|
|
1592
|
+
return false;
|
|
1593
|
+
}
|
|
1594
|
+
|
|
1595
|
+
if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
|
|
1596
|
+
SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
|
|
1597
|
+
|
|
1598
|
+
llama_free (llama_init_dft.context);
|
|
1599
|
+
llama_free_model(llama_init_dft.model);
|
|
1600
|
+
|
|
1601
|
+
return false;
|
|
1602
|
+
}
|
|
1603
|
+
|
|
1604
|
+
const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
|
|
1605
|
+
|
|
1606
|
+
cparams_dft = common_context_params_to_llama(params_dft);
|
|
1607
|
+
cparams_dft.n_batch = n_ctx_dft;
|
|
1608
|
+
|
|
1609
|
+
// force F16 KV cache for the draft model for extra performance
|
|
1610
|
+
cparams_dft.type_k = GGML_TYPE_F16;
|
|
1611
|
+
cparams_dft.type_v = GGML_TYPE_F16;
|
|
1612
|
+
|
|
1613
|
+
// the context is not needed - we will create one for each slot
|
|
1614
|
+
llama_free(llama_init_dft.context);
|
|
1615
|
+
}
|
|
659
1616
|
|
|
660
1617
|
return true;
|
|
661
1618
|
}
|
|
@@ -674,20 +1631,37 @@ struct server_context {
|
|
|
674
1631
|
}
|
|
675
1632
|
|
|
676
1633
|
void init() {
|
|
677
|
-
const int32_t n_ctx_slot = n_ctx /
|
|
1634
|
+
const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
|
|
678
1635
|
|
|
679
|
-
SRV_INF("initializing slots, n_slots = %d\n",
|
|
1636
|
+
SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
|
|
680
1637
|
|
|
681
|
-
for (int i = 0; i <
|
|
1638
|
+
for (int i = 0; i < params_base.n_parallel; i++) {
|
|
682
1639
|
server_slot slot;
|
|
683
1640
|
|
|
684
1641
|
slot.id = i;
|
|
1642
|
+
slot.ctx = ctx;
|
|
685
1643
|
slot.n_ctx = n_ctx_slot;
|
|
686
|
-
slot.n_predict =
|
|
1644
|
+
slot.n_predict = params_base.n_predict;
|
|
1645
|
+
|
|
1646
|
+
if (model_dft) {
|
|
1647
|
+
slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
|
|
1648
|
+
|
|
1649
|
+
slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
|
|
1650
|
+
if (slot.ctx_dft == nullptr) {
|
|
1651
|
+
SRV_ERR("%s", "failed to create draft context\n");
|
|
1652
|
+
return;
|
|
1653
|
+
}
|
|
1654
|
+
|
|
1655
|
+
slot.spec = common_speculative_init(slot.ctx_dft);
|
|
1656
|
+
if (slot.spec == nullptr) {
|
|
1657
|
+
SRV_ERR("%s", "failed to create speculator\n");
|
|
1658
|
+
return;
|
|
1659
|
+
}
|
|
1660
|
+
}
|
|
687
1661
|
|
|
688
1662
|
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
|
689
1663
|
|
|
690
|
-
slot.
|
|
1664
|
+
slot.params.sampling = params_base.sampling;
|
|
691
1665
|
|
|
692
1666
|
slot.callback_on_release = [this](int) {
|
|
693
1667
|
queue_tasks.pop_deferred_task();
|
|
@@ -698,8 +1672,7 @@ struct server_context {
|
|
|
698
1672
|
slots.push_back(slot);
|
|
699
1673
|
}
|
|
700
1674
|
|
|
701
|
-
default_generation_settings_for_props =
|
|
702
|
-
default_generation_settings_for_props["seed"] = -1;
|
|
1675
|
+
default_generation_settings_for_props = slots[0].to_json();
|
|
703
1676
|
|
|
704
1677
|
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
|
705
1678
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
@@ -707,7 +1680,7 @@ struct server_context {
|
|
|
707
1680
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
708
1681
|
|
|
709
1682
|
// only a single seq_id per token is needed
|
|
710
|
-
batch = llama_batch_init(std::max(n_batch,
|
|
1683
|
+
batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
|
|
711
1684
|
}
|
|
712
1685
|
|
|
713
1686
|
metrics.init();
|
|
@@ -743,7 +1716,7 @@ struct server_context {
|
|
|
743
1716
|
}
|
|
744
1717
|
|
|
745
1718
|
// length of the Longest Common Subsequence between the current slot's prompt and the input prompt
|
|
746
|
-
int cur_lcs_len =
|
|
1719
|
+
int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
|
|
747
1720
|
|
|
748
1721
|
// fraction of the common subsequence length compared to the current slot's prompt length
|
|
749
1722
|
float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
|
|
@@ -786,87 +1759,14 @@ struct server_context {
|
|
|
786
1759
|
}
|
|
787
1760
|
|
|
788
1761
|
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
|
789
|
-
|
|
790
|
-
|
|
791
|
-
|
|
792
|
-
|
|
793
|
-
|
|
794
|
-
|
|
795
|
-
slot.oaicompat = true;
|
|
796
|
-
slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
|
|
797
|
-
} else {
|
|
798
|
-
slot.oaicompat = false;
|
|
799
|
-
slot.oaicompat_model = "";
|
|
800
|
-
}
|
|
801
|
-
|
|
802
|
-
slot.params.stream = json_value(data, "stream", false);
|
|
803
|
-
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
|
804
|
-
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
|
805
|
-
slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
|
|
806
|
-
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
|
807
|
-
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
|
808
|
-
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
|
809
|
-
slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
|
|
810
|
-
slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
|
|
811
|
-
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
|
812
|
-
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
|
813
|
-
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
|
814
|
-
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
|
815
|
-
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
|
|
816
|
-
slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
|
|
817
|
-
slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
|
|
818
|
-
slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
|
|
819
|
-
slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
|
|
820
|
-
slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
|
|
821
|
-
slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
|
|
822
|
-
slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
|
|
823
|
-
slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
|
|
824
|
-
slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
|
|
825
|
-
slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
|
826
|
-
slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
|
827
|
-
slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
|
|
828
|
-
slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
|
|
829
|
-
slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
|
|
830
|
-
slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
|
831
|
-
slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
|
832
|
-
//slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
|
|
833
|
-
slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
|
|
834
|
-
|
|
835
|
-
if (slot.sparams.dry_base < 1.0f)
|
|
836
|
-
{
|
|
837
|
-
slot.sparams.dry_base = default_sparams.dry_base;
|
|
838
|
-
}
|
|
839
|
-
|
|
840
|
-
// sequence breakers for DRY
|
|
841
|
-
{
|
|
842
|
-
// Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
|
|
843
|
-
// Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
|
|
844
|
-
|
|
845
|
-
if (data.contains("dry_sequence_breakers")) {
|
|
846
|
-
slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
|
|
847
|
-
if (slot.sparams.dry_sequence_breakers.empty()) {
|
|
848
|
-
send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
|
|
849
|
-
return false;
|
|
850
|
-
}
|
|
851
|
-
}
|
|
852
|
-
}
|
|
1762
|
+
slot.reset();
|
|
1763
|
+
slot.id_task = task.id;
|
|
1764
|
+
slot.index = task.index;
|
|
1765
|
+
slot.task_type = task.type;
|
|
1766
|
+
slot.params = std::move(task.params);
|
|
1767
|
+
slot.prompt_tokens = std::move(task.prompt_tokens);
|
|
853
1768
|
|
|
854
|
-
|
|
855
|
-
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
|
856
|
-
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
|
857
|
-
return false;
|
|
858
|
-
}
|
|
859
|
-
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
860
|
-
try {
|
|
861
|
-
auto schema = json_value(data, "json_schema", json::object());
|
|
862
|
-
slot.sparams.grammar = json_schema_to_grammar(schema);
|
|
863
|
-
} catch (const std::exception & e) {
|
|
864
|
-
send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
|
|
865
|
-
return false;
|
|
866
|
-
}
|
|
867
|
-
} else {
|
|
868
|
-
slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
|
869
|
-
}
|
|
1769
|
+
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
|
870
1770
|
|
|
871
1771
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
|
872
1772
|
// Might be better to reject the request with a 400 ?
|
|
@@ -874,86 +1774,16 @@ struct server_context {
|
|
|
874
1774
|
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
|
875
1775
|
}
|
|
876
1776
|
|
|
877
|
-
{
|
|
878
|
-
slot.
|
|
879
|
-
|
|
880
|
-
|
|
881
|
-
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
|
882
|
-
}
|
|
883
|
-
|
|
884
|
-
const auto & logit_bias = data.find("logit_bias");
|
|
885
|
-
if (logit_bias != data.end() && logit_bias->is_array()) {
|
|
886
|
-
const int n_vocab = llama_n_vocab(model);
|
|
887
|
-
for (const auto & el : *logit_bias) {
|
|
888
|
-
// TODO: we may want to throw errors here, in case "el" is incorrect
|
|
889
|
-
if (el.is_array() && el.size() == 2) {
|
|
890
|
-
float bias;
|
|
891
|
-
if (el[1].is_number()) {
|
|
892
|
-
bias = el[1].get<float>();
|
|
893
|
-
} else if (el[1].is_boolean() && !el[1].get<bool>()) {
|
|
894
|
-
bias = -INFINITY;
|
|
895
|
-
} else {
|
|
896
|
-
continue;
|
|
897
|
-
}
|
|
898
|
-
|
|
899
|
-
if (el[0].is_number_integer()) {
|
|
900
|
-
llama_token tok = el[0].get<llama_token>();
|
|
901
|
-
if (tok >= 0 && tok < n_vocab) {
|
|
902
|
-
slot.sparams.logit_bias.push_back({tok, bias});
|
|
903
|
-
}
|
|
904
|
-
} else if (el[0].is_string()) {
|
|
905
|
-
auto toks = common_tokenize(model, el[0].get<std::string>(), false);
|
|
906
|
-
for (auto tok : toks) {
|
|
907
|
-
slot.sparams.logit_bias.push_back({tok, bias});
|
|
908
|
-
}
|
|
909
|
-
}
|
|
910
|
-
}
|
|
911
|
-
}
|
|
912
|
-
}
|
|
913
|
-
}
|
|
914
|
-
|
|
915
|
-
{
|
|
916
|
-
slot.params.antiprompt.clear();
|
|
917
|
-
|
|
918
|
-
const auto & stop = data.find("stop");
|
|
919
|
-
if (stop != data.end() && stop->is_array()) {
|
|
920
|
-
for (const auto & word : *stop) {
|
|
921
|
-
if (!word.empty()) {
|
|
922
|
-
slot.params.antiprompt.push_back(word);
|
|
923
|
-
}
|
|
924
|
-
}
|
|
925
|
-
}
|
|
926
|
-
}
|
|
927
|
-
|
|
928
|
-
{
|
|
929
|
-
const auto & samplers = data.find("samplers");
|
|
930
|
-
if (samplers != data.end()) {
|
|
931
|
-
if (samplers->is_array()) {
|
|
932
|
-
std::vector<std::string> sampler_names;
|
|
933
|
-
for (const auto & name : *samplers) {
|
|
934
|
-
if (name.is_string()) {
|
|
935
|
-
sampler_names.emplace_back(name);
|
|
936
|
-
}
|
|
937
|
-
}
|
|
938
|
-
slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false);
|
|
939
|
-
} else if (samplers->is_string()){
|
|
940
|
-
std::string sampler_string;
|
|
941
|
-
for (const auto & name : *samplers) {
|
|
942
|
-
sampler_string += name;
|
|
943
|
-
}
|
|
944
|
-
slot.sparams.samplers = common_sampler_types_from_chars(sampler_string);
|
|
945
|
-
}
|
|
946
|
-
} else {
|
|
947
|
-
slot.sparams.samplers = default_sparams.samplers;
|
|
948
|
-
}
|
|
949
|
-
}
|
|
950
|
-
|
|
1777
|
+
if (slot.params.ignore_eos && has_eos_token) {
|
|
1778
|
+
slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
|
1779
|
+
}
|
|
1780
|
+
|
|
951
1781
|
{
|
|
952
1782
|
if (slot.smpl != nullptr) {
|
|
953
1783
|
common_sampler_free(slot.smpl);
|
|
954
1784
|
}
|
|
955
1785
|
|
|
956
|
-
slot.smpl = common_sampler_init(model, slot.
|
|
1786
|
+
slot.smpl = common_sampler_init(model, slot.params.sampling);
|
|
957
1787
|
if (slot.smpl == nullptr) {
|
|
958
1788
|
// for now, the only error that may happen here is invalid grammar
|
|
959
1789
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -961,6 +1791,12 @@ struct server_context {
|
|
|
961
1791
|
}
|
|
962
1792
|
}
|
|
963
1793
|
|
|
1794
|
+
if (slot.ctx_dft) {
|
|
1795
|
+
llama_batch_free(slot.batch_spec);
|
|
1796
|
+
|
|
1797
|
+
slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
|
|
1798
|
+
}
|
|
1799
|
+
|
|
964
1800
|
slot.state = SLOT_STATE_STARTED;
|
|
965
1801
|
|
|
966
1802
|
SLT_INF(slot, "%s", "processing task\n");
|
|
@@ -978,49 +1814,33 @@ struct server_context {
|
|
|
978
1814
|
|
|
979
1815
|
bool process_token(completion_token_output & result, server_slot & slot) {
|
|
980
1816
|
// remember which tokens were sampled - used for repetition penalties during sampling
|
|
981
|
-
const std::string token_str =
|
|
1817
|
+
const std::string token_str = result.text_to_send;
|
|
982
1818
|
slot.sampled = result.tok;
|
|
983
1819
|
|
|
984
|
-
// search stop word and delete it
|
|
985
1820
|
slot.generated_text += token_str;
|
|
1821
|
+
if (slot.params.return_tokens) {
|
|
1822
|
+
slot.generated_tokens.push_back(result.tok);
|
|
1823
|
+
}
|
|
986
1824
|
slot.has_next_token = true;
|
|
987
1825
|
|
|
988
1826
|
// check if there is incomplete UTF-8 character at the end
|
|
989
|
-
bool incomplete =
|
|
990
|
-
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
|
991
|
-
unsigned char c = slot.generated_text[slot.generated_text.size() - i];
|
|
992
|
-
if ((c & 0xC0) == 0x80) {
|
|
993
|
-
// continuation byte: 10xxxxxx
|
|
994
|
-
continue;
|
|
995
|
-
}
|
|
996
|
-
if ((c & 0xE0) == 0xC0) {
|
|
997
|
-
// 2-byte character: 110xxxxx ...
|
|
998
|
-
incomplete = i < 2;
|
|
999
|
-
} else if ((c & 0xF0) == 0xE0) {
|
|
1000
|
-
// 3-byte character: 1110xxxx ...
|
|
1001
|
-
incomplete = i < 3;
|
|
1002
|
-
} else if ((c & 0xF8) == 0xF0) {
|
|
1003
|
-
// 4-byte character: 11110xxx ...
|
|
1004
|
-
incomplete = i < 4;
|
|
1005
|
-
}
|
|
1006
|
-
// else 1-byte character or invalid byte
|
|
1007
|
-
break;
|
|
1008
|
-
}
|
|
1827
|
+
bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
|
|
1009
1828
|
|
|
1829
|
+
// search stop word and delete it
|
|
1010
1830
|
if (!incomplete) {
|
|
1011
1831
|
size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
1012
1832
|
|
|
1013
1833
|
const std::string str_test = slot.generated_text.substr(pos);
|
|
1014
1834
|
bool send_text = true;
|
|
1015
1835
|
|
|
1016
|
-
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(),
|
|
1836
|
+
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
|
|
1017
1837
|
if (stop_pos != std::string::npos) {
|
|
1018
1838
|
slot.generated_text.erase(
|
|
1019
1839
|
slot.generated_text.begin() + pos + stop_pos,
|
|
1020
1840
|
slot.generated_text.end());
|
|
1021
1841
|
pos = std::min(slot.n_sent_text, slot.generated_text.size());
|
|
1022
1842
|
} else if (slot.has_next_token) {
|
|
1023
|
-
stop_pos = slot.find_stopping_strings(str_test, token_str.size(),
|
|
1843
|
+
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
|
|
1024
1844
|
send_text = stop_pos == std::string::npos;
|
|
1025
1845
|
}
|
|
1026
1846
|
|
|
@@ -1043,8 +1863,8 @@ struct server_context {
|
|
|
1043
1863
|
}
|
|
1044
1864
|
|
|
1045
1865
|
// check the limits
|
|
1046
|
-
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(
|
|
1047
|
-
slot.
|
|
1866
|
+
if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
|
|
1867
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1048
1868
|
slot.has_next_token = false;
|
|
1049
1869
|
|
|
1050
1870
|
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
|
@@ -1053,7 +1873,7 @@ struct server_context {
|
|
|
1053
1873
|
if (slot.has_new_line) {
|
|
1054
1874
|
// if we have already seen a new line, we stop after a certain time limit
|
|
1055
1875
|
if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
|
|
1056
|
-
slot.
|
|
1876
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1057
1877
|
slot.has_next_token = false;
|
|
1058
1878
|
|
|
1059
1879
|
SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
|
|
@@ -1073,7 +1893,7 @@ struct server_context {
|
|
|
1073
1893
|
}
|
|
1074
1894
|
|
|
1075
1895
|
if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
|
|
1076
|
-
slot.
|
|
1896
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1077
1897
|
slot.has_next_token = false;
|
|
1078
1898
|
|
|
1079
1899
|
// cut the last line
|
|
@@ -1102,7 +1922,7 @@ struct server_context {
|
|
|
1102
1922
|
// if context shift is disabled, we stop when it reaches the context limit
|
|
1103
1923
|
if (slot.n_past >= slot.n_ctx) {
|
|
1104
1924
|
slot.truncated = true;
|
|
1105
|
-
slot.
|
|
1925
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1106
1926
|
slot.has_next_token = false;
|
|
1107
1927
|
|
|
1108
1928
|
SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
|
|
@@ -1110,7 +1930,7 @@ struct server_context {
|
|
|
1110
1930
|
}
|
|
1111
1931
|
|
|
1112
1932
|
if (llama_token_is_eog(model, result.tok)) {
|
|
1113
|
-
slot.
|
|
1933
|
+
slot.stop = STOP_TYPE_EOS;
|
|
1114
1934
|
slot.has_next_token = false;
|
|
1115
1935
|
|
|
1116
1936
|
SLT_DBG(slot, "%s", "stopped by EOS\n");
|
|
@@ -1120,7 +1940,7 @@ struct server_context {
|
|
|
1120
1940
|
|
|
1121
1941
|
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
|
1122
1942
|
slot.truncated = true;
|
|
1123
|
-
slot.
|
|
1943
|
+
slot.stop = STOP_TYPE_LIMIT;
|
|
1124
1944
|
slot.has_next_token = false; // stop prediction
|
|
1125
1945
|
|
|
1126
1946
|
SLT_WRN(slot,
|
|
@@ -1134,53 +1954,53 @@ struct server_context {
|
|
|
1134
1954
|
return slot.has_next_token; // continue
|
|
1135
1955
|
}
|
|
1136
1956
|
|
|
1137
|
-
|
|
1138
|
-
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
|
|
1957
|
+
void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
|
|
1958
|
+
size_t n_probs = slot.params.sampling.n_probs;
|
|
1959
|
+
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
|
|
1960
|
+
if (post_sampling) {
|
|
1961
|
+
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
|
1962
|
+
const size_t max_probs = cur_p->size;
|
|
1963
|
+
|
|
1964
|
+
// set probability for sampled token
|
|
1965
|
+
for (size_t i = 0; i < max_probs; i++) {
|
|
1966
|
+
if (cur_p->data[i].id == result.tok) {
|
|
1967
|
+
result.prob = cur_p->data[i].p;
|
|
1968
|
+
break;
|
|
1969
|
+
}
|
|
1970
|
+
}
|
|
1143
1971
|
|
|
1144
|
-
|
|
1145
|
-
|
|
1146
|
-
|
|
1147
|
-
|
|
1148
|
-
|
|
1149
|
-
|
|
1150
|
-
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1155
|
-
|
|
1156
|
-
|
|
1157
|
-
|
|
1158
|
-
{
|
|
1159
|
-
|
|
1160
|
-
|
|
1161
|
-
|
|
1162
|
-
|
|
1163
|
-
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
{
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1173
|
-
|
|
1174
|
-
|
|
1175
|
-
|
|
1176
|
-
{"ignore_eos", slot.sparams.ignore_eos},
|
|
1177
|
-
{"stream", slot.params.stream},
|
|
1178
|
-
//{"logit_bias", slot.sparams.logit_bias},
|
|
1179
|
-
{"n_probs", slot.sparams.n_probs},
|
|
1180
|
-
{"min_keep", slot.sparams.min_keep},
|
|
1181
|
-
{"grammar", slot.sparams.grammar},
|
|
1182
|
-
{"samplers", samplers},
|
|
1183
|
-
};
|
|
1972
|
+
// set probability for top n_probs tokens
|
|
1973
|
+
result.probs.reserve(max_probs);
|
|
1974
|
+
for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
|
|
1975
|
+
result.probs.push_back({
|
|
1976
|
+
cur_p->data[i].id,
|
|
1977
|
+
common_detokenize(ctx, {cur_p->data[i].id}, special),
|
|
1978
|
+
cur_p->data[i].p
|
|
1979
|
+
});
|
|
1980
|
+
}
|
|
1981
|
+
} else {
|
|
1982
|
+
// TODO: optimize this with min-p optimization
|
|
1983
|
+
std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
|
|
1984
|
+
|
|
1985
|
+
// set probability for sampled token
|
|
1986
|
+
for (size_t i = 0; i < n_vocab; i++) {
|
|
1987
|
+
// set probability for sampled token
|
|
1988
|
+
if (cur[i].id == result.tok) {
|
|
1989
|
+
result.prob = cur[i].p;
|
|
1990
|
+
break;
|
|
1991
|
+
}
|
|
1992
|
+
}
|
|
1993
|
+
|
|
1994
|
+
// set probability for top n_probs tokens
|
|
1995
|
+
result.probs.reserve(n_probs);
|
|
1996
|
+
for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
|
|
1997
|
+
result.probs.push_back({
|
|
1998
|
+
cur[i].id,
|
|
1999
|
+
common_detokenize(ctx, {cur[i].id}, special),
|
|
2000
|
+
cur[i].p
|
|
2001
|
+
});
|
|
2002
|
+
}
|
|
2003
|
+
}
|
|
1184
2004
|
}
|
|
1185
2005
|
|
|
1186
2006
|
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
@@ -1194,108 +2014,99 @@ struct server_context {
|
|
|
1194
2014
|
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
1195
2015
|
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
|
|
1196
2016
|
|
|
1197
|
-
|
|
1198
|
-
res
|
|
1199
|
-
res
|
|
1200
|
-
res
|
|
1201
|
-
res.data = format_error_response(error, type);
|
|
1202
|
-
|
|
1203
|
-
queue_results.send(res);
|
|
1204
|
-
}
|
|
1205
|
-
|
|
1206
|
-
void send_partial_response(server_slot & slot, completion_token_output tkn) {
|
|
1207
|
-
server_task_result res;
|
|
1208
|
-
res.id = slot.id_task;
|
|
1209
|
-
res.error = false;
|
|
1210
|
-
res.stop = false;
|
|
1211
|
-
res.data = json {
|
|
1212
|
-
{"content", tkn.text_to_send},
|
|
1213
|
-
{"stop", false},
|
|
1214
|
-
{"id_slot", slot.id},
|
|
1215
|
-
{"multimodal", false},
|
|
1216
|
-
{"index", slot.index},
|
|
1217
|
-
};
|
|
2017
|
+
auto res = std::make_unique<server_task_result_error>();
|
|
2018
|
+
res->id = id_task;
|
|
2019
|
+
res->err_type = type;
|
|
2020
|
+
res->err_msg = error;
|
|
1218
2021
|
|
|
1219
|
-
|
|
1220
|
-
|
|
1221
|
-
const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
|
|
1222
|
-
const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
|
|
2022
|
+
queue_results.send(std::move(res));
|
|
2023
|
+
}
|
|
1223
2024
|
|
|
1224
|
-
|
|
1225
|
-
|
|
1226
|
-
|
|
1227
|
-
|
|
1228
|
-
|
|
1229
|
-
|
|
1230
|
-
|
|
2025
|
+
void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
|
|
2026
|
+
auto res = std::make_unique<server_task_result_cmpl_partial>();
|
|
2027
|
+
|
|
2028
|
+
res->id = slot.id_task;
|
|
2029
|
+
res->index = slot.index;
|
|
2030
|
+
res->content = tkn.text_to_send;
|
|
2031
|
+
res->tokens = { tkn.tok };
|
|
2032
|
+
|
|
2033
|
+
res->n_decoded = slot.n_decoded;
|
|
2034
|
+
res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
2035
|
+
res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
1231
2036
|
|
|
1232
|
-
|
|
2037
|
+
res->verbose = slot.params.verbose;
|
|
2038
|
+
res->oaicompat = slot.params.oaicompat;
|
|
2039
|
+
res->oaicompat_chat = slot.params.oaicompat_chat;
|
|
2040
|
+
res->oaicompat_model = slot.params.oaicompat_model;
|
|
2041
|
+
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
2042
|
+
|
|
2043
|
+
// populate res.probs_output
|
|
2044
|
+
if (slot.params.sampling.n_probs > 0) {
|
|
2045
|
+
res->prob_output = tkn; // copy the token probs
|
|
1233
2046
|
}
|
|
1234
2047
|
|
|
1235
|
-
if
|
|
1236
|
-
|
|
1237
|
-
res
|
|
2048
|
+
// populate timings if this is final response or timings_per_token is enabled
|
|
2049
|
+
if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
|
|
2050
|
+
res->timings = slot.get_timings();
|
|
1238
2051
|
}
|
|
1239
2052
|
|
|
1240
|
-
queue_results.send(res);
|
|
2053
|
+
queue_results.send(std::move(res));
|
|
1241
2054
|
}
|
|
1242
2055
|
|
|
1243
|
-
void send_final_response(
|
|
1244
|
-
|
|
1245
|
-
res
|
|
1246
|
-
res
|
|
1247
|
-
|
|
1248
|
-
res
|
|
1249
|
-
|
|
1250
|
-
|
|
1251
|
-
|
|
1252
|
-
|
|
1253
|
-
|
|
1254
|
-
|
|
1255
|
-
|
|
1256
|
-
|
|
1257
|
-
|
|
1258
|
-
|
|
1259
|
-
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1266
|
-
|
|
1267
|
-
|
|
1268
|
-
|
|
1269
|
-
|
|
1270
|
-
|
|
2056
|
+
void send_final_response(server_slot & slot) {
|
|
2057
|
+
auto res = std::make_unique<server_task_result_cmpl_final>();
|
|
2058
|
+
res->id = slot.id_task;
|
|
2059
|
+
res->id_slot = slot.id;
|
|
2060
|
+
|
|
2061
|
+
res->index = slot.index;
|
|
2062
|
+
res->content = slot.generated_text;
|
|
2063
|
+
res->tokens = slot.generated_tokens;
|
|
2064
|
+
res->timings = slot.get_timings();
|
|
2065
|
+
res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
|
|
2066
|
+
|
|
2067
|
+
res->truncated = slot.truncated;
|
|
2068
|
+
res->n_decoded = slot.n_decoded;
|
|
2069
|
+
res->n_prompt_tokens = slot.n_prompt_tokens;
|
|
2070
|
+
res->n_tokens_cached = slot.n_past;
|
|
2071
|
+
res->has_new_line = slot.has_new_line;
|
|
2072
|
+
res->stopping_word = slot.stopping_word;
|
|
2073
|
+
res->stop = slot.stop;
|
|
2074
|
+
res->post_sampling_probs = slot.params.post_sampling_probs;
|
|
2075
|
+
|
|
2076
|
+
res->verbose = slot.params.verbose;
|
|
2077
|
+
res->stream = slot.params.stream;
|
|
2078
|
+
res->oaicompat = slot.params.oaicompat;
|
|
2079
|
+
res->oaicompat_chat = slot.params.oaicompat_chat;
|
|
2080
|
+
res->oaicompat_model = slot.params.oaicompat_model;
|
|
2081
|
+
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
|
2082
|
+
|
|
2083
|
+
// populate res.probs_output
|
|
2084
|
+
if (slot.params.sampling.n_probs > 0) {
|
|
2085
|
+
if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
|
|
1271
2086
|
const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
|
|
1272
2087
|
|
|
1273
2088
|
size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
|
|
1274
|
-
|
|
2089
|
+
res->probs_output = std::vector<completion_token_output>(
|
|
1275
2090
|
slot.generated_token_probs.begin(),
|
|
1276
2091
|
slot.generated_token_probs.end() - safe_offset);
|
|
1277
2092
|
} else {
|
|
1278
|
-
|
|
2093
|
+
res->probs_output = std::vector<completion_token_output>(
|
|
1279
2094
|
slot.generated_token_probs.begin(),
|
|
1280
2095
|
slot.generated_token_probs.end());
|
|
1281
2096
|
}
|
|
1282
|
-
|
|
1283
|
-
res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
|
|
1284
2097
|
}
|
|
1285
2098
|
|
|
1286
|
-
|
|
1287
|
-
res.data["oaicompat_token_ctr"] = slot.n_decoded;
|
|
1288
|
-
res.data["model"] = slot.oaicompat_model;
|
|
1289
|
-
}
|
|
2099
|
+
res->generation_params = slot.params; // copy the parameters
|
|
1290
2100
|
|
|
1291
|
-
queue_results.send(res);
|
|
2101
|
+
queue_results.send(std::move(res));
|
|
1292
2102
|
}
|
|
1293
2103
|
|
|
1294
2104
|
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
|
1295
|
-
|
|
1296
|
-
res
|
|
1297
|
-
res
|
|
1298
|
-
res
|
|
2105
|
+
auto res = std::make_unique<server_task_result_embd>();
|
|
2106
|
+
res->id = slot.id_task;
|
|
2107
|
+
res->index = slot.index;
|
|
2108
|
+
res->n_tokens = slot.n_prompt_tokens;
|
|
2109
|
+
res->oaicompat = slot.params.oaicompat;
|
|
1299
2110
|
|
|
1300
2111
|
const int n_embd = llama_n_embd(model);
|
|
1301
2112
|
|
|
@@ -1314,32 +2125,30 @@ struct server_context {
|
|
|
1314
2125
|
if (embd == NULL) {
|
|
1315
2126
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
1316
2127
|
|
|
1317
|
-
res.
|
|
1318
|
-
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
|
1319
|
-
{"index", slot.index},
|
|
1320
|
-
};
|
|
1321
|
-
|
|
2128
|
+
res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
|
|
1322
2129
|
continue;
|
|
1323
2130
|
}
|
|
1324
2131
|
|
|
1325
|
-
|
|
1326
|
-
|
|
1327
|
-
|
|
1328
|
-
|
|
1329
|
-
|
|
1330
|
-
}
|
|
2132
|
+
// normalize only when there is pooling
|
|
2133
|
+
// TODO: configurable
|
|
2134
|
+
if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
|
|
2135
|
+
common_embd_normalize(embd, embd_res.data(), n_embd, 2);
|
|
2136
|
+
res->embedding.push_back(embd_res);
|
|
2137
|
+
} else {
|
|
2138
|
+
res->embedding.push_back({ embd, embd + n_embd });
|
|
2139
|
+
}
|
|
1331
2140
|
}
|
|
1332
2141
|
|
|
1333
2142
|
SLT_DBG(slot, "%s", "sending embeddings\n");
|
|
1334
2143
|
|
|
1335
|
-
queue_results.send(res);
|
|
2144
|
+
queue_results.send(std::move(res));
|
|
1336
2145
|
}
|
|
1337
2146
|
|
|
1338
2147
|
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
|
1339
|
-
|
|
1340
|
-
res
|
|
1341
|
-
res
|
|
1342
|
-
res
|
|
2148
|
+
auto res = std::make_unique<server_task_result_rerank>();
|
|
2149
|
+
res->id = slot.id_task;
|
|
2150
|
+
res->index = slot.index;
|
|
2151
|
+
res->n_tokens = slot.n_prompt_tokens;
|
|
1343
2152
|
|
|
1344
2153
|
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
1345
2154
|
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
|
|
@@ -1354,104 +2163,29 @@ struct server_context {
|
|
|
1354
2163
|
if (embd == NULL) {
|
|
1355
2164
|
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
1356
2165
|
|
|
1357
|
-
res
|
|
1358
|
-
{"index", slot.index},
|
|
1359
|
-
{"score", -1e6},
|
|
1360
|
-
};
|
|
1361
|
-
|
|
2166
|
+
res->score = -1e6;
|
|
1362
2167
|
continue;
|
|
1363
2168
|
}
|
|
1364
2169
|
|
|
1365
|
-
res
|
|
1366
|
-
{"index", slot.index},
|
|
1367
|
-
{"score", embd[0]},
|
|
1368
|
-
};
|
|
2170
|
+
res->score = embd[0];
|
|
1369
2171
|
}
|
|
1370
2172
|
|
|
1371
|
-
SLT_DBG(slot, "sending rerank result, res =
|
|
2173
|
+
SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
|
|
1372
2174
|
|
|
1373
|
-
queue_results.send(res);
|
|
2175
|
+
queue_results.send(std::move(res));
|
|
1374
2176
|
}
|
|
1375
2177
|
|
|
1376
2178
|
//
|
|
1377
2179
|
// Functions to create new task(s) and receive result(s)
|
|
1378
2180
|
//
|
|
1379
2181
|
|
|
1380
|
-
// break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
|
|
1381
|
-
std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
|
|
1382
|
-
std::vector<server_task> tasks;
|
|
1383
|
-
auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
|
|
1384
|
-
SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
|
|
1385
|
-
server_task task;
|
|
1386
|
-
task.id = queue_tasks.get_new_id();
|
|
1387
|
-
task.inf_type = inf_type;
|
|
1388
|
-
task.type = SERVER_TASK_TYPE_INFERENCE;
|
|
1389
|
-
task.data = task_data;
|
|
1390
|
-
task.prompt_tokens = std::move(prompt_tokens);
|
|
1391
|
-
tasks.push_back(std::move(task));
|
|
1392
|
-
};
|
|
1393
|
-
|
|
1394
|
-
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
|
|
1395
|
-
if (!data.contains("prompt")) {
|
|
1396
|
-
throw std::runtime_error(error_msg);
|
|
1397
|
-
}
|
|
1398
|
-
|
|
1399
|
-
// because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
|
|
1400
|
-
bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
|
|
1401
|
-
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
|
|
1402
|
-
switch (inf_type) {
|
|
1403
|
-
case SERVER_TASK_INF_TYPE_RERANK:
|
|
1404
|
-
{
|
|
1405
|
-
// prompts[0] is the question
|
|
1406
|
-
// the rest are the answers/documents
|
|
1407
|
-
GGML_ASSERT(tokenized_prompts.size() > 1);
|
|
1408
|
-
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
|
|
1409
|
-
for (size_t i = 1; i < tokenized_prompts.size(); i++) {
|
|
1410
|
-
data["index"] = i - 1;
|
|
1411
|
-
auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
|
|
1412
|
-
create_task(data, tokens);
|
|
1413
|
-
}
|
|
1414
|
-
} break;
|
|
1415
|
-
case SERVER_TASK_INF_TYPE_INFILL:
|
|
1416
|
-
{
|
|
1417
|
-
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
|
1418
|
-
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
|
1419
|
-
data["index"] = i;
|
|
1420
|
-
auto tokens = format_infill(
|
|
1421
|
-
ctx,
|
|
1422
|
-
data.at("input_prefix"),
|
|
1423
|
-
data.at("input_suffix"),
|
|
1424
|
-
data.at("input_extra"),
|
|
1425
|
-
params.n_batch,
|
|
1426
|
-
params.n_predict,
|
|
1427
|
-
slots[0].n_ctx, // TODO: there should be a better way
|
|
1428
|
-
params.spm_infill,
|
|
1429
|
-
tokenized_prompts[i]
|
|
1430
|
-
);
|
|
1431
|
-
create_task(data, tokens);
|
|
1432
|
-
}
|
|
1433
|
-
} break;
|
|
1434
|
-
default:
|
|
1435
|
-
{
|
|
1436
|
-
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
|
1437
|
-
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
|
1438
|
-
data["index"] = i;
|
|
1439
|
-
create_task(data, tokenized_prompts[i]);
|
|
1440
|
-
}
|
|
1441
|
-
}
|
|
1442
|
-
}
|
|
1443
|
-
|
|
1444
|
-
return tasks;
|
|
1445
|
-
}
|
|
1446
|
-
|
|
1447
2182
|
void cancel_tasks(const std::unordered_set<int> & id_tasks) {
|
|
1448
2183
|
std::vector<server_task> cancel_tasks;
|
|
1449
2184
|
cancel_tasks.reserve(id_tasks.size());
|
|
1450
2185
|
for (const auto & id_task : id_tasks) {
|
|
1451
2186
|
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
|
1452
2187
|
|
|
1453
|
-
server_task task;
|
|
1454
|
-
task.type = SERVER_TASK_TYPE_CANCEL;
|
|
2188
|
+
server_task task(SERVER_TASK_TYPE_CANCEL);
|
|
1455
2189
|
task.id_target = id_task;
|
|
1456
2190
|
cancel_tasks.push_back(task);
|
|
1457
2191
|
queue_results.remove_waiting_task_id(id_task);
|
|
@@ -1460,50 +2194,58 @@ struct server_context {
|
|
|
1460
2194
|
queue_tasks.post(cancel_tasks, true);
|
|
1461
2195
|
}
|
|
1462
2196
|
|
|
1463
|
-
// receive the results from task(s)
|
|
1464
|
-
void
|
|
2197
|
+
// receive the results from task(s)
|
|
2198
|
+
void receive_multi_results(
|
|
1465
2199
|
const std::unordered_set<int> & id_tasks,
|
|
1466
|
-
const std::function<void(std::vector<
|
|
2200
|
+
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
|
1467
2201
|
const std::function<void(json)> & error_handler) {
|
|
1468
|
-
|
|
1469
|
-
std::vector<server_task_result> results(id_tasks.size());
|
|
2202
|
+
std::vector<server_task_result_ptr> results(id_tasks.size());
|
|
1470
2203
|
for (size_t i = 0; i < id_tasks.size(); i++) {
|
|
1471
|
-
|
|
2204
|
+
server_task_result_ptr result = queue_results.recv(id_tasks);
|
|
1472
2205
|
|
|
1473
|
-
if (result
|
|
1474
|
-
error_handler(result
|
|
2206
|
+
if (result->is_error()) {
|
|
2207
|
+
error_handler(result->to_json());
|
|
1475
2208
|
cancel_tasks(id_tasks);
|
|
1476
2209
|
return;
|
|
1477
2210
|
}
|
|
1478
2211
|
|
|
1479
|
-
|
|
2212
|
+
GGML_ASSERT(
|
|
2213
|
+
dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
2214
|
+
|| dynamic_cast<server_task_result_embd*>(result.get()) != nullptr
|
|
2215
|
+
|| dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr
|
|
2216
|
+
);
|
|
2217
|
+
const size_t idx = result->get_index();
|
|
1480
2218
|
GGML_ASSERT(idx < results.size() && "index out of range");
|
|
1481
|
-
|
|
1482
|
-
results[idx] = result;
|
|
2219
|
+
results[idx] = std::move(result);
|
|
1483
2220
|
}
|
|
1484
2221
|
result_handler(results);
|
|
1485
2222
|
}
|
|
1486
2223
|
|
|
1487
|
-
// receive the results from task(s)
|
|
2224
|
+
// receive the results from task(s), in stream mode
|
|
1488
2225
|
void receive_cmpl_results_stream(
|
|
1489
|
-
const std::unordered_set<int> & id_tasks,
|
|
1490
|
-
std::function<bool(
|
|
1491
|
-
std::function<void(json)> & error_handler) {
|
|
2226
|
+
const std::unordered_set<int> & id_tasks,
|
|
2227
|
+
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
|
2228
|
+
const std::function<void(json)> & error_handler) {
|
|
1492
2229
|
size_t n_finished = 0;
|
|
1493
2230
|
while (true) {
|
|
1494
|
-
|
|
1495
|
-
|
|
2231
|
+
server_task_result_ptr result = queue_results.recv(id_tasks);
|
|
2232
|
+
|
|
2233
|
+
if (result->is_error()) {
|
|
2234
|
+
error_handler(result->to_json());
|
|
1496
2235
|
cancel_tasks(id_tasks);
|
|
1497
|
-
|
|
2236
|
+
return;
|
|
1498
2237
|
}
|
|
1499
2238
|
|
|
1500
|
-
|
|
1501
|
-
|
|
2239
|
+
GGML_ASSERT(
|
|
2240
|
+
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
|
2241
|
+
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
2242
|
+
);
|
|
2243
|
+
if (!result_handler(result)) {
|
|
1502
2244
|
cancel_tasks(id_tasks);
|
|
1503
2245
|
break;
|
|
1504
2246
|
}
|
|
1505
2247
|
|
|
1506
|
-
if (result
|
|
2248
|
+
if (result->is_stop()) {
|
|
1507
2249
|
if (++n_finished == id_tasks.size()) {
|
|
1508
2250
|
break;
|
|
1509
2251
|
}
|
|
@@ -1517,9 +2259,12 @@ struct server_context {
|
|
|
1517
2259
|
|
|
1518
2260
|
void process_single_task(server_task task) {
|
|
1519
2261
|
switch (task.type) {
|
|
1520
|
-
case
|
|
2262
|
+
case SERVER_TASK_TYPE_COMPLETION:
|
|
2263
|
+
case SERVER_TASK_TYPE_INFILL:
|
|
2264
|
+
case SERVER_TASK_TYPE_EMBEDDING:
|
|
2265
|
+
case SERVER_TASK_TYPE_RERANK:
|
|
1521
2266
|
{
|
|
1522
|
-
const int id_slot =
|
|
2267
|
+
const int id_slot = task.id_selected_slot;
|
|
1523
2268
|
|
|
1524
2269
|
server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
|
|
1525
2270
|
|
|
@@ -1536,13 +2281,6 @@ struct server_context {
|
|
|
1536
2281
|
break;
|
|
1537
2282
|
}
|
|
1538
2283
|
|
|
1539
|
-
slot->reset();
|
|
1540
|
-
|
|
1541
|
-
slot->id_task = task.id;
|
|
1542
|
-
slot->inf_type = task.inf_type;
|
|
1543
|
-
slot->index = json_value(task.data, "index", 0);
|
|
1544
|
-
slot->prompt_tokens = std::move(task.prompt_tokens);
|
|
1545
|
-
|
|
1546
2284
|
if (!launch_slot_with_task(*slot, task)) {
|
|
1547
2285
|
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
|
1548
2286
|
break;
|
|
@@ -1570,21 +2308,7 @@ struct server_context {
|
|
|
1570
2308
|
int n_processing_slots = 0;
|
|
1571
2309
|
|
|
1572
2310
|
for (server_slot & slot : slots) {
|
|
1573
|
-
json slot_data =
|
|
1574
|
-
slot_data["id"] = slot.id;
|
|
1575
|
-
slot_data["id_task"] = slot.id_task;
|
|
1576
|
-
slot_data["is_processing"] = slot.is_processing();
|
|
1577
|
-
slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
|
|
1578
|
-
slot_data["next_token"] = {
|
|
1579
|
-
{"has_next_token", slot.has_next_token},
|
|
1580
|
-
{"has_new_line", slot.has_new_line},
|
|
1581
|
-
{"n_remain", slot.n_remaining},
|
|
1582
|
-
{"n_decoded", slot.n_decoded},
|
|
1583
|
-
{"stopped_eos", slot.stopped_eos},
|
|
1584
|
-
{"stopped_word", slot.stopped_word},
|
|
1585
|
-
{"stopped_limit", slot.stopped_limit},
|
|
1586
|
-
{"stopping_word", slot.stopping_word},
|
|
1587
|
-
};
|
|
2311
|
+
json slot_data = slot.to_json();
|
|
1588
2312
|
|
|
1589
2313
|
if (slot.is_processing()) {
|
|
1590
2314
|
n_processing_slots++;
|
|
@@ -1596,43 +2320,38 @@ struct server_context {
|
|
|
1596
2320
|
}
|
|
1597
2321
|
SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
|
|
1598
2322
|
|
|
1599
|
-
|
|
1600
|
-
res
|
|
1601
|
-
res
|
|
1602
|
-
res
|
|
1603
|
-
res
|
|
1604
|
-
|
|
1605
|
-
|
|
1606
|
-
{ "deferred", queue_tasks.queue_tasks_deferred.size() },
|
|
1607
|
-
{ "t_start", metrics.t_start},
|
|
1608
|
-
|
|
1609
|
-
{ "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
|
|
1610
|
-
{ "t_tokens_generation_total", metrics.t_tokens_generation_total},
|
|
1611
|
-
{ "n_tokens_predicted_total", metrics.n_tokens_predicted_total},
|
|
1612
|
-
{ "t_prompt_processing_total", metrics.t_prompt_processing_total},
|
|
2323
|
+
auto res = std::make_unique<server_task_result_metrics>();
|
|
2324
|
+
res->id = task.id;
|
|
2325
|
+
res->slots_data = std::move(slots_data);
|
|
2326
|
+
res->n_idle_slots = n_idle_slots;
|
|
2327
|
+
res->n_processing_slots = n_processing_slots;
|
|
2328
|
+
res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
|
|
2329
|
+
res->t_start = metrics.t_start;
|
|
1613
2330
|
|
|
1614
|
-
|
|
1615
|
-
|
|
1616
|
-
{ "n_tokens_predicted", metrics.n_tokens_predicted},
|
|
1617
|
-
{ "t_tokens_generation", metrics.t_tokens_generation},
|
|
2331
|
+
res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
|
|
2332
|
+
res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
|
|
1618
2333
|
|
|
1619
|
-
|
|
1620
|
-
|
|
2334
|
+
res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
|
|
2335
|
+
res->t_prompt_processing_total = metrics.t_prompt_processing_total;
|
|
2336
|
+
res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
|
|
2337
|
+
res->t_tokens_generation_total = metrics.t_tokens_generation_total;
|
|
1621
2338
|
|
|
1622
|
-
|
|
1623
|
-
|
|
2339
|
+
res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
|
|
2340
|
+
res->t_prompt_processing = metrics.t_prompt_processing;
|
|
2341
|
+
res->n_tokens_predicted = metrics.n_tokens_predicted;
|
|
2342
|
+
res->t_tokens_generation = metrics.t_tokens_generation;
|
|
1624
2343
|
|
|
1625
|
-
|
|
1626
|
-
|
|
2344
|
+
res->n_decode_total = metrics.n_decode_total;
|
|
2345
|
+
res->n_busy_slots_total = metrics.n_busy_slots_total;
|
|
1627
2346
|
|
|
1628
|
-
if (
|
|
2347
|
+
if (task.metrics_reset_bucket) {
|
|
1629
2348
|
metrics.reset_bucket();
|
|
1630
2349
|
}
|
|
1631
|
-
queue_results.send(res);
|
|
2350
|
+
queue_results.send(std::move(res));
|
|
1632
2351
|
} break;
|
|
1633
2352
|
case SERVER_TASK_TYPE_SLOT_SAVE:
|
|
1634
2353
|
{
|
|
1635
|
-
int id_slot = task.
|
|
2354
|
+
int id_slot = task.slot_action.slot_id;
|
|
1636
2355
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
1637
2356
|
if (slot == nullptr) {
|
|
1638
2357
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -1648,32 +2367,27 @@ struct server_context {
|
|
|
1648
2367
|
const size_t token_count = slot->cache_tokens.size();
|
|
1649
2368
|
const int64_t t_start = ggml_time_us();
|
|
1650
2369
|
|
|
1651
|
-
std::string filename = task.
|
|
1652
|
-
std::string filepath = task.
|
|
2370
|
+
std::string filename = task.slot_action.filename;
|
|
2371
|
+
std::string filepath = task.slot_action.filepath;
|
|
1653
2372
|
|
|
1654
2373
|
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
|
|
1655
2374
|
|
|
1656
2375
|
const int64_t t_end = ggml_time_us();
|
|
1657
2376
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
|
1658
2377
|
|
|
1659
|
-
|
|
1660
|
-
|
|
1661
|
-
|
|
1662
|
-
|
|
1663
|
-
|
|
1664
|
-
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
{ "timings", {
|
|
1669
|
-
{ "save_ms", t_save_ms }
|
|
1670
|
-
} }
|
|
1671
|
-
};
|
|
1672
|
-
queue_results.send(result);
|
|
2378
|
+
auto res = std::make_unique<server_task_result_slot_save_load>();
|
|
2379
|
+
res->id = task.id;
|
|
2380
|
+
res->id_slot = id_slot;
|
|
2381
|
+
res->filename = filename;
|
|
2382
|
+
res->is_save = true;
|
|
2383
|
+
res->n_tokens = token_count;
|
|
2384
|
+
res->n_bytes = nwrite;
|
|
2385
|
+
res->t_ms = t_save_ms;
|
|
2386
|
+
queue_results.send(std::move(res));
|
|
1673
2387
|
} break;
|
|
1674
2388
|
case SERVER_TASK_TYPE_SLOT_RESTORE:
|
|
1675
2389
|
{
|
|
1676
|
-
int id_slot = task.
|
|
2390
|
+
int id_slot = task.slot_action.slot_id;
|
|
1677
2391
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
1678
2392
|
if (slot == nullptr) {
|
|
1679
2393
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -1688,8 +2402,8 @@ struct server_context {
|
|
|
1688
2402
|
|
|
1689
2403
|
const int64_t t_start = ggml_time_us();
|
|
1690
2404
|
|
|
1691
|
-
std::string filename = task.
|
|
1692
|
-
std::string filepath = task.
|
|
2405
|
+
std::string filename = task.slot_action.filename;
|
|
2406
|
+
std::string filepath = task.slot_action.filepath;
|
|
1693
2407
|
|
|
1694
2408
|
slot->cache_tokens.resize(slot->n_ctx);
|
|
1695
2409
|
size_t token_count = 0;
|
|
@@ -1704,24 +2418,19 @@ struct server_context {
|
|
|
1704
2418
|
const int64_t t_end = ggml_time_us();
|
|
1705
2419
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
|
1706
2420
|
|
|
1707
|
-
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
1715
|
-
|
|
1716
|
-
{ "timings", {
|
|
1717
|
-
{ "restore_ms", t_restore_ms }
|
|
1718
|
-
} }
|
|
1719
|
-
};
|
|
1720
|
-
queue_results.send(result);
|
|
2421
|
+
auto res = std::make_unique<server_task_result_slot_save_load>();
|
|
2422
|
+
res->id = task.id;
|
|
2423
|
+
res->id_slot = id_slot;
|
|
2424
|
+
res->filename = filename;
|
|
2425
|
+
res->is_save = false;
|
|
2426
|
+
res->n_tokens = token_count;
|
|
2427
|
+
res->n_bytes = nread;
|
|
2428
|
+
res->t_ms = t_restore_ms;
|
|
2429
|
+
queue_results.send(std::move(res));
|
|
1721
2430
|
} break;
|
|
1722
2431
|
case SERVER_TASK_TYPE_SLOT_ERASE:
|
|
1723
2432
|
{
|
|
1724
|
-
int id_slot = task.
|
|
2433
|
+
int id_slot = task.slot_action.slot_id;
|
|
1725
2434
|
server_slot * slot = get_slot_by_id(id_slot);
|
|
1726
2435
|
if (slot == nullptr) {
|
|
1727
2436
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -1739,25 +2448,18 @@ struct server_context {
|
|
|
1739
2448
|
llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
|
|
1740
2449
|
slot->cache_tokens.clear();
|
|
1741
2450
|
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
|
|
1747
|
-
{ "id_slot", id_slot },
|
|
1748
|
-
{ "n_erased", n_erased }
|
|
1749
|
-
};
|
|
1750
|
-
queue_results.send(result);
|
|
2451
|
+
auto res = std::make_unique<server_task_result_slot_erase>();
|
|
2452
|
+
res->id = task.id;
|
|
2453
|
+
res->id_slot = id_slot;
|
|
2454
|
+
res->n_erased = n_erased;
|
|
2455
|
+
queue_results.send(std::move(res));
|
|
1751
2456
|
} break;
|
|
1752
2457
|
case SERVER_TASK_TYPE_SET_LORA:
|
|
1753
2458
|
{
|
|
1754
2459
|
common_lora_adapters_apply(ctx, loras);
|
|
1755
|
-
|
|
1756
|
-
|
|
1757
|
-
|
|
1758
|
-
result.error = false;
|
|
1759
|
-
result.data = json{{ "success", true }};
|
|
1760
|
-
queue_results.send(result);
|
|
2460
|
+
auto res = std::make_unique<server_task_result_apply_lora>();
|
|
2461
|
+
res->id = task.id;
|
|
2462
|
+
queue_results.send(std::move(res));
|
|
1761
2463
|
} break;
|
|
1762
2464
|
}
|
|
1763
2465
|
}
|
|
@@ -1787,10 +2489,8 @@ struct server_context {
|
|
|
1787
2489
|
{
|
|
1788
2490
|
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
|
|
1789
2491
|
|
|
1790
|
-
server_task task;
|
|
1791
|
-
task.
|
|
1792
|
-
task.id_target = -1;
|
|
1793
|
-
|
|
2492
|
+
server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
|
|
2493
|
+
task.id = queue_tasks.get_new_id();
|
|
1794
2494
|
queue_tasks.post(task);
|
|
1795
2495
|
}
|
|
1796
2496
|
|
|
@@ -1798,7 +2498,7 @@ struct server_context {
|
|
|
1798
2498
|
// TODO: simplify and improve
|
|
1799
2499
|
for (server_slot & slot : slots) {
|
|
1800
2500
|
if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
|
|
1801
|
-
if (!
|
|
2501
|
+
if (!params_base.ctx_shift) {
|
|
1802
2502
|
// this check is redundant (for good)
|
|
1803
2503
|
// we should never get here, because generation should already stopped in process_token()
|
|
1804
2504
|
slot.release();
|
|
@@ -1864,7 +2564,7 @@ struct server_context {
|
|
|
1864
2564
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
1865
2565
|
|
|
1866
2566
|
// next, batch any pending prompts without exceeding n_batch
|
|
1867
|
-
if (
|
|
2567
|
+
if (params_base.cont_batching || batch.n_tokens == 0) {
|
|
1868
2568
|
for (auto & slot : slots) {
|
|
1869
2569
|
// this slot still has a prompt to be processed
|
|
1870
2570
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
|
@@ -1904,7 +2604,7 @@ struct server_context {
|
|
|
1904
2604
|
continue;
|
|
1905
2605
|
}
|
|
1906
2606
|
|
|
1907
|
-
if (slot.
|
|
2607
|
+
if (slot.is_non_causal()) {
|
|
1908
2608
|
if (slot.n_prompt_tokens > n_ubatch) {
|
|
1909
2609
|
slot.release();
|
|
1910
2610
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
|
@@ -1917,7 +2617,7 @@ struct server_context {
|
|
|
1917
2617
|
continue;
|
|
1918
2618
|
}
|
|
1919
2619
|
} else {
|
|
1920
|
-
if (!
|
|
2620
|
+
if (!params_base.ctx_shift) {
|
|
1921
2621
|
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
|
1922
2622
|
// TODO: there should be a separate parameter that control prompt truncation
|
|
1923
2623
|
// context shift should be applied only during the generation phase
|
|
@@ -1960,14 +2660,14 @@ struct server_context {
|
|
|
1960
2660
|
|
|
1961
2661
|
if (slot.params.cache_prompt) {
|
|
1962
2662
|
// reuse any previously computed tokens that are common with the new prompt
|
|
1963
|
-
slot.n_past =
|
|
2663
|
+
slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
|
|
1964
2664
|
|
|
1965
2665
|
// reuse chunks from the cached prompt by shifting their KV cache in the new position
|
|
1966
|
-
if (
|
|
2666
|
+
if (params_base.n_cache_reuse > 0) {
|
|
1967
2667
|
size_t head_c = slot.n_past; // cache
|
|
1968
2668
|
size_t head_p = slot.n_past; // current prompt
|
|
1969
2669
|
|
|
1970
|
-
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n",
|
|
2670
|
+
SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
|
|
1971
2671
|
|
|
1972
2672
|
while (head_c < slot.cache_tokens.size() &&
|
|
1973
2673
|
head_p < prompt_tokens.size()) {
|
|
@@ -1980,7 +2680,7 @@ struct server_context {
|
|
|
1980
2680
|
n_match++;
|
|
1981
2681
|
}
|
|
1982
2682
|
|
|
1983
|
-
if (n_match >= (size_t)
|
|
2683
|
+
if (n_match >= (size_t) params_base.n_cache_reuse) {
|
|
1984
2684
|
SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
|
|
1985
2685
|
//for (size_t i = head_p; i < head_p + n_match; i++) {
|
|
1986
2686
|
// SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
|
|
@@ -2019,7 +2719,7 @@ struct server_context {
|
|
|
2019
2719
|
}
|
|
2020
2720
|
|
|
2021
2721
|
// non-causal tasks require to fit the entire prompt in the physical batch
|
|
2022
|
-
if (slot.
|
|
2722
|
+
if (slot.is_non_causal()) {
|
|
2023
2723
|
// cannot fit the prompt in the current batch - will try next iter
|
|
2024
2724
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
|
2025
2725
|
continue;
|
|
@@ -2027,10 +2727,7 @@ struct server_context {
|
|
|
2027
2727
|
}
|
|
2028
2728
|
|
|
2029
2729
|
// check that we are in the right batch_type, if not defer the slot
|
|
2030
|
-
|
|
2031
|
-
slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
|
|
2032
|
-
slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
|
|
2033
|
-
|
|
2730
|
+
int slot_type = slot.is_non_causal();
|
|
2034
2731
|
if (batch_type == -1) {
|
|
2035
2732
|
batch_type = slot_type;
|
|
2036
2733
|
} else if (batch_type != slot_type) {
|
|
@@ -2053,7 +2750,10 @@ struct server_context {
|
|
|
2053
2750
|
|
|
2054
2751
|
// add prompt tokens for processing in the current batch
|
|
2055
2752
|
while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
|
|
2056
|
-
|
|
2753
|
+
// without pooling, we want to output the embeddings for all the tokens in the batch
|
|
2754
|
+
const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
|
|
2755
|
+
|
|
2756
|
+
common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
|
|
2057
2757
|
|
|
2058
2758
|
if (slot.params.cache_prompt) {
|
|
2059
2759
|
slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
|
|
@@ -2147,7 +2847,7 @@ struct server_context {
|
|
|
2147
2847
|
}
|
|
2148
2848
|
|
|
2149
2849
|
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
|
2150
|
-
if (slot.
|
|
2850
|
+
if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
|
|
2151
2851
|
// prompt evaluated for embedding
|
|
2152
2852
|
send_embedding(slot, batch_view);
|
|
2153
2853
|
slot.release();
|
|
@@ -2155,7 +2855,7 @@ struct server_context {
|
|
|
2155
2855
|
continue; // continue loop of slots
|
|
2156
2856
|
}
|
|
2157
2857
|
|
|
2158
|
-
if (slot.
|
|
2858
|
+
if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
|
|
2159
2859
|
send_rerank(slot, batch_view);
|
|
2160
2860
|
slot.release();
|
|
2161
2861
|
slot.i_batch = -1;
|
|
@@ -2168,27 +2868,33 @@ struct server_context {
|
|
|
2168
2868
|
continue; // continue loop of slots
|
|
2169
2869
|
}
|
|
2170
2870
|
|
|
2171
|
-
|
|
2172
|
-
|
|
2871
|
+
const int tok_idx = slot.i_batch - i;
|
|
2872
|
+
|
|
2873
|
+
llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
|
|
2874
|
+
|
|
2875
|
+
slot.i_batch = -1;
|
|
2173
2876
|
|
|
2174
2877
|
common_sampler_accept(slot.smpl, id, true);
|
|
2175
2878
|
|
|
2176
2879
|
slot.n_decoded += 1;
|
|
2880
|
+
|
|
2881
|
+
const int64_t t_current = ggml_time_us();
|
|
2882
|
+
|
|
2177
2883
|
if (slot.n_decoded == 1) {
|
|
2178
|
-
slot.t_start_generation =
|
|
2884
|
+
slot.t_start_generation = t_current;
|
|
2179
2885
|
slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
|
|
2180
2886
|
metrics.on_prompt_eval(slot);
|
|
2181
2887
|
}
|
|
2182
2888
|
|
|
2183
|
-
|
|
2889
|
+
slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
|
|
2184
2890
|
|
|
2185
|
-
|
|
2891
|
+
completion_token_output result;
|
|
2892
|
+
result.tok = id;
|
|
2893
|
+
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
|
2894
|
+
result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
|
|
2186
2895
|
|
|
2187
|
-
|
|
2188
|
-
result.
|
|
2189
|
-
cur_p->data[i].id,
|
|
2190
|
-
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
|
2191
|
-
});
|
|
2896
|
+
if (slot.params.sampling.n_probs > 0) {
|
|
2897
|
+
populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
|
|
2192
2898
|
}
|
|
2193
2899
|
|
|
2194
2900
|
if (!process_token(result, slot)) {
|
|
@@ -2197,9 +2903,98 @@ struct server_context {
|
|
|
2197
2903
|
slot.print_timings();
|
|
2198
2904
|
send_final_response(slot);
|
|
2199
2905
|
metrics.on_prediction(slot);
|
|
2906
|
+
continue;
|
|
2200
2907
|
}
|
|
2908
|
+
}
|
|
2201
2909
|
|
|
2202
|
-
|
|
2910
|
+
// do speculative decoding
|
|
2911
|
+
for (auto & slot : slots) {
|
|
2912
|
+
if (!slot.is_processing() || !slot.can_speculate()) {
|
|
2913
|
+
continue;
|
|
2914
|
+
}
|
|
2915
|
+
|
|
2916
|
+
if (slot.state != SLOT_STATE_GENERATING) {
|
|
2917
|
+
continue;
|
|
2918
|
+
}
|
|
2919
|
+
|
|
2920
|
+
// determine the max draft that fits the current slot state
|
|
2921
|
+
int n_draft_max = slot.params.speculative.n_max;
|
|
2922
|
+
|
|
2923
|
+
// note: n_past is not yet increased for the `id` token sampled above
|
|
2924
|
+
// also, need to leave space for 1 extra token to allow context shifts
|
|
2925
|
+
n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
|
|
2926
|
+
|
|
2927
|
+
if (slot.n_remaining > 0) {
|
|
2928
|
+
n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
|
|
2929
|
+
}
|
|
2930
|
+
|
|
2931
|
+
SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
|
|
2932
|
+
|
|
2933
|
+
if (n_draft_max < slot.params.speculative.n_min) {
|
|
2934
|
+
SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
|
|
2935
|
+
|
|
2936
|
+
continue;
|
|
2937
|
+
}
|
|
2938
|
+
|
|
2939
|
+
llama_token id = slot.sampled;
|
|
2940
|
+
|
|
2941
|
+
struct common_speculative_params params_spec;
|
|
2942
|
+
params_spec.n_draft = n_draft_max;
|
|
2943
|
+
params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
|
|
2944
|
+
params_spec.p_min = slot.params.speculative.p_min;
|
|
2945
|
+
|
|
2946
|
+
llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
|
|
2947
|
+
|
|
2948
|
+
// ignore small drafts
|
|
2949
|
+
if (slot.params.speculative.n_min > (int) draft.size()) {
|
|
2950
|
+
SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
|
|
2951
|
+
|
|
2952
|
+
continue;
|
|
2953
|
+
}
|
|
2954
|
+
|
|
2955
|
+
// construct the speculation batch
|
|
2956
|
+
common_batch_clear(slot.batch_spec);
|
|
2957
|
+
common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
|
|
2958
|
+
|
|
2959
|
+
for (size_t i = 0; i < draft.size(); ++i) {
|
|
2960
|
+
common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
|
|
2961
|
+
}
|
|
2962
|
+
|
|
2963
|
+
SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
|
|
2964
|
+
|
|
2965
|
+
llama_decode(ctx, slot.batch_spec);
|
|
2966
|
+
|
|
2967
|
+
// the accepted tokens from the speculation
|
|
2968
|
+
const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
|
|
2969
|
+
|
|
2970
|
+
slot.n_past += ids.size();
|
|
2971
|
+
slot.n_decoded += ids.size();
|
|
2972
|
+
|
|
2973
|
+
slot.cache_tokens.push_back(id);
|
|
2974
|
+
slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
|
|
2975
|
+
|
|
2976
|
+
llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
|
|
2977
|
+
|
|
2978
|
+
for (size_t i = 0; i < ids.size(); ++i) {
|
|
2979
|
+
completion_token_output result;
|
|
2980
|
+
|
|
2981
|
+
result.tok = ids[i];
|
|
2982
|
+
result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
|
|
2983
|
+
result.prob = 1.0f; // set later
|
|
2984
|
+
|
|
2985
|
+
// TODO: set result.probs
|
|
2986
|
+
|
|
2987
|
+
if (!process_token(result, slot)) {
|
|
2988
|
+
// release slot because of stop condition
|
|
2989
|
+
slot.release();
|
|
2990
|
+
slot.print_timings();
|
|
2991
|
+
send_final_response(slot);
|
|
2992
|
+
metrics.on_prediction(slot);
|
|
2993
|
+
break;
|
|
2994
|
+
}
|
|
2995
|
+
}
|
|
2996
|
+
|
|
2997
|
+
SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
|
|
2203
2998
|
}
|
|
2204
2999
|
}
|
|
2205
3000
|
|
|
@@ -2254,17 +3049,9 @@ int main(int argc, char ** argv) {
|
|
|
2254
3049
|
|
|
2255
3050
|
common_init();
|
|
2256
3051
|
|
|
2257
|
-
// enabling this will output extra debug information in the HTTP responses from the server
|
|
2258
|
-
// see format_final_response_oaicompat()
|
|
2259
|
-
const bool verbose = params.verbosity > 9;
|
|
2260
|
-
|
|
2261
3052
|
// struct that contains llama context and inference
|
|
2262
3053
|
server_context ctx_server;
|
|
2263
3054
|
|
|
2264
|
-
if (params.model_alias == "unknown") {
|
|
2265
|
-
params.model_alias = params.model;
|
|
2266
|
-
}
|
|
2267
|
-
|
|
2268
3055
|
llama_backend_init();
|
|
2269
3056
|
llama_numa_init(params.numa);
|
|
2270
3057
|
|
|
@@ -2273,16 +3060,6 @@ int main(int argc, char ** argv) {
|
|
|
2273
3060
|
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
|
|
2274
3061
|
LOG_INF("\n");
|
|
2275
3062
|
|
|
2276
|
-
// static files
|
|
2277
|
-
std::map<std::string, server_static_file> static_files = {
|
|
2278
|
-
{ "/", { index_html, index_html_len, "text/html; charset=utf-8" }},
|
|
2279
|
-
{ "/completion.js", { completion_js, completion_js_len, "text/javascript; charset=utf-8" }},
|
|
2280
|
-
{ "/deps_daisyui.min.css", { deps_daisyui_min_css, deps_daisyui_min_css_len, "text/css; charset=utf-8" }},
|
|
2281
|
-
{ "/deps_markdown-it.js", { deps_markdown_it_js, deps_markdown_it_js_len, "text/javascript; charset=utf-8" }},
|
|
2282
|
-
{ "/deps_tailwindcss.js", { deps_tailwindcss_js, deps_tailwindcss_js_len, "text/javascript; charset=utf-8" }},
|
|
2283
|
-
{ "/deps_vue.esm-browser.js", { deps_vue_esm_browser_js, deps_vue_esm_browser_js_len, "text/javascript; charset=utf-8" }},
|
|
2284
|
-
};
|
|
2285
|
-
|
|
2286
3063
|
std::unique_ptr<httplib::Server> svr;
|
|
2287
3064
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
2288
3065
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
|
@@ -2309,20 +3086,20 @@ int main(int argc, char ** argv) {
|
|
|
2309
3086
|
|
|
2310
3087
|
auto res_error = [](httplib::Response & res, const json & error_data) {
|
|
2311
3088
|
json final_response {{"error", error_data}};
|
|
2312
|
-
res.set_content(final_response
|
|
3089
|
+
res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
|
|
2313
3090
|
res.status = json_value(error_data, "code", 500);
|
|
2314
3091
|
};
|
|
2315
3092
|
|
|
2316
3093
|
auto res_ok = [](httplib::Response & res, const json & data) {
|
|
2317
|
-
res.set_content(data
|
|
3094
|
+
res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
|
|
2318
3095
|
res.status = 200;
|
|
2319
3096
|
};
|
|
2320
3097
|
|
|
2321
|
-
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
|
3098
|
+
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
|
|
2322
3099
|
std::string message;
|
|
2323
3100
|
try {
|
|
2324
3101
|
std::rethrow_exception(ep);
|
|
2325
|
-
} catch (std::exception & e) {
|
|
3102
|
+
} catch (const std::exception & e) {
|
|
2326
3103
|
message = e.what();
|
|
2327
3104
|
} catch (...) {
|
|
2328
3105
|
message = "Unknown Exception";
|
|
@@ -2363,7 +3140,7 @@ int main(int argc, char ** argv) {
|
|
|
2363
3140
|
// Middlewares
|
|
2364
3141
|
//
|
|
2365
3142
|
|
|
2366
|
-
auto middleware_validate_api_key = [¶ms, &res_error
|
|
3143
|
+
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
2367
3144
|
static const std::unordered_set<std::string> public_endpoints = {
|
|
2368
3145
|
"/health",
|
|
2369
3146
|
"/models",
|
|
@@ -2376,7 +3153,7 @@ int main(int argc, char ** argv) {
|
|
|
2376
3153
|
}
|
|
2377
3154
|
|
|
2378
3155
|
// If path is public or is static file, skip validation
|
|
2379
|
-
if (public_endpoints.find(req.path) != public_endpoints.end() ||
|
|
3156
|
+
if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
|
|
2380
3157
|
return true;
|
|
2381
3158
|
}
|
|
2382
3159
|
|
|
@@ -2451,27 +3228,33 @@ int main(int argc, char ** argv) {
|
|
|
2451
3228
|
}
|
|
2452
3229
|
|
|
2453
3230
|
// request slots data using task queue
|
|
2454
|
-
server_task task;
|
|
3231
|
+
server_task task(SERVER_TASK_TYPE_METRICS);
|
|
2455
3232
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2456
|
-
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2457
|
-
|
|
2458
3233
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2459
3234
|
ctx_server.queue_tasks.post(task, true); // high-priority task
|
|
2460
3235
|
|
|
2461
3236
|
// get the result
|
|
2462
|
-
|
|
3237
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
2463
3238
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2464
3239
|
|
|
3240
|
+
if (result->is_error()) {
|
|
3241
|
+
res_error(res, result->to_json());
|
|
3242
|
+
return;
|
|
3243
|
+
}
|
|
3244
|
+
|
|
3245
|
+
// TODO: get rid of this dynamic_cast
|
|
3246
|
+
auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
|
|
3247
|
+
GGML_ASSERT(res_metrics != nullptr);
|
|
3248
|
+
|
|
2465
3249
|
// optionally return "fail_on_no_slot" error
|
|
2466
|
-
const int n_idle_slots = result.data.at("idle");
|
|
2467
3250
|
if (req.has_param("fail_on_no_slot")) {
|
|
2468
|
-
if (n_idle_slots == 0) {
|
|
3251
|
+
if (res_metrics->n_idle_slots == 0) {
|
|
2469
3252
|
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
|
2470
3253
|
return;
|
|
2471
3254
|
}
|
|
2472
3255
|
}
|
|
2473
3256
|
|
|
2474
|
-
res_ok(res,
|
|
3257
|
+
res_ok(res, res_metrics->slots_data);
|
|
2475
3258
|
};
|
|
2476
3259
|
|
|
2477
3260
|
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
|
@@ -2481,83 +3264,77 @@ int main(int argc, char ** argv) {
|
|
|
2481
3264
|
}
|
|
2482
3265
|
|
|
2483
3266
|
// request slots data using task queue
|
|
2484
|
-
server_task task;
|
|
3267
|
+
server_task task(SERVER_TASK_TYPE_METRICS);
|
|
2485
3268
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2486
|
-
task.
|
|
2487
|
-
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2488
|
-
task.data.push_back({{"reset_bucket", true}});
|
|
3269
|
+
task.metrics_reset_bucket = true;
|
|
2489
3270
|
|
|
2490
3271
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2491
3272
|
ctx_server.queue_tasks.post(task, true); // high-priority task
|
|
2492
3273
|
|
|
2493
3274
|
// get the result
|
|
2494
|
-
|
|
3275
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
2495
3276
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2496
3277
|
|
|
2497
|
-
|
|
2498
|
-
|
|
2499
|
-
|
|
2500
|
-
|
|
2501
|
-
|
|
2502
|
-
const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
|
|
2503
|
-
const uint64_t t_tokens_generation = data.at("t_tokens_generation");
|
|
2504
|
-
|
|
2505
|
-
const uint64_t n_decode_total = data.at("n_decode_total");
|
|
2506
|
-
const uint64_t n_busy_slots_total = data.at("n_busy_slots_total");
|
|
3278
|
+
if (result->is_error()) {
|
|
3279
|
+
res_error(res, result->to_json());
|
|
3280
|
+
return;
|
|
3281
|
+
}
|
|
2507
3282
|
|
|
2508
|
-
|
|
3283
|
+
// TODO: get rid of this dynamic_cast
|
|
3284
|
+
auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
|
|
3285
|
+
GGML_ASSERT(res_metrics != nullptr);
|
|
2509
3286
|
|
|
2510
3287
|
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
|
|
2511
3288
|
json all_metrics_def = json {
|
|
2512
3289
|
{"counter", {{
|
|
2513
3290
|
{"name", "prompt_tokens_total"},
|
|
2514
3291
|
{"help", "Number of prompt tokens processed."},
|
|
2515
|
-
{"value", (uint64_t)
|
|
3292
|
+
{"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total}
|
|
2516
3293
|
}, {
|
|
2517
3294
|
{"name", "prompt_seconds_total"},
|
|
2518
3295
|
{"help", "Prompt process time"},
|
|
2519
|
-
{"value", (uint64_t)
|
|
3296
|
+
{"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3}
|
|
2520
3297
|
}, {
|
|
2521
3298
|
{"name", "tokens_predicted_total"},
|
|
2522
3299
|
{"help", "Number of generation tokens processed."},
|
|
2523
|
-
{"value", (uint64_t)
|
|
3300
|
+
{"value", (uint64_t) res_metrics->n_tokens_predicted_total}
|
|
2524
3301
|
}, {
|
|
2525
3302
|
{"name", "tokens_predicted_seconds_total"},
|
|
2526
3303
|
{"help", "Predict process time"},
|
|
2527
|
-
{"value", (uint64_t)
|
|
3304
|
+
{"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3}
|
|
2528
3305
|
}, {
|
|
2529
3306
|
{"name", "n_decode_total"},
|
|
2530
3307
|
{"help", "Total number of llama_decode() calls"},
|
|
2531
|
-
{"value", n_decode_total}
|
|
3308
|
+
{"value", res_metrics->n_decode_total}
|
|
2532
3309
|
}, {
|
|
2533
3310
|
{"name", "n_busy_slots_per_decode"},
|
|
2534
3311
|
{"help", "Average number of busy slots per llama_decode() call"},
|
|
2535
|
-
{"value", (float) n_busy_slots_total / (float) n_decode_total}
|
|
3312
|
+
{"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total}
|
|
2536
3313
|
}}},
|
|
2537
3314
|
{"gauge", {{
|
|
2538
3315
|
{"name", "prompt_tokens_seconds"},
|
|
2539
3316
|
{"help", "Average prompt throughput in tokens/s."},
|
|
2540
|
-
{"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
|
|
3317
|
+
{"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.}
|
|
2541
3318
|
},{
|
|
2542
3319
|
{"name", "predicted_tokens_seconds"},
|
|
2543
3320
|
{"help", "Average generation throughput in tokens/s."},
|
|
2544
|
-
{"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
|
|
3321
|
+
{"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.}
|
|
2545
3322
|
},{
|
|
2546
3323
|
{"name", "kv_cache_usage_ratio"},
|
|
2547
3324
|
{"help", "KV-cache usage. 1 means 100 percent usage."},
|
|
2548
|
-
{"value", 1. * kv_cache_used_cells / params.n_ctx}
|
|
3325
|
+
{"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx}
|
|
2549
3326
|
},{
|
|
2550
3327
|
{"name", "kv_cache_tokens"},
|
|
2551
3328
|
{"help", "KV-cache tokens."},
|
|
2552
|
-
{"value", (uint64_t)
|
|
3329
|
+
{"value", (uint64_t) res_metrics->kv_cache_tokens_count}
|
|
2553
3330
|
},{
|
|
2554
3331
|
{"name", "requests_processing"},
|
|
2555
3332
|
{"help", "Number of request processing."},
|
|
2556
|
-
{"value", (uint64_t)
|
|
3333
|
+
{"value", (uint64_t) res_metrics->n_processing_slots}
|
|
2557
3334
|
},{
|
|
2558
3335
|
{"name", "requests_deferred"},
|
|
2559
3336
|
{"help", "Number of request deferred."},
|
|
2560
|
-
{"value", (uint64_t)
|
|
3337
|
+
{"value", (uint64_t) res_metrics->n_tasks_deferred}
|
|
2561
3338
|
}}}
|
|
2562
3339
|
};
|
|
2563
3340
|
|
|
@@ -2578,8 +3355,7 @@ int main(int argc, char ** argv) {
|
|
|
2578
3355
|
}
|
|
2579
3356
|
}
|
|
2580
3357
|
|
|
2581
|
-
|
|
2582
|
-
res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
|
|
3358
|
+
res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start));
|
|
2583
3359
|
|
|
2584
3360
|
res.set_content(prometheus.str(), "text/plain; version=0.0.4");
|
|
2585
3361
|
res.status = 200; // HTTP OK
|
|
@@ -2594,25 +3370,24 @@ int main(int argc, char ** argv) {
|
|
|
2594
3370
|
}
|
|
2595
3371
|
std::string filepath = params.slot_save_path + filename;
|
|
2596
3372
|
|
|
2597
|
-
server_task task;
|
|
2598
|
-
task.
|
|
2599
|
-
task.
|
|
2600
|
-
|
|
2601
|
-
|
|
2602
|
-
{ "filepath", filepath },
|
|
2603
|
-
};
|
|
3373
|
+
server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
|
|
3374
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3375
|
+
task.slot_action.slot_id = id_slot;
|
|
3376
|
+
task.slot_action.filename = filename;
|
|
3377
|
+
task.slot_action.filepath = filepath;
|
|
2604
3378
|
|
|
2605
|
-
|
|
2606
|
-
ctx_server.
|
|
3379
|
+
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
3380
|
+
ctx_server.queue_tasks.post(task);
|
|
2607
3381
|
|
|
2608
|
-
|
|
2609
|
-
ctx_server.queue_results.remove_waiting_task_id(
|
|
3382
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
3383
|
+
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2610
3384
|
|
|
2611
|
-
if (result
|
|
2612
|
-
res_error(res, result
|
|
2613
|
-
|
|
2614
|
-
res_ok(res, result.data);
|
|
3385
|
+
if (result->is_error()) {
|
|
3386
|
+
res_error(res, result->to_json());
|
|
3387
|
+
return;
|
|
2615
3388
|
}
|
|
3389
|
+
|
|
3390
|
+
res_ok(res, result->to_json());
|
|
2616
3391
|
};
|
|
2617
3392
|
|
|
2618
3393
|
const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
@@ -2624,45 +3399,45 @@ int main(int argc, char ** argv) {
|
|
|
2624
3399
|
}
|
|
2625
3400
|
std::string filepath = params.slot_save_path + filename;
|
|
2626
3401
|
|
|
2627
|
-
server_task task;
|
|
2628
|
-
task.
|
|
2629
|
-
task.
|
|
2630
|
-
|
|
2631
|
-
|
|
2632
|
-
{ "filepath", filepath },
|
|
2633
|
-
};
|
|
3402
|
+
server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
|
|
3403
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3404
|
+
task.slot_action.slot_id = id_slot;
|
|
3405
|
+
task.slot_action.filename = filename;
|
|
3406
|
+
task.slot_action.filepath = filepath;
|
|
2634
3407
|
|
|
2635
|
-
|
|
2636
|
-
ctx_server.
|
|
3408
|
+
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
3409
|
+
ctx_server.queue_tasks.post(task);
|
|
2637
3410
|
|
|
2638
|
-
|
|
2639
|
-
ctx_server.queue_results.remove_waiting_task_id(
|
|
3411
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
3412
|
+
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2640
3413
|
|
|
2641
|
-
if (result
|
|
2642
|
-
res_error(res, result
|
|
2643
|
-
|
|
2644
|
-
res_ok(res, result.data);
|
|
3414
|
+
if (result->is_error()) {
|
|
3415
|
+
res_error(res, result->to_json());
|
|
3416
|
+
return;
|
|
2645
3417
|
}
|
|
3418
|
+
|
|
3419
|
+
GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
|
|
3420
|
+
res_ok(res, result->to_json());
|
|
2646
3421
|
};
|
|
2647
3422
|
|
|
2648
3423
|
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
|
2649
|
-
server_task task;
|
|
2650
|
-
task.
|
|
2651
|
-
task.
|
|
2652
|
-
{ "id_slot", id_slot },
|
|
2653
|
-
};
|
|
3424
|
+
server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
|
|
3425
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3426
|
+
task.slot_action.slot_id = id_slot;
|
|
2654
3427
|
|
|
2655
|
-
|
|
2656
|
-
ctx_server.
|
|
3428
|
+
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
3429
|
+
ctx_server.queue_tasks.post(task);
|
|
2657
3430
|
|
|
2658
|
-
|
|
2659
|
-
ctx_server.queue_results.remove_waiting_task_id(
|
|
3431
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
3432
|
+
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2660
3433
|
|
|
2661
|
-
if (result
|
|
2662
|
-
res_error(res, result
|
|
2663
|
-
|
|
2664
|
-
res_ok(res, result.data);
|
|
3434
|
+
if (result->is_error()) {
|
|
3435
|
+
res_error(res, result->to_json());
|
|
3436
|
+
return;
|
|
2665
3437
|
}
|
|
3438
|
+
|
|
3439
|
+
GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
|
|
3440
|
+
res_ok(res, result->to_json());
|
|
2666
3441
|
};
|
|
2667
3442
|
|
|
2668
3443
|
const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
|
@@ -2695,9 +3470,11 @@ int main(int argc, char ** argv) {
|
|
|
2695
3470
|
};
|
|
2696
3471
|
|
|
2697
3472
|
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
|
3473
|
+
// this endpoint is publicly available, please only return what is safe to be exposed
|
|
2698
3474
|
json data = {
|
|
2699
3475
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
|
2700
|
-
{ "total_slots", ctx_server.
|
|
3476
|
+
{ "total_slots", ctx_server.params_base.n_parallel },
|
|
3477
|
+
{ "model_path", ctx_server.params_base.model },
|
|
2701
3478
|
{ "chat_template", llama_get_chat_template(ctx_server.model) },
|
|
2702
3479
|
};
|
|
2703
3480
|
|
|
@@ -2705,7 +3482,7 @@ int main(int argc, char ** argv) {
|
|
|
2705
3482
|
};
|
|
2706
3483
|
|
|
2707
3484
|
const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
2708
|
-
if (!ctx_server.
|
|
3485
|
+
if (!ctx_server.params_base.endpoint_props) {
|
|
2709
3486
|
res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2710
3487
|
return;
|
|
2711
3488
|
}
|
|
@@ -2717,13 +3494,50 @@ int main(int argc, char ** argv) {
|
|
|
2717
3494
|
res_ok(res, {{ "success", true }});
|
|
2718
3495
|
};
|
|
2719
3496
|
|
|
2720
|
-
|
|
2721
|
-
|
|
3497
|
+
// handle completion-like requests (completion, chat, infill)
|
|
3498
|
+
// we can optionally provide a custom format for partial results and final results
|
|
3499
|
+
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
|
|
3500
|
+
server_task_type type,
|
|
3501
|
+
json & data,
|
|
3502
|
+
httplib::Response & res,
|
|
3503
|
+
bool oaicompat = false,
|
|
3504
|
+
bool oaicompat_chat = false) {
|
|
3505
|
+
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
|
3506
|
+
|
|
3507
|
+
if (ctx_server.params_base.embedding) {
|
|
2722
3508
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2723
3509
|
return;
|
|
2724
3510
|
}
|
|
2725
3511
|
|
|
2726
|
-
|
|
3512
|
+
auto completion_id = gen_chatcmplid();
|
|
3513
|
+
std::vector<server_task> tasks;
|
|
3514
|
+
|
|
3515
|
+
try {
|
|
3516
|
+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
|
|
3517
|
+
tasks.reserve(tokenized_prompts.size());
|
|
3518
|
+
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
|
3519
|
+
server_task task = server_task(type);
|
|
3520
|
+
|
|
3521
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3522
|
+
task.index = i;
|
|
3523
|
+
|
|
3524
|
+
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
|
3525
|
+
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
|
|
3526
|
+
task.id_selected_slot = json_value(data, "id_slot", -1);
|
|
3527
|
+
|
|
3528
|
+
// OAI-compat
|
|
3529
|
+
task.params.oaicompat = oaicompat;
|
|
3530
|
+
task.params.oaicompat_chat = oaicompat_chat;
|
|
3531
|
+
task.params.oaicompat_cmpl_id = completion_id;
|
|
3532
|
+
// oaicompat_model is already populated by params_from_json_cmpl
|
|
3533
|
+
|
|
3534
|
+
tasks.push_back(task);
|
|
3535
|
+
}
|
|
3536
|
+
} catch (const std::exception & e) {
|
|
3537
|
+
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
|
3538
|
+
return;
|
|
3539
|
+
}
|
|
3540
|
+
|
|
2727
3541
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
2728
3542
|
ctx_server.queue_tasks.post(tasks);
|
|
2729
3543
|
|
|
@@ -2731,15 +3545,15 @@ int main(int argc, char ** argv) {
|
|
|
2731
3545
|
const auto task_ids = server_task::get_list_id(tasks);
|
|
2732
3546
|
|
|
2733
3547
|
if (!stream) {
|
|
2734
|
-
ctx_server.
|
|
3548
|
+
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
|
2735
3549
|
if (results.size() == 1) {
|
|
2736
3550
|
// single result
|
|
2737
|
-
res_ok(res, results[0]
|
|
3551
|
+
res_ok(res, results[0]->to_json());
|
|
2738
3552
|
} else {
|
|
2739
3553
|
// multiple results (multitask)
|
|
2740
3554
|
json arr = json::array();
|
|
2741
|
-
for (
|
|
2742
|
-
arr.push_back(res
|
|
3555
|
+
for (auto & res : results) {
|
|
3556
|
+
arr.push_back(res->to_json());
|
|
2743
3557
|
}
|
|
2744
3558
|
res_ok(res, arr);
|
|
2745
3559
|
}
|
|
@@ -2749,12 +3563,26 @@ int main(int argc, char ** argv) {
|
|
|
2749
3563
|
|
|
2750
3564
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
2751
3565
|
} else {
|
|
2752
|
-
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
|
2753
|
-
ctx_server.receive_cmpl_results_stream(task_ids, [&](
|
|
2754
|
-
|
|
3566
|
+
const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
|
|
3567
|
+
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
|
|
3568
|
+
json res_json = result->to_json();
|
|
3569
|
+
if (res_json.is_array()) {
|
|
3570
|
+
for (const auto & res : res_json) {
|
|
3571
|
+
if (!server_sent_event(sink, "data", res)) {
|
|
3572
|
+
return false;
|
|
3573
|
+
}
|
|
3574
|
+
}
|
|
3575
|
+
return true;
|
|
3576
|
+
} else {
|
|
3577
|
+
return server_sent_event(sink, "data", res_json);
|
|
3578
|
+
}
|
|
2755
3579
|
}, [&](const json & error_data) {
|
|
2756
3580
|
server_sent_event(sink, "error", error_data);
|
|
2757
3581
|
});
|
|
3582
|
+
if (oaicompat) {
|
|
3583
|
+
static const std::string ev_done = "data: [DONE]\n\n";
|
|
3584
|
+
sink.write(ev_done.data(), ev_done.size());
|
|
3585
|
+
}
|
|
2758
3586
|
sink.done();
|
|
2759
3587
|
return false;
|
|
2760
3588
|
};
|
|
@@ -2769,7 +3597,12 @@ int main(int argc, char ** argv) {
|
|
|
2769
3597
|
|
|
2770
3598
|
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
2771
3599
|
json data = json::parse(req.body);
|
|
2772
|
-
return handle_completions_generic(
|
|
3600
|
+
return handle_completions_generic(
|
|
3601
|
+
SERVER_TASK_TYPE_COMPLETION,
|
|
3602
|
+
data,
|
|
3603
|
+
res,
|
|
3604
|
+
/* oaicompat */ false,
|
|
3605
|
+
/* oaicompat_chat */ false);
|
|
2773
3606
|
};
|
|
2774
3607
|
|
|
2775
3608
|
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
@@ -2792,6 +3625,11 @@ int main(int argc, char ** argv) {
|
|
|
2792
3625
|
json data = json::parse(req.body);
|
|
2793
3626
|
|
|
2794
3627
|
// validate input
|
|
3628
|
+
if (data.contains("prompt") && !data.at("prompt").is_string()) {
|
|
3629
|
+
// prompt is optional
|
|
3630
|
+
res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
3631
|
+
}
|
|
3632
|
+
|
|
2795
3633
|
if (!data.contains("input_prefix")) {
|
|
2796
3634
|
res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
|
|
2797
3635
|
}
|
|
@@ -2801,9 +3639,11 @@ int main(int argc, char ** argv) {
|
|
|
2801
3639
|
}
|
|
2802
3640
|
|
|
2803
3641
|
if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
|
|
3642
|
+
// input_extra is optional
|
|
2804
3643
|
res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
|
|
2805
3644
|
return;
|
|
2806
3645
|
}
|
|
3646
|
+
|
|
2807
3647
|
json input_extra = json_value(data, "input_extra", json::array());
|
|
2808
3648
|
for (const auto & chunk : input_extra) {
|
|
2809
3649
|
// { "text": string, "filename": string }
|
|
@@ -2819,67 +3659,40 @@ int main(int argc, char ** argv) {
|
|
|
2819
3659
|
}
|
|
2820
3660
|
data["input_extra"] = input_extra; // default to empty array if it's not exist
|
|
2821
3661
|
|
|
2822
|
-
|
|
3662
|
+
std::string prompt = json_value(data, "prompt", std::string());
|
|
3663
|
+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
|
3664
|
+
SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
|
|
3665
|
+
data["prompt"] = format_infill(
|
|
3666
|
+
ctx_server.ctx,
|
|
3667
|
+
data.at("input_prefix"),
|
|
3668
|
+
data.at("input_suffix"),
|
|
3669
|
+
data.at("input_extra"),
|
|
3670
|
+
ctx_server.params_base.n_batch,
|
|
3671
|
+
ctx_server.params_base.n_predict,
|
|
3672
|
+
ctx_server.slots[0].n_ctx, // TODO: there should be a better way
|
|
3673
|
+
ctx_server.params_base.spm_infill,
|
|
3674
|
+
tokenized_prompts[0]
|
|
3675
|
+
);
|
|
3676
|
+
|
|
3677
|
+
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
|
|
2823
3678
|
};
|
|
2824
3679
|
|
|
2825
|
-
|
|
2826
|
-
|
|
2827
|
-
if (ctx_server.params.embedding) {
|
|
3680
|
+
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
3681
|
+
if (ctx_server.params_base.embedding) {
|
|
2828
3682
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2829
3683
|
return;
|
|
2830
3684
|
}
|
|
2831
3685
|
|
|
2832
3686
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
|
2833
|
-
|
|
2834
|
-
|
|
2835
|
-
|
|
2836
|
-
|
|
2837
|
-
|
|
2838
|
-
|
|
2839
|
-
const auto task_ids = server_task::get_list_id(tasks);
|
|
2840
|
-
const auto completion_id = gen_chatcmplid();
|
|
2841
|
-
|
|
2842
|
-
if (!stream) {
|
|
2843
|
-
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
|
2844
|
-
// multitask is never support in chat completion, there is only one result
|
|
2845
|
-
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
|
2846
|
-
res_ok(res, result_oai);
|
|
2847
|
-
}, [&](const json & error_data) {
|
|
2848
|
-
res_error(res, error_data);
|
|
2849
|
-
});
|
|
2850
|
-
|
|
2851
|
-
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
2852
|
-
} else {
|
|
2853
|
-
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
|
2854
|
-
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
|
2855
|
-
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
|
2856
|
-
for (auto & event_data : result_array) {
|
|
2857
|
-
if (event_data.empty()) {
|
|
2858
|
-
continue; // skip the stop token
|
|
2859
|
-
}
|
|
2860
|
-
if (!server_sent_event(sink, "data", event_data)) {
|
|
2861
|
-
return false; // connection is closed
|
|
2862
|
-
}
|
|
2863
|
-
}
|
|
2864
|
-
return true; // ok
|
|
2865
|
-
}, [&](const json & error_data) {
|
|
2866
|
-
server_sent_event(sink, "error", error_data);
|
|
2867
|
-
});
|
|
2868
|
-
static const std::string ev_done = "data: [DONE]\n\n";
|
|
2869
|
-
sink.write(ev_done.data(), ev_done.size());
|
|
2870
|
-
sink.done();
|
|
2871
|
-
return true;
|
|
2872
|
-
};
|
|
2873
|
-
|
|
2874
|
-
auto on_complete = [task_ids, &ctx_server] (bool) {
|
|
2875
|
-
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
2876
|
-
};
|
|
2877
|
-
|
|
2878
|
-
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
2879
|
-
}
|
|
3687
|
+
return handle_completions_generic(
|
|
3688
|
+
SERVER_TASK_TYPE_COMPLETION,
|
|
3689
|
+
data,
|
|
3690
|
+
res,
|
|
3691
|
+
/* oaicompat */ true,
|
|
3692
|
+
/* oaicompat_chat */ true);
|
|
2880
3693
|
};
|
|
2881
3694
|
|
|
2882
|
-
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
|
3695
|
+
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
|
2883
3696
|
json models = {
|
|
2884
3697
|
{"object", "list"},
|
|
2885
3698
|
{"data", {
|
|
@@ -2893,7 +3706,7 @@ int main(int argc, char ** argv) {
|
|
|
2893
3706
|
}}
|
|
2894
3707
|
};
|
|
2895
3708
|
|
|
2896
|
-
res
|
|
3709
|
+
res_ok(res, models);
|
|
2897
3710
|
};
|
|
2898
3711
|
|
|
2899
3712
|
const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
@@ -2949,37 +3762,63 @@ int main(int argc, char ** argv) {
|
|
|
2949
3762
|
res_ok(res, data);
|
|
2950
3763
|
};
|
|
2951
3764
|
|
|
2952
|
-
const auto
|
|
3765
|
+
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
|
|
2953
3766
|
const json body = json::parse(req.body);
|
|
2954
|
-
bool is_openai = false;
|
|
2955
3767
|
|
|
2956
|
-
|
|
3768
|
+
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
|
3769
|
+
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
|
3770
|
+
return;
|
|
3771
|
+
}
|
|
3772
|
+
|
|
3773
|
+
// for the shape of input/content, see tokenize_input_prompts()
|
|
2957
3774
|
json prompt;
|
|
2958
3775
|
if (body.count("input") != 0) {
|
|
2959
|
-
is_openai = true;
|
|
2960
3776
|
prompt = body.at("input");
|
|
2961
|
-
} else if (body.
|
|
2962
|
-
|
|
2963
|
-
prompt =
|
|
3777
|
+
} else if (body.contains("content")) {
|
|
3778
|
+
oaicompat = false;
|
|
3779
|
+
prompt = body.at("content");
|
|
2964
3780
|
} else {
|
|
2965
3781
|
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
2966
3782
|
return;
|
|
2967
3783
|
}
|
|
2968
3784
|
|
|
3785
|
+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
|
|
3786
|
+
for (const auto & tokens : tokenized_prompts) {
|
|
3787
|
+
// this check is necessary for models that do not add BOS token to the input
|
|
3788
|
+
if (tokens.empty()) {
|
|
3789
|
+
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
|
|
3790
|
+
return;
|
|
3791
|
+
}
|
|
3792
|
+
}
|
|
3793
|
+
|
|
2969
3794
|
// create and queue the task
|
|
2970
3795
|
json responses = json::array();
|
|
2971
3796
|
bool error = false;
|
|
2972
3797
|
{
|
|
2973
|
-
std::vector<server_task> tasks
|
|
3798
|
+
std::vector<server_task> tasks;
|
|
3799
|
+
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
|
3800
|
+
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
|
|
3801
|
+
|
|
3802
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3803
|
+
task.index = i;
|
|
3804
|
+
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
|
3805
|
+
|
|
3806
|
+
// OAI-compat
|
|
3807
|
+
task.params.oaicompat = oaicompat;
|
|
3808
|
+
|
|
3809
|
+
tasks.push_back(task);
|
|
3810
|
+
}
|
|
3811
|
+
|
|
2974
3812
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
2975
3813
|
ctx_server.queue_tasks.post(tasks);
|
|
2976
3814
|
|
|
2977
3815
|
// get the result
|
|
2978
3816
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
2979
3817
|
|
|
2980
|
-
ctx_server.
|
|
2981
|
-
for (
|
|
2982
|
-
|
|
3818
|
+
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
|
3819
|
+
for (auto & res : results) {
|
|
3820
|
+
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
|
3821
|
+
responses.push_back(res->to_json());
|
|
2983
3822
|
}
|
|
2984
3823
|
}, [&](const json & error_data) {
|
|
2985
3824
|
res_error(res, error_data);
|
|
@@ -2994,14 +3833,20 @@ int main(int argc, char ** argv) {
|
|
|
2994
3833
|
}
|
|
2995
3834
|
|
|
2996
3835
|
// write JSON response
|
|
2997
|
-
json root =
|
|
2998
|
-
? format_embeddings_response_oaicompat(body, responses)
|
|
2999
|
-
: responses[0];
|
|
3836
|
+
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
|
|
3000
3837
|
res_ok(res, root);
|
|
3001
3838
|
};
|
|
3002
3839
|
|
|
3840
|
+
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
|
3841
|
+
handle_embeddings_impl(req, res, false);
|
|
3842
|
+
};
|
|
3843
|
+
|
|
3844
|
+
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
|
3845
|
+
handle_embeddings_impl(req, res, true);
|
|
3846
|
+
};
|
|
3847
|
+
|
|
3003
3848
|
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3004
|
-
if (!ctx_server.
|
|
3849
|
+
if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
|
|
3005
3850
|
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3006
3851
|
return;
|
|
3007
3852
|
}
|
|
@@ -3035,29 +3880,33 @@ int main(int argc, char ** argv) {
|
|
|
3035
3880
|
return;
|
|
3036
3881
|
}
|
|
3037
3882
|
|
|
3038
|
-
|
|
3039
|
-
json prompt;
|
|
3040
|
-
prompt.push_back(query);
|
|
3041
|
-
for (const auto & doc : documents) {
|
|
3042
|
-
prompt.push_back(doc);
|
|
3043
|
-
}
|
|
3044
|
-
|
|
3045
|
-
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
|
|
3883
|
+
llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0];
|
|
3046
3884
|
|
|
3047
3885
|
// create and queue the task
|
|
3048
3886
|
json responses = json::array();
|
|
3049
3887
|
bool error = false;
|
|
3050
3888
|
{
|
|
3051
|
-
std::vector<server_task> tasks
|
|
3889
|
+
std::vector<server_task> tasks;
|
|
3890
|
+
std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
|
|
3891
|
+
tasks.reserve(tokenized_docs.size());
|
|
3892
|
+
for (size_t i = 0; i < tokenized_docs.size(); i++) {
|
|
3893
|
+
server_task task = server_task(SERVER_TASK_TYPE_RERANK);
|
|
3894
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3895
|
+
task.index = i;
|
|
3896
|
+
task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
|
|
3897
|
+
tasks.push_back(task);
|
|
3898
|
+
}
|
|
3899
|
+
|
|
3052
3900
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
3053
3901
|
ctx_server.queue_tasks.post(tasks);
|
|
3054
3902
|
|
|
3055
3903
|
// get the result
|
|
3056
3904
|
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
3057
3905
|
|
|
3058
|
-
ctx_server.
|
|
3059
|
-
for (
|
|
3060
|
-
|
|
3906
|
+
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
|
3907
|
+
for (auto & res : results) {
|
|
3908
|
+
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
|
3909
|
+
responses.push_back(res->to_json());
|
|
3061
3910
|
}
|
|
3062
3911
|
}, [&](const json & error_data) {
|
|
3063
3912
|
res_error(res, error_data);
|
|
@@ -3108,36 +3957,47 @@ int main(int argc, char ** argv) {
|
|
|
3108
3957
|
}
|
|
3109
3958
|
}
|
|
3110
3959
|
|
|
3111
|
-
server_task task;
|
|
3112
|
-
task.
|
|
3113
|
-
|
|
3114
|
-
ctx_server.
|
|
3960
|
+
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
|
3961
|
+
task.id = ctx_server.queue_tasks.get_new_id();
|
|
3962
|
+
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
3963
|
+
ctx_server.queue_tasks.post(task);
|
|
3964
|
+
|
|
3965
|
+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
3966
|
+
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
3115
3967
|
|
|
3116
|
-
|
|
3117
|
-
|
|
3968
|
+
if (result->is_error()) {
|
|
3969
|
+
res_error(res, result->to_json());
|
|
3970
|
+
return;
|
|
3971
|
+
}
|
|
3118
3972
|
|
|
3119
|
-
|
|
3120
|
-
res
|
|
3973
|
+
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
|
|
3974
|
+
res_ok(res, result->to_json());
|
|
3121
3975
|
};
|
|
3122
3976
|
|
|
3123
3977
|
//
|
|
3124
3978
|
// Router
|
|
3125
3979
|
//
|
|
3126
3980
|
|
|
3127
|
-
|
|
3128
|
-
|
|
3129
|
-
// Set the base directory for serving static files
|
|
3130
|
-
bool is_found = svr->set_mount_point("/", params.public_path);
|
|
3131
|
-
if (!is_found) {
|
|
3132
|
-
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
|
3133
|
-
return 1;
|
|
3134
|
-
}
|
|
3981
|
+
if (!params.webui) {
|
|
3982
|
+
LOG_INF("Web UI is disabled\n");
|
|
3135
3983
|
} else {
|
|
3136
|
-
//
|
|
3137
|
-
|
|
3138
|
-
|
|
3139
|
-
svr->
|
|
3140
|
-
|
|
3984
|
+
// register static assets routes
|
|
3985
|
+
if (!params.public_path.empty()) {
|
|
3986
|
+
// Set the base directory for serving static files
|
|
3987
|
+
bool is_found = svr->set_mount_point("/", params.public_path);
|
|
3988
|
+
if (!is_found) {
|
|
3989
|
+
LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
|
|
3990
|
+
return 1;
|
|
3991
|
+
}
|
|
3992
|
+
} else {
|
|
3993
|
+
// using embedded static index.html
|
|
3994
|
+
svr->Get("/", [](const httplib::Request & req, httplib::Response & res) {
|
|
3995
|
+
if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
|
|
3996
|
+
res.set_content("Error: gzip is not supported by this browser", "text/plain");
|
|
3997
|
+
} else {
|
|
3998
|
+
res.set_header("Content-Encoding", "gzip");
|
|
3999
|
+
res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
|
|
4000
|
+
}
|
|
3141
4001
|
return false;
|
|
3142
4002
|
});
|
|
3143
4003
|
}
|
|
@@ -3158,7 +4018,7 @@ int main(int argc, char ** argv) {
|
|
|
3158
4018
|
svr->Post("/infill", handle_infill);
|
|
3159
4019
|
svr->Post("/embedding", handle_embeddings); // legacy
|
|
3160
4020
|
svr->Post("/embeddings", handle_embeddings);
|
|
3161
|
-
svr->Post("/v1/embeddings",
|
|
4021
|
+
svr->Post("/v1/embeddings", handle_embeddings_oai);
|
|
3162
4022
|
svr->Post("/rerank", handle_rerank);
|
|
3163
4023
|
svr->Post("/reranking", handle_rerank);
|
|
3164
4024
|
svr->Post("/v1/rerank", handle_rerank);
|
|
@@ -3188,8 +4048,18 @@ int main(int argc, char ** argv) {
|
|
|
3188
4048
|
llama_backend_free();
|
|
3189
4049
|
};
|
|
3190
4050
|
|
|
3191
|
-
// bind HTTP listen port
|
|
3192
|
-
|
|
4051
|
+
// bind HTTP listen port
|
|
4052
|
+
bool was_bound = false;
|
|
4053
|
+
if (params.port == 0) {
|
|
4054
|
+
int bound_port = svr->bind_to_any_port(params.hostname);
|
|
4055
|
+
if ((was_bound = (bound_port >= 0))) {
|
|
4056
|
+
params.port = bound_port;
|
|
4057
|
+
}
|
|
4058
|
+
} else {
|
|
4059
|
+
was_bound = svr->bind_to_port(params.hostname, params.port);
|
|
4060
|
+
}
|
|
4061
|
+
|
|
4062
|
+
if (!was_bound) {
|
|
3193
4063
|
//LOG_ERROR("couldn't bind HTTP server socket", {
|
|
3194
4064
|
// {"hostname", params.hostname},
|
|
3195
4065
|
// {"port", params.port},
|
|
@@ -3198,6 +4068,8 @@ int main(int argc, char ** argv) {
|
|
|
3198
4068
|
clean_up();
|
|
3199
4069
|
return 1;
|
|
3200
4070
|
}
|
|
4071
|
+
|
|
4072
|
+
// run the HTTP server in a thread
|
|
3201
4073
|
std::thread t([&]() { svr->listen_after_bind(); });
|
|
3202
4074
|
svr->wait_until_ready();
|
|
3203
4075
|
|