@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -1,20 +1,17 @@
1
1
  #include "utils.hpp"
2
2
 
3
+ #include "arg.h"
3
4
  #include "common.h"
5
+ #include "log.h"
6
+ #include "sampling.h"
4
7
  #include "json-schema-to-grammar.h"
5
8
  #include "llama.h"
6
- #include "grammar-parser.h"
7
9
 
8
- #ifndef NDEBUG
9
- // crash the server in debug mode, otherwise send an http 500 error
10
- #define CPPHTTPLIB_NO_EXCEPTIONS 1
11
- #endif
12
- // increase max payload length to allow use of larger context size
13
- #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
14
- #include "httplib.h"
15
10
  // Change JSON_ASSERT from assert() to GGML_ASSERT:
16
11
  #define JSON_ASSERT GGML_ASSERT
17
12
  #include "json.hpp"
13
+ // mime type for sending response
14
+ #define MIMETYPE_JSON "application/json; charset=utf-8"
18
15
 
19
16
  // auto generated files (update with ./deps.sh)
20
17
  #include "colorthemes.css.hpp"
@@ -32,42 +29,53 @@
32
29
  #include "system-prompts.js.hpp"
33
30
  #include "prompt-formats.js.hpp"
34
31
  #include "json-schema-to-grammar.mjs.hpp"
32
+ #include "loading.html.hpp"
35
33
 
36
34
  #include <atomic>
37
- #include <chrono>
38
35
  #include <condition_variable>
39
36
  #include <cstddef>
40
- #include <set>
37
+ #include <cinttypes>
38
+ #include <deque>
39
+ #include <memory>
41
40
  #include <mutex>
42
- #include <thread>
43
41
  #include <signal.h>
44
- #include <memory>
42
+ #include <thread>
43
+ #include <unordered_map>
44
+ #include <unordered_set>
45
45
 
46
- using json = nlohmann::ordered_json;
46
+ #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
47
+ #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
48
+ #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
49
+ #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
50
+
51
+ #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
52
+ #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
53
+ #define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
54
+ #define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
55
+
56
+ #define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
57
+ #define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
58
+ #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
59
+ #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
47
60
 
48
- bool server_verbose = false;
49
- bool server_log_json = true;
61
+ using json = nlohmann::ordered_json;
50
62
 
51
63
  enum stop_type {
52
64
  STOP_TYPE_FULL,
53
65
  STOP_TYPE_PARTIAL,
54
66
  };
55
67
 
68
+ // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
56
69
  enum slot_state {
57
70
  SLOT_STATE_IDLE,
58
- SLOT_STATE_PROCESSING,
59
- };
60
-
61
- enum slot_command {
62
- SLOT_COMMAND_NONE,
63
- SLOT_COMMAND_LOAD_PROMPT,
64
- SLOT_COMMAND_RELEASE,
71
+ SLOT_STATE_PROCESSING_PROMPT,
72
+ SLOT_STATE_DONE_PROMPT,
73
+ SLOT_STATE_GENERATING,
65
74
  };
66
75
 
67
76
  enum server_state {
68
77
  SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
69
78
  SERVER_STATE_READY, // Server is ready and model is loaded
70
- SERVER_STATE_ERROR // An error occurred, load_model failed
71
79
  };
72
80
 
73
81
  enum server_task_type {
@@ -78,23 +86,37 @@ enum server_task_type {
78
86
  SERVER_TASK_TYPE_SLOT_SAVE,
79
87
  SERVER_TASK_TYPE_SLOT_RESTORE,
80
88
  SERVER_TASK_TYPE_SLOT_ERASE,
89
+ SERVER_TASK_TYPE_SET_LORA,
90
+ };
91
+
92
+ enum server_task_cmpl_type {
93
+ SERVER_TASK_CMPL_TYPE_NORMAL,
94
+ SERVER_TASK_CMPL_TYPE_EMBEDDING,
95
+ SERVER_TASK_CMPL_TYPE_RERANK,
96
+ SERVER_TASK_CMPL_TYPE_INFILL,
81
97
  };
82
98
 
83
99
  struct server_task {
84
100
  int id = -1; // to be filled by server_queue
85
- int id_multi = -1;
86
- int id_target = -1;
101
+ int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
87
102
 
88
103
  server_task_type type;
89
104
  json data;
90
105
 
91
- bool infill = false;
92
- bool embedding = false;
106
+ server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
107
+
108
+ // utility function
109
+ static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
110
+ std::unordered_set<int> ids(tasks.size());
111
+ for (size_t i = 0; i < tasks.size(); i++) {
112
+ ids.insert(tasks[i].id);
113
+ }
114
+ return ids;
115
+ }
93
116
  };
94
117
 
