@fugood/llama.node 0.0.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. package/CMakeLists.txt +85 -0
  2. package/README.md +56 -0
  3. package/bin/darwin/arm64/llama-node.node +0 -0
  4. package/bin/darwin/x64/llama-node.node +0 -0
  5. package/bin/linux/arm64/llama-node.node +0 -0
  6. package/bin/linux/x64/llama-node.node +0 -0
  7. package/bin/win32/arm64/llama-node.node +0 -0
  8. package/bin/win32/arm64/node.lib +0 -0
  9. package/bin/win32/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/node.lib +0 -0
  11. package/lib/binding.js +13 -0
  12. package/lib/binding.ts +57 -0
  13. package/lib/index.js +24 -0
  14. package/lib/index.ts +13 -0
  15. package/package.json +65 -0
  16. package/src/addons.cpp +506 -0
  17. package/src/llama.cpp/CMakeLists.txt +1320 -0
  18. package/src/llama.cpp/build.zig +172 -0
  19. package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
  20. package/src/llama.cpp/common/CMakeLists.txt +87 -0
  21. package/src/llama.cpp/common/base64.hpp +392 -0
  22. package/src/llama.cpp/common/common.cpp +2949 -0
  23. package/src/llama.cpp/common/common.h +324 -0
  24. package/src/llama.cpp/common/console.cpp +501 -0
  25. package/src/llama.cpp/common/console.h +19 -0
  26. package/src/llama.cpp/common/grammar-parser.cpp +440 -0
  27. package/src/llama.cpp/common/grammar-parser.h +29 -0
  28. package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/json.hpp +24766 -0
  31. package/src/llama.cpp/common/log.h +724 -0
  32. package/src/llama.cpp/common/ngram-cache.cpp +282 -0
  33. package/src/llama.cpp/common/ngram-cache.h +94 -0
  34. package/src/llama.cpp/common/sampling.cpp +353 -0
  35. package/src/llama.cpp/common/sampling.h +147 -0
  36. package/src/llama.cpp/common/stb_image.h +8396 -0
  37. package/src/llama.cpp/common/train.cpp +1513 -0
  38. package/src/llama.cpp/common/train.h +233 -0
  39. package/src/llama.cpp/examples/CMakeLists.txt +52 -0
  40. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
  41. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
  42. package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
  43. package/src/llama.cpp/examples/batched/batched.cpp +262 -0
  44. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
  45. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
  46. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
  47. package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
  48. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
  49. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
  50. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
  51. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
  52. package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
  54. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
  55. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
  56. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
  58. package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
  59. package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
  60. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
  61. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
  62. package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
  64. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
  65. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
  66. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
  67. package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
  68. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
  69. package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
  70. package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
  71. package/src/llama.cpp/examples/infill/infill.cpp +767 -0
  72. package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
  73. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
  74. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
  75. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
  76. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
  77. package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
  78. package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
  79. package/src/llama.cpp/examples/llava/clip.h +85 -0
  80. package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
  81. package/src/llama.cpp/examples/llava/llava.cpp +426 -0
  82. package/src/llama.cpp/examples/llava/llava.h +50 -0
  83. package/src/llama.cpp/examples/llava/requirements.txt +3 -0
  84. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
  85. package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
  86. package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
  87. package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
  88. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
  89. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
  90. package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
  91. package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
  92. package/src/llama.cpp/examples/main/main.cpp +957 -0
  93. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
  94. package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
  95. package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
  96. package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
  97. package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
  98. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
  99. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
  100. package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
  101. package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
  102. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
  103. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
  104. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
  106. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
  107. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
  108. package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
  109. package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
  110. package/src/llama.cpp/examples/server/httplib.h +9465 -0
  111. package/src/llama.cpp/examples/server/server.cpp +3826 -0
  112. package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
  113. package/src/llama.cpp/examples/server/utils.hpp +653 -0
  114. package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
  115. package/src/llama.cpp/examples/simple/simple.cpp +183 -0
  116. package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
  117. package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
  118. package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
  119. package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
  120. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
  121. package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
  122. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
  123. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
  124. package/src/llama.cpp/ggml-alloc.c +985 -0
  125. package/src/llama.cpp/ggml-alloc.h +76 -0
  126. package/src/llama.cpp/ggml-backend-impl.h +141 -0
  127. package/src/llama.cpp/ggml-backend.c +2099 -0
  128. package/src/llama.cpp/ggml-backend.h +233 -0
  129. package/src/llama.cpp/ggml-common.h +1853 -0
  130. package/src/llama.cpp/ggml-cuda.h +43 -0
  131. package/src/llama.cpp/ggml-impl.h +265 -0
  132. package/src/llama.cpp/ggml-kompute.cpp +2006 -0
  133. package/src/llama.cpp/ggml-kompute.h +46 -0
  134. package/src/llama.cpp/ggml-metal.h +66 -0
  135. package/src/llama.cpp/ggml-mpi.c +216 -0
  136. package/src/llama.cpp/ggml-mpi.h +39 -0
  137. package/src/llama.cpp/ggml-opencl.cpp +2301 -0
  138. package/src/llama.cpp/ggml-opencl.h +36 -0
  139. package/src/llama.cpp/ggml-quants.c +12678 -0
  140. package/src/llama.cpp/ggml-quants.h +133 -0
  141. package/src/llama.cpp/ggml-sycl.cpp +17882 -0
  142. package/src/llama.cpp/ggml-sycl.h +49 -0
  143. package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
  144. package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
  145. package/src/llama.cpp/ggml-vulkan.h +29 -0
  146. package/src/llama.cpp/ggml.c +21819 -0
  147. package/src/llama.cpp/ggml.h +2403 -0
  148. package/src/llama.cpp/llama.cpp +17468 -0
  149. package/src/llama.cpp/llama.h +1117 -0
  150. package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
  151. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
  152. package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
  153. package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
  154. package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
  155. package/src/llama.cpp/prompts/alpaca.txt +1 -0
  156. package/src/llama.cpp/prompts/assistant.txt +31 -0
  157. package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
  158. package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
  159. package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
  160. package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
  161. package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
  162. package/src/llama.cpp/prompts/chat.txt +28 -0
  163. package/src/llama.cpp/prompts/dan-modified.txt +1 -0
  164. package/src/llama.cpp/prompts/dan.txt +1 -0
  165. package/src/llama.cpp/prompts/mnemonics.txt +93 -0
  166. package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
  167. package/src/llama.cpp/prompts/reason-act.txt +18 -0
  168. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
  169. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
  170. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
  171. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
  172. package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
  173. package/src/llama.cpp/requirements.txt +12 -0
  174. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
  175. package/src/llama.cpp/scripts/xxd.cmake +16 -0
  176. package/src/llama.cpp/sgemm.cpp +999 -0
  177. package/src/llama.cpp/sgemm.h +12 -0
  178. package/src/llama.cpp/tests/CMakeLists.txt +78 -0
  179. package/src/llama.cpp/tests/get-model.cpp +21 -0
  180. package/src/llama.cpp/tests/get-model.h +2 -0
  181. package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
  182. package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
  183. package/src/llama.cpp/tests/test-c.c +7 -0
  184. package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
  185. package/src/llama.cpp/tests/test-double-float.cpp +57 -0
  186. package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
  187. package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
  188. package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
  189. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
  190. package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
  191. package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
  192. package/src/llama.cpp/tests/test-opt.cpp +181 -0
  193. package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
  194. package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
  195. package/src/llama.cpp/tests/test-rope.cpp +221 -0
  196. package/src/llama.cpp/tests/test-sampling.cpp +301 -0
  197. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
  198. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
  199. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
  200. package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
  201. package/src/llama.cpp/unicode-data.cpp +1651 -0
  202. package/src/llama.cpp/unicode-data.h +16 -0
  203. package/src/llama.cpp/unicode.cpp +277 -0
  204. package/src/llama.cpp/unicode.h +28 -0
