@fugood/llama.node 0.3.0 → 0.3.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -10
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +6 -4
- package/src/LlamaCompletionWorker.cpp +6 -6
- package/src/LlamaContext.cpp +7 -9
- package/src/common.hpp +2 -1
- package/src/llama.cpp/.github/workflows/build.yml +98 -24
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +43 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +20 -8
- package/src/llama.cpp/common/CMakeLists.txt +12 -10
- package/src/llama.cpp/common/arg.cpp +2006 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +496 -1632
- package/src/llama.cpp/common/common.h +161 -63
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +3 -0
- package/src/llama.cpp/common/sampling.cpp +348 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/common/train.cpp +2 -0
- package/src/llama.cpp/docs/build.md +36 -1
- package/src/llama.cpp/examples/CMakeLists.txt +0 -1
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +39 -55
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
- package/src/llama.cpp/examples/infill/infill.cpp +117 -132
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +685 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
- package/src/llama.cpp/examples/llava/llava.cpp +110 -24
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
- package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
- package/src/llama.cpp/examples/main/main.cpp +210 -262
- package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
- package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
- package/src/llama.cpp/examples/server/server.cpp +1027 -1073
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +107 -105
- package/src/llama.cpp/examples/simple/simple.cpp +35 -41
- package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
- package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
- package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
- package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
- package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
- package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
- package/src/llama.cpp/ggml/include/ggml.h +293 -186
- package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
- package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
- package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
- package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
- package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
- package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
- package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
- package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
- package/src/llama.cpp/include/llama.h +241 -264
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
- package/src/llama.cpp/src/llama-sampling.h +20 -47
- package/src/llama.cpp/src/llama-vocab.cpp +343 -120
- package/src/llama.cpp/src/llama-vocab.h +33 -17
- package/src/llama.cpp/src/llama.cpp +4247 -1525
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +3 -0
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
- package/src/llama.cpp/tests/test-barrier.cpp +93 -0
- package/src/llama.cpp/tests/test-grad0.cpp +187 -70
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
- package/src/llama.cpp/tests/test-rope.cpp +1 -1
- package/src/llama.cpp/tests/test-sampling.cpp +157 -98
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
|
@@ -1,20 +1,17 @@
|
|
|
1
1
|
#include "utils.hpp"
|
|
2
2
|
|
|
3
|
+
#include "arg.h"
|
|
3
4
|
#include "common.h"
|
|
5
|
+
#include "log.h"
|
|
6
|
+
#include "sampling.h"
|
|
4
7
|
#include "json-schema-to-grammar.h"
|
|
5
8
|
#include "llama.h"
|
|
6
|
-
#include "grammar-parser.h"
|
|
7
9
|
|
|
8
|
-
#ifndef NDEBUG
|
|
9
|
-
// crash the server in debug mode, otherwise send an http 500 error
|
|
10
|
-
#define CPPHTTPLIB_NO_EXCEPTIONS 1
|
|
11
|
-
#endif
|
|
12
|
-
// increase max payload length to allow use of larger context size
|
|
13
|
-
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
|
|
14
|
-
#include "httplib.h"
|
|
15
10
|
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
|
16
11
|
#define JSON_ASSERT GGML_ASSERT
|
|
17
12
|
#include "json.hpp"
|
|
13
|
+
// mime type for sending response
|
|
14
|
+
#define MIMETYPE_JSON "application/json; charset=utf-8"
|
|
18
15
|
|
|
19
16
|
// auto generated files (update with ./deps.sh)
|
|
20
17
|
#include "colorthemes.css.hpp"
|
|
@@ -32,42 +29,53 @@
|
|
|
32
29
|
#include "system-prompts.js.hpp"
|
|
33
30
|
#include "prompt-formats.js.hpp"
|
|
34
31
|
#include "json-schema-to-grammar.mjs.hpp"
|
|
32
|
+
#include "loading.html.hpp"
|
|
35
33
|
|
|
36
34
|
#include <atomic>
|
|
37
|
-
#include <chrono>
|
|
38
35
|
#include <condition_variable>
|
|
39
36
|
#include <cstddef>
|
|
40
|
-
#include <
|
|
37
|
+
#include <cinttypes>
|
|
38
|
+
#include <deque>
|
|
39
|
+
#include <memory>
|
|
41
40
|
#include <mutex>
|
|
42
|
-
#include <thread>
|
|
43
41
|
#include <signal.h>
|
|
44
|
-
#include <
|
|
42
|
+
#include <thread>
|
|
43
|
+
#include <unordered_map>
|
|
44
|
+
#include <unordered_set>
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
47
|
+
#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
48
|
+
#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
49
|
+
#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
|
|
50
|
+
|
|
51
|
+
#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
52
|
+
#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
53
|
+
#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
54
|
+
#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
55
|
+
|
|
56
|
+
#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
57
|
+
#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
58
|
+
#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
59
|
+
#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
|
|
47
60
|
|
|
48
|
-
|
|
49
|
-
bool server_log_json = true;
|
|
61
|
+
using json = nlohmann::ordered_json;
|
|
50
62
|
|
|
51
63
|
enum stop_type {
|
|
52
64
|
STOP_TYPE_FULL,
|
|
53
65
|
STOP_TYPE_PARTIAL,
|
|
54
66
|
};
|
|
55
67
|
|
|
68
|
+
// state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
|
|
56
69
|
enum slot_state {
|
|
57
70
|
SLOT_STATE_IDLE,
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
enum slot_command {
|
|
62
|
-
SLOT_COMMAND_NONE,
|
|
63
|
-
SLOT_COMMAND_LOAD_PROMPT,
|
|
64
|
-
SLOT_COMMAND_RELEASE,
|
|
71
|
+
SLOT_STATE_PROCESSING_PROMPT,
|
|
72
|
+
SLOT_STATE_DONE_PROMPT,
|
|
73
|
+
SLOT_STATE_GENERATING,
|
|
65
74
|
};
|
|
66
75
|
|
|
67
76
|
enum server_state {
|
|
68
77
|
SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
|
|
69
78
|
SERVER_STATE_READY, // Server is ready and model is loaded
|
|
70
|
-
SERVER_STATE_ERROR // An error occurred, load_model failed
|
|
71
79
|
};
|
|
72
80
|
|
|
73
81
|
enum server_task_type {
|
|
@@ -78,23 +86,37 @@ enum server_task_type {
|
|
|
78
86
|
SERVER_TASK_TYPE_SLOT_SAVE,
|
|
79
87
|
SERVER_TASK_TYPE_SLOT_RESTORE,
|
|
80
88
|
SERVER_TASK_TYPE_SLOT_ERASE,
|
|
89
|
+
SERVER_TASK_TYPE_SET_LORA,
|
|
90
|
+
};
|
|
91
|
+
|
|
92
|
+
enum server_task_cmpl_type {
|
|
93
|
+
SERVER_TASK_CMPL_TYPE_NORMAL,
|
|
94
|
+
SERVER_TASK_CMPL_TYPE_EMBEDDING,
|
|
95
|
+
SERVER_TASK_CMPL_TYPE_RERANK,
|
|
96
|
+
SERVER_TASK_CMPL_TYPE_INFILL,
|
|
81
97
|
};
|
|
82
98
|
|
|
83
99
|
struct server_task {
|
|
84
100
|
int id = -1; // to be filled by server_queue
|
|
85
|
-
int
|
|
86
|
-
int id_target = -1;
|
|
101
|
+
int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
|
|
87
102
|
|
|
88
103
|
server_task_type type;
|
|
89
104
|
json data;
|
|
90
105
|
|
|
91
|
-
|
|
92
|
-
|
|
106
|
+
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
|
107
|
+
|
|
108
|
+
// utility function
|
|
109
|
+
static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
|
|
110
|
+
std::unordered_set<int> ids(tasks.size());
|
|
111
|
+
for (size_t i = 0; i < tasks.size(); i++) {
|
|
112
|
+
ids.insert(tasks[i].id);
|
|
113
|
+
}
|
|
114
|
+
return ids;
|
|
115
|
+
}
|
|
93
116
|
};
|
|
94
117
|
|
|
95
118
|
struct server_task_result {
|
|
96
119
|
int id = -1;
|
|
97
|
-
int id_multi = -1;
|
|
98
120
|
|
|
99
121
|
json data;
|
|
100
122
|
|
|
@@ -102,13 +124,6 @@ struct server_task_result {
|
|
|
102
124
|
bool error;
|
|
103
125
|
};
|
|
104
126
|
|
|
105
|
-
struct server_task_multi {
|
|
106
|
-
int id = -1;
|
|
107
|
-
|
|
108
|
-
std::set<int> subtasks_remaining;
|
|
109
|
-
std::vector<server_task_result> results;
|
|
110
|
-
};
|
|
111
|
-
|
|
112
127
|
struct slot_params {
|
|
113
128
|
bool stream = true;
|
|
114
129
|
bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
|
|
@@ -126,12 +141,13 @@ struct slot_params {
|
|
|
126
141
|
struct server_slot {
|
|
127
142
|
int id;
|
|
128
143
|
int id_task = -1;
|
|
129
|
-
|
|
144
|
+
|
|
145
|
+
// the index relative to completion multi-task request
|
|
146
|
+
size_t index = 0;
|
|
130
147
|
|
|
131
148
|
struct slot_params params;
|
|
132
149
|
|
|
133
150
|
slot_state state = SLOT_STATE_IDLE;
|
|
134
|
-
slot_command command = SLOT_COMMAND_NONE;
|
|
135
151
|
|
|
136
152
|
// used to determine the slot that has been used the longest
|
|
137
153
|
int64_t t_last_used = -1;
|
|
@@ -156,8 +172,8 @@ struct server_slot {
|
|
|
156
172
|
std::vector<llama_token> cache_tokens;
|
|
157
173
|
std::vector<completion_token_output> generated_token_probs;
|
|
158
174
|
|
|
159
|
-
|
|
160
|
-
|
|
175
|
+
server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
|
176
|
+
|
|
161
177
|
bool has_next_token = true;
|
|
162
178
|
bool truncated = false;
|
|
163
179
|
bool stopped_eos = false;
|
|
@@ -170,11 +186,13 @@ struct server_slot {
|
|
|
170
186
|
std::string stopping_word;
|
|
171
187
|
|
|
172
188
|
// sampling
|
|
173
|
-
llama_token sampled;
|
|
174
|
-
struct llama_sampling_params sparams;
|
|
175
|
-
llama_sampling_context * ctx_sampling = nullptr;
|
|
176
189
|
json json_schema;
|
|
177
190
|
|
|
191
|
+
struct gpt_sampler_params sparams;
|
|
192
|
+
struct gpt_sampler * smpl = nullptr;
|
|
193
|
+
|
|
194
|
+
llama_token sampled;
|
|
195
|
+
|
|
178
196
|
int32_t ga_i = 0; // group-attention state
|
|
179
197
|
int32_t ga_n = 1; // group-attention factor
|
|
180
198
|
int32_t ga_w = 512; // group-attention width
|
|
@@ -191,7 +209,11 @@ struct server_slot {
|
|
|
191
209
|
double t_prompt_processing; // ms
|
|
192
210
|
double t_token_generation; // ms
|
|
193
211
|
|
|
212
|
+
std::function<void(int)> callback_on_release;
|
|
213
|
+
|
|
194
214
|
void reset() {
|
|
215
|
+
SLT_DBG(*this, "%s", "\n");
|
|
216
|
+
|
|
195
217
|
n_prompt_tokens = 0;
|
|
196
218
|
generated_text = "";
|
|
197
219
|
truncated = false;
|
|
@@ -202,7 +224,7 @@ struct server_slot {
|
|
|
202
224
|
n_past = 0;
|
|
203
225
|
n_sent_text = 0;
|
|
204
226
|
n_sent_token_probs = 0;
|
|
205
|
-
|
|
227
|
+
cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
|
|
206
228
|
ga_i = 0;
|
|
207
229
|
n_past_se = 0;
|
|
208
230
|
|
|
@@ -225,25 +247,25 @@ struct server_slot {
|
|
|
225
247
|
return n_remaining > 0; // no budget
|
|
226
248
|
}
|
|
227
249
|
|
|
228
|
-
bool available() const {
|
|
229
|
-
return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
|
|
230
|
-
}
|
|
231
|
-
|
|
232
250
|
bool is_processing() const {
|
|
233
|
-
return
|
|
251
|
+
return state != SLOT_STATE_IDLE;
|
|
234
252
|
}
|
|
235
253
|
|
|
236
|
-
void
|
|
237
|
-
if (
|
|
254
|
+
void add_token(const completion_token_output & token) {
|
|
255
|
+
if (!is_processing()) {
|
|
256
|
+
SLT_WRN(*this, "%s", "slot is not processing\n");
|
|
238
257
|
return;
|
|
239
258
|
}
|
|
240
259
|
generated_token_probs.push_back(token);
|
|
241
260
|
}
|
|
242
261
|
|
|
243
262
|
void release() {
|
|
244
|
-
if (
|
|
263
|
+
if (is_processing()) {
|
|
264
|
+
SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
|
|
265
|
+
|
|
245
266
|
t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
|
|
246
|
-
|
|
267
|
+
state = SLOT_STATE_IDLE;
|
|
268
|
+
callback_on_release(id);
|
|
247
269
|
}
|
|
248
270
|
}
|
|
249
271
|
|
|
@@ -290,49 +312,20 @@ struct server_slot {
|
|
|
290
312
|
}
|
|
291
313
|
|
|
292
314
|
void print_timings() const {
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
double
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
{"t_token", t_token},
|
|
308
|
-
{"n_tokens_second", n_tokens_second},
|
|
309
|
-
});
|
|
310
|
-
|
|
311
|
-
t_token = t_token_generation / n_decoded;
|
|
312
|
-
n_tokens_second = 1e3 / t_token_generation * n_decoded;
|
|
313
|
-
|
|
314
|
-
snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
|
|
315
|
-
t_token_generation, n_decoded,
|
|
316
|
-
t_token, n_tokens_second);
|
|
317
|
-
|
|
318
|
-
LOG_INFO(buffer, {
|
|
319
|
-
{"id_slot", id},
|
|
320
|
-
{"id_task", id_task},
|
|
321
|
-
{"t_token_generation", t_token_generation},
|
|
322
|
-
{"n_decoded", n_decoded},
|
|
323
|
-
{"t_token", t_token},
|
|
324
|
-
{"n_tokens_second", n_tokens_second},
|
|
325
|
-
});
|
|
326
|
-
|
|
327
|
-
snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
|
|
328
|
-
|
|
329
|
-
LOG_INFO(buffer, {
|
|
330
|
-
{"id_slot", id},
|
|
331
|
-
{"id_task", id_task},
|
|
332
|
-
{"t_prompt_processing", t_prompt_processing},
|
|
333
|
-
{"t_token_generation", t_token_generation},
|
|
334
|
-
{"t_total", t_prompt_processing + t_token_generation},
|
|
335
|
-
});
|
|
315
|
+
const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
|
|
316
|
+
const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
|
|
317
|
+
|
|
318
|
+
const double t_gen = t_token_generation / n_decoded;
|
|
319
|
+
const double n_gen_second = 1e3 / t_token_generation * n_decoded;
|
|
320
|
+
|
|
321
|
+
SLT_INF(*this,
|
|
322
|
+
"\n"
|
|
323
|
+
"\rprompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
324
|
+
"\r eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
|
|
325
|
+
"\r total time = %10.2f ms / %5d tokens\n",
|
|
326
|
+
t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
|
|
327
|
+
t_token_generation, n_decoded, t_gen, n_gen_second,
|
|
328
|
+
t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
|
|
336
329
|
}
|
|
337
330
|
};
|
|
338
331
|
|
|
@@ -350,6 +343,9 @@ struct server_metrics {
|
|
|
350
343
|
uint64_t n_tokens_predicted = 0;
|
|
351
344
|
uint64_t t_tokens_generation = 0;
|
|
352
345
|
|
|
346
|
+
uint64_t n_decode_total = 0;
|
|
347
|
+
uint64_t n_busy_slots_total = 0;
|
|
348
|
+
|
|
353
349
|
void init() {
|
|
354
350
|
t_start = ggml_time_us();
|
|
355
351
|
}
|
|
@@ -368,6 +364,15 @@ struct server_metrics {
|
|
|
368
364
|
t_tokens_generation_total += slot.t_token_generation;
|
|
369
365
|
}
|
|
370
366
|
|
|
367
|
+
void on_decoded(const std::vector<server_slot> & slots) {
|
|
368
|
+
n_decode_total++;
|
|
369
|
+
for (const auto & slot : slots) {
|
|
370
|
+
if (slot.is_processing()) {
|
|
371
|
+
n_busy_slots_total++;
|
|
372
|
+
}
|
|
373
|
+
}
|
|
374
|
+
}
|
|
375
|
+
|
|
371
376
|
void reset_bucket() {
|
|
372
377
|
n_prompt_tokens_processed = 0;
|
|
373
378
|
t_prompt_processing = 0;
|
|
@@ -381,42 +386,62 @@ struct server_queue {
|
|
|
381
386
|
bool running;
|
|
382
387
|
|
|
383
388
|
// queues
|
|
384
|
-
std::
|
|
385
|
-
std::
|
|
386
|
-
|
|
387
|
-
std::vector<server_task_multi> queue_multitasks;
|
|
389
|
+
std::deque<server_task> queue_tasks;
|
|
390
|
+
std::deque<server_task> queue_tasks_deferred;
|
|
388
391
|
|
|
389
392
|
std::mutex mutex_tasks;
|
|
390
393
|
std::condition_variable condition_tasks;
|
|
391
394
|
|
|
392
395
|
// callback functions
|
|
393
|
-
std::function<void(server_task
|
|
394
|
-
std::function<void(
|
|
395
|
-
std::function<void(void)> callback_update_slots;
|
|
396
|
+
std::function<void(server_task&)> callback_new_task;
|
|
397
|
+
std::function<void(void)> callback_update_slots;
|
|
396
398
|
|
|
397
399
|
// Add a new task to the end of the queue
|
|
398
|
-
int post(server_task task) {
|
|
400
|
+
int post(server_task task, bool front = false) {
|
|
399
401
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
400
402
|
if (task.id == -1) {
|
|
401
403
|
task.id = id++;
|
|
402
|
-
LOG_VERBOSE("new task id", {{"new_id", task.id}});
|
|
403
404
|
}
|
|
404
|
-
|
|
405
|
+
QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
|
|
406
|
+
if (front) {
|
|
407
|
+
queue_tasks.push_front(std::move(task));
|
|
408
|
+
} else {
|
|
409
|
+
queue_tasks.push_back(std::move(task));
|
|
410
|
+
}
|
|
405
411
|
condition_tasks.notify_one();
|
|
406
412
|
return task.id;
|
|
407
413
|
}
|
|
408
414
|
|
|
415
|
+
// multi-task version of post()
|
|
416
|
+
int post(std::vector<server_task> & tasks, bool front = false) {
|
|
417
|
+
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
418
|
+
for (auto & task : tasks) {
|
|
419
|
+
if (task.id == -1) {
|
|
420
|
+
task.id = id++;
|
|
421
|
+
}
|
|
422
|
+
QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
|
|
423
|
+
if (front) {
|
|
424
|
+
queue_tasks.push_front(std::move(task));
|
|
425
|
+
} else {
|
|
426
|
+
queue_tasks.push_back(std::move(task));
|
|
427
|
+
}
|
|
428
|
+
}
|
|
429
|
+
condition_tasks.notify_one();
|
|
430
|
+
return 0;
|
|
431
|
+
}
|
|
432
|
+
|
|
409
433
|
// Add a new task, but defer until one slot is available
|
|
410
434
|
void defer(server_task task) {
|
|
411
435
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
436
|
+
QUE_DBG("defer task, id = %d\n", task.id);
|
|
412
437
|
queue_tasks_deferred.push_back(std::move(task));
|
|
438
|
+
condition_tasks.notify_one();
|
|
413
439
|
}
|
|
414
440
|
|
|
415
|
-
// Get the next id for creating
|
|
441
|
+
// Get the next id for creating a new task
|
|
416
442
|
int get_new_id() {
|
|
417
443
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
418
444
|
int new_id = id++;
|
|
419
|
-
LOG_VERBOSE("new task id", {{"new_id", new_id}});
|
|
420
445
|
return new_id;
|
|
421
446
|
}
|
|
422
447
|
|
|
@@ -425,24 +450,19 @@ struct server_queue {
|
|
|
425
450
|
callback_new_task = std::move(callback);
|
|
426
451
|
}
|
|
427
452
|
|
|
428
|
-
// Register function to process a multitask when it is finished
|
|
429
|
-
void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
|
|
430
|
-
callback_finish_multitask = std::move(callback);
|
|
431
|
-
}
|
|
432
|
-
|
|
433
453
|
// Register the function to be called when all slots data is ready to be processed
|
|
434
454
|
void on_update_slots(std::function<void(void)> callback) {
|
|
435
455
|
callback_update_slots = std::move(callback);
|
|
436
456
|
}
|
|
437
457
|
|
|
438
|
-
// Call when the state of one slot is changed
|
|
439
|
-
void
|
|
440
|
-
// move deferred tasks back to main loop
|
|
458
|
+
// Call when the state of one slot is changed, it will move one task from deferred to main queue
|
|
459
|
+
void pop_deferred_task() {
|
|
441
460
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
442
|
-
|
|
443
|
-
queue_tasks.
|
|
461
|
+
if (!queue_tasks_deferred.empty()) {
|
|
462
|
+
queue_tasks.emplace_back(std::move(queue_tasks_deferred.front()));
|
|
463
|
+
queue_tasks_deferred.pop_front();
|
|
444
464
|
}
|
|
445
|
-
|
|
465
|
+
condition_tasks.notify_one();
|
|
446
466
|
}
|
|
447
467
|
|
|
448
468
|
// end the start_loop routine
|
|
@@ -463,7 +483,7 @@ struct server_queue {
|
|
|
463
483
|
running = true;
|
|
464
484
|
|
|
465
485
|
while (true) {
|
|
466
|
-
|
|
486
|
+
QUE_DBG("%s", "processing new tasks\n");
|
|
467
487
|
|
|
468
488
|
while (true) {
|
|
469
489
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
@@ -472,39 +492,24 @@ struct server_queue {
|
|
|
472
492
|
break;
|
|
473
493
|
}
|
|
474
494
|
server_task task = queue_tasks.front();
|
|
475
|
-
queue_tasks.
|
|
495
|
+
queue_tasks.pop_front();
|
|
476
496
|
lock.unlock();
|
|
477
|
-
LOG_VERBOSE("callback_new_task", {{"id_task", task.id}});
|
|
478
|
-
callback_new_task(task);
|
|
479
|
-
}
|
|
480
|
-
|
|
481
|
-
LOG_VERBOSE("update_multitasks", {});
|
|
482
497
|
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
while (queue_iterator != queue_multitasks.end()) {
|
|
486
|
-
if (queue_iterator->subtasks_remaining.empty()) {
|
|
487
|
-
// all subtasks done == multitask is done
|
|
488
|
-
server_task_multi current_multitask = *queue_iterator;
|
|
489
|
-
callback_finish_multitask(current_multitask);
|
|
490
|
-
// remove this multitask
|
|
491
|
-
queue_iterator = queue_multitasks.erase(queue_iterator);
|
|
492
|
-
} else {
|
|
493
|
-
++queue_iterator;
|
|
494
|
-
}
|
|
498
|
+
QUE_DBG("processing task, id = %d\n", task.id);
|
|
499
|
+
callback_new_task(task);
|
|
495
500
|
}
|
|
496
501
|
|
|
497
502
|
// all tasks in the current loop is processed, slots data is now ready
|
|
498
|
-
|
|
503
|
+
QUE_DBG("%s", "update slots\n");
|
|
499
504
|
|
|
500
505
|
callback_update_slots();
|
|
501
506
|
|
|
502
|
-
|
|
507
|
+
QUE_DBG("%s", "waiting for new tasks\n");
|
|
503
508
|
{
|
|
504
509
|
std::unique_lock<std::mutex> lock(mutex_tasks);
|
|
505
510
|
if (queue_tasks.empty()) {
|
|
506
511
|
if (!running) {
|
|
507
|
-
|
|
512
|
+
QUE_DBG("%s", "terminate\n");
|
|
508
513
|
return;
|
|
509
514
|
}
|
|
510
515
|
condition_tasks.wait(lock, [&]{
|
|
@@ -514,38 +519,11 @@ struct server_queue {
|
|
|
514
519
|
}
|
|
515
520
|
}
|
|
516
521
|
}
|
|
517
|
-
|
|
518
|
-
//
|
|
519
|
-
// functions to manage multitasks
|
|
520
|
-
//
|
|
521
|
-
|
|
522
|
-
// add a multitask by specifying the id of all subtask (subtask is a server_task)
|
|
523
|
-
void add_multitask(int id_multi, std::vector<int> & sub_ids) {
|
|
524
|
-
std::lock_guard<std::mutex> lock(mutex_tasks);
|
|
525
|
-
server_task_multi multi;
|
|
526
|
-
multi.id = id_multi;
|
|
527
|
-
std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
|
|
528
|
-
queue_multitasks.push_back(multi);
|
|
529
|
-
}
|
|
530
|
-
|
|
531
|
-
// updatethe remaining subtasks, while appending results to multitask
|
|
532
|
-
void update_multitask(int id_multi, int id_sub, server_task_result & result) {
|
|
533
|
-
std::lock_guard<std::mutex> lock(mutex_tasks);
|
|
534
|
-
for (auto & multitask : queue_multitasks) {
|
|
535
|
-
if (multitask.id == id_multi) {
|
|
536
|
-
multitask.subtasks_remaining.erase(id_sub);
|
|
537
|
-
multitask.results.push_back(result);
|
|
538
|
-
}
|
|
539
|
-
}
|
|
540
|
-
}
|
|
541
522
|
};
|
|
542
523
|
|
|
543
524
|
struct server_response {
|
|
544
|
-
typedef std::function<void(int, int, server_task_result &)> callback_multitask_t;
|
|
545
|
-
callback_multitask_t callback_update_multitask;
|
|
546
|
-
|
|
547
525
|
// for keeping track of all tasks waiting for the result
|
|
548
|
-
std::
|
|
526
|
+
std::unordered_set<int> waiting_task_ids;
|
|
549
527
|
|
|
550
528
|
// the main result queue
|
|
551
529
|
std::vector<server_task_result> queue_results;
|
|
@@ -555,22 +533,40 @@ struct server_response {
|
|
|
555
533
|
|
|
556
534
|
// add the id_task to the list of tasks waiting for response
|
|
557
535
|
void add_waiting_task_id(int id_task) {
|
|
558
|
-
|
|
536
|
+
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
|
|
559
537
|
|
|
560
538
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
561
539
|
waiting_task_ids.insert(id_task);
|
|
562
540
|
}
|
|
563
541
|
|
|
542
|
+
void add_waiting_tasks(const std::vector<server_task> & tasks) {
|
|
543
|
+
std::unique_lock<std::mutex> lock(mutex_results);
|
|
544
|
+
|
|
545
|
+
for (const auto & task : tasks) {
|
|
546
|
+
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
|
|
547
|
+
waiting_task_ids.insert(task.id);
|
|
548
|
+
}
|
|
549
|
+
}
|
|
550
|
+
|
|
564
551
|
// when the request is finished, we can remove task associated with it
|
|
565
552
|
void remove_waiting_task_id(int id_task) {
|
|
566
|
-
|
|
553
|
+
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
|
|
567
554
|
|
|
568
555
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
569
556
|
waiting_task_ids.erase(id_task);
|
|
570
557
|
}
|
|
571
558
|
|
|
572
|
-
|
|
573
|
-
|
|
559
|
+
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
|
|
560
|
+
std::unique_lock<std::mutex> lock(mutex_results);
|
|
561
|
+
|
|
562
|
+
for (const auto & id_task : id_tasks) {
|
|
563
|
+
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
|
|
564
|
+
waiting_task_ids.erase(id_task);
|
|
565
|
+
}
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
// This function blocks the thread until there is a response for one of the id_tasks
|
|
569
|
+
server_task_result recv(const std::unordered_set<int> & id_tasks) {
|
|
574
570
|
while (true) {
|
|
575
571
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
576
572
|
condition_results.wait(lock, [&]{
|
|
@@ -578,8 +574,7 @@ struct server_response {
|
|
|
578
574
|
});
|
|
579
575
|
|
|
580
576
|
for (int i = 0; i < (int) queue_results.size(); i++) {
|
|
581
|
-
if (queue_results[i].id
|
|
582
|
-
assert(queue_results[i].id_multi == -1);
|
|
577
|
+
if (id_tasks.find(queue_results[i].id) != id_tasks.end()) {
|
|
583
578
|
server_task_result res = queue_results[i];
|
|
584
579
|
queue_results.erase(queue_results.begin() + i);
|
|
585
580
|
return res;
|
|
@@ -590,28 +585,22 @@ struct server_response {
|
|
|
590
585
|
// should never reach here
|
|
591
586
|
}
|
|
592
587
|
|
|
593
|
-
//
|
|
594
|
-
|
|
595
|
-
|
|
588
|
+
// single-task version of recv()
|
|
589
|
+
server_task_result recv(int id_task) {
|
|
590
|
+
std::unordered_set<int> id_tasks = {id_task};
|
|
591
|
+
return recv(id_tasks);
|
|
596
592
|
}
|
|
597
593
|
|
|
598
594
|
// Send a new result to a waiting id_task
|
|
599
|
-
void send(server_task_result result) {
|
|
600
|
-
|
|
595
|
+
void send(server_task_result & result) {
|
|
596
|
+
SRV_DBG("sending result for task id = %d\n", result.id);
|
|
601
597
|
|
|
602
598
|
std::unique_lock<std::mutex> lock(mutex_results);
|
|
603
599
|
for (const auto & id_task : waiting_task_ids) {
|
|
604
|
-
// LOG_TEE("waiting task id %i \n", id_task);
|
|
605
|
-
// for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
|
|
606
|
-
if (result.id_multi == id_task) {
|
|
607
|
-
LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}});
|
|
608
|
-
callback_update_multitask(id_task, result.id, result);
|
|
609
|
-
continue;
|
|
610
|
-
}
|
|
611
|
-
|
|
612
600
|
if (result.id == id_task) {
|
|
613
|
-
|
|
614
|
-
|
|
601
|
+
SRV_DBG("task id = %d moved to result queue\n", result.id);
|
|
602
|
+
|
|
603
|
+
queue_results.push_back(std::move(result));
|
|
615
604
|
condition_results.notify_all();
|
|
616
605
|
return;
|
|
617
606
|
}
|
|
@@ -622,13 +611,15 @@ struct server_response {
|
|
|
622
611
|
struct server_context {
|
|
623
612
|
llama_model * model = nullptr;
|
|
624
613
|
llama_context * ctx = nullptr;
|
|
614
|
+
std::vector<llama_lora_adapter_container> loras;
|
|
625
615
|
|
|
626
616
|
gpt_params params;
|
|
627
617
|
|
|
628
|
-
llama_batch batch;
|
|
618
|
+
llama_batch batch = {};
|
|
629
619
|
|
|
630
620
|
bool clean_kv_cache = true;
|
|
631
621
|
bool add_bos_token = true;
|
|
622
|
+
bool has_eos_token = false;
|
|
632
623
|
|
|
633
624
|
int32_t n_ctx; // total context for all clients / slots
|
|
634
625
|
|
|
@@ -663,8 +654,8 @@ struct server_context {
|
|
|
663
654
|
|
|
664
655
|
// Clear any sampling context
|
|
665
656
|
for (server_slot & slot : slots) {
|
|
666
|
-
if (slot.
|
|
667
|
-
|
|
657
|
+
if (slot.smpl != nullptr) {
|
|
658
|
+
gpt_sampler_free(slot.smpl);
|
|
668
659
|
}
|
|
669
660
|
}
|
|
670
661
|
|
|
@@ -677,17 +668,23 @@ struct server_context {
|
|
|
677
668
|
// dedicate one sequence to the system prompt
|
|
678
669
|
params.n_parallel += 1;
|
|
679
670
|
|
|
680
|
-
|
|
671
|
+
llama_init_result llama_init = llama_init_from_gpt_params(params);
|
|
672
|
+
|
|
673
|
+
model = llama_init.model;
|
|
674
|
+
ctx = llama_init.context;
|
|
675
|
+
loras = llama_init.lora_adapters;
|
|
676
|
+
|
|
681
677
|
params.n_parallel -= 1; // but be sneaky about it
|
|
678
|
+
|
|
682
679
|
if (model == nullptr) {
|
|
683
|
-
|
|
680
|
+
SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
|
|
684
681
|
return false;
|
|
685
682
|
}
|
|
686
683
|
|
|
687
684
|
n_ctx = llama_n_ctx(ctx);
|
|
688
685
|
|
|
689
|
-
add_bos_token =
|
|
690
|
-
|
|
686
|
+
add_bos_token = llama_add_bos_token(model);
|
|
687
|
+
has_eos_token = !llama_add_eos_token(model);
|
|
691
688
|
|
|
692
689
|
return true;
|
|
693
690
|
}
|
|
@@ -703,7 +700,7 @@ struct server_context {
|
|
|
703
700
|
void init() {
|
|
704
701
|
const int32_t n_ctx_slot = n_ctx / params.n_parallel;
|
|
705
702
|
|
|
706
|
-
|
|
703
|
+
SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
|
|
707
704
|
|
|
708
705
|
for (int i = 0; i < params.n_parallel; i++) {
|
|
709
706
|
server_slot slot;
|
|
@@ -712,10 +709,7 @@ struct server_context {
|
|
|
712
709
|
slot.n_ctx = n_ctx_slot;
|
|
713
710
|
slot.n_predict = params.n_predict;
|
|
714
711
|
|
|
715
|
-
|
|
716
|
-
{"id_slot", slot.id},
|
|
717
|
-
{"n_ctx_slot", slot.n_ctx}
|
|
718
|
-
});
|
|
712
|
+
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
|
|
719
713
|
|
|
720
714
|
const int ga_n = params.grp_attn_n;
|
|
721
715
|
const int ga_w = params.grp_attn_w;
|
|
@@ -726,11 +720,7 @@ struct server_context {
|
|
|
726
720
|
//GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
|
|
727
721
|
//GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
|
|
728
722
|
|
|
729
|
-
|
|
730
|
-
{"id_slot", slot.id},
|
|
731
|
-
{"ga_n", ga_n},
|
|
732
|
-
{"ga_w", ga_w}
|
|
733
|
-
});
|
|
723
|
+
SLT_INF(slot, "slot self-extend: ga_n = %d, ga_w = %d\n", ga_n, ga_w);
|
|
734
724
|
}
|
|
735
725
|
|
|
736
726
|
slot.ga_i = 0;
|
|
@@ -739,6 +729,10 @@ struct server_context {
|
|
|
739
729
|
|
|
740
730
|
slot.sparams = params.sparams;
|
|
741
731
|
|
|
732
|
+
slot.callback_on_release = [this](int) {
|
|
733
|
+
queue_tasks.pop_deferred_task();
|
|
734
|
+
};
|
|
735
|
+
|
|
742
736
|
slot.reset();
|
|
743
737
|
|
|
744
738
|
slots.push_back(slot);
|
|
@@ -747,13 +741,13 @@ struct server_context {
|
|
|
747
741
|
default_generation_settings_for_props = get_formated_generation(slots.front());
|
|
748
742
|
default_generation_settings_for_props["seed"] = -1;
|
|
749
743
|
|
|
750
|
-
// the update_slots() logic will always submit a maximum of n_batch tokens
|
|
744
|
+
// the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
|
|
751
745
|
// note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
|
|
752
746
|
{
|
|
753
747
|
const int32_t n_batch = llama_n_batch(ctx);
|
|
754
748
|
|
|
755
749
|
// only a single seq_id per token is needed
|
|
756
|
-
batch = llama_batch_init(n_batch, 0, 1);
|
|
750
|
+
batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
|
|
757
751
|
}
|
|
758
752
|
|
|
759
753
|
metrics.init();
|
|
@@ -820,7 +814,7 @@ struct server_context {
|
|
|
820
814
|
|
|
821
815
|
for (server_slot & slot : slots) {
|
|
822
816
|
// skip the slot if it is not available
|
|
823
|
-
if (
|
|
817
|
+
if (slot.is_processing()) {
|
|
824
818
|
continue;
|
|
825
819
|
}
|
|
826
820
|
|
|
@@ -849,11 +843,7 @@ struct server_context {
|
|
|
849
843
|
}
|
|
850
844
|
|
|
851
845
|
if (ret != nullptr) {
|
|
852
|
-
|
|
853
|
-
{"id_slot", ret->id},
|
|
854
|
-
{"max_lcp_len", max_lcp_len},
|
|
855
|
-
{"similarity", similarity},
|
|
856
|
-
});
|
|
846
|
+
SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
|
|
857
847
|
}
|
|
858
848
|
}
|
|
859
849
|
|
|
@@ -862,7 +852,7 @@ struct server_context {
|
|
|
862
852
|
int64_t t_last = ggml_time_us();
|
|
863
853
|
for (server_slot & slot : slots) {
|
|
864
854
|
// skip the slot if it is not available
|
|
865
|
-
if (
|
|
855
|
+
if (slot.is_processing()) {
|
|
866
856
|
continue;
|
|
867
857
|
}
|
|
868
858
|
|
|
@@ -874,10 +864,7 @@ struct server_context {
|
|
|
874
864
|
}
|
|
875
865
|
|
|
876
866
|
if (ret != nullptr) {
|
|
877
|
-
|
|
878
|
-
{"id_slot", ret->id},
|
|
879
|
-
{"t_last", t_last},
|
|
880
|
-
});
|
|
867
|
+
SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last);
|
|
881
868
|
}
|
|
882
869
|
}
|
|
883
870
|
|
|
@@ -887,8 +874,8 @@ struct server_context {
|
|
|
887
874
|
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
|
888
875
|
slot_params default_params;
|
|
889
876
|
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
|
890
|
-
|
|
891
|
-
auto & data = task.data;
|
|
877
|
+
auto default_sparams = params.sparams;
|
|
878
|
+
const auto & data = task.data;
|
|
892
879
|
|
|
893
880
|
if (data.count("__oaicompat") != 0) {
|
|
894
881
|
slot.oaicompat = true;
|
|
@@ -900,12 +887,12 @@ struct server_context {
|
|
|
900
887
|
|
|
901
888
|
slot.params.stream = json_value(data, "stream", false);
|
|
902
889
|
slot.params.cache_prompt = json_value(data, "cache_prompt", false);
|
|
903
|
-
slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
|
|
890
|
+
slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
|
|
904
891
|
slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
|
|
905
892
|
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
|
906
893
|
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
|
907
894
|
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
|
908
|
-
slot.sparams.
|
|
895
|
+
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
|
909
896
|
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
|
910
897
|
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
|
911
898
|
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
|
@@ -927,7 +914,8 @@ struct server_context {
|
|
|
927
914
|
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
|
928
915
|
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
|
929
916
|
return false;
|
|
930
|
-
}
|
|
917
|
+
}
|
|
918
|
+
if (data.contains("json_schema") && !data.contains("grammar")) {
|
|
931
919
|
try {
|
|
932
920
|
auto schema = json_value(data, "json_schema", json::object());
|
|
933
921
|
slot.sparams.grammar = json_schema_to_grammar(schema);
|
|
@@ -940,17 +928,14 @@ struct server_context {
|
|
|
940
928
|
}
|
|
941
929
|
|
|
942
930
|
if (slot.params.cache_prompt && slot.ga_n != 1) {
|
|
943
|
-
LOG_WARNING("cache_prompt is not supported with group-attention", {});
|
|
944
931
|
slot.params.cache_prompt = false;
|
|
932
|
+
SLT_WRN(slot, "%s", "group-attention is not supported with prompt caching. disabling cache\n");
|
|
945
933
|
}
|
|
946
934
|
|
|
947
935
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
|
948
936
|
// Might be better to reject the request with a 400 ?
|
|
949
|
-
LOG_WARNING("Max tokens to predict exceeds server configuration", {
|
|
950
|
-
{"params.n_predict", slot.params.n_predict},
|
|
951
|
-
{"slot.n_predict", slot.n_predict},
|
|
952
|
-
});
|
|
953
937
|
slot.params.n_predict = slot.n_predict;
|
|
938
|
+
SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
|
|
954
939
|
}
|
|
955
940
|
|
|
956
941
|
// infill
|
|
@@ -958,7 +943,7 @@ struct server_context {
|
|
|
958
943
|
slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
|
|
959
944
|
|
|
960
945
|
// get prompt
|
|
961
|
-
if (
|
|
946
|
+
if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
|
|
962
947
|
const auto & prompt = data.find("prompt");
|
|
963
948
|
if (prompt == data.end()) {
|
|
964
949
|
send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
|
|
@@ -969,62 +954,28 @@ struct server_context {
|
|
|
969
954
|
(prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
|
|
970
955
|
(prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
|
|
971
956
|
slot.prompt = *prompt;
|
|
972
|
-
} else {
|
|
973
|
-
|
|
974
|
-
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
|
|
978
|
-
|
|
979
|
-
|
|
980
|
-
slot.sparams.penalty_prompt_tokens.clear();
|
|
981
|
-
slot.sparams.use_penalty_prompt_tokens = false;
|
|
982
|
-
|
|
983
|
-
const auto & penalty_prompt = data.find("penalty_prompt");
|
|
984
|
-
|
|
985
|
-
if (penalty_prompt != data.end()) {
|
|
986
|
-
if (penalty_prompt->is_string()) {
|
|
987
|
-
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
|
988
|
-
slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
|
|
989
|
-
|
|
990
|
-
if (slot.params.n_predict > 0) {
|
|
991
|
-
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
|
|
992
|
-
}
|
|
993
|
-
slot.sparams.use_penalty_prompt_tokens = true;
|
|
994
|
-
|
|
995
|
-
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
996
|
-
{"id_slot", slot.id},
|
|
997
|
-
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
998
|
-
});
|
|
999
|
-
}
|
|
1000
|
-
else if (penalty_prompt->is_array()) {
|
|
1001
|
-
const auto n_tokens = penalty_prompt->size();
|
|
1002
|
-
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
|
|
1003
|
-
|
|
1004
|
-
const int n_vocab = llama_n_vocab(model);
|
|
1005
|
-
for (const auto & penalty_token : *penalty_prompt) {
|
|
1006
|
-
if (penalty_token.is_number_integer()) {
|
|
1007
|
-
const auto tok = penalty_token.get<llama_token>();
|
|
1008
|
-
if (tok >= 0 && tok < n_vocab) {
|
|
1009
|
-
slot.sparams.penalty_prompt_tokens.push_back(tok);
|
|
1010
|
-
}
|
|
1011
|
-
}
|
|
957
|
+
} else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
|
|
958
|
+
slot.prompt = prompt->at(0);
|
|
959
|
+
} else if (prompt->is_array() && prompt->size() > 1) {
|
|
960
|
+
// array of strings
|
|
961
|
+
for (const auto & el : *prompt) {
|
|
962
|
+
if (!el.is_string()) {
|
|
963
|
+
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
964
|
+
return false;
|
|
1012
965
|
}
|
|
1013
|
-
slot.sparams.use_penalty_prompt_tokens = true;
|
|
1014
|
-
|
|
1015
|
-
LOG_VERBOSE("penalty_prompt_tokens", {
|
|
1016
|
-
{"id_slot", slot.id},
|
|
1017
|
-
{"tokens", slot.sparams.penalty_prompt_tokens},
|
|
1018
|
-
});
|
|
1019
966
|
}
|
|
967
|
+
slot.prompt = *prompt;
|
|
968
|
+
} else {
|
|
969
|
+
send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
|
|
970
|
+
return false;
|
|
1020
971
|
}
|
|
1021
972
|
}
|
|
1022
973
|
|
|
1023
974
|
{
|
|
1024
975
|
slot.sparams.logit_bias.clear();
|
|
1025
976
|
|
|
1026
|
-
if (json_value(data, "ignore_eos", false)) {
|
|
1027
|
-
slot.sparams.logit_bias
|
|
977
|
+
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
|
978
|
+
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
|
1028
979
|
}
|
|
1029
980
|
|
|
1030
981
|
const auto & logit_bias = data.find("logit_bias");
|
|
@@ -1045,12 +996,12 @@ struct server_context {
|
|
|
1045
996
|
if (el[0].is_number_integer()) {
|
|
1046
997
|
llama_token tok = el[0].get<llama_token>();
|
|
1047
998
|
if (tok >= 0 && tok < n_vocab) {
|
|
1048
|
-
slot.sparams.logit_bias
|
|
999
|
+
slot.sparams.logit_bias.push_back({tok, bias});
|
|
1049
1000
|
}
|
|
1050
1001
|
} else if (el[0].is_string()) {
|
|
1051
1002
|
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
|
1052
1003
|
for (auto tok : toks) {
|
|
1053
|
-
slot.sparams.logit_bias
|
|
1004
|
+
slot.sparams.logit_bias.push_back({tok, bias});
|
|
1054
1005
|
}
|
|
1055
1006
|
}
|
|
1056
1007
|
}
|
|
@@ -1072,45 +1023,43 @@ struct server_context {
|
|
|
1072
1023
|
}
|
|
1073
1024
|
|
|
1074
1025
|
{
|
|
1075
|
-
const auto &
|
|
1076
|
-
if (
|
|
1026
|
+
const auto & samplers = data.find("samplers");
|
|
1027
|
+
if (samplers != data.end() && samplers->is_array()) {
|
|
1077
1028
|
std::vector<std::string> sampler_names;
|
|
1078
|
-
for (const auto &
|
|
1079
|
-
if (
|
|
1080
|
-
sampler_names.emplace_back(
|
|
1029
|
+
for (const auto & name : *samplers) {
|
|
1030
|
+
if (name.is_string()) {
|
|
1031
|
+
sampler_names.emplace_back(name);
|
|
1081
1032
|
}
|
|
1082
1033
|
}
|
|
1083
|
-
slot.sparams.
|
|
1034
|
+
slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
|
|
1084
1035
|
} else {
|
|
1085
|
-
slot.sparams.
|
|
1036
|
+
slot.sparams.samplers = default_sparams.samplers;
|
|
1086
1037
|
}
|
|
1087
1038
|
}
|
|
1088
1039
|
|
|
1089
1040
|
{
|
|
1090
|
-
if (slot.
|
|
1091
|
-
|
|
1041
|
+
if (slot.smpl != nullptr) {
|
|
1042
|
+
gpt_sampler_free(slot.smpl);
|
|
1092
1043
|
}
|
|
1093
|
-
|
|
1094
|
-
|
|
1044
|
+
|
|
1045
|
+
slot.smpl = gpt_sampler_init(model, slot.sparams);
|
|
1046
|
+
if (slot.smpl == nullptr) {
|
|
1095
1047
|
// for now, the only error that may happen here is invalid grammar
|
|
1096
1048
|
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
|
1097
1049
|
return false;
|
|
1098
1050
|
}
|
|
1099
1051
|
}
|
|
1100
1052
|
|
|
1101
|
-
slot.
|
|
1053
|
+
slot.state = SLOT_STATE_PROCESSING_PROMPT;
|
|
1102
1054
|
slot.prompt_tokens.clear();
|
|
1103
1055
|
|
|
1104
|
-
|
|
1105
|
-
{"id_slot", slot.id},
|
|
1106
|
-
{"id_task", slot.id_task},
|
|
1107
|
-
});
|
|
1056
|
+
SLT_INF(slot, "%s", "processing task\n");
|
|
1108
1057
|
|
|
1109
1058
|
return true;
|
|
1110
1059
|
}
|
|
1111
1060
|
|
|
1112
1061
|
void kv_cache_clear() {
|
|
1113
|
-
|
|
1062
|
+
SRV_DBG("%s", "clearing KV cache\n");
|
|
1114
1063
|
|
|
1115
1064
|
// clear the entire KV cache
|
|
1116
1065
|
llama_kv_cache_clear(ctx);
|
|
@@ -1118,9 +1067,7 @@ struct server_context {
|
|
|
1118
1067
|
}
|
|
1119
1068
|
|
|
1120
1069
|
void system_prompt_update() {
|
|
1121
|
-
|
|
1122
|
-
{"system_prompt", system_prompt},
|
|
1123
|
-
});
|
|
1070
|
+
SRV_DBG("updating system prompt: '%s'\n", system_prompt.c_str());
|
|
1124
1071
|
|
|
1125
1072
|
kv_cache_clear();
|
|
1126
1073
|
system_tokens.clear();
|
|
@@ -1128,29 +1075,20 @@ struct server_context {
|
|
|
1128
1075
|
if (!system_prompt.empty()) {
|
|
1129
1076
|
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
|
1130
1077
|
|
|
1131
|
-
|
|
1078
|
+
const int32_t n_batch = llama_n_batch(ctx);
|
|
1079
|
+
const int32_t n_tokens_prompt = system_tokens.size();
|
|
1132
1080
|
|
|
1133
|
-
for (
|
|
1134
|
-
|
|
1135
|
-
}
|
|
1081
|
+
for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
|
|
1082
|
+
const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
|
|
1136
1083
|
|
|
1137
|
-
|
|
1084
|
+
llama_batch_clear(batch);
|
|
1138
1085
|
|
|
1139
|
-
|
|
1140
|
-
|
|
1141
|
-
|
|
1142
|
-
n_tokens,
|
|
1143
|
-
batch.token + i,
|
|
1144
|
-
nullptr,
|
|
1145
|
-
batch.pos + i,
|
|
1146
|
-
batch.n_seq_id + i,
|
|
1147
|
-
batch.seq_id + i,
|
|
1148
|
-
batch.logits + i,
|
|
1149
|
-
0, 0, 0, // unused
|
|
1150
|
-
};
|
|
1086
|
+
for (int32_t j = 0; j < n_tokens; ++j) {
|
|
1087
|
+
llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
|
|
1088
|
+
}
|
|
1151
1089
|
|
|
1152
|
-
if (llama_decode(ctx,
|
|
1153
|
-
|
|
1090
|
+
if (llama_decode(ctx, batch) != 0) {
|
|
1091
|
+
SRV_ERR("%s", "llama_decode() failed\n");
|
|
1154
1092
|
return;
|
|
1155
1093
|
}
|
|
1156
1094
|
}
|
|
@@ -1165,11 +1103,9 @@ struct server_context {
|
|
|
1165
1103
|
}
|
|
1166
1104
|
|
|
1167
1105
|
bool system_prompt_set(const std::string & sys_prompt) {
|
|
1168
|
-
|
|
1106
|
+
SRV_DBG("system prompt set: '%s'\n", system_prompt.c_str());
|
|
1169
1107
|
|
|
1170
|
-
|
|
1171
|
-
{"system_prompt", system_prompt},
|
|
1172
|
-
});
|
|
1108
|
+
system_prompt = sys_prompt;
|
|
1173
1109
|
|
|
1174
1110
|
// release all slots
|
|
1175
1111
|
for (server_slot & slot : slots) {
|
|
@@ -1189,11 +1125,6 @@ struct server_context {
|
|
|
1189
1125
|
slot.generated_text += token_str;
|
|
1190
1126
|
slot.has_next_token = true;
|
|
1191
1127
|
|
|
1192
|
-
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
|
1193
|
-
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
|
1194
|
-
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
|
1195
|
-
}
|
|
1196
|
-
|
|
1197
1128
|
// check if there is incomplete UTF-8 character at the end
|
|
1198
1129
|
bool incomplete = false;
|
|
1199
1130
|
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
|
@@ -1242,7 +1173,7 @@ struct server_context {
|
|
|
1242
1173
|
// add the token to slot queue and cache
|
|
1243
1174
|
}
|
|
1244
1175
|
|
|
1245
|
-
slot.
|
|
1176
|
+
slot.add_token(result);
|
|
1246
1177
|
if (slot.params.stream) {
|
|
1247
1178
|
send_partial_response(slot, result);
|
|
1248
1179
|
}
|
|
@@ -1257,74 +1188,56 @@ struct server_context {
|
|
|
1257
1188
|
slot.stopped_limit = true;
|
|
1258
1189
|
slot.has_next_token = false;
|
|
1259
1190
|
|
|
1260
|
-
|
|
1261
|
-
|
|
1262
|
-
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
|
|
1191
|
+
SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
|
|
1192
|
+
}
|
|
1193
|
+
|
|
1194
|
+
// if context shift is disabled, we stop when it reaches the context limit
|
|
1195
|
+
if (slot.n_decoded >= slot.n_ctx) {
|
|
1196
|
+
slot.truncated = true;
|
|
1197
|
+
slot.stopped_limit = true;
|
|
1198
|
+
slot.has_next_token = false;
|
|
1199
|
+
|
|
1200
|
+
SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
|
|
1266
1201
|
}
|
|
1267
1202
|
|
|
1268
1203
|
if (llama_token_is_eog(model, result.tok)) {
|
|
1269
1204
|
slot.stopped_eos = true;
|
|
1270
1205
|
slot.has_next_token = false;
|
|
1271
1206
|
|
|
1272
|
-
|
|
1273
|
-
}
|
|
1274
|
-
|
|
1275
|
-
auto n_ctx_train = llama_n_ctx_train(model);
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
LOG_WARNING("n_predict is not set and self-context extend is disabled."
|
|
1279
|
-
" Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
|
|
1280
|
-
{ "id_slot", slot.id },
|
|
1281
|
-
{ "params.n_predict", slot.params.n_predict },
|
|
1282
|
-
{ "slot.n_prompt_tokens", slot.n_prompt_tokens },
|
|
1283
|
-
{ "slot.n_decoded", slot.n_decoded },
|
|
1284
|
-
{ "slot.n_predict", slot.n_predict },
|
|
1285
|
-
{ "n_slots", params.n_parallel },
|
|
1286
|
-
{ "slot.n_ctx", slot.n_ctx },
|
|
1287
|
-
{ "n_ctx", n_ctx },
|
|
1288
|
-
{ "n_ctx_train", n_ctx_train },
|
|
1289
|
-
{ "ga_n", slot.ga_n },
|
|
1290
|
-
});
|
|
1207
|
+
SLT_DBG(slot, "%s", "stopped by EOS\n");
|
|
1208
|
+
}
|
|
1209
|
+
|
|
1210
|
+
const auto n_ctx_train = llama_n_ctx_train(model);
|
|
1211
|
+
|
|
1212
|
+
if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
|
|
1291
1213
|
slot.truncated = true;
|
|
1292
1214
|
slot.stopped_limit = true;
|
|
1293
1215
|
slot.has_next_token = false; // stop prediction
|
|
1216
|
+
|
|
1217
|
+
SLT_WRN(slot,
|
|
1218
|
+
"n_predict (%d) is not set and self-context extend is disabled. "
|
|
1219
|
+
"Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
|
|
1220
|
+
slot.params.n_predict, n_ctx_train);
|
|
1294
1221
|
}
|
|
1295
1222
|
|
|
1296
|
-
|
|
1297
|
-
{"id_slot", slot.id},
|
|
1298
|
-
{"id_task", slot.id_task},
|
|
1299
|
-
{"token", result.tok},
|
|
1300
|
-
{"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
|
|
1301
|
-
{"has_next_token", slot.has_next_token},
|
|
1302
|
-
{"n_remain", slot.n_remaining},
|
|
1303
|
-
{"n_decoded", slot.n_decoded},
|
|
1304
|
-
{"stopped_eos", slot.stopped_eos},
|
|
1305
|
-
{"stopped_word", slot.stopped_word},
|
|
1306
|
-
{"stopped_limit", slot.stopped_limit},
|
|
1307
|
-
{"stopping_word", slot.stopping_word},
|
|
1308
|
-
});
|
|
1223
|
+
SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
|
|
1309
1224
|
|
|
1310
1225
|
return slot.has_next_token; // continue
|
|
1311
1226
|
}
|
|
1312
1227
|
|
|
1313
1228
|
json get_formated_generation(const server_slot & slot) const {
|
|
1314
|
-
|
|
1315
|
-
|
|
1316
|
-
|
|
1317
|
-
|
|
1318
|
-
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
|
|
1319
|
-
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
|
|
1320
|
-
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
|
|
1229
|
+
std::vector<std::string> samplers;
|
|
1230
|
+
samplers.reserve(slot.sparams.samplers.size());
|
|
1231
|
+
for (const auto & sampler : slot.sparams.samplers) {
|
|
1232
|
+
samplers.emplace_back(gpt_sampler_type_to_str(sampler));
|
|
1321
1233
|
}
|
|
1322
1234
|
|
|
1323
1235
|
return json {
|
|
1324
1236
|
{"n_ctx", slot.n_ctx},
|
|
1325
|
-
{"n_predict", slot.n_predict},
|
|
1237
|
+
{"n_predict", slot.n_predict}, // Server configured n_predict
|
|
1326
1238
|
{"model", params.model_alias},
|
|
1327
1239
|
{"seed", slot.sparams.seed},
|
|
1240
|
+
{"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0},
|
|
1328
1241
|
{"temperature", slot.sparams.temp},
|
|
1329
1242
|
{"dynatemp_range", slot.sparams.dynatemp_range},
|
|
1330
1243
|
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
|
|
@@ -1332,49 +1245,42 @@ struct server_context {
|
|
|
1332
1245
|
{"top_p", slot.sparams.top_p},
|
|
1333
1246
|
{"min_p", slot.sparams.min_p},
|
|
1334
1247
|
{"tfs_z", slot.sparams.tfs_z},
|
|
1335
|
-
{"typical_p", slot.sparams.
|
|
1248
|
+
{"typical_p", slot.sparams.typ_p},
|
|
1336
1249
|
{"repeat_last_n", slot.sparams.penalty_last_n},
|
|
1337
1250
|
{"repeat_penalty", slot.sparams.penalty_repeat},
|
|
1338
1251
|
{"presence_penalty", slot.sparams.penalty_present},
|
|
1339
1252
|
{"frequency_penalty", slot.sparams.penalty_freq},
|
|
1340
|
-
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
|
1341
|
-
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
|
1342
1253
|
{"mirostat", slot.sparams.mirostat},
|
|
1343
1254
|
{"mirostat_tau", slot.sparams.mirostat_tau},
|
|
1344
1255
|
{"mirostat_eta", slot.sparams.mirostat_eta},
|
|
1345
1256
|
{"penalize_nl", slot.sparams.penalize_nl},
|
|
1346
1257
|
{"stop", slot.params.antiprompt},
|
|
1347
|
-
{"
|
|
1258
|
+
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
|
1348
1259
|
{"n_keep", slot.params.n_keep},
|
|
1349
1260
|
{"n_discard", slot.params.n_discard},
|
|
1350
|
-
{"ignore_eos", ignore_eos},
|
|
1261
|
+
{"ignore_eos", slot.sparams.ignore_eos},
|
|
1351
1262
|
{"stream", slot.params.stream},
|
|
1352
|
-
|
|
1263
|
+
//{"logit_bias", slot.sparams.logit_bias},
|
|
1353
1264
|
{"n_probs", slot.sparams.n_probs},
|
|
1354
1265
|
{"min_keep", slot.sparams.min_keep},
|
|
1355
1266
|
{"grammar", slot.sparams.grammar},
|
|
1356
|
-
{"samplers",
|
|
1267
|
+
{"samplers", samplers},
|
|
1357
1268
|
};
|
|
1358
1269
|
}
|
|
1359
1270
|
|
|
1360
1271
|
void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
1361
|
-
send_error(task.id,
|
|
1272
|
+
send_error(task.id, error, type);
|
|
1362
1273
|
}
|
|
1363
1274
|
|
|
1364
1275
|
void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
1365
|
-
send_error(slot.id_task,
|
|
1276
|
+
send_error(slot.id_task, error, type);
|
|
1366
1277
|
}
|
|
1367
1278
|
|
|
1368
|
-
void send_error(const int id_task, const
|
|
1369
|
-
|
|
1370
|
-
{"id_multi", id_multi},
|
|
1371
|
-
{"id_task", id_task},
|
|
1372
|
-
{"error", error},
|
|
1373
|
-
});
|
|
1279
|
+
void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
|
|
1280
|
+
SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
|
|
1374
1281
|
|
|
1375
1282
|
server_task_result res;
|
|
1376
1283
|
res.id = id_task;
|
|
1377
|
-
res.id_multi = id_multi;
|
|
1378
1284
|
res.stop = false;
|
|
1379
1285
|
res.error = true;
|
|
1380
1286
|
res.data = format_error_response(error, type);
|
|
@@ -1385,14 +1291,14 @@ struct server_context {
|
|
|
1385
1291
|
void send_partial_response(server_slot & slot, completion_token_output tkn) {
|
|
1386
1292
|
server_task_result res;
|
|
1387
1293
|
res.id = slot.id_task;
|
|
1388
|
-
res.id_multi = slot.id_multi;
|
|
1389
1294
|
res.error = false;
|
|
1390
1295
|
res.stop = false;
|
|
1391
1296
|
res.data = json {
|
|
1392
1297
|
{"content", tkn.text_to_send},
|
|
1393
1298
|
{"stop", false},
|
|
1394
1299
|
{"id_slot", slot.id},
|
|
1395
|
-
{"multimodal", false}
|
|
1300
|
+
{"multimodal", false},
|
|
1301
|
+
{"index", slot.index},
|
|
1396
1302
|
};
|
|
1397
1303
|
|
|
1398
1304
|
if (slot.sparams.n_probs > 0) {
|
|
@@ -1422,7 +1328,6 @@ struct server_context {
|
|
|
1422
1328
|
void send_final_response(const server_slot & slot) {
|
|
1423
1329
|
server_task_result res;
|
|
1424
1330
|
res.id = slot.id_task;
|
|
1425
|
-
res.id_multi = slot.id_multi;
|
|
1426
1331
|
res.error = false;
|
|
1427
1332
|
res.stop = true;
|
|
1428
1333
|
res.data = json {
|
|
@@ -1440,7 +1345,8 @@ struct server_context {
|
|
|
1440
1345
|
{"stopped_limit", slot.stopped_limit},
|
|
1441
1346
|
{"stopping_word", slot.stopping_word},
|
|
1442
1347
|
{"tokens_cached", slot.n_past},
|
|
1443
|
-
{"timings", slot.get_formated_timings()}
|
|
1348
|
+
{"timings", slot.get_formated_timings()},
|
|
1349
|
+
{"index", slot.index},
|
|
1444
1350
|
};
|
|
1445
1351
|
|
|
1446
1352
|
if (slot.sparams.n_probs > 0) {
|
|
@@ -1472,7 +1378,6 @@ struct server_context {
|
|
|
1472
1378
|
void send_embedding(const server_slot & slot, const llama_batch & batch) {
|
|
1473
1379
|
server_task_result res;
|
|
1474
1380
|
res.id = slot.id_task;
|
|
1475
|
-
res.id_multi = slot.id_multi;
|
|
1476
1381
|
res.error = false;
|
|
1477
1382
|
res.stop = true;
|
|
1478
1383
|
|
|
@@ -1491,13 +1396,11 @@ struct server_context {
|
|
|
1491
1396
|
}
|
|
1492
1397
|
|
|
1493
1398
|
if (embd == NULL) {
|
|
1494
|
-
|
|
1495
|
-
{"token", batch.token [i]},
|
|
1496
|
-
{"seq_id", batch.seq_id[i][0]}
|
|
1497
|
-
});
|
|
1399
|
+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
1498
1400
|
|
|
1499
1401
|
res.data = json {
|
|
1500
1402
|
{"embedding", std::vector<float>(n_embd, 0.0f)},
|
|
1403
|
+
{"index", slot.index},
|
|
1501
1404
|
};
|
|
1502
1405
|
|
|
1503
1406
|
continue;
|
|
@@ -1507,83 +1410,191 @@ struct server_context {
|
|
|
1507
1410
|
|
|
1508
1411
|
res.data = json {
|
|
1509
1412
|
{"embedding", embd_res},
|
|
1413
|
+
{"index", slot.index},
|
|
1510
1414
|
};
|
|
1511
1415
|
}
|
|
1512
1416
|
|
|
1417
|
+
SLT_DBG(slot, "%s", "sending embeddings\n");
|
|
1418
|
+
|
|
1513
1419
|
queue_results.send(res);
|
|
1514
1420
|
}
|
|
1515
1421
|
|
|
1516
|
-
void
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
// when a completion task's prompt array is not a singleton, we split it into multiple requests
|
|
1527
|
-
// otherwise, it's a single-prompt task, we actually queue it
|
|
1528
|
-
// if there's numbers in the prompt array it will be treated as an array of tokens
|
|
1529
|
-
if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
|
|
1530
|
-
bool numbers = false;
|
|
1531
|
-
for (const auto & e : task.data.at("prompt")) {
|
|
1532
|
-
if (e.is_number()) {
|
|
1533
|
-
numbers = true;
|
|
1534
|
-
break;
|
|
1535
|
-
}
|
|
1422
|
+
void send_rerank(const server_slot & slot, const llama_batch & batch) {
|
|
1423
|
+
server_task_result res;
|
|
1424
|
+
res.id = slot.id_task;
|
|
1425
|
+
res.error = false;
|
|
1426
|
+
res.stop = true;
|
|
1427
|
+
|
|
1428
|
+
for (int i = 0; i < batch.n_tokens; ++i) {
|
|
1429
|
+
if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
|
|
1430
|
+
continue;
|
|
1536
1431
|
}
|
|
1537
1432
|
|
|
1538
|
-
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
// if there are numbers, it needs to be treated like a single prompt,
|
|
1542
|
-
// queue_tasks handles a mix of strings and numbers just fine.
|
|
1543
|
-
if (numbers) {
|
|
1544
|
-
queue_tasks.post(task);
|
|
1545
|
-
} else {
|
|
1546
|
-
split_multiprompt_task(id_task, task);
|
|
1433
|
+
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
|
|
1434
|
+
if (embd == NULL) {
|
|
1435
|
+
embd = llama_get_embeddings_ith(ctx, i);
|
|
1547
1436
|
}
|
|
1548
|
-
|
|
1549
|
-
|
|
1437
|
+
|
|
1438
|
+
if (embd == NULL) {
|
|
1439
|
+
SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
|
|
1440
|
+
|
|
1441
|
+
res.data = json {
|
|
1442
|
+
{"index", slot.index},
|
|
1443
|
+
{"score", -1e6},
|
|
1444
|
+
};
|
|
1445
|
+
|
|
1446
|
+
continue;
|
|
1447
|
+
}
|
|
1448
|
+
|
|
1449
|
+
res.data = json {
|
|
1450
|
+
{"index", slot.index},
|
|
1451
|
+
{"score", embd[0]},
|
|
1452
|
+
};
|
|
1550
1453
|
}
|
|
1551
|
-
}
|
|
1552
1454
|
|
|
1553
|
-
|
|
1554
|
-
server_task task;
|
|
1555
|
-
task.type = SERVER_TASK_TYPE_CANCEL;
|
|
1556
|
-
task.id_target = id_task;
|
|
1455
|
+
SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
|
|
1557
1456
|
|
|
1558
|
-
|
|
1457
|
+
queue_results.send(res);
|
|
1559
1458
|
}
|
|
1560
1459
|
|
|
1561
|
-
|
|
1562
|
-
|
|
1563
|
-
|
|
1564
|
-
|
|
1565
|
-
|
|
1460
|
+
//
|
|
1461
|
+
// Functions to create new task(s) and receive result(s)
|
|
1462
|
+
//
|
|
1463
|
+
|
|
1464
|
+
std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
|
|
1465
|
+
std::vector<server_task> tasks;
|
|
1466
|
+
auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
|
|
1467
|
+
server_task task;
|
|
1468
|
+
task.id = queue_tasks.get_new_id();
|
|
1469
|
+
task.cmpl_type = cmpl_type;
|
|
1470
|
+
task.type = SERVER_TASK_TYPE_COMPLETION;
|
|
1471
|
+
if (replace_prompt) {
|
|
1472
|
+
task.data = task_data;
|
|
1473
|
+
task.data["prompt"] = std::move(prompt);
|
|
1474
|
+
} else {
|
|
1475
|
+
task.data = std::move(task_data);
|
|
1476
|
+
}
|
|
1477
|
+
tasks.push_back(std::move(task));
|
|
1478
|
+
};
|
|
1479
|
+
|
|
1480
|
+
static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
|
|
1481
|
+
if (!data.contains("prompt")) {
|
|
1482
|
+
throw std::runtime_error(error_msg);
|
|
1483
|
+
}
|
|
1484
|
+
|
|
1485
|
+
json prompt = data.at("prompt");
|
|
1486
|
+
|
|
1487
|
+
// if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
|
|
1488
|
+
if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
|
|
1489
|
+
data["index"] = 0;
|
|
1490
|
+
create_task(data, false, nullptr);
|
|
1491
|
+
}
|
|
1492
|
+
// otherwise, it's a multiple-prompt task, we break it into smaller tasks
|
|
1493
|
+
else if (prompt.is_array()) {
|
|
1494
|
+
std::vector<json> prompts = prompt;
|
|
1495
|
+
if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
1496
|
+
// prompts[0] is the question
|
|
1497
|
+
// the rest are the answers/documents
|
|
1498
|
+
SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
|
|
1499
|
+
for (size_t i = 1; i < prompts.size(); i++) {
|
|
1500
|
+
json qd;
|
|
1501
|
+
qd.push_back(prompts[0]);
|
|
1502
|
+
qd.push_back(prompts[i]);
|
|
1503
|
+
data["index"] = i - 1;
|
|
1504
|
+
create_task(data, true, qd);
|
|
1505
|
+
}
|
|
1506
|
+
} else {
|
|
1507
|
+
SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
|
|
1508
|
+
for (size_t i = 0; i < prompts.size(); i++) {
|
|
1509
|
+
const auto & e = prompts[i];
|
|
1510
|
+
if (e.is_string() || json_is_array_of_numbers(e)) {
|
|
1511
|
+
data["index"] = i;
|
|
1512
|
+
create_task(data, true, e);
|
|
1513
|
+
} else {
|
|
1514
|
+
throw std::runtime_error(error_msg);
|
|
1515
|
+
}
|
|
1516
|
+
}
|
|
1517
|
+
}
|
|
1566
1518
|
}
|
|
1519
|
+
// invalid case
|
|
1520
|
+
else {
|
|
1521
|
+
throw std::runtime_error(error_msg);
|
|
1522
|
+
}
|
|
1523
|
+
|
|
1524
|
+
return tasks;
|
|
1525
|
+
}
|
|
1526
|
+
|
|
1527
|
+
void cancel_tasks(const std::unordered_set<int> & id_tasks) {
|
|
1528
|
+
std::vector<server_task> cancel_tasks;
|
|
1529
|
+
cancel_tasks.reserve(id_tasks.size());
|
|
1530
|
+
for (const auto & id_task : id_tasks) {
|
|
1531
|
+
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
|
1532
|
+
|
|
1533
|
+
server_task task;
|
|
1534
|
+
task.type = SERVER_TASK_TYPE_CANCEL;
|
|
1535
|
+
task.id_target = id_task;
|
|
1536
|
+
cancel_tasks.push_back(task);
|
|
1537
|
+
queue_results.remove_waiting_task_id(id_task);
|
|
1538
|
+
}
|
|
1539
|
+
// push to beginning of the queue, so it has highest priority
|
|
1540
|
+
queue_tasks.post(cancel_tasks, true);
|
|
1541
|
+
}
|
|
1542
|
+
|
|
1543
|
+
// receive the results from task(s) created by create_tasks_cmpl
|
|
1544
|
+
void receive_cmpl_results(
|
|
1545
|
+
const std::unordered_set<int> & id_tasks,
|
|
1546
|
+
const std::function<void(std::vector<server_task_result>&)> & result_handler,
|
|
1547
|
+
const std::function<void(json)> & error_handler) {
|
|
1548
|
+
// TODO: currently, there is no way to detect the client has cancelled the request
|
|
1549
|
+
std::vector<server_task_result> results(id_tasks.size());
|
|
1550
|
+
for (size_t i = 0; i < id_tasks.size(); i++) {
|
|
1551
|
+
server_task_result result = queue_results.recv(id_tasks);
|
|
1552
|
+
|
|
1553
|
+
if (result.error) {
|
|
1554
|
+
error_handler(result.data);
|
|
1555
|
+
cancel_tasks(id_tasks);
|
|
1556
|
+
return;
|
|
1557
|
+
}
|
|
1567
1558
|
|
|
1568
|
-
|
|
1569
|
-
|
|
1570
|
-
|
|
1571
|
-
|
|
1559
|
+
const size_t idx = result.data["index"];
|
|
1560
|
+
GGML_ASSERT(idx < results.size() && "index out of range");
|
|
1561
|
+
|
|
1562
|
+
results[idx] = result;
|
|
1572
1563
|
}
|
|
1564
|
+
result_handler(results);
|
|
1565
|
+
}
|
|
1573
1566
|
|
|
1574
|
-
|
|
1575
|
-
|
|
1567
|
+
// receive the results from task(s) created by create_tasks_cmpl, in stream mode
|
|
1568
|
+
void receive_cmpl_results_stream(
|
|
1569
|
+
const std::unordered_set<int> & id_tasks, const
|
|
1570
|
+
std::function<bool(server_task_result&)> & result_handler, const
|
|
1571
|
+
std::function<void(json)> & error_handler) {
|
|
1572
|
+
size_t n_finished = 0;
|
|
1573
|
+
while (true) {
|
|
1574
|
+
server_task_result result = queue_results.recv(id_tasks);
|
|
1575
|
+
if (!result_handler(result)) {
|
|
1576
|
+
cancel_tasks(id_tasks);
|
|
1577
|
+
break;
|
|
1578
|
+
}
|
|
1576
1579
|
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
|
|
1580
|
+
if (result.error) {
|
|
1581
|
+
error_handler(result.data);
|
|
1582
|
+
cancel_tasks(id_tasks);
|
|
1583
|
+
break;
|
|
1584
|
+
}
|
|
1581
1585
|
|
|
1582
|
-
|
|
1583
|
-
|
|
1586
|
+
if (result.stop) {
|
|
1587
|
+
if (++n_finished == id_tasks.size()) {
|
|
1588
|
+
break;
|
|
1589
|
+
}
|
|
1590
|
+
}
|
|
1584
1591
|
}
|
|
1585
1592
|
}
|
|
1586
1593
|
|
|
1594
|
+
//
|
|
1595
|
+
// Functions to process the task
|
|
1596
|
+
//
|
|
1597
|
+
|
|
1587
1598
|
void process_single_task(const server_task & task) {
|
|
1588
1599
|
switch (task.type) {
|
|
1589
1600
|
case SERVER_TASK_TYPE_COMPLETION:
|
|
@@ -1605,13 +1616,13 @@ struct server_context {
|
|
|
1605
1616
|
|
|
1606
1617
|
if (slot == nullptr) {
|
|
1607
1618
|
// if no slot is available, we defer this task for processing later
|
|
1608
|
-
|
|
1619
|
+
SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
|
|
1609
1620
|
queue_tasks.defer(task);
|
|
1610
1621
|
break;
|
|
1611
1622
|
}
|
|
1612
|
-
if (
|
|
1623
|
+
if (slot->is_processing()) {
|
|
1613
1624
|
// if requested slot is unavailable, we defer this task for processing later
|
|
1614
|
-
|
|
1625
|
+
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
1615
1626
|
queue_tasks.defer(task);
|
|
1616
1627
|
break;
|
|
1617
1628
|
}
|
|
@@ -1629,12 +1640,11 @@ struct server_context {
|
|
|
1629
1640
|
slot->reset();
|
|
1630
1641
|
|
|
1631
1642
|
slot->id_task = task.id;
|
|
1632
|
-
slot->
|
|
1633
|
-
slot->
|
|
1634
|
-
slot->embedding = task.embedding;
|
|
1643
|
+
slot->cmpl_type = task.cmpl_type;
|
|
1644
|
+
slot->index = json_value(task.data, "index", 0);
|
|
1635
1645
|
|
|
1636
1646
|
if (!launch_slot_with_task(*slot, task)) {
|
|
1637
|
-
|
|
1647
|
+
SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
|
|
1638
1648
|
break;
|
|
1639
1649
|
}
|
|
1640
1650
|
} break;
|
|
@@ -1683,22 +1693,10 @@ struct server_context {
|
|
|
1683
1693
|
|
|
1684
1694
|
slots_data.push_back(slot_data);
|
|
1685
1695
|
}
|
|
1686
|
-
|
|
1687
|
-
{"id_task", task.id},
|
|
1688
|
-
{"n_idle_slots", n_idle_slots},
|
|
1689
|
-
{"n_processing_slots", n_processing_slots}
|
|
1690
|
-
});
|
|
1691
|
-
|
|
1692
|
-
LOG_VERBOSE("slot data", {
|
|
1693
|
-
{"id_task", task.id},
|
|
1694
|
-
{"n_idle_slots", n_idle_slots},
|
|
1695
|
-
{"n_processing_slots", n_processing_slots},
|
|
1696
|
-
{"slots", slots_data}
|
|
1697
|
-
});
|
|
1696
|
+
SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
|
|
1698
1697
|
|
|
1699
1698
|
server_task_result res;
|
|
1700
1699
|
res.id = task.id;
|
|
1701
|
-
res.id_multi = task.id_multi;
|
|
1702
1700
|
res.stop = true;
|
|
1703
1701
|
res.error = false;
|
|
1704
1702
|
res.data = {
|
|
@@ -1717,6 +1715,9 @@ struct server_context {
|
|
|
1717
1715
|
{ "n_tokens_predicted", metrics.n_tokens_predicted},
|
|
1718
1716
|
{ "t_tokens_generation", metrics.t_tokens_generation},
|
|
1719
1717
|
|
|
1718
|
+
{ "n_decode_total", metrics.n_decode_total},
|
|
1719
|
+
{ "n_busy_slots_total", metrics.n_busy_slots_total},
|
|
1720
|
+
|
|
1720
1721
|
{ "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
|
|
1721
1722
|
{ "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
|
|
1722
1723
|
|
|
@@ -1736,9 +1737,9 @@ struct server_context {
|
|
|
1736
1737
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
1737
1738
|
break;
|
|
1738
1739
|
}
|
|
1739
|
-
if (
|
|
1740
|
+
if (slot->is_processing()) {
|
|
1740
1741
|
// if requested slot is unavailable, we defer this task for processing later
|
|
1741
|
-
|
|
1742
|
+
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
1742
1743
|
queue_tasks.defer(task);
|
|
1743
1744
|
break;
|
|
1744
1745
|
}
|
|
@@ -1777,9 +1778,9 @@ struct server_context {
|
|
|
1777
1778
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
1778
1779
|
break;
|
|
1779
1780
|
}
|
|
1780
|
-
if (
|
|
1781
|
+
if (slot->is_processing()) {
|
|
1781
1782
|
// if requested slot is unavailable, we defer this task for processing later
|
|
1782
|
-
|
|
1783
|
+
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
1783
1784
|
queue_tasks.defer(task);
|
|
1784
1785
|
break;
|
|
1785
1786
|
}
|
|
@@ -1825,9 +1826,9 @@ struct server_context {
|
|
|
1825
1826
|
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
|
|
1826
1827
|
break;
|
|
1827
1828
|
}
|
|
1828
|
-
if (
|
|
1829
|
+
if (slot->is_processing()) {
|
|
1829
1830
|
// if requested slot is unavailable, we defer this task for processing later
|
|
1830
|
-
|
|
1831
|
+
SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
|
|
1831
1832
|
queue_tasks.defer(task);
|
|
1832
1833
|
break;
|
|
1833
1834
|
}
|
|
@@ -1847,68 +1848,37 @@ struct server_context {
|
|
|
1847
1848
|
};
|
|
1848
1849
|
queue_results.send(result);
|
|
1849
1850
|
} break;
|
|
1851
|
+
case SERVER_TASK_TYPE_SET_LORA:
|
|
1852
|
+
{
|
|
1853
|
+
llama_lora_adapters_apply(ctx, loras);
|
|
1854
|
+
server_task_result result;
|
|
1855
|
+
result.id = task.id;
|
|
1856
|
+
result.stop = true;
|
|
1857
|
+
result.error = false;
|
|
1858
|
+
result.data = json{{ "success", true }};
|
|
1859
|
+
queue_results.send(result);
|
|
1860
|
+
} break;
|
|
1850
1861
|
}
|
|
1851
1862
|
}
|
|
1852
1863
|
|
|
1853
|
-
void on_finish_multitask(const server_task_multi & multitask) {
|
|
1854
|
-
// all subtasks done == multitask is done
|
|
1855
|
-
server_task_result result;
|
|
1856
|
-
result.id = multitask.id;
|
|
1857
|
-
result.stop = true;
|
|
1858
|
-
result.error = false;
|
|
1859
|
-
|
|
1860
|
-
// collect json results into one json result
|
|
1861
|
-
std::vector<json> result_jsons;
|
|
1862
|
-
for (const auto & subres : multitask.results) {
|
|
1863
|
-
result_jsons.push_back(subres.data);
|
|
1864
|
-
result.error = result.error && subres.error;
|
|
1865
|
-
}
|
|
1866
|
-
result.data = json {
|
|
1867
|
-
{ "results", result_jsons }
|
|
1868
|
-
};
|
|
1869
|
-
|
|
1870
|
-
queue_results.send(result);
|
|
1871
|
-
}
|
|
1872
|
-
|
|
1873
1864
|
void update_slots() {
|
|
1874
1865
|
if (system_need_update) {
|
|
1875
1866
|
system_prompt_update();
|
|
1876
1867
|
}
|
|
1877
1868
|
|
|
1878
|
-
// release slots
|
|
1879
|
-
for (auto & slot : slots) {
|
|
1880
|
-
if (slot.command == SLOT_COMMAND_RELEASE) {
|
|
1881
|
-
slot.state = SLOT_STATE_IDLE;
|
|
1882
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
1883
|
-
slot.t_last_used = ggml_time_us();
|
|
1884
|
-
|
|
1885
|
-
LOG_INFO("slot released", {
|
|
1886
|
-
{"id_slot", slot.id},
|
|
1887
|
-
{"id_task", slot.id_task},
|
|
1888
|
-
{"n_ctx", n_ctx},
|
|
1889
|
-
{"n_past", slot.n_past},
|
|
1890
|
-
{"n_system_tokens", system_tokens.size()},
|
|
1891
|
-
{"n_cache_tokens", slot.cache_tokens.size()},
|
|
1892
|
-
{"truncated", slot.truncated}
|
|
1893
|
-
});
|
|
1894
|
-
|
|
1895
|
-
queue_tasks.notify_slot_changed();
|
|
1896
|
-
}
|
|
1897
|
-
}
|
|
1898
|
-
|
|
1899
1869
|
// check if all slots are idle
|
|
1900
1870
|
{
|
|
1901
1871
|
bool all_idle = true;
|
|
1902
1872
|
|
|
1903
1873
|
for (auto & slot : slots) {
|
|
1904
|
-
if (slot.
|
|
1874
|
+
if (slot.is_processing()) {
|
|
1905
1875
|
all_idle = false;
|
|
1906
1876
|
break;
|
|
1907
1877
|
}
|
|
1908
1878
|
}
|
|
1909
1879
|
|
|
1910
1880
|
if (all_idle) {
|
|
1911
|
-
|
|
1881
|
+
SRV_INF("%s", "all slots are idle\n");
|
|
1912
1882
|
if (system_prompt.empty() && clean_kv_cache) {
|
|
1913
1883
|
kv_cache_clear();
|
|
1914
1884
|
}
|
|
@@ -1918,7 +1888,7 @@ struct server_context {
|
|
|
1918
1888
|
}
|
|
1919
1889
|
|
|
1920
1890
|
{
|
|
1921
|
-
|
|
1891
|
+
SRV_DBG("%s", "posting NEXT_RESPONSE\n");
|
|
1922
1892
|
|
|
1923
1893
|
server_task task;
|
|
1924
1894
|
task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
|
|
@@ -1932,22 +1902,20 @@ struct server_context {
|
|
|
1932
1902
|
for (server_slot & slot : slots) {
|
|
1933
1903
|
if (slot.ga_n == 1) {
|
|
1934
1904
|
if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
|
|
1905
|
+
if (!params.ctx_shift) {
|
|
1906
|
+
// this check is redundant (for good)
|
|
1907
|
+
// we should never get here, because generation should already stopped in process_token()
|
|
1908
|
+
slot.release();
|
|
1909
|
+
send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
|
|
1910
|
+
continue;
|
|
1911
|
+
}
|
|
1912
|
+
|
|
1935
1913
|
// Shift context
|
|
1936
1914
|
const int n_keep = slot.params.n_keep + add_bos_token;
|
|
1937
1915
|
const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
|
|
1938
1916
|
const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
|
|
1939
1917
|
|
|
1940
|
-
|
|
1941
|
-
{"id_slot", slot.id},
|
|
1942
|
-
{"id_task", slot.id_task},
|
|
1943
|
-
{"n_keep", n_keep},
|
|
1944
|
-
{"n_left", n_left},
|
|
1945
|
-
{"n_discard", n_discard},
|
|
1946
|
-
{"n_ctx", n_ctx},
|
|
1947
|
-
{"n_past", slot.n_past},
|
|
1948
|
-
{"n_system_tokens", system_tokens.size()},
|
|
1949
|
-
{"n_cache_tokens", slot.cache_tokens.size()}
|
|
1950
|
-
});
|
|
1918
|
+
SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
|
|
1951
1919
|
|
|
1952
1920
|
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
|
1953
1921
|
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
|
@@ -1972,7 +1940,7 @@ struct server_context {
|
|
|
1972
1940
|
|
|
1973
1941
|
// frist, add sampled tokens from any ongoing sequences
|
|
1974
1942
|
for (auto & slot : slots) {
|
|
1975
|
-
if (slot.state
|
|
1943
|
+
if (slot.state != SLOT_STATE_GENERATING) {
|
|
1976
1944
|
continue;
|
|
1977
1945
|
}
|
|
1978
1946
|
|
|
@@ -1990,15 +1958,8 @@ struct server_context {
|
|
|
1990
1958
|
slot.cache_tokens.push_back(slot.sampled);
|
|
1991
1959
|
}
|
|
1992
1960
|
|
|
1993
|
-
|
|
1994
|
-
|
|
1995
|
-
{"id_task", slot.id_task},
|
|
1996
|
-
{"n_ctx", n_ctx},
|
|
1997
|
-
{"n_past", slot.n_past},
|
|
1998
|
-
{"n_system_tokens", system_tokens.size()},
|
|
1999
|
-
{"n_cache_tokens", slot.cache_tokens.size()},
|
|
2000
|
-
{"truncated", slot.truncated}
|
|
2001
|
-
});
|
|
1961
|
+
SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_system_tokens = %d, n_cache_tokens = %d, truncated = %d\n",
|
|
1962
|
+
slot.n_ctx, slot.n_past, (int) system_tokens.size(), (int) slot.cache_tokens.size(), slot.truncated);
|
|
2002
1963
|
}
|
|
2003
1964
|
|
|
2004
1965
|
// process in chunks of params.n_batch
|
|
@@ -2008,27 +1969,25 @@ struct server_context {
|
|
|
2008
1969
|
// track if this is an embedding or non-embedding batch
|
|
2009
1970
|
// if we've added sampled tokens above, we are in non-embedding mode
|
|
2010
1971
|
// -1: none, 0: non-embedding, 1: embedding
|
|
1972
|
+
// TODO: make enum
|
|
2011
1973
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
2012
1974
|
|
|
2013
1975
|
// next, batch any pending prompts without exceeding n_batch
|
|
2014
1976
|
if (params.cont_batching || batch.n_tokens == 0) {
|
|
2015
1977
|
for (auto & slot : slots) {
|
|
2016
1978
|
// this slot still has a prompt to be processed
|
|
2017
|
-
if (slot.state ==
|
|
1979
|
+
if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
|
|
2018
1980
|
auto & prompt_tokens = slot.prompt_tokens;
|
|
2019
1981
|
|
|
2020
1982
|
// we haven't tokenized the prompt yet - do it now:
|
|
2021
1983
|
if (prompt_tokens.empty()) {
|
|
2022
|
-
|
|
2023
|
-
{"id_slot", slot.id},
|
|
2024
|
-
{"id_task", slot.id_task}
|
|
2025
|
-
});
|
|
1984
|
+
SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
|
|
2026
1985
|
|
|
2027
1986
|
slot.t_start_process_prompt = ggml_time_us();
|
|
2028
1987
|
slot.t_start_generation = 0;
|
|
2029
1988
|
|
|
2030
|
-
if (slot.
|
|
2031
|
-
const bool add_bos =
|
|
1989
|
+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
|
|
1990
|
+
const bool add_bos = llama_add_bos_token(model);
|
|
2032
1991
|
bool suff_rm_leading_spc = true;
|
|
2033
1992
|
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
|
|
2034
1993
|
params.input_suffix.erase(0, 1);
|
|
@@ -2059,6 +2018,29 @@ struct server_context {
|
|
|
2059
2018
|
}
|
|
2060
2019
|
|
|
2061
2020
|
prompt_tokens = embd_inp;
|
|
2021
|
+
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
2022
|
+
// require slot.prompt to be array of 2 strings
|
|
2023
|
+
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
|
|
2024
|
+
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
|
|
2025
|
+
slot.release();
|
|
2026
|
+
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
|
|
2027
|
+
continue;
|
|
2028
|
+
}
|
|
2029
|
+
|
|
2030
|
+
// prompt: [BOS]query[EOS][SEP]doc[EOS]
|
|
2031
|
+
prompt_tokens.clear();
|
|
2032
|
+
prompt_tokens.push_back(llama_token_bos(model));
|
|
2033
|
+
{
|
|
2034
|
+
const auto part = tokenize(slot.prompt[0], false);
|
|
2035
|
+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
2036
|
+
}
|
|
2037
|
+
prompt_tokens.push_back(llama_token_eos(model));
|
|
2038
|
+
prompt_tokens.push_back(llama_token_sep(model));
|
|
2039
|
+
{
|
|
2040
|
+
const auto part = tokenize(slot.prompt[1], false);
|
|
2041
|
+
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
|
|
2042
|
+
}
|
|
2043
|
+
prompt_tokens.push_back(llama_token_eos(model));
|
|
2062
2044
|
} else {
|
|
2063
2045
|
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
|
2064
2046
|
}
|
|
@@ -2066,40 +2048,34 @@ struct server_context {
|
|
|
2066
2048
|
slot.n_past = 0;
|
|
2067
2049
|
slot.n_prompt_tokens = prompt_tokens.size();
|
|
2068
2050
|
|
|
2069
|
-
|
|
2070
|
-
{"id_slot", slot.id},
|
|
2071
|
-
{"id_task", slot.id_task},
|
|
2072
|
-
{"n_ctx", slot.n_ctx},
|
|
2073
|
-
{"n_keep", slot.params.n_keep},
|
|
2074
|
-
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
2075
|
-
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
|
2076
|
-
});
|
|
2051
|
+
SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
|
|
2077
2052
|
|
|
2078
2053
|
// empty prompt passed -> release the slot and send empty response
|
|
2079
2054
|
if (prompt_tokens.empty()) {
|
|
2080
|
-
|
|
2081
|
-
{"id_slot", slot.id},
|
|
2082
|
-
{"id_task", slot.id_task}
|
|
2083
|
-
});
|
|
2055
|
+
SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
|
|
2084
2056
|
|
|
2085
|
-
slot.state = SLOT_STATE_PROCESSING;
|
|
2086
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
2087
2057
|
slot.release();
|
|
2088
2058
|
slot.print_timings();
|
|
2089
2059
|
send_final_response(slot);
|
|
2090
2060
|
continue;
|
|
2091
2061
|
}
|
|
2092
2062
|
|
|
2093
|
-
if (slot.
|
|
2063
|
+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
2094
2064
|
// this prompt is too large to process - discard it
|
|
2095
2065
|
if (slot.n_prompt_tokens > n_ubatch) {
|
|
2096
|
-
slot.state = SLOT_STATE_PROCESSING;
|
|
2097
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
2098
2066
|
slot.release();
|
|
2099
2067
|
send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
|
|
2100
2068
|
continue;
|
|
2101
2069
|
}
|
|
2102
2070
|
} else {
|
|
2071
|
+
if (!params.ctx_shift) {
|
|
2072
|
+
// if context shift is disabled, we make sure prompt size is smaller than KV size
|
|
2073
|
+
if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
|
|
2074
|
+
slot.release();
|
|
2075
|
+
send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
|
|
2076
|
+
continue;
|
|
2077
|
+
}
|
|
2078
|
+
}
|
|
2103
2079
|
if (slot.params.n_keep < 0) {
|
|
2104
2080
|
slot.params.n_keep = slot.n_prompt_tokens;
|
|
2105
2081
|
}
|
|
@@ -2126,20 +2102,12 @@ struct server_context {
|
|
|
2126
2102
|
slot.truncated = true;
|
|
2127
2103
|
slot.n_prompt_tokens = prompt_tokens.size();
|
|
2128
2104
|
|
|
2129
|
-
|
|
2130
|
-
{"id_slot", slot.id},
|
|
2131
|
-
{"id_task", slot.id_task},
|
|
2132
|
-
{"n_ctx", slot.n_ctx},
|
|
2133
|
-
{"n_keep", slot.params.n_keep},
|
|
2134
|
-
{"n_left", n_left},
|
|
2135
|
-
{"n_prompt_tokens", slot.n_prompt_tokens},
|
|
2136
|
-
{"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
|
|
2137
|
-
});
|
|
2105
|
+
SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
|
|
2138
2106
|
|
|
2139
2107
|
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
|
2140
2108
|
}
|
|
2141
2109
|
|
|
2142
|
-
|
|
2110
|
+
gpt_sampler_reset(slot.smpl);
|
|
2143
2111
|
|
|
2144
2112
|
if (!slot.params.cache_prompt) {
|
|
2145
2113
|
slot.n_past_se = 0;
|
|
@@ -2152,17 +2120,14 @@ struct server_context {
|
|
|
2152
2120
|
|
|
2153
2121
|
// push the prompt into the sampling context (do not apply grammar)
|
|
2154
2122
|
for (int i = 0; i < slot.n_past; ++i) {
|
|
2155
|
-
|
|
2123
|
+
gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
|
2156
2124
|
}
|
|
2157
2125
|
}
|
|
2158
2126
|
}
|
|
2159
2127
|
|
|
2160
2128
|
if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
|
|
2161
2129
|
// we have to evaluate at least 1 token to generate logits.
|
|
2162
|
-
|
|
2163
|
-
{ "id_slot", slot.id },
|
|
2164
|
-
{ "id_task", slot.id_task }
|
|
2165
|
-
});
|
|
2130
|
+
SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
|
|
2166
2131
|
|
|
2167
2132
|
slot.n_past--;
|
|
2168
2133
|
if (slot.ga_i > 0) {
|
|
@@ -2173,7 +2138,8 @@ struct server_context {
|
|
|
2173
2138
|
slot.n_prompt_tokens_processed = 0;
|
|
2174
2139
|
}
|
|
2175
2140
|
|
|
2176
|
-
|
|
2141
|
+
// non-causal tasks require to fit the entire prompt in the physical batch
|
|
2142
|
+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
2177
2143
|
// cannot fit the prompt in the current batch - will try next iter
|
|
2178
2144
|
if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
|
|
2179
2145
|
continue;
|
|
@@ -2181,7 +2147,10 @@ struct server_context {
|
|
|
2181
2147
|
}
|
|
2182
2148
|
|
|
2183
2149
|
// check that we are in the right batch_type, if not defer the slot
|
|
2184
|
-
bool slot_type =
|
|
2150
|
+
const bool slot_type =
|
|
2151
|
+
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
|
|
2152
|
+
slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
|
|
2153
|
+
|
|
2185
2154
|
if (batch_type == -1) {
|
|
2186
2155
|
batch_type = slot_type;
|
|
2187
2156
|
} else if (batch_type != slot_type) {
|
|
@@ -2205,17 +2174,13 @@ struct server_context {
|
|
|
2205
2174
|
slot.n_past_se = 0;
|
|
2206
2175
|
slot.ga_i = 0;
|
|
2207
2176
|
// TODO: is the system prompt ever in the sampling context?
|
|
2208
|
-
|
|
2177
|
+
gpt_sampler_reset(slot.smpl);
|
|
2209
2178
|
}
|
|
2210
2179
|
|
|
2211
2180
|
// remove the non-common part from the cache
|
|
2212
2181
|
slot.cache_tokens.resize(slot.n_past);
|
|
2213
2182
|
|
|
2214
|
-
|
|
2215
|
-
{ "id_slot", slot.id },
|
|
2216
|
-
{ "id_task", slot.id_task },
|
|
2217
|
-
{ "p0", p0 }
|
|
2218
|
-
});
|
|
2183
|
+
SLT_INF(slot, "kv cache rm [%d, end)\n", p0);
|
|
2219
2184
|
|
|
2220
2185
|
int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
|
|
2221
2186
|
|
|
@@ -2244,18 +2209,11 @@ struct server_context {
|
|
|
2244
2209
|
slot_npast++;
|
|
2245
2210
|
}
|
|
2246
2211
|
|
|
2247
|
-
|
|
2248
|
-
{"id_slot", slot.id},
|
|
2249
|
-
{"n_past", slot.n_past},
|
|
2250
|
-
{"n_ctx", n_ctx},
|
|
2251
|
-
{"n_tokens", batch.n_tokens},
|
|
2252
|
-
{"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
|
|
2253
|
-
});
|
|
2212
|
+
SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
|
|
2254
2213
|
|
|
2255
|
-
// entire prompt has been processed
|
|
2214
|
+
// entire prompt has been processed
|
|
2256
2215
|
if (slot.n_past == slot.n_prompt_tokens) {
|
|
2257
|
-
slot.state
|
|
2258
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
2216
|
+
slot.state = SLOT_STATE_DONE_PROMPT;
|
|
2259
2217
|
|
|
2260
2218
|
GGML_ASSERT(batch.n_tokens > 0);
|
|
2261
2219
|
|
|
@@ -2265,12 +2223,7 @@ struct server_context {
|
|
|
2265
2223
|
slot.n_decoded = 0;
|
|
2266
2224
|
slot.i_batch = batch.n_tokens - 1;
|
|
2267
2225
|
|
|
2268
|
-
|
|
2269
|
-
{"id_slot", slot.id},
|
|
2270
|
-
{"n_past", slot.n_past},
|
|
2271
|
-
{"n_ctx", n_ctx},
|
|
2272
|
-
{"n_tokens", batch.n_tokens},
|
|
2273
|
-
});
|
|
2226
|
+
SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
|
|
2274
2227
|
}
|
|
2275
2228
|
}
|
|
2276
2229
|
|
|
@@ -2281,13 +2234,11 @@ struct server_context {
|
|
|
2281
2234
|
}
|
|
2282
2235
|
|
|
2283
2236
|
if (batch.n_tokens == 0) {
|
|
2284
|
-
|
|
2237
|
+
SRV_WRN("%s", "no tokens to decode\n");
|
|
2285
2238
|
return;
|
|
2286
2239
|
}
|
|
2287
2240
|
|
|
2288
|
-
|
|
2289
|
-
{"n_tokens", batch.n_tokens},
|
|
2290
|
-
});
|
|
2241
|
+
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
|
2291
2242
|
|
|
2292
2243
|
// make sure we're in the right embedding mode
|
|
2293
2244
|
llama_set_embeddings(ctx, batch_type == 1);
|
|
@@ -2305,10 +2256,9 @@ struct server_context {
|
|
|
2305
2256
|
const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
|
|
2306
2257
|
const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
|
|
2307
2258
|
|
|
2308
|
-
|
|
2309
|
-
|
|
2310
|
-
|
|
2311
|
-
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
|
2259
|
+
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
|
|
2260
|
+
SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
|
2261
|
+
SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
|
2312
2262
|
|
|
2313
2263
|
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
|
2314
2264
|
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
|
@@ -2318,7 +2268,7 @@ struct server_context {
|
|
|
2318
2268
|
|
|
2319
2269
|
slot.ga_i += slot.ga_w / slot.ga_n;
|
|
2320
2270
|
|
|
2321
|
-
|
|
2271
|
+
SLT_DBG(slot, "\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
|
|
2322
2272
|
}
|
|
2323
2273
|
|
|
2324
2274
|
slot.n_past_se += n_tokens;
|
|
@@ -2337,18 +2287,13 @@ struct server_context {
|
|
|
2337
2287
|
};
|
|
2338
2288
|
|
|
2339
2289
|
const int ret = llama_decode(ctx, batch_view);
|
|
2290
|
+
metrics.on_decoded(slots);
|
|
2340
2291
|
|
|
2341
2292
|
if (ret != 0) {
|
|
2342
2293
|
if (n_batch == 1 || ret < 0) {
|
|
2343
2294
|
// if you get here, it means the KV cache is full - try increasing it via the context size
|
|
2344
|
-
|
|
2345
|
-
{"i", i},
|
|
2346
|
-
{"n_batch", ret},
|
|
2347
|
-
{"ret", ret},
|
|
2348
|
-
});
|
|
2295
|
+
SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
|
2349
2296
|
for (auto & slot : slots) {
|
|
2350
|
-
slot.state = SLOT_STATE_PROCESSING;
|
|
2351
|
-
slot.command = SLOT_COMMAND_NONE;
|
|
2352
2297
|
slot.release();
|
|
2353
2298
|
send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
|
|
2354
2299
|
}
|
|
@@ -2359,32 +2304,42 @@ struct server_context {
|
|
|
2359
2304
|
n_batch /= 2;
|
|
2360
2305
|
i -= n_batch;
|
|
2361
2306
|
|
|
2362
|
-
|
|
2363
|
-
{"i", i},
|
|
2364
|
-
{"n_batch", n_batch},
|
|
2365
|
-
{"ret", ret},
|
|
2366
|
-
});
|
|
2307
|
+
SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
|
|
2367
2308
|
|
|
2368
2309
|
continue; // continue loop of n_batch
|
|
2369
2310
|
}
|
|
2370
2311
|
|
|
2371
2312
|
for (auto & slot : slots) {
|
|
2372
|
-
if (slot.
|
|
2313
|
+
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
|
2373
2314
|
continue; // continue loop of slots
|
|
2374
2315
|
}
|
|
2375
2316
|
|
|
2376
|
-
|
|
2377
|
-
|
|
2378
|
-
|
|
2379
|
-
|
|
2380
|
-
|
|
2381
|
-
|
|
2382
|
-
|
|
2317
|
+
if (slot.state == SLOT_STATE_DONE_PROMPT) {
|
|
2318
|
+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
|
|
2319
|
+
// prompt evaluated for embedding
|
|
2320
|
+
send_embedding(slot, batch_view);
|
|
2321
|
+
slot.release();
|
|
2322
|
+
slot.i_batch = -1;
|
|
2323
|
+
continue; // continue loop of slots
|
|
2324
|
+
}
|
|
2325
|
+
|
|
2326
|
+
if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
|
|
2327
|
+
send_rerank(slot, batch_view);
|
|
2328
|
+
slot.release();
|
|
2329
|
+
slot.i_batch = -1;
|
|
2330
|
+
continue; // continue loop of slots
|
|
2331
|
+
}
|
|
2332
|
+
|
|
2333
|
+
// prompt evaluated for next-token prediction
|
|
2334
|
+
slot.state = SLOT_STATE_GENERATING;
|
|
2335
|
+
} else if (slot.state != SLOT_STATE_GENERATING) {
|
|
2336
|
+
continue; // continue loop of slots
|
|
2337
|
+
}
|
|
2383
2338
|
|
|
2384
2339
|
completion_token_output result;
|
|
2385
|
-
const llama_token id =
|
|
2340
|
+
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
|
2386
2341
|
|
|
2387
|
-
|
|
2342
|
+
gpt_sampler_accept(slot.smpl, id, true);
|
|
2388
2343
|
|
|
2389
2344
|
slot.n_decoded += 1;
|
|
2390
2345
|
if (slot.n_decoded == 1) {
|
|
@@ -2393,37 +2348,19 @@ struct server_context {
|
|
|
2393
2348
|
metrics.on_prompt_eval(slot);
|
|
2394
2349
|
}
|
|
2395
2350
|
|
|
2396
|
-
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
|
2397
2351
|
result.tok = id;
|
|
2398
2352
|
|
|
2399
|
-
const
|
|
2400
|
-
if (n_probs > 0) {
|
|
2401
|
-
const size_t n_valid = slot.ctx_sampling->n_valid;
|
|
2402
|
-
|
|
2403
|
-
// Make sure at least n_probs top tokens are at the front of the vector:
|
|
2404
|
-
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
|
|
2405
|
-
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
|
|
2406
|
-
}
|
|
2353
|
+
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
|
|
2407
2354
|
|
|
2408
|
-
|
|
2409
|
-
|
|
2410
|
-
|
|
2411
|
-
|
|
2412
|
-
|
|
2413
|
-
i == 0 ? 1.0f : 0.0f
|
|
2414
|
-
});
|
|
2415
|
-
}
|
|
2416
|
-
} else {
|
|
2417
|
-
for (size_t i = 0; i < n_probs; ++i) {
|
|
2418
|
-
result.probs.push_back({
|
|
2419
|
-
cur_p.data[i].id,
|
|
2420
|
-
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
|
|
2421
|
-
});
|
|
2422
|
-
}
|
|
2423
|
-
}
|
|
2355
|
+
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
|
|
2356
|
+
result.probs.push_back({
|
|
2357
|
+
cur_p->data[i].id,
|
|
2358
|
+
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
|
2359
|
+
});
|
|
2424
2360
|
}
|
|
2425
2361
|
|
|
2426
2362
|
if (!process_token(result, slot)) {
|
|
2363
|
+
// release slot because of stop condition
|
|
2427
2364
|
slot.release();
|
|
2428
2365
|
slot.print_timings();
|
|
2429
2366
|
send_final_response(slot);
|
|
@@ -2434,7 +2371,7 @@ struct server_context {
|
|
|
2434
2371
|
}
|
|
2435
2372
|
}
|
|
2436
2373
|
|
|
2437
|
-
|
|
2374
|
+
SRV_DBG("%s", "run slots completed\n");
|
|
2438
2375
|
}
|
|
2439
2376
|
|
|
2440
2377
|
json model_meta() const {
|
|
@@ -2455,19 +2392,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
|
|
|
2455
2392
|
return;
|
|
2456
2393
|
}
|
|
2457
2394
|
|
|
2458
|
-
|
|
2459
|
-
{"remote_addr", req.remote_addr},
|
|
2460
|
-
{"remote_port", req.remote_port},
|
|
2461
|
-
{"status", res.status},
|
|
2462
|
-
{"method", req.method},
|
|
2463
|
-
{"path", req.path},
|
|
2464
|
-
{"params", req.params},
|
|
2465
|
-
});
|
|
2395
|
+
LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
|
|
2466
2396
|
|
|
2467
|
-
|
|
2468
|
-
|
|
2469
|
-
{"response", res.body},
|
|
2470
|
-
});
|
|
2397
|
+
LOG_DBG("request: %s\n", req.body.c_str());
|
|
2398
|
+
LOG_DBG("response: %s\n", res.body.c_str());
|
|
2471
2399
|
}
|
|
2472
2400
|
|
|
2473
2401
|
std::function<void(int)> shutdown_handler;
|
|
@@ -2485,20 +2413,18 @@ inline void signal_handler(int signal) {
|
|
|
2485
2413
|
}
|
|
2486
2414
|
|
|
2487
2415
|
int main(int argc, char ** argv) {
|
|
2488
|
-
#if SERVER_VERBOSE != 1
|
|
2489
|
-
log_disable();
|
|
2490
|
-
#endif
|
|
2491
2416
|
// own arguments required by this example
|
|
2492
2417
|
gpt_params params;
|
|
2493
2418
|
|
|
2494
|
-
if (!gpt_params_parse(argc, argv, params)) {
|
|
2495
|
-
gpt_params_print_usage(argc, argv, params);
|
|
2419
|
+
if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
|
|
2496
2420
|
return 1;
|
|
2497
2421
|
}
|
|
2498
2422
|
|
|
2499
|
-
|
|
2500
|
-
|
|
2501
|
-
|
|
2423
|
+
gpt_init();
|
|
2424
|
+
|
|
2425
|
+
// enabling this will output extra debug information in the HTTP responses from the server
|
|
2426
|
+
// see format_final_response_oaicompat()
|
|
2427
|
+
const bool verbose = params.verbosity > 9;
|
|
2502
2428
|
|
|
2503
2429
|
// struct that contains llama context and inference
|
|
2504
2430
|
server_context ctx_server;
|
|
@@ -2514,30 +2440,27 @@ int main(int argc, char ** argv) {
|
|
|
2514
2440
|
llama_backend_init();
|
|
2515
2441
|
llama_numa_init(params.numa);
|
|
2516
2442
|
|
|
2517
|
-
|
|
2518
|
-
|
|
2519
|
-
|
|
2520
|
-
|
|
2521
|
-
|
|
2522
|
-
LOG_INFO("system info", {
|
|
2523
|
-
{"n_threads", params.n_threads},
|
|
2524
|
-
{"n_threads_batch", params.n_threads_batch},
|
|
2525
|
-
{"total_threads", std::thread::hardware_concurrency()},
|
|
2526
|
-
{"system_info", llama_print_system_info()},
|
|
2527
|
-
});
|
|
2443
|
+
LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
|
|
2444
|
+
LOG_INF("\n");
|
|
2445
|
+
LOG_INF("%s\n", gpt_params_get_system_info(params).c_str());
|
|
2446
|
+
LOG_INF("\n");
|
|
2528
2447
|
|
|
2529
2448
|
std::unique_ptr<httplib::Server> svr;
|
|
2530
2449
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
|
2531
2450
|
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
|
2532
|
-
|
|
2451
|
+
LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
|
|
2533
2452
|
svr.reset(
|
|
2534
2453
|
new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
|
|
2535
2454
|
);
|
|
2536
2455
|
} else {
|
|
2537
|
-
|
|
2456
|
+
LOG_INF("Running without SSL\n");
|
|
2538
2457
|
svr.reset(new httplib::Server());
|
|
2539
2458
|
}
|
|
2540
2459
|
#else
|
|
2460
|
+
if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
|
|
2461
|
+
LOG_ERR("Server is built without SSL support\n");
|
|
2462
|
+
return 1;
|
|
2463
|
+
}
|
|
2541
2464
|
svr.reset(new httplib::Server());
|
|
2542
2465
|
#endif
|
|
2543
2466
|
|
|
@@ -2546,26 +2469,31 @@ int main(int argc, char ** argv) {
|
|
|
2546
2469
|
svr->set_default_headers({{"Server", "llama.cpp"}});
|
|
2547
2470
|
|
|
2548
2471
|
// CORS preflight
|
|
2549
|
-
svr->Options(R"(.*)", [](const httplib::Request
|
|
2550
|
-
|
|
2472
|
+
svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
|
|
2473
|
+
// Access-Control-Allow-Origin is already set by middleware
|
|
2551
2474
|
res.set_header("Access-Control-Allow-Credentials", "true");
|
|
2552
2475
|
res.set_header("Access-Control-Allow-Methods", "POST");
|
|
2553
2476
|
res.set_header("Access-Control-Allow-Headers", "*");
|
|
2554
|
-
return res.set_content("", "
|
|
2477
|
+
return res.set_content("", "text/html"); // blank response, no data
|
|
2555
2478
|
});
|
|
2556
2479
|
|
|
2557
2480
|
svr->set_logger(log_server_request);
|
|
2558
2481
|
|
|
2559
|
-
auto res_error = [](httplib::Response & res, json error_data) {
|
|
2482
|
+
auto res_error = [](httplib::Response & res, const json & error_data) {
|
|
2560
2483
|
json final_response {{"error", error_data}};
|
|
2561
|
-
res.set_content(final_response.dump(
|
|
2484
|
+
res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
|
2562
2485
|
res.status = json_value(error_data, "code", 500);
|
|
2563
2486
|
};
|
|
2564
2487
|
|
|
2488
|
+
auto res_ok = [](httplib::Response & res, const json & data) {
|
|
2489
|
+
res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
|
|
2490
|
+
res.status = 200;
|
|
2491
|
+
};
|
|
2492
|
+
|
|
2565
2493
|
svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
|
|
2566
2494
|
std::string message;
|
|
2567
2495
|
try {
|
|
2568
|
-
std::rethrow_exception(
|
|
2496
|
+
std::rethrow_exception(ep);
|
|
2569
2497
|
} catch (std::exception & e) {
|
|
2570
2498
|
message = e.what();
|
|
2571
2499
|
} catch (...) {
|
|
@@ -2573,7 +2501,7 @@ int main(int argc, char ** argv) {
|
|
|
2573
2501
|
}
|
|
2574
2502
|
|
|
2575
2503
|
json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
|
|
2576
|
-
|
|
2504
|
+
LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
|
|
2577
2505
|
res_error(res, formatted_error);
|
|
2578
2506
|
});
|
|
2579
2507
|
|
|
@@ -2588,11 +2516,6 @@ int main(int argc, char ** argv) {
|
|
|
2588
2516
|
svr->set_read_timeout (params.timeout_read);
|
|
2589
2517
|
svr->set_write_timeout(params.timeout_write);
|
|
2590
2518
|
|
|
2591
|
-
if (!svr->bind_to_port(params.hostname, params.port)) {
|
|
2592
|
-
fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
|
|
2593
|
-
return 1;
|
|
2594
|
-
}
|
|
2595
|
-
|
|
2596
2519
|
std::unordered_map<std::string, std::string> log_data;
|
|
2597
2520
|
|
|
2598
2521
|
log_data["hostname"] = params.hostname;
|
|
@@ -2608,42 +2531,13 @@ int main(int argc, char ** argv) {
|
|
|
2608
2531
|
// Necessary similarity of prompt for slot selection
|
|
2609
2532
|
ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
|
|
2610
2533
|
|
|
2611
|
-
// load the model
|
|
2612
|
-
if (!ctx_server.load_model(params)) {
|
|
2613
|
-
state.store(SERVER_STATE_ERROR);
|
|
2614
|
-
return 1;
|
|
2615
|
-
} else {
|
|
2616
|
-
ctx_server.init();
|
|
2617
|
-
state.store(SERVER_STATE_READY);
|
|
2618
|
-
}
|
|
2619
|
-
|
|
2620
|
-
LOG_INFO("model loaded", {});
|
|
2621
|
-
|
|
2622
|
-
const auto model_meta = ctx_server.model_meta();
|
|
2623
|
-
|
|
2624
|
-
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
|
2625
|
-
if (params.chat_template.empty()) {
|
|
2626
|
-
if (!ctx_server.validate_model_chat_template()) {
|
|
2627
|
-
LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
|
|
2628
|
-
params.chat_template = "chatml";
|
|
2629
|
-
}
|
|
2630
|
-
}
|
|
2631
|
-
|
|
2632
|
-
// print sample chat example to make it clear which template is used
|
|
2633
|
-
{
|
|
2634
|
-
LOG_INFO("chat template", {
|
|
2635
|
-
{"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
|
|
2636
|
-
{"built_in", params.chat_template.empty()},
|
|
2637
|
-
});
|
|
2638
|
-
}
|
|
2639
|
-
|
|
2640
2534
|
//
|
|
2641
2535
|
// Middlewares
|
|
2642
2536
|
//
|
|
2643
2537
|
|
|
2644
2538
|
auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
2645
2539
|
// TODO: should we apply API key to all endpoints, including "/health" and "/models"?
|
|
2646
|
-
static const std::
|
|
2540
|
+
static const std::unordered_set<std::string> protected_endpoints = {
|
|
2647
2541
|
"/props",
|
|
2648
2542
|
"/completion",
|
|
2649
2543
|
"/completions",
|
|
@@ -2680,17 +2574,34 @@ int main(int argc, char ** argv) {
|
|
|
2680
2574
|
}
|
|
2681
2575
|
|
|
2682
2576
|
// API key is invalid or not provided
|
|
2683
|
-
// TODO: make another middleware for CORS related logic
|
|
2684
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
2685
2577
|
res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
|
|
2686
2578
|
|
|
2687
|
-
|
|
2579
|
+
LOG_WRN("Unauthorized: Invalid API Key\n");
|
|
2688
2580
|
|
|
2689
2581
|
return false;
|
|
2690
2582
|
};
|
|
2691
2583
|
|
|
2584
|
+
auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
|
|
2585
|
+
server_state current_state = state.load();
|
|
2586
|
+
if (current_state == SERVER_STATE_LOADING_MODEL) {
|
|
2587
|
+
auto tmp = string_split(req.path, '.');
|
|
2588
|
+
if (req.path == "/" || tmp.back() == "html") {
|
|
2589
|
+
res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
|
|
2590
|
+
res.status = 503;
|
|
2591
|
+
} else {
|
|
2592
|
+
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
|
2593
|
+
}
|
|
2594
|
+
return false;
|
|
2595
|
+
}
|
|
2596
|
+
return true;
|
|
2597
|
+
};
|
|
2598
|
+
|
|
2692
2599
|
// register server middlewares
|
|
2693
|
-
svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
|
|
2600
|
+
svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
|
|
2601
|
+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
2602
|
+
if (!middleware_server_state(req, res)) {
|
|
2603
|
+
return httplib::Server::HandlerResponse::Handled;
|
|
2604
|
+
}
|
|
2694
2605
|
if (!middleware_validate_api_key(req, res)) {
|
|
2695
2606
|
return httplib::Server::HandlerResponse::Handled;
|
|
2696
2607
|
}
|
|
@@ -2701,99 +2612,57 @@ int main(int argc, char ** argv) {
|
|
|
2701
2612
|
// Route handlers (or controllers)
|
|
2702
2613
|
//
|
|
2703
2614
|
|
|
2704
|
-
const auto handle_health = [&](const httplib::Request
|
|
2705
|
-
|
|
2706
|
-
|
|
2707
|
-
|
|
2708
|
-
{
|
|
2709
|
-
// request slots data using task queue
|
|
2710
|
-
server_task task;
|
|
2711
|
-
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2712
|
-
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2713
|
-
task.id_target = -1;
|
|
2714
|
-
|
|
2715
|
-
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2716
|
-
ctx_server.queue_tasks.post(task);
|
|
2717
|
-
|
|
2718
|
-
// get the result
|
|
2719
|
-
server_task_result result = ctx_server.queue_results.recv(task.id);
|
|
2720
|
-
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2721
|
-
|
|
2722
|
-
const int n_idle_slots = result.data.at("idle");
|
|
2723
|
-
const int n_processing_slots = result.data.at("processing");
|
|
2724
|
-
|
|
2725
|
-
json health = {
|
|
2726
|
-
{"status", "ok"},
|
|
2727
|
-
{"slots_idle", n_idle_slots},
|
|
2728
|
-
{"slots_processing", n_processing_slots}
|
|
2729
|
-
};
|
|
2730
|
-
|
|
2731
|
-
res.status = 200; // HTTP OK
|
|
2732
|
-
if (params.endpoint_slots && req.has_param("include_slots")) {
|
|
2733
|
-
health["slots"] = result.data.at("slots");
|
|
2734
|
-
}
|
|
2735
|
-
|
|
2736
|
-
if (n_idle_slots == 0) {
|
|
2737
|
-
health["status"] = "no slot available";
|
|
2738
|
-
if (req.has_param("fail_on_no_slot")) {
|
|
2739
|
-
res.status = 503; // HTTP Service Unavailable
|
|
2740
|
-
}
|
|
2741
|
-
}
|
|
2742
|
-
|
|
2743
|
-
res.set_content(health.dump(), "application/json");
|
|
2744
|
-
break;
|
|
2745
|
-
}
|
|
2746
|
-
case SERVER_STATE_LOADING_MODEL:
|
|
2747
|
-
{
|
|
2748
|
-
res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
|
|
2749
|
-
} break;
|
|
2750
|
-
case SERVER_STATE_ERROR:
|
|
2751
|
-
{
|
|
2752
|
-
res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
|
|
2753
|
-
} break;
|
|
2754
|
-
}
|
|
2615
|
+
const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
|
|
2616
|
+
// error and loading states are handled by middleware
|
|
2617
|
+
json health = {{"status", "ok"}};
|
|
2618
|
+
res_ok(res, health);
|
|
2755
2619
|
};
|
|
2756
2620
|
|
|
2757
|
-
const auto handle_slots = [&](const httplib::Request
|
|
2621
|
+
const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
|
|
2758
2622
|
if (!params.endpoint_slots) {
|
|
2759
|
-
res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
|
2623
|
+
res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2760
2624
|
return;
|
|
2761
2625
|
}
|
|
2762
2626
|
|
|
2763
2627
|
// request slots data using task queue
|
|
2764
2628
|
server_task task;
|
|
2765
2629
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2766
|
-
task.id_multi = -1;
|
|
2767
|
-
task.id_target = -1;
|
|
2768
2630
|
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2769
2631
|
|
|
2770
2632
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2771
|
-
ctx_server.queue_tasks.post(task);
|
|
2633
|
+
ctx_server.queue_tasks.post(task, true); // high-priority task
|
|
2772
2634
|
|
|
2773
2635
|
// get the result
|
|
2774
2636
|
server_task_result result = ctx_server.queue_results.recv(task.id);
|
|
2775
2637
|
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
2776
2638
|
|
|
2777
|
-
|
|
2778
|
-
|
|
2639
|
+
// optionally return "fail_on_no_slot" error
|
|
2640
|
+
const int n_idle_slots = result.data.at("idle");
|
|
2641
|
+
if (req.has_param("fail_on_no_slot")) {
|
|
2642
|
+
if (n_idle_slots == 0) {
|
|
2643
|
+
res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
|
|
2644
|
+
return;
|
|
2645
|
+
}
|
|
2646
|
+
}
|
|
2647
|
+
|
|
2648
|
+
res_ok(res, result.data.at("slots"));
|
|
2779
2649
|
};
|
|
2780
2650
|
|
|
2781
2651
|
const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
|
|
2782
2652
|
if (!params.endpoint_metrics) {
|
|
2783
|
-
res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
|
|
2653
|
+
res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2784
2654
|
return;
|
|
2785
2655
|
}
|
|
2786
2656
|
|
|
2787
2657
|
// request slots data using task queue
|
|
2788
2658
|
server_task task;
|
|
2789
2659
|
task.id = ctx_server.queue_tasks.get_new_id();
|
|
2790
|
-
task.id_multi = -1;
|
|
2791
2660
|
task.id_target = -1;
|
|
2792
2661
|
task.type = SERVER_TASK_TYPE_METRICS;
|
|
2793
2662
|
task.data.push_back({{"reset_bucket", true}});
|
|
2794
2663
|
|
|
2795
2664
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
2796
|
-
ctx_server.queue_tasks.post(task);
|
|
2665
|
+
ctx_server.queue_tasks.post(task, true); // high-priority task
|
|
2797
2666
|
|
|
2798
2667
|
// get the result
|
|
2799
2668
|
server_task_result result = ctx_server.queue_results.recv(task.id);
|
|
@@ -2807,6 +2676,9 @@ int main(int argc, char ** argv) {
|
|
|
2807
2676
|
const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
|
|
2808
2677
|
const uint64_t t_tokens_generation = data.at("t_tokens_generation");
|
|
2809
2678
|
|
|
2679
|
+
const uint64_t n_decode_total = data.at("n_decode_total");
|
|
2680
|
+
const uint64_t n_busy_slots_total = data.at("n_busy_slots_total");
|
|
2681
|
+
|
|
2810
2682
|
const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
|
|
2811
2683
|
|
|
2812
2684
|
// metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
|
|
@@ -2827,6 +2699,14 @@ int main(int argc, char ** argv) {
|
|
|
2827
2699
|
{"name", "tokens_predicted_seconds_total"},
|
|
2828
2700
|
{"help", "Predict process time"},
|
|
2829
2701
|
{"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
|
|
2702
|
+
}, {
|
|
2703
|
+
{"name", "n_decode_total"},
|
|
2704
|
+
{"help", "Total number of llama_decode() calls"},
|
|
2705
|
+
{"value", n_decode_total}
|
|
2706
|
+
}, {
|
|
2707
|
+
{"name", "n_busy_slots_per_decode"},
|
|
2708
|
+
{"help", "Average number of busy slots per llama_decode() call"},
|
|
2709
|
+
{"value", (float) n_busy_slots_total / (float) n_decode_total}
|
|
2830
2710
|
}}},
|
|
2831
2711
|
{"gauge", {{
|
|
2832
2712
|
{"name", "prompt_tokens_seconds"},
|
|
@@ -2879,7 +2759,7 @@ int main(int argc, char ** argv) {
|
|
|
2879
2759
|
res.status = 200; // HTTP OK
|
|
2880
2760
|
};
|
|
2881
2761
|
|
|
2882
|
-
const auto handle_slots_save = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
2762
|
+
const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
2883
2763
|
json request_data = json::parse(req.body);
|
|
2884
2764
|
std::string filename = request_data.at("filename");
|
|
2885
2765
|
if (!fs_validate_filename(filename)) {
|
|
@@ -2893,7 +2773,7 @@ int main(int argc, char ** argv) {
|
|
|
2893
2773
|
task.data = {
|
|
2894
2774
|
{ "id_slot", id_slot },
|
|
2895
2775
|
{ "filename", filename },
|
|
2896
|
-
{ "filepath", filepath }
|
|
2776
|
+
{ "filepath", filepath },
|
|
2897
2777
|
};
|
|
2898
2778
|
|
|
2899
2779
|
const int id_task = ctx_server.queue_tasks.post(task);
|
|
@@ -2905,11 +2785,11 @@ int main(int argc, char ** argv) {
|
|
|
2905
2785
|
if (result.error) {
|
|
2906
2786
|
res_error(res, result.data);
|
|
2907
2787
|
} else {
|
|
2908
|
-
res
|
|
2788
|
+
res_ok(res, result.data);
|
|
2909
2789
|
}
|
|
2910
2790
|
};
|
|
2911
2791
|
|
|
2912
|
-
const auto handle_slots_restore = [&ctx_server, &res_error, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
2792
|
+
const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) {
|
|
2913
2793
|
json request_data = json::parse(req.body);
|
|
2914
2794
|
std::string filename = request_data.at("filename");
|
|
2915
2795
|
if (!fs_validate_filename(filename)) {
|
|
@@ -2923,7 +2803,7 @@ int main(int argc, char ** argv) {
|
|
|
2923
2803
|
task.data = {
|
|
2924
2804
|
{ "id_slot", id_slot },
|
|
2925
2805
|
{ "filename", filename },
|
|
2926
|
-
{ "filepath", filepath }
|
|
2806
|
+
{ "filepath", filepath },
|
|
2927
2807
|
};
|
|
2928
2808
|
|
|
2929
2809
|
const int id_task = ctx_server.queue_tasks.post(task);
|
|
@@ -2935,11 +2815,11 @@ int main(int argc, char ** argv) {
|
|
|
2935
2815
|
if (result.error) {
|
|
2936
2816
|
res_error(res, result.data);
|
|
2937
2817
|
} else {
|
|
2938
|
-
res
|
|
2818
|
+
res_ok(res, result.data);
|
|
2939
2819
|
}
|
|
2940
2820
|
};
|
|
2941
2821
|
|
|
2942
|
-
const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
|
2822
|
+
const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
|
|
2943
2823
|
server_task task;
|
|
2944
2824
|
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
|
|
2945
2825
|
task.data = {
|
|
@@ -2955,12 +2835,15 @@ int main(int argc, char ** argv) {
|
|
|
2955
2835
|
if (result.error) {
|
|
2956
2836
|
res_error(res, result.data);
|
|
2957
2837
|
} else {
|
|
2958
|
-
res
|
|
2838
|
+
res_ok(res, result.data);
|
|
2959
2839
|
}
|
|
2960
2840
|
};
|
|
2961
2841
|
|
|
2962
|
-
const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
|
2963
|
-
|
|
2842
|
+
const auto handle_slots_action = [¶ms, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
|
|
2843
|
+
if (params.slot_save_path.empty()) {
|
|
2844
|
+
res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2845
|
+
return;
|
|
2846
|
+
}
|
|
2964
2847
|
|
|
2965
2848
|
std::string id_slot_str = req.path_params.at("id_slot");
|
|
2966
2849
|
int id_slot;
|
|
@@ -2985,7 +2868,7 @@ int main(int argc, char ** argv) {
|
|
|
2985
2868
|
}
|
|
2986
2869
|
};
|
|
2987
2870
|
|
|
2988
|
-
const auto handle_props = [&ctx_server](const httplib::Request
|
|
2871
|
+
const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
|
2989
2872
|
std::string template_key = "tokenizer.chat_template", curr_tmpl;
|
|
2990
2873
|
int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
|
|
2991
2874
|
if (tlen > 0) {
|
|
@@ -2994,274 +2877,190 @@ int main(int argc, char ** argv) {
|
|
|
2994
2877
|
curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
|
|
2995
2878
|
}
|
|
2996
2879
|
}
|
|
2997
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
2998
2880
|
json data = {
|
|
2999
2881
|
{ "system_prompt", ctx_server.system_prompt.c_str() },
|
|
3000
2882
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
|
3001
2883
|
{ "total_slots", ctx_server.params.n_parallel },
|
|
3002
|
-
{ "chat_template", curr_tmpl.c_str() }
|
|
2884
|
+
{ "chat_template", curr_tmpl.c_str() },
|
|
3003
2885
|
};
|
|
3004
2886
|
|
|
3005
|
-
res
|
|
2887
|
+
res_ok(res, data);
|
|
3006
2888
|
};
|
|
3007
2889
|
|
|
3008
|
-
const auto
|
|
3009
|
-
if (ctx_server.params.embedding) {
|
|
3010
|
-
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
|
2890
|
+
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
|
|
2891
|
+
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
|
2892
|
+
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3011
2893
|
return;
|
|
3012
2894
|
}
|
|
3013
2895
|
|
|
3014
|
-
|
|
2896
|
+
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
|
|
2897
|
+
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
2898
|
+
ctx_server.queue_tasks.post(tasks);
|
|
3015
2899
|
|
|
3016
|
-
|
|
3017
|
-
|
|
3018
|
-
const int id_task = ctx_server.queue_tasks.get_new_id();
|
|
2900
|
+
bool stream = json_value(data, "stream", false);
|
|
2901
|
+
const auto task_ids = server_task::get_list_id(tasks);
|
|
3019
2902
|
|
|
3020
|
-
|
|
3021
|
-
|
|
3022
|
-
|
|
3023
|
-
|
|
3024
|
-
|
|
3025
|
-
|
|
3026
|
-
|
|
3027
|
-
|
|
3028
|
-
|
|
3029
|
-
|
|
3030
|
-
|
|
3031
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3032
|
-
} else {
|
|
3033
|
-
const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
|
|
3034
|
-
while (true) {
|
|
3035
|
-
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
3036
|
-
if (!result.error) {
|
|
3037
|
-
const std::string str =
|
|
3038
|
-
"data: " +
|
|
3039
|
-
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3040
|
-
"\n\n";
|
|
3041
|
-
|
|
3042
|
-
LOG_VERBOSE("data stream", {
|
|
3043
|
-
{ "to_send", str }
|
|
3044
|
-
});
|
|
3045
|
-
|
|
3046
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3047
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3048
|
-
return false;
|
|
3049
|
-
}
|
|
3050
|
-
|
|
3051
|
-
if (result.stop) {
|
|
3052
|
-
break;
|
|
3053
|
-
}
|
|
3054
|
-
} else {
|
|
3055
|
-
const std::string str =
|
|
3056
|
-
"error: " +
|
|
3057
|
-
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3058
|
-
"\n\n";
|
|
3059
|
-
|
|
3060
|
-
LOG_VERBOSE("data stream", {
|
|
3061
|
-
{ "to_send", str }
|
|
3062
|
-
});
|
|
3063
|
-
|
|
3064
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3065
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3066
|
-
return false;
|
|
3067
|
-
}
|
|
3068
|
-
|
|
3069
|
-
break;
|
|
2903
|
+
if (!stream) {
|
|
2904
|
+
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
|
2905
|
+
if (results.size() == 1) {
|
|
2906
|
+
// single result
|
|
2907
|
+
res_ok(res, results[0].data);
|
|
2908
|
+
} else {
|
|
2909
|
+
// multiple results (multitask)
|
|
2910
|
+
json arr = json::array();
|
|
2911
|
+
for (const auto & res : results) {
|
|
2912
|
+
arr.push_back(res.data);
|
|
3070
2913
|
}
|
|
2914
|
+
res_ok(res, arr);
|
|
3071
2915
|
}
|
|
2916
|
+
}, [&](const json & error_data) {
|
|
2917
|
+
res_error(res, error_data);
|
|
2918
|
+
});
|
|
3072
2919
|
|
|
3073
|
-
|
|
2920
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
2921
|
+
} else {
|
|
2922
|
+
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
|
|
2923
|
+
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
|
2924
|
+
return server_sent_event(sink, "data", result.data);
|
|
2925
|
+
}, [&](const json & error_data) {
|
|
2926
|
+
server_sent_event(sink, "error", error_data);
|
|
2927
|
+
});
|
|
3074
2928
|
sink.done();
|
|
3075
|
-
|
|
3076
|
-
return true;
|
|
2929
|
+
return false;
|
|
3077
2930
|
};
|
|
3078
2931
|
|
|
3079
|
-
auto on_complete = [
|
|
3080
|
-
|
|
3081
|
-
ctx_server.request_cancel(id_task);
|
|
3082
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
2932
|
+
auto on_complete = [task_ids, &ctx_server] (bool) {
|
|
2933
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
3083
2934
|
};
|
|
3084
2935
|
|
|
3085
2936
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
3086
2937
|
}
|
|
3087
2938
|
};
|
|
3088
2939
|
|
|
3089
|
-
const auto
|
|
3090
|
-
|
|
3091
|
-
|
|
3092
|
-
|
|
3093
|
-
{"object", "list"},
|
|
3094
|
-
{"data", {
|
|
3095
|
-
{
|
|
3096
|
-
{"id", params.model_alias},
|
|
3097
|
-
{"object", "model"},
|
|
3098
|
-
{"created", std::time(0)},
|
|
3099
|
-
{"owned_by", "llamacpp"},
|
|
3100
|
-
{"meta", model_meta}
|
|
3101
|
-
},
|
|
3102
|
-
}}
|
|
3103
|
-
};
|
|
2940
|
+
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
2941
|
+
json data = json::parse(req.body);
|
|
2942
|
+
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
|
|
2943
|
+
};
|
|
3104
2944
|
|
|
3105
|
-
|
|
2945
|
+
const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
|
2946
|
+
json data = json::parse(req.body);
|
|
2947
|
+
return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
|
|
3106
2948
|
};
|
|
3107
2949
|
|
|
3108
|
-
|
|
3109
|
-
|
|
3110
|
-
|
|
2950
|
+
// TODO: maybe merge this function with "handle_completions_generic"
|
|
2951
|
+
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
|
|
2952
|
+
if (ctx_server.params.embedding || ctx_server.params.reranking) {
|
|
2953
|
+
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3111
2954
|
return;
|
|
3112
2955
|
}
|
|
3113
2956
|
|
|
3114
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
3115
2957
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
|
3116
2958
|
|
|
3117
|
-
|
|
3118
|
-
|
|
3119
|
-
ctx_server.
|
|
3120
|
-
ctx_server.request_completion(id_task, -1, data, false, false);
|
|
2959
|
+
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
|
|
2960
|
+
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
2961
|
+
ctx_server.queue_tasks.post(tasks);
|
|
3121
2962
|
|
|
2963
|
+
bool stream = json_value(data, "stream", false);
|
|
2964
|
+
const auto task_ids = server_task::get_list_id(tasks);
|
|
3122
2965
|
const auto completion_id = gen_chatcmplid();
|
|
3123
|
-
if (!json_value(data, "stream", false)) {
|
|
3124
|
-
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
3125
2966
|
|
|
3126
|
-
|
|
3127
|
-
|
|
2967
|
+
if (!stream) {
|
|
2968
|
+
ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
|
|
2969
|
+
// multitask is never support in chat completion, there is only one result
|
|
2970
|
+
json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
|
|
2971
|
+
res_ok(res, result_oai);
|
|
2972
|
+
}, [&](const json & error_data) {
|
|
2973
|
+
res_error(res, error_data);
|
|
2974
|
+
});
|
|
3128
2975
|
|
|
3129
|
-
|
|
3130
|
-
} else {
|
|
3131
|
-
res_error(res, result.data);
|
|
3132
|
-
}
|
|
3133
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
2976
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
3134
2977
|
} else {
|
|
3135
|
-
const auto chunked_content_provider = [
|
|
3136
|
-
|
|
3137
|
-
|
|
3138
|
-
|
|
3139
|
-
|
|
3140
|
-
|
|
3141
|
-
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
|
|
3142
|
-
if (!it->empty()) {
|
|
3143
|
-
const std::string str =
|
|
3144
|
-
"data: " +
|
|
3145
|
-
it->dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3146
|
-
"\n\n";
|
|
3147
|
-
LOG_VERBOSE("data stream", {{"to_send", str}});
|
|
3148
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3149
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3150
|
-
return false;
|
|
3151
|
-
}
|
|
3152
|
-
}
|
|
3153
|
-
}
|
|
3154
|
-
if (result.stop) {
|
|
3155
|
-
break;
|
|
2978
|
+
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
|
|
2979
|
+
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
|
|
2980
|
+
std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
|
|
2981
|
+
for (auto & event_data : result_array) {
|
|
2982
|
+
if (event_data.empty()) {
|
|
2983
|
+
continue; // skip the stop token
|
|
3156
2984
|
}
|
|
3157
|
-
|
|
3158
|
-
|
|
3159
|
-
"error: " +
|
|
3160
|
-
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3161
|
-
"\n\n";
|
|
3162
|
-
LOG_VERBOSE("data stream", {{"to_send", str}});
|
|
3163
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3164
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3165
|
-
return false;
|
|
2985
|
+
if (!server_sent_event(sink, "data", event_data)) {
|
|
2986
|
+
return false; // connection is closed
|
|
3166
2987
|
}
|
|
3167
|
-
break;
|
|
3168
2988
|
}
|
|
3169
|
-
|
|
2989
|
+
return true; // ok
|
|
2990
|
+
}, [&](const json & error_data) {
|
|
2991
|
+
server_sent_event(sink, "error", error_data);
|
|
2992
|
+
});
|
|
2993
|
+
static const std::string ev_done = "data: [DONE]\n\n";
|
|
2994
|
+
sink.write(ev_done.data(), ev_done.size());
|
|
3170
2995
|
sink.done();
|
|
3171
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3172
2996
|
return true;
|
|
3173
2997
|
};
|
|
3174
2998
|
|
|
3175
|
-
auto on_complete = [
|
|
3176
|
-
|
|
3177
|
-
ctx_server.request_cancel(id_task);
|
|
3178
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
2999
|
+
auto on_complete = [task_ids, &ctx_server] (bool) {
|
|
3000
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
3179
3001
|
};
|
|
3180
3002
|
|
|
3181
3003
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
|
3182
3004
|
}
|
|
3183
3005
|
};
|
|
3184
3006
|
|
|
3185
|
-
const auto
|
|
3186
|
-
|
|
3187
|
-
|
|
3188
|
-
|
|
3189
|
-
|
|
3190
|
-
|
|
3191
|
-
|
|
3192
|
-
|
|
3193
|
-
|
|
3007
|
+
const auto handle_models = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) {
|
|
3008
|
+
json models = {
|
|
3009
|
+
{"object", "list"},
|
|
3010
|
+
{"data", {
|
|
3011
|
+
{
|
|
3012
|
+
{"id", params.model_alias},
|
|
3013
|
+
{"object", "model"},
|
|
3014
|
+
{"created", std::time(0)},
|
|
3015
|
+
{"owned_by", "llamacpp"},
|
|
3016
|
+
{"meta", ctx_server.model_meta()}
|
|
3017
|
+
},
|
|
3018
|
+
}}
|
|
3019
|
+
};
|
|
3194
3020
|
|
|
3195
|
-
|
|
3021
|
+
res.set_content(models.dump(), MIMETYPE_JSON);
|
|
3022
|
+
};
|
|
3196
3023
|
|
|
3197
|
-
|
|
3198
|
-
|
|
3024
|
+
const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3025
|
+
const json body = json::parse(req.body);
|
|
3199
3026
|
|
|
3200
|
-
|
|
3201
|
-
|
|
3202
|
-
|
|
3203
|
-
|
|
3204
|
-
|
|
3205
|
-
res_error(res, result.data);
|
|
3206
|
-
}
|
|
3027
|
+
json tokens_response = json::array();
|
|
3028
|
+
if (body.count("content") != 0) {
|
|
3029
|
+
const bool add_special = json_value(body, "add_special", false);
|
|
3030
|
+
const bool with_pieces = json_value(body, "with_pieces", false);
|
|
3031
|
+
std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special);
|
|
3207
3032
|
|
|
3208
|
-
|
|
3209
|
-
|
|
3210
|
-
|
|
3211
|
-
|
|
3212
|
-
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
3213
|
-
if (!result.error) {
|
|
3214
|
-
const std::string str =
|
|
3215
|
-
"data: " +
|
|
3216
|
-
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
|
3217
|
-
"\n\n";
|
|
3218
|
-
|
|
3219
|
-
LOG_VERBOSE("data stream", {
|
|
3220
|
-
{ "to_send", str }
|
|
3221
|
-
});
|
|
3222
|
-
|
|
3223
|
-
if (!sink.write(str.c_str(), str.size())) {
|
|
3224
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3225
|
-
return false;
|
|
3226
|
-
}
|
|
3033
|
+
if (with_pieces) {
|
|
3034
|
+
for (const auto& token : tokens) {
|
|
3035
|
+
std::string piece = llama_token_to_piece(ctx_server.ctx, token);
|
|
3036
|
+
json piece_json;
|
|
3227
3037
|
|
|
3228
|
-
|
|
3229
|
-
|
|
3230
|
-
|
|
3038
|
+
// Check if the piece is valid UTF-8
|
|
3039
|
+
if (is_valid_utf8(piece)) {
|
|
3040
|
+
piece_json = piece;
|
|
3231
3041
|
} else {
|
|
3232
|
-
|
|
3042
|
+
// If not valid UTF-8, store as array of byte values
|
|
3043
|
+
piece_json = json::array();
|
|
3044
|
+
for (unsigned char c : piece) {
|
|
3045
|
+
piece_json.push_back(static_cast<int>(c));
|
|
3046
|
+
}
|
|
3233
3047
|
}
|
|
3234
|
-
}
|
|
3235
|
-
|
|
3236
|
-
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3237
|
-
sink.done();
|
|
3238
3048
|
|
|
3239
|
-
|
|
3240
|
-
|
|
3241
|
-
|
|
3242
|
-
|
|
3243
|
-
|
|
3244
|
-
}
|
|
3245
|
-
|
|
3246
|
-
|
|
3049
|
+
tokens_response.push_back({
|
|
3050
|
+
{"id", token},
|
|
3051
|
+
{"piece", piece_json}
|
|
3052
|
+
});
|
|
3053
|
+
}
|
|
3054
|
+
} else {
|
|
3055
|
+
tokens_response = tokens;
|
|
3056
|
+
}
|
|
3247
3057
|
}
|
|
3248
|
-
};
|
|
3249
|
-
|
|
3250
|
-
const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
3251
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
3252
|
-
const json body = json::parse(req.body);
|
|
3253
3058
|
|
|
3254
|
-
|
|
3255
|
-
|
|
3256
|
-
const bool add_special = json_value(body, "add_special", false);
|
|
3257
|
-
tokens = ctx_server.tokenize(body.at("content"), add_special);
|
|
3258
|
-
}
|
|
3259
|
-
const json data = format_tokenizer_response(tokens);
|
|
3260
|
-
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
|
3059
|
+
const json data = format_tokenizer_response(tokens_response);
|
|
3060
|
+
res_ok(res, data);
|
|
3261
3061
|
};
|
|
3262
3062
|
|
|
3263
|
-
const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
|
3264
|
-
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
|
3063
|
+
const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3265
3064
|
const json body = json::parse(req.body);
|
|
3266
3065
|
|
|
3267
3066
|
std::string content;
|
|
@@ -3271,12 +3070,15 @@ int main(int argc, char ** argv) {
|
|
|
3271
3070
|
}
|
|
3272
3071
|
|
|
3273
3072
|
const json data = format_detokenized_response(content);
|
|
3274
|
-
|
|
3073
|
+
res_ok(res, data);
|
|
3275
3074
|
};
|
|
3276
3075
|
|
|
3277
|
-
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
|
3278
|
-
|
|
3279
|
-
|
|
3076
|
+
const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3077
|
+
// TODO: somehow clean up this checks in the future
|
|
3078
|
+
if (!ctx_server.params.embedding || ctx_server.params.reranking) {
|
|
3079
|
+
res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3080
|
+
return;
|
|
3081
|
+
}
|
|
3280
3082
|
const json body = json::parse(req.body);
|
|
3281
3083
|
bool is_openai = false;
|
|
3282
3084
|
|
|
@@ -3294,35 +3096,156 @@ int main(int argc, char ** argv) {
|
|
|
3294
3096
|
}
|
|
3295
3097
|
|
|
3296
3098
|
// create and queue the task
|
|
3297
|
-
json responses;
|
|
3099
|
+
json responses = json::array();
|
|
3100
|
+
bool error = false;
|
|
3298
3101
|
{
|
|
3299
|
-
|
|
3300
|
-
ctx_server.queue_results.
|
|
3301
|
-
ctx_server.
|
|
3102
|
+
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
|
|
3103
|
+
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
3104
|
+
ctx_server.queue_tasks.post(tasks);
|
|
3302
3105
|
|
|
3303
3106
|
// get the result
|
|
3304
|
-
|
|
3305
|
-
|
|
3306
|
-
|
|
3307
|
-
|
|
3308
|
-
|
|
3309
|
-
responses = result.data.at("results");
|
|
3310
|
-
} else {
|
|
3311
|
-
// result for single task
|
|
3312
|
-
responses = std::vector<json>{result.data};
|
|
3107
|
+
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
3108
|
+
|
|
3109
|
+
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
|
3110
|
+
for (const auto & res : results) {
|
|
3111
|
+
responses.push_back(res.data);
|
|
3313
3112
|
}
|
|
3314
|
-
}
|
|
3315
|
-
|
|
3316
|
-
|
|
3317
|
-
|
|
3318
|
-
|
|
3113
|
+
}, [&](const json & error_data) {
|
|
3114
|
+
res_error(res, error_data);
|
|
3115
|
+
error = true;
|
|
3116
|
+
});
|
|
3117
|
+
|
|
3118
|
+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
3119
|
+
}
|
|
3120
|
+
|
|
3121
|
+
if (error) {
|
|
3122
|
+
return;
|
|
3319
3123
|
}
|
|
3320
3124
|
|
|
3321
3125
|
// write JSON response
|
|
3322
3126
|
json root = is_openai
|
|
3323
3127
|
? format_embeddings_response_oaicompat(body, responses)
|
|
3324
3128
|
: responses[0];
|
|
3325
|
-
|
|
3129
|
+
res_ok(res, root);
|
|
3130
|
+
};
|
|
3131
|
+
|
|
3132
|
+
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
|
3133
|
+
if (!ctx_server.params.reranking) {
|
|
3134
|
+
res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
|
|
3135
|
+
return;
|
|
3136
|
+
}
|
|
3137
|
+
const json body = json::parse(req.body);
|
|
3138
|
+
|
|
3139
|
+
// TODO: implement
|
|
3140
|
+
//int top_n = 1;
|
|
3141
|
+
//if (body.count("top_n") != 1) {
|
|
3142
|
+
// top_n = body.at("top_n");
|
|
3143
|
+
//} else {
|
|
3144
|
+
// res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
3145
|
+
// return;
|
|
3146
|
+
//}
|
|
3147
|
+
|
|
3148
|
+
json query;
|
|
3149
|
+
if (body.count("query") == 1) {
|
|
3150
|
+
query = body.at("query");
|
|
3151
|
+
if (!query.is_string()) {
|
|
3152
|
+
res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
|
|
3153
|
+
return;
|
|
3154
|
+
}
|
|
3155
|
+
} else {
|
|
3156
|
+
res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
|
3157
|
+
return;
|
|
3158
|
+
}
|
|
3159
|
+
|
|
3160
|
+
std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
|
|
3161
|
+
if (documents.empty()) {
|
|
3162
|
+
res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
|
|
3163
|
+
return;
|
|
3164
|
+
}
|
|
3165
|
+
|
|
3166
|
+
// construct prompt object: array of ["query", "doc0", "doc1", ...]
|
|
3167
|
+
json prompt;
|
|
3168
|
+
prompt.push_back(query);
|
|
3169
|
+
for (const auto & doc : documents) {
|
|
3170
|
+
prompt.push_back(doc);
|
|
3171
|
+
}
|
|
3172
|
+
|
|
3173
|
+
LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
|
|
3174
|
+
|
|
3175
|
+
// create and queue the task
|
|
3176
|
+
json responses = json::array();
|
|
3177
|
+
bool error = false;
|
|
3178
|
+
{
|
|
3179
|
+
std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
|
|
3180
|
+
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
3181
|
+
ctx_server.queue_tasks.post(tasks);
|
|
3182
|
+
|
|
3183
|
+
// get the result
|
|
3184
|
+
std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
|
|
3185
|
+
|
|
3186
|
+
ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
|
|
3187
|
+
for (const auto & res : results) {
|
|
3188
|
+
responses.push_back(res.data);
|
|
3189
|
+
}
|
|
3190
|
+
}, [&](const json & error_data) {
|
|
3191
|
+
res_error(res, error_data);
|
|
3192
|
+
error = true;
|
|
3193
|
+
});
|
|
3194
|
+
}
|
|
3195
|
+
|
|
3196
|
+
if (error) {
|
|
3197
|
+
return;
|
|
3198
|
+
}
|
|
3199
|
+
|
|
3200
|
+
// write JSON response
|
|
3201
|
+
json root = format_response_rerank(body, responses);
|
|
3202
|
+
res_ok(res, root);
|
|
3203
|
+
};
|
|
3204
|
+
|
|
3205
|
+
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
|
3206
|
+
json result = json::array();
|
|
3207
|
+
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
|
3208
|
+
auto & lora = ctx_server.loras[i];
|
|
3209
|
+
result.push_back({
|
|
3210
|
+
{"id", i},
|
|
3211
|
+
{"path", lora.path},
|
|
3212
|
+
{"scale", lora.scale},
|
|
3213
|
+
});
|
|
3214
|
+
}
|
|
3215
|
+
res_ok(res, result);
|
|
3216
|
+
res.status = 200; // HTTP OK
|
|
3217
|
+
};
|
|
3218
|
+
|
|
3219
|
+
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
|
3220
|
+
const std::vector<json> body = json::parse(req.body);
|
|
3221
|
+
int max_idx = ctx_server.loras.size();
|
|
3222
|
+
|
|
3223
|
+
// clear existing value
|
|
3224
|
+
for (auto & lora : ctx_server.loras) {
|
|
3225
|
+
lora.scale = 0.0f;
|
|
3226
|
+
}
|
|
3227
|
+
|
|
3228
|
+
// set value
|
|
3229
|
+
for (auto entry : body) {
|
|
3230
|
+
int id = entry.at("id");
|
|
3231
|
+
float scale = entry.at("scale");
|
|
3232
|
+
if (0 <= id && id < max_idx) {
|
|
3233
|
+
ctx_server.loras[id].scale = scale;
|
|
3234
|
+
} else {
|
|
3235
|
+
throw std::runtime_error("invalid adapter id");
|
|
3236
|
+
}
|
|
3237
|
+
}
|
|
3238
|
+
|
|
3239
|
+
server_task task;
|
|
3240
|
+
task.type = SERVER_TASK_TYPE_SET_LORA;
|
|
3241
|
+
const int id_task = ctx_server.queue_tasks.post(task);
|
|
3242
|
+
ctx_server.queue_results.add_waiting_task_id(id_task);
|
|
3243
|
+
|
|
3244
|
+
server_task_result result = ctx_server.queue_results.recv(id_task);
|
|
3245
|
+
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
|
3246
|
+
|
|
3247
|
+
res_ok(res, result.data);
|
|
3248
|
+
res.status = 200; // HTTP OK
|
|
3326
3249
|
};
|
|
3327
3250
|
|
|
3328
3251
|
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
|
|
@@ -3363,7 +3286,6 @@ int main(int argc, char ** argv) {
|
|
|
3363
3286
|
|
|
3364
3287
|
// register API routes
|
|
3365
3288
|
svr->Get ("/health", handle_health);
|
|
3366
|
-
svr->Get ("/slots", handle_slots);
|
|
3367
3289
|
svr->Get ("/metrics", handle_metrics);
|
|
3368
3290
|
svr->Get ("/props", handle_props);
|
|
3369
3291
|
svr->Get ("/v1/models", handle_models);
|
|
@@ -3376,12 +3298,18 @@ int main(int argc, char ** argv) {
|
|
|
3376
3298
|
svr->Post("/embedding", handle_embeddings); // legacy
|
|
3377
3299
|
svr->Post("/embeddings", handle_embeddings);
|
|
3378
3300
|
svr->Post("/v1/embeddings", handle_embeddings);
|
|
3301
|
+
svr->Post("/rerank", handle_rerank);
|
|
3302
|
+
svr->Post("/reranking", handle_rerank);
|
|
3303
|
+
svr->Post("/v1/rerank", handle_rerank);
|
|
3304
|
+
svr->Post("/v1/reranking", handle_rerank);
|
|
3379
3305
|
svr->Post("/tokenize", handle_tokenize);
|
|
3380
3306
|
svr->Post("/detokenize", handle_detokenize);
|
|
3381
|
-
|
|
3382
|
-
|
|
3383
|
-
|
|
3384
|
-
|
|
3307
|
+
// LoRA adapters hotswap
|
|
3308
|
+
svr->Get ("/lora-adapters", handle_lora_adapters_list);
|
|
3309
|
+
svr->Post("/lora-adapters", handle_lora_adapters_apply);
|
|
3310
|
+
// Save & load slots
|
|
3311
|
+
svr->Get ("/slots", handle_slots);
|
|
3312
|
+
svr->Post("/slots/:id_slot", handle_slots_action);
|
|
3385
3313
|
|
|
3386
3314
|
//
|
|
3387
3315
|
// Start the server
|
|
@@ -3393,36 +3321,66 @@ int main(int argc, char ** argv) {
|
|
|
3393
3321
|
log_data["n_threads_http"] = std::to_string(params.n_threads_http);
|
|
3394
3322
|
svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); };
|
|
3395
3323
|
|
|
3396
|
-
|
|
3324
|
+
// clean up function, to be called before exit
|
|
3325
|
+
auto clean_up = [&svr]() {
|
|
3326
|
+
svr->stop();
|
|
3327
|
+
llama_backend_free();
|
|
3328
|
+
};
|
|
3329
|
+
|
|
3330
|
+
// bind HTTP listen port, run the HTTP server in a thread
|
|
3331
|
+
if (!svr->bind_to_port(params.hostname, params.port)) {
|
|
3332
|
+
//LOG_ERROR("couldn't bind HTTP server socket", {
|
|
3333
|
+
// {"hostname", params.hostname},
|
|
3334
|
+
// {"port", params.port},
|
|
3335
|
+
//});
|
|
3336
|
+
LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
|
|
3337
|
+
clean_up();
|
|
3338
|
+
return 1;
|
|
3339
|
+
}
|
|
3340
|
+
std::thread t([&]() { svr->listen_after_bind(); });
|
|
3341
|
+
svr->wait_until_ready();
|
|
3397
3342
|
|
|
3398
|
-
|
|
3399
|
-
|
|
3400
|
-
|
|
3401
|
-
|
|
3402
|
-
|
|
3343
|
+
LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http);
|
|
3344
|
+
|
|
3345
|
+
// load the model
|
|
3346
|
+
LOG_INF("%s: loading model\n", __func__);
|
|
3347
|
+
|
|
3348
|
+
if (!ctx_server.load_model(params)) {
|
|
3349
|
+
clean_up();
|
|
3350
|
+
t.join();
|
|
3351
|
+
LOG_ERR("%s: exiting due to model loading error\n", __func__);
|
|
3352
|
+
return 1;
|
|
3353
|
+
}
|
|
3354
|
+
|
|
3355
|
+
ctx_server.init();
|
|
3356
|
+
state.store(SERVER_STATE_READY);
|
|
3357
|
+
|
|
3358
|
+
LOG_INF("%s: model loaded\n", __func__);
|
|
3359
|
+
|
|
3360
|
+
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
|
3361
|
+
if (params.chat_template.empty()) {
|
|
3362
|
+
if (!ctx_server.validate_model_chat_template()) {
|
|
3363
|
+
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
|
3364
|
+
params.chat_template = "chatml";
|
|
3403
3365
|
}
|
|
3366
|
+
}
|
|
3404
3367
|
|
|
3405
|
-
|
|
3406
|
-
|
|
3368
|
+
// print sample chat example to make it clear which template is used
|
|
3369
|
+
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
|
3407
3370
|
|
|
3408
3371
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
|
3409
|
-
|
|
3410
|
-
ctx_server.queue_tasks.on_finish_multitask(std::bind(
|
|
3411
|
-
&server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
|
|
3372
|
+
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
|
3412
3373
|
ctx_server.queue_tasks.on_update_slots(std::bind(
|
|
3413
|
-
|
|
3414
|
-
ctx_server.queue_results.on_multitask_update(std::bind(
|
|
3415
|
-
&server_queue::update_multitask,
|
|
3416
|
-
&ctx_server.queue_tasks,
|
|
3417
|
-
std::placeholders::_1,
|
|
3418
|
-
std::placeholders::_2,
|
|
3419
|
-
std::placeholders::_3
|
|
3420
|
-
));
|
|
3374
|
+
&server_context::update_slots, &ctx_server));
|
|
3421
3375
|
|
|
3422
3376
|
shutdown_handler = [&](int) {
|
|
3423
3377
|
ctx_server.queue_tasks.terminate();
|
|
3424
3378
|
};
|
|
3425
3379
|
|
|
3380
|
+
LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
|
|
3381
|
+
|
|
3382
|
+
ctx_server.queue_tasks.start_loop();
|
|
3383
|
+
|
|
3426
3384
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
|
|
3427
3385
|
struct sigaction sigint_action;
|
|
3428
3386
|
sigint_action.sa_handler = signal_handler;
|
|
@@ -3437,12 +3395,8 @@ int main(int argc, char ** argv) {
|
|
|
3437
3395
|
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
|
|
3438
3396
|
#endif
|
|
3439
3397
|
|
|
3440
|
-
|
|
3441
|
-
|
|
3442
|
-
svr->stop();
|
|
3398
|
+
clean_up();
|
|
3443
3399
|
t.join();
|
|
3444
3400
|
|
|
3445
|
-
llama_backend_free();
|
|
3446
|
-
|
|
3447
3401
|
return 0;
|
|
3448
3402
|
}
|