@fugood/llama.node 0.3.3 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. package/CMakeLists.txt +5 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/EmbeddingWorker.cpp +15 -5
  19. package/src/EmbeddingWorker.h +2 -1
  20. package/src/LlamaCompletionWorker.cpp +1 -1
  21. package/src/LlamaContext.cpp +81 -18
  22. package/src/LlamaContext.h +2 -0
  23. package/src/llama.cpp/.github/workflows/build.yml +197 -159
  24. package/src/llama.cpp/.github/workflows/docker.yml +5 -8
  25. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  26. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  27. package/src/llama.cpp/CMakeLists.txt +11 -6
  28. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  29. package/src/llama.cpp/cmake/common.cmake +33 -0
  30. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  31. package/src/llama.cpp/common/CMakeLists.txt +6 -2
  32. package/src/llama.cpp/common/arg.cpp +426 -245
  33. package/src/llama.cpp/common/common.cpp +143 -80
  34. package/src/llama.cpp/common/common.h +81 -24
  35. package/src/llama.cpp/common/sampling.cpp +53 -19
  36. package/src/llama.cpp/common/sampling.h +22 -1
  37. package/src/llama.cpp/common/speculative.cpp +274 -0
  38. package/src/llama.cpp/common/speculative.h +28 -0
  39. package/src/llama.cpp/docs/build.md +101 -148
  40. package/src/llama.cpp/examples/CMakeLists.txt +32 -13
  41. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +5 -4
  43. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  46. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  47. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  48. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  49. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  50. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  51. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  52. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  54. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  55. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  57. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  58. package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
  59. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  60. package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
  61. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/infill/infill.cpp +1 -1
  63. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  64. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
  65. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  66. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  67. package/src/llama.cpp/examples/llava/clip.cpp +262 -66
  68. package/src/llama.cpp/examples/llava/clip.h +8 -2
  69. package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
  70. package/src/llama.cpp/examples/llava/llava.cpp +46 -19
  71. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
  72. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  73. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
  75. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  76. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
  77. package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
  78. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/main/main.cpp +9 -5
  80. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  81. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
  83. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  84. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  87. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  88. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  89. package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
  90. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  91. package/src/llama.cpp/examples/run/run.cpp +911 -0
  92. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
  94. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
  95. package/src/llama.cpp/examples/server/server.cpp +1758 -886
  96. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  97. package/src/llama.cpp/examples/server/utils.hpp +94 -304
  98. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  99. package/src/llama.cpp/examples/simple/simple.cpp +4 -0
  100. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
  101. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
  102. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
  104. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  106. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
  108. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  109. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  110. package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
  111. package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
  112. package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
  113. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  114. package/src/llama.cpp/ggml/include/ggml.h +106 -24
  115. package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
  116. package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
  117. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
  118. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
  119. package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
  120. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
  121. package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
  122. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
  123. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  124. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  125. package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
  126. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  127. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  128. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  129. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  130. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  131. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  132. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  133. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  134. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  135. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
  136. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  137. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  138. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
  139. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
  140. package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
  141. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  142. package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
  143. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
  151. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
  152. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
  153. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  155. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
  156. package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
  157. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
  158. package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
  159. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
  160. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
  161. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
  162. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  163. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  164. package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
  165. package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
  167. package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
  169. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
  172. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  173. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  174. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
  175. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
  176. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  177. package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
  178. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  179. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  180. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
  181. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
  182. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
  183. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  184. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  185. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
  187. package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
  188. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
  189. package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
  190. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
  191. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
  192. package/src/llama.cpp/ggml/src/ggml.c +367 -207
  193. package/src/llama.cpp/include/llama-cpp.h +25 -0
  194. package/src/llama.cpp/include/llama.h +26 -19
  195. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  196. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  197. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  198. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  199. package/src/llama.cpp/src/CMakeLists.txt +2 -7
  200. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  201. package/src/llama.cpp/src/llama-grammar.h +2 -5
  202. package/src/llama.cpp/src/llama-sampling.cpp +35 -90
  203. package/src/llama.cpp/src/llama-vocab.cpp +6 -1
  204. package/src/llama.cpp/src/llama.cpp +1748 -640
  205. package/src/llama.cpp/src/unicode.cpp +62 -51
  206. package/src/llama.cpp/src/unicode.h +9 -10
  207. package/src/llama.cpp/tests/CMakeLists.txt +48 -37
  208. package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
  209. package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
  210. package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
  211. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  212. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  213. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  214. package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
  215. package/src/llama.cpp/tests/test-rope.cpp +61 -20
  216. package/src/llama.cpp/tests/test-sampling.cpp +2 -2
  217. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  218. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  219. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  220. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  221. package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
  222. package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
  223. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
  224. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
  225. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
@@ -2,10 +2,11 @@
2
2
 
3
3
  #include "arg.h"
4
4
  #include "common.h"
5
- #include "log.h"
6
- #include "sampling.h"
7
5
  #include "json-schema-to-grammar.h"
8
6
  #include "llama.h"
7
+ #include "log.h"
8
+ #include "sampling.h"
9
+ #include "speculative.h"
9
10
 
10
11
  // Change JSON_ASSERT from assert() to GGML_ASSERT:
11
12
  #define JSON_ASSERT GGML_ASSERT
@@ -14,13 +15,8 @@
14
15
  #define MIMETYPE_JSON "application/json; charset=utf-8"
15
16
 
16
17
  // auto generated files (update with ./deps.sh)
17
- #include "index.html.hpp"
18
- #include "completion.js.hpp"
18
+ #include "index.html.gz.hpp"
19
19
  #include "loading.html.hpp"
20
- #include "deps_daisyui.min.css.hpp"
21
- #include "deps_markdown-it.js.hpp"
22
- #include "deps_tailwindcss.js.hpp"
23
- #include "deps_vue.esm-browser.js.hpp"
24
20
 
25
21
  #include <atomic>
26
22
  #include <condition_variable>
@@ -37,8 +33,10 @@
37
33
  using json = nlohmann::ordered_json;
38
34
 
39
35
  enum stop_type {
40
- STOP_TYPE_FULL,
41
- STOP_TYPE_PARTIAL,
36
+ STOP_TYPE_NONE,
37
+ STOP_TYPE_EOS,
38
+ STOP_TYPE_WORD,
39
+ STOP_TYPE_LIMIT,
42
40
  };
43
41
 
44
42
  // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
@@ -56,7 +54,10 @@ enum server_state {
56
54
  };
57
55
 
58
56
  enum server_task_type {
59
- SERVER_TASK_TYPE_INFERENCE,
57
+ SERVER_TASK_TYPE_COMPLETION,
58
+ SERVER_TASK_TYPE_EMBEDDING,
59
+ SERVER_TASK_TYPE_RERANK,
60
+ SERVER_TASK_TYPE_INFILL,
60
61
  SERVER_TASK_TYPE_CANCEL,
61
62
  SERVER_TASK_TYPE_NEXT_RESPONSE,
62
63
  SERVER_TASK_TYPE_METRICS,
@@ -66,22 +67,309 @@ enum server_task_type {
66
67
  SERVER_TASK_TYPE_SET_LORA,
67
68
  };
68
69
 
69
- enum server_task_inf_type {
70
- SERVER_TASK_INF_TYPE_COMPLETION,
71
- SERVER_TASK_INF_TYPE_EMBEDDING,
72
- SERVER_TASK_INF_TYPE_RERANK,
73
- SERVER_TASK_INF_TYPE_INFILL,
70
+ // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
71
+ enum error_type {
72
+ ERROR_TYPE_INVALID_REQUEST,
73
+ ERROR_TYPE_AUTHENTICATION,
74
+ ERROR_TYPE_SERVER,
75
+ ERROR_TYPE_NOT_FOUND,
76
+ ERROR_TYPE_PERMISSION,
77
+ ERROR_TYPE_UNAVAILABLE, // custom error
78
+ ERROR_TYPE_NOT_SUPPORTED, // custom error
79
+ };
80
+
81
+ struct slot_params {
82
+ bool stream = true;
83
+ bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
84
+ bool return_tokens = false;
85
+
86
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
87
+ int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
88
+ int32_t n_predict = -1; // new tokens to predict
89
+ int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
90
+
91
+ int64_t t_max_prompt_ms = -1; // TODO: implement
92
+ int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
93
+
94
+ std::vector<std::string> antiprompt;
95
+ bool timings_per_token = false;
96
+ bool post_sampling_probs = false;
97
+ bool ignore_eos = false;
98
+
99
+ struct common_params_sampling sampling;
100
+ struct common_params_speculative speculative;
101
+
102
+ // OAI-compat fields
103
+ bool verbose = false;
104
+ bool oaicompat = false;
105
+ bool oaicompat_chat = true;
106
+ std::string oaicompat_model;
107
+ std::string oaicompat_cmpl_id;
108
+
109
+ json to_json() const {
110
+ std::vector<std::string> samplers;
111
+ samplers.reserve(sampling.samplers.size());
112
+ for (const auto & sampler : sampling.samplers) {
113
+ samplers.emplace_back(common_sampler_type_to_str(sampler));
114
+ }
115
+
116
+ return json {
117
+ {"n_predict", n_predict}, // Server configured n_predict
118
+ {"seed", sampling.seed},
119
+ {"temperature", sampling.temp},
120
+ {"dynatemp_range", sampling.dynatemp_range},
121
+ {"dynatemp_exponent", sampling.dynatemp_exponent},
122
+ {"top_k", sampling.top_k},
123
+ {"top_p", sampling.top_p},
124
+ {"min_p", sampling.min_p},
125
+ {"xtc_probability", sampling.xtc_probability},
126
+ {"xtc_threshold", sampling.xtc_threshold},
127
+ {"typical_p", sampling.typ_p},
128
+ {"repeat_last_n", sampling.penalty_last_n},
129
+ {"repeat_penalty", sampling.penalty_repeat},
130
+ {"presence_penalty", sampling.penalty_present},
131
+ {"frequency_penalty", sampling.penalty_freq},
132
+ {"dry_multiplier", sampling.dry_multiplier},
133
+ {"dry_base", sampling.dry_base},
134
+ {"dry_allowed_length", sampling.dry_allowed_length},
135
+ {"dry_penalty_last_n", sampling.dry_penalty_last_n},
136
+ {"dry_sequence_breakers", sampling.dry_sequence_breakers},
137
+ {"mirostat", sampling.mirostat},
138
+ {"mirostat_tau", sampling.mirostat_tau},
139
+ {"mirostat_eta", sampling.mirostat_eta},
140
+ {"stop", antiprompt},
141
+ {"max_tokens", n_predict}, // User configured n_predict
142
+ {"n_keep", n_keep},
143
+ {"n_discard", n_discard},
144
+ {"ignore_eos", sampling.ignore_eos},
145
+ {"stream", stream},
146
+ {"logit_bias", format_logit_bias(sampling.logit_bias)},
147
+ {"n_probs", sampling.n_probs},
148
+ {"min_keep", sampling.min_keep},
149
+ {"grammar", sampling.grammar},
150
+ {"samplers", samplers},
151
+ {"speculative.n_max", speculative.n_max},
152
+ {"speculative.n_min", speculative.n_min},
153
+ {"speculative.p_min", speculative.p_min},
154
+ {"timings_per_token", timings_per_token},
155
+ {"post_sampling_probs", post_sampling_probs},
156
+ };
157
+ }
74
158
  };
75
159
 
76
160
  struct server_task {
77
- int id = -1; // to be filled by server_queue
78
- int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
161
+ int id = -1; // to be filled by server_queue
162
+ int index = -1; // used when there are multiple prompts (batch request)
79
163
 
80
- llama_tokens prompt_tokens;
81
164
  server_task_type type;
82
- json data;
83
165
 
84
- server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
166
+ // used by SERVER_TASK_TYPE_CANCEL
167
+ int id_target = -1;
168
+
169
+ // used by SERVER_TASK_TYPE_INFERENCE
170
+ slot_params params;
171
+ llama_tokens prompt_tokens;
172
+ int id_selected_slot = -1;
173
+
174
+ // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
175
+ struct slot_action {
176
+ int slot_id;
177
+ std::string filename;
178
+ std::string filepath;
179
+ };
180
+ slot_action slot_action;
181
+
182
+ // used by SERVER_TASK_TYPE_METRICS
183
+ bool metrics_reset_bucket = false;
184
+
185
+ server_task(server_task_type type) : type(type) {}
186
+
187
+ static slot_params params_from_json_cmpl(
188
+ const llama_model * model,
189
+ const llama_context * ctx,
190
+ const common_params & params_base,
191
+ const json & data) {
192
+ slot_params params;
193
+
194
+ // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
195
+ slot_params defaults;
196
+ defaults.sampling = params_base.sampling;
197
+ defaults.speculative = params_base.speculative;
198
+
199
+ // enabling this will output extra debug information in the HTTP responses from the server
200
+ params.verbose = params_base.verbosity > 9;
201
+ params.timings_per_token = json_value(data, "timings_per_token", false);
202
+
203
+ params.stream = json_value(data, "stream", false);
204
+ params.cache_prompt = json_value(data, "cache_prompt", true);
205
+ params.return_tokens = json_value(data, "return_tokens", false);
206
+ params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
207
+ params.n_indent = json_value(data, "n_indent", defaults.n_indent);
208
+ params.n_keep = json_value(data, "n_keep", defaults.n_keep);
209
+ params.n_discard = json_value(data, "n_discard", defaults.n_discard);
210
+ //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
211
+ params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
212
+
213
+ params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
214
+ params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
215
+ params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
216
+ params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
217
+ params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
218
+ params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
219
+ params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
220
+ params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
221
+ params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
222
+ params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
223
+ params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
224
+ params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
225
+ params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
226
+ params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
227
+ params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
228
+ params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
229
+ params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
230
+ params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
231
+ params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
232
+ params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
233
+ params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
234
+ params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
235
+ params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
236
+ params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
237
+
238
+ params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
239
+ params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
240
+ params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
241
+
242
+ params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
243
+ params.speculative.n_min = std::max(params.speculative.n_min, 2);
244
+ params.speculative.n_max = std::max(params.speculative.n_max, 0);
245
+
246
+ // TODO: add more sanity checks for the input parameters
247
+
248
+ if (params.sampling.penalty_last_n < -1) {
249
+ throw std::runtime_error("Error: repeat_last_n must be >= -1");
250
+ }
251
+
252
+ if (params.sampling.dry_penalty_last_n < -1) {
253
+ throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
254
+ }
255
+
256
+ if (params.sampling.penalty_last_n == -1) {
257
+ // note: should be the slot's context and not the full context, but it's ok
258
+ params.sampling.penalty_last_n = llama_n_ctx(ctx);
259
+ }
260
+
261
+ if (params.sampling.dry_penalty_last_n == -1) {
262
+ params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
263
+ }
264
+
265
+ if (params.sampling.dry_base < 1.0f) {
266
+ params.sampling.dry_base = defaults.sampling.dry_base;
267
+ }
268
+
269
+ // sequence breakers for DRY
270
+ {
271
+ // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
272
+ // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
273
+
274
+ if (data.contains("dry_sequence_breakers")) {
275
+ params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
276
+ if (params.sampling.dry_sequence_breakers.empty()) {
277
+ throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
278
+ }
279
+ }
280
+ }
281
+
282
+ // process "json_schema" and "grammar"
283
+ if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
284
+ throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
285
+ }
286
+ if (data.contains("json_schema") && !data.contains("grammar")) {
287
+ try {
288
+ auto schema = json_value(data, "json_schema", json::object());
289
+ params.sampling.grammar = json_schema_to_grammar(schema);
290
+ } catch (const std::exception & e) {
291
+ throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
292
+ }
293
+ } else {
294
+ params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
295
+ }
296
+
297
+ {
298
+ params.sampling.logit_bias.clear();
299
+ params.ignore_eos = json_value(data, "ignore_eos", false);
300
+
301
+ const auto & logit_bias = data.find("logit_bias");
302
+ if (logit_bias != data.end() && logit_bias->is_array()) {
303
+ const int n_vocab = llama_n_vocab(model);
304
+ for (const auto & el : *logit_bias) {
305
+ // TODO: we may want to throw errors here, in case "el" is incorrect
306
+ if (el.is_array() && el.size() == 2) {
307
+ float bias;
308
+ if (el[1].is_number()) {
309
+ bias = el[1].get<float>();
310
+ } else if (el[1].is_boolean() && !el[1].get<bool>()) {
311
+ bias = -INFINITY;
312
+ } else {
313
+ continue;
314
+ }
315
+
316
+ if (el[0].is_number_integer()) {
317
+ llama_token tok = el[0].get<llama_token>();
318
+ if (tok >= 0 && tok < n_vocab) {
319
+ params.sampling.logit_bias.push_back({tok, bias});
320
+ }
321
+ } else if (el[0].is_string()) {
322
+ auto toks = common_tokenize(model, el[0].get<std::string>(), false);
323
+ for (auto tok : toks) {
324
+ params.sampling.logit_bias.push_back({tok, bias});
325
+ }
326
+ }
327
+ }
328
+ }
329
+ }
330
+ }
331
+
332
+ {
333
+ params.antiprompt.clear();
334
+
335
+ const auto & stop = data.find("stop");
336
+ if (stop != data.end() && stop->is_array()) {
337
+ for (const auto & word : *stop) {
338
+ if (!word.empty()) {
339
+ params.antiprompt.push_back(word);
340
+ }
341
+ }
342
+ }
343
+ }
344
+
345
+ {
346
+ const auto & samplers = data.find("samplers");
347
+ if (samplers != data.end()) {
348
+ if (samplers->is_array()) {
349
+ std::vector<std::string> sampler_names;
350
+ for (const auto & name : *samplers) {
351
+ if (name.is_string()) {
352
+ sampler_names.emplace_back(name);
353
+ }
354
+ }
355
+ params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
356
+ } else if (samplers->is_string()){
357
+ std::string sampler_string;
358
+ for (const auto & name : *samplers) {
359
+ sampler_string += name;
360
+ }
361
+ params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
362
+ }
363
+ } else {
364
+ params.sampling.samplers = defaults.sampling.samplers;
365
+ }
366
+ }
367
+
368
+ std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
369
+ params.oaicompat_model = json_value(data, "model", model_name);
370
+
371
+ return params;
372
+ }
85
373
 
86
374
  // utility function
87
375
  static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@@ -93,40 +381,628 @@ struct server_task {
93
381
  }
94
382
  };
95
383
 
384
+ struct result_timings {
385
+ int32_t prompt_n = -1;
386
+ double prompt_ms;
387
+ double prompt_per_token_ms;
388
+ double prompt_per_second;
389
+
390
+ int32_t predicted_n = -1;
391
+ double predicted_ms;
392
+ double predicted_per_token_ms;
393
+ double predicted_per_second;
394
+
395
+ json to_json() const {
396
+ return {
397
+ {"prompt_n", prompt_n},
398
+ {"prompt_ms", prompt_ms},
399
+ {"prompt_per_token_ms", prompt_per_token_ms},
400
+ {"prompt_per_second", prompt_per_second},
401
+
402
+ {"predicted_n", predicted_n},
403
+ {"predicted_ms", predicted_ms},
404
+ {"predicted_per_token_ms", predicted_per_token_ms},
405
+ {"predicted_per_second", predicted_per_second},
406
+ };
407
+ }
408
+ };
409
+
96
410
  struct server_task_result {
97
- int id = -1;
411
+ int id = -1;
412
+ int id_slot = -1;
413
+ virtual bool is_error() {
414
+ // only used by server_task_result_error
415
+ return false;
416
+ }
417
+ virtual bool is_stop() {
418
+ // only used by server_task_result_cmpl_*
419
+ return false;
420
+ }
421
+ virtual int get_index() {
422
+ return -1;
423
+ }
424
+ virtual json to_json() = 0;
425
+ virtual ~server_task_result() = default;
426
+ };
98
427
 
99
- json data;
428
+ // using shared_ptr for polymorphism of server_task_result
429
+ using server_task_result_ptr = std::unique_ptr<server_task_result>;
100
430
 
101
- bool stop;
102
- bool error;
431
+ inline std::string stop_type_to_str(stop_type type) {
432
+ switch (type) {
433
+ case STOP_TYPE_EOS: return "eos";
434
+ case STOP_TYPE_WORD: return "word";
435
+ case STOP_TYPE_LIMIT: return "limit";
436
+ default: return "none";
437
+ }
438
+ }
439
+
440
+ struct completion_token_output {
441
+ llama_token tok;
442
+ float prob;
443
+ std::string text_to_send;
444
+ struct prob_info {
445
+ llama_token tok;
446
+ std::string txt;
447
+ float prob;
448
+ };
449
+ std::vector<prob_info> probs;
450
+
451
+ json to_json(bool post_sampling_probs) const {
452
+ json probs_for_token = json::array();
453
+ for (const auto & p : probs) {
454
+ std::string txt(p.txt);
455
+ txt.resize(validate_utf8(txt));
456
+ probs_for_token.push_back(json {
457
+ {"id", p.tok},
458
+ {"token", txt},
459
+ {"bytes", str_to_bytes(p.txt)},
460
+ {
461
+ post_sampling_probs ? "prob" : "logprob",
462
+ post_sampling_probs ? p.prob : logarithm(p.prob)
463
+ },
464
+ });
465
+ }
466
+ return probs_for_token;
467
+ }
468
+
469
+ static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
470
+ json out = json::array();
471
+ for (const auto & p : probs) {
472
+ std::string txt(p.text_to_send);
473
+ txt.resize(validate_utf8(txt));
474
+ out.push_back(json {
475
+ {"id", p.tok},
476
+ {"token", txt},
477
+ {"bytes", str_to_bytes(p.text_to_send)},
478
+ {
479
+ post_sampling_probs ? "prob" : "logprob",
480
+ post_sampling_probs ? p.prob : logarithm(p.prob)
481
+ },
482
+ {
483
+ post_sampling_probs ? "top_probs" : "top_logprobs",
484
+ p.to_json(post_sampling_probs)
485
+ },
486
+ });
487
+ }
488
+ return out;
489
+ }
490
+
491
+ static float logarithm(float x) {
492
+ // nlohmann::json converts -inf to null, so we need to prevent that
493
+ return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
494
+ }
495
+
496
+ static std::vector<unsigned char> str_to_bytes(const std::string & str) {
497
+ std::vector<unsigned char> bytes;
498
+ for (unsigned char c : str) {
499
+ bytes.push_back(c);
500
+ }
501
+ return bytes;
502
+ }
103
503
  };