95
118
  struct server_task_result {
96
119
  int id = -1;
97
- int id_multi = -1;
98
120
 
99
121
  json data;
100
122
 
@@ -102,13 +124,6 @@ struct server_task_result {
102
124
  bool error;
103
125
  };
104
126
 
105
- struct server_task_multi {
106
- int id = -1;
107
-
108
- std::set<int> subtasks_remaining;
109
- std::vector<server_task_result> results;
110
- };
111
-
112
127
  struct slot_params {
113
128
  bool stream = true;
114
129
  bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
@@ -126,12 +141,13 @@ struct slot_params {
126
141
  struct server_slot {
127
142
  int id;
128
143
  int id_task = -1;
129
- int id_multi = -1;
144
+
145
+ // the index relative to completion multi-task request
146
+ size_t index = 0;
130
147
 
131
148
  struct slot_params params;
132
149
 
133
150
  slot_state state = SLOT_STATE_IDLE;
134
- slot_command command = SLOT_COMMAND_NONE;
135
151
 
136
152
  // used to determine the slot that has been used the longest
137
153
  int64_t t_last_used = -1;
@@ -156,8 +172,8 @@ struct server_slot {
156
172
  std::vector<llama_token> cache_tokens;
157
173
  std::vector<completion_token_output> generated_token_probs;
158
174
 
159
- bool infill = false;
160
- bool embedding = false;
175
+ server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
176
+
161
177
  bool has_next_token = true;
162
178
  bool truncated = false;
163
179
  bool stopped_eos = false;
@@ -170,11 +186,13 @@ struct server_slot {
170
186
  std::string stopping_word;
171
187
 
172
188
  // sampling
173
- llama_token sampled;
174
- struct llama_sampling_params sparams;
175
- llama_sampling_context * ctx_sampling = nullptr;
176
189
  json json_schema;
177
190
 
191
+ struct gpt_sampler_params sparams;
192
+ struct gpt_sampler * smpl = nullptr;
193
+
194
+ llama_token sampled;
195
+
178
196
  int32_t ga_i = 0; // group-attention state
179
197
  int32_t ga_n = 1; // group-attention factor
180
198
  int32_t ga_w = 512; // group-attention width
@@ -191,7 +209,11 @@ struct server_slot {
191
209
  double t_prompt_processing; // ms
192
210
  double t_token_generation; // ms
193
211
 
212
+ std::function<void(int)> callback_on_release;
213
+
194
214
  void reset() {
215
+ SLT_DBG(*this, "%s", "\n");
216
+
195
217
  n_prompt_tokens = 0;
196
218
  generated_text = "";
197
219
  truncated = false;
@@ -202,7 +224,7 @@ struct server_slot {
202
224
  n_past = 0;
203
225
  n_sent_text = 0;
204
226
  n_sent_token_probs = 0;
205
- infill = false;
227
+ cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
206
228
  ga_i = 0;
207
229
  n_past_se = 0;
208
230
 
@@ -225,25 +247,25 @@ struct server_slot {
225
247
  return n_remaining > 0; // no budget
226
248
  }
227
249
 
228
- bool available() const {
229
- return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
230
- }
231
-
232
250
  bool is_processing() const {
233
- return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING;
251
+ return state != SLOT_STATE_IDLE;
234
252
  }
235
253
 
236
- void add_token_string(const completion_token_output & token) {
237
- if (command == SLOT_COMMAND_RELEASE) {
254
+ void add_token(const completion_token_output & token) {
255
+ if (!is_processing()) {
256
+ SLT_WRN(*this, "%s", "slot is not processing\n");
238
257
  return;
239
258
  }
240
259
  generated_token_probs.push_back(token);
241
260
  }
242
261
 
243
262
  void release() {
244
- if (state == SLOT_STATE_PROCESSING) {
263
+ if (is_processing()) {
264
+ SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
265
+
245
266
  t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
246
- command = SLOT_COMMAND_RELEASE;
267
+ state = SLOT_STATE_IDLE;
268
+ callback_on_release(id);
247
269
  }
248
270
  }
249
271
 
@@ -290,49 +312,20 @@ struct server_slot {
290
312
  }
291
313
 
292
314
  void print_timings() const {
293
- char buffer[512];
294
-
295
- double t_token = t_prompt_processing / n_prompt_tokens_processed;
296
- double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
297
-
298
- snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
299
- t_prompt_processing, n_prompt_tokens_processed,
300
- t_token, n_tokens_second);
301
-
302
- LOG_INFO(buffer, {
303
- {"id_slot", id},
304
- {"id_task", id_task},
305
- {"t_prompt_processing", t_prompt_processing},
306
- {"n_prompt_tokens_processed", n_prompt_tokens_processed},
307
- {"t_token", t_token},
308
- {"n_tokens_second", n_tokens_second},
309
- });
310
-
311
- t_token = t_token_generation / n_decoded;
312
- n_tokens_second = 1e3 / t_token_generation * n_decoded;
313
-
314
- snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
315
- t_token_generation, n_decoded,
316
- t_token, n_tokens_second);
317
-
318
- LOG_INFO(buffer, {
319
- {"id_slot", id},
320
- {"id_task", id_task},
321
- {"t_token_generation", t_token_generation},
322
- {"n_decoded", n_decoded},
323
- {"t_token", t_token},
324
- {"n_tokens_second", n_tokens_second},
325
- });
326
-
327
- snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
328
-
329
- LOG_INFO(buffer, {
330
- {"id_slot", id},
331
- {"id_task", id_task},
332
- {"t_prompt_processing", t_prompt_processing},
333
- {"t_token_generation", t_token_generation},
334
- {"t_total", t_prompt_processing + t_token_generation},
335
- });
315
+ const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
316
+ const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
317
+
318
+ const double t_gen = t_token_generation / n_decoded;
319
+ const double n_gen_second = 1e3 / t_token_generation * n_decoded;
320
+
321
+ SLT_INF(*this,
322
+ "\n"
323
+ "\rprompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
324
+ "\r eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
325
+ "\r total time = %10.2f ms / %5d tokens\n",
326
+ t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
327
+ t_token_generation, n_decoded, t_gen, n_gen_second,
328
+ t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
336
329
  }
337
330
  };
338
331
 
@@ -350,6 +343,9 @@ struct server_metrics {
350
343
  uint64_t n_tokens_predicted = 0;
351
344
  uint64_t t_tokens_generation = 0;
352
345
 
346
+ uint64_t n_decode_total = 0;
347
+ uint64_t n_busy_slots_total = 0;
348
+
353
349
  void init() {
354
350
  t_start = ggml_time_us();
355
351
  }
@@ -368,6 +364,15 @@ struct server_metrics {
368
364
  t_tokens_generation_total += slot.t_token_generation;
369
365
  }
370
366
 
367
+ void on_decoded(const std::vector<server_slot> & slots) {
368
+ n_decode_total++;
369
+ for (const auto & slot : slots) {
370
+ if (slot.is_processing()) {
371
+ n_busy_slots_total++;
372
+ }
373
+ }
374
+ }
375
+
371
376
  void reset_bucket() {
372
377
  n_prompt_tokens_processed = 0;
373
378
  t_prompt_processing = 0;
@@ -381,42 +386,62 @@ struct server_queue {
381
386
  bool running;
382
387
 
383
388
  // queues
384
- std::vector<server_task> queue_tasks;
385
- std::vector<server_task> queue_tasks_deferred;
386
-
387
- std::vector<server_task_multi> queue_multitasks;
389
+ std::deque<server_task> queue_tasks;
390
+ std::deque<server_task> queue_tasks_deferred;
388
391
 
389
392
  std::mutex mutex_tasks;
390
393
  std::condition_variable condition_tasks;
391
394
 
392
395
  // callback functions
393
- std::function<void(server_task &)> callback_new_task;
394
- std::function<void(server_task_multi &)> callback_finish_multitask;
395
- std::function<void(void)> callback_update_slots;
396
+ std::function<void(server_task&)> callback_new_task;
397
+ std::function<void(void)> callback_update_slots;
396
398
 
397
399
  // Add a new task to the end of the queue
398
- int post(server_task task) {
400
+ int post(server_task task, bool front = false) {
399
401
  std::unique_lock<std::mutex> lock(mutex_tasks);
400
402
  if (task.id == -1) {
401
403
  task.id = id++;
402
- LOG_VERBOSE("new task id", {{"new_id", task.id}});
403
404
  }
404
- queue_tasks.push_back(std::move(task));
405
+ QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
406
+ if (front) {
407
+ queue_tasks.push_front(std::move(task));
408
+ } else {
409
+ queue_tasks.push_back(std::move(task));
410
+ }
405
411
  condition_tasks.notify_one();
406
412
  return task.id;
407
413
  }
408
414
 
415
+ // multi-task version of post()
416
+ int post(std::vector<server_task> & tasks, bool front = false) {
417
+ std::unique_lock<std::mutex> lock(mutex_tasks);
418
+ for (auto & task : tasks) {
419
+ if (task.id == -1) {
420
+ task.id = id++;
421
+ }
422
+ QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
423
+ if (front) {
424
+ queue_tasks.push_front(std::move(task));
425
+ } else {
426
+ queue_tasks.push_back(std::move(task));
427
+ }
428
+ }
429
+ condition_tasks.notify_one();
430
+ return 0;
431
+ }
432
+
409
433
  // Add a new task, but defer until one slot is available
410
434
  void defer(server_task task) {
411
435
  std::unique_lock<std::mutex> lock(mutex_tasks);
436
+ QUE_DBG("defer task, id = %d\n", task.id);
412
437
  queue_tasks_deferred.push_back(std::move(task));
438
+ condition_tasks.notify_one();
413
439
  }
414
440
 
415
- // Get the next id for creating anew task
441
+ // Get the next id for creating a new task
416
442
  int get_new_id() {
417
443
  std::unique_lock<std::mutex> lock(mutex_tasks);
418
444
  int new_id = id++;
419
- LOG_VERBOSE("new task id", {{"new_id", new_id}});
420
445
  return new_id;
421
446
  }
422
447
 
@@ -425,24 +450,19 @@ struct server_queue {
425
450
  callback_new_task = std::move(callback);
426
451
  }
427
452
 
428
- // Register function to process a multitask when it is finished
429
- void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
430
- callback_finish_multitask = std::move(callback);
431
- }
432
-
433
453
  // Register the function to be called when all slots data is ready to be processed
434
454
  void on_update_slots(std::function<void(void)> callback) {
435
455
  callback_update_slots = std::move(callback);
436
456
  }
437
457
 
438
- // Call when the state of one slot is changed
439
- void notify_slot_changed() {
440
- // move deferred tasks back to main loop
458
+ // Call when the state of one slot is changed, it will move one task from deferred to main queue
459
+ void pop_deferred_task() {
441
460
  std::unique_lock<std::mutex> lock(mutex_tasks);
442
- for (auto & task : queue_tasks_deferred) {
443
- queue_tasks.push_back(std::move(task));
461
+ if (!queue_tasks_deferred.empty()) {
462
+ queue_tasks.emplace_back(std::move(queue_tasks_deferred.front()));
463
+ queue_tasks_deferred.pop_front();
444
464
  }
445
- queue_tasks_deferred.clear();
465
+ condition_tasks.notify_one();
446
466
  }
447
467
 
448
468
  // end the start_loop routine
@@ -463,7 +483,7 @@ struct server_queue {
463
483
  running = true;
464
484
 
465
485
  while (true) {
466
- LOG_VERBOSE("new task may arrive", {});
486
+ QUE_DBG("%s", "processing new tasks\n");
467
487
 
468
488
  while (true) {
469
489
  std::unique_lock<std::mutex> lock(mutex_tasks);
@@ -472,39 +492,24 @@ struct server_queue {
472
492
  break;
473
493
  }
474
494
  server_task task = queue_tasks.front();
475
- queue_tasks.erase(queue_tasks.begin());
495
+ queue_tasks.pop_front();
476
496
  lock.unlock();
477
- LOG_VERBOSE("callback_new_task", {{"id_task", task.id}});
478
- callback_new_task(task);
479
- }
480
-
481
- LOG_VERBOSE("update_multitasks", {});
482
497
 
483
- // check if we have any finished multitasks
484
- auto queue_iterator = queue_multitasks.begin();
485
- while (queue_iterator != queue_multitasks.end()) {
486
- if (queue_iterator->subtasks_remaining.empty()) {
487
- // all subtasks done == multitask is done
488
- server_task_multi current_multitask = *queue_iterator;
489
- callback_finish_multitask(current_multitask);
490
- // remove this multitask
491
- queue_iterator = queue_multitasks.erase(queue_iterator);
492
- } else {
493
- ++queue_iterator;
494
- }
498
+ QUE_DBG("processing task, id = %d\n", task.id);
499
+ callback_new_task(task);
495
500
  }
496
501
 
497
502
  // all tasks in the current loop is processed, slots data is now ready
498
- LOG_VERBOSE("callback_update_slots", {});
503
+ QUE_DBG("%s", "update slots\n");
499
504
 
500
505
  callback_update_slots();
501
506
 
502
- LOG_VERBOSE("wait for new task", {});
507
+ QUE_DBG("%s", "waiting for new tasks\n");
503
508
  {
504
509
  std::unique_lock<std::mutex> lock(mutex_tasks);
505
510
  if (queue_tasks.empty()) {
506
511
  if (!running) {
507
- LOG_VERBOSE("ending start_loop", {});
512
+ QUE_DBG("%s", "terminate\n");
508
513
  return;
509
514
  }
510
515
  condition_tasks.wait(lock, [&]{
@@ -514,38 +519,11 @@ struct server_queue {
514
519
  }
515
520
  }
516
521
  }
517
-
518
- //
519
- // functions to manage multitasks
520
- //
521
-
522
- // add a multitask by specifying the id of all subtask (subtask is a server_task)
523
- void add_multitask(int id_multi, std::vector<int> & sub_ids) {
524
- std::lock_guard<std::mutex> lock(mutex_tasks);
525
- server_task_multi multi;
526
- multi.id = id_multi;
527
- std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
528
- queue_multitasks.push_back(multi);
529
- }
530
-
531
- // updatethe remaining subtasks, while appending results to multitask
532
- void update_multitask(int id_multi, int id_sub, server_task_result & result) {
533
- std::lock_guard<std::mutex> lock(mutex_tasks);
534
- for (auto & multitask : queue_multitasks) {
535
- if (multitask.id == id_multi) {
536
- multitask.subtasks_remaining.erase(id_sub);
537
- multitask.results.push_back(result);
538
- }
539
- }
540
- }
541
522
  };
542
523
 
543
524
  struct server_response {
544
- typedef std::function<void(int, int, server_task_result &)> callback_multitask_t;
545
- callback_multitask_t callback_update_multitask;
546
-
547
525
  // for keeping track of all tasks waiting for the result
548
- std::set<int> waiting_task_ids;
526
+ std::unordered_set<int> waiting_task_ids;
549
527
 
550
528
  // the main result queue
551
529
  std::vector<server_task_result> queue_results;
@@ -555,22 +533,40 @@ struct server_response {
555
533
 
556
534
  // add the id_task to the list of tasks waiting for response
557
535
  void add_waiting_task_id(int id_task) {
558
- LOG_VERBOSE("waiting for task id", {{"id_task", id_task}});
536
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
559
537
 
560
538
  std::unique_lock<std::mutex> lock(mutex_results);
561
539
  waiting_task_ids.insert(id_task);
562
540
  }
563
541
 
542
+ void add_waiting_tasks(const std::vector<server_task> & tasks) {
543
+ std::unique_lock<std::mutex> lock(mutex_results);
544
+
545
+ for (const auto & task : tasks) {
546
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
547
+ waiting_task_ids.insert(task.id);
548
+ }
549
+ }
550
+
564
551
  // when the request is finished, we can remove task associated with it
565
552
  void remove_waiting_task_id(int id_task) {
566
- LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}});
553
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
567
554
 
568
555
  std::unique_lock<std::mutex> lock(mutex_results);
569
556
  waiting_task_ids.erase(id_task);
570
557
  }
571
558
 
572
- // This function blocks the thread until there is a response for this id_task
573
- server_task_result recv(int id_task) {
559
+ void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
560
+ std::unique_lock<std::mutex> lock(mutex_results);
561
+
562
+ for (const auto & id_task : id_tasks) {
563
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
564
+ waiting_task_ids.erase(id_task);
565
+ }
566
+ }
567
+
568
+ // This function blocks the thread until there is a response for one of the id_tasks
569
+ server_task_result recv(const std::unordered_set<int> & id_tasks) {
574
570
  while (true) {
575
571
  std::unique_lock<std::mutex> lock(mutex_results);
576
572
  condition_results.wait(lock, [&]{
@@ -578,8 +574,7 @@ struct server_response {
578
574
  });
579
575
 
580
576
  for (int i = 0; i < (int) queue_results.size(); i++) {
581
- if (queue_results[i].id == id_task) {
582
- assert(queue_results[i].id_multi == -1);
577
+ if (id_tasks.find(queue_results[i].id) != id_tasks.end()) {
583
578
  server_task_result res = queue_results[i];
584
579
  queue_results.erase(queue_results.begin() + i);
585
580
  return res;
@@ -590,28 +585,22 @@ struct server_response {
590
585
  // should never reach here
591
586
  }
592
587
 
593
- // Register the function to update multitask
594
- void on_multitask_update(callback_multitask_t callback) {
595
- callback_update_multitask = std::move(callback);
588
+ // single-task version of recv()
589
+ server_task_result recv(int id_task) {
590
+ std::unordered_set<int> id_tasks = {id_task};
591
+ return recv(id_tasks);
596
592
  }
597
593
 
598
594
  // Send a new result to a waiting id_task
599
- void send(server_task_result result) {
600
- LOG_VERBOSE("send new result", {{"id_task", result.id}});
595
+ void send(server_task_result & result) {
596
+ SRV_DBG("sending result for task id = %d\n", result.id);
601
597
 
602
598
  std::unique_lock<std::mutex> lock(mutex_results);
603
599
  for (const auto & id_task : waiting_task_ids) {
604
- // LOG_TEE("waiting task id %i \n", id_task);
605
- // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
606
- if (result.id_multi == id_task) {
607
- LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}});
608
- callback_update_multitask(id_task, result.id, result);
609
- continue;
610
- }
611
-
612
600
  if (result.id == id_task) {
613
- LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}});
614
- queue_results.push_back(result);
601
+ SRV_DBG("task id = %d moved to result queue\n", result.id);
602
+
603
+ queue_results.push_back(std::move(result));
615
604
  condition_results.notify_all();
616
605
  return;
617
606
  }
@@ -622,13 +611,15 @@ struct server_response {
622
611
  struct server_context {
623
612
  llama_model * model = nullptr;
624
613
  llama_context * ctx = nullptr;
614
+ std::vector<llama_lora_adapter_container> loras;
625
615
 
626
616
  gpt_params params;
627
617
 
628
- llama_batch batch;
618
+ llama_batch batch = {};
629
619
 
630
620
  bool clean_kv_cache = true;
631
621
  bool add_bos_token = true;
622
+ bool has_eos_token = false;
632
623
 
633
624
  int32_t n_ctx; // total context for all clients / slots
634
625
 
@@ -663,8 +654,8 @@ struct server_context {
663
654
 
664
655
  // Clear any sampling context
665
656
  for (server_slot & slot : slots) {
666
- if (slot.ctx_sampling != nullptr) {
667
- llama_sampling_free(slot.ctx_sampling);
657
+ if (slot.smpl != nullptr) {
658
+ gpt_sampler_free(slot.smpl);
668
659
  }
669
660
  }
670
661
 
@@ -677,17 +668,23 @@ struct server_context {
677
668
  // dedicate one sequence to the system prompt
678
669
  params.n_parallel += 1;
679
670
 
680
- std::tie(model, ctx) = llama_init_from_gpt_params(params);
671
+ llama_init_result llama_init = llama_init_from_gpt_params(params);
672
+
673
+ model = llama_init.model;
674
+ ctx = llama_init.context;
675
+ loras = llama_init.lora_adapters;
676
+
681
677
  params.n_parallel -= 1; // but be sneaky about it
678
+
682
679
  if (model == nullptr) {
683
- LOG_ERROR("unable to load model", {{"model", params.model}});
680
+ SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
684
681
  return false;
685
682
  }
686
683
 
687
684
  n_ctx = llama_n_ctx(ctx);
688
685
 
689
- add_bos_token = llama_should_add_bos_token(model);
690
- GGML_ASSERT(llama_add_eos_token(model) != 1);
686
+ add_bos_token = llama_add_bos_token(model);
687
+ has_eos_token = !llama_add_eos_token(model);
691
688
 
692
689
  return true;
693
690
  }
@@ -703,7 +700,7 @@ struct server_context {
703
700
  void init() {
704
701
  const int32_t n_ctx_slot = n_ctx / params.n_parallel;
705
702
 
706
- LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
703
+ SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
707
704
 
708
705
  for (int i = 0; i < params.n_parallel; i++) {
709
706
  server_slot slot;
@@ -712,10 +709,7 @@ struct server_context {
712
709
  slot.n_ctx = n_ctx_slot;
713
710
  slot.n_predict = params.n_predict;
714
711
 
715
- LOG_INFO("new slot", {
716
- {"id_slot", slot.id},
717
- {"n_ctx_slot", slot.n_ctx}
718
- });
712
+ SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
719
713
 
720
714
  const int ga_n = params.grp_attn_n;
721
715
  const int ga_w = params.grp_attn_w;
@@ -726,11 +720,7 @@ struct server_context {
726
720
  //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
727
721
  //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
728
722
 
729
- LOG_INFO("slot self-extend", {
730
- {"id_slot", slot.id},
731
- {"ga_n", ga_n},
732
- {"ga_w", ga_w}
733
- });
723
+ SLT_INF(slot, "slot self-extend: ga_n = %d, ga_w = %d\n", ga_n, ga_w);
734
724
  }
735
725
 
736
726
  slot.ga_i = 0;
@@ -739,6 +729,10 @@ struct server_context {
739
729
 
740
730
  slot.sparams = params.sparams;
741
731
 
732
+ slot.callback_on_release = [this](int) {
733
+ queue_tasks.pop_deferred_task();
734
+ };
735
+
742
736
  slot.reset();
743
737
 
744
738
  slots.push_back(slot);
@@ -747,13 +741,13 @@ struct server_context {
747
741
  default_generation_settings_for_props = get_formated_generation(slots.front());
748
742
  default_generation_settings_for_props["seed"] = -1;
749
743
 
750
- // the update_slots() logic will always submit a maximum of n_batch tokens
744
+ // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
751
745
  // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
752
746
  {
753
747
  const int32_t n_batch = llama_n_batch(ctx);
754
748
 
755
749
  // only a single seq_id per token is needed
756
- batch = llama_batch_init(n_batch, 0, 1);
750
+ batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
757
751
  }
758
752
 
759
753
  metrics.init();
@@ -820,7 +814,7 @@ struct server_context {
820
814
 
821
815
  for (server_slot & slot : slots) {
822
816
  // skip the slot if it is not available
823
- if (!slot.available()) {
817
+ if (slot.is_processing()) {
824
818
  continue;
825
819
  }
826
820
 
@@ -849,11 +843,7 @@ struct server_context {
849
843
  }
850
844
 
851
845
  if (ret != nullptr) {
852
- LOG_VERBOSE("selected slot by lcp similarity", {
853
- {"id_slot", ret->id},
854
- {"max_lcp_len", max_lcp_len},
855
- {"similarity", similarity},
856
- });
846
+ SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
857
847
  }
858
848
  }
859
849
 
@@ -862,7 +852,7 @@ struct server_context {
862
852
  int64_t t_last = ggml_time_us();
863
853
  for (server_slot & slot : slots) {
864
854
  // skip the slot if it is not available
865
- if (!slot.available()) {
855
+ if (slot.is_processing()) {
866
856
  continue;
867
857
  }
868
858
 
@@ -874,10 +864,7 @@ struct server_context {
874
864
  }
875
865
 
876
866
  if (ret != nullptr) {
877
- LOG_VERBOSE("selected slot by lru", {
878
- {"id_slot", ret->id},
879
- {"t_last", t_last},
880
- });
867
+ SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last);
881
868
  }
882
869
  }
883
870
 
@@ -887,8 +874,8 @@ struct server_context {
887
874
  bool launch_slot_with_task(server_slot & slot, const server_task & task) {
888
875
  slot_params default_params;
889
876
  // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
890
- llama_sampling_params default_sparams = params.sparams;
891
- auto & data = task.data;
877
+ auto default_sparams = params.sparams;
878
+ const auto & data = task.data;
892
879
 
893
880
  if (data.count("__oaicompat") != 0) {
894
881
  slot.oaicompat = true;
@@ -900,12 +887,12 @@ struct server_context {
900
887
 
901
888
  slot.params.stream = json_value(data, "stream", false);
902
889
  slot.params.cache_prompt = json_value(data, "cache_prompt", false);
903
- slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
890
+ slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
904
891
  slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
905
892
  slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
906
893
  slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
907
894
  slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
908
- slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
895
+ slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
909
896
  slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
910
897
  slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
911
898
  slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
@@ -927,7 +914,8 @@ struct server_context {
927
914
  if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
928
915
  send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
929
916
  return false;
930
- } else if (data.contains("json_schema") && !data.contains("grammar")) {
917
+ }
918
+ if (data.contains("json_schema") && !data.contains("grammar")) {
931
919
  try {
932
920
  auto schema = json_value(data, "json_schema", json::object());
933
921
  slot.sparams.grammar = json_schema_to_grammar(schema);
@@ -940,17 +928,14 @@ struct server_context {
940
928
  }
941
929
 
942
930
  if (slot.params.cache_prompt && slot.ga_n != 1) {
943
- LOG_WARNING("cache_prompt is not supported with group-attention", {});
944
931
  slot.params.cache_prompt = false;
932
+ SLT_WRN(slot, "%s", "group-attention is not supported with prompt caching. disabling cache\n");
945
933
  }
946
934
 
947
935
  if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
948
936
  // Might be better to reject the request with a 400 ?
949
- LOG_WARNING("Max tokens to predict exceeds server configuration", {
950
- {"params.n_predict", slot.params.n_predict},
951
- {"slot.n_predict", slot.n_predict},
952
- });
953
937
  slot.params.n_predict = slot.n_predict;
938
+ SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
954
939
  }
955
940
 
956
941
  // infill
@@ -958,7 +943,7 @@ struct server_context {
958
943
  slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
959
944
 
960
945
  // get prompt
961
- if (!task.infill) {
946
+ if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
962
947
  const auto & prompt = data.find("prompt");
963
948
  if (prompt == data.end()) {
964
949
  send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
@@ -969,62 +954,28 @@ struct server_context {
969
954
  (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
970
955
  (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
971
956
  slot.prompt = *prompt;
972
- } else {
973
- send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
974
- return false;
975
- }
976
- }
977
-
978
- // penalize user-provided tokens
979
- {
980
- slot.sparams.penalty_prompt_tokens.clear();
981
- slot.sparams.use_penalty_prompt_tokens = false;
982
-
983
- const auto & penalty_prompt = data.find("penalty_prompt");
984
-
985
- if (penalty_prompt != data.end()) {
986
- if (penalty_prompt->is_string()) {
987
- const auto penalty_prompt_string = penalty_prompt->get<std::string>();
988
- slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
989
-
990
- if (slot.params.n_predict > 0) {
991
- slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
992
- }
993
- slot.sparams.use_penalty_prompt_tokens = true;
994
-
995
- LOG_VERBOSE("penalty_prompt_tokens", {
996
- {"id_slot", slot.id},
997
- {"tokens", slot.sparams.penalty_prompt_tokens},
998
- });
999
- }
1000
- else if (penalty_prompt->is_array()) {
1001
- const auto n_tokens = penalty_prompt->size();
1002
- slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
1003
-
1004
- const int n_vocab = llama_n_vocab(model);
1005
- for (const auto & penalty_token : *penalty_prompt) {
1006
- if (penalty_token.is_number_integer()) {
1007
- const auto tok = penalty_token.get<llama_token>();
1008
- if (tok >= 0 && tok < n_vocab) {
1009
- slot.sparams.penalty_prompt_tokens.push_back(tok);
1010
- }
1011
- }
957
+ } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
958
+ slot.prompt = prompt->at(0);
959
+ } else if (prompt->is_array() && prompt->size() > 1) {
960
+ // array of strings
961
+ for (const auto & el : *prompt) {
962
+ if (!el.is_string()) {
963
+ send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
964
+ return false;
1012
965
  }
1013
- slot.sparams.use_penalty_prompt_tokens = true;
1014
-
1015
- LOG_VERBOSE("penalty_prompt_tokens", {
1016
- {"id_slot", slot.id},
1017
- {"tokens", slot.sparams.penalty_prompt_tokens},
1018
- });
1019
966
  }
967
+ slot.prompt = *prompt;
968
+ } else {
969
+ send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
970
+ return false;
1020
971
  }
1021
972
  }
1022
973
 
1023
974
  {
1024
975
  slot.sparams.logit_bias.clear();
1025
976
 
1026
- if (json_value(data, "ignore_eos", false)) {
1027
- slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
977
+ if (json_value(data, "ignore_eos", false) && has_eos_token) {
978
+ slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
1028
979
  }
1029
980
 
1030
981
  const auto & logit_bias = data.find("logit_bias");
@@ -1045,12 +996,12 @@ struct server_context {
1045
996
  if (el[0].is_number_integer()) {
1046
997
  llama_token tok = el[0].get<llama_token>();
1047
998
  if (tok >= 0 && tok < n_vocab) {
1048
- slot.sparams.logit_bias[tok] = bias;
999
+ slot.sparams.logit_bias.push_back({tok, bias});
1049
1000
  }
1050
1001
  } else if (el[0].is_string()) {
1051
1002
  auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
1052
1003
  for (auto tok : toks) {
1053
- slot.sparams.logit_bias[tok] = bias;
1004
+ slot.sparams.logit_bias.push_back({tok, bias});
1054
1005
  }
1055
1006
  }
1056
1007
  }
@@ -1072,45 +1023,43 @@ struct server_context {
1072
1023
  }
1073
1024
 
1074
1025
  {
1075
- const auto & samplers_sequence = data.find("samplers");
1076
- if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
1026
+ const auto & samplers = data.find("samplers");
1027
+ if (samplers != data.end() && samplers->is_array()) {
1077
1028
  std::vector<std::string> sampler_names;
1078
- for (const auto & sampler_name : *samplers_sequence) {
1079
- if (sampler_name.is_string()) {
1080
- sampler_names.emplace_back(sampler_name);
1029
+ for (const auto & name : *samplers) {
1030
+ if (name.is_string()) {
1031
+ sampler_names.emplace_back(name);
1081
1032
  }
1082
1033
  }
1083
- slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
1034
+ slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
1084
1035
  } else {
1085
- slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
1036
+ slot.sparams.samplers = default_sparams.samplers;
1086
1037
  }
1087
1038
  }
1088
1039
 
1089
1040
  {
1090
- if (slot.ctx_sampling != nullptr) {
1091
- llama_sampling_free(slot.ctx_sampling);
1041
+ if (slot.smpl != nullptr) {
1042
+ gpt_sampler_free(slot.smpl);
1092
1043
  }
1093
- slot.ctx_sampling = llama_sampling_init(slot.sparams);
1094
- if (slot.ctx_sampling == nullptr) {
1044
+
1045
+ slot.smpl = gpt_sampler_init(model, slot.sparams);
1046
+ if (slot.smpl == nullptr) {
1095
1047
  // for now, the only error that may happen here is invalid grammar
1096
1048
  send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
1097
1049
  return false;
1098
1050
  }
1099
1051
  }
1100
1052
 
1101
- slot.command = SLOT_COMMAND_LOAD_PROMPT;
1053
+ slot.state = SLOT_STATE_PROCESSING_PROMPT;
1102
1054
  slot.prompt_tokens.clear();
1103
1055
 
1104
- LOG_INFO("slot is processing task", {
1105
- {"id_slot", slot.id},
1106
- {"id_task", slot.id_task},
1107
- });
1056
+ SLT_INF(slot, "%s", "processing task\n");
1108
1057
 
1109
1058
  return true;
1110
1059
  }
1111
1060
 
1112
1061
  void kv_cache_clear() {
1113
- LOG_VERBOSE("clearing KV cache", {});
1062
+ SRV_DBG("%s", "clearing KV cache\n");
1114
1063
 
1115
1064
  // clear the entire KV cache
1116
1065
  llama_kv_cache_clear(ctx);
@@ -1118,9 +1067,7 @@ struct server_context {
1118
1067
  }
1119
1068
 
1120
1069
  void system_prompt_update() {
1121
- LOG_VERBOSE("system prompt update", {
1122
- {"system_prompt", system_prompt},
1123
- });
1070
+ SRV_DBG("updating system prompt: '%s'\n", system_prompt.c_str());
1124
1071
 
1125
1072
  kv_cache_clear();
1126
1073
  system_tokens.clear();
@@ -1128,29 +1075,20 @@ struct server_context {
1128
1075
  if (!system_prompt.empty()) {
1129
1076
  system_tokens = ::llama_tokenize(ctx, system_prompt, true);
1130
1077
 
1131
- llama_batch_clear(batch);
1078
+ const int32_t n_batch = llama_n_batch(ctx);
1079
+ const int32_t n_tokens_prompt = system_tokens.size();
1132
1080
 
1133
- for (int i = 0; i < (int)system_tokens.size(); ++i) {
1134
- llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
1135
- }
1081
+ for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
1082
+ const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
1136
1083
 
1137
- const int32_t n_batch = llama_n_batch(ctx);
1084
+ llama_batch_clear(batch);
1138
1085
 
1139
- for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
1140
- const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
1141
- llama_batch batch_view = {
1142
- n_tokens,
1143
- batch.token + i,
1144
- nullptr,
1145
- batch.pos + i,
1146
- batch.n_seq_id + i,
1147
- batch.seq_id + i,
1148
- batch.logits + i,
1149
- 0, 0, 0, // unused
1150
- };
1086
+ for (int32_t j = 0; j < n_tokens; ++j) {
1087
+ llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
1088
+ }
1151
1089
 
1152
- if (llama_decode(ctx, batch_view) != 0) {
1153
- LOG_ERROR("llama_decode() failed", {});
1090
+ if (llama_decode(ctx, batch) != 0) {
1091
+ SRV_ERR("%s", "llama_decode() failed\n");
1154
1092
  return;
1155
1093
  }
1156
1094
  }
@@ -1165,11 +1103,9 @@ struct server_context {
1165
1103
  }
1166
1104
 
1167
1105
  bool system_prompt_set(const std::string & sys_prompt) {
1168
- system_prompt = sys_prompt;
1106
+ SRV_DBG("system prompt set: '%s'\n", system_prompt.c_str());
1169
1107
 
1170
- LOG_VERBOSE("system prompt process", {
1171
- {"system_prompt", system_prompt},
1172
- });
1108
+ system_prompt = sys_prompt;
1173
1109
 
1174
1110
  // release all slots
1175
1111
  for (server_slot & slot : slots) {
@@ -1189,11 +1125,6 @@ struct server_context {
1189
1125
  slot.generated_text += token_str;
1190
1126
  slot.has_next_token = true;
1191
1127
 
1192
- if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
1193
- // we can change penalty_prompt_tokens because it is always created from scratch each request
1194
- slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
1195
- }
1196
-
1197
1128
  // check if there is incomplete UTF-8 character at the end
1198
1129
  bool incomplete = false;
1199
1130
  for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
@@ -1242,7 +1173,7 @@ struct server_context {
1242
1173
  // add the token to slot queue and cache
1243
1174
  }
1244
1175
 
1245
- slot.add_token_string(result);
1176
+ slot.add_token(result);
1246
1177
  if (slot.params.stream) {
1247
1178
  send_partial_response(slot, result);
1248
1179
  }
@@ -1257,74 +1188,56 @@ struct server_context {
1257
1188
  slot.stopped_limit = true;
1258
1189
  slot.has_next_token = false;
1259
1190
 
1260
- LOG_VERBOSE("stopped by limit", {
1261
- {"id_slot", slot.id},
1262
- {"id_task", slot.id_task},
1263
- {"n_decoded", slot.n_decoded},
1264
- {"n_predict", slot.params.n_predict},
1265
- });
1191
+ SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
1192
+ }
1193
+
1194
+ // if context shift is disabled, we stop when it reaches the context limit
1195
+ if (slot.n_decoded >= slot.n_ctx) {
1196
+ slot.truncated = true;
1197
+ slot.stopped_limit = true;
1198
+ slot.has_next_token = false;
1199
+
1200
+ SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
1266
1201
  }
1267
1202
 
1268
1203
  if (llama_token_is_eog(model, result.tok)) {
1269
1204
  slot.stopped_eos = true;
1270
1205
  slot.has_next_token = false;
1271
1206
 
1272
- LOG_VERBOSE("eos token found", {});
1273
- }
1274
-
1275
- auto n_ctx_train = llama_n_ctx_train(model);
1276
- if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
1277
- && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1278
- LOG_WARNING("n_predict is not set and self-context extend is disabled."
1279
- " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
1280
- { "id_slot", slot.id },
1281
- { "params.n_predict", slot.params.n_predict },
1282
- { "slot.n_prompt_tokens", slot.n_prompt_tokens },
1283
- { "slot.n_decoded", slot.n_decoded },
1284
- { "slot.n_predict", slot.n_predict },
1285
- { "n_slots", params.n_parallel },
1286
- { "slot.n_ctx", slot.n_ctx },
1287
- { "n_ctx", n_ctx },
1288
- { "n_ctx_train", n_ctx_train },
1289
- { "ga_n", slot.ga_n },
1290
- });
1207
+ SLT_DBG(slot, "%s", "stopped by EOS\n");
1208
+ }
1209
+
1210
+ const auto n_ctx_train = llama_n_ctx_train(model);
1211
+
1212
+ if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1291
1213
  slot.truncated = true;
1292
1214
  slot.stopped_limit = true;
1293
1215
  slot.has_next_token = false; // stop prediction
1216
+
1217
+ SLT_WRN(slot,
1218
+ "n_predict (%d) is not set and self-context extend is disabled. "
1219
+ "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
1220
+ slot.params.n_predict, n_ctx_train);
1294
1221
  }
1295
1222
 
1296
- LOG_VERBOSE("next token", {
1297
- {"id_slot", slot.id},
1298
- {"id_task", slot.id_task},
1299
- {"token", result.tok},
1300
- {"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
1301
- {"has_next_token", slot.has_next_token},
1302
- {"n_remain", slot.n_remaining},
1303
- {"n_decoded", slot.n_decoded},
1304
- {"stopped_eos", slot.stopped_eos},
1305
- {"stopped_word", slot.stopped_word},
1306
- {"stopped_limit", slot.stopped_limit},
1307
- {"stopping_word", slot.stopping_word},
1308
- });
1223
+ SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
1309
1224
 
1310
1225
  return slot.has_next_token; // continue
1311
1226
  }
1312
1227
 
1313
1228
  json get_formated_generation(const server_slot & slot) const {
1314
- const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
1315
- const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
1316
-
1317
- std::vector<std::string> samplers_sequence;
1318
- samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
1319
- for (const auto & sampler_type : slot.sparams.samplers_sequence) {
1320
- samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
1229
+ std::vector<std::string> samplers;
1230
+ samplers.reserve(slot.sparams.samplers.size());
1231
+ for (const auto & sampler : slot.sparams.samplers) {
1232
+ samplers.emplace_back(gpt_sampler_type_to_str(sampler));
1321
1233
  }
1322
1234
 
1323
1235
  return json {
1324
1236
  {"n_ctx", slot.n_ctx},
1325
- {"n_predict", slot.n_predict},
1237
+ {"n_predict", slot.n_predict}, // Server configured n_predict
1326
1238
  {"model", params.model_alias},
1327
1239
  {"seed", slot.sparams.seed},
1240
+ {"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0},
1328
1241
  {"temperature", slot.sparams.temp},
1329
1242
  {"dynatemp_range", slot.sparams.dynatemp_range},
1330
1243
  {"dynatemp_exponent", slot.sparams.dynatemp_exponent},
@@ -1332,49 +1245,42 @@ struct server_context {
1332
1245
  {"top_p", slot.sparams.top_p},
1333
1246
  {"min_p", slot.sparams.min_p},
1334
1247
  {"tfs_z", slot.sparams.tfs_z},
1335
- {"typical_p", slot.sparams.typical_p},
1248
+ {"typical_p", slot.sparams.typ_p},
1336
1249
  {"repeat_last_n", slot.sparams.penalty_last_n},
1337
1250
  {"repeat_penalty", slot.sparams.penalty_repeat},
1338
1251
  {"presence_penalty", slot.sparams.penalty_present},
1339
1252
  {"frequency_penalty", slot.sparams.penalty_freq},
1340
- {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
1341
- {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
1342
1253
  {"mirostat", slot.sparams.mirostat},
1343
1254
  {"mirostat_tau", slot.sparams.mirostat_tau},
1344
1255
  {"mirostat_eta", slot.sparams.mirostat_eta},
1345
1256
  {"penalize_nl", slot.sparams.penalize_nl},
1346
1257
  {"stop", slot.params.antiprompt},
1347
- {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
1258
+ {"max_tokens", slot.params.n_predict}, // User configured n_predict
1348
1259
  {"n_keep", slot.params.n_keep},
1349
1260
  {"n_discard", slot.params.n_discard},
1350
- {"ignore_eos", ignore_eos},
1261
+ {"ignore_eos", slot.sparams.ignore_eos},
1351
1262
  {"stream", slot.params.stream},
1352
- {"logit_bias", slot.sparams.logit_bias},
1263
+ //{"logit_bias", slot.sparams.logit_bias},
1353
1264
  {"n_probs", slot.sparams.n_probs},
1354
1265
  {"min_keep", slot.sparams.min_keep},
1355
1266
  {"grammar", slot.sparams.grammar},
1356
- {"samplers", samplers_sequence}
1267
+ {"samplers", samplers},
1357
1268
  };
1358
1269
  }
1359
1270
 
1360
1271
  void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1361
- send_error(task.id, task.id_multi, error, type);
1272
+ send_error(task.id, error, type);
1362
1273
  }
1363
1274
 
1364
1275
  void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1365
- send_error(slot.id_task, slot.id_multi, error, type);
1276
+ send_error(slot.id_task, error, type);
1366
1277
  }
1367
1278
 
1368
- void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1369
- LOG_ERROR("task error", {
1370
- {"id_multi", id_multi},
1371
- {"id_task", id_task},
1372
- {"error", error},
1373
- });
1279
+ void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1280
+ SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
1374
1281
 
1375
1282
  server_task_result res;
1376
1283
  res.id = id_task;
1377
- res.id_multi = id_multi;
1378
1284
  res.stop = false;
1379
1285
  res.error = true;
1380
1286
  res.data = format_error_response(error, type);
@@ -1385,14 +1291,14 @@ struct server_context {
1385
1291
  void send_partial_response(server_slot & slot, completion_token_output tkn) {
1386
1292
  server_task_result res;
1387
1293
  res.id = slot.id_task;
1388
- res.id_multi = slot.id_multi;
1389
1294
  res.error = false;
1390
1295
  res.stop = false;
1391
1296
  res.data = json {
1392
1297
  {"content", tkn.text_to_send},
1393
1298
  {"stop", false},
1394
1299
  {"id_slot", slot.id},
1395
- {"multimodal", false}
1300
+ {"multimodal", false},
1301
+ {"index", slot.index},
1396
1302
  };
1397
1303
 
1398
1304
  if (slot.sparams.n_probs > 0) {
@@ -1422,7 +1328,6 @@ struct server_context {
1422
1328
  void send_final_response(const server_slot & slot) {
1423
1329
  server_task_result res;
1424
1330
  res.id = slot.id_task;
1425
- res.id_multi = slot.id_multi;
1426
1331
  res.error = false;
1427
1332
  res.stop = true;
1428
1333
  res.data = json {
@@ -1440,7 +1345,8 @@ struct server_context {
1440
1345
  {"stopped_limit", slot.stopped_limit},
1441
1346
  {"stopping_word", slot.stopping_word},
1442
1347
  {"tokens_cached", slot.n_past},
1443
- {"timings", slot.get_formated_timings()}
1348
+ {"timings", slot.get_formated_timings()},
1349
+ {"index", slot.index},
1444
1350
  };
1445
1351
 
1446
1352
  if (slot.sparams.n_probs > 0) {
@@ -1472,7 +1378,6 @@ struct server_context {
1472
1378
  void send_embedding(const server_slot & slot, const llama_batch & batch) {
1473
1379
  server_task_result res;
1474
1380
  res.id = slot.id_task;
1475
- res.id_multi = slot.id_multi;
1476
1381
  res.error = false;
1477
1382
  res.stop = true;
1478
1383
 
@@ -1491,13 +1396,11 @@ struct server_context {
1491
1396
  }
1492
1397
 
1493
1398
  if (embd == NULL) {
1494
- LOG_ERROR("failed to get embeddings", {
1495
- {"token", batch.token [i]},
1496
- {"seq_id", batch.seq_id[i][0]}
1497
- });
1399
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1498
1400
 
1499
1401
  res.data = json {
1500
1402
  {"embedding", std::vector<float>(n_embd, 0.0f)},
1403
+ {"index", slot.index},
1501
1404
  };
1502
1405
 
1503
1406
  continue;
@@ -1507,83 +1410,191 @@ struct server_context {
1507
1410
 
1508
1411
  res.data = json {
1509
1412
  {"embedding", embd_res},
1413
+ {"index", slot.index},
1510
1414
  };
1511
1415
  }
1512
1416
 
1417
+ SLT_DBG(slot, "%s", "sending embeddings\n");
1418
+
1513
1419
  queue_results.send(res);
1514
1420
  }
1515
1421
 
1516
- void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) {
1517
- server_task task;
1518
- task.id = id_task;
1519
- task.id_multi = id_multi;
1520
- task.id_target = 0;
1521
- task.data = std::move(data);
1522
- task.infill = infill;
1523
- task.embedding = embedding;
1524
- task.type = SERVER_TASK_TYPE_COMPLETION;
1525
-
1526
- // when a completion task's prompt array is not a singleton, we split it into multiple requests
1527
- // otherwise, it's a single-prompt task, we actually queue it
1528
- // if there's numbers in the prompt array it will be treated as an array of tokens
1529
- if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
1530
- bool numbers = false;
1531
- for (const auto & e : task.data.at("prompt")) {
1532
- if (e.is_number()) {
1533
- numbers = true;
1534
- break;
1535
- }
1422
+ void send_rerank(const server_slot & slot, const llama_batch & batch) {
1423
+ server_task_result res;
1424
+ res.id = slot.id_task;
1425
+ res.error = false;
1426
+ res.stop = true;
1427
+
1428
+ for (int i = 0; i < batch.n_tokens; ++i) {
1429
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1430
+ continue;
1536
1431
  }
1537
1432
 
1538
- // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
1539
- // it will completely stall the server. I don't know where the bug for this is.
1540
- //
1541
- // if there are numbers, it needs to be treated like a single prompt,
1542
- // queue_tasks handles a mix of strings and numbers just fine.
1543
- if (numbers) {
1544
- queue_tasks.post(task);
1545
- } else {
1546
- split_multiprompt_task(id_task, task);
1433
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1434
+ if (embd == NULL) {
1435
+ embd = llama_get_embeddings_ith(ctx, i);
1547
1436
  }
1548
- } else {
1549
- queue_tasks.post(task);
1437
+
1438
+ if (embd == NULL) {
1439
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1440
+
1441
+ res.data = json {
1442
+ {"index", slot.index},
1443
+ {"score", -1e6},
1444
+ };
1445
+
1446
+ continue;
1447
+ }
1448
+
1449
+ res.data = json {
1450
+ {"index", slot.index},
1451
+ {"score", embd[0]},
1452
+ };
1550
1453
  }
1551
- }
1552
1454
 
1553
- void request_cancel(int id_task) {
1554
- server_task task;
1555
- task.type = SERVER_TASK_TYPE_CANCEL;
1556
- task.id_target = id_task;
1455
+ SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
1557
1456
 
1558
- queue_tasks.post(task);
1457
+ queue_results.send(res);
1559
1458
  }
1560
1459
 
1561
- void split_multiprompt_task(int id_multi, const server_task & multiprompt_task) {
1562
- const int prompt_count = multiprompt_task.data.at("prompt").size();
1563
- if (prompt_count <= 1) {
1564
- send_error(multiprompt_task, "error while handling multiple prompts");
1565
- return;
1460
+ //
1461
+ // Functions to create new task(s) and receive result(s)
1462
+ //
1463
+
1464
+ std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
1465
+ std::vector<server_task> tasks;
1466
+ auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
1467
+ server_task task;
1468
+ task.id = queue_tasks.get_new_id();
1469
+ task.cmpl_type = cmpl_type;
1470
+ task.type = SERVER_TASK_TYPE_COMPLETION;
1471
+ if (replace_prompt) {
1472
+ task.data = task_data;
1473
+ task.data["prompt"] = std::move(prompt);
1474
+ } else {
1475
+ task.data = std::move(task_data);
1476
+ }
1477
+ tasks.push_back(std::move(task));
1478
+ };
1479
+
1480
+ static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
1481
+ if (!data.contains("prompt")) {
1482
+ throw std::runtime_error(error_msg);
1483
+ }
1484
+
1485
+ json prompt = data.at("prompt");
1486
+
1487
+ // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
1488
+ if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
1489
+ data["index"] = 0;
1490
+ create_task(data, false, nullptr);
1491
+ }
1492
+ // otherwise, it's a multiple-prompt task, we break it into smaller tasks
1493
+ else if (prompt.is_array()) {
1494
+ std::vector<json> prompts = prompt;
1495
+ if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1496
+ // prompts[0] is the question
1497
+ // the rest are the answers/documents
1498
+ SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
1499
+ for (size_t i = 1; i < prompts.size(); i++) {
1500
+ json qd;
1501
+ qd.push_back(prompts[0]);
1502
+ qd.push_back(prompts[i]);
1503
+ data["index"] = i - 1;
1504
+ create_task(data, true, qd);
1505
+ }
1506
+ } else {
1507
+ SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
1508
+ for (size_t i = 0; i < prompts.size(); i++) {
1509
+ const auto & e = prompts[i];
1510
+ if (e.is_string() || json_is_array_of_numbers(e)) {
1511
+ data["index"] = i;
1512
+ create_task(data, true, e);
1513
+ } else {
1514
+ throw std::runtime_error(error_msg);
1515
+ }
1516
+ }
1517
+ }
1566
1518
  }
1519
+ // invalid case
1520
+ else {
1521
+ throw std::runtime_error(error_msg);
1522
+ }
1523
+
1524
+ return tasks;
1525
+ }
1526
+
1527
+ void cancel_tasks(const std::unordered_set<int> & id_tasks) {
1528
+ std::vector<server_task> cancel_tasks;
1529
+ cancel_tasks.reserve(id_tasks.size());
1530
+ for (const auto & id_task : id_tasks) {
1531
+ SRV_WRN("cancel task, id_task = %d\n", id_task);
1532
+
1533
+ server_task task;
1534
+ task.type = SERVER_TASK_TYPE_CANCEL;
1535
+ task.id_target = id_task;
1536
+ cancel_tasks.push_back(task);
1537
+ queue_results.remove_waiting_task_id(id_task);
1538
+ }
1539
+ // push to beginning of the queue, so it has highest priority
1540
+ queue_tasks.post(cancel_tasks, true);
1541
+ }
1542
+
1543
+ // receive the results from task(s) created by create_tasks_cmpl
1544
+ void receive_cmpl_results(
1545
+ const std::unordered_set<int> & id_tasks,
1546
+ const std::function<void(std::vector<server_task_result>&)> & result_handler,
1547
+ const std::function<void(json)> & error_handler) {
1548
+ // TODO: currently, there is no way to detect the client has cancelled the request
1549
+ std::vector<server_task_result> results(id_tasks.size());
1550
+ for (size_t i = 0; i < id_tasks.size(); i++) {
1551
+ server_task_result result = queue_results.recv(id_tasks);
1552
+
1553
+ if (result.error) {
1554
+ error_handler(result.data);
1555
+ cancel_tasks(id_tasks);
1556
+ return;
1557
+ }
1567
1558
 
1568
- // generate all the ID for subtask
1569
- std::vector<int> subtask_ids(prompt_count);
1570
- for (int i = 0; i < prompt_count; i++) {
1571
- subtask_ids[i] = queue_tasks.get_new_id();
1559
+ const size_t idx = result.data["index"];
1560
+ GGML_ASSERT(idx < results.size() && "index out of range");
1561
+
1562
+ results[idx] = result;
1572
1563
  }
1564
+ result_handler(results);
1565
+ }
1573
1566
 
1574
- // queue up the multitask so we can track its subtask progression
1575
- queue_tasks.add_multitask(id_multi, subtask_ids);
1567
+ // receive the results from task(s) created by create_tasks_cmpl, in stream mode
1568
+ void receive_cmpl_results_stream(
1569
+ const std::unordered_set<int> & id_tasks, const
1570
+ std::function<bool(server_task_result&)> & result_handler, const
1571
+ std::function<void(json)> & error_handler) {
1572
+ size_t n_finished = 0;
1573
+ while (true) {
1574
+ server_task_result result = queue_results.recv(id_tasks);
1575
+ if (!result_handler(result)) {
1576
+ cancel_tasks(id_tasks);
1577
+ break;
1578
+ }
1576
1579
 
1577
- // add subtasks
1578
- for (int i = 0; i < prompt_count; i++) {
1579
- json subtask_data = multiprompt_task.data;
1580
- subtask_data["prompt"] = subtask_data.at("prompt")[i];
1580
+ if (result.error) {
1581
+ error_handler(result.data);
1582
+ cancel_tasks(id_tasks);
1583
+ break;
1584
+ }
1581
1585
 
1582
- // subtasks inherit everything else (infill mode, embedding mode, etc.)
1583
- request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding);
1586
+ if (result.stop) {
1587
+ if (++n_finished == id_tasks.size()) {
1588
+ break;
1589
+ }
1590
+ }
1584
1591
  }
1585
1592
  }
1586
1593
 
1594
+ //
1595
+ // Functions to process the task
1596
+ //
1597
+
1587
1598
  void process_single_task(const server_task & task) {
1588
1599
  switch (task.type) {
1589
1600
  case SERVER_TASK_TYPE_COMPLETION:
@@ -1605,13 +1616,13 @@ struct server_context {
1605
1616
 
1606
1617
  if (slot == nullptr) {
1607
1618
  // if no slot is available, we defer this task for processing later
1608
- LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
1619
+ SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
1609
1620
  queue_tasks.defer(task);
1610
1621
  break;
1611
1622
  }
1612
- if (!slot->available()) {
1623
+ if (slot->is_processing()) {
1613
1624
  // if requested slot is unavailable, we defer this task for processing later
1614
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1625
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1615
1626
  queue_tasks.defer(task);
1616
1627
  break;
1617
1628
  }
@@ -1629,12 +1640,11 @@ struct server_context {
1629
1640
  slot->reset();
1630
1641
 
1631
1642
  slot->id_task = task.id;
1632
- slot->id_multi = task.id_multi;
1633
- slot->infill = task.infill;
1634
- slot->embedding = task.embedding;
1643
+ slot->cmpl_type = task.cmpl_type;
1644
+ slot->index = json_value(task.data, "index", 0);
1635
1645
 
1636
1646
  if (!launch_slot_with_task(*slot, task)) {
1637
- LOG_ERROR("error while launching slot", task.data);
1647
+ SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
1638
1648
  break;
1639
1649
  }
1640
1650
  } break;
@@ -1683,22 +1693,10 @@ struct server_context {
1683
1693
 
1684
1694
  slots_data.push_back(slot_data);
1685
1695
  }
1686
- LOG_INFO("slot data", {
1687
- {"id_task", task.id},
1688
- {"n_idle_slots", n_idle_slots},
1689
- {"n_processing_slots", n_processing_slots}
1690
- });
1691
-
1692
- LOG_VERBOSE("slot data", {
1693
- {"id_task", task.id},
1694
- {"n_idle_slots", n_idle_slots},
1695
- {"n_processing_slots", n_processing_slots},
1696
- {"slots", slots_data}
1697
- });
1696
+ SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
1698
1697
 
1699
1698
  server_task_result res;
1700
1699
  res.id = task.id;
1701
- res.id_multi = task.id_multi;
1702
1700
  res.stop = true;
1703
1701
  res.error = false;
1704
1702
  res.data = {
@@ -1717,6 +1715,9 @@ struct server_context {
1717
1715
  { "n_tokens_predicted", metrics.n_tokens_predicted},
1718
1716
  { "t_tokens_generation", metrics.t_tokens_generation},
1719
1717
 
1718
+ { "n_decode_total", metrics.n_decode_total},
1719
+ { "n_busy_slots_total", metrics.n_busy_slots_total},
1720
+
1720
1721
  { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
1721
1722
  { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
1722
1723
 
@@ -1736,9 +1737,9 @@ struct server_context {
1736
1737
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1737
1738
  break;
1738
1739
  }
1739
- if (!slot->available()) {
1740
+ if (slot->is_processing()) {
1740
1741
  // if requested slot is unavailable, we defer this task for processing later
1741
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1742
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1742
1743
  queue_tasks.defer(task);
1743
1744
  break;
1744
1745
  }
@@ -1777,9 +1778,9 @@ struct server_context {
1777
1778
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1778
1779
  break;
1779
1780
  }
1780
- if (!slot->available()) {
1781
+ if (slot->is_processing()) {
1781
1782
  // if requested slot is unavailable, we defer this task for processing later
1782
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1783
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1783
1784
  queue_tasks.defer(task);
1784
1785
  break;
1785
1786
  }
@@ -1825,9 +1826,9 @@ struct server_context {
1825
1826
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1826
1827
  break;
1827
1828
  }
1828
- if (!slot->available()) {
1829
+ if (slot->is_processing()) {
1829
1830
  // if requested slot is unavailable, we defer this task for processing later
1830
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1831
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1831
1832
  queue_tasks.defer(task);
1832
1833
  break;
1833
1834
  }
@@ -1847,68 +1848,37 @@ struct server_context {
1847
1848
  };
1848
1849
  queue_results.send(result);
1849
1850
  } break;
1851
+ case SERVER_TASK_TYPE_SET_LORA:
1852
+ {
1853
+ llama_lora_adapters_apply(ctx, loras);
1854
+ server_task_result result;
1855
+ result.id = task.id;
1856
+ result.stop = true;
1857
+ result.error = false;
1858
+ result.data = json{{ "success", true }};
1859
+ queue_results.send(result);
1860
+ } break;
1850
1861
  }
1851
1862
  }
1852
1863
 
1853
- void on_finish_multitask(const server_task_multi & multitask) {
1854
- // all subtasks done == multitask is done
1855
- server_task_result result;
1856
- result.id = multitask.id;
1857
- result.stop = true;
1858
- result.error = false;
1859
-
1860
- // collect json results into one json result
1861
- std::vector<json> result_jsons;
1862
- for (const auto & subres : multitask.results) {
1863
- result_jsons.push_back(subres.data);
1864
- result.error = result.error && subres.error;
1865
- }
1866
- result.data = json {
1867
- { "results", result_jsons }
1868
- };
1869
-
1870
- queue_results.send(result);
1871
- }
1872
-
1873
1864
  void update_slots() {
1874
1865
  if (system_need_update) {
1875
1866
  system_prompt_update();
1876
1867
  }
1877
1868
 
1878
- // release slots
1879
- for (auto & slot : slots) {
1880
- if (slot.command == SLOT_COMMAND_RELEASE) {
1881
- slot.state = SLOT_STATE_IDLE;
1882
- slot.command = SLOT_COMMAND_NONE;
1883
- slot.t_last_used = ggml_time_us();
1884
-
1885
- LOG_INFO("slot released", {
1886
- {"id_slot", slot.id},
1887
- {"id_task", slot.id_task},
1888
- {"n_ctx", n_ctx},
1889
- {"n_past", slot.n_past},
1890
- {"n_system_tokens", system_tokens.size()},
1891
- {"n_cache_tokens", slot.cache_tokens.size()},
1892
- {"truncated", slot.truncated}
1893
- });
1894
-
1895
- queue_tasks.notify_slot_changed();
1896
- }
1897
- }
1898
-
1899
1869
  // check if all slots are idle
1900
1870
  {
1901
1871
  bool all_idle = true;
1902
1872
 
1903
1873
  for (auto & slot : slots) {
1904
- if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
1874
+ if (slot.is_processing()) {
1905
1875
  all_idle = false;
1906
1876
  break;
1907
1877
  }
1908
1878
  }
1909
1879
 
1910
1880
  if (all_idle) {
1911
- LOG_INFO("all slots are idle", {});
1881
+ SRV_INF("%s", "all slots are idle\n");
1912
1882
  if (system_prompt.empty() && clean_kv_cache) {
1913
1883
  kv_cache_clear();
1914
1884
  }
@@ -1918,7 +1888,7 @@ struct server_context {
1918
1888
  }
1919
1889
 
1920
1890
  {
1921
- LOG_VERBOSE("posting NEXT_RESPONSE", {});
1891
+ SRV_DBG("%s", "posting NEXT_RESPONSE\n");
1922
1892
 
1923
1893
  server_task task;
1924
1894
  task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
@@ -1932,22 +1902,20 @@ struct server_context {
1932
1902
  for (server_slot & slot : slots) {
1933
1903
  if (slot.ga_n == 1) {
1934
1904
  if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
1905
+ if (!params.ctx_shift) {
1906
+ // this check is redundant (for good)
1907
+ // we should never get here, because generation should already stopped in process_token()
1908
+ slot.release();
1909
+ send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
1910
+ continue;
1911
+ }
1912
+
1935
1913
  // Shift context
1936
1914
  const int n_keep = slot.params.n_keep + add_bos_token;
1937
1915
  const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
1938
1916
  const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
1939
1917
 
1940
- LOG_INFO("slot context shift", {
1941
- {"id_slot", slot.id},
1942
- {"id_task", slot.id_task},
1943
- {"n_keep", n_keep},
1944
- {"n_left", n_left},
1945
- {"n_discard", n_discard},
1946
- {"n_ctx", n_ctx},
1947
- {"n_past", slot.n_past},
1948
- {"n_system_tokens", system_tokens.size()},
1949
- {"n_cache_tokens", slot.cache_tokens.size()}
1950
- });
1918
+ SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
1951
1919
 
1952
1920
  llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1953
1921
  llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
@@ -1972,7 +1940,7 @@ struct server_context {
1972
1940
 
1973
1941
  // frist, add sampled tokens from any ongoing sequences
1974
1942
  for (auto & slot : slots) {
1975
- if (slot.state == SLOT_STATE_IDLE) {
1943
+ if (slot.state != SLOT_STATE_GENERATING) {
1976
1944
  continue;
1977
1945
  }
1978
1946
 
@@ -1990,15 +1958,8 @@ struct server_context {
1990
1958
  slot.cache_tokens.push_back(slot.sampled);
1991
1959
  }
1992
1960
 
1993
- LOG_VERBOSE("slot decode token", {
1994
- {"id_slot", slot.id},
1995
- {"id_task", slot.id_task},
1996
- {"n_ctx", n_ctx},
1997
- {"n_past", slot.n_past},
1998
- {"n_system_tokens", system_tokens.size()},
1999
- {"n_cache_tokens", slot.cache_tokens.size()},
2000
- {"truncated", slot.truncated}
2001
- });
1961
+ SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_system_tokens = %d, n_cache_tokens = %d, truncated = %d\n",
1962
+ slot.n_ctx, slot.n_past, (int) system_tokens.size(), (int) slot.cache_tokens.size(), slot.truncated);
2002
1963
  }
2003
1964
 
2004
1965
  // process in chunks of params.n_batch
@@ -2008,27 +1969,25 @@ struct server_context {
2008
1969
  // track if this is an embedding or non-embedding batch
2009
1970
  // if we've added sampled tokens above, we are in non-embedding mode
2010
1971
  // -1: none, 0: non-embedding, 1: embedding
1972
+ // TODO: make enum
2011
1973
  int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
2012
1974
 
2013
1975
  // next, batch any pending prompts without exceeding n_batch
2014
1976
  if (params.cont_batching || batch.n_tokens == 0) {
2015
1977
  for (auto & slot : slots) {
2016
1978
  // this slot still has a prompt to be processed
2017
- if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
1979
+ if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
2018
1980
  auto & prompt_tokens = slot.prompt_tokens;
2019
1981
 
2020
1982
  // we haven't tokenized the prompt yet - do it now:
2021
1983
  if (prompt_tokens.empty()) {
2022
- LOG_VERBOSE("tokenizing prompt", {
2023
- {"id_slot", slot.id},
2024
- {"id_task", slot.id_task}
2025
- });
1984
+ SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
2026
1985
 
2027
1986
  slot.t_start_process_prompt = ggml_time_us();
2028
1987
  slot.t_start_generation = 0;
2029
1988
 
2030
- if (slot.infill) {
2031
- const bool add_bos = llama_should_add_bos_token(model);
1989
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
1990
+ const bool add_bos = llama_add_bos_token(model);
2032
1991
  bool suff_rm_leading_spc = true;
2033
1992
  if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
2034
1993
  params.input_suffix.erase(0, 1);
@@ -2059,6 +2018,29 @@ struct server_context {
2059
2018
  }
2060
2019
 
2061
2020
  prompt_tokens = embd_inp;
2021
+ } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2022
+ // require slot.prompt to be array of 2 strings
2023
+ if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
2024
+ SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
2025
+ slot.release();
2026
+ send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
2027
+ continue;
2028
+ }
2029
+
2030
+ // prompt: [BOS]query[EOS][SEP]doc[EOS]
2031
+ prompt_tokens.clear();
2032
+ prompt_tokens.push_back(llama_token_bos(model));
2033
+ {
2034
+ const auto part = tokenize(slot.prompt[0], false);
2035
+ prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2036
+ }
2037
+ prompt_tokens.push_back(llama_token_eos(model));
2038
+ prompt_tokens.push_back(llama_token_sep(model));
2039
+ {
2040
+ const auto part = tokenize(slot.prompt[1], false);
2041
+ prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2042
+ }
2043
+ prompt_tokens.push_back(llama_token_eos(model));
2062
2044
  } else {
2063
2045
  prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
2064
2046
  }
@@ -2066,40 +2048,34 @@ struct server_context {
2066
2048
  slot.n_past = 0;
2067
2049
  slot.n_prompt_tokens = prompt_tokens.size();
2068
2050
 
2069
- LOG_VERBOSE("prompt tokenized", {
2070
- {"id_slot", slot.id},
2071
- {"id_task", slot.id_task},
2072
- {"n_ctx", slot.n_ctx},
2073
- {"n_keep", slot.params.n_keep},
2074
- {"n_prompt_tokens", slot.n_prompt_tokens},
2075
- {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
2076
- });
2051
+ SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
2077
2052
 
2078
2053
  // empty prompt passed -> release the slot and send empty response
2079
2054
  if (prompt_tokens.empty()) {
2080
- LOG_INFO("empty prompt - releasing slot", {
2081
- {"id_slot", slot.id},
2082
- {"id_task", slot.id_task}
2083
- });
2055
+ SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
2084
2056
 
2085
- slot.state = SLOT_STATE_PROCESSING;
2086
- slot.command = SLOT_COMMAND_NONE;
2087
2057
  slot.release();
2088
2058
  slot.print_timings();
2089
2059
  send_final_response(slot);
2090
2060
  continue;
2091
2061
  }
2092
2062
 
2093
- if (slot.embedding) {
2063
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2094
2064
  // this prompt is too large to process - discard it
2095
2065
  if (slot.n_prompt_tokens > n_ubatch) {
2096
- slot.state = SLOT_STATE_PROCESSING;
2097
- slot.command = SLOT_COMMAND_NONE;
2098
2066
  slot.release();
2099
2067
  send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
2100
2068
  continue;
2101
2069
  }
2102
2070
  } else {
2071
+ if (!params.ctx_shift) {
2072
+ // if context shift is disabled, we make sure prompt size is smaller than KV size
2073
+ if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
2074
+ slot.release();
2075
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
2076
+ continue;
2077
+ }
2078
+ }
2103
2079
  if (slot.params.n_keep < 0) {
2104
2080
  slot.params.n_keep = slot.n_prompt_tokens;
2105
2081
  }
@@ -2126,20 +2102,12 @@ struct server_context {
2126
2102
  slot.truncated = true;
2127
2103
  slot.n_prompt_tokens = prompt_tokens.size();
2128
2104
 
2129
- LOG_VERBOSE("input truncated", {
2130
- {"id_slot", slot.id},
2131
- {"id_task", slot.id_task},
2132
- {"n_ctx", slot.n_ctx},
2133
- {"n_keep", slot.params.n_keep},
2134
- {"n_left", n_left},
2135
- {"n_prompt_tokens", slot.n_prompt_tokens},
2136
- {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
2137
- });
2105
+ SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
2138
2106
 
2139
2107
  GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
2140
2108
  }
2141
2109
 
2142
- llama_sampling_reset(slot.ctx_sampling);
2110
+ gpt_sampler_reset(slot.smpl);
2143
2111
 
2144
2112
  if (!slot.params.cache_prompt) {
2145
2113
  slot.n_past_se = 0;
@@ -2152,17 +2120,14 @@ struct server_context {
2152
2120
 
2153
2121
  // push the prompt into the sampling context (do not apply grammar)
2154
2122
  for (int i = 0; i < slot.n_past; ++i) {
2155
- llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
2123
+ gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
2156
2124
  }
2157
2125
  }
2158
2126
  }
2159
2127
 
2160
2128
  if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
2161
2129
  // we have to evaluate at least 1 token to generate logits.
2162
- LOG_INFO("we have to evaluate at least 1 token to generate logits", {
2163
- { "id_slot", slot.id },
2164
- { "id_task", slot.id_task }
2165
- });
2130
+ SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
2166
2131
 
2167
2132
  slot.n_past--;
2168
2133
  if (slot.ga_i > 0) {
@@ -2173,7 +2138,8 @@ struct server_context {
2173
2138
  slot.n_prompt_tokens_processed = 0;
2174
2139
  }
2175
2140
 
2176
- if (slot.embedding) {
2141
+ // non-causal tasks require to fit the entire prompt in the physical batch
2142
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2177
2143
  // cannot fit the prompt in the current batch - will try next iter
2178
2144
  if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
2179
2145
  continue;
@@ -2181,7 +2147,10 @@ struct server_context {
2181
2147
  }
2182
2148
 
2183
2149
  // check that we are in the right batch_type, if not defer the slot
2184
- bool slot_type = slot.embedding ? 1 : 0;
2150
+ const bool slot_type =
2151
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
2152
+ slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
2153
+
2185
2154
  if (batch_type == -1) {
2186
2155
  batch_type = slot_type;
2187
2156
  } else if (batch_type != slot_type) {
@@ -2205,17 +2174,13 @@ struct server_context {
2205
2174
  slot.n_past_se = 0;
2206
2175
  slot.ga_i = 0;
2207
2176
  // TODO: is the system prompt ever in the sampling context?
2208
- llama_sampling_reset(slot.ctx_sampling);
2177
+ gpt_sampler_reset(slot.smpl);
2209
2178
  }
2210
2179
 
2211
2180
  // remove the non-common part from the cache
2212
2181
  slot.cache_tokens.resize(slot.n_past);
2213
2182
 
2214
- LOG_INFO("kv cache rm [p0, end)", {
2215
- { "id_slot", slot.id },
2216
- { "id_task", slot.id_task },
2217
- { "p0", p0 }
2218
- });
2183
+ SLT_INF(slot, "kv cache rm [%d, end)\n", p0);
2219
2184
 
2220
2185
  int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
2221
2186
 
@@ -2244,18 +2209,11 @@ struct server_context {
2244
2209
  slot_npast++;
2245
2210
  }
2246
2211
 
2247
- LOG_VERBOSE("prompt processing progress", {
2248
- {"id_slot", slot.id},
2249
- {"n_past", slot.n_past},
2250
- {"n_ctx", n_ctx},
2251
- {"n_tokens", batch.n_tokens},
2252
- {"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
2253
- });
2212
+ SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
2254
2213
 
2255
- // entire prompt has been processed - start decoding new tokens
2214
+ // entire prompt has been processed
2256
2215
  if (slot.n_past == slot.n_prompt_tokens) {
2257
- slot.state = SLOT_STATE_PROCESSING;
2258
- slot.command = SLOT_COMMAND_NONE;
2216
+ slot.state = SLOT_STATE_DONE_PROMPT;
2259
2217
 
2260
2218
  GGML_ASSERT(batch.n_tokens > 0);
2261
2219
 
@@ -2265,12 +2223,7 @@ struct server_context {
2265
2223
  slot.n_decoded = 0;
2266
2224
  slot.i_batch = batch.n_tokens - 1;
2267
2225
 
2268
- LOG_VERBOSE("prompt done", {
2269
- {"id_slot", slot.id},
2270
- {"n_past", slot.n_past},
2271
- {"n_ctx", n_ctx},
2272
- {"n_tokens", batch.n_tokens},
2273
- });
2226
+ SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
2274
2227
  }
2275
2228
  }
2276
2229
 
@@ -2281,13 +2234,11 @@ struct server_context {
2281
2234
  }
2282
2235
 
2283
2236
  if (batch.n_tokens == 0) {
2284
- LOG_VERBOSE("no tokens to decode", {});
2237
+ SRV_WRN("%s", "no tokens to decode\n");
2285
2238
  return;
2286
2239
  }
2287
2240
 
2288
- LOG_VERBOSE("decoding batch", {
2289
- {"n_tokens", batch.n_tokens},
2290
- });
2241
+ SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
2291
2242
 
2292
2243
  // make sure we're in the right embedding mode
2293
2244
  llama_set_embeddings(ctx, batch_type == 1);
@@ -2305,10 +2256,9 @@ struct server_context {
2305
2256
  const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
2306
2257
  const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
2307
2258
 
2308
- LOG_TEE("\n");
2309
- LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2310
- LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
2311
- LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2259
+ SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2260
+ SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
2261
+ SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2312
2262
 
2313
2263
  llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
2314
2264
  llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
@@ -2318,7 +2268,7 @@ struct server_context {
2318
2268
 
2319
2269
  slot.ga_i += slot.ga_w / slot.ga_n;
2320
2270
 
2321
- LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
2271
+ SLT_DBG(slot, "\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
2322
2272
  }
2323
2273
 
2324
2274
  slot.n_past_se += n_tokens;
@@ -2337,18 +2287,13 @@ struct server_context {
2337
2287
  };
2338
2288
 
2339
2289
  const int ret = llama_decode(ctx, batch_view);
2290
+ metrics.on_decoded(slots);
2340
2291
 
2341
2292
  if (ret != 0) {
2342
2293
  if (n_batch == 1 || ret < 0) {
2343
2294
  // if you get here, it means the KV cache is full - try increasing it via the context size
2344
- LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
2345
- {"i", i},
2346
- {"n_batch", ret},
2347
- {"ret", ret},
2348
- });
2295
+ SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
2349
2296
  for (auto & slot : slots) {
2350
- slot.state = SLOT_STATE_PROCESSING;
2351
- slot.command = SLOT_COMMAND_NONE;
2352
2297
  slot.release();
2353
2298
  send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
2354
2299
  }
@@ -2359,32 +2304,42 @@ struct server_context {
2359
2304
  n_batch /= 2;
2360
2305
  i -= n_batch;
2361
2306
 
2362
- LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", {
2363
- {"i", i},
2364
- {"n_batch", n_batch},
2365
- {"ret", ret},
2366
- });
2307
+ SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
2367
2308
 
2368
2309
  continue; // continue loop of n_batch
2369
2310
  }
2370
2311
 
2371
2312
  for (auto & slot : slots) {
2372
- if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
2313
+ if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
2373
2314
  continue; // continue loop of slots
2374
2315
  }
2375
2316
 
2376
- // prompt evaluated for embedding
2377
- if (slot.embedding) {
2378
- send_embedding(slot, batch_view);
2379
- slot.release();
2380
- slot.i_batch = -1;
2381
- continue; // continue loop of slots
2382
- }
2317
+ if (slot.state == SLOT_STATE_DONE_PROMPT) {
2318
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2319
+ // prompt evaluated for embedding
2320
+ send_embedding(slot, batch_view);
2321
+ slot.release();
2322
+ slot.i_batch = -1;
2323
+ continue; // continue loop of slots
2324
+ }
2325
+
2326
+ if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2327
+ send_rerank(slot, batch_view);
2328
+ slot.release();
2329
+ slot.i_batch = -1;
2330
+ continue; // continue loop of slots
2331
+ }
2332
+
2333
+ // prompt evaluated for next-token prediction
2334
+ slot.state = SLOT_STATE_GENERATING;
2335
+ } else if (slot.state != SLOT_STATE_GENERATING) {
2336
+ continue; // continue loop of slots
2337
+ }
2383
2338
 
2384
2339
  completion_token_output result;
2385
- const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
2340
+ const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
2386
2341
 
2387
- llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
2342
+ gpt_sampler_accept(slot.smpl, id, true);
2388
2343
 
2389
2344
  slot.n_decoded += 1;
2390
2345
  if (slot.n_decoded == 1) {
@@ -2393,37 +2348,19 @@ struct server_context {
2393
2348
  metrics.on_prompt_eval(slot);
2394
2349
  }
2395
2350
 
2396
- llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
2397
2351
  result.tok = id;
2398
2352
 
2399
- const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
2400
- if (n_probs > 0) {
2401
- const size_t n_valid = slot.ctx_sampling->n_valid;
2402
-
2403
- // Make sure at least n_probs top tokens are at the front of the vector:
2404
- if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
2405
- llama_sample_top_k(ctx, &cur_p, n_probs, 0);
2406
- }
2353
+ const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
2407
2354
 
2408
- if (slot.sparams.temp == 0.0f) {
2409
- // With greedy sampling the probabilities have possibly not been calculated.
2410
- for (size_t i = 0; i < n_probs; ++i) {
2411
- result.probs.push_back({
2412
- cur_p.data[i].id,
2413
- i == 0 ? 1.0f : 0.0f
2414
- });
2415
- }
2416
- } else {
2417
- for (size_t i = 0; i < n_probs; ++i) {
2418
- result.probs.push_back({
2419
- cur_p.data[i].id,
2420
- i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
2421
- });
2422
- }
2423
- }
2355
+ for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2356
+ result.probs.push_back({
2357
+ cur_p->data[i].id,
2358
+ i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2359
+ });
2424
2360
  }
2425
2361
 
2426
2362
  if (!process_token(result, slot)) {
2363
+ // release slot because of stop condition
2427
2364
  slot.release();
2428
2365
  slot.print_timings();
2429
2366
  send_final_response(slot);
@@ -2434,7 +2371,7 @@ struct server_context {
2434
2371
  }
2435
2372
  }
2436
2373
 
2437
- LOG_VERBOSE("run slots completed", {});
2374
+ SRV_DBG("%s", "run slots completed\n");
2438
2375
  }
2439
2376
 
2440
2377
  json model_meta() const {
@@ -2455,19 +2392,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
2455
2392
  return;
2456
2393
  }
2457
2394
 
2458
- LOG_INFO("request", {
2459
- {"remote_addr", req.remote_addr},
2460
- {"remote_port", req.remote_port},
2461
- {"status", res.status},
2462
- {"method", req.method},
2463
- {"path", req.path},
2464
- {"params", req.params},
2465
- });
2395
+ LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
2466
2396
 
2467
- LOG_VERBOSE("request", {
2468
- {"request", req.body},
2469
- {"response", res.body},
2470
- });
2397
+ LOG_DBG("request: %s\n", req.body.c_str());
2398
+ LOG_DBG("response: %s\n", res.body.c_str());
2471
2399
  }
2472
2400
 
2473
2401
  std::function<void(int)> shutdown_handler;
@@ -2485,20 +2413,18 @@ inline void signal_handler(int signal) {
2485
2413
  }
2486
2414
 
2487
2415
  int main(int argc, char ** argv) {
2488
- #if SERVER_VERBOSE != 1
2489
- log_disable();
2490
- #endif
2491
2416
  // own arguments required by this example
2492
2417
  gpt_params params;
2493
2418
 
2494
- if (!gpt_params_parse(argc, argv, params)) {
2495
- gpt_params_print_usage(argc, argv, params);
2419
+ if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
2496
2420
  return 1;
2497
2421
  }
2498
2422
 
2499
- // TODO: not great to use extern vars
2500
- server_log_json = params.log_json;
2501
- server_verbose = params.verbosity > 0;
2423
+ gpt_init();
2424
+
2425
+ // enabling this will output extra debug information in the HTTP responses from the server
2426
+ // see format_final_response_oaicompat()
2427
+ const bool verbose = params.verbosity > 9;
2502
2428
 
2503
2429
  // struct that contains llama context and inference
2504
2430
  server_context ctx_server;
@@ -2514,30 +2440,27 @@ int main(int argc, char ** argv) {
2514
2440
  llama_backend_init();
2515
2441
  llama_numa_init(params.numa);
2516
2442
 
2517
- LOG_INFO("build info", {
2518
- {"build", LLAMA_BUILD_NUMBER},
2519
- {"commit", LLAMA_COMMIT}
2520
- });
2521
-
2522
- LOG_INFO("system info", {
2523
- {"n_threads", params.n_threads},
2524
- {"n_threads_batch", params.n_threads_batch},
2525
- {"total_threads", std::thread::hardware_concurrency()},
2526
- {"system_info", llama_print_system_info()},
2527
- });
2443
+ LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
2444
+ LOG_INF("\n");
2445
+ LOG_INF("%s\n", gpt_params_get_system_info(params).c_str());
2446
+ LOG_INF("\n");
2528
2447
 
2529
2448
  std::unique_ptr<httplib::Server> svr;
2530
2449
  #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2531
2450
  if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
2532
- LOG_INFO("Running with SSL", {{"key", params.ssl_file_key}, {"cert", params.ssl_file_cert}});
2451
+ LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
2533
2452
  svr.reset(
2534
2453
  new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
2535
2454
  );
2536
2455
  } else {
2537
- LOG_INFO("Running without SSL", {});
2456
+ LOG_INF("Running without SSL\n");
2538
2457
  svr.reset(new httplib::Server());
2539
2458
  }
2540
2459
  #else
2460
+ if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
2461
+ LOG_ERR("Server is built without SSL support\n");
2462
+ return 1;
2463
+ }
2541
2464
  svr.reset(new httplib::Server());
2542
2465
  #endif
2543
2466
 
@@ -2546,26 +2469,31 @@ int main(int argc, char ** argv) {
2546
2469
  svr->set_default_headers({{"Server", "llama.cpp"}});
2547
2470
 
2548
2471
  // CORS preflight
2549
- svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
2550
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2472
+ svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
2473
+ // Access-Control-Allow-Origin is already set by middleware
2551
2474
  res.set_header("Access-Control-Allow-Credentials", "true");
2552
2475
  res.set_header("Access-Control-Allow-Methods", "POST");
2553
2476
  res.set_header("Access-Control-Allow-Headers", "*");
2554
- return res.set_content("", "application/json; charset=utf-8");
2477
+ return res.set_content("", "text/html"); // blank response, no data
2555
2478
  });
2556
2479
 
2557
2480
  svr->set_logger(log_server_request);
2558
2481
 
2559
- auto res_error = [](httplib::Response & res, json error_data) {
2482
+ auto res_error = [](httplib::Response & res, const json & error_data) {
2560
2483
  json final_response {{"error", error_data}};
2561
- res.set_content(final_response.dump(), "application/json; charset=utf-8");
2484
+ res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
2562
2485
  res.status = json_value(error_data, "code", 500);
2563
2486
  };
2564
2487
 
2488
+ auto res_ok = [](httplib::Response & res, const json & data) {
2489
+ res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
2490
+ res.status = 200;
2491
+ };
2492
+
2565
2493
  svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
2566
2494
  std::string message;
2567
2495
  try {
2568
- std::rethrow_exception(std::move(ep));
2496
+ std::rethrow_exception(ep);
2569
2497
  } catch (std::exception & e) {
2570
2498
  message = e.what();
2571
2499
  } catch (...) {
@@ -2573,7 +2501,7 @@ int main(int argc, char ** argv) {
2573
2501
  }
2574
2502
 
2575
2503
  json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
2576
- LOG_VERBOSE("Got exception", formatted_error);
2504
+ LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
2577
2505
  res_error(res, formatted_error);
2578
2506
  });
2579
2507
 
@@ -2588,11 +2516,6 @@ int main(int argc, char ** argv) {
2588
2516
  svr->set_read_timeout (params.timeout_read);
2589
2517
  svr->set_write_timeout(params.timeout_write);
2590
2518
 
2591
- if (!svr->bind_to_port(params.hostname, params.port)) {
2592
- fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
2593
- return 1;
2594
- }
2595
-
2596
2519
  std::unordered_map<std::string, std::string> log_data;
2597
2520
 
2598
2521
  log_data["hostname"] = params.hostname;
@@ -2608,42 +2531,13 @@ int main(int argc, char ** argv) {
2608
2531
  // Necessary similarity of prompt for slot selection
2609
2532
  ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
2610
2533
 
2611
- // load the model
2612
- if (!ctx_server.load_model(params)) {
2613
- state.store(SERVER_STATE_ERROR);
2614
- return 1;
2615
- } else {
2616
- ctx_server.init();
2617
- state.store(SERVER_STATE_READY);
2618
- }
2619
-
2620
- LOG_INFO("model loaded", {});
2621
-
2622
- const auto model_meta = ctx_server.model_meta();
2623
-
2624
- // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
2625
- if (params.chat_template.empty()) {
2626
- if (!ctx_server.validate_model_chat_template()) {
2627
- LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
2628
- params.chat_template = "chatml";
2629
- }
2630
- }
2631
-
2632
- // print sample chat example to make it clear which template is used
2633
- {
2634
- LOG_INFO("chat template", {
2635
- {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
2636
- {"built_in", params.chat_template.empty()},
2637
- });
2638
- }
2639
-
2640
2534
  //
2641
2535
  // Middlewares
2642
2536
  //
2643
2537
 
2644
2538
  auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
2645
2539
  // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
2646
- static const std::set<std::string> protected_endpoints = {
2540
+ static const std::unordered_set<std::string> protected_endpoints = {
2647
2541
  "/props",
2648
2542
  "/completion",
2649
2543
  "/completions",
@@ -2680,17 +2574,34 @@ int main(int argc, char ** argv) {
2680
2574
  }
2681
2575
 
2682
2576
  // API key is invalid or not provided
2683
- // TODO: make another middleware for CORS related logic
2684
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2685
2577
  res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
2686
2578
 
2687
- LOG_WARNING("Unauthorized: Invalid API Key", {});
2579
+ LOG_WRN("Unauthorized: Invalid API Key\n");
2688
2580
 
2689
2581
  return false;
2690
2582
  };
2691
2583
 
2584
+ auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
2585
+ server_state current_state = state.load();
2586
+ if (current_state == SERVER_STATE_LOADING_MODEL) {
2587
+ auto tmp = string_split(req.path, '.');
2588
+ if (req.path == "/" || tmp.back() == "html") {
2589
+ res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
2590
+ res.status = 503;
2591
+ } else {
2592
+ res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
2593
+ }
2594
+ return false;
2595
+ }
2596
+ return true;
2597
+ };
2598
+
2692
2599
  // register server middlewares
2693
- svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
2600
+ svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
2601
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2602
+ if (!middleware_server_state(req, res)) {
2603
+ return httplib::Server::HandlerResponse::Handled;
2604
+ }
2694
2605
  if (!middleware_validate_api_key(req, res)) {
2695
2606
  return httplib::Server::HandlerResponse::Handled;
2696
2607
  }
@@ -2701,99 +2612,57 @@ int main(int argc, char ** argv) {
2701
2612
  // Route handlers (or controllers)
2702
2613
  //
2703
2614
 
2704
- const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
2705
- server_state current_state = state.load();
2706
- switch (current_state) {
2707
- case SERVER_STATE_READY:
2708
- {
2709
- // request slots data using task queue
2710
- server_task task;
2711
- task.id = ctx_server.queue_tasks.get_new_id();
2712
- task.type = SERVER_TASK_TYPE_METRICS;
2713
- task.id_target = -1;
2714
-
2715
- ctx_server.queue_results.add_waiting_task_id(task.id);
2716
- ctx_server.queue_tasks.post(task);
2717
-
2718
- // get the result
2719
- server_task_result result = ctx_server.queue_results.recv(task.id);
2720
- ctx_server.queue_results.remove_waiting_task_id(task.id);
2721
-
2722
- const int n_idle_slots = result.data.at("idle");
2723
- const int n_processing_slots = result.data.at("processing");
2724
-
2725
- json health = {
2726
- {"status", "ok"},
2727
- {"slots_idle", n_idle_slots},
2728
- {"slots_processing", n_processing_slots}
2729
- };
2730
-
2731
- res.status = 200; // HTTP OK
2732
- if (params.endpoint_slots && req.has_param("include_slots")) {
2733
- health["slots"] = result.data.at("slots");
2734
- }
2735
-
2736
- if (n_idle_slots == 0) {
2737
- health["status"] = "no slot available";
2738
- if (req.has_param("fail_on_no_slot")) {
2739
- res.status = 503; // HTTP Service Unavailable
2740
- }
2741
- }
2742
-
2743
- res.set_content(health.dump(), "application/json");
2744
- break;
2745
- }
2746
- case SERVER_STATE_LOADING_MODEL:
2747
- {
2748
- res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
2749
- } break;
2750
- case SERVER_STATE_ERROR:
2751
- {
2752
- res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
2753
- } break;
2754
- }
2615
+ const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
2616
+ // error and loading states are handled by middleware
2617
+ json health = {{"status", "ok"}};
2618
+ res_ok(res, health);
2755
2619
  };
2756
2620
 
2757
- const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
2621
+ const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
2758
2622
  if (!params.endpoint_slots) {
2759
- res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
2623
+ res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
2760
2624
  return;
2761
2625
  }
2762
2626
 
2763
2627
  // request slots data using task queue
2764
2628
  server_task task;
2765
2629
  task.id = ctx_server.queue_tasks.get_new_id();
2766
- task.id_multi = -1;
2767
- task.id_target = -1;
2768
2630
  task.type = SERVER_TASK_TYPE_METRICS;
2769
2631
 
2770
2632
  ctx_server.queue_results.add_waiting_task_id(task.id);
2771
- ctx_server.queue_tasks.post(task);
2633
+ ctx_server.queue_tasks.post(task, true); // high-priority task
2772
2634
 
2773
2635
  // get the result
2774
2636
  server_task_result result = ctx_server.queue_results.recv(task.id);
2775
2637
  ctx_server.queue_results.remove_waiting_task_id(task.id);
2776
2638
 
2777
- res.set_content(result.data.at("slots").dump(), "application/json");
2778
- res.status = 200; // HTTP OK
2639
+ // optionally return "fail_on_no_slot" error
2640
+ const int n_idle_slots = result.data.at("idle");
2641
+ if (req.has_param("fail_on_no_slot")) {
2642
+ if (n_idle_slots == 0) {
2643
+ res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
2644
+ return;
2645
+ }
2646
+ }
2647
+
2648
+ res_ok(res, result.data.at("slots"));
2779
2649
  };
2780
2650
 
2781
2651
  const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
2782
2652
  if (!params.endpoint_metrics) {
2783
- res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
2653
+ res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
2784
2654
  return;
2785
2655
  }
2786
2656
 
2787
2657
  // request slots data using task queue
2788
2658
  server_task task;
2789
2659
  task.id = ctx_server.queue_tasks.get_new_id();
2790
- task.id_multi = -1;
2791
2660
  task.id_target = -1;
2792
2661
  task.type = SERVER_TASK_TYPE_METRICS;
2793
2662
  task.data.push_back({{"reset_bucket", true}});
2794
2663
 
2795
2664
  ctx_server.queue_results.add_waiting_task_id(task.id);
2796
- ctx_server.queue_tasks.post(task);
2665
+ ctx_server.queue_tasks.post(task, true); // high-priority task
2797
2666
 
2798
2667
  // get the result
2799
2668
  server_task_result result = ctx_server.queue_results.recv(task.id);
@@ -2807,6 +2676,9 @@ int main(int argc, char ** argv) {
2807
2676
  const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
2808
2677
  const uint64_t t_tokens_generation = data.at("t_tokens_generation");
2809
2678
 
2679
+ const uint64_t n_decode_total = data.at("n_decode_total");
2680
+ const uint64_t n_busy_slots_total = data.at("n_busy_slots_total");
2681
+
2810
2682
  const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
2811
2683
 
2812
2684
  // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
@@ -2827,6 +2699,14 @@ int main(int argc, char ** argv) {
2827
2699
  {"name", "tokens_predicted_seconds_total"},
2828
2700
  {"help", "Predict process time"},
2829
2701
  {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
2702
+ }, {
2703
+ {"name", "n_decode_total"},
2704
+ {"help", "Total number of llama_decode() calls"},
2705
+ {"value", n_decode_total}
2706
+ }, {
2707
+ {"name", "n_busy_slots_per_decode"},
2708
+ {"help", "Average number of busy slots per llama_decode() call"},
2709
+ {"value", (float) n_busy_slots_total / (float) n_decode_total}
2830
2710
  }}},
2831
2711
  {"gauge", {{
2832
2712
  {"name", "prompt_tokens_seconds"},
@@ -2879,7 +2759,7 @@ int main(int argc, char ** argv) {
2879
2759
  res.status = 200; // HTTP OK
2880
2760
  };
2881
2761
 
2882
- const auto handle_slots_save = [&ctx_server, &res_error, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
2762
+ const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
2883
2763
  json request_data = json::parse(req.body);
2884
2764
  std::string filename = request_data.at("filename");
2885
2765
  if (!fs_validate_filename(filename)) {
@@ -2893,7 +2773,7 @@ int main(int argc, char ** argv) {
2893
2773
  task.data = {
2894
2774
  { "id_slot", id_slot },
2895
2775
  { "filename", filename },
2896
- { "filepath", filepath }
2776
+ { "filepath", filepath },
2897
2777
  };
2898
2778
 
2899
2779
  const int id_task = ctx_server.queue_tasks.post(task);
@@ -2905,11 +2785,11 @@ int main(int argc, char ** argv) {
2905
2785
  if (result.error) {
2906
2786
  res_error(res, result.data);
2907
2787
  } else {
2908
- res.set_content(result.data.dump(), "application/json");
2788
+ res_ok(res, result.data);
2909
2789
  }
2910
2790
  };
2911
2791
 
2912
- const auto handle_slots_restore = [&ctx_server, &res_error, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
2792
+ const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
2913
2793
  json request_data = json::parse(req.body);
2914
2794
  std::string filename = request_data.at("filename");
2915
2795
  if (!fs_validate_filename(filename)) {
@@ -2923,7 +2803,7 @@ int main(int argc, char ** argv) {
2923
2803
  task.data = {
2924
2804
  { "id_slot", id_slot },
2925
2805
  { "filename", filename },
2926
- { "filepath", filepath }
2806
+ { "filepath", filepath },
2927
2807
  };
2928
2808
 
2929
2809
  const int id_task = ctx_server.queue_tasks.post(task);
@@ -2935,11 +2815,11 @@ int main(int argc, char ** argv) {
2935
2815
  if (result.error) {
2936
2816
  res_error(res, result.data);
2937
2817
  } else {
2938
- res.set_content(result.data.dump(), "application/json");
2818
+ res_ok(res, result.data);
2939
2819
  }
2940
2820
  };
2941
2821
 
2942
- const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
2822
+ const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
2943
2823
  server_task task;
2944
2824
  task.type = SERVER_TASK_TYPE_SLOT_ERASE;
2945
2825
  task.data = {
@@ -2955,12 +2835,15 @@ int main(int argc, char ** argv) {
2955
2835
  if (result.error) {
2956
2836
  res_error(res, result.data);
2957
2837
  } else {
2958
- res.set_content(result.data.dump(), "application/json");
2838
+ res_ok(res, result.data);
2959
2839
  }
2960
2840
  };
2961
2841
 
2962
- const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
2963
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2842
+ const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
2843
+ if (params.slot_save_path.empty()) {
2844
+ res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
2845
+ return;
2846
+ }
2964
2847
 
2965
2848
  std::string id_slot_str = req.path_params.at("id_slot");
2966
2849
  int id_slot;
@@ -2985,7 +2868,7 @@ int main(int argc, char ** argv) {
2985
2868
  }
2986
2869
  };
2987
2870
 
2988
- const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
2871
+ const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
2989
2872
  std::string template_key = "tokenizer.chat_template", curr_tmpl;
2990
2873
  int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
2991
2874
  if (tlen > 0) {
@@ -2994,274 +2877,190 @@ int main(int argc, char ** argv) {
2994
2877
  curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
2995
2878
  }
2996
2879
  }
2997
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2998
2880
  json data = {
2999
2881
  { "system_prompt", ctx_server.system_prompt.c_str() },
3000
2882
  { "default_generation_settings", ctx_server.default_generation_settings_for_props },
3001
2883
  { "total_slots", ctx_server.params.n_parallel },
3002
- { "chat_template", curr_tmpl.c_str() }
2884
+ { "chat_template", curr_tmpl.c_str() },
3003
2885
  };
3004
2886
 
3005
- res.set_content(data.dump(), "application/json; charset=utf-8");
2887
+ res_ok(res, data);
3006
2888
  };
3007
2889
 
3008
- const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3009
- if (ctx_server.params.embedding) {
3010
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2890
+ const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
2891
+ if (ctx_server.params.embedding || ctx_server.params.reranking) {
2892
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3011
2893
  return;
3012
2894
  }
3013
2895
 
3014
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2896
+ std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
2897
+ ctx_server.queue_results.add_waiting_tasks(tasks);
2898
+ ctx_server.queue_tasks.post(tasks);
3015
2899
 
3016
- json data = json::parse(req.body);
3017
-
3018
- const int id_task = ctx_server.queue_tasks.get_new_id();
2900
+ bool stream = json_value(data, "stream", false);
2901
+ const auto task_ids = server_task::get_list_id(tasks);
3019
2902
 
3020
- ctx_server.queue_results.add_waiting_task_id(id_task);
3021
- ctx_server.request_completion(id_task, -1, data, false, false);
3022
-
3023
- if (!json_value(data, "stream", false)) {
3024
- server_task_result result = ctx_server.queue_results.recv(id_task);
3025
- if (!result.error && result.stop) {
3026
- res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
3027
- } else {
3028
- res_error(res, result.data);
3029
- }
3030
-
3031
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3032
- } else {
3033
- const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
3034
- while (true) {
3035
- server_task_result result = ctx_server.queue_results.recv(id_task);
3036
- if (!result.error) {
3037
- const std::string str =
3038
- "data: " +
3039
- result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3040
- "\n\n";
3041
-
3042
- LOG_VERBOSE("data stream", {
3043
- { "to_send", str }
3044
- });
3045
-
3046
- if (!sink.write(str.c_str(), str.size())) {
3047
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3048
- return false;
3049
- }
3050
-
3051
- if (result.stop) {
3052
- break;
3053
- }
3054
- } else {
3055
- const std::string str =
3056
- "error: " +
3057
- result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3058
- "\n\n";
3059
-
3060
- LOG_VERBOSE("data stream", {
3061
- { "to_send", str }
3062
- });
3063
-
3064
- if (!sink.write(str.c_str(), str.size())) {
3065
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3066
- return false;
3067
- }
3068
-
3069
- break;
2903
+ if (!stream) {
2904
+ ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
2905
+ if (results.size() == 1) {
2906
+ // single result
2907
+ res_ok(res, results[0].data);
2908
+ } else {
2909
+ // multiple results (multitask)
2910
+ json arr = json::array();
2911
+ for (const auto & res : results) {
2912
+ arr.push_back(res.data);
3070
2913
  }
2914
+ res_ok(res, arr);
3071
2915
  }
2916
+ }, [&](const json & error_data) {
2917
+ res_error(res, error_data);
2918
+ });
3072
2919
 
3073
- ctx_server.queue_results.remove_waiting_task_id(id_task);
2920
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2921
+ } else {
2922
+ const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
2923
+ ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
2924
+ return server_sent_event(sink, "data", result.data);
2925
+ }, [&](const json & error_data) {
2926
+ server_sent_event(sink, "error", error_data);
2927
+ });
3074
2928
  sink.done();
3075
-
3076
- return true;
2929
+ return false;
3077
2930
  };
3078
2931
 
3079
- auto on_complete = [id_task, &ctx_server] (bool) {
3080
- // cancel
3081
- ctx_server.request_cancel(id_task);
3082
- ctx_server.queue_results.remove_waiting_task_id(id_task);
2932
+ auto on_complete = [task_ids, &ctx_server] (bool) {
2933
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
3083
2934
  };
3084
2935
 
3085
2936
  res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3086
2937
  }
3087
2938
  };
3088
2939
 
3089
- const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
3090
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3091
-
3092
- json models = {
3093
- {"object", "list"},
3094
- {"data", {
3095
- {
3096
- {"id", params.model_alias},
3097
- {"object", "model"},
3098
- {"created", std::time(0)},
3099
- {"owned_by", "llamacpp"},
3100
- {"meta", model_meta}
3101
- },
3102
- }}
3103
- };
2940
+ const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2941
+ json data = json::parse(req.body);
2942
+ return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
2943
+ };
3104
2944
 
3105
- res.set_content(models.dump(), "application/json; charset=utf-8");
2945
+ const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2946
+ json data = json::parse(req.body);
2947
+ return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
3106
2948
  };
3107
2949
 
3108
- const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
3109
- if (ctx_server.params.embedding) {
3110
- res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2950
+ // TODO: maybe merge this function with "handle_completions_generic"
2951
+ const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
2952
+ if (ctx_server.params.embedding || ctx_server.params.reranking) {
2953
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3111
2954
  return;
3112
2955
  }
3113
2956
 
3114
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3115
2957
  json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
3116
2958
 
3117
- const int id_task = ctx_server.queue_tasks.get_new_id();
3118
-
3119
- ctx_server.queue_results.add_waiting_task_id(id_task);
3120
- ctx_server.request_completion(id_task, -1, data, false, false);
2959
+ std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
2960
+ ctx_server.queue_results.add_waiting_tasks(tasks);
2961
+ ctx_server.queue_tasks.post(tasks);
3121
2962
 
2963
+ bool stream = json_value(data, "stream", false);
2964
+ const auto task_ids = server_task::get_list_id(tasks);
3122
2965
  const auto completion_id = gen_chatcmplid();
3123
- if (!json_value(data, "stream", false)) {
3124
- server_task_result result = ctx_server.queue_results.recv(id_task);
3125
2966
 
3126
- if (!result.error && result.stop) {
3127
- json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
2967
+ if (!stream) {
2968
+ ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
2969
+ // multitask is never support in chat completion, there is only one result
2970
+ json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
2971
+ res_ok(res, result_oai);
2972
+ }, [&](const json & error_data) {
2973
+ res_error(res, error_data);
2974
+ });
3128
2975
 
3129
- res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
3130
- } else {
3131
- res_error(res, result.data);
3132
- }
3133
- ctx_server.queue_results.remove_waiting_task_id(id_task);
2976
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
3134
2977
  } else {
3135
- const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
3136
- while (true) {
3137
- server_task_result result = ctx_server.queue_results.recv(id_task);
3138
- if (!result.error) {
3139
- std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
3140
-
3141
- for (auto it = result_array.begin(); it != result_array.end(); ++it) {
3142
- if (!it->empty()) {
3143
- const std::string str =
3144
- "data: " +
3145
- it->dump(-1, ' ', false, json::error_handler_t::replace) +
3146
- "\n\n";
3147
- LOG_VERBOSE("data stream", {{"to_send", str}});
3148
- if (!sink.write(str.c_str(), str.size())) {
3149
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3150
- return false;
3151
- }
3152
- }
3153
- }
3154
- if (result.stop) {
3155
- break;
2978
+ const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
2979
+ ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
2980
+ std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
2981
+ for (auto & event_data : result_array) {
2982
+ if (event_data.empty()) {
2983
+ continue; // skip the stop token
3156
2984
  }
3157
- } else {
3158
- const std::string str =
3159
- "error: " +
3160
- result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3161
- "\n\n";
3162
- LOG_VERBOSE("data stream", {{"to_send", str}});
3163
- if (!sink.write(str.c_str(), str.size())) {
3164
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3165
- return false;
2985
+ if (!server_sent_event(sink, "data", event_data)) {
2986
+ return false; // connection is closed
3166
2987
  }
3167
- break;
3168
2988
  }
3169
- }
2989
+ return true; // ok
2990
+ }, [&](const json & error_data) {
2991
+ server_sent_event(sink, "error", error_data);
2992
+ });
2993
+ static const std::string ev_done = "data: [DONE]\n\n";
2994
+ sink.write(ev_done.data(), ev_done.size());
3170
2995
  sink.done();
3171
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3172
2996
  return true;
3173
2997
  };
3174
2998
 
3175
- auto on_complete = [id_task, &ctx_server](bool) {
3176
- // cancel request
3177
- ctx_server.request_cancel(id_task);
3178
- ctx_server.queue_results.remove_waiting_task_id(id_task);
2999
+ auto on_complete = [task_ids, &ctx_server] (bool) {
3000
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
3179
3001
  };
3180
3002
 
3181
3003
  res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3182
3004
  }
3183
3005
  };
3184
3006
 
3185
- const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3186
- if (ctx_server.params.embedding) {
3187
- res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
3188
- return;
3189
- }
3190
-
3191
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3192
-
3193
- json data = json::parse(req.body);
3007
+ const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
3008
+ json models = {
3009
+ {"object", "list"},
3010
+ {"data", {
3011
+ {
3012
+ {"id", params.model_alias},
3013
+ {"object", "model"},
3014
+ {"created", std::time(0)},
3015
+ {"owned_by", "llamacpp"},
3016
+ {"meta", ctx_server.model_meta()}
3017
+ },
3018
+ }}
3019
+ };
3194
3020
 
3195
- const int id_task = ctx_server.queue_tasks.get_new_id();
3021
+ res.set_content(models.dump(), MIMETYPE_JSON);
3022
+ };
3196
3023
 
3197
- ctx_server.queue_results.add_waiting_task_id(id_task);
3198
- ctx_server.request_completion(id_task, -1, data, true, false);
3024
+ const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
3025
+ const json body = json::parse(req.body);
3199
3026
 
3200
- if (!json_value(data, "stream", false)) {
3201
- server_task_result result = ctx_server.queue_results.recv(id_task);
3202
- if (!result.error && result.stop) {
3203
- res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
3204
- } else {
3205
- res_error(res, result.data);
3206
- }
3027
+ json tokens_response = json::array();
3028
+ if (body.count("content") != 0) {
3029
+ const bool add_special = json_value(body, "add_special", false);
3030
+ const bool with_pieces = json_value(body, "with_pieces", false);
3031
+ std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special);
3207
3032
 
3208
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3209
- } else {
3210
- const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
3211
- while (true) {
3212
- server_task_result result = ctx_server.queue_results.recv(id_task);
3213
- if (!result.error) {
3214
- const std::string str =
3215
- "data: " +
3216
- result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3217
- "\n\n";
3218
-
3219
- LOG_VERBOSE("data stream", {
3220
- { "to_send", str }
3221
- });
3222
-
3223
- if (!sink.write(str.c_str(), str.size())) {
3224
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3225
- return false;
3226
- }
3033
+ if (with_pieces) {
3034
+ for (const auto& token : tokens) {
3035
+ std::string piece = llama_token_to_piece(ctx_server.ctx, token);
3036
+ json piece_json;
3227
3037
 
3228
- if (result.stop) {
3229
- break;
3230
- }
3038
+ // Check if the piece is valid UTF-8
3039
+ if (is_valid_utf8(piece)) {
3040
+ piece_json = piece;
3231
3041
  } else {
3232
- break;
3042
+ // If not valid UTF-8, store as array of byte values
3043
+ piece_json = json::array();
3044
+ for (unsigned char c : piece) {
3045
+ piece_json.push_back(static_cast<int>(c));
3046
+ }
3233
3047
  }
3234
- }
3235
-
3236
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3237
- sink.done();
3238
3048
 
3239
- return true;
3240
- };
3241
-
3242
- auto on_complete = [id_task, &ctx_server] (bool) {
3243
- ctx_server.request_cancel(id_task);
3244
- };
3245
-
3246
- res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3049
+ tokens_response.push_back({
3050
+ {"id", token},
3051
+ {"piece", piece_json}
3052
+ });
3053
+ }
3054
+ } else {
3055
+ tokens_response = tokens;
3056
+ }
3247
3057
  }
3248
- };
3249
-
3250
- const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3251
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3252
- const json body = json::parse(req.body);
3253
3058
 
3254
- std::vector<llama_token> tokens;
3255
- if (body.count("content") != 0) {
3256
- const bool add_special = json_value(body, "add_special", false);
3257
- tokens = ctx_server.tokenize(body.at("content"), add_special);
3258
- }
3259
- const json data = format_tokenizer_response(tokens);
3260
- return res.set_content(data.dump(), "application/json; charset=utf-8");
3059
+ const json data = format_tokenizer_response(tokens_response);
3060
+ res_ok(res, data);
3261
3061
  };
3262
3062
 
3263
- const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3264
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3063
+ const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
3265
3064
  const json body = json::parse(req.body);
3266
3065
 
3267
3066
  std::string content;
@@ -3271,12 +3070,15 @@ int main(int argc, char ** argv) {
3271
3070
  }
3272
3071
 
3273
3072
  const json data = format_detokenized_response(content);
3274
- return res.set_content(data.dump(), "application/json; charset=utf-8");
3073
+ res_ok(res, data);
3275
3074
  };
3276
3075
 
3277
- const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3278
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3279
-
3076
+ const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3077
+ // TODO: somehow clean up this checks in the future
3078
+ if (!ctx_server.params.embedding || ctx_server.params.reranking) {
3079
+ res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3080
+ return;
3081
+ }
3280
3082
  const json body = json::parse(req.body);
3281
3083
  bool is_openai = false;
3282
3084
 
@@ -3294,35 +3096,156 @@ int main(int argc, char ** argv) {
3294
3096
  }
3295
3097
 
3296
3098
  // create and queue the task
3297
- json responses;
3099
+ json responses = json::array();
3100
+ bool error = false;
3298
3101
  {
3299
- const int id_task = ctx_server.queue_tasks.get_new_id();
3300
- ctx_server.queue_results.add_waiting_task_id(id_task);
3301
- ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
3102
+ std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
3103
+ ctx_server.queue_results.add_waiting_tasks(tasks);
3104
+ ctx_server.queue_tasks.post(tasks);
3302
3105
 
3303
3106
  // get the result
3304
- server_task_result result = ctx_server.queue_results.recv(id_task);
3305
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3306
- if (!result.error) {
3307
- if (result.data.count("results")) {
3308
- // result for multi-task
3309
- responses = result.data.at("results");
3310
- } else {
3311
- // result for single task
3312
- responses = std::vector<json>{result.data};
3107
+ std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
3108
+
3109
+ ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3110
+ for (const auto & res : results) {
3111
+ responses.push_back(res.data);
3313
3112
  }
3314
- } else {
3315
- // error received, ignore everything else
3316
- res_error(res, result.data);
3317
- return;
3318
- }
3113
+ }, [&](const json & error_data) {
3114
+ res_error(res, error_data);
3115
+ error = true;
3116
+ });
3117
+
3118
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
3119
+ }
3120
+
3121
+ if (error) {
3122
+ return;
3319
3123
  }
3320
3124
 
3321
3125
  // write JSON response
3322
3126
  json root = is_openai
3323
3127
  ? format_embeddings_response_oaicompat(body, responses)
3324
3128
  : responses[0];
3325
- return res.set_content(root.dump(), "application/json; charset=utf-8");
3129
+ res_ok(res, root);
3130
+ };
3131
+
3132
+ const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3133
+ if (!ctx_server.params.reranking) {
3134
+ res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3135
+ return;
3136
+ }
3137
+ const json body = json::parse(req.body);
3138
+
3139
+ // TODO: implement
3140
+ //int top_n = 1;
3141
+ //if (body.count("top_n") != 1) {
3142
+ // top_n = body.at("top_n");
3143
+ //} else {
3144
+ // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3145
+ // return;
3146
+ //}
3147
+
3148
+ json query;
3149
+ if (body.count("query") == 1) {
3150
+ query = body.at("query");
3151
+ if (!query.is_string()) {
3152
+ res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3153
+ return;
3154
+ }
3155
+ } else {
3156
+ res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3157
+ return;
3158
+ }
3159
+
3160
+ std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
3161
+ if (documents.empty()) {
3162
+ res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
3163
+ return;
3164
+ }
3165
+
3166
+ // construct prompt object: array of ["query", "doc0", "doc1", ...]
3167
+ json prompt;
3168
+ prompt.push_back(query);
3169
+ for (const auto & doc : documents) {
3170
+ prompt.push_back(doc);
3171
+ }
3172
+
3173
+ LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
3174
+
3175
+ // create and queue the task
3176
+ json responses = json::array();
3177
+ bool error = false;
3178
+ {
3179
+ std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
3180
+ ctx_server.queue_results.add_waiting_tasks(tasks);
3181
+ ctx_server.queue_tasks.post(tasks);
3182
+
3183
+ // get the result
3184
+ std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
3185
+
3186
+ ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3187
+ for (const auto & res : results) {
3188
+ responses.push_back(res.data);
3189
+ }
3190
+ }, [&](const json & error_data) {
3191
+ res_error(res, error_data);
3192
+ error = true;
3193
+ });
3194
+ }
3195
+
3196
+ if (error) {
3197
+ return;
3198
+ }
3199
+
3200
+ // write JSON response
3201
+ json root = format_response_rerank(body, responses);
3202
+ res_ok(res, root);
3203
+ };
3204
+
3205
+ const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
3206
+ json result = json::array();
3207
+ for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
3208
+ auto & lora = ctx_server.loras[i];
3209
+ result.push_back({
3210
+ {"id", i},
3211
+ {"path", lora.path},
3212
+ {"scale", lora.scale},
3213
+ });
3214
+ }
3215
+ res_ok(res, result);
3216
+ res.status = 200; // HTTP OK
3217
+ };
3218
+
3219
+ const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
3220
+ const std::vector<json> body = json::parse(req.body);
3221
+ int max_idx = ctx_server.loras.size();
3222
+
3223
+ // clear existing value
3224
+ for (auto & lora : ctx_server.loras) {
3225
+ lora.scale = 0.0f;
3226
+ }
3227
+
3228
+ // set value
3229
+ for (auto entry : body) {
3230
+ int id = entry.at("id");
3231
+ float scale = entry.at("scale");
3232
+ if (0 <= id && id < max_idx) {
3233
+ ctx_server.loras[id].scale = scale;
3234
+ } else {
3235
+ throw std::runtime_error("invalid adapter id");
3236
+ }
3237
+ }
3238
+
3239
+ server_task task;
3240
+ task.type = SERVER_TASK_TYPE_SET_LORA;
3241
+ const int id_task = ctx_server.queue_tasks.post(task);
3242
+ ctx_server.queue_results.add_waiting_task_id(id_task);
3243
+
3244
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3245
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3246
+
3247
+ res_ok(res, result.data);
3248
+ res.status = 200; // HTTP OK
3326
3249
  };
3327
3250
 
3328
3251
  auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
@@ -3363,7 +3286,6 @@ int main(int argc, char ** argv) {
3363
3286
 
3364
3287
  // register API routes
3365
3288
  svr->Get ("/health", handle_health);
3366
- svr->Get ("/slots", handle_slots);
3367
3289
  svr->Get ("/metrics", handle_metrics);
3368
3290
  svr->Get ("/props", handle_props);
3369
3291
  svr->Get ("/v1/models", handle_models);
@@ -3376,12 +3298,18 @@ int main(int argc, char ** argv) {
3376
3298
  svr->Post("/embedding", handle_embeddings); // legacy
3377
3299
  svr->Post("/embeddings", handle_embeddings);
3378
3300
  svr->Post("/v1/embeddings", handle_embeddings);
3301
+ svr->Post("/rerank", handle_rerank);
3302
+ svr->Post("/reranking", handle_rerank);
3303
+ svr->Post("/v1/rerank", handle_rerank);
3304
+ svr->Post("/v1/reranking", handle_rerank);
3379
3305
  svr->Post("/tokenize", handle_tokenize);
3380
3306
  svr->Post("/detokenize", handle_detokenize);
3381
- if (!params.slot_save_path.empty()) {
3382
- // only enable slot endpoints if slot_save_path is set
3383
- svr->Post("/slots/:id_slot", handle_slots_action);
3384
- }
3307
+ // LoRA adapters hotswap
3308
+ svr->Get ("/lora-adapters", handle_lora_adapters_list);
3309
+ svr->Post("/lora-adapters", handle_lora_adapters_apply);
3310
+ // Save & load slots
3311
+ svr->Get ("/slots", handle_slots);
3312
+ svr->Post("/slots/:id_slot", handle_slots_action);
3385
3313
 
3386
3314
  //
3387
3315
  // Start the server
@@ -3393,36 +3321,66 @@ int main(int argc, char ** argv) {
3393
3321
  log_data["n_threads_http"] = std::to_string(params.n_threads_http);
3394
3322
  svr->new_task_queue = [&params] { return new httplib::ThreadPool(params.n_threads_http); };
3395
3323
 
3396
- LOG_INFO("HTTP server listening", log_data);
3324
+ // clean up function, to be called before exit
3325
+ auto clean_up = [&svr]() {
3326
+ svr->stop();
3327
+ llama_backend_free();
3328
+ };
3329
+
3330
+ // bind HTTP listen port, run the HTTP server in a thread
3331
+ if (!svr->bind_to_port(params.hostname, params.port)) {
3332
+ //LOG_ERROR("couldn't bind HTTP server socket", {
3333
+ // {"hostname", params.hostname},
3334
+ // {"port", params.port},
3335
+ //});
3336
+ LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
3337
+ clean_up();
3338
+ return 1;
3339
+ }
3340
+ std::thread t([&]() { svr->listen_after_bind(); });
3341
+ svr->wait_until_ready();
3397
3342
 
3398
- // run the HTTP server in a thread - see comment below
3399
- std::thread t([&]() {
3400
- if (!svr->listen_after_bind()) {
3401
- state.store(SERVER_STATE_ERROR);
3402
- return 1;
3343
+ LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http);
3344
+
3345
+ // load the model
3346
+ LOG_INF("%s: loading model\n", __func__);
3347
+
3348
+ if (!ctx_server.load_model(params)) {
3349
+ clean_up();
3350
+ t.join();
3351
+ LOG_ERR("%s: exiting due to model loading error\n", __func__);
3352
+ return 1;
3353
+ }
3354
+
3355
+ ctx_server.init();
3356
+ state.store(SERVER_STATE_READY);
3357
+
3358
+ LOG_INF("%s: model loaded\n", __func__);
3359
+
3360
+ // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
3361
+ if (params.chat_template.empty()) {
3362
+ if (!ctx_server.validate_model_chat_template()) {
3363
+ LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
3364
+ params.chat_template = "chatml";
3403
3365
  }
3366
+ }
3404
3367
 
3405
- return 0;
3406
- });
3368
+ // print sample chat example to make it clear which template is used
3369
+ LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
3407
3370
 
3408
3371
  ctx_server.queue_tasks.on_new_task(std::bind(
3409
- &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3410
- ctx_server.queue_tasks.on_finish_multitask(std::bind(
3411
- &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
3372
+ &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3412
3373
  ctx_server.queue_tasks.on_update_slots(std::bind(
3413
- &server_context::update_slots, &ctx_server));
3414
- ctx_server.queue_results.on_multitask_update(std::bind(
3415
- &server_queue::update_multitask,
3416
- &ctx_server.queue_tasks,
3417
- std::placeholders::_1,
3418
- std::placeholders::_2,
3419
- std::placeholders::_3
3420
- ));
3374
+ &server_context::update_slots, &ctx_server));
3421
3375
 
3422
3376
  shutdown_handler = [&](int) {
3423
3377
  ctx_server.queue_tasks.terminate();
3424
3378
  };
3425
3379
 
3380
+ LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
3381
+
3382
+ ctx_server.queue_tasks.start_loop();
3383
+
3426
3384
  #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
3427
3385
  struct sigaction sigint_action;
3428
3386
  sigint_action.sa_handler = signal_handler;
@@ -3437,12 +3395,8 @@ int main(int argc, char ** argv) {
3437
3395
  SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
3438
3396
  #endif
3439
3397
 
3440
- ctx_server.queue_tasks.start_loop();
3441
-
3442
- svr->stop();
3398
+ clean_up();
3443
3399
  t.join();
3444
3400
 
3445
- llama_backend_free();
3446
-
3447
3401
  return 0;
3448
3402
  }