@@ -0,0 +1,3826 @@
1
+ #include "utils.hpp"
2
+
3
+ #include "common.h"
4
+ #include "json-schema-to-grammar.h"
5
+ #include "llama.h"
6
+ #include "grammar-parser.h"
7
+
8
+ #ifndef NDEBUG
9
+ // crash the server in debug mode, otherwise send an http 500 error
10
+ #define CPPHTTPLIB_NO_EXCEPTIONS 1
11
+ #endif
12
+ // increase max payload length to allow use of larger context size
13
+ #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
14
+ #include "httplib.h"
15
+ #include "json.hpp"
16
+
17
+ // auto generated files (update with ./deps.sh)
18
+ #include "index.html.hpp"
19
+ #include "index.js.hpp"
20
+ #include "completion.js.hpp"
21
+ #include "json-schema-to-grammar.mjs.hpp"
22
+
23
+ #include <atomic>
24
+ #include <chrono>
25
+ #include <condition_variable>
26
+ #include <cstddef>
27
+ #include <set>
28
+ #include <mutex>
29
+ #include <thread>
30
+ #include <signal.h>
31
+ #include <memory>
32
+
33
+ using json = nlohmann::ordered_json;
34
+
35
+ bool server_verbose = false;
36
+ bool server_log_json = true;
37
+
38
+ enum stop_type {
39
+ STOP_TYPE_FULL,
40
+ STOP_TYPE_PARTIAL,
41
+ };
42
+
43
+ enum slot_state {
44
+ SLOT_STATE_IDLE,
45
+ SLOT_STATE_PROCESSING,
46
+ };
47
+
48
+ enum slot_command {
49
+ SLOT_COMMAND_NONE,
50
+ SLOT_COMMAND_LOAD_PROMPT,
51
+ SLOT_COMMAND_RELEASE,
52
+ };
53
+
54
+ enum server_state {
55
+ SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
56
+ SERVER_STATE_READY, // Server is ready and model is loaded
57
+ SERVER_STATE_ERROR // An error occurred, load_model failed
58
+ };
59
+
60
+ enum server_task_type {
61
+ SERVER_TASK_TYPE_COMPLETION,
62
+ SERVER_TASK_TYPE_CANCEL,
63
+ SERVER_TASK_TYPE_NEXT_RESPONSE,
64
+ SERVER_TASK_TYPE_METRICS,
65
+ SERVER_TASK_TYPE_SLOT_SAVE,
66
+ SERVER_TASK_TYPE_SLOT_RESTORE,
67
+ SERVER_TASK_TYPE_SLOT_ERASE,
68
+ };
69
+
70
+ struct server_task {
71
+ int id = -1; // to be filled by server_queue
72
+ int id_multi = -1;
73
+ int id_target = -1;
74
+
75
+ server_task_type type;
76
+ json data;
77
+
78
+ bool infill = false;
79
+ bool embedding = false;
80
+ };
81
+
82
+ struct server_task_result {
83
+ int id = -1;
84
+ int id_multi = -1;
85
+
86
+ json data;
87
+
88
+ bool stop;
89
+ bool error;
90
+ };
91
+
92
+ struct server_task_multi {
93
+ int id = -1;
94
+
95
+ std::set<int> subtasks_remaining;
96
+ std::vector<server_task_result> results;
97
+ };
98
+
99
+ struct slot_params {
100
+ bool stream = true;
101
+ bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
102
+
103
+ uint32_t seed = -1; // RNG seed
104
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
105
+ int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
106
+ int32_t n_predict = -1; // new tokens to predict
107
+
108
+ std::vector<std::string> antiprompt;
109
+
110
+ json input_prefix;
111
+ json input_suffix;
112
+ };
113
+
114
+ struct server_params {
115
+ int32_t port = 8080;
116
+ int32_t read_timeout = 600;
117
+ int32_t write_timeout = 600;
118
+ int32_t n_threads_http = -1;
119
+
120
+ std::string hostname = "127.0.0.1";
121
+ std::string public_path = "";
122
+ std::string chat_template = "";
123
+ std::string system_prompt = "";
124
+
125
+ std::vector<std::string> api_keys;
126
+
127
+ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
128
+ std::string ssl_key_file = "";
129
+ std::string ssl_cert_file = "";
130
+ #endif
131
+
132
+ bool slots_endpoint = true;
133
+ bool metrics_endpoint = false;
134
+ std::string slot_save_path;
135
+ };
136
+
137
+ struct server_slot {
138
+ int id;
139
+ int id_task = -1;
140
+ int id_multi = -1;
141
+
142
+ struct slot_params params;
143
+
144
+ slot_state state = SLOT_STATE_IDLE;
145
+ slot_command command = SLOT_COMMAND_NONE;
146
+
147
+ // used to determine the slot that has been used the longest
148
+ int64_t t_last_used = -1;
149
+
150
+ // generation props
151
+ int32_t n_ctx = 0; // context size per slot
152
+ int32_t n_past = 0;
153
+ int32_t n_decoded = 0;
154
+ int32_t n_remaining = -1;
155
+ int32_t i_batch = -1;
156
+ int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
157
+
158
+ int32_t n_prompt_tokens = 0;
159
+ int32_t n_prompt_tokens_processed = 0;
160
+
161
+ json prompt;
162
+
163
+ // when a task is submitted, we first tokenize the prompt and store it here
164
+ std::vector<llama_token> prompt_tokens;
165
+
166
+ std::string generated_text;
167
+ std::vector<llama_token> cache_tokens;
168
+ std::vector<completion_token_output> generated_token_probs;
169
+
170
+ bool infill = false;
171
+ bool embedding = false;
172
+ bool has_next_token = true;
173
+ bool truncated = false;
174
+ bool stopped_eos = false;
175
+ bool stopped_word = false;
176
+ bool stopped_limit = false;
177
+
178
+ bool oaicompat = false;
179
+
180
+ std::string oaicompat_model;
181
+ std::string stopping_word;
182
+
183
+ // sampling
184
+ llama_token sampled;
185
+ struct llama_sampling_params sparams;
186
+ llama_sampling_context * ctx_sampling = nullptr;
187
+ json json_schema;
188
+
189
+ int32_t ga_i = 0; // group-attention state
190
+ int32_t ga_n = 1; // group-attention factor
191
+ int32_t ga_w = 512; // group-attention width
192
+
193
+ int32_t n_past_se = 0; // self-extend
194
+
195
+ // stats
196
+ size_t n_sent_text = 0; // number of sent text character
197
+ size_t n_sent_token_probs = 0;
198
+
199
+ int64_t t_start_process_prompt;
200
+ int64_t t_start_generation;
201
+
202
+ double t_prompt_processing; // ms
203
+ double t_token_generation; // ms
204
+
205
+ void reset() {
206
+ n_prompt_tokens = 0;
207
+ generated_text = "";
208
+ truncated = false;
209
+ stopped_eos = false;
210
+ stopped_word = false;
211
+ stopped_limit = false;
212
+ stopping_word = "";
213
+ n_past = 0;
214
+ n_sent_text = 0;
215
+ n_sent_token_probs = 0;
216
+ infill = false;
217
+ ga_i = 0;
218
+ n_past_se = 0;
219
+
220
+ generated_token_probs.clear();
221
+ }
222
+
223
+ bool has_budget(gpt_params &global_params) {
224
+ if (params.n_predict == -1 && global_params.n_predict == -1) {
225
+ return true; // limitless
226
+ }
227
+
228
+ n_remaining = -1;
229
+
230
+ if (params.n_predict != -1) {
231
+ n_remaining = params.n_predict - n_decoded;
232
+ } else if (global_params.n_predict != -1) {
233
+ n_remaining = global_params.n_predict - n_decoded;
234
+ }
235
+
236
+ return n_remaining > 0; // no budget
237
+ }
238
+
239
+ bool available() const {
240
+ return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
241
+ }
242
+
243
+ bool is_processing() const {
244
+ return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING;
245
+ }
246
+
247
+ void add_token_string(const completion_token_output & token) {
248
+ if (command == SLOT_COMMAND_RELEASE) {
249
+ return;
250
+ }
251
+ generated_token_probs.push_back(token);
252
+ }
253
+
254
+ void release() {
255
+ if (state == SLOT_STATE_PROCESSING) {
256
+ t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
257
+ command = SLOT_COMMAND_RELEASE;
258
+ }
259
+ }
260
+
261
+ json get_formated_timings() const {
262
+ return json {
263
+ {"prompt_n", n_prompt_tokens_processed},
264
+ {"prompt_ms", t_prompt_processing},
265
+ {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed},
266
+ {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed},
267
+
268
+ {"predicted_n", n_decoded},
269
+ {"predicted_ms", t_token_generation},
270
+ {"predicted_per_token_ms", t_token_generation / n_decoded},
271
+ {"predicted_per_second", 1e3 / t_token_generation * n_decoded},
272
+ };
273
+ }
274
+
275
+ size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) {
276
+ size_t stop_pos = std::string::npos;
277
+
278
+ for (const std::string & word : params.antiprompt) {
279
+ size_t pos;
280
+
281
+ if (type == STOP_TYPE_FULL) {
282
+ const size_t tmp = word.size() + last_token_size;
283
+ const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
284
+
285
+ pos = text.find(word, from_pos);
286
+ } else {
287
+ pos = find_partial_stop_string(word, text);
288
+ }
289
+
290
+ if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
291
+ if (type == STOP_TYPE_FULL) {
292
+ stopped_word = true;
293
+ stopping_word = word;
294
+ has_next_token = false;
295
+ }
296
+ stop_pos = pos;
297
+ }
298
+ }
299
+
300
+ return stop_pos;
301
+ }
302
+
303
+ void print_timings() const {
304
+ char buffer[512];
305
+
306
+ double t_token = t_prompt_processing / n_prompt_tokens_processed;
307
+ double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
308
+
309
+ snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
310
+ t_prompt_processing, n_prompt_tokens_processed,
311
+ t_token, n_tokens_second);
312
+
313
+ LOG_INFO(buffer, {
314
+ {"id_slot", id},
315
+ {"id_task", id_task},
316
+ {"t_prompt_processing", t_prompt_processing},
317
+ {"n_prompt_tokens_processed", n_prompt_tokens_processed},
318
+ {"t_token", t_token},
319
+ {"n_tokens_second", n_tokens_second},
320
+ });
321
+
322
+ t_token = t_token_generation / n_decoded;
323
+ n_tokens_second = 1e3 / t_token_generation * n_decoded;
324
+
325
+ snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
326
+ t_token_generation, n_decoded,
327
+ t_token, n_tokens_second);
328
+
329
+ LOG_INFO(buffer, {
330
+ {"id_slot", id},
331
+ {"id_task", id_task},
332
+ {"t_token_generation", t_token_generation},
333
+ {"n_decoded", n_decoded},
334
+ {"t_token", t_token},
335
+ {"n_tokens_second", n_tokens_second},
336
+ });
337
+
338
+ snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
339
+
340
+ LOG_INFO(buffer, {
341
+ {"id_slot", id},
342
+ {"id_task", id_task},
343
+ {"t_prompt_processing", t_prompt_processing},
344
+ {"t_token_generation", t_token_generation},
345
+ {"t_total", t_prompt_processing + t_token_generation},
346
+ });
347
+ }
348
+ };
349
+
350
+ struct server_metrics {
351
+ int64_t t_start = 0;
352
+
353
+ uint64_t n_prompt_tokens_processed_total = 0;
354
+ uint64_t t_prompt_processing_total = 0;
355
+ uint64_t n_tokens_predicted_total = 0;
356
+ uint64_t t_tokens_generation_total = 0;
357
+
358
+ uint64_t n_prompt_tokens_processed = 0;
359
+ uint64_t t_prompt_processing = 0;
360
+
361
+ uint64_t n_tokens_predicted = 0;
362
+ uint64_t t_tokens_generation = 0;
363
+
364
+ void init() {
365
+ t_start = ggml_time_us();
366
+ }
367
+
368
+ void on_prompt_eval(const server_slot & slot) {
369
+ n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed;
370
+ n_prompt_tokens_processed += slot.n_prompt_tokens_processed;
371
+ t_prompt_processing += slot.t_prompt_processing;
372
+ t_prompt_processing_total += slot.t_prompt_processing;
373
+ }
374
+
375
+ void on_prediction(const server_slot & slot) {
376
+ n_tokens_predicted_total += slot.n_decoded;
377
+ n_tokens_predicted += slot.n_decoded;
378
+ t_tokens_generation += slot.t_token_generation;
379
+ t_tokens_generation_total += slot.t_token_generation;
380
+ }
381
+
382
+ void reset_bucket() {
383
+ n_prompt_tokens_processed = 0;
384
+ t_prompt_processing = 0;
385
+ n_tokens_predicted = 0;
386
+ t_tokens_generation = 0;
387
+ }
388
+ };
389
+
390
+ struct server_queue {
391
+ int id = 0;
392
+ bool running;
393
+
394
+ // queues
395
+ std::vector<server_task> queue_tasks;
396
+ std::vector<server_task> queue_tasks_deferred;
397
+
398
+ std::vector<server_task_multi> queue_multitasks;
399
+
400
+ std::mutex mutex_tasks;
401
+ std::condition_variable condition_tasks;
402
+
403
+ // callback functions
404
+ std::function<void(server_task &)> callback_new_task;
405
+ std::function<void(server_task_multi &)> callback_finish_multitask;
406
+ std::function<void(void)> callback_update_slots;
407
+
408
+ // Add a new task to the end of the queue
409
+ int post(server_task task) {
410
+ std::unique_lock<std::mutex> lock(mutex_tasks);
411
+ if (task.id == -1) {
412
+ task.id = id++;
413
+ LOG_VERBOSE("new task id", {{"new_id", task.id}});
414
+ }
415
+ queue_tasks.push_back(std::move(task));
416
+ condition_tasks.notify_one();
417
+ return task.id;
418
+ }
419
+
420
+ // Add a new task, but defer until one slot is available
421
+ void defer(server_task task) {
422
+ std::unique_lock<std::mutex> lock(mutex_tasks);
423
+ queue_tasks_deferred.push_back(std::move(task));
424
+ }
425
+
426
+ // Get the next id for creating anew task
427
+ int get_new_id() {
428
+ std::unique_lock<std::mutex> lock(mutex_tasks);
429
+ int new_id = id++;
430
+ LOG_VERBOSE("new task id", {{"new_id", new_id}});
431
+ return new_id;
432
+ }
433
+
434
+ // Register function to process a new task
435
+ void on_new_task(std::function<void(server_task &)> callback) {
436
+ callback_new_task = std::move(callback);
437
+ }
438
+
439
+ // Register function to process a multitask when it is finished
440
+ void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
441
+ callback_finish_multitask = std::move(callback);
442
+ }
443
+
444
+ // Register the function to be called when all slots data is ready to be processed
445
+ void on_update_slots(std::function<void(void)> callback) {
446
+ callback_update_slots = std::move(callback);
447
+ }
448
+
449
+ // Call when the state of one slot is changed
450
+ void notify_slot_changed() {
451
+ // move deferred tasks back to main loop
452
+ std::unique_lock<std::mutex> lock(mutex_tasks);
453
+ for (auto & task : queue_tasks_deferred) {
454
+ queue_tasks.push_back(std::move(task));
455
+ }
456
+ queue_tasks_deferred.clear();
457
+ }
458
+
459
+ // end the start_loop routine
460
+ void terminate() {
461
+ std::unique_lock<std::mutex> lock(mutex_tasks);
462
+ running = false;
463
+ condition_tasks.notify_all();
464
+ }
465
+
466
+ /**
467
+ * Main loop consists of these steps:
468
+ * - Wait until a new task arrives
469
+ * - Process the task (i.e. maybe copy data into slot)
470
+ * - Check if multitask is finished
471
+ * - Update all slots
472
+ */
473
+ void start_loop() {
474
+ running = true;
475
+
476
+ while (true) {
477
+ LOG_VERBOSE("new task may arrive", {});
478
+
479
+ while (true) {
480
+ std::unique_lock<std::mutex> lock(mutex_tasks);
481
+ if (queue_tasks.empty()) {
482
+ lock.unlock();
483
+ break;
484
+ }
485
+ server_task task = queue_tasks.front();
486
+ queue_tasks.erase(queue_tasks.begin());
487
+ lock.unlock();
488
+ LOG_VERBOSE("callback_new_task", {{"id_task", task.id}});
489
+ callback_new_task(task);
490
+ }
491
+
492
+ LOG_VERBOSE("update_multitasks", {});
493
+
494
+ // check if we have any finished multitasks
495
+ auto queue_iterator = queue_multitasks.begin();
496
+ while (queue_iterator != queue_multitasks.end()) {
497
+ if (queue_iterator->subtasks_remaining.empty()) {
498
+ // all subtasks done == multitask is done
499
+ server_task_multi current_multitask = *queue_iterator;
500
+ callback_finish_multitask(current_multitask);
501
+ // remove this multitask
502
+ queue_iterator = queue_multitasks.erase(queue_iterator);
503
+ } else {
504
+ ++queue_iterator;
505
+ }
506
+ }
507
+
508
+ // all tasks in the current loop is processed, slots data is now ready
509
+ LOG_VERBOSE("callback_update_slots", {});
510
+
511
+ callback_update_slots();
512
+
513
+ LOG_VERBOSE("wait for new task", {});
514
+ {
515
+ std::unique_lock<std::mutex> lock(mutex_tasks);
516
+ if (queue_tasks.empty()) {
517
+ if (!running) {
518
+ LOG_VERBOSE("ending start_loop", {});
519
+ return;
520
+ }
521
+ condition_tasks.wait(lock, [&]{
522
+ return (!queue_tasks.empty() || !running);
523
+ });
524
+ }
525
+ }
526
+ }
527
+ }
528
+
529
+ //
530
+ // functions to manage multitasks
531
+ //
532
+
533
+ // add a multitask by specifying the id of all subtask (subtask is a server_task)
534
+ void add_multitask(int id_multi, std::vector<int> & sub_ids) {
535
+ std::lock_guard<std::mutex> lock(mutex_tasks);
536
+ server_task_multi multi;
537
+ multi.id = id_multi;
538
+ std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
539
+ queue_multitasks.push_back(multi);
540
+ }
541
+
542
+ // updatethe remaining subtasks, while appending results to multitask
543
+ void update_multitask(int id_multi, int id_sub, server_task_result & result) {
544
+ std::lock_guard<std::mutex> lock(mutex_tasks);
545
+ for (auto & multitask : queue_multitasks) {
546
+ if (multitask.id == id_multi) {
547
+ multitask.subtasks_remaining.erase(id_sub);
548
+ multitask.results.push_back(result);
549
+ }
550
+ }
551
+ }
552
+ };
553
+
554
+ struct server_response {
555
+ typedef std::function<void(int, int, server_task_result &)> callback_multitask_t;
556
+ callback_multitask_t callback_update_multitask;
557
+
558
+ // for keeping track of all tasks waiting for the result
559
+ std::set<int> waiting_task_ids;
560
+
561
+ // the main result queue
562
+ std::vector<server_task_result> queue_results;
563
+
564
+ std::mutex mutex_results;
565
+ std::condition_variable condition_results;
566
+
567
+ // add the id_task to the list of tasks waiting for response
568
+ void add_waiting_task_id(int id_task) {
569
+ LOG_VERBOSE("waiting for task id", {{"id_task", id_task}});
570
+
571
+ std::unique_lock<std::mutex> lock(mutex_results);
572
+ waiting_task_ids.insert(id_task);
573
+ }
574
+
575
+ // when the request is finished, we can remove task associated with it
576
+ void remove_waiting_task_id(int id_task) {
577
+ LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}});
578
+
579
+ std::unique_lock<std::mutex> lock(mutex_results);
580
+ waiting_task_ids.erase(id_task);
581
+ }
582
+
583
+ // This function blocks the thread until there is a response for this id_task
584
+ server_task_result recv(int id_task) {
585
+ while (true) {
586
+ std::unique_lock<std::mutex> lock(mutex_results);
587
+ condition_results.wait(lock, [&]{
588
+ return !queue_results.empty();
589
+ });
590
+
591
+ for (int i = 0; i < (int) queue_results.size(); i++) {
592
+ if (queue_results[i].id == id_task) {
593
+ assert(queue_results[i].id_multi == -1);
594
+ server_task_result res = queue_results[i];
595
+ queue_results.erase(queue_results.begin() + i);
596
+ return res;
597
+ }
598
+ }
599
+ }
600
+
601
+ // should never reach here
602
+ }
603
+
604
+ // Register the function to update multitask
605
+ void on_multitask_update(callback_multitask_t callback) {
606
+ callback_update_multitask = std::move(callback);
607
+ }
608
+
609
+ // Send a new result to a waiting id_task
610
+ void send(server_task_result result) {
611
+ LOG_VERBOSE("send new result", {{"id_task", result.id}});
612
+
613
+ std::unique_lock<std::mutex> lock(mutex_results);
614
+ for (const auto & id_task : waiting_task_ids) {
615
+ // LOG_TEE("waiting task id %i \n", id_task);
616
+ // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
617
+ if (result.id_multi == id_task) {
618
+ LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}});
619
+ callback_update_multitask(id_task, result.id, result);
620
+ continue;
621
+ }
622
+
623
+ if (result.id == id_task) {
624
+ LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}});
625
+ queue_results.push_back(result);
626
+ condition_results.notify_all();
627
+ return;
628
+ }
629
+ }
630
+ }
631
+ };
632
+
633
+ struct server_context {
634
+ llama_model * model = nullptr;
635
+ llama_context * ctx = nullptr;
636
+
637
+ gpt_params params;
638
+
639
+ llama_batch batch;
640
+
641
+ bool clean_kv_cache = true;
642
+ bool add_bos_token = true;
643
+
644
+ int32_t n_ctx; // total context for all clients / slots
645
+
646
+ // system prompt
647
+ bool system_need_update = false;
648
+
649
+ std::string system_prompt;
650
+ std::vector<llama_token> system_tokens;
651
+
652
+ std::string name_user; // this should be the antiprompt
653
+ std::string name_assistant;
654
+
655
+ // slots / clients
656
+ std::vector<server_slot> slots;
657
+ json default_generation_settings_for_props;
658
+
659
+ server_queue queue_tasks;
660
+ server_response queue_results;
661
+
662
+ server_metrics metrics;
663
+
664
+ ~server_context() {
665
+ if (ctx) {
666
+ llama_free(ctx);
667
+ ctx = nullptr;
668
+ }
669
+
670
+ if (model) {
671
+ llama_free_model(model);
672
+ model = nullptr;
673
+ }
674
+ }
675
+
676
+ bool load_model(const gpt_params & params_) {
677
+ params = params_;
678
+
679
+ // dedicate one sequence to the system prompt
680
+ params.n_parallel += 1;
681
+
682
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
683
+ params.n_parallel -= 1; // but be sneaky about it
684
+ if (model == nullptr) {
685
+ LOG_ERROR("unable to load model", {{"model", params.model}});
686
+ return false;
687
+ }
688
+
689
+ n_ctx = llama_n_ctx(ctx);
690
+
691
+ add_bos_token = llama_should_add_bos_token(model);
692
+ GGML_ASSERT(llama_add_eos_token(model) != 1);
693
+
694
+ return true;
695
+ }
696
+
697
+ bool validate_model_chat_template() const {
698
+ llama_chat_message chat[] = {{"user", "test"}};
699
+
700
+ const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
701
+
702
+ return res > 0;
703
+ }
704
+
705
+ void init() {
706
+ const int32_t n_ctx_slot = n_ctx / params.n_parallel;
707
+
708
+ LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
709
+
710
+ for (int i = 0; i < params.n_parallel; i++) {
711
+ server_slot slot;
712
+
713
+ slot.id = i;
714
+ slot.n_ctx = n_ctx_slot;
715
+ slot.n_predict = params.n_predict;
716
+
717
+ LOG_INFO("new slot", {
718
+ {"id_slot", slot.id},
719
+ {"n_ctx_slot", slot.n_ctx}
720
+ });
721
+
722
+ const int ga_n = params.grp_attn_n;
723
+ const int ga_w = params.grp_attn_w;
724
+
725
+ if (ga_n != 1) {
726
+ GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
727
+ GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
728
+ //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
729
+ //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
730
+
731
+ LOG_INFO("slot self-extend", {
732
+ {"id_slot", slot.id},
733
+ {"ga_n", ga_n},
734
+ {"ga_w", ga_w}
735
+ });
736
+ }
737
+
738
+ slot.ga_i = 0;
739
+ slot.ga_n = ga_n;
740
+ slot.ga_w = ga_w;
741
+
742
+ slot.reset();
743
+
744
+ slots.push_back(slot);
745
+ }
746
+
747
+ default_generation_settings_for_props = get_formated_generation(slots.front());
748
+ default_generation_settings_for_props["seed"] = -1;
749
+
750
+ // the update_slots() logic will always submit a maximum of n_batch tokens
751
+ // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
752
+ {
753
+ const int32_t n_batch = llama_n_batch(ctx);
754
+
755
+ // only a single seq_id per token is needed
756
+ batch = llama_batch_init(n_batch, 0, 1);
757
+ }
758
+
759
+ metrics.init();
760
+ }
761
+
762
+ std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
763
+ // TODO: currently, we tokenize using special tokens by default
764
+ // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
765
+ // but it's better compared to completely ignoring ChatML and other chat templates
766
+ const bool TMP_FORCE_SPECIAL = true;
767
+
768
+ // If `add_bos` is true, we only add BOS, when json_prompt is a string,
769
+ // or the first element of the json_prompt array is a string.
770
+ std::vector<llama_token> prompt_tokens;
771
+
772
+ if (json_prompt.is_array()) {
773
+ bool first = true;
774
+ for (const auto & p : json_prompt) {
775
+ if (p.is_string()) {
776
+ auto s = p.template get<std::string>();
777
+
778
+ std::vector<llama_token> p;
779
+ if (first) {
780
+ p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
781
+ first = false;
782
+ } else {
783
+ p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
784
+ }
785
+
786
+ prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
787
+ } else {
788
+ if (first) {
789
+ first = false;
790
+ }
791
+
792
+ prompt_tokens.push_back(p.template get<llama_token>());
793
+ }
794
+ }
795
+ } else {
796
+ auto s = json_prompt.template get<std::string>();
797
+ prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
798
+ }
799
+
800
+ return prompt_tokens;
801
+ }
802
+
803
+ server_slot * get_slot(int id) {
804
+ int64_t t_last = ggml_time_us();
805
+
806
+ server_slot * last_used = nullptr;
807
+
808
+ for (server_slot & slot : slots) {
809
+ if (slot.id == id && slot.available()) {
810
+ return &slot;
811
+ }
812
+
813
+ // among all available slots, find the one that has been least recently used
814
+ if (slot.available() && slot.t_last_used < t_last) {
815
+ last_used = &slot;
816
+ t_last = slot.t_last_used;
817
+ }
818
+ }
819
+
820
+ return last_used;
821
+ }
822
+
823
+ bool launch_slot_with_task(server_slot & slot, const server_task & task) {
824
+ slot_params default_params;
825
+ llama_sampling_params default_sparams;
826
+ auto & data = task.data;
827
+
828
+ if (data.count("__oaicompat") != 0) {
829
+ slot.oaicompat = true;
830
+ slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
831
+ } else {
832
+ slot.oaicompat = false;
833
+ slot.oaicompat_model = "";
834
+ }
835
+
836
+ slot.params.stream = json_value(data, "stream", false);
837
+ slot.params.cache_prompt = json_value(data, "cache_prompt", false);
838
+ slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
839
+ slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
840
+ slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
841
+ slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
842
+ slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
843
+ slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
844
+ slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
845
+ slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
846
+ slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
847
+ slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
848
+ slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
849
+ slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
850
+ slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
851
+ slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
852
+ slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
853
+ slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
854
+ slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
855
+ slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
856
+ slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
857
+ slot.params.seed = json_value(data, "seed", default_params.seed);
858
+ slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
859
+ slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
860
+
861
+ // process "json_schema" and "grammar"
862
+ if (data.contains("json_schema") && !data["json_schema"].is_null() && data.contains("grammar") && !data["grammar"].is_null()) {
863
+ send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
864
+ return false;
865
+ } else if (data.contains("json_schema") && !data.contains("grammar")) {
866
+ try {
867
+ auto schema = json_value(data, "json_schema", json::object());
868
+ slot.sparams.grammar = json_schema_to_grammar(schema);
869
+ } catch (const std::exception & e) {
870
+ send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
871
+ return false;
872
+ }
873
+ } else {
874
+ slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
875
+ }
876
+
877
+ if (slot.params.cache_prompt && slot.ga_n != 1) {
878
+ LOG_WARNING("cache_prompt is not supported with group-attention", {});
879
+ slot.params.cache_prompt = false;
880
+ }
881
+
882
+ if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
883
+ // Might be better to reject the request with a 400 ?
884
+ LOG_WARNING("Max tokens to predict exceeds server configuration", {
885
+ {"params.n_predict", slot.params.n_predict},
886
+ {"slot.n_predict", slot.n_predict},
887
+ });
888
+ slot.params.n_predict = slot.n_predict;
889
+ }
890
+
891
+ // infill
892
+ slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
893
+ slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
894
+
895
+ // get prompt
896
+ {
897
+ const auto & prompt = data.find("prompt");
898
+ if (prompt == data.end()) {
899
+ send_error(task, "Either \"prompt\" or \"messages\" must be provided", ERROR_TYPE_INVALID_REQUEST);
900
+ return false;
901
+ } else {
902
+ slot.prompt = *prompt;
903
+ }
904
+ if (slot.prompt.is_array() && slot.prompt.size() == 0) {
905
+ send_error(task, "\"prompt\" cannot be an empty array", ERROR_TYPE_INVALID_REQUEST);
906
+ return false;
907
+ }
908
+ }
909
+
910
+ // penalize user-provided tokens
911
+ {
912
+ slot.sparams.penalty_prompt_tokens.clear();
913
+ slot.sparams.use_penalty_prompt_tokens = false;
914
+
915
+ const auto & penalty_prompt = data.find("penalty_prompt");
916
+
917
+ if (penalty_prompt != data.end()) {
918
+ if (penalty_prompt->is_string()) {
919
+ const auto penalty_prompt_string = penalty_prompt->get<std::string>();
920
+ slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
921
+
922
+ if (slot.params.n_predict > 0) {
923
+ slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
924
+ }
925
+ slot.sparams.use_penalty_prompt_tokens = true;
926
+
927
+ LOG_VERBOSE("penalty_prompt_tokens", {
928
+ {"id_slot", slot.id},
929
+ {"tokens", slot.sparams.penalty_prompt_tokens},
930
+ });
931
+ }
932
+ else if (penalty_prompt->is_array()) {
933
+ const auto n_tokens = penalty_prompt->size();
934
+ slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
935
+
936
+ const int n_vocab = llama_n_vocab(model);
937
+ for (const auto & penalty_token : *penalty_prompt) {
938
+ if (penalty_token.is_number_integer()) {
939
+ const auto tok = penalty_token.get<llama_token>();
940
+ if (tok >= 0 && tok < n_vocab) {
941
+ slot.sparams.penalty_prompt_tokens.push_back(tok);
942
+ }
943
+ }
944
+ }
945
+ slot.sparams.use_penalty_prompt_tokens = true;
946
+
947
+ LOG_VERBOSE("penalty_prompt_tokens", {
948
+ {"id_slot", slot.id},
949
+ {"tokens", slot.sparams.penalty_prompt_tokens},
950
+ });
951
+ }
952
+ }
953
+ }
954
+
955
+ {
956
+ slot.sparams.logit_bias.clear();
957
+
958
+ if (json_value(data, "ignore_eos", false)) {
959
+ slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
960
+ }
961
+
962
+ const auto & logit_bias = data.find("logit_bias");
963
+ if (logit_bias != data.end() && logit_bias->is_array()) {
964
+ const int n_vocab = llama_n_vocab(model);
965
+ for (const auto & el : *logit_bias) {
966
+ // TODO: we may want to throw errors here, in case "el" is incorrect
967
+ if (el.is_array() && el.size() == 2) {
968
+ float bias;
969
+ if (el[1].is_number()) {
970
+ bias = el[1].get<float>();
971
+ } else if (el[1].is_boolean() && !el[1].get<bool>()) {
972
+ bias = -INFINITY;
973
+ } else {
974
+ continue;
975
+ }
976
+
977
+ if (el[0].is_number_integer()) {
978
+ llama_token tok = el[0].get<llama_token>();
979
+ if (tok >= 0 && tok < n_vocab) {
980
+ slot.sparams.logit_bias[tok] = bias;
981
+ }
982
+ } else if (el[0].is_string()) {
983
+ auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
984
+ for (auto tok : toks) {
985
+ slot.sparams.logit_bias[tok] = bias;
986
+ }
987
+ }
988
+ }
989
+ }
990
+ }
991
+ }
992
+
993
+ {
994
+ slot.params.antiprompt.clear();
995
+
996
+ const auto & stop = data.find("stop");
997
+ if (stop != data.end() && stop->is_array()) {
998
+ for (const auto & word : *stop) {
999
+ if (!word.empty()) {
1000
+ slot.params.antiprompt.push_back(word);
1001
+ }
1002
+ }
1003
+ }
1004
+ }
1005
+
1006
+ {
1007
+ const auto & samplers_sequence = data.find("samplers");
1008
+ if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
1009
+ std::vector<std::string> sampler_names;
1010
+ for (const auto & sampler_name : *samplers_sequence) {
1011
+ if (sampler_name.is_string()) {
1012
+ sampler_names.emplace_back(sampler_name);
1013
+ }
1014
+ }
1015
+ slot.sparams.samplers_sequence = sampler_types_from_names(sampler_names, false);
1016
+ } else {
1017
+ slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
1018
+ }
1019
+ }
1020
+
1021
+ {
1022
+ if (slot.ctx_sampling != nullptr) {
1023
+ llama_sampling_free(slot.ctx_sampling);
1024
+ }
1025
+ slot.ctx_sampling = llama_sampling_init(slot.sparams);
1026
+ if (slot.ctx_sampling == nullptr) {
1027
+ // for now, the only error that may happen here is invalid grammar
1028
+ send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
1029
+ return false;
1030
+ }
1031
+ llama_set_rng_seed(ctx, slot.params.seed);
1032
+ }
1033
+
1034
+ slot.command = SLOT_COMMAND_LOAD_PROMPT;
1035
+ slot.prompt_tokens.clear();
1036
+
1037
+ LOG_INFO("slot is processing task", {
1038
+ {"id_slot", slot.id},
1039
+ {"id_task", slot.id_task},
1040
+ });
1041
+
1042
+ return true;
1043
+ }
1044
+
1045
+ void kv_cache_clear() {
1046
+ LOG_VERBOSE("clearing KV cache", {});
1047
+
1048
+ // clear the entire KV cache
1049
+ llama_kv_cache_clear(ctx);
1050
+ clean_kv_cache = false;
1051
+ }
1052
+
1053
+ void system_prompt_update() {
1054
+ LOG_VERBOSE("system prompt update", {
1055
+ {"system_prompt", system_prompt},
1056
+ });
1057
+
1058
+ kv_cache_clear();
1059
+ system_tokens.clear();
1060
+
1061
+ if (!system_prompt.empty()) {
1062
+ system_tokens = ::llama_tokenize(ctx, system_prompt, true);
1063
+
1064
+ llama_batch_clear(batch);
1065
+
1066
+ for (int i = 0; i < (int)system_tokens.size(); ++i) {
1067
+ llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
1068
+ }
1069
+
1070
+ const int32_t n_batch = llama_n_batch(ctx);
1071
+
1072
+ for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
1073
+ const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
1074
+ llama_batch batch_view = {
1075
+ n_tokens,
1076
+ batch.token + i,
1077
+ nullptr,
1078
+ batch.pos + i,
1079
+ batch.n_seq_id + i,
1080
+ batch.seq_id + i,
1081
+ batch.logits + i,
1082
+ 0, 0, 0, // unused
1083
+ };
1084
+
1085
+ if (llama_decode(ctx, batch_view) != 0) {
1086
+ LOG_ERROR("llama_decode() failed", {});
1087
+ return;
1088
+ }
1089
+ }
1090
+
1091
+ // assign the system KV cache to all parallel sequences
1092
+ for (int32_t i = 1; i <= params.n_parallel; ++i) {
1093
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
1094
+ }
1095
+ }
1096
+
1097
+ system_need_update = false;
1098
+ }
1099
+
1100
+ void system_prompt_set(const json & sys_props) {
1101
+ system_prompt = sys_props.value("prompt", "");
1102
+ name_user = sys_props.value("anti_prompt", "");
1103
+ name_assistant = sys_props.value("assistant_name", "");
1104
+
1105
+ LOG_VERBOSE("system prompt process", {
1106
+ {"system_prompt", system_prompt},
1107
+ {"name_user", name_user},
1108
+ {"name_assistant", name_assistant},
1109
+ });
1110
+
1111
+ // release all slots
1112
+ for (server_slot & slot : slots) {
1113
+ slot.release();
1114
+ }
1115
+
1116
+ system_need_update = true;
1117
+ }
1118
+
1119
+ bool process_token(completion_token_output & result, server_slot & slot) {
1120
+ // remember which tokens were sampled - used for repetition penalties during sampling
1121
+ const std::string token_str = llama_token_to_piece(ctx, result.tok);
1122
+ slot.sampled = result.tok;
1123
+
1124
+ // search stop word and delete it
1125
+ slot.generated_text += token_str;
1126
+ slot.has_next_token = true;
1127
+
1128
+ if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
1129
+ // we can change penalty_prompt_tokens because it is always created from scratch each request
1130
+ slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
1131
+ }
1132
+
1133
+ // check if there is incomplete UTF-8 character at the end
1134
+ bool incomplete = false;
1135
+ for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
1136
+ unsigned char c = slot.generated_text[slot.generated_text.size() - i];
1137
+ if ((c & 0xC0) == 0x80) {
1138
+ // continuation byte: 10xxxxxx
1139
+ continue;
1140
+ }
1141
+ if ((c & 0xE0) == 0xC0) {
1142
+ // 2-byte character: 110xxxxx ...
1143
+ incomplete = i < 2;
1144
+ } else if ((c & 0xF0) == 0xE0) {
1145
+ // 3-byte character: 1110xxxx ...
1146
+ incomplete = i < 3;
1147
+ } else if ((c & 0xF8) == 0xF0) {
1148
+ // 4-byte character: 11110xxx ...
1149
+ incomplete = i < 4;
1150
+ }
1151
+ // else 1-byte character or invalid byte
1152
+ break;
1153
+ }
1154
+
1155
+ if (!incomplete) {
1156
+ size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
1157
+
1158
+ const std::string str_test = slot.generated_text.substr(pos);
1159
+ bool is_stop_full = false;
1160
+
1161
+ size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
1162
+ if (stop_pos != std::string::npos) {
1163
+ is_stop_full = true;
1164
+ slot.generated_text.erase(
1165
+ slot.generated_text.begin() + pos + stop_pos,
1166
+ slot.generated_text.end());
1167
+ pos = std::min(slot.n_sent_text, slot.generated_text.size());
1168
+ } else {
1169
+ is_stop_full = false;
1170
+ stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
1171
+ }
1172
+
1173
+ // check if there is any token to predict
1174
+ if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
1175
+ // no send the stop word in the response
1176
+ result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
1177
+ slot.n_sent_text += result.text_to_send.size();
1178
+ // add the token to slot queue and cache
1179
+ }
1180
+
1181
+ slot.add_token_string(result);
1182
+ if (slot.params.stream) {
1183
+ send_partial_response(slot, result);
1184
+ }
1185
+ }
1186
+
1187
+ if (incomplete) {
1188
+ slot.has_next_token = true;
1189
+ }
1190
+
1191
+ // check the limits
1192
+ if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
1193
+ slot.stopped_limit = true;
1194
+ slot.has_next_token = false;
1195
+
1196
+ LOG_VERBOSE("stopped by limit", {
1197
+ {"id_slot", slot.id},
1198
+ {"id_task", slot.id_task},
1199
+ {"n_decoded", slot.n_decoded},
1200
+ {"n_predict", slot.params.n_predict},
1201
+ });
1202
+ }
1203
+
1204
+ if (llama_token_is_eog(model, result.tok)) {
1205
+ slot.stopped_eos = true;
1206
+ slot.has_next_token = false;
1207
+
1208
+ LOG_VERBOSE("eos token found", {});
1209
+ }
1210
+
1211
+ LOG_VERBOSE("next token", {
1212
+ {"id_slot", slot.id},
1213
+ {"id_task", slot.id_task},
1214
+ {"token", result.tok},
1215
+ {"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
1216
+ {"has_next_token", slot.has_next_token},
1217
+ {"n_remain", slot.n_remaining},
1218
+ {"n_decoded", slot.n_decoded},
1219
+ {"stopped_eos", slot.stopped_eos},
1220
+ {"stopped_word", slot.stopped_word},
1221
+ {"stopped_limit", slot.stopped_limit},
1222
+ {"stopping_word", slot.stopping_word},
1223
+ });
1224
+
1225
+ return slot.has_next_token; // continue
1226
+ }
1227
+
1228
+ json get_formated_generation(const server_slot & slot) const {
1229
+ const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
1230
+ const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
1231
+
1232
+ std::vector<std::string> samplers_sequence;
1233
+ samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
1234
+ for (const auto & sampler_type : slot.sparams.samplers_sequence) {
1235
+ samplers_sequence.emplace_back(sampler_type_to_name_string(sampler_type));
1236
+ }
1237
+
1238
+ return json {
1239
+ {"n_ctx", slot.n_ctx},
1240
+ {"n_predict", slot.n_predict},
1241
+ {"model", params.model_alias},
1242
+ {"seed", slot.params.seed},
1243
+ {"temperature", slot.sparams.temp},
1244
+ {"dynatemp_range", slot.sparams.dynatemp_range},
1245
+ {"dynatemp_exponent", slot.sparams.dynatemp_exponent},
1246
+ {"top_k", slot.sparams.top_k},
1247
+ {"top_p", slot.sparams.top_p},
1248
+ {"min_p", slot.sparams.min_p},
1249
+ {"tfs_z", slot.sparams.tfs_z},
1250
+ {"typical_p", slot.sparams.typical_p},
1251
+ {"repeat_last_n", slot.sparams.penalty_last_n},
1252
+ {"repeat_penalty", slot.sparams.penalty_repeat},
1253
+ {"presence_penalty", slot.sparams.penalty_present},
1254
+ {"frequency_penalty", slot.sparams.penalty_freq},
1255
+ {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
1256
+ {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
1257
+ {"mirostat", slot.sparams.mirostat},
1258
+ {"mirostat_tau", slot.sparams.mirostat_tau},
1259
+ {"mirostat_eta", slot.sparams.mirostat_eta},
1260
+ {"penalize_nl", slot.sparams.penalize_nl},
1261
+ {"stop", slot.params.antiprompt},
1262
+ {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
1263
+ {"n_keep", slot.params.n_keep},
1264
+ {"n_discard", slot.params.n_discard},
1265
+ {"ignore_eos", ignore_eos},
1266
+ {"stream", slot.params.stream},
1267
+ {"logit_bias", slot.sparams.logit_bias},
1268
+ {"n_probs", slot.sparams.n_probs},
1269
+ {"min_keep", slot.sparams.min_keep},
1270
+ {"grammar", slot.sparams.grammar},
1271
+ {"samplers", samplers_sequence}
1272
+ };
1273
+ }
1274
+
1275
+ void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1276
+ send_error(task.id, task.id_multi, error, type);
1277
+ }
1278
+
1279
+ void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1280
+ send_error(slot.id_task, slot.id_multi, error, type);
1281
+ }
1282
+
1283
+ void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1284
+ LOG_ERROR("task error", {
1285
+ {"id_multi", id_multi},
1286
+ {"id_task", id_task},
1287
+ {"error", error},
1288
+ });
1289
+
1290
+ server_task_result res;
1291
+ res.id = id_task;
1292
+ res.id_multi = id_multi;
1293
+ res.stop = false;
1294
+ res.error = true;
1295
+ res.data = format_error_response(error, type);
1296
+
1297
+ queue_results.send(res);
1298
+ }
1299
+
1300
+ void send_partial_response(server_slot & slot, completion_token_output tkn) {
1301
+ server_task_result res;
1302
+ res.id = slot.id_task;
1303
+ res.id_multi = slot.id_multi;
1304
+ res.error = false;
1305
+ res.stop = false;
1306
+ res.data = json {
1307
+ {"content", tkn.text_to_send},
1308
+ {"stop", false},
1309
+ {"id_slot", slot.id},
1310
+ {"multimodal", false}
1311
+ };
1312
+
1313
+ if (slot.sparams.n_probs > 0) {
1314
+ const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
1315
+ const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
1316
+ const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
1317
+
1318
+ std::vector<completion_token_output> probs_output;
1319
+ if (probs_pos < probs_stop_pos) {
1320
+ probs_output = std::vector<completion_token_output>(
1321
+ slot.generated_token_probs.begin() + probs_pos,
1322
+ slot.generated_token_probs.begin() + probs_stop_pos);
1323
+ }
1324
+ slot.n_sent_token_probs = probs_stop_pos;
1325
+
1326
+ res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
1327
+ }
1328
+
1329
+ if (slot.oaicompat) {
1330
+ res.data["oaicompat_token_ctr"] = slot.n_decoded;
1331
+ res.data["model"] = slot.oaicompat_model;
1332
+ }
1333
+
1334
+ queue_results.send(res);
1335
+ }
1336
+
1337
+ void send_final_response(const server_slot & slot) {
1338
+ server_task_result res;
1339
+ res.id = slot.id_task;
1340
+ res.id_multi = slot.id_multi;
1341
+ res.error = false;
1342
+ res.stop = true;
1343
+ res.data = json {
1344
+ {"content", !slot.params.stream ? slot.generated_text : ""},
1345
+ {"id_slot", slot.id},
1346
+ {"stop", true},
1347
+ {"model", params.model_alias},
1348
+ {"tokens_predicted", slot.n_decoded},
1349
+ {"tokens_evaluated", slot.n_prompt_tokens},
1350
+ {"generation_settings", get_formated_generation(slot)},
1351
+ {"prompt", slot.prompt},
1352
+ {"truncated", slot.truncated},
1353
+ {"stopped_eos", slot.stopped_eos},
1354
+ {"stopped_word", slot.stopped_word},
1355
+ {"stopped_limit", slot.stopped_limit},
1356
+ {"stopping_word", slot.stopping_word},
1357
+ {"tokens_cached", slot.n_past},
1358
+ {"timings", slot.get_formated_timings()}
1359
+ };
1360
+
1361
+ if (slot.sparams.n_probs > 0) {
1362
+ std::vector<completion_token_output> probs;
1363
+ if (!slot.params.stream && slot.stopped_word) {
1364
+ const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
1365
+
1366
+ probs = std::vector<completion_token_output>(
1367
+ slot.generated_token_probs.begin(),
1368
+ slot.generated_token_probs.end() - stop_word_toks.size());
1369
+ } else {
1370
+ probs = std::vector<completion_token_output>(
1371
+ slot.generated_token_probs.begin(),
1372
+ slot.generated_token_probs.end());
1373
+ }
1374
+
1375
+ res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
1376
+ }
1377
+
1378
+ if (slot.oaicompat) {
1379
+ res.data["oaicompat_token_ctr"] = slot.n_decoded;
1380
+ res.data["model"] = slot.oaicompat_model;
1381
+ }
1382
+
1383
+ queue_results.send(res);
1384
+ }
1385
+
1386
+ void send_embedding(const server_slot & slot, const llama_batch & batch) {
1387
+ server_task_result res;
1388
+ res.id = slot.id_task;
1389
+ res.id_multi = slot.id_multi;
1390
+ res.error = false;
1391
+ res.stop = true;
1392
+
1393
+ const int n_embd = llama_n_embd(model);
1394
+
1395
+ std::vector<float> embd_res(n_embd, 0.0f);
1396
+
1397
+ for (int i = 0; i < batch.n_tokens; ++i) {
1398
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1399
+ continue;
1400
+ }
1401
+
1402
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1403
+ if (embd == NULL) {
1404
+ embd = llama_get_embeddings_ith(ctx, i);
1405
+ }
1406
+
1407
+ if (embd == NULL) {
1408
+ LOG_ERROR("failed to get embeddings", {
1409
+ {"token", batch.token [i]},
1410
+ {"seq_id", batch.seq_id[i][0]}
1411
+ });
1412
+
1413
+ res.data = json {
1414
+ {"embedding", std::vector<float>(n_embd, 0.0f)},
1415
+ };
1416
+
1417
+ continue;
1418
+ }
1419
+
1420
+ llama_embd_normalize(embd, embd_res.data(), n_embd);
1421
+
1422
+ res.data = json {
1423
+ {"embedding", embd_res},
1424
+ };
1425
+ }
1426
+
1427
+ queue_results.send(res);
1428
+ }
1429
+
1430
+ void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) {
1431
+ server_task task;
1432
+ task.id = id_task;
1433
+ task.id_multi = id_multi;
1434
+ task.id_target = 0;
1435
+ task.data = std::move(data);
1436
+ task.infill = infill;
1437
+ task.embedding = embedding;
1438
+ task.type = SERVER_TASK_TYPE_COMPLETION;
1439
+
1440
+ // when a completion task's prompt array is not a singleton, we split it into multiple requests
1441
+ // otherwise, it's a single-prompt task, we actually queue it
1442
+ // if there's numbers in the prompt array it will be treated as an array of tokens
1443
+ if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
1444
+ bool numbers = false;
1445
+ for (const auto & e : task.data.at("prompt")) {
1446
+ if (e.is_number()) {
1447
+ numbers = true;
1448
+ break;
1449
+ }
1450
+ }
1451
+
1452
+ // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
1453
+ // it will completely stall the server. I don't know where the bug for this is.
1454
+ //
1455
+ // if there are numbers, it needs to be treated like a single prompt,
1456
+ // queue_tasks handles a mix of strings and numbers just fine.
1457
+ if (numbers) {
1458
+ queue_tasks.post(task);
1459
+ } else {
1460
+ split_multiprompt_task(id_task, task);
1461
+ }
1462
+ } else {
1463
+ queue_tasks.post(task);
1464
+ }
1465
+ }
1466
+
1467
+ void request_cancel(int id_task) {
1468
+ server_task task;
1469
+ task.type = SERVER_TASK_TYPE_CANCEL;
1470
+ task.id_target = id_task;
1471
+
1472
+ queue_tasks.post(task);
1473
+ }
1474
+
1475
+ void split_multiprompt_task(int id_multi, const server_task & multiprompt_task) {
1476
+ const int prompt_count = multiprompt_task.data.at("prompt").size();
1477
+ if (prompt_count <= 1) {
1478
+ send_error(multiprompt_task, "error while handling multiple prompts");
1479
+ return;
1480
+ }
1481
+
1482
+ // generate all the ID for subtask
1483
+ std::vector<int> subtask_ids(prompt_count);
1484
+ for (int i = 0; i < prompt_count; i++) {
1485
+ subtask_ids[i] = queue_tasks.get_new_id();
1486
+ }
1487
+
1488
+ // queue up the multitask so we can track its subtask progression
1489
+ queue_tasks.add_multitask(id_multi, subtask_ids);
1490
+
1491
+ // add subtasks
1492
+ for (int i = 0; i < prompt_count; i++) {
1493
+ json subtask_data = multiprompt_task.data;
1494
+ subtask_data["prompt"] = subtask_data["prompt"][i];
1495
+
1496
+ // subtasks inherit everything else (infill mode, embedding mode, etc.)
1497
+ request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding);
1498
+ }
1499
+ }
1500
+
1501
+ void process_single_task(const server_task & task) {
1502
+ switch (task.type) {
1503
+ case SERVER_TASK_TYPE_COMPLETION:
1504
+ {
1505
+ server_slot * slot = get_slot(json_value(task.data, "id_slot", -1));
1506
+ if (slot == nullptr) {
1507
+ // if no slot is available, we defer this task for processing later
1508
+ LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
1509
+ queue_tasks.defer(task);
1510
+ break;
1511
+ }
1512
+
1513
+ if (task.data.contains("system_prompt")) {
1514
+ system_prompt_set(task.data["system_prompt"]);
1515
+
1516
+ for (server_slot & slot : slots) {
1517
+ slot.n_past = 0;
1518
+ slot.n_past_se = 0;
1519
+ }
1520
+ }
1521
+
1522
+ slot->reset();
1523
+
1524
+ slot->id_task = task.id;
1525
+ slot->id_multi = task.id_multi;
1526
+ slot->infill = task.infill;
1527
+ slot->embedding = task.embedding;
1528
+
1529
+ if (!launch_slot_with_task(*slot, task)) {
1530
+ LOG_ERROR("error while launching slot", task.data);
1531
+ break;
1532
+ }
1533
+ } break;
1534
+ case SERVER_TASK_TYPE_CANCEL:
1535
+ {
1536
+ // release slot linked with the task id
1537
+ for (auto & slot : slots) {
1538
+ if (slot.id_task == task.id_target) {
1539
+ slot.release();
1540
+ break;
1541
+ }
1542
+ }
1543
+ } break;
1544
+ case SERVER_TASK_TYPE_NEXT_RESPONSE:
1545
+ {
1546
+ // do nothing
1547
+ } break;
1548
+ case SERVER_TASK_TYPE_METRICS:
1549
+ {
1550
+ json slots_data = json::array();
1551
+
1552
+ int n_idle_slots = 0;
1553
+ int n_processing_slots = 0;
1554
+
1555
+ for (server_slot & slot : slots) {
1556
+ json slot_data = get_formated_generation(slot);
1557
+ slot_data["id"] = slot.id;
1558
+ slot_data["id_task"] = slot.id_task;
1559
+ slot_data["state"] = slot.state;
1560
+ slot_data["prompt"] = slot.prompt;
1561
+ slot_data["next_token"] = {
1562
+ {"has_next_token", slot.has_next_token},
1563
+ {"n_remain", slot.n_remaining},
1564
+ {"n_decoded", slot.n_decoded},
1565
+ {"stopped_eos", slot.stopped_eos},
1566
+ {"stopped_word", slot.stopped_word},
1567
+ {"stopped_limit", slot.stopped_limit},
1568
+ {"stopping_word", slot.stopping_word},
1569
+ };
1570
+
1571
+ if (slot_data["state"] == SLOT_STATE_IDLE) {
1572
+ n_idle_slots++;
1573
+ } else {
1574
+ n_processing_slots++;
1575
+ }
1576
+
1577
+ slots_data.push_back(slot_data);
1578
+ }
1579
+ LOG_INFO("slot data", {
1580
+ {"id_task", task.id},
1581
+ {"n_idle_slots", n_idle_slots},
1582
+ {"n_processing_slots", n_processing_slots}
1583
+ });
1584
+
1585
+ LOG_VERBOSE("slot data", {
1586
+ {"id_task", task.id},
1587
+ {"n_idle_slots", n_idle_slots},
1588
+ {"n_processing_slots", n_processing_slots},
1589
+ {"slots", slots_data}
1590
+ });
1591
+
1592
+ server_task_result res;
1593
+ res.id = task.id;
1594
+ res.id_multi = task.id_multi;
1595
+ res.stop = true;
1596
+ res.error = false;
1597
+ res.data = {
1598
+ { "idle", n_idle_slots },
1599
+ { "processing", n_processing_slots },
1600
+ { "deferred", queue_tasks.queue_tasks_deferred.size() },
1601
+ { "t_start", metrics.t_start},
1602
+
1603
+ { "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
1604
+ { "t_tokens_generation_total", metrics.t_tokens_generation_total},
1605
+ { "n_tokens_predicted_total", metrics.n_tokens_predicted_total},
1606
+ { "t_prompt_processing_total", metrics.t_prompt_processing_total},
1607
+
1608
+ { "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed},
1609
+ { "t_prompt_processing", metrics.t_prompt_processing},
1610
+ { "n_tokens_predicted", metrics.n_tokens_predicted},
1611
+ { "t_tokens_generation", metrics.t_tokens_generation},
1612
+
1613
+ { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
1614
+ { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
1615
+
1616
+ { "slots", slots_data },
1617
+ };
1618
+
1619
+ if (json_value(task.data, "reset_bucket", false)) {
1620
+ metrics.reset_bucket();
1621
+ }
1622
+ queue_results.send(res);
1623
+ } break;
1624
+ case SERVER_TASK_TYPE_SLOT_SAVE:
1625
+ {
1626
+ int id_slot = task.data["id_slot"];
1627
+ server_slot * slot = get_slot(id_slot);
1628
+ if (slot == nullptr) {
1629
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1630
+ break;
1631
+ }
1632
+
1633
+ const size_t token_count = slot->cache_tokens.size();
1634
+ const int64_t t_start = ggml_time_us();
1635
+
1636
+ std::string filename = task.data["filename"];
1637
+ std::string filepath = task.data["filepath"];
1638
+
1639
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
1640
+
1641
+ const int64_t t_end = ggml_time_us();
1642
+ const double t_save_ms = (t_end - t_start) / 1000.0;
1643
+
1644
+ server_task_result result;
1645
+ result.id = task.id;
1646
+ result.stop = true;
1647
+ result.error = false;
1648
+ result.data = json {
1649
+ { "id_slot", id_slot },
1650
+ { "filename", filename },
1651
+ { "n_saved", token_count }, // tokens saved
1652
+ { "n_written", nwrite }, // bytes written
1653
+ { "timings", {
1654
+ { "save_ms", t_save_ms }
1655
+ } }
1656
+ };
1657
+ queue_results.send(result);
1658
+ } break;
1659
+ case SERVER_TASK_TYPE_SLOT_RESTORE:
1660
+ {
1661
+ int id_slot = task.data["id_slot"];
1662
+ server_slot * slot = get_slot(id_slot);
1663
+ if (slot == nullptr) {
1664
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1665
+ break;
1666
+ }
1667
+
1668
+ const int64_t t_start = ggml_time_us();
1669
+
1670
+ std::string filename = task.data["filename"];
1671
+ std::string filepath = task.data["filepath"];
1672
+
1673
+ slot->cache_tokens.resize(slot->n_ctx);
1674
+ size_t token_count = 0;
1675
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
1676
+ if (nread == 0) {
1677
+ slot->cache_tokens.resize(0);
1678
+ send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
1679
+ break;
1680
+ }
1681
+ slot->cache_tokens.resize(token_count);
1682
+
1683
+ const int64_t t_end = ggml_time_us();
1684
+ const double t_restore_ms = (t_end - t_start) / 1000.0;
1685
+
1686
+ server_task_result result;
1687
+ result.id = task.id;
1688
+ result.stop = true;
1689
+ result.error = false;
1690
+ result.data = json {
1691
+ { "id_slot", id_slot },
1692
+ { "filename", filename },
1693
+ { "n_restored", token_count }, // tokens restored
1694
+ { "n_read", nread }, // bytes read
1695
+ { "timings", {
1696
+ { "restore_ms", t_restore_ms }
1697
+ } }
1698
+ };
1699
+ queue_results.send(result);
1700
+ } break;
1701
+ case SERVER_TASK_TYPE_SLOT_ERASE:
1702
+ {
1703
+ int id_slot = task.data["id_slot"];
1704
+ server_slot * slot = get_slot(id_slot);
1705
+ if (slot == nullptr) {
1706
+ send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1707
+ break;
1708
+ }
1709
+
1710
+ // Erase token cache
1711
+ const size_t n_erased = slot->cache_tokens.size();
1712
+ llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
1713
+ slot->cache_tokens.clear();
1714
+
1715
+ server_task_result result;
1716
+ result.id = task.id;
1717
+ result.stop = true;
1718
+ result.error = false;
1719
+ result.data = json {
1720
+ { "id_slot", id_slot },
1721
+ { "n_erased", n_erased }
1722
+ };
1723
+ queue_results.send(result);
1724
+ } break;
1725
+ }
1726
+ }
1727
+
1728
+ void on_finish_multitask(const server_task_multi & multitask) {
1729
+ // all subtasks done == multitask is done
1730
+ server_task_result result;
1731
+ result.id = multitask.id;
1732
+ result.stop = true;
1733
+ result.error = false;
1734
+
1735
+ // collect json results into one json result
1736
+ std::vector<json> result_jsons;
1737
+ for (const auto & subres : multitask.results) {
1738
+ result_jsons.push_back(subres.data);
1739
+ result.error = result.error && subres.error;
1740
+ }
1741
+ result.data = json {
1742
+ { "results", result_jsons }
1743
+ };
1744
+
1745
+ queue_results.send(result);
1746
+ }
1747
+
1748
+ void update_slots() {
1749
+ if (system_need_update) {
1750
+ system_prompt_update();
1751
+ }
1752
+
1753
+ // release slots
1754
+ for (auto & slot : slots) {
1755
+ if (slot.command == SLOT_COMMAND_RELEASE) {
1756
+ slot.state = SLOT_STATE_IDLE;
1757
+ slot.command = SLOT_COMMAND_NONE;
1758
+ slot.t_last_used = ggml_time_us();
1759
+
1760
+ LOG_INFO("slot released", {
1761
+ {"id_slot", slot.id},
1762
+ {"id_task", slot.id_task},
1763
+ {"n_ctx", n_ctx},
1764
+ {"n_past", slot.n_past},
1765
+ {"n_system_tokens", system_tokens.size()},
1766
+ {"n_cache_tokens", slot.cache_tokens.size()},
1767
+ {"truncated", slot.truncated}
1768
+ });
1769
+
1770
+ queue_tasks.notify_slot_changed();
1771
+ }
1772
+ }
1773
+
1774
+ // check if all slots are idle
1775
+ {
1776
+ bool all_idle = true;
1777
+
1778
+ for (auto & slot : slots) {
1779
+ if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
1780
+ all_idle = false;
1781
+ break;
1782
+ }
1783
+ }
1784
+
1785
+ if (all_idle) {
1786
+ LOG_INFO("all slots are idle", {});
1787
+ if (system_prompt.empty() && clean_kv_cache) {
1788
+ kv_cache_clear();
1789
+ }
1790
+
1791
+ return;
1792
+ }
1793
+ }
1794
+
1795
+ {
1796
+ LOG_VERBOSE("posting NEXT_RESPONSE", {});
1797
+
1798
+ server_task task;
1799
+ task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
1800
+ task.id_target = -1;
1801
+
1802
+ queue_tasks.post(task);
1803
+ }
1804
+
1805
+ // apply context-shift if needed
1806
+ // TODO: simplify and improve
1807
+ for (server_slot & slot : slots) {
1808
+ if (slot.ga_n == 1) {
1809
+ if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
1810
+ // Shift context
1811
+ const int n_keep = slot.params.n_keep + add_bos_token;
1812
+ const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
1813
+ const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
1814
+
1815
+ LOG_INFO("slot context shift", {
1816
+ {"id_slot", slot.id},
1817
+ {"id_task", slot.id_task},
1818
+ {"n_keep", n_keep},
1819
+ {"n_left", n_left},
1820
+ {"n_discard", n_discard},
1821
+ {"n_ctx", n_ctx},
1822
+ {"n_past", slot.n_past},
1823
+ {"n_system_tokens", system_tokens.size()},
1824
+ {"n_cache_tokens", slot.cache_tokens.size()}
1825
+ });
1826
+
1827
+ llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1828
+ llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
1829
+
1830
+ if (slot.params.cache_prompt) {
1831
+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
1832
+ slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1833
+ }
1834
+
1835
+ slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1836
+ }
1837
+
1838
+ slot.n_past -= n_discard;
1839
+
1840
+ slot.truncated = true;
1841
+ }
1842
+ }
1843
+ }
1844
+
1845
+ // start populating the batch for this iteration
1846
+ llama_batch_clear(batch);
1847
+
1848
+ // frist, add sampled tokens from any ongoing sequences
1849
+ for (auto & slot : slots) {
1850
+ if (slot.state == SLOT_STATE_IDLE) {
1851
+ continue;
1852
+ }
1853
+
1854
+ slot.i_batch = batch.n_tokens;
1855
+
1856
+ const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
1857
+
1858
+ // TODO: we always have to take into account the "system_tokens"
1859
+ // this is not great and needs to be improved somehow
1860
+ llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
1861
+
1862
+ slot.n_past += 1;
1863
+
1864
+ if (slot.params.cache_prompt) {
1865
+ slot.cache_tokens.push_back(slot.sampled);
1866
+ }
1867
+
1868
+ LOG_VERBOSE("slot decode token", {
1869
+ {"id_slot", slot.id},
1870
+ {"id_task", slot.id_task},
1871
+ {"n_ctx", n_ctx},
1872
+ {"n_past", slot.n_past},
1873
+ {"n_system_tokens", system_tokens.size()},
1874
+ {"n_cache_tokens", slot.cache_tokens.size()},
1875
+ {"truncated", slot.truncated}
1876
+ });
1877
+ }
1878
+
1879
+ // process in chunks of params.n_batch
1880
+ int32_t n_batch = llama_n_batch(ctx);
1881
+ int32_t n_ubatch = llama_n_ubatch(ctx);
1882
+
1883
+ // next, batch any pending prompts without exceeding n_batch
1884
+ if (params.cont_batching || batch.n_tokens == 0) {
1885
+ for (auto & slot : slots) {
1886
+ // this slot still has a prompt to be processed
1887
+ if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
1888
+ auto & prompt_tokens = slot.prompt_tokens;
1889
+
1890
+ // we haven't tokenized the prompt yet - do it now:
1891
+ if (prompt_tokens.empty()) {
1892
+ LOG_VERBOSE("tokenizing prompt", {
1893
+ {"id_slot", slot.id},
1894
+ {"id_task", slot.id_task}
1895
+ });
1896
+
1897
+ slot.t_start_process_prompt = ggml_time_us();
1898
+ slot.t_start_generation = 0;
1899
+
1900
+ if (slot.infill) {
1901
+ bool suff_rm_leading_spc = true;
1902
+ if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
1903
+ params.input_suffix.erase(0, 1);
1904
+ suff_rm_leading_spc = false;
1905
+ }
1906
+
1907
+ auto prefix_tokens = tokenize(slot.params.input_prefix, false);
1908
+ auto suffix_tokens = tokenize(slot.params.input_suffix, false);
1909
+
1910
+ const int space_token = 29871; // TODO: this should not be hardcoded
1911
+ if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
1912
+ suffix_tokens.erase(suffix_tokens.begin());
1913
+ }
1914
+
1915
+ prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
1916
+ prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
1917
+ prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
1918
+ prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
1919
+ prefix_tokens.push_back(llama_token_middle(model));
1920
+ prompt_tokens = prefix_tokens;
1921
+ } else {
1922
+ prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
1923
+ }
1924
+
1925
+ slot.n_past = 0;
1926
+ slot.n_prompt_tokens = prompt_tokens.size();
1927
+
1928
+ LOG_VERBOSE("prompt tokenized", {
1929
+ {"id_slot", slot.id},
1930
+ {"id_task", slot.id_task},
1931
+ {"n_ctx", slot.n_ctx},
1932
+ {"n_keep", slot.params.n_keep},
1933
+ {"n_prompt_tokens", slot.n_prompt_tokens},
1934
+ {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
1935
+ });
1936
+
1937
+ // empty prompt passed -> release the slot and send empty response
1938
+ if (prompt_tokens.empty()) {
1939
+ LOG_INFO("empty prompt - releasing slot", {
1940
+ {"id_slot", slot.id},
1941
+ {"id_task", slot.id_task}
1942
+ });
1943
+
1944
+ slot.state = SLOT_STATE_PROCESSING;
1945
+ slot.command = SLOT_COMMAND_NONE;
1946
+ slot.release();
1947
+ slot.print_timings();
1948
+ send_final_response(slot);
1949
+ continue;
1950
+ }
1951
+
1952
+ if (slot.embedding) {
1953
+ // this prompt is too large to process - discard it
1954
+ if (slot.n_prompt_tokens > n_ubatch) {
1955
+ slot.state = SLOT_STATE_PROCESSING;
1956
+ slot.command = SLOT_COMMAND_NONE;
1957
+ slot.release();
1958
+ slot.print_timings();
1959
+ send_final_response(slot);
1960
+ continue;
1961
+ }
1962
+ } else {
1963
+ if (slot.params.n_keep < 0) {
1964
+ slot.params.n_keep = slot.n_prompt_tokens;
1965
+ }
1966
+ slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
1967
+
1968
+ // if input prompt is too big, truncate it (if group attention self-extend is disabled)
1969
+ if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
1970
+ const int n_left = slot.n_ctx - slot.params.n_keep;
1971
+
1972
+ const int n_block_size = n_left / 2;
1973
+ const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
1974
+
1975
+ std::vector<llama_token> new_tokens(
1976
+ prompt_tokens.begin(),
1977
+ prompt_tokens.begin() + slot.params.n_keep);
1978
+
1979
+ new_tokens.insert(
1980
+ new_tokens.end(),
1981
+ prompt_tokens.begin() + slot.params.n_keep + erased_blocks * n_block_size,
1982
+ prompt_tokens.end());
1983
+
1984
+ prompt_tokens = std::move(new_tokens);
1985
+
1986
+ slot.truncated = true;
1987
+ slot.n_prompt_tokens = prompt_tokens.size();
1988
+
1989
+ LOG_VERBOSE("input truncated", {
1990
+ {"id_slot", slot.id},
1991
+ {"id_task", slot.id_task},
1992
+ {"n_ctx", slot.n_ctx},
1993
+ {"n_keep", slot.params.n_keep},
1994
+ {"n_left", n_left},
1995
+ {"n_prompt_tokens", slot.n_prompt_tokens},
1996
+ {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
1997
+ });
1998
+
1999
+ GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
2000
+ }
2001
+
2002
+ llama_sampling_reset(slot.ctx_sampling);
2003
+
2004
+ if (!slot.params.cache_prompt) {
2005
+ slot.n_past_se = 0;
2006
+ slot.ga_i = 0;
2007
+ } else {
2008
+ GGML_ASSERT(slot.ga_n == 1);
2009
+
2010
+ // reuse any previously computed tokens that are common with the new prompt
2011
+ slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
2012
+
2013
+ // push the prompt into the sampling context (do not apply grammar)
2014
+ for (int i = 0; i < slot.n_past; ++i) {
2015
+ llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
2016
+ }
2017
+ }
2018
+ }
2019
+
2020
+ if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
2021
+ // we have to evaluate at least 1 token to generate logits.
2022
+ LOG_INFO("we have to evaluate at least 1 token to generate logits", {
2023
+ { "id_slot", slot.id },
2024
+ { "id_task", slot.id_task }
2025
+ });
2026
+
2027
+ slot.n_past--;
2028
+ if (slot.ga_i > 0) {
2029
+ slot.n_past_se--;
2030
+ }
2031
+ }
2032
+
2033
+ slot.n_prompt_tokens_processed = 0;
2034
+ }
2035
+
2036
+ if (slot.embedding) {
2037
+ // cannot fit the prompt in the current batch - will try next iter
2038
+ if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
2039
+ continue;
2040
+ }
2041
+ }
2042
+
2043
+ // keep only the common part
2044
+ int p0 = (int) system_tokens.size() + slot.n_past;
2045
+ if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
2046
+ // could not partially delete (likely using a non-Transformer model)
2047
+ llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
2048
+
2049
+ p0 = (int) system_tokens.size();
2050
+ if (p0 != 0) {
2051
+ // copy over the system prompt when there is one
2052
+ llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
2053
+ }
2054
+
2055
+ // there is no common part left (except for the system prompt)
2056
+ slot.n_past = 0;
2057
+ slot.n_past_se = 0;
2058
+ slot.ga_i = 0;
2059
+ // TODO: is the system prompt ever in the sampling context?
2060
+ llama_sampling_reset(slot.ctx_sampling);
2061
+ }
2062
+
2063
+ // remove the non-common part from the cache
2064
+ slot.cache_tokens.resize(slot.n_past);
2065
+
2066
+ LOG_INFO("kv cache rm [p0, end)", {
2067
+ { "id_slot", slot.id },
2068
+ { "id_task", slot.id_task },
2069
+ { "p0", p0 }
2070
+ });
2071
+
2072
+ int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
2073
+
2074
+ int32_t ga_i = slot.ga_i;
2075
+ int32_t ga_n = slot.ga_n;
2076
+ int32_t ga_w = slot.ga_w;
2077
+
2078
+ // add prompt tokens for processing in the current batch
2079
+ // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
2080
+ for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
2081
+ if (slot.ga_n != 1) {
2082
+ while (slot_npast >= ga_i + ga_w) {
2083
+ const int bd = (ga_w/ga_n)*(ga_n - 1);
2084
+ slot_npast -= bd;
2085
+ ga_i += ga_w/ga_n;
2086
+ }
2087
+ }
2088
+
2089
+ llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
2090
+
2091
+ if (slot.params.cache_prompt) {
2092
+ slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
2093
+ }
2094
+
2095
+ slot.n_prompt_tokens_processed++;
2096
+ slot_npast++;
2097
+ }
2098
+
2099
+ LOG_VERBOSE("prompt processing progress", {
2100
+ {"id_slot", slot.id},
2101
+ {"n_past", slot.n_past},
2102
+ {"n_ctx", n_ctx},
2103
+ {"n_tokens", batch.n_tokens},
2104
+ {"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
2105
+ });
2106
+
2107
+ // entire prompt has been processed - start decoding new tokens
2108
+ if (slot.n_past == slot.n_prompt_tokens) {
2109
+ slot.state = SLOT_STATE_PROCESSING;
2110
+ slot.command = SLOT_COMMAND_NONE;
2111
+
2112
+ GGML_ASSERT(batch.n_tokens > 0);
2113
+
2114
+ // extract the logits only for the last token
2115
+ batch.logits[batch.n_tokens - 1] = true;
2116
+
2117
+ slot.n_decoded = 0;
2118
+ slot.i_batch = batch.n_tokens - 1;
2119
+
2120
+ LOG_VERBOSE("prompt done", {
2121
+ {"id_slot", slot.id},
2122
+ {"n_past", slot.n_past},
2123
+ {"n_ctx", n_ctx},
2124
+ {"n_tokens", batch.n_tokens},
2125
+ });
2126
+ }
2127
+ }
2128
+
2129
+ if (batch.n_tokens >= n_batch) {
2130
+ break;
2131
+ }
2132
+ }
2133
+ }
2134
+
2135
+ if (batch.n_tokens == 0) {
2136
+ LOG_VERBOSE("no tokens to decode", {});
2137
+ return;
2138
+ }
2139
+
2140
+ LOG_VERBOSE("decoding batch", {
2141
+ {"n_tokens", batch.n_tokens},
2142
+ });
2143
+
2144
+ // process the created batch of tokens
2145
+ for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
2146
+ const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
2147
+
2148
+ for (auto & slot : slots) {
2149
+ if (slot.ga_n != 1) {
2150
+ // context extension via Self-Extend
2151
+ // TODO: simplify and/or abstract this
2152
+ while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
2153
+ const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
2154
+ const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
2155
+ const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
2156
+
2157
+ LOG_TEE("\n");
2158
+ LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2159
+ LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
2160
+ LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2161
+
2162
+ llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
2163
+ llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
2164
+ llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
2165
+
2166
+ slot.n_past_se -= bd;
2167
+
2168
+ slot.ga_i += slot.ga_w / slot.ga_n;
2169
+
2170
+ LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
2171
+ }
2172
+
2173
+ slot.n_past_se += n_tokens;
2174
+ }
2175
+ }
2176
+
2177
+ llama_batch batch_view = {
2178
+ n_tokens,
2179
+ batch.token + i,
2180
+ nullptr,
2181
+ batch.pos + i,
2182
+ batch.n_seq_id + i,
2183
+ batch.seq_id + i,
2184
+ batch.logits + i,
2185
+ 0, 0, 0, // unused
2186
+ };
2187
+
2188
+ const int ret = llama_decode(ctx, batch_view);
2189
+
2190
+ if (ret != 0) {
2191
+ if (n_batch == 1 || ret < 0) {
2192
+ // if you get here, it means the KV cache is full - try increasing it via the context size
2193
+ LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
2194
+ {"i", i},
2195
+ {"n_batch", ret},
2196
+ {"ret", ret},
2197
+ });
2198
+ for (auto & slot : slots) {
2199
+ slot.state = SLOT_STATE_PROCESSING;
2200
+ slot.command = SLOT_COMMAND_NONE;
2201
+ slot.release();
2202
+ send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
2203
+ }
2204
+ break; // break loop of n_batch
2205
+ }
2206
+
2207
+ // retry with half the batch size to try to find a free slot in the KV cache
2208
+ n_batch /= 2;
2209
+ i -= n_batch;
2210
+
2211
+ LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", {
2212
+ {"i", i},
2213
+ {"n_batch", n_batch},
2214
+ {"ret", ret},
2215
+ });
2216
+
2217
+ continue; // continue loop of n_batch
2218
+ }
2219
+
2220
+ for (auto & slot : slots) {
2221
+ if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
2222
+ continue; // continue loop of slots
2223
+ }
2224
+
2225
+ // prompt evaluated for embedding
2226
+ if (slot.embedding) {
2227
+ send_embedding(slot, batch_view);
2228
+ slot.release();
2229
+ slot.i_batch = -1;
2230
+ continue; // continue loop of slots
2231
+ }
2232
+
2233
+ completion_token_output result;
2234
+ const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
2235
+
2236
+ llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
2237
+
2238
+ slot.n_decoded += 1;
2239
+ if (slot.n_decoded == 1) {
2240
+ slot.t_start_generation = ggml_time_us();
2241
+ slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2242
+ metrics.on_prompt_eval(slot);
2243
+ }
2244
+
2245
+ llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
2246
+ result.tok = id;
2247
+
2248
+ const int32_t n_probs = slot.sparams.n_probs;
2249
+ if (slot.sparams.temp <= 0 && n_probs > 0) {
2250
+ // for llama_sample_token_greedy we need to sort candidates
2251
+ llama_sample_softmax(ctx, &cur_p);
2252
+ }
2253
+
2254
+ for (size_t i = 0; i < std::min(cur_p.size, (size_t) n_probs); ++i) {
2255
+ result.probs.push_back({
2256
+ cur_p.data[i].id,
2257
+ cur_p.data[i].p
2258
+ });
2259
+ }
2260
+
2261
+ if (!process_token(result, slot)) {
2262
+ slot.release();
2263
+ slot.print_timings();
2264
+ send_final_response(slot);
2265
+ metrics.on_prediction(slot);
2266
+ }
2267
+
2268
+ slot.i_batch = -1;
2269
+ }
2270
+ }
2271
+
2272
+ LOG_VERBOSE("run slots completed", {});
2273
+ }
2274
+
2275
+ json model_meta() const {
2276
+ return json {
2277
+ {"vocab_type", llama_vocab_type (model)},
2278
+ {"n_vocab", llama_n_vocab (model)},
2279
+ {"n_ctx_train", llama_n_ctx_train (model)},
2280
+ {"n_embd", llama_n_embd (model)},
2281
+ {"n_params", llama_model_n_params(model)},
2282
+ {"size", llama_model_size (model)},
2283
+ };
2284
+ }
2285
+ };
2286
+
2287
+ static void server_print_usage(const char * argv0, const gpt_params & params, const server_params & sparams) {
2288
+ printf("usage: %s [options]\n", argv0);
2289
+ printf("\n");
2290
+ printf("options:\n");
2291
+ printf(" -h, --help show this help message and exit\n");
2292
+ printf(" -v, --verbose verbose output (default: %s)\n", server_verbose ? "enabled" : "disabled");
2293
+ printf(" -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
2294
+ printf(" -tb N, --threads-batch N number of threads to use during batch and prompt processing (default: same as --threads)\n");
2295
+ printf(" --threads-http N number of threads in the http server pool to process requests (default: max(hardware concurrency - 1, --parallel N + 2))\n");
2296
+ printf(" -c N, --ctx-size N size of the prompt context (default: %d)\n", params.n_ctx);
2297
+ printf(" --rope-scaling {none,linear,yarn}\n");
2298
+ printf(" RoPE frequency scaling method, defaults to linear unless specified by the model\n");
2299
+ printf(" --rope-freq-base N RoPE base frequency (default: loaded from model)\n");
2300
+ printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n");
2301
+ printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n");
2302
+ printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
2303
+ printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
2304
+ printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
2305
+ printf(" --pooling {none,mean,cls} pooling type for embeddings, use model default if unspecified\n");
2306
+ printf(" -dt N, --defrag-thold N\n");
2307
+ printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
2308
+ printf(" -b N, --batch-size N logical maximum batch size (default: %d)\n", params.n_batch);
2309
+ printf(" -ub N, --ubatch-size N physical maximum batch size (default: %d)\n", params.n_ubatch);
2310
+ if (llama_supports_mlock()) {
2311
+ printf(" --mlock force system to keep model in RAM rather than swapping or compressing\n");
2312
+ }
2313
+ if (llama_supports_mmap()) {
2314
+ printf(" --no-mmap do not memory-map model (slower load but may reduce pageouts if not using mlock)\n");
2315
+ }
2316
+ printf(" --numa TYPE attempt optimizations that help on some NUMA systems\n");
2317
+ printf(" - distribute: spread execution evenly over all nodes\n");
2318
+ printf(" - isolate: only spawn threads on CPUs on the node that execution started on\n");
2319
+ printf(" - numactl: use the CPU map provided my numactl\n");
2320
+ if (llama_supports_gpu_offload()) {
2321
+ printf(" -ngl N, --n-gpu-layers N\n");
2322
+ printf(" number of layers to store in VRAM\n");
2323
+ printf(" -sm SPLIT_MODE, --split-mode SPLIT_MODE\n");
2324
+ printf(" how to split the model across multiple GPUs, one of:\n");
2325
+ printf(" - none: use one GPU only\n");
2326
+ printf(" - layer (default): split layers and KV across GPUs\n");
2327
+ printf(" - row: split rows across GPUs\n");
2328
+ printf(" -ts SPLIT --tensor-split SPLIT\n");
2329
+ printf(" fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1\n");
2330
+ printf(" -mg i, --main-gpu i the GPU to use for the model (with split-mode = none),\n");
2331
+ printf(" or for intermediate results and KV (with split-mode = row)\n");
2332
+ printf(" -nkvo, --no-kv-offload\n");
2333
+ printf(" disable KV offload\n");
2334
+ }
2335
+ printf(" -m FNAME, --model FNAME\n");
2336
+ printf(" model path (default: %s)\n", params.model.c_str());
2337
+ printf(" -mu MODEL_URL, --model-url MODEL_URL\n");
2338
+ printf(" model download url (default: unused)\n");
2339
+ printf(" -hfr REPO, --hf-repo REPO\n");
2340
+ printf(" Hugging Face model repository (default: unused)\n");
2341
+ printf(" -hff FILE, --hf-file FILE\n");
2342
+ printf(" Hugging Face model file (default: unused)\n");
2343
+ printf(" -a ALIAS, --alias ALIAS\n");
2344
+ printf(" set an alias for the model, will be added as `model` field in completion response\n");
2345
+ printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
2346
+ printf(" --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
2347
+ printf(" --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
2348
+ printf(" --port PORT port to listen (default (default: %d)\n", sparams.port);
2349
+ printf(" --path PUBLIC_PATH path from which to serve static files (default: disabled)\n");
2350
+ printf(" --api-key API_KEY optional api key to enhance server security. If set, requests must include this key for access.\n");
2351
+ printf(" --api-key-file FNAME path to file containing api keys delimited by new lines. If set, requests must include one of the keys for access.\n");
2352
+ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2353
+ printf(" --ssl-key-file FNAME path to file a PEM-encoded SSL private key\n");
2354
+ printf(" --ssl-cert-file FNAME path to file a PEM-encoded SSL certificate\n");
2355
+ #endif
2356
+ printf(" -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
2357
+ printf(" --embeddings enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
2358
+ printf(" -np N, --parallel N number of slots for process requests (default: %d)\n", params.n_parallel);
2359
+ printf(" -cb, --cont-batching enable continuous batching (a.k.a dynamic batching) (default: enabled)\n");
2360
+ printf(" -spf FNAME, --system-prompt-file FNAME\n");
2361
+ printf(" set a file to load a system prompt (initial prompt of all slots), this is useful for chat applications.\n");
2362
+ printf(" -ctk TYPE, --cache-type-k TYPE\n");
2363
+ printf(" KV cache data type for K (default: f16)\n");
2364
+ printf(" -ctv TYPE, --cache-type-v TYPE\n");
2365
+ printf(" KV cache data type for V (default: f16)\n");
2366
+ printf(" --log-format log output format: json or text (default: json)\n");
2367
+ printf(" --log-disable disables logging to a file.\n");
2368
+ printf(" --slots-endpoint-disable disables slots monitoring endpoint.\n");
2369
+ printf(" --metrics enable prometheus compatible metrics endpoint (default: %s).\n", sparams.metrics_endpoint ? "enabled" : "disabled");
2370
+ printf(" --slot-save-path PATH path to save slot kv cache (default: disabled)\n");
2371
+ printf("\n");
2372
+ printf(" -n, --n-predict maximum tokens to predict (default: %d)\n", params.n_predict);
2373
+ printf(" --override-kv KEY=TYPE:VALUE\n");
2374
+ printf(" advanced option to override model metadata by key. may be specified multiple times.\n");
2375
+ printf(" types: int, float, bool. example: --override-kv tokenizer.ggml.add_bos_token=bool:false\n");
2376
+ printf(" -gan N, --grp-attn-n N set the group attention factor to extend context size through self-extend(default: 1=disabled), used together with group attention width `--grp-attn-w`\n");
2377
+ printf(" -gaw N, --grp-attn-w N set the group attention width to extend context size through self-extend(default: 512), used together with group attention factor `--grp-attn-n`\n");
2378
+ printf(" --chat-template JINJA_TEMPLATE\n");
2379
+ printf(" set custom jinja chat template (default: template taken from model's metadata)\n");
2380
+ printf(" only commonly used templates are accepted:\n");
2381
+ printf(" https://github.com/ggerganov/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template\n");
2382
+ printf("\n");
2383
+ }
2384
+
2385
+ static void server_params_parse(int argc, char ** argv, server_params & sparams, gpt_params & params) {
2386
+ gpt_params default_params;
2387
+ server_params default_sparams;
2388
+
2389
+ std::string arg;
2390
+ bool invalid_param = false;
2391
+
2392
+ for (int i = 1; i < argc; i++) {
2393
+ arg = argv[i];
2394
+ if (arg == "--port") {
2395
+ if (++i >= argc) {
2396
+ invalid_param = true;
2397
+ break;
2398
+ }
2399
+ sparams.port = std::stoi(argv[i]);
2400
+ } else if (arg == "--host") {
2401
+ if (++i >= argc) {
2402
+ invalid_param = true;
2403
+ break;
2404
+ }
2405
+ sparams.hostname = argv[i];
2406
+ } else if (arg == "--path") {
2407
+ if (++i >= argc) {
2408
+ invalid_param = true;
2409
+ break;
2410
+ }
2411
+ sparams.public_path = argv[i];
2412
+ } else if (arg == "--api-key") {
2413
+ if (++i >= argc) {
2414
+ invalid_param = true;
2415
+ break;
2416
+ }
2417
+ sparams.api_keys.push_back(argv[i]);
2418
+ } else if (arg == "--api-key-file") {
2419
+ if (++i >= argc) {
2420
+ invalid_param = true;
2421
+ break;
2422
+ }
2423
+ std::ifstream key_file(argv[i]);
2424
+ if (!key_file) {
2425
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
2426
+ invalid_param = true;
2427
+ break;
2428
+ }
2429
+ std::string key;
2430
+ while (std::getline(key_file, key)) {
2431
+ if (key.size() > 0) {
2432
+ sparams.api_keys.push_back(key);
2433
+ }
2434
+ }
2435
+ key_file.close();
2436
+
2437
+ }
2438
+ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2439
+ else if (arg == "--ssl-key-file") {
2440
+ if (++i >= argc) {
2441
+ invalid_param = true;
2442
+ break;
2443
+ }
2444
+ sparams.ssl_key_file = argv[i];
2445
+ } else if (arg == "--ssl-cert-file") {
2446
+ if (++i >= argc) {
2447
+ invalid_param = true;
2448
+ break;
2449
+ }
2450
+ sparams.ssl_cert_file = argv[i];
2451
+ }
2452
+ #endif
2453
+ else if (arg == "--timeout" || arg == "-to") {
2454
+ if (++i >= argc) {
2455
+ invalid_param = true;
2456
+ break;
2457
+ }
2458
+ sparams.read_timeout = std::stoi(argv[i]);
2459
+ sparams.write_timeout = std::stoi(argv[i]);
2460
+ } else if (arg == "-m" || arg == "--model") {
2461
+ if (++i >= argc) {
2462
+ invalid_param = true;
2463
+ break;
2464
+ }
2465
+ params.model = argv[i];
2466
+ } else if (arg == "-mu" || arg == "--model-url") {
2467
+ if (++i >= argc) {
2468
+ invalid_param = true;
2469
+ break;
2470
+ }
2471
+ params.model_url = argv[i];
2472
+ } else if (arg == "-hfr" || arg == "--hf-repo") {
2473
+ if (++i >= argc) {
2474
+ invalid_param = true;
2475
+ break;
2476
+ }
2477
+ params.hf_repo = argv[i];
2478
+ } else if (arg == "-hff" || arg == "--hf-file") {
2479
+ if (++i >= argc) {
2480
+ invalid_param = true;
2481
+ break;
2482
+ }
2483
+ params.hf_file = argv[i];
2484
+ } else if (arg == "-a" || arg == "--alias") {
2485
+ if (++i >= argc) {
2486
+ invalid_param = true;
2487
+ break;
2488
+ }
2489
+ params.model_alias = argv[i];
2490
+ } else if (arg == "-h" || arg == "--help") {
2491
+ server_print_usage(argv[0], default_params, default_sparams);
2492
+ exit(0);
2493
+ } else if (arg == "-c" || arg == "--ctx-size" || arg == "--ctx_size") {
2494
+ if (++i >= argc) {
2495
+ invalid_param = true;
2496
+ break;
2497
+ }
2498
+ params.n_ctx = std::stoi(argv[i]);
2499
+ } else if (arg == "--rope-scaling") {
2500
+ if (++i >= argc) {
2501
+ invalid_param = true;
2502
+ break;
2503
+ }
2504
+ std::string value(argv[i]);
2505
+ /**/ if (value == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; }
2506
+ else if (value == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; }
2507
+ else if (value == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; }
2508
+ else { invalid_param = true; break; }
2509
+ } else if (arg == "--rope-freq-base") {
2510
+ if (++i >= argc) {
2511
+ invalid_param = true;
2512
+ break;
2513
+ }
2514
+ params.rope_freq_base = std::stof(argv[i]);
2515
+ } else if (arg == "--rope-freq-scale") {
2516
+ if (++i >= argc) {
2517
+ invalid_param = true;
2518
+ break;
2519
+ }
2520
+ params.rope_freq_scale = std::stof(argv[i]);
2521
+ } else if (arg == "--yarn-ext-factor") {
2522
+ if (++i >= argc) {
2523
+ invalid_param = true;
2524
+ break;
2525
+ }
2526
+ params.yarn_ext_factor = std::stof(argv[i]);
2527
+ }
2528
+ else if (arg == "--yarn-attn-factor") {
2529
+ if (++i >= argc) {
2530
+ invalid_param = true;
2531
+ break;
2532
+ }
2533
+ params.yarn_attn_factor = std::stof(argv[i]);
2534
+ } else if (arg == "--yarn-beta-fast") {
2535
+ if (++i >= argc) {
2536
+ invalid_param = true;
2537
+ break;
2538
+ }
2539
+ params.yarn_beta_fast = std::stof(argv[i]);
2540
+ } else if (arg == "--yarn-beta-slow") {
2541
+ if (++i >= argc) {
2542
+ invalid_param = true;
2543
+ break;
2544
+ }
2545
+ params.yarn_beta_slow = std::stof(argv[i]);
2546
+ } else if (arg == "--pooling") {
2547
+ if (++i >= argc) {
2548
+ invalid_param = true;
2549
+ break;
2550
+ }
2551
+ std::string value(argv[i]);
2552
+ /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
2553
+ else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
2554
+ else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
2555
+ else { invalid_param = true; break; }
2556
+ } else if (arg == "--defrag-thold" || arg == "-dt") {
2557
+ if (++i >= argc) {
2558
+ invalid_param = true;
2559
+ break;
2560
+ }
2561
+ params.defrag_thold = std::stof(argv[i]);
2562
+ } else if (arg == "--threads" || arg == "-t") {
2563
+ if (++i >= argc)
2564
+ {
2565
+ invalid_param = true;
2566
+ break;
2567
+ }
2568
+ params.n_threads = std::stoi(argv[i]);
2569
+ } else if (arg == "--grp-attn-n" || arg == "-gan") {
2570
+ if (++i >= argc) {
2571
+ invalid_param = true;
2572
+ break;
2573
+ }
2574
+
2575
+ params.grp_attn_n = std::stoi(argv[i]);
2576
+ } else if (arg == "--grp-attn-w" || arg == "-gaw") {
2577
+ if (++i >= argc) {
2578
+ invalid_param = true;
2579
+ break;
2580
+ }
2581
+
2582
+ params.grp_attn_w = std::stoi(argv[i]);
2583
+ } else if (arg == "--threads-batch" || arg == "-tb") {
2584
+ if (++i >= argc) {
2585
+ invalid_param = true;
2586
+ break;
2587
+ }
2588
+ params.n_threads_batch = std::stoi(argv[i]);
2589
+ } else if (arg == "--threads-http") {
2590
+ if (++i >= argc) {
2591
+ invalid_param = true;
2592
+ break;
2593
+ }
2594
+ sparams.n_threads_http = std::stoi(argv[i]);
2595
+ } else if (arg == "-b" || arg == "--batch-size") {
2596
+ if (++i >= argc) {
2597
+ invalid_param = true;
2598
+ break;
2599
+ }
2600
+ params.n_batch = std::stoi(argv[i]);
2601
+ } else if (arg == "-ub" || arg == "--ubatch-size") {
2602
+ if (++i >= argc) {
2603
+ invalid_param = true;
2604
+ break;
2605
+ }
2606
+ params.n_ubatch = std::stoi(argv[i]);
2607
+ } else if (arg == "--gpu-layers" || arg == "-ngl" || arg == "--n-gpu-layers") {
2608
+ if (++i >= argc) {
2609
+ invalid_param = true;
2610
+ break;
2611
+ }
2612
+ if (llama_supports_gpu_offload()) {
2613
+ params.n_gpu_layers = std::stoi(argv[i]);
2614
+ } else {
2615
+ LOG_WARNING(
2616
+ "Not compiled with GPU offload support, --n-gpu-layers option will be ignored. "
2617
+ "See main README.md for information on enabling GPU BLAS support",
2618
+ {{"n_gpu_layers", params.n_gpu_layers}});
2619
+ }
2620
+ } else if (arg == "-nkvo" || arg == "--no-kv-offload") {
2621
+ params.no_kv_offload = true;
2622
+ } else if (arg == "--split-mode" || arg == "-sm") {
2623
+ if (++i >= argc) {
2624
+ invalid_param = true;
2625
+ break;
2626
+ }
2627
+ std::string arg_next = argv[i];
2628
+ if (arg_next == "none") {
2629
+ params.split_mode = LLAMA_SPLIT_MODE_NONE;
2630
+ } else if (arg_next == "layer") {
2631
+ params.split_mode = LLAMA_SPLIT_MODE_LAYER;
2632
+ } else if (arg_next == "row") {
2633
+ params.split_mode = LLAMA_SPLIT_MODE_ROW;
2634
+ } else {
2635
+ invalid_param = true;
2636
+ break;
2637
+ }
2638
+ #ifndef GGML_USE_CUDA
2639
+ fprintf(stderr, "warning: llama.cpp was compiled without CUDA. Setting the split mode has no effect.\n");
2640
+ #endif // GGML_USE_CUDA
2641
+ } else if (arg == "--tensor-split" || arg == "-ts") {
2642
+ if (++i >= argc) {
2643
+ invalid_param = true;
2644
+ break;
2645
+ }
2646
+ #if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
2647
+ std::string arg_next = argv[i];
2648
+
2649
+ // split string by , and /
2650
+ const std::regex regex{R"([,/]+)"};
2651
+ std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
2652
+ std::vector<std::string> split_arg{it, {}};
2653
+ GGML_ASSERT(split_arg.size() <= llama_max_devices());
2654
+
2655
+ for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) {
2656
+ if (i_device < split_arg.size()) {
2657
+ params.tensor_split[i_device] = std::stof(split_arg[i_device]);
2658
+ } else {
2659
+ params.tensor_split[i_device] = 0.0f;
2660
+ }
2661
+ }
2662
+ #else
2663
+ LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n", {});
2664
+ #endif // GGML_USE_CUDA
2665
+ } else if (arg == "--main-gpu" || arg == "-mg") {
2666
+ if (++i >= argc) {
2667
+ invalid_param = true;
2668
+ break;
2669
+ }
2670
+ #if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
2671
+ params.main_gpu = std::stoi(argv[i]);
2672
+ #else
2673
+ LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a main GPU.", {});
2674
+ #endif
2675
+ } else if (arg == "--lora") {
2676
+ if (++i >= argc) {
2677
+ invalid_param = true;
2678
+ break;
2679
+ }
2680
+ params.lora_adapter.emplace_back(argv[i], 1.0f);
2681
+ params.use_mmap = false;
2682
+ } else if (arg == "--lora-scaled") {
2683
+ if (++i >= argc) {
2684
+ invalid_param = true;
2685
+ break;
2686
+ }
2687
+ const char * lora_adapter = argv[i];
2688
+ if (++i >= argc) {
2689
+ invalid_param = true;
2690
+ break;
2691
+ }
2692
+ params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
2693
+ params.use_mmap = false;
2694
+ } else if (arg == "--lora-base") {
2695
+ if (++i >= argc) {
2696
+ invalid_param = true;
2697
+ break;
2698
+ }
2699
+ params.lora_base = argv[i];
2700
+ } else if (arg == "-v" || arg == "--verbose") {
2701
+ #if SERVER_VERBOSE != 1
2702
+ LOG_WARNING("server.cpp is not built with verbose logging.", {});
2703
+ #else
2704
+ server_verbose = true;
2705
+ #endif
2706
+ } else if (arg == "--mlock") {
2707
+ params.use_mlock = true;
2708
+ } else if (arg == "--no-mmap") {
2709
+ params.use_mmap = false;
2710
+ } else if (arg == "--numa") {
2711
+ if (++i >= argc) {
2712
+ invalid_param = true;
2713
+ break;
2714
+ } else {
2715
+ std::string value(argv[i]);
2716
+ /**/ if (value == "distribute" || value == "" ) { params.numa = GGML_NUMA_STRATEGY_DISTRIBUTE; }
2717
+ else if (value == "isolate") { params.numa = GGML_NUMA_STRATEGY_ISOLATE; }
2718
+ else if (value == "numactl") { params.numa = GGML_NUMA_STRATEGY_NUMACTL; }
2719
+ else { invalid_param = true; break; }
2720
+ }
2721
+ } else if (arg == "--embedding" || arg == "--embeddings") {
2722
+ params.embedding = true;
2723
+ } else if (arg == "-cb" || arg == "--cont-batching") {
2724
+ params.cont_batching = true;
2725
+ } else if (arg == "-np" || arg == "--parallel") {
2726
+ if (++i >= argc) {
2727
+ invalid_param = true;
2728
+ break;
2729
+ }
2730
+ params.n_parallel = std::stoi(argv[i]);
2731
+ } else if (arg == "-n" || arg == "--n-predict") {
2732
+ if (++i >= argc) {
2733
+ invalid_param = true;
2734
+ break;
2735
+ }
2736
+ params.n_predict = std::stoi(argv[i]);
2737
+ } else if (arg == "-spf" || arg == "--system-prompt-file") {
2738
+ if (++i >= argc) {
2739
+ invalid_param = true;
2740
+ break;
2741
+ }
2742
+ std::ifstream file(argv[i]);
2743
+ if (!file) {
2744
+ fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
2745
+ invalid_param = true;
2746
+ break;
2747
+ }
2748
+ std::string system_prompt;
2749
+ std::copy(
2750
+ std::istreambuf_iterator<char>(file),
2751
+ std::istreambuf_iterator<char>(),
2752
+ std::back_inserter(system_prompt)
2753
+ );
2754
+ sparams.system_prompt = system_prompt;
2755
+ } else if (arg == "-ctk" || arg == "--cache-type-k") {
2756
+ params.cache_type_k = argv[++i];
2757
+ } else if (arg == "-ctv" || arg == "--cache-type-v") {
2758
+ params.cache_type_v = argv[++i];
2759
+ } else if (arg == "--log-format") {
2760
+ if (++i >= argc) {
2761
+ invalid_param = true;
2762
+ break;
2763
+ }
2764
+ if (std::strcmp(argv[i], "json") == 0) {
2765
+ server_log_json = true;
2766
+ } else if (std::strcmp(argv[i], "text") == 0) {
2767
+ server_log_json = false;
2768
+ } else {
2769
+ invalid_param = true;
2770
+ break;
2771
+ }
2772
+ } else if (arg == "--log-disable") {
2773
+ log_set_target(stdout);
2774
+ LOG_INFO("logging to file is disabled.", {});
2775
+ } else if (arg == "--slots-endpoint-disable") {
2776
+ sparams.slots_endpoint = false;
2777
+ } else if (arg == "--metrics") {
2778
+ sparams.metrics_endpoint = true;
2779
+ } else if (arg == "--slot-save-path") {
2780
+ if (++i >= argc) {
2781
+ invalid_param = true;
2782
+ break;
2783
+ }
2784
+ sparams.slot_save_path = argv[i];
2785
+ // if doesn't end with DIRECTORY_SEPARATOR, add it
2786
+ if (!sparams.slot_save_path.empty() && sparams.slot_save_path[sparams.slot_save_path.size() - 1] != DIRECTORY_SEPARATOR) {
2787
+ sparams.slot_save_path += DIRECTORY_SEPARATOR;
2788
+ }
2789
+ } else if (arg == "--chat-template") {
2790
+ if (++i >= argc) {
2791
+ invalid_param = true;
2792
+ break;
2793
+ }
2794
+ if (!verify_custom_template(argv[i])) {
2795
+ fprintf(stderr, "error: the supplied chat template is not supported: %s\n", argv[i]);
2796
+ fprintf(stderr, "note: llama.cpp does not use jinja parser, we only support commonly used templates\n");
2797
+ invalid_param = true;
2798
+ break;
2799
+ }
2800
+ sparams.chat_template = argv[i];
2801
+ } else if (arg == "--override-kv") {
2802
+ if (++i >= argc) {
2803
+ invalid_param = true;
2804
+ break;
2805
+ }
2806
+ char * sep = strchr(argv[i], '=');
2807
+ if (sep == nullptr || sep - argv[i] >= 128) {
2808
+ fprintf(stderr, "error: Malformed KV override: %s\n", argv[i]);
2809
+ invalid_param = true;
2810
+ break;
2811
+ }
2812
+
2813
+ struct llama_model_kv_override kvo;
2814
+ std::strncpy(kvo.key, argv[i], sep - argv[i]);
2815
+ kvo.key[sep - argv[i]] = 0;
2816
+ sep++;
2817
+ if (strncmp(sep, "int:", 4) == 0) {
2818
+ sep += 4;
2819
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
2820
+ kvo.int_value = std::atol(sep);
2821
+ } else if (strncmp(sep, "float:", 6) == 0) {
2822
+ sep += 6;
2823
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
2824
+ kvo.float_value = std::atof(sep);
2825
+ } else if (strncmp(sep, "bool:", 5) == 0) {
2826
+ sep += 5;
2827
+ kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
2828
+ if (std::strcmp(sep, "true") == 0) {
2829
+ kvo.bool_value = true;
2830
+ } else if (std::strcmp(sep, "false") == 0) {
2831
+ kvo.bool_value = false;
2832
+ } else {
2833
+ fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
2834
+ invalid_param = true;
2835
+ break;
2836
+ }
2837
+ } else {
2838
+ fprintf(stderr, "error: Invalid type for KV override: %s\n", argv[i]);
2839
+ invalid_param = true;
2840
+ break;
2841
+ }
2842
+ params.kv_overrides.push_back(kvo);
2843
+ } else {
2844
+ fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
2845
+ server_print_usage(argv[0], default_params, default_sparams);
2846
+ exit(1);
2847
+ }
2848
+ }
2849
+
2850
+ if (!params.kv_overrides.empty()) {
2851
+ params.kv_overrides.emplace_back();
2852
+ params.kv_overrides.back().key[0] = 0;
2853
+ }
2854
+
2855
+ if (invalid_param) {
2856
+ fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str());
2857
+ server_print_usage(argv[0], default_params, default_sparams);
2858
+ exit(1);
2859
+ }
2860
+ }
2861
+
2862
+ static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
2863
+ // skip GH copilot requests when using default port
2864
+ if (req.path == "/v1/health" || req.path == "/v1/completions") {
2865
+ return;
2866
+ }
2867
+
2868
+ LOG_INFO("request", {
2869
+ {"remote_addr", req.remote_addr},
2870
+ {"remote_port", req.remote_port},
2871
+ {"status", res.status},
2872
+ {"method", req.method},
2873
+ {"path", req.path},
2874
+ {"params", req.params},
2875
+ });
2876
+
2877
+ LOG_VERBOSE("request", {
2878
+ {"request", req.body},
2879
+ {"response", res.body},
2880
+ });
2881
+ }
2882
+
2883
+ std::function<void(int)> shutdown_handler;
2884
+ std::atomic_flag is_terminating = ATOMIC_FLAG_INIT;
2885
+
2886
+ inline void signal_handler(int signal) {
2887
+ if (is_terminating.test_and_set()) {
2888
+ // in case it hangs, we can force terminate the server by hitting Ctrl+C twice
2889
+ // this is for better developer experience, we can remove when the server is stable enough
2890
+ fprintf(stderr, "Received second interrupt, terminating immediately.\n");
2891
+ exit(1);
2892
+ }
2893
+
2894
+ shutdown_handler(signal);
2895
+ }
2896
+
2897
+ int main(int argc, char ** argv) {
2898
+ #if SERVER_VERBOSE != 1
2899
+ log_disable();
2900
+ #endif
2901
+ // own arguments required by this example
2902
+ gpt_params params;
2903
+ server_params sparams;
2904
+
2905
+ // struct that contains llama context and inference
2906
+ server_context ctx_server;
2907
+
2908
+ server_params_parse(argc, argv, sparams, params);
2909
+
2910
+ if (!sparams.system_prompt.empty()) {
2911
+ ctx_server.system_prompt_set(json::parse(sparams.system_prompt));
2912
+ }
2913
+
2914
+ if (params.model_alias == "unknown") {
2915
+ params.model_alias = params.model;
2916
+ }
2917
+
2918
+ llama_backend_init();
2919
+ llama_numa_init(params.numa);
2920
+
2921
+ LOG_INFO("build info", {
2922
+ {"build", LLAMA_BUILD_NUMBER},
2923
+ {"commit", LLAMA_COMMIT}
2924
+ });
2925
+
2926
+ LOG_INFO("system info", {
2927
+ {"n_threads", params.n_threads},
2928
+ {"n_threads_batch", params.n_threads_batch},
2929
+ {"total_threads", std::thread::hardware_concurrency()},
2930
+ {"system_info", llama_print_system_info()},
2931
+ });
2932
+
2933
+ std::unique_ptr<httplib::Server> svr;
2934
+ #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2935
+ if (sparams.ssl_key_file != "" && sparams.ssl_cert_file != "") {
2936
+ LOG_INFO("Running with SSL", {{"key", sparams.ssl_key_file}, {"cert", sparams.ssl_cert_file}});
2937
+ svr.reset(
2938
+ new httplib::SSLServer(sparams.ssl_cert_file.c_str(), sparams.ssl_key_file.c_str())
2939
+ );
2940
+ } else {
2941
+ LOG_INFO("Running without SSL", {});
2942
+ svr.reset(new httplib::Server());
2943
+ }
2944
+ #else
2945
+ svr.reset(new httplib::Server());
2946
+ #endif
2947
+
2948
+ std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
2949
+
2950
+ svr->set_default_headers({{"Server", "llama.cpp"}});
2951
+
2952
+ // CORS preflight
2953
+ svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
2954
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2955
+ res.set_header("Access-Control-Allow-Credentials", "true");
2956
+ res.set_header("Access-Control-Allow-Methods", "POST");
2957
+ res.set_header("Access-Control-Allow-Headers", "*");
2958
+ return res.set_content("", "application/json; charset=utf-8");
2959
+ });
2960
+
2961
+ svr->set_logger(log_server_request);
2962
+
2963
+ auto res_error = [](httplib::Response & res, json error_data) {
2964
+ json final_response {{"error", error_data}};
2965
+ res.set_content(final_response.dump(), "application/json; charset=utf-8");
2966
+ res.status = json_value(error_data, "code", 500);
2967
+ };
2968
+
2969
+ svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
2970
+ std::string message;
2971
+ try {
2972
+ std::rethrow_exception(std::move(ep));
2973
+ } catch (std::exception & e) {
2974
+ message = e.what();
2975
+ } catch (...) {
2976
+ message = "Unknown Exception";
2977
+ }
2978
+
2979
+ json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
2980
+ LOG_VERBOSE("Got exception", formatted_error);
2981
+ res_error(res, formatted_error);
2982
+ });
2983
+
2984
+ svr->set_error_handler([&res_error](const httplib::Request &, httplib::Response & res) {
2985
+ if (res.status == 404) {
2986
+ res_error(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND));
2987
+ }
2988
+ // for other error codes, we skip processing here because it's already done by res_error()
2989
+ });
2990
+
2991
+ // set timeouts and change hostname and port
2992
+ svr->set_read_timeout (sparams.read_timeout);
2993
+ svr->set_write_timeout(sparams.write_timeout);
2994
+
2995
+ if (!svr->bind_to_port(sparams.hostname, sparams.port)) {
2996
+ fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", sparams.hostname.c_str(), sparams.port);
2997
+ return 1;
2998
+ }
2999
+
3000
+ std::unordered_map<std::string, std::string> log_data;
3001
+
3002
+ log_data["hostname"] = sparams.hostname;
3003
+ log_data["port"] = std::to_string(sparams.port);
3004
+
3005
+ if (sparams.api_keys.size() == 1) {
3006
+ auto key = sparams.api_keys[0];
3007
+ log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0));
3008
+ } else if (sparams.api_keys.size() > 1) {
3009
+ log_data["api_key"] = "api_key: " + std::to_string(sparams.api_keys.size()) + " keys loaded";
3010
+ }
3011
+
3012
+ // load the model
3013
+ if (!ctx_server.load_model(params)) {
3014
+ state.store(SERVER_STATE_ERROR);
3015
+ return 1;
3016
+ } else {
3017
+ ctx_server.init();
3018
+ state.store(SERVER_STATE_READY);
3019
+ }
3020
+
3021
+ LOG_INFO("model loaded", {});
3022
+
3023
+ const auto model_meta = ctx_server.model_meta();
3024
+
3025
+ // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
3026
+ if (sparams.chat_template.empty()) {
3027
+ if (!ctx_server.validate_model_chat_template()) {
3028
+ LOG_ERROR("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
3029
+ sparams.chat_template = "chatml";
3030
+ }
3031
+ }
3032
+
3033
+ // print sample chat example to make it clear which template is used
3034
+ {
3035
+ json chat;
3036
+ chat.push_back({{"role", "system"}, {"content", "You are a helpful assistant"}});
3037
+ chat.push_back({{"role", "user"}, {"content", "Hello"}});
3038
+ chat.push_back({{"role", "assistant"}, {"content", "Hi there"}});
3039
+ chat.push_back({{"role", "user"}, {"content", "How are you?"}});
3040
+
3041
+ const std::string chat_example = format_chat(ctx_server.model, sparams.chat_template, chat);
3042
+
3043
+ LOG_INFO("chat template", {
3044
+ {"chat_example", chat_example},
3045
+ {"built_in", sparams.chat_template.empty()},
3046
+ });
3047
+ }
3048
+
3049
+ //
3050
+ // Middlewares
3051
+ //
3052
+
3053
+ auto middleware_validate_api_key = [&sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
3054
+ // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
3055
+ static const std::set<std::string> protected_endpoints = {
3056
+ "/props",
3057
+ "/completion",
3058
+ "/completions",
3059
+ "/v1/completions",
3060
+ "/chat/completions",
3061
+ "/v1/chat/completions",
3062
+ "/infill",
3063
+ "/tokenize",
3064
+ "/detokenize",
3065
+ "/embedding",
3066
+ "/embeddings",
3067
+ "/v1/embeddings",
3068
+ };
3069
+
3070
+ // If API key is not set, skip validation
3071
+ if (sparams.api_keys.empty()) {
3072
+ return true;
3073
+ }
3074
+
3075
+ // If path is not in protected_endpoints list, skip validation
3076
+ if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
3077
+ return true;
3078
+ }
3079
+
3080
+ // Check for API key in the header
3081
+ auto auth_header = req.get_header_value("Authorization");
3082
+
3083
+ std::string prefix = "Bearer ";
3084
+ if (auth_header.substr(0, prefix.size()) == prefix) {
3085
+ std::string received_api_key = auth_header.substr(prefix.size());
3086
+ if (std::find(sparams.api_keys.begin(), sparams.api_keys.end(), received_api_key) != sparams.api_keys.end()) {
3087
+ return true; // API key is valid
3088
+ }
3089
+ }
3090
+
3091
+ // API key is invalid or not provided
3092
+ // TODO: make another middleware for CORS related logic
3093
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3094
+ res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
3095
+
3096
+ LOG_WARNING("Unauthorized: Invalid API Key", {});
3097
+
3098
+ return false;
3099
+ };
3100
+
3101
+ // register server middlewares
3102
+ svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
3103
+ if (!middleware_validate_api_key(req, res)) {
3104
+ return httplib::Server::HandlerResponse::Handled;
3105
+ }
3106
+ return httplib::Server::HandlerResponse::Unhandled;
3107
+ });
3108
+
3109
+ //
3110
+ // Route handlers (or controllers)
3111
+ //
3112
+
3113
+ const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
3114
+ server_state current_state = state.load();
3115
+ switch (current_state) {
3116
+ case SERVER_STATE_READY:
3117
+ {
3118
+ // request slots data using task queue
3119
+ server_task task;
3120
+ task.id = ctx_server.queue_tasks.get_new_id();
3121
+ task.type = SERVER_TASK_TYPE_METRICS;
3122
+ task.id_target = -1;
3123
+
3124
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3125
+ ctx_server.queue_tasks.post(task);
3126
+
3127
+ // get the result
3128
+ server_task_result result = ctx_server.queue_results.recv(task.id);
3129
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
3130
+
3131
+ const int n_idle_slots = result.data["idle"];
3132
+ const int n_processing_slots = result.data["processing"];
3133
+
3134
+ json health = {
3135
+ {"status", "ok"},
3136
+ {"slots_idle", n_idle_slots},
3137
+ {"slots_processing", n_processing_slots}
3138
+ };
3139
+
3140
+ res.status = 200; // HTTP OK
3141
+ if (sparams.slots_endpoint && req.has_param("include_slots")) {
3142
+ health["slots"] = result.data["slots"];
3143
+ }
3144
+
3145
+ if (n_idle_slots == 0) {
3146
+ health["status"] = "no slot available";
3147
+ if (req.has_param("fail_on_no_slot")) {
3148
+ res.status = 503; // HTTP Service Unavailable
3149
+ }
3150
+ }
3151
+
3152
+ res.set_content(health.dump(), "application/json");
3153
+ break;
3154
+ }
3155
+ case SERVER_STATE_LOADING_MODEL:
3156
+ {
3157
+ res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
3158
+ } break;
3159
+ case SERVER_STATE_ERROR:
3160
+ {
3161
+ res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
3162
+ } break;
3163
+ }
3164
+ };
3165
+
3166
+ const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
3167
+ if (!sparams.slots_endpoint) {
3168
+ res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
3169
+ return;
3170
+ }
3171
+
3172
+ // request slots data using task queue
3173
+ server_task task;
3174
+ task.id = ctx_server.queue_tasks.get_new_id();
3175
+ task.id_multi = -1;
3176
+ task.id_target = -1;
3177
+ task.type = SERVER_TASK_TYPE_METRICS;
3178
+
3179
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3180
+ ctx_server.queue_tasks.post(task);
3181
+
3182
+ // get the result
3183
+ server_task_result result = ctx_server.queue_results.recv(task.id);
3184
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
3185
+
3186
+ res.set_content(result.data["slots"].dump(), "application/json");
3187
+ res.status = 200; // HTTP OK
3188
+ };
3189
+
3190
+ const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
3191
+ if (!sparams.metrics_endpoint) {
3192
+ res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
3193
+ return;
3194
+ }
3195
+
3196
+ // request slots data using task queue
3197
+ server_task task;
3198
+ task.id = ctx_server.queue_tasks.get_new_id();
3199
+ task.id_multi = -1;
3200
+ task.id_target = -1;
3201
+ task.type = SERVER_TASK_TYPE_METRICS;
3202
+ task.data.push_back({{"reset_bucket", true}});
3203
+
3204
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3205
+ ctx_server.queue_tasks.post(task);
3206
+
3207
+ // get the result
3208
+ server_task_result result = ctx_server.queue_results.recv(task.id);
3209
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
3210
+
3211
+ json data = result.data;
3212
+
3213
+ const uint64_t n_prompt_tokens_processed = data["n_prompt_tokens_processed"];
3214
+ const uint64_t t_prompt_processing = data["t_prompt_processing"];
3215
+
3216
+ const uint64_t n_tokens_predicted = data["n_tokens_predicted"];
3217
+ const uint64_t t_tokens_generation = data["t_tokens_generation"];
3218
+
3219
+ const int32_t kv_cache_used_cells = data["kv_cache_used_cells"];
3220
+
3221
+ // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
3222
+ json all_metrics_def = json {
3223
+ {"counter", {{
3224
+ {"name", "prompt_tokens_total"},
3225
+ {"help", "Number of prompt tokens processed."},
3226
+ {"value", (uint64_t) data["n_prompt_tokens_processed_total"]}
3227
+ }, {
3228
+ {"name", "prompt_seconds_total"},
3229
+ {"help", "Prompt process time"},
3230
+ {"value", (uint64_t) data["t_prompt_processing_total"] / 1.e3}
3231
+ }, {
3232
+ {"name", "tokens_predicted_total"},
3233
+ {"help", "Number of generation tokens processed."},
3234
+ {"value", (uint64_t) data["n_tokens_predicted_total"]}
3235
+ }, {
3236
+ {"name", "tokens_predicted_seconds_total"},
3237
+ {"help", "Predict process time"},
3238
+ {"value", (uint64_t) data["t_tokens_generation_total"] / 1.e3}
3239
+ }}},
3240
+ {"gauge", {{
3241
+ {"name", "prompt_tokens_seconds"},
3242
+ {"help", "Average prompt throughput in tokens/s."},
3243
+ {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
3244
+ },{
3245
+ {"name", "predicted_tokens_seconds"},
3246
+ {"help", "Average generation throughput in tokens/s."},
3247
+ {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
3248
+ },{
3249
+ {"name", "kv_cache_usage_ratio"},
3250
+ {"help", "KV-cache usage. 1 means 100 percent usage."},
3251
+ {"value", 1. * kv_cache_used_cells / params.n_ctx}
3252
+ },{
3253
+ {"name", "kv_cache_tokens"},
3254
+ {"help", "KV-cache tokens."},
3255
+ {"value", (uint64_t) data["kv_cache_tokens_count"]}
3256
+ },{
3257
+ {"name", "requests_processing"},
3258
+ {"help", "Number of request processing."},
3259
+ {"value", (uint64_t) data["processing"]}
3260
+ },{
3261
+ {"name", "requests_deferred"},
3262
+ {"help", "Number of request deferred."},
3263
+ {"value", (uint64_t) data["deferred"]}
3264
+ }}}
3265
+ };
3266
+
3267
+ std::stringstream prometheus;
3268
+
3269
+ for (const auto & el : all_metrics_def.items()) {
3270
+ const auto & type = el.key();
3271
+ const auto & metrics_def = el.value();
3272
+
3273
+ for (const auto & metric_def : metrics_def) {
3274
+ const std::string name = metric_def["name"];
3275
+ const std::string help = metric_def["help"];
3276
+
3277
+ auto value = json_value(metric_def, "value", 0.);
3278
+ prometheus << "# HELP llamacpp:" << name << " " << help << "\n"
3279
+ << "# TYPE llamacpp:" << name << " " << type << "\n"
3280
+ << "llamacpp:" << name << " " << value << "\n";
3281
+ }
3282
+ }
3283
+
3284
+ const int64_t t_start = data["t_start"];
3285
+ res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
3286
+
3287
+ res.set_content(prometheus.str(), "text/plain; version=0.0.4");
3288
+ res.status = 200; // HTTP OK
3289
+ };
3290
+
3291
+ const auto handle_slots_save = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
3292
+ json request_data = json::parse(req.body);
3293
+ std::string filename = request_data["filename"];
3294
+ if (!validate_file_name(filename)) {
3295
+ res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
3296
+ return;
3297
+ }
3298
+ std::string filepath = sparams.slot_save_path + filename;
3299
+
3300
+ server_task task;
3301
+ task.type = SERVER_TASK_TYPE_SLOT_SAVE;
3302
+ task.data = {
3303
+ { "id_slot", id_slot },
3304
+ { "filename", filename },
3305
+ { "filepath", filepath }
3306
+ };
3307
+
3308
+ const int id_task = ctx_server.queue_tasks.post(task);
3309
+ ctx_server.queue_results.add_waiting_task_id(id_task);
3310
+
3311
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3312
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3313
+
3314
+ if (result.error) {
3315
+ res_error(res, result.data);
3316
+ } else {
3317
+ res.set_content(result.data.dump(), "application/json");
3318
+ }
3319
+ };
3320
+
3321
+ const auto handle_slots_restore = [&ctx_server, &res_error, &sparams](const httplib::Request & req, httplib::Response & res, int id_slot) {
3322
+ json request_data = json::parse(req.body);
3323
+ std::string filename = request_data["filename"];
3324
+ if (!validate_file_name(filename)) {
3325
+ res_error(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST));
3326
+ return;
3327
+ }
3328
+ std::string filepath = sparams.slot_save_path + filename;
3329
+
3330
+ server_task task;
3331
+ task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
3332
+ task.data = {
3333
+ { "id_slot", id_slot },
3334
+ { "filename", filename },
3335
+ { "filepath", filepath }
3336
+ };
3337
+
3338
+ const int id_task = ctx_server.queue_tasks.post(task);
3339
+ ctx_server.queue_results.add_waiting_task_id(id_task);
3340
+
3341
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3342
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3343
+
3344
+ if (result.error) {
3345
+ res_error(res, result.data);
3346
+ } else {
3347
+ res.set_content(result.data.dump(), "application/json");
3348
+ }
3349
+ };
3350
+
3351
+ const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
3352
+ server_task task;
3353
+ task.type = SERVER_TASK_TYPE_SLOT_ERASE;
3354
+ task.data = {
3355
+ { "id_slot", id_slot },
3356
+ };
3357
+
3358
+ const int id_task = ctx_server.queue_tasks.post(task);
3359
+ ctx_server.queue_results.add_waiting_task_id(id_task);
3360
+
3361
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3362
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3363
+
3364
+ if (result.error) {
3365
+ res_error(res, result.data);
3366
+ } else {
3367
+ res.set_content(result.data.dump(), "application/json");
3368
+ }
3369
+ };
3370
+
3371
+ const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
3372
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3373
+
3374
+ std::string id_slot_str = req.path_params.at("id_slot");
3375
+ int id_slot;
3376
+
3377
+ try {
3378
+ id_slot = std::stoi(id_slot_str);
3379
+ } catch (const std::exception &) {
3380
+ res_error(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST));
3381
+ return;
3382
+ }
3383
+
3384
+ std::string action = req.get_param_value("action");
3385
+
3386
+ if (action == "save") {
3387
+ handle_slots_save(req, res, id_slot);
3388
+ } else if (action == "restore") {
3389
+ handle_slots_restore(req, res, id_slot);
3390
+ } else if (action == "erase") {
3391
+ handle_slots_erase(req, res, id_slot);
3392
+ } else {
3393
+ res_error(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST));
3394
+ }
3395
+ };
3396
+
3397
+ const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3398
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3399
+ json data = {
3400
+ { "user_name", ctx_server.name_user.c_str() },
3401
+ { "assistant_name", ctx_server.name_assistant.c_str() },
3402
+ { "default_generation_settings", ctx_server.default_generation_settings_for_props },
3403
+ { "total_slots", ctx_server.params.n_parallel }
3404
+ };
3405
+
3406
+ res.set_content(data.dump(), "application/json; charset=utf-8");
3407
+ };
3408
+
3409
+ const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3410
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3411
+
3412
+ json data = json::parse(req.body);
3413
+
3414
+ const int id_task = ctx_server.queue_tasks.get_new_id();
3415
+
3416
+ ctx_server.queue_results.add_waiting_task_id(id_task);
3417
+ ctx_server.request_completion(id_task, -1, data, false, false);
3418
+
3419
+ if (!json_value(data, "stream", false)) {
3420
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3421
+ if (!result.error && result.stop) {
3422
+ res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
3423
+ } else {
3424
+ res_error(res, result.data);
3425
+ }
3426
+
3427
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3428
+ } else {
3429
+ const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
3430
+ while (true) {
3431
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3432
+ if (!result.error) {
3433
+ const std::string str =
3434
+ "data: " +
3435
+ result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3436
+ "\n\n";
3437
+
3438
+ LOG_VERBOSE("data stream", {
3439
+ { "to_send", str }
3440
+ });
3441
+
3442
+ if (!sink.write(str.c_str(), str.size())) {
3443
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3444
+ return false;
3445
+ }
3446
+
3447
+ if (result.stop) {
3448
+ break;
3449
+ }
3450
+ } else {
3451
+ const std::string str =
3452
+ "error: " +
3453
+ result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3454
+ "\n\n";
3455
+
3456
+ LOG_VERBOSE("data stream", {
3457
+ { "to_send", str }
3458
+ });
3459
+
3460
+ if (!sink.write(str.c_str(), str.size())) {
3461
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3462
+ return false;
3463
+ }
3464
+
3465
+ break;
3466
+ }
3467
+ }
3468
+
3469
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3470
+ sink.done();
3471
+
3472
+ return true;
3473
+ };
3474
+
3475
+ auto on_complete = [id_task, &ctx_server] (bool) {
3476
+ // cancel
3477
+ ctx_server.request_cancel(id_task);
3478
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3479
+ };
3480
+
3481
+ res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3482
+ }
3483
+ };
3484
+
3485
+ const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
3486
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3487
+
3488
+ json models = {
3489
+ {"object", "list"},
3490
+ {"data", {
3491
+ {
3492
+ {"id", params.model_alias},
3493
+ {"object", "model"},
3494
+ {"created", std::time(0)},
3495
+ {"owned_by", "llamacpp"},
3496
+ {"meta", model_meta}
3497
+ },
3498
+ }}
3499
+ };
3500
+
3501
+ res.set_content(models.dump(), "application/json; charset=utf-8");
3502
+ };
3503
+
3504
+ const auto handle_chat_completions = [&ctx_server, &sparams, &res_error](const httplib::Request & req, httplib::Response & res) {
3505
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3506
+ json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), sparams.chat_template);
3507
+
3508
+ const int id_task = ctx_server.queue_tasks.get_new_id();
3509
+
3510
+ ctx_server.queue_results.add_waiting_task_id(id_task);
3511
+ ctx_server.request_completion(id_task, -1, data, false, false);
3512
+
3513
+ const auto completion_id = gen_chatcmplid();
3514
+ if (!json_value(data, "stream", false)) {
3515
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3516
+
3517
+ if (!result.error && result.stop) {
3518
+ json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
3519
+
3520
+ res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
3521
+ } else {
3522
+ res_error(res, result.data);
3523
+ }
3524
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3525
+ } else {
3526
+ const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
3527
+ while (true) {
3528
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3529
+ if (!result.error) {
3530
+ std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
3531
+
3532
+ for (auto it = result_array.begin(); it != result_array.end(); ++it) {
3533
+ if (!it->empty()) {
3534
+ const std::string str =
3535
+ "data: " +
3536
+ it->dump(-1, ' ', false, json::error_handler_t::replace) +
3537
+ "\n\n";
3538
+ LOG_VERBOSE("data stream", {{"to_send", str}});
3539
+ if (!sink.write(str.c_str(), str.size())) {
3540
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3541
+ return false;
3542
+ }
3543
+ }
3544
+ }
3545
+ if (result.stop) {
3546
+ break;
3547
+ }
3548
+ } else {
3549
+ const std::string str =
3550
+ "error: " +
3551
+ result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3552
+ "\n\n";
3553
+ LOG_VERBOSE("data stream", {{"to_send", str}});
3554
+ if (!sink.write(str.c_str(), str.size())) {
3555
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3556
+ return false;
3557
+ }
3558
+ break;
3559
+ }
3560
+ }
3561
+ sink.done();
3562
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3563
+ return true;
3564
+ };
3565
+
3566
+ auto on_complete = [id_task, &ctx_server](bool) {
3567
+ // cancel request
3568
+ ctx_server.request_cancel(id_task);
3569
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3570
+ };
3571
+
3572
+ res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3573
+ }
3574
+ };
3575
+
3576
+ const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3577
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3578
+
3579
+ json data = json::parse(req.body);
3580
+
3581
+ const int id_task = ctx_server.queue_tasks.get_new_id();
3582
+
3583
+ ctx_server.queue_results.add_waiting_task_id(id_task);
3584
+ ctx_server.request_completion(id_task, -1, data, true, false);
3585
+
3586
+ if (!json_value(data, "stream", false)) {
3587
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3588
+ if (!result.error && result.stop) {
3589
+ res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
3590
+ } else {
3591
+ res_error(res, result.data);
3592
+ }
3593
+
3594
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3595
+ } else {
3596
+ const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
3597
+ while (true) {
3598
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3599
+ if (!result.error) {
3600
+ const std::string str =
3601
+ "data: " +
3602
+ result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3603
+ "\n\n";
3604
+
3605
+ LOG_VERBOSE("data stream", {
3606
+ { "to_send", str }
3607
+ });
3608
+
3609
+ if (!sink.write(str.c_str(), str.size())) {
3610
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3611
+ return false;
3612
+ }
3613
+
3614
+ if (result.stop) {
3615
+ break;
3616
+ }
3617
+ } else {
3618
+ break;
3619
+ }
3620
+ }
3621
+
3622
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3623
+ sink.done();
3624
+
3625
+ return true;
3626
+ };
3627
+
3628
+ auto on_complete = [id_task, &ctx_server] (bool) {
3629
+ ctx_server.request_cancel(id_task);
3630
+ };
3631
+
3632
+ res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3633
+ }
3634
+ };
3635
+
3636
+ const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3637
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3638
+ const json body = json::parse(req.body);
3639
+
3640
+ std::vector<llama_token> tokens;
3641
+ if (body.count("content") != 0) {
3642
+ tokens = ctx_server.tokenize(body["content"], false);
3643
+ }
3644
+ const json data = format_tokenizer_response(tokens);
3645
+ return res.set_content(data.dump(), "application/json; charset=utf-8");
3646
+ };
3647
+
3648
+ const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3649
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3650
+ const json body = json::parse(req.body);
3651
+
3652
+ std::string content;
3653
+ if (body.count("tokens") != 0) {
3654
+ const std::vector<llama_token> tokens = body["tokens"];
3655
+ content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
3656
+ }
3657
+
3658
+ const json data = format_detokenized_response(content);
3659
+ return res.set_content(data.dump(), "application/json; charset=utf-8");
3660
+ };
3661
+
3662
+ const auto handle_embeddings = [&params, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3663
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3664
+ if (!params.embedding) {
3665
+ res.status = 501;
3666
+ res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
3667
+ return;
3668
+ }
3669
+
3670
+ const json body = json::parse(req.body);
3671
+ bool is_openai = false;
3672
+
3673
+ // an input prompt can be a string or a list of tokens (integer)
3674
+ json prompt;
3675
+ if (body.count("input") != 0) {
3676
+ is_openai = true;
3677
+ prompt = body["input"];
3678
+ } else if (body.count("content") != 0) {
3679
+ // with "content", we only support single prompt
3680
+ prompt = std::vector<std::string>{body["content"]};
3681
+ } else {
3682
+ res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3683
+ return;
3684
+ }
3685
+
3686
+ // create and queue the task
3687
+ json responses;
3688
+ {
3689
+ const int id_task = ctx_server.queue_tasks.get_new_id();
3690
+ ctx_server.queue_results.add_waiting_task_id(id_task);
3691
+ ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
3692
+
3693
+ // get the result
3694
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3695
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3696
+ if (!result.error) {
3697
+ if (result.data.count("results")) {
3698
+ // result for multi-task
3699
+ responses = result.data["results"];
3700
+ } else {
3701
+ // result for single task
3702
+ responses = std::vector<json>{result.data};
3703
+ }
3704
+ } else {
3705
+ // error received, ignore everything else
3706
+ res_error(res, result.data);
3707
+ return;
3708
+ }
3709
+ }
3710
+
3711
+ // write JSON response
3712
+ json root = is_openai
3713
+ ? format_embeddings_response_oaicompat(body, responses)
3714
+ : responses[0];
3715
+ return res.set_content(root.dump(), "application/json; charset=utf-8");
3716
+ };
3717
+
3718
+ auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
3719
+ return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
3720
+ res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
3721
+ return false;
3722
+ };
3723
+ };
3724
+
3725
+ //
3726
+ // Router
3727
+ //
3728
+
3729
+ // register static assets routes
3730
+ if (!sparams.public_path.empty()) {
3731
+ // Set the base directory for serving static files
3732
+ svr->set_base_dir(sparams.public_path);
3733
+ }
3734
+
3735
+ // using embedded static files
3736
+ svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
3737
+ svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
3738
+ svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
3739
+ svr->Get("/json-schema-to-grammar.mjs", handle_static_file(
3740
+ json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
3741
+
3742
+ // register API routes
3743
+ svr->Get ("/health", handle_health);
3744
+ svr->Get ("/slots", handle_slots);
3745
+ svr->Get ("/metrics", handle_metrics);
3746
+ svr->Get ("/props", handle_props);
3747
+ svr->Get ("/v1/models", handle_models);
3748
+ svr->Post("/completion", handle_completions); // legacy
3749
+ svr->Post("/completions", handle_completions);
3750
+ svr->Post("/v1/completions", handle_completions);
3751
+ svr->Post("/chat/completions", handle_chat_completions);
3752
+ svr->Post("/v1/chat/completions", handle_chat_completions);
3753
+ svr->Post("/infill", handle_infill);
3754
+ svr->Post("/embedding", handle_embeddings); // legacy
3755
+ svr->Post("/embeddings", handle_embeddings);
3756
+ svr->Post("/v1/embeddings", handle_embeddings);
3757
+ svr->Post("/tokenize", handle_tokenize);
3758
+ svr->Post("/detokenize", handle_detokenize);
3759
+ if (!sparams.slot_save_path.empty()) {
3760
+ // only enable slot endpoints if slot_save_path is set
3761
+ svr->Post("/slots/:id_slot", handle_slots_action);
3762
+ }
3763
+
3764
+ //
3765
+ // Start the server
3766
+ //
3767
+ if (sparams.n_threads_http < 1) {
3768
+ // +2 threads for monitoring endpoints
3769
+ sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
3770
+ }
3771
+ log_data["n_threads_http"] = std::to_string(sparams.n_threads_http);
3772
+ svr->new_task_queue = [&sparams] { return new httplib::ThreadPool(sparams.n_threads_http); };
3773
+
3774
+ LOG_INFO("HTTP server listening", log_data);
3775
+
3776
+ // run the HTTP server in a thread - see comment below
3777
+ std::thread t([&]() {
3778
+ if (!svr->listen_after_bind()) {
3779
+ state.store(SERVER_STATE_ERROR);
3780
+ return 1;
3781
+ }
3782
+
3783
+ return 0;
3784
+ });
3785
+
3786
+ ctx_server.queue_tasks.on_new_task(std::bind(
3787
+ &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3788
+ ctx_server.queue_tasks.on_finish_multitask(std::bind(
3789
+ &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
3790
+ ctx_server.queue_tasks.on_update_slots(std::bind(
3791
+ &server_context::update_slots, &ctx_server));
3792
+ ctx_server.queue_results.on_multitask_update(std::bind(
3793
+ &server_queue::update_multitask,
3794
+ &ctx_server.queue_tasks,
3795
+ std::placeholders::_1,
3796
+ std::placeholders::_2,
3797
+ std::placeholders::_3
3798
+ ));
3799
+
3800
+ shutdown_handler = [&](int) {
3801
+ ctx_server.queue_tasks.terminate();
3802
+ };
3803
+
3804
+ #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
3805
+ struct sigaction sigint_action;
3806
+ sigint_action.sa_handler = signal_handler;
3807
+ sigemptyset (&sigint_action.sa_mask);
3808
+ sigint_action.sa_flags = 0;
3809
+ sigaction(SIGINT, &sigint_action, NULL);
3810
+ sigaction(SIGTERM, &sigint_action, NULL);
3811
+ #elif defined (_WIN32)
3812
+ auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL {
3813
+ return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false;
3814
+ };
3815
+ SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
3816
+ #endif
3817
+
3818
+ ctx_server.queue_tasks.start_loop();
3819
+
3820
+ svr->stop();
3821
+ t.join();
3822
+
3823
+ llama_backend_free();
3824
+
3825
+ return 0;
3826
+ }