104
504
 
105
- struct server_static_file {
106
- const unsigned char * data;
107
- unsigned int size;
108
- const char * mime_type;
505
+ struct server_task_result_cmpl_final : server_task_result {
506
+ int index = 0;
507
+
508
+ std::string content;
509
+ llama_tokens tokens;
510
+
511
+ bool stream;
512
+ result_timings timings;
513
+ std::string prompt;
514
+
515
+ bool truncated;
516
+ int32_t n_decoded;
517
+ int32_t n_prompt_tokens;
518
+ int32_t n_tokens_cached;
519
+ bool has_new_line;
520
+ std::string stopping_word;
521
+ stop_type stop = STOP_TYPE_NONE;
522
+
523
+ bool post_sampling_probs;
524
+ std::vector<completion_token_output> probs_output;
525
+
526
+ slot_params generation_params;
527
+
528
+ // OAI-compat fields
529
+ bool verbose = false;
530
+ bool oaicompat = false;
531
+ bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
532
+ std::string oaicompat_model;
533
+ std::string oaicompat_cmpl_id;
534
+
535
+ virtual int get_index() override {
536
+ return index;
537
+ }
538
+
539
+ virtual bool is_stop() override {
540
+ return true; // in stream mode, final responses are considered stop
541
+ }
542
+
543
+ virtual json to_json() override {
544
+ return oaicompat
545
+ ? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
546
+ : to_json_non_oaicompat();
547
+ }
548
+
549
+ json to_json_non_oaicompat() {
550
+ json res = json {
551
+ {"index", index},
552
+ {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
553
+ {"tokens", stream ? llama_tokens {} : tokens},
554
+ {"id_slot", id_slot},
555
+ {"stop", true},
556
+ {"model", oaicompat_model},
557
+ {"tokens_predicted", n_decoded},
558
+ {"tokens_evaluated", n_prompt_tokens},
559
+ {"generation_settings", generation_params.to_json()},
560
+ {"prompt", prompt},
561
+ {"has_new_line", has_new_line},
562
+ {"truncated", truncated},
563
+ {"stop_type", stop_type_to_str(stop)},
564
+ {"stopping_word", stopping_word},
565
+ {"tokens_cached", n_tokens_cached},
566
+ {"timings", timings.to_json()},
567
+ };
568
+ if (!stream && !probs_output.empty()) {
569
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
570
+ }
571
+ return res;
572
+ }
573
+
574
+ json to_json_oaicompat_chat() {
575
+ std::string finish_reason = "length";
576
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
577
+ finish_reason = "stop";
578
+ }
579
+
580
+ json choice = json{
581
+ {"finish_reason", finish_reason},
582
+ {"index", 0},
583
+ {"message", json {
584
+ {"content", content},
585
+ {"role", "assistant"}
586
+ }
587
+ }};
588
+
589
+ if (!stream && probs_output.size() > 0) {
590
+ choice["logprobs"] = json{
591
+ {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
592
+ };
593
+ }
594
+
595
+ std::time_t t = std::time(0);
596
+
597
+ json res = json {
598
+ {"choices", json::array({choice})},
599
+ {"created", t},
600
+ {"model", oaicompat_model},
601
+ {"object", "chat.completion"},
602
+ {"usage", json {
603
+ {"completion_tokens", n_decoded},
604
+ {"prompt_tokens", n_prompt_tokens},
605
+ {"total_tokens", n_decoded + n_prompt_tokens}
606
+ }},
607
+ {"id", oaicompat_cmpl_id}
608
+ };
609
+
610
+ // extra fields for debugging purposes
611
+ if (verbose) {
612
+ res["__verbose"] = to_json_non_oaicompat();
613
+ }
614
+ if (timings.prompt_n >= 0) {
615
+ res.push_back({"timings", timings.to_json()});
616
+ }
617
+
618
+ return res;
619
+ }
620
+
621
+ json to_json_oaicompat_chat_stream() {
622
+ std::time_t t = std::time(0);
623
+ std::string finish_reason = "length";
624
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
625
+ finish_reason = "stop";
626
+ }
627
+
628
+ json choice = json{
629
+ {"finish_reason", finish_reason},
630
+ {"index", 0},
631
+ {"delta", json::object()}
632
+ };
633
+
634
+ json ret = json {
635
+ {"choices", json::array({choice})},
636
+ {"created", t},
637
+ {"id", oaicompat_cmpl_id},
638
+ {"model", oaicompat_model},
639
+ {"object", "chat.completion.chunk"},
640
+ {"usage", json {
641
+ {"completion_tokens", n_decoded},
642
+ {"prompt_tokens", n_prompt_tokens},
643
+ {"total_tokens", n_decoded + n_prompt_tokens},
644
+ }},
645
+ };
646
+
647
+ if (timings.prompt_n >= 0) {
648
+ ret.push_back({"timings", timings.to_json()});
649
+ }
650
+
651
+ return ret;
652
+ }
109
653
  };
110
654
 
111
- struct slot_params {
112
- bool stream = true;
113
- bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
655
+ struct server_task_result_cmpl_partial : server_task_result {
656
+ int index = 0;
114
657
 
115
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
116
- int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
117
- int32_t n_predict = -1; // new tokens to predict
118
- int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
658
+ std::string content;
659
+ llama_tokens tokens;
119
660
 
120
- int64_t t_max_prompt_ms = -1; // TODO: implement
121
- int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
661
+ int32_t n_decoded;
662
+ int32_t n_prompt_tokens;
122
663
 
123
- std::vector<std::string> antiprompt;
664
+ bool post_sampling_probs;
665
+ completion_token_output prob_output;
666
+ result_timings timings;
667
+
668
+ // OAI-compat fields
669
+ bool verbose = false;
670
+ bool oaicompat = false;
671
+ bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
672
+ std::string oaicompat_model;
673
+ std::string oaicompat_cmpl_id;
674
+
675
+ virtual int get_index() override {
676
+ return index;
677
+ }
678
+
679
+ virtual bool is_stop() override {
680
+ return false; // in stream mode, partial responses are not considered stop
681
+ }
682
+
683
+ virtual json to_json() override {
684
+ return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
685
+ }
686
+
687
+ json to_json_non_oaicompat() {
688
+ // non-OAI-compat JSON
689
+ json res = json {
690
+ {"index", index},
691
+ {"content", content},
692
+ {"tokens", tokens},
693
+ {"stop", false},
694
+ {"id_slot", id_slot},
695
+ {"tokens_predicted", n_decoded},
696
+ {"tokens_evaluated", n_prompt_tokens},
697
+ };
698
+ // populate the timings object when needed (usually for the last response or with timings_per_token enabled)
699
+ if (timings.prompt_n > 0) {
700
+ res.push_back({"timings", timings.to_json()});
701
+ }
702
+ if (!prob_output.probs.empty()) {
703
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
704
+ }
705
+ return res;
706
+ }
707
+
708
+ json to_json_oaicompat() {
709
+ bool first = n_decoded == 0;
710
+ std::time_t t = std::time(0);
711
+ json choices;
712
+
713
+ if (first) {
714
+ if (content.empty()) {
715
+ choices = json::array({json{{"finish_reason", nullptr},
716
+ {"index", 0},
717
+ {"delta", json{{"role", "assistant"}}}}});
718
+ } else {
719
+ // We have to send this as two updates to conform to openai behavior
720
+ json initial_ret = json{{"choices", json::array({json{
721
+ {"finish_reason", nullptr},
722
+ {"index", 0},
723
+ {"delta", json{
724
+ {"role", "assistant"}
725
+ }}}})},
726
+ {"created", t},
727
+ {"id", oaicompat_cmpl_id},
728
+ {"model", oaicompat_model},
729
+ {"object", "chat.completion.chunk"}};
730
+
731
+ json second_ret = json{
732
+ {"choices", json::array({json{{"finish_reason", nullptr},
733
+ {"index", 0},
734
+ {"delta", json {
735
+ {"content", content}}}
736
+ }})},
737
+ {"created", t},
738
+ {"id", oaicompat_cmpl_id},
739
+ {"model", oaicompat_model},
740
+ {"object", "chat.completion.chunk"}};
741
+
742
+ return std::vector<json>({initial_ret, second_ret});
743
+ }
744
+ } else {
745
+ choices = json::array({json{
746
+ {"finish_reason", nullptr},
747
+ {"index", 0},
748
+ {"delta",
749
+ json {
750
+ {"content", content},
751
+ }},
752
+ }});
753
+ }
754
+
755
+ GGML_ASSERT(choices.size() >= 1);
756
+
757
+ if (prob_output.probs.size() > 0) {
758
+ choices[0]["logprobs"] = json{
759
+ {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
760
+ };
761
+ }
762
+
763
+ json ret = json {
764
+ {"choices", choices},
765
+ {"created", t},
766
+ {"id", oaicompat_cmpl_id},
767
+ {"model", oaicompat_model},
768
+ {"object", "chat.completion.chunk"}
769
+ };
770
+
771
+ if (timings.prompt_n >= 0) {
772
+ ret.push_back({"timings", timings.to_json()});
773
+ }
774
+
775
+ return std::vector<json>({ret});
776
+ }
777
+ };
778
+
779
+ struct server_task_result_embd : server_task_result {
780
+ int index = 0;
781
+ std::vector<std::vector<float>> embedding;
782
+
783
+ int32_t n_tokens;
784
+
785
+ // OAI-compat fields
786
+ bool oaicompat = false;
787
+
788
+ virtual int get_index() override {
789
+ return index;
790
+ }
791
+
792
+ virtual json to_json() override {
793
+ return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
794
+ }
795
+
796
+ json to_json_non_oaicompat() {
797
+ return json {
798
+ {"index", index},
799
+ {"embedding", embedding},
800
+ };
801
+ }
802
+
803
+ json to_json_oaicompat() {
804
+ return json {
805
+ {"index", index},
806
+ {"embedding", embedding[0]},
807
+ {"tokens_evaluated", n_tokens},
808
+ };
809
+ }
810
+ };
811
+
812
+ struct server_task_result_rerank : server_task_result {
813
+ int index = 0;
814
+ float score = -1e6;
815
+
816
+ int32_t n_tokens;
817
+
818
+ virtual int get_index() override {
819
+ return index;
820
+ }
821
+
822
+ virtual json to_json() override {
823
+ return json {
824
+ {"index", index},
825
+ {"score", score},
826
+ {"tokens_evaluated", n_tokens},
827
+ };
828
+ }
829
+ };
830
+
831
+ // this function maybe used outside of server_task_result_error
832
+ static json format_error_response(const std::string & message, const enum error_type type) {
833
+ std::string type_str;
834
+ int code = 500;
835
+ switch (type) {
836
+ case ERROR_TYPE_INVALID_REQUEST:
837
+ type_str = "invalid_request_error";
838
+ code = 400;
839
+ break;
840
+ case ERROR_TYPE_AUTHENTICATION:
841
+ type_str = "authentication_error";
842
+ code = 401;
843
+ break;
844
+ case ERROR_TYPE_NOT_FOUND:
845
+ type_str = "not_found_error";
846
+ code = 404;
847
+ break;
848
+ case ERROR_TYPE_SERVER:
849
+ type_str = "server_error";
850
+ code = 500;
851
+ break;
852
+ case ERROR_TYPE_PERMISSION:
853
+ type_str = "permission_error";
854
+ code = 403;
855
+ break;
856
+ case ERROR_TYPE_NOT_SUPPORTED:
857
+ type_str = "not_supported_error";
858
+ code = 501;
859
+ break;
860
+ case ERROR_TYPE_UNAVAILABLE:
861
+ type_str = "unavailable_error";
862
+ code = 503;
863
+ break;
864
+ }
865
+ return json {
866
+ {"code", code},
867
+ {"message", message},
868
+ {"type", type_str},
869
+ };
870
+ }
871
+
872
+ struct server_task_result_error : server_task_result {
873
+ int index = 0;
874
+ error_type err_type = ERROR_TYPE_SERVER;
875
+ std::string err_msg;
876
+
877
+ virtual bool is_error() override {
878
+ return true;
879
+ }
880
+
881
+ virtual json to_json() override {
882
+ return format_error_response(err_msg, err_type);
883
+ }
884
+ };
885
+
886
+ struct server_task_result_metrics : server_task_result {
887
+ int n_idle_slots;
888
+ int n_processing_slots;
889
+ int n_tasks_deferred;
890
+ int64_t t_start;
891
+
892
+ int32_t kv_cache_tokens_count;
893
+ int32_t kv_cache_used_cells;
894
+
895
+ // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
896
+ uint64_t n_prompt_tokens_processed_total = 0;
897
+ uint64_t t_prompt_processing_total = 0;
898
+ uint64_t n_tokens_predicted_total = 0;
899
+ uint64_t t_tokens_generation_total = 0;
900
+
901
+ uint64_t n_prompt_tokens_processed = 0;
902
+ uint64_t t_prompt_processing = 0;
903
+
904
+ uint64_t n_tokens_predicted = 0;
905
+ uint64_t t_tokens_generation = 0;
906
+
907
+ uint64_t n_decode_total = 0;
908
+ uint64_t n_busy_slots_total = 0;
909
+
910
+ // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
911
+ // therefore, we use json to temporarily store the slot.to_json() result
912
+ json slots_data = json::array();
913
+
914
+ virtual json to_json() override {
915
+ return json {
916
+ { "idle", n_idle_slots },
917
+ { "processing", n_processing_slots },
918
+ { "deferred", n_tasks_deferred },
919
+ { "t_start", t_start },
920
+
921
+ { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
922
+ { "t_tokens_generation_total", t_tokens_generation_total },
923
+ { "n_tokens_predicted_total", n_tokens_predicted_total },
924
+ { "t_prompt_processing_total", t_prompt_processing_total },
925
+
926
+ { "n_prompt_tokens_processed", n_prompt_tokens_processed },
927
+ { "t_prompt_processing", t_prompt_processing },
928
+ { "n_tokens_predicted", n_tokens_predicted },
929
+ { "t_tokens_generation", t_tokens_generation },
930
+
931
+ { "n_decode_total", n_decode_total },
932
+ { "n_busy_slots_total", n_busy_slots_total },
933
+
934
+ { "kv_cache_tokens_count", kv_cache_tokens_count },
935
+ { "kv_cache_used_cells", kv_cache_used_cells },
936
+
937
+ { "slots", slots_data },
938
+ };
939
+ }
940
+ };
941
+
942
+ struct server_task_result_slot_save_load : server_task_result {
943
+ std::string filename;
944
+ bool is_save; // true = save, false = load
945
+
946
+ size_t n_tokens;
947
+ size_t n_bytes;
948
+ double t_ms;
949
+
950
+ virtual json to_json() override {
951
+ if (is_save) {
952
+ return json {
953
+ { "id_slot", id_slot },
954
+ { "filename", filename },
955
+ { "n_saved", n_tokens },
956
+ { "n_written", n_bytes },
957
+ { "timings", {
958
+ { "save_ms", t_ms }
959
+ }},
960
+ };
961
+ } else {
962
+ return json {
963
+ { "id_slot", id_slot },
964
+ { "filename", filename },
965
+ { "n_restored", n_tokens },
966
+ { "n_read", n_bytes },
967
+ { "timings", {
968
+ { "restore_ms", t_ms }
969
+ }},
970
+ };
971
+ }
972
+ }
973
+ };
974
+
975
+ struct server_task_result_slot_erase : server_task_result {
976
+ size_t n_erased;
977
+
978
+ virtual json to_json() override {
979
+ return json {
980
+ { "id_slot", id_slot },
981
+ { "n_erased", n_erased },
982
+ };
983
+ }
984
+ };
985
+
986
+ struct server_task_result_apply_lora : server_task_result {
987
+ virtual json to_json() override {
988
+ return json {{ "success", true }};
989
+ }
124
990
  };
125
991
 
126
992
  struct server_slot {
127
993
  int id;
128
994
  int id_task = -1;
129
995
 
996
+ // only used for completion/embedding/infill/rerank
997
+ server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
998
+
999
+ llama_batch batch_spec = {};
1000
+
1001
+ llama_context * ctx = nullptr;
1002
+ llama_context * ctx_dft = nullptr;
1003
+
1004
+ common_speculative * spec = nullptr;
1005
+
130
1006
  // the index relative to completion multi-task request
131
1007
  size_t index = 0;
132
1008
 
@@ -154,35 +1030,29 @@ struct server_slot {
154
1030
 
155
1031
  size_t last_nl_pos = 0;
156
1032
 
157
- std::string generated_text;
1033
+ std::string generated_text;
1034
+ llama_tokens generated_tokens;
1035
+
158
1036
  llama_tokens cache_tokens;
159
- std::vector<completion_token_output> generated_token_probs;
160
1037
 
161
- server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
1038
+ std::vector<completion_token_output> generated_token_probs;
162
1039
 
163
1040
  bool has_next_token = true;
164
1041
  bool has_new_line = false;
165
1042
  bool truncated = false;
166
- bool stopped_eos = false;
167
- bool stopped_word = false;
168
- bool stopped_limit = false;
1043
+ stop_type stop;
169
1044
 
170
- bool oaicompat = false;
171
-
172
- std::string oaicompat_model;
173
1045
  std::string stopping_word;
174
1046
 
175
1047
  // sampling
176
1048
  json json_schema;
177
1049
 
178
- struct common_sampler_params sparams;
179
1050
  struct common_sampler * smpl = nullptr;
180
1051
 
181
1052
  llama_token sampled;
182
1053
 
183
1054
  // stats
184
1055
  size_t n_sent_text = 0; // number of sent text character
185
- size_t n_sent_token_probs = 0;
186
1056
 
187
1057
  int64_t t_start_process_prompt;
188
1058
  int64_t t_start_generation;
@@ -200,19 +1070,21 @@ struct server_slot {
200
1070
  generated_text = "";
201
1071
  has_new_line = false;
202
1072
  truncated = false;
203
- stopped_eos = false;
204
- stopped_word = false;
205
- stopped_limit = false;
1073
+ stop = STOP_TYPE_NONE;
206
1074
  stopping_word = "";
207
1075
  n_past = 0;
208
1076
  n_sent_text = 0;
209
- n_sent_token_probs = 0;
210
- inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
1077
+ task_type = SERVER_TASK_TYPE_COMPLETION;
211
1078
 
1079
+ generated_tokens.clear();
212
1080
  generated_token_probs.clear();
213
1081
  }
214
1082
 
215
- bool has_budget(common_params &global_params) {
1083
+ bool is_non_causal() const {
1084
+ return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
1085
+ }
1086
+
1087
+ bool has_budget(const common_params & global_params) {
216
1088
  if (params.n_predict == -1 && global_params.n_predict == -1) {
217
1089
  return true; // limitless
218
1090
  }
@@ -232,6 +1104,10 @@ struct server_slot {
232
1104
  return state != SLOT_STATE_IDLE;
233
1105
  }
234
1106
 
1107
+ bool can_speculate() const {
1108
+ return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
1109
+ }
1110
+
235
1111
  void add_token(const completion_token_output & token) {
236
1112
  if (!is_processing()) {
237
1113
  SLT_WRN(*this, "%s", "slot is not processing\n");
@@ -251,38 +1127,40 @@ struct server_slot {
251
1127
  }
252
1128
  }
253
1129
 
254
- json get_formated_timings() const {
255
- return json {
256
- {"prompt_n", n_prompt_tokens_processed},
257
- {"prompt_ms", t_prompt_processing},
258
- {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed},
259
- {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed},
260
-
261
- {"predicted_n", n_decoded},
262
- {"predicted_ms", t_token_generation},
263
- {"predicted_per_token_ms", t_token_generation / n_decoded},
264
- {"predicted_per_second", 1e3 / t_token_generation * n_decoded},
265
- };
1130
+ result_timings get_timings() const {
1131
+ result_timings timings;
1132
+ timings.prompt_n = n_prompt_tokens_processed;
1133
+ timings.prompt_ms = t_prompt_processing;
1134
+ timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
1135
+ timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
1136
+
1137
+ timings.predicted_n = n_decoded;
1138
+ timings.predicted_ms = t_token_generation;
1139
+ timings.predicted_per_token_ms = t_token_generation / n_decoded;
1140
+ timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
1141
+
1142
+ return timings;
266
1143
  }
267
1144
 
268
- size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) {
1145
+ size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
269
1146
  size_t stop_pos = std::string::npos;
270
1147
 
271
1148
  for (const std::string & word : params.antiprompt) {
272
1149
  size_t pos;
273
1150
 
274
- if (type == STOP_TYPE_FULL) {
1151
+ if (is_full_stop) {
275
1152
  const size_t tmp = word.size() + last_token_size;
276
1153
  const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
277
1154
 
278
1155
  pos = text.find(word, from_pos);
279
1156
  } else {
1157
+ // otherwise, partial stop
280
1158
  pos = find_partial_stop_string(word, text);
281
1159
  }
282
1160
 
283
1161
  if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
284
- if (type == STOP_TYPE_FULL) {
285
- stopped_word = true;
1162
+ if (is_full_stop) {
1163
+ stop = STOP_TYPE_WORD;
286
1164
  stopping_word = word;
287
1165
  has_next_token = false;
288
1166
  }
@@ -302,13 +1180,35 @@ struct server_slot {
302
1180
 
303
1181
  SLT_INF(*this,
304
1182
  "\n"
305
- "\rprompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
306
- "\r eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
307
- "\r total time = %10.2f ms / %5d tokens\n",
1183
+ "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
1184
+ " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
1185
+ " total time = %10.2f ms / %5d tokens\n",
308
1186
  t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
309
1187
  t_token_generation, n_decoded, t_gen, n_gen_second,
310
1188
  t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
311
1189
  }
1190
+
1191
+ json to_json() const {
1192
+ return json {
1193
+ {"id", id},
1194
+ {"id_task", id_task},
1195
+ {"n_ctx", n_ctx},
1196
+ {"speculative", can_speculate()},
1197
+ {"is_processing", is_processing()},
1198
+ {"non_causal", is_non_causal()},
1199
+ {"params", params.to_json()},
1200
+ {"prompt", common_detokenize(ctx, prompt_tokens)},
1201
+ {"next_token",
1202
+ {
1203
+ {"has_next_token", has_next_token},
1204
+ {"has_new_line", has_new_line},
1205
+ {"n_remain", n_remaining},
1206
+ {"n_decoded", n_decoded},
1207
+ {"stopping_word", stopping_word},
1208
+ }
1209
+ },
1210
+ };
1211
+ }
312
1212
  };
313
1213
 
314
1214
  struct server_metrics {
@@ -381,9 +1281,7 @@ struct server_queue {
381
1281
  // Add a new task to the end of the queue
382
1282
  int post(server_task task, bool front = false) {
383
1283
  std::unique_lock<std::mutex> lock(mutex_tasks);
384
- if (task.id == -1) {
385
- task.id = id++;
386
- }
1284
+ GGML_ASSERT(task.id != -1);
387
1285
  QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
388
1286
  if (front) {
389
1287
  queue_tasks.push_front(std::move(task));
@@ -507,8 +1405,8 @@ struct server_response {
507
1405
  // for keeping track of all tasks waiting for the result
508
1406
  std::unordered_set<int> waiting_task_ids;
509
1407
 
510
- // the main result queue
511
- std::vector<server_task_result> queue_results;
1408
+ // the main result queue (using ptr for polymorphism)
1409
+ std::vector<server_task_result_ptr> queue_results;
512
1410
 
513
1411
  std::mutex mutex_results;
514
1412
  std::condition_variable condition_results;
@@ -548,7 +1446,7 @@ struct server_response {
548
1446
  }
549
1447
 
550
1448
  // This function blocks the thread until there is a response for one of the id_tasks
551
- server_task_result recv(const std::unordered_set<int> & id_tasks) {
1449
+ server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
552
1450
  while (true) {
553
1451
  std::unique_lock<std::mutex> lock(mutex_results);
554
1452
  condition_results.wait(lock, [&]{
@@ -556,8 +1454,8 @@ struct server_response {
556
1454
  });
557
1455
 
558
1456
  for (int i = 0; i < (int) queue_results.size(); i++) {
559
- if (id_tasks.find(queue_results[i].id) != id_tasks.end()) {
560
- server_task_result res = queue_results[i];
1457
+ if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
1458
+ server_task_result_ptr res = std::move(queue_results[i]);
561
1459
  queue_results.erase(queue_results.begin() + i);
562
1460
  return res;
563
1461
  }
@@ -568,21 +1466,21 @@ struct server_response {
568
1466
  }
569
1467
 
570
1468
  // single-task version of recv()
571
- server_task_result recv(int id_task) {
1469
+ server_task_result_ptr recv(int id_task) {
572
1470
  std::unordered_set<int> id_tasks = {id_task};
573
1471
  return recv(id_tasks);
574
1472
  }
575
1473
 
576
1474
  // Send a new result to a waiting id_task
577
- void send(server_task_result & result) {
578
- SRV_DBG("sending result for task id = %d\n", result.id);
1475
+ void send(server_task_result_ptr && result) {
1476
+ SRV_DBG("sending result for task id = %d\n", result->id);
579
1477
 
580
1478
  std::unique_lock<std::mutex> lock(mutex_results);
581
1479
  for (const auto & id_task : waiting_task_ids) {
582
- if (result.id == id_task) {
583
- SRV_DBG("task id = %d moved to result queue\n", result.id);
1480
+ if (result->id == id_task) {
1481
+ SRV_DBG("task id = %d pushed to result queue\n", result->id);
584
1482
 
585
- queue_results.push_back(std::move(result));
1483
+ queue_results.emplace_back(std::move(result));
586
1484
  condition_results.notify_all();
587
1485
  return;
588
1486
  }
@@ -591,11 +1489,14 @@ struct server_response {
591
1489
  };
592
1490
 
593
1491
  struct server_context {
1492
+ common_params params_base;
1493
+
594
1494
  llama_model * model = nullptr;
595
1495
  llama_context * ctx = nullptr;
596
1496
  std::vector<common_lora_adapter_container> loras;
597
1497
 
598
- common_params params;
1498
+ llama_model * model_dft = nullptr;
1499
+ llama_context_params cparams_dft;
599
1500
 
600
1501
  llama_batch batch = {};
601
1502
 
@@ -628,34 +1529,90 @@ struct server_context {
628
1529
  model = nullptr;
629
1530
  }
630
1531
 
1532
+ if (model_dft) {
1533
+ llama_free_model(model_dft);
1534
+ model_dft = nullptr;
1535
+ }
1536
+
631
1537
  // Clear any sampling context
632
1538
  for (server_slot & slot : slots) {
633
- if (slot.smpl != nullptr) {
634
- common_sampler_free(slot.smpl);
635
- }
1539
+ common_sampler_free(slot.smpl);
1540
+ slot.smpl = nullptr;
1541
+
1542
+ llama_free(slot.ctx_dft);
1543
+ slot.ctx_dft = nullptr;
1544
+
1545
+ common_speculative_free(slot.spec);
1546
+ slot.spec = nullptr;
1547
+
1548
+ llama_batch_free(slot.batch_spec);
636
1549
  }
637
1550
 
638
1551
  llama_batch_free(batch);
639
1552
  }
640
1553
 
641
- bool load_model(const common_params & params_) {
642
- params = params_;
1554
+ bool load_model(const common_params & params) {
1555
+ SRV_INF("loading model '%s'\n", params.model.c_str());
643
1556
 
644
- common_init_result llama_init = common_init_from_params(params);
1557
+ params_base = params;
1558
+
1559
+ common_init_result llama_init = common_init_from_params(params_base);
645
1560
 
646
1561
  model = llama_init.model;
647
1562
  ctx = llama_init.context;
648
1563
  loras = llama_init.lora_adapters;
649
1564
 
650
1565
  if (model == nullptr) {
651
- SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
1566
+ SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
652
1567
  return false;
653
1568
  }
654
1569
 
655
1570
  n_ctx = llama_n_ctx(ctx);
656
1571
 
657
1572
  add_bos_token = llama_add_bos_token(model);
658
- has_eos_token = !llama_add_eos_token(model);
1573
+ has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
1574
+
1575
+ if (!params_base.speculative.model.empty()) {
1576
+ SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
1577
+
1578
+ auto params_dft = params_base;
1579
+
1580
+ params_dft.devices = params_base.speculative.devices;
1581
+ params_dft.model = params_base.speculative.model;
1582
+ params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
1583
+ params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
1584
+ params_dft.n_parallel = 1;
1585
+
1586
+ common_init_result llama_init_dft = common_init_from_params(params_dft);
1587
+
1588
+ model_dft = llama_init_dft.model;
1589
+
1590
+ if (model_dft == nullptr) {
1591
+ SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
1592
+ return false;
1593
+ }
1594
+
1595
+ if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
1596
+ SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
1597
+
1598
+ llama_free (llama_init_dft.context);
1599
+ llama_free_model(llama_init_dft.model);
1600
+
1601
+ return false;
1602
+ }
1603
+
1604
+ const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
1605
+
1606
+ cparams_dft = common_context_params_to_llama(params_dft);
1607
+ cparams_dft.n_batch = n_ctx_dft;
1608
+
1609
+ // force F16 KV cache for the draft model for extra performance
1610
+ cparams_dft.type_k = GGML_TYPE_F16;
1611
+ cparams_dft.type_v = GGML_TYPE_F16;
1612
+
1613
+ // the context is not needed - we will create one for each slot
1614
+ llama_free(llama_init_dft.context);
1615
+ }
659
1616
 
660
1617
  return true;
661
1618
  }
@@ -674,20 +1631,37 @@ struct server_context {
674
1631
  }
675
1632
 
676
1633
  void init() {
677
- const int32_t n_ctx_slot = n_ctx / params.n_parallel;
1634
+ const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
678
1635
 
679
- SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
1636
+ SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
680
1637
 
681
- for (int i = 0; i < params.n_parallel; i++) {
1638
+ for (int i = 0; i < params_base.n_parallel; i++) {
682
1639
  server_slot slot;
683
1640
 
684
1641
  slot.id = i;
1642
+ slot.ctx = ctx;
685
1643
  slot.n_ctx = n_ctx_slot;
686
- slot.n_predict = params.n_predict;
1644
+ slot.n_predict = params_base.n_predict;
1645
+
1646
+ if (model_dft) {
1647
+ slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
1648
+
1649
+ slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
1650
+ if (slot.ctx_dft == nullptr) {
1651
+ SRV_ERR("%s", "failed to create draft context\n");
1652
+ return;
1653
+ }
1654
+
1655
+ slot.spec = common_speculative_init(slot.ctx_dft);
1656
+ if (slot.spec == nullptr) {
1657
+ SRV_ERR("%s", "failed to create speculator\n");
1658
+ return;
1659
+ }
1660
+ }
687
1661
 
688
1662
  SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
689
1663
 
690
- slot.sparams = params.sparams;
1664
+ slot.params.sampling = params_base.sampling;
691
1665
 
692
1666
  slot.callback_on_release = [this](int) {
693
1667
  queue_tasks.pop_deferred_task();
@@ -698,8 +1672,7 @@ struct server_context {
698
1672
  slots.push_back(slot);
699
1673
  }
700
1674
 
701
- default_generation_settings_for_props = get_formated_generation(slots.front());
702
- default_generation_settings_for_props["seed"] = -1;
1675
+ default_generation_settings_for_props = slots[0].to_json();
703
1676
 
704
1677
  // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
705
1678
  // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
@@ -707,7 +1680,7 @@ struct server_context {
707
1680
  const int32_t n_batch = llama_n_batch(ctx);
708
1681
 
709
1682
  // only a single seq_id per token is needed
710
- batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
1683
+ batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
711
1684
  }
712
1685
 
713
1686
  metrics.init();
@@ -743,7 +1716,7 @@ struct server_context {
743
1716
  }
744
1717
 
745
1718
  // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
746
- int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
1719
+ int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
747
1720
 
748
1721
  // fraction of the common subsequence length compared to the current slot's prompt length
749
1722
  float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
@@ -786,87 +1759,14 @@ struct server_context {
786
1759
  }
787
1760
 
788
1761
  bool launch_slot_with_task(server_slot & slot, const server_task & task) {
789
- slot_params default_params;
790
- // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
791
- auto default_sparams = params.sparams;
792
- const auto & data = task.data;
793
-
794
- if (data.count("__oaicompat") != 0) {
795
- slot.oaicompat = true;
796
- slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
797
- } else {
798
- slot.oaicompat = false;
799
- slot.oaicompat_model = "";
800
- }
801
-
802
- slot.params.stream = json_value(data, "stream", false);
803
- slot.params.cache_prompt = json_value(data, "cache_prompt", false);
804
- slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
805
- slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
806
- slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
807
- slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
808
- slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
809
- slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
810
- slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
811
- slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
812
- slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
813
- slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
814
- slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
815
- slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
816
- slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
817
- slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
818
- slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
819
- slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
820
- slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
821
- slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
822
- slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
823
- slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
824
- slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
825
- slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
826
- slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
827
- slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
828
- slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
829
- slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
830
- slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
831
- slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
832
- //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
833
- slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
834
-
835
- if (slot.sparams.dry_base < 1.0f)
836
- {
837
- slot.sparams.dry_base = default_sparams.dry_base;
838
- }
839
-
840
- // sequence breakers for DRY
841
- {
842
- // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
843
- // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
844
-
845
- if (data.contains("dry_sequence_breakers")) {
846
- slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
847
- if (slot.sparams.dry_sequence_breakers.empty()) {
848
- send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
849
- return false;
850
- }
851
- }
852
- }
1762
+ slot.reset();
1763
+ slot.id_task = task.id;
1764
+ slot.index = task.index;
1765
+ slot.task_type = task.type;
1766
+ slot.params = std::move(task.params);
1767
+ slot.prompt_tokens = std::move(task.prompt_tokens);
853
1768
 
854
- // process "json_schema" and "grammar"
855
- if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
856
- send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
857
- return false;
858
- }
859
- if (data.contains("json_schema") && !data.contains("grammar")) {
860
- try {
861
- auto schema = json_value(data, "json_schema", json::object());
862
- slot.sparams.grammar = json_schema_to_grammar(schema);
863
- } catch (const std::exception & e) {
864
- send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
865
- return false;
866
- }
867
- } else {
868
- slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
869
- }
1769
+ SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
870
1770
 
871
1771
  if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
872
1772
  // Might be better to reject the request with a 400 ?
@@ -874,86 +1774,16 @@ struct server_context {
874
1774
  SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
875
1775
  }
876
1776
 
877
- {
878
- slot.sparams.logit_bias.clear();
879
-
880
- if (json_value(data, "ignore_eos", false) && has_eos_token) {
881
- slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
882
- }
883
-
884
- const auto & logit_bias = data.find("logit_bias");
885
- if (logit_bias != data.end() && logit_bias->is_array()) {
886
- const int n_vocab = llama_n_vocab(model);
887
- for (const auto & el : *logit_bias) {
888
- // TODO: we may want to throw errors here, in case "el" is incorrect
889
- if (el.is_array() && el.size() == 2) {
890
- float bias;
891
- if (el[1].is_number()) {
892
- bias = el[1].get<float>();
893
- } else if (el[1].is_boolean() && !el[1].get<bool>()) {
894
- bias = -INFINITY;
895
- } else {
896
- continue;
897
- }
898
-
899
- if (el[0].is_number_integer()) {
900
- llama_token tok = el[0].get<llama_token>();
901
- if (tok >= 0 && tok < n_vocab) {
902
- slot.sparams.logit_bias.push_back({tok, bias});
903
- }
904
- } else if (el[0].is_string()) {
905
- auto toks = common_tokenize(model, el[0].get<std::string>(), false);
906
- for (auto tok : toks) {
907
- slot.sparams.logit_bias.push_back({tok, bias});
908
- }
909
- }
910
- }
911
- }
912
- }
913
- }
914
-
915
- {
916
- slot.params.antiprompt.clear();
917
-
918
- const auto & stop = data.find("stop");
919
- if (stop != data.end() && stop->is_array()) {
920
- for (const auto & word : *stop) {
921
- if (!word.empty()) {
922
- slot.params.antiprompt.push_back(word);
923
- }
924
- }
925
- }
926
- }
927
-
928
- {
929
- const auto & samplers = data.find("samplers");
930
- if (samplers != data.end()) {
931
- if (samplers->is_array()) {
932
- std::vector<std::string> sampler_names;
933
- for (const auto & name : *samplers) {
934
- if (name.is_string()) {
935
- sampler_names.emplace_back(name);
936
- }
937
- }
938
- slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false);
939
- } else if (samplers->is_string()){
940
- std::string sampler_string;
941
- for (const auto & name : *samplers) {
942
- sampler_string += name;
943
- }
944
- slot.sparams.samplers = common_sampler_types_from_chars(sampler_string);
945
- }
946
- } else {
947
- slot.sparams.samplers = default_sparams.samplers;
948
- }
949
- }
950
-
1777
+ if (slot.params.ignore_eos && has_eos_token) {
1778
+ slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
1779
+ }
1780
+
951
1781
  {
952
1782
  if (slot.smpl != nullptr) {
953
1783
  common_sampler_free(slot.smpl);
954
1784
  }
955
1785
 
956
- slot.smpl = common_sampler_init(model, slot.sparams);
1786
+ slot.smpl = common_sampler_init(model, slot.params.sampling);
957
1787
  if (slot.smpl == nullptr) {
958
1788
  // for now, the only error that may happen here is invalid grammar
959
1789
  send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@@ -961,6 +1791,12 @@ struct server_context {
961
1791
  }
962
1792
  }
963
1793
 
1794
+ if (slot.ctx_dft) {
1795
+ llama_batch_free(slot.batch_spec);
1796
+
1797
+ slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
1798
+ }
1799
+
964
1800
  slot.state = SLOT_STATE_STARTED;
965
1801
 
966
1802
  SLT_INF(slot, "%s", "processing task\n");
@@ -978,49 +1814,33 @@ struct server_context {
978
1814
 
979
1815
  bool process_token(completion_token_output & result, server_slot & slot) {
980
1816
  // remember which tokens were sampled - used for repetition penalties during sampling
981
- const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
1817
+ const std::string token_str = result.text_to_send;
982
1818
  slot.sampled = result.tok;
983
1819
 
984
- // search stop word and delete it
985
1820
  slot.generated_text += token_str;
1821
+ if (slot.params.return_tokens) {
1822
+ slot.generated_tokens.push_back(result.tok);
1823
+ }
986
1824
  slot.has_next_token = true;
987
1825
 
988
1826
  // check if there is incomplete UTF-8 character at the end
989
- bool incomplete = false;
990
- for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
991
- unsigned char c = slot.generated_text[slot.generated_text.size() - i];
992
- if ((c & 0xC0) == 0x80) {
993
- // continuation byte: 10xxxxxx
994
- continue;
995
- }
996
- if ((c & 0xE0) == 0xC0) {
997
- // 2-byte character: 110xxxxx ...
998
- incomplete = i < 2;
999
- } else if ((c & 0xF0) == 0xE0) {
1000
- // 3-byte character: 1110xxxx ...
1001
- incomplete = i < 3;
1002
- } else if ((c & 0xF8) == 0xF0) {
1003
- // 4-byte character: 11110xxx ...
1004
- incomplete = i < 4;
1005
- }
1006
- // else 1-byte character or invalid byte
1007
- break;
1008
- }
1827
+ bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
1009
1828
 
1829
+ // search stop word and delete it
1010
1830
  if (!incomplete) {
1011
1831
  size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
1012
1832
 
1013
1833
  const std::string str_test = slot.generated_text.substr(pos);
1014
1834
  bool send_text = true;
1015
1835
 
1016
- size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
1836
+ size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
1017
1837
  if (stop_pos != std::string::npos) {
1018
1838
  slot.generated_text.erase(
1019
1839
  slot.generated_text.begin() + pos + stop_pos,
1020
1840
  slot.generated_text.end());
1021
1841
  pos = std::min(slot.n_sent_text, slot.generated_text.size());
1022
1842
  } else if (slot.has_next_token) {
1023
- stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
1843
+ stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
1024
1844
  send_text = stop_pos == std::string::npos;
1025
1845
  }
1026
1846
 
@@ -1043,8 +1863,8 @@ struct server_context {
1043
1863
  }
1044
1864
 
1045
1865
  // check the limits
1046
- if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
1047
- slot.stopped_limit = true;
1866
+ if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
1867
+ slot.stop = STOP_TYPE_LIMIT;
1048
1868
  slot.has_next_token = false;
1049
1869
 
1050
1870
  SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
@@ -1053,7 +1873,7 @@ struct server_context {
1053
1873
  if (slot.has_new_line) {
1054
1874
  // if we have already seen a new line, we stop after a certain time limit
1055
1875
  if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
1056
- slot.stopped_limit = true;
1876
+ slot.stop = STOP_TYPE_LIMIT;
1057
1877
  slot.has_next_token = false;
1058
1878
 
1059
1879
  SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
@@ -1073,7 +1893,7 @@ struct server_context {
1073
1893
  }
1074
1894
 
1075
1895
  if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
1076
- slot.stopped_limit = true;
1896
+ slot.stop = STOP_TYPE_LIMIT;
1077
1897
  slot.has_next_token = false;
1078
1898
 
1079
1899
  // cut the last line
@@ -1102,7 +1922,7 @@ struct server_context {
1102
1922
  // if context shift is disabled, we stop when it reaches the context limit
1103
1923
  if (slot.n_past >= slot.n_ctx) {
1104
1924
  slot.truncated = true;
1105
- slot.stopped_limit = true;
1925
+ slot.stop = STOP_TYPE_LIMIT;
1106
1926
  slot.has_next_token = false;
1107
1927
 
1108
1928
  SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
@@ -1110,7 +1930,7 @@ struct server_context {
1110
1930
  }
1111
1931
 
1112
1932
  if (llama_token_is_eog(model, result.tok)) {
1113
- slot.stopped_eos = true;
1933
+ slot.stop = STOP_TYPE_EOS;
1114
1934
  slot.has_next_token = false;
1115
1935
 
1116
1936
  SLT_DBG(slot, "%s", "stopped by EOS\n");
@@ -1120,7 +1940,7 @@ struct server_context {
1120
1940
 
1121
1941
  if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1122
1942
  slot.truncated = true;
1123
- slot.stopped_limit = true;
1943
+ slot.stop = STOP_TYPE_LIMIT;
1124
1944
  slot.has_next_token = false; // stop prediction
1125
1945
 
1126
1946
  SLT_WRN(slot,
@@ -1134,53 +1954,53 @@ struct server_context {
1134
1954
  return slot.has_next_token; // continue
1135
1955
  }
1136
1956
 
1137
- json get_formated_generation(const server_slot & slot) const {
1138
- std::vector<std::string> samplers;
1139
- samplers.reserve(slot.sparams.samplers.size());
1140
- for (const auto & sampler : slot.sparams.samplers) {
1141
- samplers.emplace_back(common_sampler_type_to_str(sampler));
1142
- }
1957
+ void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
1958
+ size_t n_probs = slot.params.sampling.n_probs;
1959
+ size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
1960
+ if (post_sampling) {
1961
+ const auto * cur_p = common_sampler_get_candidates(slot.smpl);
1962
+ const size_t max_probs = cur_p->size;
1963
+
1964
+ // set probability for sampled token
1965
+ for (size_t i = 0; i < max_probs; i++) {
1966
+ if (cur_p->data[i].id == result.tok) {
1967
+ result.prob = cur_p->data[i].p;
1968
+ break;
1969
+ }
1970
+ }
1143
1971
 
1144
- return json {
1145
- {"n_ctx", slot.n_ctx},
1146
- {"n_predict", slot.n_predict}, // Server configured n_predict
1147
- {"model", params.model_alias},
1148
- {"seed", slot.sparams.seed},
1149
- {"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
1150
- {"temperature", slot.sparams.temp},
1151
- {"dynatemp_range", slot.sparams.dynatemp_range},
1152
- {"dynatemp_exponent", slot.sparams.dynatemp_exponent},
1153
- {"top_k", slot.sparams.top_k},
1154
- {"top_p", slot.sparams.top_p},
1155
- {"min_p", slot.sparams.min_p},
1156
- {"xtc_probability", slot.sparams.xtc_probability},
1157
- {"xtc_threshold", slot.sparams.xtc_threshold},
1158
- {"typical_p", slot.sparams.typ_p},
1159
- {"repeat_last_n", slot.sparams.penalty_last_n},
1160
- {"repeat_penalty", slot.sparams.penalty_repeat},
1161
- {"presence_penalty", slot.sparams.penalty_present},
1162
- {"frequency_penalty", slot.sparams.penalty_freq},
1163
- {"dry_multiplier", slot.sparams.dry_multiplier},
1164
- {"dry_base", slot.sparams.dry_base},
1165
- {"dry_allowed_length", slot.sparams.dry_allowed_length},
1166
- {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
1167
- {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
1168
- {"mirostat", slot.sparams.mirostat},
1169
- {"mirostat_tau", slot.sparams.mirostat_tau},
1170
- {"mirostat_eta", slot.sparams.mirostat_eta},
1171
- {"penalize_nl", slot.sparams.penalize_nl},
1172
- {"stop", slot.params.antiprompt},
1173
- {"max_tokens", slot.params.n_predict}, // User configured n_predict
1174
- {"n_keep", slot.params.n_keep},
1175
- {"n_discard", slot.params.n_discard},
1176
- {"ignore_eos", slot.sparams.ignore_eos},
1177
- {"stream", slot.params.stream},
1178
- //{"logit_bias", slot.sparams.logit_bias},
1179
- {"n_probs", slot.sparams.n_probs},
1180
- {"min_keep", slot.sparams.min_keep},
1181
- {"grammar", slot.sparams.grammar},
1182
- {"samplers", samplers},
1183
- };
1972
+ // set probability for top n_probs tokens
1973
+ result.probs.reserve(max_probs);
1974
+ for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
1975
+ result.probs.push_back({
1976
+ cur_p->data[i].id,
1977
+ common_detokenize(ctx, {cur_p->data[i].id}, special),
1978
+ cur_p->data[i].p
1979
+ });
1980
+ }
1981
+ } else {
1982
+ // TODO: optimize this with min-p optimization
1983
+ std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
1984
+
1985
+ // set probability for sampled token
1986
+ for (size_t i = 0; i < n_vocab; i++) {
1987
+ // set probability for sampled token
1988
+ if (cur[i].id == result.tok) {
1989
+ result.prob = cur[i].p;
1990
+ break;
1991
+ }
1992
+ }
1993
+
1994
+ // set probability for top n_probs tokens
1995
+ result.probs.reserve(n_probs);
1996
+ for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
1997
+ result.probs.push_back({
1998
+ cur[i].id,
1999
+ common_detokenize(ctx, {cur[i].id}, special),
2000
+ cur[i].p
2001
+ });
2002
+ }
2003
+ }
1184
2004
  }
1185
2005
 
1186
2006
  void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
@@ -1194,108 +2014,99 @@ struct server_context {
1194
2014
  void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1195
2015
  SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
1196
2016
 
1197
- server_task_result res;
1198
- res.id = id_task;
1199
- res.stop = false;
1200
- res.error = true;
1201
- res.data = format_error_response(error, type);
1202
-
1203
- queue_results.send(res);
1204
- }
1205
-
1206
- void send_partial_response(server_slot & slot, completion_token_output tkn) {
1207
- server_task_result res;
1208
- res.id = slot.id_task;
1209
- res.error = false;
1210
- res.stop = false;
1211
- res.data = json {
1212
- {"content", tkn.text_to_send},
1213
- {"stop", false},
1214
- {"id_slot", slot.id},
1215
- {"multimodal", false},
1216
- {"index", slot.index},
1217
- };
2017
+ auto res = std::make_unique<server_task_result_error>();
2018
+ res->id = id_task;
2019
+ res->err_type = type;
2020
+ res->err_msg = error;
1218
2021
 
1219
- if (slot.sparams.n_probs > 0) {
1220
- const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
1221
- const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
1222
- const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
2022
+ queue_results.send(std::move(res));
2023
+ }
1223
2024
 
1224
- std::vector<completion_token_output> probs_output;
1225
- if (probs_pos < probs_stop_pos) {
1226
- probs_output = std::vector<completion_token_output>(
1227
- slot.generated_token_probs.begin() + probs_pos,
1228
- slot.generated_token_probs.begin() + probs_stop_pos);
1229
- }
1230
- slot.n_sent_token_probs = probs_stop_pos;
2025
+ void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
2026
+ auto res = std::make_unique<server_task_result_cmpl_partial>();
2027
+
2028
+ res->id = slot.id_task;
2029
+ res->index = slot.index;
2030
+ res->content = tkn.text_to_send;
2031
+ res->tokens = { tkn.tok };
2032
+
2033
+ res->n_decoded = slot.n_decoded;
2034
+ res->n_prompt_tokens = slot.n_prompt_tokens;
2035
+ res->post_sampling_probs = slot.params.post_sampling_probs;
1231
2036
 
1232
- res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
2037
+ res->verbose = slot.params.verbose;
2038
+ res->oaicompat = slot.params.oaicompat;
2039
+ res->oaicompat_chat = slot.params.oaicompat_chat;
2040
+ res->oaicompat_model = slot.params.oaicompat_model;
2041
+ res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2042
+
2043
+ // populate res.probs_output
2044
+ if (slot.params.sampling.n_probs > 0) {
2045
+ res->prob_output = tkn; // copy the token probs
1233
2046
  }
1234
2047
 
1235
- if (slot.oaicompat) {
1236
- res.data["oaicompat_token_ctr"] = slot.n_decoded;
1237
- res.data["model"] = slot.oaicompat_model;
2048
+ // populate timings if this is final response or timings_per_token is enabled
2049
+ if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
2050
+ res->timings = slot.get_timings();
1238
2051
  }
1239
2052
 
1240
- queue_results.send(res);
2053
+ queue_results.send(std::move(res));
1241
2054
  }
1242
2055
 
1243
- void send_final_response(const server_slot & slot) {
1244
- server_task_result res;
1245
- res.id = slot.id_task;
1246
- res.error = false;
1247
- res.stop = true;
1248
- res.data = json {
1249
- {"content", !slot.params.stream ? slot.generated_text : ""},
1250
- {"id_slot", slot.id},
1251
- {"stop", true},
1252
- {"model", params.model_alias},
1253
- {"tokens_predicted", slot.n_decoded},
1254
- {"tokens_evaluated", slot.n_prompt_tokens},
1255
- {"generation_settings", get_formated_generation(slot)},
1256
- {"prompt", common_detokenize(ctx, slot.prompt_tokens)},
1257
- {"has_new_line", slot.has_new_line},
1258
- {"truncated", slot.truncated},
1259
- {"stopped_eos", slot.stopped_eos},
1260
- {"stopped_word", slot.stopped_word},
1261
- {"stopped_limit", slot.stopped_limit},
1262
- {"stopping_word", slot.stopping_word},
1263
- {"tokens_cached", slot.n_past},
1264
- {"timings", slot.get_formated_timings()},
1265
- {"index", slot.index},
1266
- };
1267
-
1268
- if (slot.sparams.n_probs > 0) {
1269
- std::vector<completion_token_output> probs;
1270
- if (!slot.params.stream && slot.stopped_word) {
2056
+ void send_final_response(server_slot & slot) {
2057
+ auto res = std::make_unique<server_task_result_cmpl_final>();
2058
+ res->id = slot.id_task;
2059
+ res->id_slot = slot.id;
2060
+
2061
+ res->index = slot.index;
2062
+ res->content = slot.generated_text;
2063
+ res->tokens = slot.generated_tokens;
2064
+ res->timings = slot.get_timings();
2065
+ res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
2066
+
2067
+ res->truncated = slot.truncated;
2068
+ res->n_decoded = slot.n_decoded;
2069
+ res->n_prompt_tokens = slot.n_prompt_tokens;
2070
+ res->n_tokens_cached = slot.n_past;
2071
+ res->has_new_line = slot.has_new_line;
2072
+ res->stopping_word = slot.stopping_word;
2073
+ res->stop = slot.stop;
2074
+ res->post_sampling_probs = slot.params.post_sampling_probs;
2075
+
2076
+ res->verbose = slot.params.verbose;
2077
+ res->stream = slot.params.stream;
2078
+ res->oaicompat = slot.params.oaicompat;
2079
+ res->oaicompat_chat = slot.params.oaicompat_chat;
2080
+ res->oaicompat_model = slot.params.oaicompat_model;
2081
+ res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2082
+
2083
+ // populate res.probs_output
2084
+ if (slot.params.sampling.n_probs > 0) {
2085
+ if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
1271
2086
  const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
1272
2087
 
1273
2088
  size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
1274
- probs = std::vector<completion_token_output>(
2089
+ res->probs_output = std::vector<completion_token_output>(
1275
2090
  slot.generated_token_probs.begin(),
1276
2091
  slot.generated_token_probs.end() - safe_offset);
1277
2092
  } else {
1278
- probs = std::vector<completion_token_output>(
2093
+ res->probs_output = std::vector<completion_token_output>(
1279
2094
  slot.generated_token_probs.begin(),
1280
2095
  slot.generated_token_probs.end());
1281
2096
  }
1282
-
1283
- res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
1284
2097
  }
1285
2098
 
1286
- if (slot.oaicompat) {
1287
- res.data["oaicompat_token_ctr"] = slot.n_decoded;
1288
- res.data["model"] = slot.oaicompat_model;
1289
- }
2099
+ res->generation_params = slot.params; // copy the parameters
1290
2100
 
1291
- queue_results.send(res);
2101
+ queue_results.send(std::move(res));
1292
2102
  }
1293
2103
 
1294
2104
  void send_embedding(const server_slot & slot, const llama_batch & batch) {
1295
- server_task_result res;
1296
- res.id = slot.id_task;
1297
- res.error = false;
1298
- res.stop = true;
2105
+ auto res = std::make_unique<server_task_result_embd>();
2106
+ res->id = slot.id_task;
2107
+ res->index = slot.index;
2108
+ res->n_tokens = slot.n_prompt_tokens;
2109
+ res->oaicompat = slot.params.oaicompat;
1299
2110
 
1300
2111
  const int n_embd = llama_n_embd(model);
1301
2112
 
@@ -1314,32 +2125,30 @@ struct server_context {
1314
2125
  if (embd == NULL) {
1315
2126
  SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1316
2127
 
1317
- res.data = json {
1318
- {"embedding", std::vector<float>(n_embd, 0.0f)},
1319
- {"index", slot.index},
1320
- };
1321
-
2128
+ res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
1322
2129
  continue;
1323
2130
  }
1324
2131
 
1325
- common_embd_normalize(embd, embd_res.data(), n_embd);
1326
-
1327
- res.data = json {
1328
- {"embedding", embd_res},
1329
- {"index", slot.index},
1330
- };
2132
+ // normalize only when there is pooling
2133
+ // TODO: configurable
2134
+ if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
2135
+ common_embd_normalize(embd, embd_res.data(), n_embd, 2);
2136
+ res->embedding.push_back(embd_res);
2137
+ } else {
2138
+ res->embedding.push_back({ embd, embd + n_embd });
2139
+ }
1331
2140
  }
1332
2141
 
1333
2142
  SLT_DBG(slot, "%s", "sending embeddings\n");
1334
2143
 
1335
- queue_results.send(res);
2144
+ queue_results.send(std::move(res));
1336
2145
  }
1337
2146
 
1338
2147
  void send_rerank(const server_slot & slot, const llama_batch & batch) {
1339
- server_task_result res;
1340
- res.id = slot.id_task;
1341
- res.error = false;
1342
- res.stop = true;
2148
+ auto res = std::make_unique<server_task_result_rerank>();
2149
+ res->id = slot.id_task;
2150
+ res->index = slot.index;
2151
+ res->n_tokens = slot.n_prompt_tokens;
1343
2152
 
1344
2153
  for (int i = 0; i < batch.n_tokens; ++i) {
1345
2154
  if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
@@ -1354,104 +2163,29 @@ struct server_context {
1354
2163
  if (embd == NULL) {
1355
2164
  SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1356
2165
 
1357
- res.data = json {
1358
- {"index", slot.index},
1359
- {"score", -1e6},
1360
- };
1361
-
2166
+ res->score = -1e6;
1362
2167
  continue;
1363
2168
  }
1364
2169
 
1365
- res.data = json {
1366
- {"index", slot.index},
1367
- {"score", embd[0]},
1368
- };
2170
+ res->score = embd[0];
1369
2171
  }
1370
2172
 
1371
- SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
2173
+ SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
1372
2174
 
1373
- queue_results.send(res);
2175
+ queue_results.send(std::move(res));
1374
2176
  }
1375
2177
 
1376
2178
  //
1377
2179
  // Functions to create new task(s) and receive result(s)
1378
2180
  //
1379
2181
 
1380
- // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
1381
- std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
1382
- std::vector<server_task> tasks;
1383
- auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
1384
- SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
1385
- server_task task;
1386
- task.id = queue_tasks.get_new_id();
1387
- task.inf_type = inf_type;
1388
- task.type = SERVER_TASK_TYPE_INFERENCE;
1389
- task.data = task_data;
1390
- task.prompt_tokens = std::move(prompt_tokens);
1391
- tasks.push_back(std::move(task));
1392
- };
1393
-
1394
- static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
1395
- if (!data.contains("prompt")) {
1396
- throw std::runtime_error(error_msg);
1397
- }
1398
-
1399
- // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
1400
- bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
1401
- std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
1402
- switch (inf_type) {
1403
- case SERVER_TASK_INF_TYPE_RERANK:
1404
- {
1405
- // prompts[0] is the question
1406
- // the rest are the answers/documents
1407
- GGML_ASSERT(tokenized_prompts.size() > 1);
1408
- SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
1409
- for (size_t i = 1; i < tokenized_prompts.size(); i++) {
1410
- data["index"] = i - 1;
1411
- auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
1412
- create_task(data, tokens);
1413
- }
1414
- } break;
1415
- case SERVER_TASK_INF_TYPE_INFILL:
1416
- {
1417
- SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
1418
- for (size_t i = 0; i < tokenized_prompts.size(); i++) {
1419
- data["index"] = i;
1420
- auto tokens = format_infill(
1421
- ctx,
1422
- data.at("input_prefix"),
1423
- data.at("input_suffix"),
1424
- data.at("input_extra"),
1425
- params.n_batch,
1426
- params.n_predict,
1427
- slots[0].n_ctx, // TODO: there should be a better way
1428
- params.spm_infill,
1429
- tokenized_prompts[i]
1430
- );
1431
- create_task(data, tokens);
1432
- }
1433
- } break;
1434
- default:
1435
- {
1436
- SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
1437
- for (size_t i = 0; i < tokenized_prompts.size(); i++) {
1438
- data["index"] = i;
1439
- create_task(data, tokenized_prompts[i]);
1440
- }
1441
- }
1442
- }
1443
-
1444
- return tasks;
1445
- }
1446
-
1447
2182
  void cancel_tasks(const std::unordered_set<int> & id_tasks) {
1448
2183
  std::vector<server_task> cancel_tasks;
1449
2184
  cancel_tasks.reserve(id_tasks.size());
1450
2185
  for (const auto & id_task : id_tasks) {
1451
2186
  SRV_WRN("cancel task, id_task = %d\n", id_task);
1452
2187
 
1453
- server_task task;
1454
- task.type = SERVER_TASK_TYPE_CANCEL;
2188
+ server_task task(SERVER_TASK_TYPE_CANCEL);
1455
2189
  task.id_target = id_task;
1456
2190
  cancel_tasks.push_back(task);
1457
2191
  queue_results.remove_waiting_task_id(id_task);
@@ -1460,50 +2194,58 @@ struct server_context {
1460
2194
  queue_tasks.post(cancel_tasks, true);
1461
2195
  }
1462
2196
 
1463
- // receive the results from task(s) created by create_tasks_inference
1464
- void receive_cmpl_results(
2197
+ // receive the results from task(s)
2198
+ void receive_multi_results(
1465
2199
  const std::unordered_set<int> & id_tasks,
1466
- const std::function<void(std::vector<server_task_result>&)> & result_handler,
2200
+ const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
1467
2201
  const std::function<void(json)> & error_handler) {
1468
- // TODO: currently, there is no way to detect the client has cancelled the request
1469
- std::vector<server_task_result> results(id_tasks.size());
2202
+ std::vector<server_task_result_ptr> results(id_tasks.size());
1470
2203
  for (size_t i = 0; i < id_tasks.size(); i++) {
1471
- server_task_result result = queue_results.recv(id_tasks);
2204
+ server_task_result_ptr result = queue_results.recv(id_tasks);
1472
2205
 
1473
- if (result.error) {
1474
- error_handler(result.data);
2206
+ if (result->is_error()) {
2207
+ error_handler(result->to_json());
1475
2208
  cancel_tasks(id_tasks);
1476
2209
  return;
1477
2210
  }
1478
2211
 
1479
- const size_t idx = result.data["index"];
2212
+ GGML_ASSERT(
2213
+ dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
2214
+ || dynamic_cast<server_task_result_embd*>(result.get()) != nullptr
2215
+ || dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr
2216
+ );
2217
+ const size_t idx = result->get_index();
1480
2218
  GGML_ASSERT(idx < results.size() && "index out of range");
1481
-
1482
- results[idx] = result;
2219
+ results[idx] = std::move(result);
1483
2220
  }
1484
2221
  result_handler(results);
1485
2222
  }
1486
2223
 
1487
- // receive the results from task(s) created by create_tasks_inference, in stream mode
2224
+ // receive the results from task(s), in stream mode
1488
2225
  void receive_cmpl_results_stream(
1489
- const std::unordered_set<int> & id_tasks, const
1490
- std::function<bool(server_task_result&)> & result_handler, const
1491
- std::function<void(json)> & error_handler) {
2226
+ const std::unordered_set<int> & id_tasks,
2227
+ const std::function<bool(server_task_result_ptr&)> & result_handler,
2228
+ const std::function<void(json)> & error_handler) {
1492
2229
  size_t n_finished = 0;
1493
2230
  while (true) {
1494
- server_task_result result = queue_results.recv(id_tasks);
1495
- if (!result_handler(result)) {
2231
+ server_task_result_ptr result = queue_results.recv(id_tasks);
2232
+
2233
+ if (result->is_error()) {
2234
+ error_handler(result->to_json());
1496
2235
  cancel_tasks(id_tasks);
1497
- break;
2236
+ return;
1498
2237
  }
1499
2238
 
1500
- if (result.error) {
1501
- error_handler(result.data);
2239
+ GGML_ASSERT(
2240
+ dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
2241
+ || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
2242
+ );
2243
+ if (!result_handler(result)) {
1502
2244
  cancel_tasks(id_tasks);
1503
2245
  break;
1504
2246
  }
1505
2247
 
1506
- if (result.stop) {
2248
+ if (result->is_stop()) {
1507
2249
  if (++n_finished == id_tasks.size()) {
1508
2250
  break;
1509
2251
  }
@@ -1517,9 +2259,12 @@ struct server_context {
1517
2259
 
1518
2260
  void process_single_task(server_task task) {
1519
2261
  switch (task.type) {
1520
- case SERVER_TASK_TYPE_INFERENCE:
2262
+ case SERVER_TASK_TYPE_COMPLETION:
2263
+ case SERVER_TASK_TYPE_INFILL:
2264
+ case SERVER_TASK_TYPE_EMBEDDING:
2265
+ case SERVER_TASK_TYPE_RERANK:
1521
2266
  {
1522
- const int id_slot = json_value(task.data, "id_slot", -1);
2267
+ const int id_slot = task.id_selected_slot;
1523
2268
 
1524
2269
  server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
1525
2270
 
@@ -1536,13 +2281,6 @@ struct server_context {
1536
2281
  break;
1537
2282
  }
1538
2283
 
1539
- slot->reset();
1540
-
1541
- slot->id_task = task.id;
1542
- slot->inf_type = task.inf_type;
1543
- slot->index = json_value(task.data, "index", 0);
1544
- slot->prompt_tokens = std::move(task.prompt_tokens);
1545
-
1546
2284
  if (!launch_slot_with_task(*slot, task)) {
1547
2285
  SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
1548
2286
  break;
@@ -1570,21 +2308,7 @@ struct server_context {
1570
2308
  int n_processing_slots = 0;
1571
2309
 
1572
2310
  for (server_slot & slot : slots) {
1573
- json slot_data = get_formated_generation(slot);
1574
- slot_data["id"] = slot.id;
1575
- slot_data["id_task"] = slot.id_task;
1576
- slot_data["is_processing"] = slot.is_processing();
1577
- slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
1578
- slot_data["next_token"] = {
1579
- {"has_next_token", slot.has_next_token},
1580
- {"has_new_line", slot.has_new_line},
1581
- {"n_remain", slot.n_remaining},
1582
- {"n_decoded", slot.n_decoded},
1583
- {"stopped_eos", slot.stopped_eos},
1584
- {"stopped_word", slot.stopped_word},
1585
- {"stopped_limit", slot.stopped_limit},
1586
- {"stopping_word", slot.stopping_word},
1587
- };
2311
+ json slot_data = slot.to_json();
1588
2312
 
1589
2313
  if (slot.is_processing()) {
1590
2314
  n_processing_slots++;
@@ -1596,43 +2320,38 @@ struct server_context {
1596
2320
  }
1597
2321
  SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
1598
2322
 
1599
- server_task_result res;
1600
- res.id = task.id;
1601
- res.stop = true;
1602
- res.error = false;
1603
- res.data = {
1604
- { "idle", n_idle_slots },
1605
- { "processing", n_processing_slots },
1606
- { "deferred", queue_tasks.queue_tasks_deferred.size() },
1607
- { "t_start", metrics.t_start},
1608
-
1609
- { "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
1610
- { "t_tokens_generation_total", metrics.t_tokens_generation_total},
1611
- { "n_tokens_predicted_total", metrics.n_tokens_predicted_total},
1612
- { "t_prompt_processing_total", metrics.t_prompt_processing_total},
2323
+ auto res = std::make_unique<server_task_result_metrics>();
2324
+ res->id = task.id;
2325
+ res->slots_data = std::move(slots_data);
2326
+ res->n_idle_slots = n_idle_slots;
2327
+ res->n_processing_slots = n_processing_slots;
2328
+ res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
2329
+ res->t_start = metrics.t_start;
1613
2330
 
1614
- { "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed},
1615
- { "t_prompt_processing", metrics.t_prompt_processing},
1616
- { "n_tokens_predicted", metrics.n_tokens_predicted},
1617
- { "t_tokens_generation", metrics.t_tokens_generation},
2331
+ res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
2332
+ res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
1618
2333
 
1619
- { "n_decode_total", metrics.n_decode_total},
1620
- { "n_busy_slots_total", metrics.n_busy_slots_total},
2334
+ res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
2335
+ res->t_prompt_processing_total = metrics.t_prompt_processing_total;
2336
+ res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
2337
+ res->t_tokens_generation_total = metrics.t_tokens_generation_total;
1621
2338
 
1622
- { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
1623
- { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
2339
+ res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
2340
+ res->t_prompt_processing = metrics.t_prompt_processing;
2341
+ res->n_tokens_predicted = metrics.n_tokens_predicted;
2342
+ res->t_tokens_generation = metrics.t_tokens_generation;
1624
2343
 
1625
- { "slots", slots_data },
1626
- };
2344
+ res->n_decode_total = metrics.n_decode_total;
2345
+ res->n_busy_slots_total = metrics.n_busy_slots_total;
1627
2346
 
1628
- if (json_value(task.data, "reset_bucket", false)) {
2347
+ if (task.metrics_reset_bucket) {
1629
2348
  metrics.reset_bucket();
1630
2349
  }
1631
- queue_results.send(res);
2350
+ queue_results.send(std::move(res));
1632
2351
  } break;
1633
2352
  case SERVER_TASK_TYPE_SLOT_SAVE:
1634
2353
  {
1635
- int id_slot = task.data.at("id_slot");
2354
+ int id_slot = task.slot_action.slot_id;
1636
2355
  server_slot * slot = get_slot_by_id(id_slot);
1637
2356
  if (slot == nullptr) {
1638
2357
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -1648,32 +2367,27 @@ struct server_context {
1648
2367
  const size_t token_count = slot->cache_tokens.size();
1649
2368
  const int64_t t_start = ggml_time_us();
1650
2369
 
1651
- std::string filename = task.data.at("filename");
1652
- std::string filepath = task.data.at("filepath");
2370
+ std::string filename = task.slot_action.filename;
2371
+ std::string filepath = task.slot_action.filepath;
1653
2372
 
1654
2373
  const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
1655
2374
 
1656
2375
  const int64_t t_end = ggml_time_us();
1657
2376
  const double t_save_ms = (t_end - t_start) / 1000.0;
1658
2377
 
1659
- server_task_result result;
1660
- result.id = task.id;
1661
- result.stop = true;
1662
- result.error = false;
1663
- result.data = json {
1664
- { "id_slot", id_slot },
1665
- { "filename", filename },
1666
- { "n_saved", token_count }, // tokens saved
1667
- { "n_written", nwrite }, // bytes written
1668
- { "timings", {
1669
- { "save_ms", t_save_ms }
1670
- } }
1671
- };
1672
- queue_results.send(result);
2378
+ auto res = std::make_unique<server_task_result_slot_save_load>();
2379
+ res->id = task.id;
2380
+ res->id_slot = id_slot;
2381
+ res->filename = filename;
2382
+ res->is_save = true;
2383
+ res->n_tokens = token_count;
2384
+ res->n_bytes = nwrite;
2385
+ res->t_ms = t_save_ms;
2386
+ queue_results.send(std::move(res));
1673
2387
  } break;
1674
2388
  case SERVER_TASK_TYPE_SLOT_RESTORE:
1675
2389
  {
1676
- int id_slot = task.data.at("id_slot");
2390
+ int id_slot = task.slot_action.slot_id;
1677
2391
  server_slot * slot = get_slot_by_id(id_slot);
1678
2392
  if (slot == nullptr) {
1679
2393
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -1688,8 +2402,8 @@ struct server_context {
1688
2402
 
1689
2403
  const int64_t t_start = ggml_time_us();
1690
2404
 
1691
- std::string filename = task.data.at("filename");
1692
- std::string filepath = task.data.at("filepath");
2405
+ std::string filename = task.slot_action.filename;
2406
+ std::string filepath = task.slot_action.filepath;
1693
2407
 
1694
2408
  slot->cache_tokens.resize(slot->n_ctx);
1695
2409
  size_t token_count = 0;
@@ -1704,24 +2418,19 @@ struct server_context {
1704
2418
  const int64_t t_end = ggml_time_us();
1705
2419
  const double t_restore_ms = (t_end - t_start) / 1000.0;
1706
2420
 
1707
- server_task_result result;
1708
- result.id = task.id;
1709
- result.stop = true;
1710
- result.error = false;
1711
- result.data = json {
1712
- { "id_slot", id_slot },
1713
- { "filename", filename },
1714
- { "n_restored", token_count }, // tokens restored
1715
- { "n_read", nread }, // bytes read
1716
- { "timings", {
1717
- { "restore_ms", t_restore_ms }
1718
- } }
1719
- };
1720
- queue_results.send(result);
2421
+ auto res = std::make_unique<server_task_result_slot_save_load>();
2422
+ res->id = task.id;
2423
+ res->id_slot = id_slot;
2424
+ res->filename = filename;
2425
+ res->is_save = false;
2426
+ res->n_tokens = token_count;
2427
+ res->n_bytes = nread;
2428
+ res->t_ms = t_restore_ms;
2429
+ queue_results.send(std::move(res));
1721
2430
  } break;
1722
2431
  case SERVER_TASK_TYPE_SLOT_ERASE:
1723
2432
  {
1724
- int id_slot = task.data.at("id_slot");
2433
+ int id_slot = task.slot_action.slot_id;
1725
2434
  server_slot * slot = get_slot_by_id(id_slot);
1726
2435
  if (slot == nullptr) {
1727
2436
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -1739,25 +2448,18 @@ struct server_context {
1739
2448
  llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
1740
2449
  slot->cache_tokens.clear();
1741
2450
 
1742
- server_task_result result;
1743
- result.id = task.id;
1744
- result.stop = true;
1745
- result.error = false;
1746
- result.data = json {
1747
- { "id_slot", id_slot },
1748
- { "n_erased", n_erased }
1749
- };
1750
- queue_results.send(result);
2451
+ auto res = std::make_unique<server_task_result_slot_erase>();
2452
+ res->id = task.id;
2453
+ res->id_slot = id_slot;
2454
+ res->n_erased = n_erased;
2455
+ queue_results.send(std::move(res));
1751
2456
  } break;
1752
2457
  case SERVER_TASK_TYPE_SET_LORA:
1753
2458
  {
1754
2459
  common_lora_adapters_apply(ctx, loras);
1755
- server_task_result result;
1756
- result.id = task.id;
1757
- result.stop = true;
1758
- result.error = false;
1759
- result.data = json{{ "success", true }};
1760
- queue_results.send(result);
2460
+ auto res = std::make_unique<server_task_result_apply_lora>();
2461
+ res->id = task.id;
2462
+ queue_results.send(std::move(res));
1761
2463
  } break;
1762
2464
  }
1763
2465
  }
@@ -1787,10 +2489,8 @@ struct server_context {
1787
2489
  {
1788
2490
  SRV_DBG("%s", "posting NEXT_RESPONSE\n");
1789
2491
 
1790
- server_task task;
1791
- task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
1792
- task.id_target = -1;
1793
-
2492
+ server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
2493
+ task.id = queue_tasks.get_new_id();
1794
2494
  queue_tasks.post(task);
1795
2495
  }
1796
2496
 
@@ -1798,7 +2498,7 @@ struct server_context {
1798
2498
  // TODO: simplify and improve
1799
2499
  for (server_slot & slot : slots) {
1800
2500
  if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
1801
- if (!params.ctx_shift) {
2501
+ if (!params_base.ctx_shift) {
1802
2502
  // this check is redundant (for good)
1803
2503
  // we should never get here, because generation should already stopped in process_token()
1804
2504
  slot.release();
@@ -1864,7 +2564,7 @@ struct server_context {
1864
2564
  int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
1865
2565
 
1866
2566
  // next, batch any pending prompts without exceeding n_batch
1867
- if (params.cont_batching || batch.n_tokens == 0) {
2567
+ if (params_base.cont_batching || batch.n_tokens == 0) {
1868
2568
  for (auto & slot : slots) {
1869
2569
  // this slot still has a prompt to be processed
1870
2570
  if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
@@ -1904,7 +2604,7 @@ struct server_context {
1904
2604
  continue;
1905
2605
  }
1906
2606
 
1907
- if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2607
+ if (slot.is_non_causal()) {
1908
2608
  if (slot.n_prompt_tokens > n_ubatch) {
1909
2609
  slot.release();
1910
2610
  send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
@@ -1917,7 +2617,7 @@ struct server_context {
1917
2617
  continue;
1918
2618
  }
1919
2619
  } else {
1920
- if (!params.ctx_shift) {
2620
+ if (!params_base.ctx_shift) {
1921
2621
  // if context shift is disabled, we make sure prompt size is smaller than KV size
1922
2622
  // TODO: there should be a separate parameter that control prompt truncation
1923
2623
  // context shift should be applied only during the generation phase
@@ -1960,14 +2660,14 @@ struct server_context {
1960
2660
 
1961
2661
  if (slot.params.cache_prompt) {
1962
2662
  // reuse any previously computed tokens that are common with the new prompt
1963
- slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
2663
+ slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
1964
2664
 
1965
2665
  // reuse chunks from the cached prompt by shifting their KV cache in the new position
1966
- if (params.n_cache_reuse > 0) {
2666
+ if (params_base.n_cache_reuse > 0) {
1967
2667
  size_t head_c = slot.n_past; // cache
1968
2668
  size_t head_p = slot.n_past; // current prompt
1969
2669
 
1970
- SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
2670
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
1971
2671
 
1972
2672
  while (head_c < slot.cache_tokens.size() &&
1973
2673
  head_p < prompt_tokens.size()) {
@@ -1980,7 +2680,7 @@ struct server_context {
1980
2680
  n_match++;
1981
2681
  }
1982
2682
 
1983
- if (n_match >= (size_t) params.n_cache_reuse) {
2683
+ if (n_match >= (size_t) params_base.n_cache_reuse) {
1984
2684
  SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
1985
2685
  //for (size_t i = head_p; i < head_p + n_match; i++) {
1986
2686
  // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
@@ -2019,7 +2719,7 @@ struct server_context {
2019
2719
  }
2020
2720
 
2021
2721
  // non-causal tasks require to fit the entire prompt in the physical batch
2022
- if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2722
+ if (slot.is_non_causal()) {
2023
2723
  // cannot fit the prompt in the current batch - will try next iter
2024
2724
  if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
2025
2725
  continue;
@@ -2027,10 +2727,7 @@ struct server_context {
2027
2727
  }
2028
2728
 
2029
2729
  // check that we are in the right batch_type, if not defer the slot
2030
- const bool slot_type =
2031
- slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
2032
- slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
2033
-
2730
+ int slot_type = slot.is_non_causal();
2034
2731
  if (batch_type == -1) {
2035
2732
  batch_type = slot_type;
2036
2733
  } else if (batch_type != slot_type) {
@@ -2053,7 +2750,10 @@ struct server_context {
2053
2750
 
2054
2751
  // add prompt tokens for processing in the current batch
2055
2752
  while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2056
- common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
2753
+ // without pooling, we want to output the embeddings for all the tokens in the batch
2754
+ const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
2755
+
2756
+ common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
2057
2757
 
2058
2758
  if (slot.params.cache_prompt) {
2059
2759
  slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
@@ -2147,7 +2847,7 @@ struct server_context {
2147
2847
  }
2148
2848
 
2149
2849
  if (slot.state == SLOT_STATE_DONE_PROMPT) {
2150
- if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
2850
+ if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
2151
2851
  // prompt evaluated for embedding
2152
2852
  send_embedding(slot, batch_view);
2153
2853
  slot.release();
@@ -2155,7 +2855,7 @@ struct server_context {
2155
2855
  continue; // continue loop of slots
2156
2856
  }
2157
2857
 
2158
- if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2858
+ if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
2159
2859
  send_rerank(slot, batch_view);
2160
2860
  slot.release();
2161
2861
  slot.i_batch = -1;
@@ -2168,27 +2868,33 @@ struct server_context {
2168
2868
  continue; // continue loop of slots
2169
2869
  }
2170
2870
 
2171
- completion_token_output result;
2172
- const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
2871
+ const int tok_idx = slot.i_batch - i;
2872
+
2873
+ llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
2874
+
2875
+ slot.i_batch = -1;
2173
2876
 
2174
2877
  common_sampler_accept(slot.smpl, id, true);
2175
2878
 
2176
2879
  slot.n_decoded += 1;
2880
+
2881
+ const int64_t t_current = ggml_time_us();
2882
+
2177
2883
  if (slot.n_decoded == 1) {
2178
- slot.t_start_generation = ggml_time_us();
2884
+ slot.t_start_generation = t_current;
2179
2885
  slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2180
2886
  metrics.on_prompt_eval(slot);
2181
2887
  }
2182
2888
 
2183
- result.tok = id;
2889
+ slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
2184
2890
 
2185
- const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2891
+ completion_token_output result;
2892
+ result.tok = id;
2893
+ result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
2894
+ result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
2186
2895
 
2187
- for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2188
- result.probs.push_back({
2189
- cur_p->data[i].id,
2190
- i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2191
- });
2896
+ if (slot.params.sampling.n_probs > 0) {
2897
+ populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
2192
2898
  }
2193
2899
 
2194
2900
  if (!process_token(result, slot)) {
@@ -2197,9 +2903,98 @@ struct server_context {
2197
2903
  slot.print_timings();
2198
2904
  send_final_response(slot);
2199
2905
  metrics.on_prediction(slot);
2906
+ continue;
2200
2907
  }
2908
+ }
2201
2909
 
2202
- slot.i_batch = -1;
2910
+ // do speculative decoding
2911
+ for (auto & slot : slots) {
2912
+ if (!slot.is_processing() || !slot.can_speculate()) {
2913
+ continue;
2914
+ }
2915
+
2916
+ if (slot.state != SLOT_STATE_GENERATING) {
2917
+ continue;
2918
+ }
2919
+
2920
+ // determine the max draft that fits the current slot state
2921
+ int n_draft_max = slot.params.speculative.n_max;
2922
+
2923
+ // note: n_past is not yet increased for the `id` token sampled above
2924
+ // also, need to leave space for 1 extra token to allow context shifts
2925
+ n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
2926
+
2927
+ if (slot.n_remaining > 0) {
2928
+ n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
2929
+ }
2930
+
2931
+ SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
2932
+
2933
+ if (n_draft_max < slot.params.speculative.n_min) {
2934
+ SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
2935
+
2936
+ continue;
2937
+ }
2938
+
2939
+ llama_token id = slot.sampled;
2940
+
2941
+ struct common_speculative_params params_spec;
2942
+ params_spec.n_draft = n_draft_max;
2943
+ params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
2944
+ params_spec.p_min = slot.params.speculative.p_min;
2945
+
2946
+ llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
2947
+
2948
+ // ignore small drafts
2949
+ if (slot.params.speculative.n_min > (int) draft.size()) {
2950
+ SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
2951
+
2952
+ continue;
2953
+ }
2954
+
2955
+ // construct the speculation batch
2956
+ common_batch_clear(slot.batch_spec);
2957
+ common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
2958
+
2959
+ for (size_t i = 0; i < draft.size(); ++i) {
2960
+ common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
2961
+ }
2962
+
2963
+ SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
2964
+
2965
+ llama_decode(ctx, slot.batch_spec);
2966
+
2967
+ // the accepted tokens from the speculation
2968
+ const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
2969
+
2970
+ slot.n_past += ids.size();
2971
+ slot.n_decoded += ids.size();
2972
+
2973
+ slot.cache_tokens.push_back(id);
2974
+ slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
2975
+
2976
+ llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
2977
+
2978
+ for (size_t i = 0; i < ids.size(); ++i) {
2979
+ completion_token_output result;
2980
+
2981
+ result.tok = ids[i];
2982
+ result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
2983
+ result.prob = 1.0f; // set later
2984
+
2985
+ // TODO: set result.probs
2986
+
2987
+ if (!process_token(result, slot)) {
2988
+ // release slot because of stop condition
2989
+ slot.release();
2990
+ slot.print_timings();
2991
+ send_final_response(slot);
2992
+ metrics.on_prediction(slot);
2993
+ break;
2994
+ }
2995
+ }
2996
+
2997
+ SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
2203
2998
  }
2204
2999
  }
2205
3000
 
@@ -2254,17 +3049,9 @@ int main(int argc, char ** argv) {
2254
3049
 
2255
3050
  common_init();
2256
3051
 
2257
- // enabling this will output extra debug information in the HTTP responses from the server
2258
- // see format_final_response_oaicompat()
2259
- const bool verbose = params.verbosity > 9;
2260
-
2261
3052
  // struct that contains llama context and inference
2262
3053
  server_context ctx_server;
2263
3054
 
2264
- if (params.model_alias == "unknown") {
2265
- params.model_alias = params.model;
2266
- }
2267
-
2268
3055
  llama_backend_init();
2269
3056
  llama_numa_init(params.numa);
2270
3057
 
@@ -2273,16 +3060,6 @@ int main(int argc, char ** argv) {
2273
3060
  LOG_INF("%s\n", common_params_get_system_info(params).c_str());
2274
3061
  LOG_INF("\n");
2275
3062
 
2276
- // static files
2277
- std::map<std::string, server_static_file> static_files = {
2278
- { "/", { index_html, index_html_len, "text/html; charset=utf-8" }},
2279
- { "/completion.js", { completion_js, completion_js_len, "text/javascript; charset=utf-8" }},
2280
- { "/deps_daisyui.min.css", { deps_daisyui_min_css, deps_daisyui_min_css_len, "text/css; charset=utf-8" }},
2281
- { "/deps_markdown-it.js", { deps_markdown_it_js, deps_markdown_it_js_len, "text/javascript; charset=utf-8" }},
2282
- { "/deps_tailwindcss.js", { deps_tailwindcss_js, deps_tailwindcss_js_len, "text/javascript; charset=utf-8" }},
2283
- { "/deps_vue.esm-browser.js", { deps_vue_esm_browser_js, deps_vue_esm_browser_js_len, "text/javascript; charset=utf-8" }},
2284
- };
2285
-
2286
3063
  std::unique_ptr<httplib::Server> svr;
2287
3064
  #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2288
3065
  if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
@@ -2309,20 +3086,20 @@ int main(int argc, char ** argv) {
2309
3086
 
2310
3087
  auto res_error = [](httplib::Response & res, const json & error_data) {
2311
3088
  json final_response {{"error", error_data}};
2312
- res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
3089
+ res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
2313
3090
  res.status = json_value(error_data, "code", 500);
2314
3091
  };
2315
3092
 
2316
3093
  auto res_ok = [](httplib::Response & res, const json & data) {
2317
- res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
3094
+ res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
2318
3095
  res.status = 200;
2319
3096
  };
2320
3097
 
2321
- svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
3098
+ svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
2322
3099
  std::string message;
2323
3100
  try {
2324
3101
  std::rethrow_exception(ep);
2325
- } catch (std::exception & e) {
3102
+ } catch (const std::exception & e) {
2326
3103
  message = e.what();
2327
3104
  } catch (...) {
2328
3105
  message = "Unknown Exception";
@@ -2363,7 +3140,7 @@ int main(int argc, char ** argv) {
2363
3140
  // Middlewares
2364
3141
  //
2365
3142
 
2366
- auto middleware_validate_api_key = [&params, &res_error, &static_files](const httplib::Request & req, httplib::Response & res) {
3143
+ auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
2367
3144
  static const std::unordered_set<std::string> public_endpoints = {
2368
3145
  "/health",
2369
3146
  "/models",
@@ -2376,7 +3153,7 @@ int main(int argc, char ** argv) {
2376
3153
  }
2377
3154
 
2378
3155
  // If path is public or is static file, skip validation
2379
- if (public_endpoints.find(req.path) != public_endpoints.end() || static_files.find(req.path) != static_files.end()) {
3156
+ if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
2380
3157
  return true;
2381
3158
  }
2382
3159
 
@@ -2451,27 +3228,33 @@ int main(int argc, char ** argv) {
2451
3228
  }
2452
3229
 
2453
3230
  // request slots data using task queue
2454
- server_task task;
3231
+ server_task task(SERVER_TASK_TYPE_METRICS);
2455
3232
  task.id = ctx_server.queue_tasks.get_new_id();
2456
- task.type = SERVER_TASK_TYPE_METRICS;
2457
-
2458
3233
  ctx_server.queue_results.add_waiting_task_id(task.id);
2459
3234
  ctx_server.queue_tasks.post(task, true); // high-priority task
2460
3235
 
2461
3236
  // get the result
2462
- server_task_result result = ctx_server.queue_results.recv(task.id);
3237
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
2463
3238
  ctx_server.queue_results.remove_waiting_task_id(task.id);
2464
3239
 
3240
+ if (result->is_error()) {
3241
+ res_error(res, result->to_json());
3242
+ return;
3243
+ }
3244
+
3245
+ // TODO: get rid of this dynamic_cast
3246
+ auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
3247
+ GGML_ASSERT(res_metrics != nullptr);
3248
+
2465
3249
  // optionally return "fail_on_no_slot" error
2466
- const int n_idle_slots = result.data.at("idle");
2467
3250
  if (req.has_param("fail_on_no_slot")) {
2468
- if (n_idle_slots == 0) {
3251
+ if (res_metrics->n_idle_slots == 0) {
2469
3252
  res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
2470
3253
  return;
2471
3254
  }
2472
3255
  }
2473
3256
 
2474
- res_ok(res, result.data.at("slots"));
3257
+ res_ok(res, res_metrics->slots_data);
2475
3258
  };
2476
3259
 
2477
3260
  const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
@@ -2481,83 +3264,77 @@ int main(int argc, char ** argv) {
2481
3264
  }
2482
3265
 
2483
3266
  // request slots data using task queue
2484
- server_task task;
3267
+ server_task task(SERVER_TASK_TYPE_METRICS);
2485
3268
  task.id = ctx_server.queue_tasks.get_new_id();
2486
- task.id_target = -1;
2487
- task.type = SERVER_TASK_TYPE_METRICS;
2488
- task.data.push_back({{"reset_bucket", true}});
3269
+ task.metrics_reset_bucket = true;
2489
3270
 
2490
3271
  ctx_server.queue_results.add_waiting_task_id(task.id);
2491
3272
  ctx_server.queue_tasks.post(task, true); // high-priority task
2492
3273
 
2493
3274
  // get the result
2494
- server_task_result result = ctx_server.queue_results.recv(task.id);
3275
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
2495
3276
  ctx_server.queue_results.remove_waiting_task_id(task.id);
2496
3277
 
2497
- json data = result.data;
2498
-
2499
- const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed");
2500
- const uint64_t t_prompt_processing = data.at("t_prompt_processing");
2501
-
2502
- const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
2503
- const uint64_t t_tokens_generation = data.at("t_tokens_generation");
2504
-
2505
- const uint64_t n_decode_total = data.at("n_decode_total");
2506
- const uint64_t n_busy_slots_total = data.at("n_busy_slots_total");
3278
+ if (result->is_error()) {
3279
+ res_error(res, result->to_json());
3280
+ return;
3281
+ }
2507
3282
 
2508
- const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
3283
+ // TODO: get rid of this dynamic_cast
3284
+ auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
3285
+ GGML_ASSERT(res_metrics != nullptr);
2509
3286
 
2510
3287
  // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
2511
3288
  json all_metrics_def = json {
2512
3289
  {"counter", {{
2513
3290
  {"name", "prompt_tokens_total"},
2514
3291
  {"help", "Number of prompt tokens processed."},
2515
- {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")}
3292
+ {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total}
2516
3293
  }, {
2517
3294
  {"name", "prompt_seconds_total"},
2518
3295
  {"help", "Prompt process time"},
2519
- {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3}
3296
+ {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3}
2520
3297
  }, {
2521
3298
  {"name", "tokens_predicted_total"},
2522
3299
  {"help", "Number of generation tokens processed."},
2523
- {"value", (uint64_t) data.at("n_tokens_predicted_total")}
3300
+ {"value", (uint64_t) res_metrics->n_tokens_predicted_total}
2524
3301
  }, {
2525
3302
  {"name", "tokens_predicted_seconds_total"},
2526
3303
  {"help", "Predict process time"},
2527
- {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
3304
+ {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3}
2528
3305
  }, {
2529
3306
  {"name", "n_decode_total"},
2530
3307
  {"help", "Total number of llama_decode() calls"},
2531
- {"value", n_decode_total}
3308
+ {"value", res_metrics->n_decode_total}
2532
3309
  }, {
2533
3310
  {"name", "n_busy_slots_per_decode"},
2534
3311
  {"help", "Average number of busy slots per llama_decode() call"},
2535
- {"value", (float) n_busy_slots_total / (float) n_decode_total}
3312
+ {"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total}
2536
3313
  }}},
2537
3314
  {"gauge", {{
2538
3315
  {"name", "prompt_tokens_seconds"},
2539
3316
  {"help", "Average prompt throughput in tokens/s."},
2540
- {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
3317
+ {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.}
2541
3318
  },{
2542
3319
  {"name", "predicted_tokens_seconds"},
2543
3320
  {"help", "Average generation throughput in tokens/s."},
2544
- {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
3321
+ {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.}
2545
3322
  },{
2546
3323
  {"name", "kv_cache_usage_ratio"},
2547
3324
  {"help", "KV-cache usage. 1 means 100 percent usage."},
2548
- {"value", 1. * kv_cache_used_cells / params.n_ctx}
3325
+ {"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx}
2549
3326
  },{
2550
3327
  {"name", "kv_cache_tokens"},
2551
3328
  {"help", "KV-cache tokens."},
2552
- {"value", (uint64_t) data.at("kv_cache_tokens_count")}
3329
+ {"value", (uint64_t) res_metrics->kv_cache_tokens_count}
2553
3330
  },{
2554
3331
  {"name", "requests_processing"},
2555
3332
  {"help", "Number of request processing."},
2556
- {"value", (uint64_t) data.at("processing")}
3333
+ {"value", (uint64_t) res_metrics->n_processing_slots}
2557
3334
  },{
2558
3335
  {"name", "requests_deferred"},
2559
3336
  {"help", "Number of request deferred."},
2560
- {"value", (uint64_t) data.at("deferred")}
3337
+ {"value", (uint64_t) res_metrics->n_tasks_deferred}
2561
3338
  }}}
2562
3339
  };
2563
3340
 
@@ -2578,8 +3355,7 @@ int main(int argc, char ** argv) {
2578
3355
  }
2579
3356
  }
2580
3357
 
2581
- const int64_t t_start = data.at("t_start");
2582
- res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
3358
+ res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start));
2583
3359
 
2584
3360
  res.set_content(prometheus.str(), "text/plain; version=0.0.4");
2585
3361
  res.status = 200; // HTTP OK
@@ -2594,25 +3370,24 @@ int main(int argc, char ** argv) {
2594
3370
  }
2595
3371
  std::string filepath = params.slot_save_path + filename;
2596
3372
 
2597
- server_task task;
2598
- task.type = SERVER_TASK_TYPE_SLOT_SAVE;
2599
- task.data = {
2600
- { "id_slot", id_slot },
2601
- { "filename", filename },
2602
- { "filepath", filepath },
2603
- };
3373
+ server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
3374
+ task.id = ctx_server.queue_tasks.get_new_id();
3375
+ task.slot_action.slot_id = id_slot;
3376
+ task.slot_action.filename = filename;
3377
+ task.slot_action.filepath = filepath;
2604
3378
 
2605
- const int id_task = ctx_server.queue_tasks.post(task);
2606
- ctx_server.queue_results.add_waiting_task_id(id_task);
3379
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3380
+ ctx_server.queue_tasks.post(task);
2607
3381
 
2608
- server_task_result result = ctx_server.queue_results.recv(id_task);
2609
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3382
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3383
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
2610
3384
 
2611
- if (result.error) {
2612
- res_error(res, result.data);
2613
- } else {
2614
- res_ok(res, result.data);
3385
+ if (result->is_error()) {
3386
+ res_error(res, result->to_json());
3387
+ return;
2615
3388
  }
3389
+
3390
+ res_ok(res, result->to_json());
2616
3391
  };
2617
3392
 
2618
3393
  const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
@@ -2624,45 +3399,45 @@ int main(int argc, char ** argv) {
2624
3399
  }
2625
3400
  std::string filepath = params.slot_save_path + filename;
2626
3401
 
2627
- server_task task;
2628
- task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
2629
- task.data = {
2630
- { "id_slot", id_slot },
2631
- { "filename", filename },
2632
- { "filepath", filepath },
2633
- };
3402
+ server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
3403
+ task.id = ctx_server.queue_tasks.get_new_id();
3404
+ task.slot_action.slot_id = id_slot;
3405
+ task.slot_action.filename = filename;
3406
+ task.slot_action.filepath = filepath;
2634
3407
 
2635
- const int id_task = ctx_server.queue_tasks.post(task);
2636
- ctx_server.queue_results.add_waiting_task_id(id_task);
3408
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3409
+ ctx_server.queue_tasks.post(task);
2637
3410
 
2638
- server_task_result result = ctx_server.queue_results.recv(id_task);
2639
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3411
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3412
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
2640
3413
 
2641
- if (result.error) {
2642
- res_error(res, result.data);
2643
- } else {
2644
- res_ok(res, result.data);
3414
+ if (result->is_error()) {
3415
+ res_error(res, result->to_json());
3416
+ return;
2645
3417
  }
3418
+
3419
+ GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
3420
+ res_ok(res, result->to_json());
2646
3421
  };
2647
3422
 
2648
3423
  const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
2649
- server_task task;
2650
- task.type = SERVER_TASK_TYPE_SLOT_ERASE;
2651
- task.data = {
2652
- { "id_slot", id_slot },
2653
- };
3424
+ server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
3425
+ task.id = ctx_server.queue_tasks.get_new_id();
3426
+ task.slot_action.slot_id = id_slot;
2654
3427
 
2655
- const int id_task = ctx_server.queue_tasks.post(task);
2656
- ctx_server.queue_results.add_waiting_task_id(id_task);
3428
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3429
+ ctx_server.queue_tasks.post(task);
2657
3430
 
2658
- server_task_result result = ctx_server.queue_results.recv(id_task);
2659
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3431
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3432
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
2660
3433
 
2661
- if (result.error) {
2662
- res_error(res, result.data);
2663
- } else {
2664
- res_ok(res, result.data);
3434
+ if (result->is_error()) {
3435
+ res_error(res, result->to_json());
3436
+ return;
2665
3437
  }
3438
+
3439
+ GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
3440
+ res_ok(res, result->to_json());
2666
3441
  };
2667
3442
 
2668
3443
  const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
@@ -2695,9 +3470,11 @@ int main(int argc, char ** argv) {
2695
3470
  };
2696
3471
 
2697
3472
  const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
3473
+ // this endpoint is publicly available, please only return what is safe to be exposed
2698
3474
  json data = {
2699
3475
  { "default_generation_settings", ctx_server.default_generation_settings_for_props },
2700
- { "total_slots", ctx_server.params.n_parallel },
3476
+ { "total_slots", ctx_server.params_base.n_parallel },
3477
+ { "model_path", ctx_server.params_base.model },
2701
3478
  { "chat_template", llama_get_chat_template(ctx_server.model) },
2702
3479
  };
2703
3480
 
@@ -2705,7 +3482,7 @@ int main(int argc, char ** argv) {
2705
3482
  };
2706
3483
 
2707
3484
  const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
2708
- if (!ctx_server.params.endpoint_props) {
3485
+ if (!ctx_server.params_base.endpoint_props) {
2709
3486
  res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
2710
3487
  return;
2711
3488
  }
@@ -2717,13 +3494,50 @@ int main(int argc, char ** argv) {
2717
3494
  res_ok(res, {{ "success", true }});
2718
3495
  };
2719
3496
 
2720
- const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
2721
- if (ctx_server.params.embedding) {
3497
+ // handle completion-like requests (completion, chat, infill)
3498
+ // we can optionally provide a custom format for partial results and final results
3499
+ const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
3500
+ server_task_type type,
3501
+ json & data,
3502
+ httplib::Response & res,
3503
+ bool oaicompat = false,
3504
+ bool oaicompat_chat = false) {
3505
+ GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
3506
+
3507
+ if (ctx_server.params_base.embedding) {
2722
3508
  res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2723
3509
  return;
2724
3510
  }
2725
3511
 
2726
- std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
3512
+ auto completion_id = gen_chatcmplid();
3513
+ std::vector<server_task> tasks;
3514
+
3515
+ try {
3516
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
3517
+ tasks.reserve(tokenized_prompts.size());
3518
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3519
+ server_task task = server_task(type);
3520
+
3521
+ task.id = ctx_server.queue_tasks.get_new_id();
3522
+ task.index = i;
3523
+
3524
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
3525
+ task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
3526
+ task.id_selected_slot = json_value(data, "id_slot", -1);
3527
+
3528
+ // OAI-compat
3529
+ task.params.oaicompat = oaicompat;
3530
+ task.params.oaicompat_chat = oaicompat_chat;
3531
+ task.params.oaicompat_cmpl_id = completion_id;
3532
+ // oaicompat_model is already populated by params_from_json_cmpl
3533
+
3534
+ tasks.push_back(task);
3535
+ }
3536
+ } catch (const std::exception & e) {
3537
+ res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
3538
+ return;
3539
+ }
3540
+
2727
3541
  ctx_server.queue_results.add_waiting_tasks(tasks);
2728
3542
  ctx_server.queue_tasks.post(tasks);
2729
3543
 
@@ -2731,15 +3545,15 @@ int main(int argc, char ** argv) {
2731
3545
  const auto task_ids = server_task::get_list_id(tasks);
2732
3546
 
2733
3547
  if (!stream) {
2734
- ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3548
+ ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
2735
3549
  if (results.size() == 1) {
2736
3550
  // single result
2737
- res_ok(res, results[0].data);
3551
+ res_ok(res, results[0]->to_json());
2738
3552
  } else {
2739
3553
  // multiple results (multitask)
2740
3554
  json arr = json::array();
2741
- for (const auto & res : results) {
2742
- arr.push_back(res.data);
3555
+ for (auto & res : results) {
3556
+ arr.push_back(res->to_json());
2743
3557
  }
2744
3558
  res_ok(res, arr);
2745
3559
  }
@@ -2749,12 +3563,26 @@ int main(int argc, char ** argv) {
2749
3563
 
2750
3564
  ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2751
3565
  } else {
2752
- const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
2753
- ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
2754
- return server_sent_event(sink, "data", result.data);
3566
+ const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
3567
+ ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
3568
+ json res_json = result->to_json();
3569
+ if (res_json.is_array()) {
3570
+ for (const auto & res : res_json) {
3571
+ if (!server_sent_event(sink, "data", res)) {
3572
+ return false;
3573
+ }
3574
+ }
3575
+ return true;
3576
+ } else {
3577
+ return server_sent_event(sink, "data", res_json);
3578
+ }
2755
3579
  }, [&](const json & error_data) {
2756
3580
  server_sent_event(sink, "error", error_data);
2757
3581
  });
3582
+ if (oaicompat) {
3583
+ static const std::string ev_done = "data: [DONE]\n\n";
3584
+ sink.write(ev_done.data(), ev_done.size());
3585
+ }
2758
3586
  sink.done();
2759
3587
  return false;
2760
3588
  };
@@ -2769,7 +3597,12 @@ int main(int argc, char ** argv) {
2769
3597
 
2770
3598
  const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2771
3599
  json data = json::parse(req.body);
2772
- return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
3600
+ return handle_completions_generic(
3601
+ SERVER_TASK_TYPE_COMPLETION,
3602
+ data,
3603
+ res,
3604
+ /* oaicompat */ false,
3605
+ /* oaicompat_chat */ false);
2773
3606
  };
2774
3607
 
2775
3608
  const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
@@ -2792,6 +3625,11 @@ int main(int argc, char ** argv) {
2792
3625
  json data = json::parse(req.body);
2793
3626
 
2794
3627
  // validate input
3628
+ if (data.contains("prompt") && !data.at("prompt").is_string()) {
3629
+ // prompt is optional
3630
+ res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3631
+ }
3632
+
2795
3633
  if (!data.contains("input_prefix")) {
2796
3634
  res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
2797
3635
  }
@@ -2801,9 +3639,11 @@ int main(int argc, char ** argv) {
2801
3639
  }
2802
3640
 
2803
3641
  if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
3642
+ // input_extra is optional
2804
3643
  res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
2805
3644
  return;
2806
3645
  }
3646
+
2807
3647
  json input_extra = json_value(data, "input_extra", json::array());
2808
3648
  for (const auto & chunk : input_extra) {
2809
3649
  // { "text": string, "filename": string }
@@ -2819,67 +3659,40 @@ int main(int argc, char ** argv) {
2819
3659
  }
2820
3660
  data["input_extra"] = input_extra; // default to empty array if it's not exist
2821
3661
 
2822
- return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
3662
+ std::string prompt = json_value(data, "prompt", std::string());
3663
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
3664
+ SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
3665
+ data["prompt"] = format_infill(
3666
+ ctx_server.ctx,
3667
+ data.at("input_prefix"),
3668
+ data.at("input_suffix"),
3669
+ data.at("input_extra"),
3670
+ ctx_server.params_base.n_batch,
3671
+ ctx_server.params_base.n_predict,
3672
+ ctx_server.slots[0].n_ctx, // TODO: there should be a better way
3673
+ ctx_server.params_base.spm_infill,
3674
+ tokenized_prompts[0]
3675
+ );
3676
+
3677
+ return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
2823
3678
  };
2824
3679
 
2825
- // TODO: maybe merge this function with "handle_completions_generic"
2826
- const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
2827
- if (ctx_server.params.embedding) {
3680
+ const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
3681
+ if (ctx_server.params_base.embedding) {
2828
3682
  res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2829
3683
  return;
2830
3684
  }
2831
3685
 
2832
3686
  json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
2833
-
2834
- std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
2835
- ctx_server.queue_results.add_waiting_tasks(tasks);
2836
- ctx_server.queue_tasks.post(tasks);
2837
-
2838
- bool stream = json_value(data, "stream", false);
2839
- const auto task_ids = server_task::get_list_id(tasks);
2840
- const auto completion_id = gen_chatcmplid();
2841
-
2842
- if (!stream) {
2843
- ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
2844
- // multitask is never support in chat completion, there is only one result
2845
- json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
2846
- res_ok(res, result_oai);
2847
- }, [&](const json & error_data) {
2848
- res_error(res, error_data);
2849
- });
2850
-
2851
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2852
- } else {
2853
- const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
2854
- ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
2855
- std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
2856
- for (auto & event_data : result_array) {
2857
- if (event_data.empty()) {
2858
- continue; // skip the stop token
2859
- }
2860
- if (!server_sent_event(sink, "data", event_data)) {
2861
- return false; // connection is closed
2862
- }
2863
- }
2864
- return true; // ok
2865
- }, [&](const json & error_data) {
2866
- server_sent_event(sink, "error", error_data);
2867
- });
2868
- static const std::string ev_done = "data: [DONE]\n\n";
2869
- sink.write(ev_done.data(), ev_done.size());
2870
- sink.done();
2871
- return true;
2872
- };
2873
-
2874
- auto on_complete = [task_ids, &ctx_server] (bool) {
2875
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2876
- };
2877
-
2878
- res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
2879
- }
3687
+ return handle_completions_generic(
3688
+ SERVER_TASK_TYPE_COMPLETION,
3689
+ data,
3690
+ res,
3691
+ /* oaicompat */ true,
3692
+ /* oaicompat_chat */ true);
2880
3693
  };
2881
3694
 
2882
- const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
3695
+ const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
2883
3696
  json models = {
2884
3697
  {"object", "list"},
2885
3698
  {"data", {
@@ -2893,7 +3706,7 @@ int main(int argc, char ** argv) {
2893
3706
  }}
2894
3707
  };
2895
3708
 
2896
- res.set_content(models.dump(), MIMETYPE_JSON);
3709
+ res_ok(res, models);
2897
3710
  };
2898
3711
 
2899
3712
  const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
@@ -2949,37 +3762,63 @@ int main(int argc, char ** argv) {
2949
3762
  res_ok(res, data);
2950
3763
  };
2951
3764
 
2952
- const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3765
+ const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
2953
3766
  const json body = json::parse(req.body);
2954
- bool is_openai = false;
2955
3767
 
2956
- // an input prompt can be a string or a list of tokens (integer)
3768
+ if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
3769
+ res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
3770
+ return;
3771
+ }
3772
+
3773
+ // for the shape of input/content, see tokenize_input_prompts()
2957
3774
  json prompt;
2958
3775
  if (body.count("input") != 0) {
2959
- is_openai = true;
2960
3776
  prompt = body.at("input");
2961
- } else if (body.count("content") != 0) {
2962
- // with "content", we only support single prompt
2963
- prompt = std::vector<std::string>{body.at("content")};
3777
+ } else if (body.contains("content")) {
3778
+ oaicompat = false;
3779
+ prompt = body.at("content");
2964
3780
  } else {
2965
3781
  res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
2966
3782
  return;
2967
3783
  }
2968
3784
 
3785
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
3786
+ for (const auto & tokens : tokenized_prompts) {
3787
+ // this check is necessary for models that do not add BOS token to the input
3788
+ if (tokens.empty()) {
3789
+ res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
3790
+ return;
3791
+ }
3792
+ }
3793
+
2969
3794
  // create and queue the task
2970
3795
  json responses = json::array();
2971
3796
  bool error = false;
2972
3797
  {
2973
- std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
3798
+ std::vector<server_task> tasks;
3799
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3800
+ server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
3801
+
3802
+ task.id = ctx_server.queue_tasks.get_new_id();
3803
+ task.index = i;
3804
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
3805
+
3806
+ // OAI-compat
3807
+ task.params.oaicompat = oaicompat;
3808
+
3809
+ tasks.push_back(task);
3810
+ }
3811
+
2974
3812
  ctx_server.queue_results.add_waiting_tasks(tasks);
2975
3813
  ctx_server.queue_tasks.post(tasks);
2976
3814
 
2977
3815
  // get the result
2978
3816
  std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
2979
3817
 
2980
- ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
2981
- for (const auto & res : results) {
2982
- responses.push_back(res.data);
3818
+ ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
3819
+ for (auto & res : results) {
3820
+ GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
3821
+ responses.push_back(res->to_json());
2983
3822
  }
2984
3823
  }, [&](const json & error_data) {
2985
3824
  res_error(res, error_data);
@@ -2994,14 +3833,20 @@ int main(int argc, char ** argv) {
2994
3833
  }
2995
3834
 
2996
3835
  // write JSON response
2997
- json root = is_openai
2998
- ? format_embeddings_response_oaicompat(body, responses)
2999
- : responses[0];
3836
+ json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
3000
3837
  res_ok(res, root);
3001
3838
  };
3002
3839
 
3840
+ const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3841
+ handle_embeddings_impl(req, res, false);
3842
+ };
3843
+
3844
+ const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3845
+ handle_embeddings_impl(req, res, true);
3846
+ };
3847
+
3003
3848
  const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3004
- if (!ctx_server.params.reranking || ctx_server.params.embedding) {
3849
+ if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
3005
3850
  res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
3006
3851
  return;
3007
3852
  }
@@ -3035,29 +3880,33 @@ int main(int argc, char ** argv) {
3035
3880
  return;
3036
3881
  }
3037
3882
 
3038
- // construct prompt object: array of ["query", "doc0", "doc1", ...]
3039
- json prompt;
3040
- prompt.push_back(query);
3041
- for (const auto & doc : documents) {
3042
- prompt.push_back(doc);
3043
- }
3044
-
3045
- LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
3883
+ llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0];
3046
3884
 
3047
3885
  // create and queue the task
3048
3886
  json responses = json::array();
3049
3887
  bool error = false;
3050
3888
  {
3051
- std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
3889
+ std::vector<server_task> tasks;
3890
+ std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
3891
+ tasks.reserve(tokenized_docs.size());
3892
+ for (size_t i = 0; i < tokenized_docs.size(); i++) {
3893
+ server_task task = server_task(SERVER_TASK_TYPE_RERANK);
3894
+ task.id = ctx_server.queue_tasks.get_new_id();
3895
+ task.index = i;
3896
+ task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
3897
+ tasks.push_back(task);
3898
+ }
3899
+
3052
3900
  ctx_server.queue_results.add_waiting_tasks(tasks);
3053
3901
  ctx_server.queue_tasks.post(tasks);
3054
3902
 
3055
3903
  // get the result
3056
3904
  std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
3057
3905
 
3058
- ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3059
- for (const auto & res : results) {
3060
- responses.push_back(res.data);
3906
+ ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
3907
+ for (auto & res : results) {
3908
+ GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
3909
+ responses.push_back(res->to_json());
3061
3910
  }
3062
3911
  }, [&](const json & error_data) {
3063
3912
  res_error(res, error_data);
@@ -3108,36 +3957,47 @@ int main(int argc, char ** argv) {
3108
3957
  }
3109
3958
  }
3110
3959
 
3111
- server_task task;
3112
- task.type = SERVER_TASK_TYPE_SET_LORA;
3113
- const int id_task = ctx_server.queue_tasks.post(task);
3114
- ctx_server.queue_results.add_waiting_task_id(id_task);
3960
+ server_task task(SERVER_TASK_TYPE_SET_LORA);
3961
+ task.id = ctx_server.queue_tasks.get_new_id();
3962
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3963
+ ctx_server.queue_tasks.post(task);
3964
+
3965
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3966
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
3115
3967
 
3116
- server_task_result result = ctx_server.queue_results.recv(id_task);
3117
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3968
+ if (result->is_error()) {
3969
+ res_error(res, result->to_json());
3970
+ return;
3971
+ }
3118
3972
 
3119
- res_ok(res, result.data);
3120
- res.status = 200; // HTTP OK
3973
+ GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
3974
+ res_ok(res, result->to_json());
3121
3975
  };
3122
3976
 
3123
3977
  //
3124
3978
  // Router
3125
3979
  //
3126
3980
 
3127
- // register static assets routes
3128
- if (!params.public_path.empty()) {
3129
- // Set the base directory for serving static files
3130
- bool is_found = svr->set_mount_point("/", params.public_path);
3131
- if (!is_found) {
3132
- LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
3133
- return 1;
3134
- }
3981
+ if (!params.webui) {
3982
+ LOG_INF("Web UI is disabled\n");
3135
3983
  } else {
3136
- // using embedded static files
3137
- for (const auto & it : static_files) {
3138
- const server_static_file & static_file = it.second;
3139
- svr->Get(it.first.c_str(), [&static_file](const httplib::Request &, httplib::Response & res) {
3140
- res.set_content(reinterpret_cast<const char*>(static_file.data), static_file.size, static_file.mime_type);
3984
+ // register static assets routes
3985
+ if (!params.public_path.empty()) {
3986
+ // Set the base directory for serving static files
3987
+ bool is_found = svr->set_mount_point("/", params.public_path);
3988
+ if (!is_found) {
3989
+ LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
3990
+ return 1;
3991
+ }
3992
+ } else {
3993
+ // using embedded static index.html
3994
+ svr->Get("/", [](const httplib::Request & req, httplib::Response & res) {
3995
+ if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
3996
+ res.set_content("Error: gzip is not supported by this browser", "text/plain");
3997
+ } else {
3998
+ res.set_header("Content-Encoding", "gzip");
3999
+ res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
4000
+ }
3141
4001
  return false;
3142
4002
  });
3143
4003
  }
@@ -3158,7 +4018,7 @@ int main(int argc, char ** argv) {
3158
4018
  svr->Post("/infill", handle_infill);
3159
4019
  svr->Post("/embedding", handle_embeddings); // legacy
3160
4020
  svr->Post("/embeddings", handle_embeddings);
3161
- svr->Post("/v1/embeddings", handle_embeddings);
4021
+ svr->Post("/v1/embeddings", handle_embeddings_oai);
3162
4022
  svr->Post("/rerank", handle_rerank);
3163
4023
  svr->Post("/reranking", handle_rerank);
3164
4024
  svr->Post("/v1/rerank", handle_rerank);
@@ -3188,8 +4048,18 @@ int main(int argc, char ** argv) {
3188
4048
  llama_backend_free();
3189
4049
  };
3190
4050
 
3191
- // bind HTTP listen port, run the HTTP server in a thread
3192
- if (!svr->bind_to_port(params.hostname, params.port)) {
4051
+ // bind HTTP listen port
4052
+ bool was_bound = false;
4053
+ if (params.port == 0) {
4054
+ int bound_port = svr->bind_to_any_port(params.hostname);
4055
+ if ((was_bound = (bound_port >= 0))) {
4056
+ params.port = bound_port;
4057
+ }
4058
+ } else {
4059
+ was_bound = svr->bind_to_port(params.hostname, params.port);
4060
+ }
4061
+
4062
+ if (!was_bound) {
3193
4063
  //LOG_ERROR("couldn't bind HTTP server socket", {
3194
4064
  // {"hostname", params.hostname},
3195
4065
  // {"port", params.port},
@@ -3198,6 +4068,8 @@ int main(int argc, char ** argv) {
3198
4068
  clean_up();
3199
4069
  return 1;
3200
4070
  }
4071
+
4072
+ // run the HTTP server in a thread
3201
4073
  std::thread t([&]() { svr->listen_after_bind(); });
3202
4074
  svr->wait_until_ready();
3203
4075