@fugood/llama.node 0.3.1 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. package/CMakeLists.txt +1 -8
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +4 -2
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +10 -10
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +14 -17
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +5 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +137 -29
  25. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  26. package/src/llama.cpp/.github/workflows/docker.yml +46 -34
  27. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  28. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  29. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  30. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  31. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  32. package/src/llama.cpp/CMakeLists.txt +26 -11
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/common/CMakeLists.txt +10 -10
  35. package/src/llama.cpp/common/arg.cpp +2041 -0
  36. package/src/llama.cpp/common/arg.h +77 -0
  37. package/src/llama.cpp/common/common.cpp +523 -1861
  38. package/src/llama.cpp/common/common.h +234 -106
  39. package/src/llama.cpp/common/console.cpp +3 -0
  40. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  41. package/src/llama.cpp/common/log.cpp +401 -0
  42. package/src/llama.cpp/common/log.h +66 -698
  43. package/src/llama.cpp/common/ngram-cache.cpp +39 -36
  44. package/src/llama.cpp/common/ngram-cache.h +19 -19
  45. package/src/llama.cpp/common/sampling.cpp +356 -350
  46. package/src/llama.cpp/common/sampling.h +62 -139
  47. package/src/llama.cpp/common/stb_image.h +5990 -6398
  48. package/src/llama.cpp/docs/build.md +72 -17
  49. package/src/llama.cpp/examples/CMakeLists.txt +1 -2
  50. package/src/llama.cpp/examples/batched/batched.cpp +49 -65
  51. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
  52. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  53. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
  54. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  55. package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
  56. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
  58. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  59. package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
  60. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  61. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  62. package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
  63. package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
  64. package/src/llama.cpp/examples/infill/infill.cpp +131 -192
  65. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
  66. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  67. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
  68. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  69. package/src/llama.cpp/examples/llava/clip.cpp +686 -150
  70. package/src/llama.cpp/examples/llava/clip.h +11 -2
  71. package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
  72. package/src/llama.cpp/examples/llava/llava.cpp +146 -26
  73. package/src/llama.cpp/examples/llava/llava.h +2 -3
  74. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  75. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  76. package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
  77. package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
  78. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  79. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
  80. package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
  81. package/src/llama.cpp/examples/main/main.cpp +216 -313
  82. package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
  83. package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
  84. package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
  85. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  86. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  87. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
  88. package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
  89. package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
  90. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
  91. package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
  92. package/src/llama.cpp/examples/server/server.cpp +1347 -1531
  93. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  94. package/src/llama.cpp/examples/server/utils.hpp +396 -107
  95. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  96. package/src/llama.cpp/examples/simple/simple.cpp +132 -106
  97. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  98. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  99. package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
  100. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  101. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  102. package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
  103. package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
  104. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  105. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  106. package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
  107. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  108. package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
  109. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  110. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  111. package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
  112. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  113. package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
  114. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  115. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  116. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  117. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  118. package/src/llama.cpp/ggml/include/ggml.h +272 -505
  119. package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
  120. package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
  121. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  122. package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
  123. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  124. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  125. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  126. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  127. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  128. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
  129. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  130. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
  131. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  132. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
  133. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  134. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  135. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  136. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  137. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  138. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
  139. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  140. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  141. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  142. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  143. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  144. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  145. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  146. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
  147. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  148. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  149. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  150. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  151. package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
  152. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  153. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  154. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  155. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  156. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  157. package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
  158. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  159. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
  160. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  161. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  162. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  163. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  164. package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
  165. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
  166. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  167. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
  168. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  169. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
  170. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
  171. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
  172. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  173. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  174. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  175. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  176. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  177. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  178. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  179. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
  180. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  181. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  182. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  183. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
  184. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  185. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  186. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
  187. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
  188. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  189. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  190. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  192. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  193. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  194. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  195. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  196. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  197. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  198. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  199. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  200. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
  201. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
  202. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
  203. package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
  204. package/src/llama.cpp/include/llama.h +296 -285
  205. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  206. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  207. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  208. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  209. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  210. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  211. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  212. package/src/llama.cpp/src/llama-grammar.h +120 -15
  213. package/src/llama.cpp/src/llama-impl.h +156 -1
  214. package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
  215. package/src/llama.cpp/src/llama-sampling.h +39 -47
  216. package/src/llama.cpp/src/llama-vocab.cpp +390 -127
  217. package/src/llama.cpp/src/llama-vocab.h +60 -20
  218. package/src/llama.cpp/src/llama.cpp +6215 -3263
  219. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  220. package/src/llama.cpp/src/unicode-data.h +4 -4
  221. package/src/llama.cpp/src/unicode.cpp +15 -7
  222. package/src/llama.cpp/tests/CMakeLists.txt +4 -2
  223. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  224. package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
  225. package/src/llama.cpp/tests/test-barrier.cpp +94 -0
  226. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  227. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  228. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  229. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
  230. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  231. package/src/llama.cpp/tests/test-log.cpp +39 -0
  232. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  233. package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
  234. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  235. package/src/llama.cpp/tests/test-rope.cpp +2 -1
  236. package/src/llama.cpp/tests/test-sampling.cpp +226 -142
  237. package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
  238. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  239. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  240. package/patches/llama.patch +0 -22
  241. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  242. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  243. package/src/llama.cpp/common/grammar-parser.h +0 -29
  244. package/src/llama.cpp/common/train.cpp +0 -1513
  245. package/src/llama.cpp/common/train.h +0 -233
  246. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
  247. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  248. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
  249. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
  250. package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
  251. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  252. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -1,100 +1,100 @@
1
1
  #include "utils.hpp"
2
2
 
3
+ #include "arg.h"
3
4
  #include "common.h"
5
+ #include "log.h"
6
+ #include "sampling.h"
4
7
  #include "json-schema-to-grammar.h"
5
8
  #include "llama.h"
6
- #include "grammar-parser.h"
7
9
 
8
- #ifndef NDEBUG
9
- // crash the server in debug mode, otherwise send an http 500 error
10
- #define CPPHTTPLIB_NO_EXCEPTIONS 1
11
- #endif
12
- // increase max payload length to allow use of larger context size
13
- #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
14
- #include "httplib.h"
15
10
  // Change JSON_ASSERT from assert() to GGML_ASSERT:
16
11
  #define JSON_ASSERT GGML_ASSERT
17
12
  #include "json.hpp"
13
+ // mime type for sending response
14
+ #define MIMETYPE_JSON "application/json; charset=utf-8"
18
15
 
19
16
  // auto generated files (update with ./deps.sh)
20
- #include "colorthemes.css.hpp"
21
- #include "style.css.hpp"
22
- #include "theme-beeninorder.css.hpp"
23
- #include "theme-ketivah.css.hpp"
24
- #include "theme-mangotango.css.hpp"
25
- #include "theme-playground.css.hpp"
26
- #include "theme-polarnight.css.hpp"
27
- #include "theme-snowstorm.css.hpp"
28
17
  #include "index.html.hpp"
29
- #include "index-new.html.hpp"
30
- #include "index.js.hpp"
31
18
  #include "completion.js.hpp"
32
- #include "system-prompts.js.hpp"
33
- #include "prompt-formats.js.hpp"
34
- #include "json-schema-to-grammar.mjs.hpp"
19
+ #include "loading.html.hpp"
20
+ #include "deps_daisyui.min.css.hpp"
21
+ #include "deps_markdown-it.js.hpp"
22
+ #include "deps_tailwindcss.js.hpp"
23
+ #include "deps_vue.esm-browser.js.hpp"
35
24
 
36
25
  #include <atomic>
37
- #include <chrono>
38
26
  #include <condition_variable>
39
27
  #include <cstddef>
40
- #include <set>
28
+ #include <cinttypes>
29
+ #include <deque>
30
+ #include <memory>
41
31
  #include <mutex>
42
- #include <thread>
43
32
  #include <signal.h>
44
- #include <memory>
33
+ #include <thread>
34
+ #include <unordered_map>
35
+ #include <unordered_set>
45
36
 
46
37
  using json = nlohmann::ordered_json;
47
38
 
48
- bool server_verbose = false;
49
- bool server_log_json = true;
50
-
51
39
  enum stop_type {
52
40
  STOP_TYPE_FULL,
53
41
  STOP_TYPE_PARTIAL,
54
42
  };
55
43
 
44
+ // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
56
45
  enum slot_state {
57
46
  SLOT_STATE_IDLE,
58
- SLOT_STATE_PROCESSING,
59
- };
60
-
61
- enum slot_command {
62
- SLOT_COMMAND_NONE,
63
- SLOT_COMMAND_LOAD_PROMPT,
64
- SLOT_COMMAND_RELEASE,
47
+ SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
48
+ SLOT_STATE_PROCESSING_PROMPT,
49
+ SLOT_STATE_DONE_PROMPT,
50
+ SLOT_STATE_GENERATING,
65
51
  };
66
52
 
67
53
  enum server_state {
68
54
  SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet
69
55
  SERVER_STATE_READY, // Server is ready and model is loaded
70
- SERVER_STATE_ERROR // An error occurred, load_model failed
71
56
  };
72
57
 
73
58
  enum server_task_type {
74
- SERVER_TASK_TYPE_COMPLETION,
59
+ SERVER_TASK_TYPE_INFERENCE,
75
60
  SERVER_TASK_TYPE_CANCEL,
76
61
  SERVER_TASK_TYPE_NEXT_RESPONSE,
77
62
  SERVER_TASK_TYPE_METRICS,
78
63
  SERVER_TASK_TYPE_SLOT_SAVE,
79
64
  SERVER_TASK_TYPE_SLOT_RESTORE,
80
65
  SERVER_TASK_TYPE_SLOT_ERASE,
66
+ SERVER_TASK_TYPE_SET_LORA,
67
+ };
68
+
69
+ enum server_task_inf_type {
70
+ SERVER_TASK_INF_TYPE_COMPLETION,
71
+ SERVER_TASK_INF_TYPE_EMBEDDING,
72
+ SERVER_TASK_INF_TYPE_RERANK,
73
+ SERVER_TASK_INF_TYPE_INFILL,
81
74
  };
82
75
 
83
76
  struct server_task {
84
77
  int id = -1; // to be filled by server_queue
85
- int id_multi = -1;
86
- int id_target = -1;
78
+ int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
87
79
 
80
+ llama_tokens prompt_tokens;
88
81
  server_task_type type;
89
82
  json data;
90
83
 
91
- bool infill = false;
92
- bool embedding = false;
84
+ server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
85
+
86
+ // utility function
87
+ static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
88
+ std::unordered_set<int> ids(tasks.size());
89
+ for (size_t i = 0; i < tasks.size(); i++) {
90
+ ids.insert(tasks[i].id);
91
+ }
92
+ return ids;
93
+ }
93
94
  };
94
95
 
95
96
  struct server_task_result {
96
97
  int id = -1;
97
- int id_multi = -1;
98
98
 
99
99
  json data;
100
100
 
@@ -102,36 +102,37 @@ struct server_task_result {
102
102
  bool error;
103
103
  };
104
104
 
105
- struct server_task_multi {
106
- int id = -1;
107
-
108
- std::set<int> subtasks_remaining;
109
- std::vector<server_task_result> results;
105
+ struct server_static_file {
106
+ const unsigned char * data;
107
+ unsigned int size;
108
+ const char * mime_type;
110
109
  };
111
110
 
112
111
  struct slot_params {
113
112
  bool stream = true;
114
113
  bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
115
114
 
116
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
117
- int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
118
- int32_t n_predict = -1; // new tokens to predict
115
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
116
+ int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
117
+ int32_t n_predict = -1; // new tokens to predict
118
+ int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
119
119
 
120
- std::vector<std::string> antiprompt;
120
+ int64_t t_max_prompt_ms = -1; // TODO: implement
121
+ int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
121
122
 
122
- json input_prefix;
123
- json input_suffix;
123
+ std::vector<std::string> antiprompt;
124
124
  };
125
125
 
126
126
  struct server_slot {
127
127
  int id;
128
128
  int id_task = -1;
129
- int id_multi = -1;
129
+
130
+ // the index relative to completion multi-task request
131
+ size_t index = 0;
130
132
 
131
133
  struct slot_params params;
132
134
 
133
135
  slot_state state = SLOT_STATE_IDLE;
134
- slot_command command = SLOT_COMMAND_NONE;
135
136
 
136
137
  // used to determine the slot that has been used the longest
137
138
  int64_t t_last_used = -1;
@@ -144,21 +145,23 @@ struct server_slot {
144
145
  int32_t i_batch = -1;
145
146
  int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
146
147
 
148
+ // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
147
149
  int32_t n_prompt_tokens = 0;
148
150
  int32_t n_prompt_tokens_processed = 0;
149
151
 
150
- json prompt; // can be either a string, array of strings or array of token ids
152
+ // input prompt tokens
153
+ llama_tokens prompt_tokens;
151
154
 
152
- // when a task is submitted, we first tokenize the prompt and store it here
153
- std::vector<llama_token> prompt_tokens;
155
+ size_t last_nl_pos = 0;
154
156
 
155
157
  std::string generated_text;
156
- std::vector<llama_token> cache_tokens;
158
+ llama_tokens cache_tokens;
157
159
  std::vector<completion_token_output> generated_token_probs;
158
160
 
159
- bool infill = false;
160
- bool embedding = false;
161
+ server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
162
+
161
163
  bool has_next_token = true;
164
+ bool has_new_line = false;
162
165
  bool truncated = false;
163
166
  bool stopped_eos = false;
164
167
  bool stopped_word = false;
@@ -170,30 +173,32 @@ struct server_slot {
170
173
  std::string stopping_word;
171
174
 
172
175
  // sampling
173
- llama_token sampled;
174
- struct llama_sampling_params sparams;
175
- llama_sampling_context * ctx_sampling = nullptr;
176
176
  json json_schema;
177
177
 
178
- int32_t ga_i = 0; // group-attention state
179
- int32_t ga_n = 1; // group-attention factor
180
- int32_t ga_w = 512; // group-attention width
178
+ struct common_sampler_params sparams;
179
+ struct common_sampler * smpl = nullptr;
181
180
 
182
- int32_t n_past_se = 0; // self-extend
181
+ llama_token sampled;
183
182
 
184
183
  // stats
185
- size_t n_sent_text = 0; // number of sent text character
184
+ size_t n_sent_text = 0; // number of sent text character
186
185
  size_t n_sent_token_probs = 0;
187
186
 
188
187
  int64_t t_start_process_prompt;
189
188
  int64_t t_start_generation;
190
189
 
191
190
  double t_prompt_processing; // ms
192
- double t_token_generation; // ms
191
+ double t_token_generation; // ms
192
+
193
+ std::function<void(int)> callback_on_release;
193
194
 
194
195
  void reset() {
196
+ SLT_DBG(*this, "%s", "\n");
197
+
195
198
  n_prompt_tokens = 0;
199
+ last_nl_pos = 0;
196
200
  generated_text = "";
201
+ has_new_line = false;
197
202
  truncated = false;
198
203
  stopped_eos = false;
199
204
  stopped_word = false;
@@ -202,14 +207,12 @@ struct server_slot {
202
207
  n_past = 0;
203
208
  n_sent_text = 0;
204
209
  n_sent_token_probs = 0;
205
- infill = false;
206
- ga_i = 0;
207
- n_past_se = 0;
210
+ inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
208
211
 
209
212
  generated_token_probs.clear();
210
213
  }
211
214
 
212
- bool has_budget(gpt_params &global_params) {
215
+ bool has_budget(common_params &global_params) {
213
216
  if (params.n_predict == -1 && global_params.n_predict == -1) {
214
217
  return true; // limitless
215
218
  }
@@ -225,25 +228,26 @@ struct server_slot {
225
228
  return n_remaining > 0; // no budget
226
229
  }
227
230
 
228
- bool available() const {
229
- return state == SLOT_STATE_IDLE && command == SLOT_COMMAND_NONE;
230
- }
231
-
232
231
  bool is_processing() const {
233
- return (state == SLOT_STATE_IDLE && command == SLOT_COMMAND_LOAD_PROMPT) || state == SLOT_STATE_PROCESSING;
232
+ return state != SLOT_STATE_IDLE;
234
233
  }
235
234
 
236
- void add_token_string(const completion_token_output & token) {
237
- if (command == SLOT_COMMAND_RELEASE) {
235
+ void add_token(const completion_token_output & token) {
236
+ if (!is_processing()) {
237
+ SLT_WRN(*this, "%s", "slot is not processing\n");
238
238
  return;
239
239
  }
240
240
  generated_token_probs.push_back(token);
241
241
  }
242
242
 
243
243
  void release() {
244
- if (state == SLOT_STATE_PROCESSING) {
244
+ if (is_processing()) {
245
+ SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
246
+
247
+ t_last_used = ggml_time_us();
245
248
  t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
246
- command = SLOT_COMMAND_RELEASE;
249
+ state = SLOT_STATE_IDLE;
250
+ callback_on_release(id);
247
251
  }
248
252
  }
249
253
 
@@ -290,49 +294,20 @@ struct server_slot {
290
294
  }
291
295
 
292
296
  void print_timings() const {
293
- char buffer[512];
294
-
295
- double t_token = t_prompt_processing / n_prompt_tokens_processed;
296
- double n_tokens_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
297
-
298
- snprintf(buffer, 512, "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)",
299
- t_prompt_processing, n_prompt_tokens_processed,
300
- t_token, n_tokens_second);
301
-
302
- LOG_INFO(buffer, {
303
- {"id_slot", id},
304
- {"id_task", id_task},
305
- {"t_prompt_processing", t_prompt_processing},
306
- {"n_prompt_tokens_processed", n_prompt_tokens_processed},
307
- {"t_token", t_token},
308
- {"n_tokens_second", n_tokens_second},
309
- });
310
-
311
- t_token = t_token_generation / n_decoded;
312
- n_tokens_second = 1e3 / t_token_generation * n_decoded;
313
-
314
- snprintf(buffer, 512, "generation eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)",
315
- t_token_generation, n_decoded,
316
- t_token, n_tokens_second);
317
-
318
- LOG_INFO(buffer, {
319
- {"id_slot", id},
320
- {"id_task", id_task},
321
- {"t_token_generation", t_token_generation},
322
- {"n_decoded", n_decoded},
323
- {"t_token", t_token},
324
- {"n_tokens_second", n_tokens_second},
325
- });
326
-
327
- snprintf(buffer, 512, " total time = %10.2f ms", t_prompt_processing + t_token_generation);
328
-
329
- LOG_INFO(buffer, {
330
- {"id_slot", id},
331
- {"id_task", id_task},
332
- {"t_prompt_processing", t_prompt_processing},
333
- {"t_token_generation", t_token_generation},
334
- {"t_total", t_prompt_processing + t_token_generation},
335
- });
297
+ const double t_prompt = t_prompt_processing / n_prompt_tokens_processed;
298
+ const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
299
+
300
+ const double t_gen = t_token_generation / n_decoded;
301
+ const double n_gen_second = 1e3 / t_token_generation * n_decoded;
302
+
303
+ SLT_INF(*this,
304
+ "\n"
305
+ "\rprompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
306
+ "\r eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
307
+ "\r total time = %10.2f ms / %5d tokens\n",
308
+ t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
309
+ t_token_generation, n_decoded, t_gen, n_gen_second,
310
+ t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
336
311
  }
337
312
  };
338
313
 
@@ -350,6 +325,9 @@ struct server_metrics {
350
325
  uint64_t n_tokens_predicted = 0;
351
326
  uint64_t t_tokens_generation = 0;
352
327
 
328
+ uint64_t n_decode_total = 0;
329
+ uint64_t n_busy_slots_total = 0;
330
+
353
331
  void init() {
354
332
  t_start = ggml_time_us();
355
333
  }
@@ -368,6 +346,15 @@ struct server_metrics {
368
346
  t_tokens_generation_total += slot.t_token_generation;
369
347
  }
370
348
 
349
+ void on_decoded(const std::vector<server_slot> & slots) {
350
+ n_decode_total++;
351
+ for (const auto & slot : slots) {
352
+ if (slot.is_processing()) {
353
+ n_busy_slots_total++;
354
+ }
355
+ }
356
+ }
357
+
371
358
  void reset_bucket() {
372
359
  n_prompt_tokens_processed = 0;
373
360
  t_prompt_processing = 0;
@@ -381,68 +368,83 @@ struct server_queue {
381
368
  bool running;
382
369
 
383
370
  // queues
384
- std::vector<server_task> queue_tasks;
385
- std::vector<server_task> queue_tasks_deferred;
386
-
387
- std::vector<server_task_multi> queue_multitasks;
371
+ std::deque<server_task> queue_tasks;
372
+ std::deque<server_task> queue_tasks_deferred;
388
373
 
389
374
  std::mutex mutex_tasks;
390
375
  std::condition_variable condition_tasks;
391
376
 
392
377
  // callback functions
393
- std::function<void(server_task &)> callback_new_task;
394
- std::function<void(server_task_multi &)> callback_finish_multitask;
395
- std::function<void(void)> callback_update_slots;
378
+ std::function<void(server_task)> callback_new_task;
379
+ std::function<void(void)> callback_update_slots;
396
380
 
397
381
  // Add a new task to the end of the queue
398
- int post(server_task task) {
382
+ int post(server_task task, bool front = false) {
399
383
  std::unique_lock<std::mutex> lock(mutex_tasks);
400
384
  if (task.id == -1) {
401
385
  task.id = id++;
402
- LOG_VERBOSE("new task id", {{"new_id", task.id}});
403
386
  }
404
- queue_tasks.push_back(std::move(task));
387
+ QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
388
+ if (front) {
389
+ queue_tasks.push_front(std::move(task));
390
+ } else {
391
+ queue_tasks.push_back(std::move(task));
392
+ }
405
393
  condition_tasks.notify_one();
406
394
  return task.id;
407
395
  }
408
396
 
397
+ // multi-task version of post()
398
+ int post(std::vector<server_task> & tasks, bool front = false) {
399
+ std::unique_lock<std::mutex> lock(mutex_tasks);
400
+ for (auto & task : tasks) {
401
+ if (task.id == -1) {
402
+ task.id = id++;
403
+ }
404
+ QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front);
405
+ if (front) {
406
+ queue_tasks.push_front(std::move(task));
407
+ } else {
408
+ queue_tasks.push_back(std::move(task));
409
+ }
410
+ }
411
+ condition_tasks.notify_one();
412
+ return 0;
413
+ }
414
+
409
415
  // Add a new task, but defer until one slot is available
410
416
  void defer(server_task task) {
411
417
  std::unique_lock<std::mutex> lock(mutex_tasks);
418
+ QUE_DBG("defer task, id = %d\n", task.id);
412
419
  queue_tasks_deferred.push_back(std::move(task));
420
+ condition_tasks.notify_one();
413
421
  }
414
422
 
415
- // Get the next id for creating anew task
423
+ // Get the next id for creating a new task
416
424
  int get_new_id() {
417
425
  std::unique_lock<std::mutex> lock(mutex_tasks);
418
426
  int new_id = id++;
419
- LOG_VERBOSE("new task id", {{"new_id", new_id}});
420
427
  return new_id;
421
428
  }
422
429
 
423
430
  // Register function to process a new task
424
- void on_new_task(std::function<void(server_task &)> callback) {
431
+ void on_new_task(std::function<void(server_task)> callback) {
425
432
  callback_new_task = std::move(callback);
426
433
  }
427
434
 
428
- // Register function to process a multitask when it is finished
429
- void on_finish_multitask(std::function<void(server_task_multi&)> callback) {
430
- callback_finish_multitask = std::move(callback);
431
- }
432
-
433
435
  // Register the function to be called when all slots data is ready to be processed
434
436
  void on_update_slots(std::function<void(void)> callback) {
435
437
  callback_update_slots = std::move(callback);
436
438
  }
437
439
 
438
- // Call when the state of one slot is changed
439
- void notify_slot_changed() {
440
- // move deferred tasks back to main loop
440
+ // Call when the state of one slot is changed, it will move one task from deferred to main queue
441
+ void pop_deferred_task() {
441
442
  std::unique_lock<std::mutex> lock(mutex_tasks);
442
- for (auto & task : queue_tasks_deferred) {
443
- queue_tasks.push_back(std::move(task));
443
+ if (!queue_tasks_deferred.empty()) {
444
+ queue_tasks.emplace_back(std::move(queue_tasks_deferred.front()));
445
+ queue_tasks_deferred.pop_front();
444
446
  }
445
- queue_tasks_deferred.clear();
447
+ condition_tasks.notify_one();
446
448
  }
447
449
 
448
450
  // end the start_loop routine
@@ -463,7 +465,7 @@ struct server_queue {
463
465
  running = true;
464
466
 
465
467
  while (true) {
466
- LOG_VERBOSE("new task may arrive", {});
468
+ QUE_DBG("%s", "processing new tasks\n");
467
469
 
468
470
  while (true) {
469
471
  std::unique_lock<std::mutex> lock(mutex_tasks);
@@ -472,39 +474,24 @@ struct server_queue {
472
474
  break;
473
475
  }
474
476
  server_task task = queue_tasks.front();
475
- queue_tasks.erase(queue_tasks.begin());
477
+ queue_tasks.pop_front();
476
478
  lock.unlock();
477
- LOG_VERBOSE("callback_new_task", {{"id_task", task.id}});
478
- callback_new_task(task);
479
- }
480
479
 
481
- LOG_VERBOSE("update_multitasks", {});
482
-
483
- // check if we have any finished multitasks
484
- auto queue_iterator = queue_multitasks.begin();
485
- while (queue_iterator != queue_multitasks.end()) {
486
- if (queue_iterator->subtasks_remaining.empty()) {
487
- // all subtasks done == multitask is done
488
- server_task_multi current_multitask = *queue_iterator;
489
- callback_finish_multitask(current_multitask);
490
- // remove this multitask
491
- queue_iterator = queue_multitasks.erase(queue_iterator);
492
- } else {
493
- ++queue_iterator;
494
- }
480
+ QUE_DBG("processing task, id = %d\n", task.id);
481
+ callback_new_task(std::move(task));
495
482
  }
496
483
 
497
484
  // all tasks in the current loop is processed, slots data is now ready
498
- LOG_VERBOSE("callback_update_slots", {});
485
+ QUE_DBG("%s", "update slots\n");
499
486
 
500
487
  callback_update_slots();
501
488
 
502
- LOG_VERBOSE("wait for new task", {});
489
+ QUE_DBG("%s", "waiting for new tasks\n");
503
490
  {
504
491
  std::unique_lock<std::mutex> lock(mutex_tasks);
505
492
  if (queue_tasks.empty()) {
506
493
  if (!running) {
507
- LOG_VERBOSE("ending start_loop", {});
494
+ QUE_DBG("%s", "terminate\n");
508
495
  return;
509
496
  }
510
497
  condition_tasks.wait(lock, [&]{
@@ -514,38 +501,11 @@ struct server_queue {
514
501
  }
515
502
  }
516
503
  }
517
-
518
- //
519
- // functions to manage multitasks
520
- //
521
-
522
- // add a multitask by specifying the id of all subtask (subtask is a server_task)
523
- void add_multitask(int id_multi, std::vector<int> & sub_ids) {
524
- std::lock_guard<std::mutex> lock(mutex_tasks);
525
- server_task_multi multi;
526
- multi.id = id_multi;
527
- std::copy(sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
528
- queue_multitasks.push_back(multi);
529
- }
530
-
531
- // updatethe remaining subtasks, while appending results to multitask
532
- void update_multitask(int id_multi, int id_sub, server_task_result & result) {
533
- std::lock_guard<std::mutex> lock(mutex_tasks);
534
- for (auto & multitask : queue_multitasks) {
535
- if (multitask.id == id_multi) {
536
- multitask.subtasks_remaining.erase(id_sub);
537
- multitask.results.push_back(result);
538
- }
539
- }
540
- }
541
504
  };
542
505
 
543
506
  struct server_response {
544
- typedef std::function<void(int, int, server_task_result &)> callback_multitask_t;
545
- callback_multitask_t callback_update_multitask;
546
-
547
507
  // for keeping track of all tasks waiting for the result
548
- std::set<int> waiting_task_ids;
508
+ std::unordered_set<int> waiting_task_ids;
549
509
 
550
510
  // the main result queue
551
511
  std::vector<server_task_result> queue_results;
@@ -555,22 +515,40 @@ struct server_response {
555
515
 
556
516
  // add the id_task to the list of tasks waiting for response
557
517
  void add_waiting_task_id(int id_task) {
558
- LOG_VERBOSE("waiting for task id", {{"id_task", id_task}});
518
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
559
519
 
560
520
  std::unique_lock<std::mutex> lock(mutex_results);
561
521
  waiting_task_ids.insert(id_task);
562
522
  }
563
523
 
524
+ void add_waiting_tasks(const std::vector<server_task> & tasks) {
525
+ std::unique_lock<std::mutex> lock(mutex_results);
526
+
527
+ for (const auto & task : tasks) {
528
+ SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
529
+ waiting_task_ids.insert(task.id);
530
+ }
531
+ }
532
+
564
533
  // when the request is finished, we can remove task associated with it
565
534
  void remove_waiting_task_id(int id_task) {
566
- LOG_VERBOSE("remove waiting for task id", {{"id_task", id_task}});
535
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
567
536
 
568
537
  std::unique_lock<std::mutex> lock(mutex_results);
569
538
  waiting_task_ids.erase(id_task);
570
539
  }
571
540
 
572
- // This function blocks the thread until there is a response for this id_task
573
- server_task_result recv(int id_task) {
541
+ void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
542
+ std::unique_lock<std::mutex> lock(mutex_results);
543
+
544
+ for (const auto & id_task : id_tasks) {
545
+ SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
546
+ waiting_task_ids.erase(id_task);
547
+ }
548
+ }
549
+
550
+ // This function blocks the thread until there is a response for one of the id_tasks
551
+ server_task_result recv(const std::unordered_set<int> & id_tasks) {
574
552
  while (true) {
575
553
  std::unique_lock<std::mutex> lock(mutex_results);
576
554
  condition_results.wait(lock, [&]{
@@ -578,8 +556,7 @@ struct server_response {
578
556
  });
579
557
 
580
558
  for (int i = 0; i < (int) queue_results.size(); i++) {
581
- if (queue_results[i].id == id_task) {
582
- assert(queue_results[i].id_multi == -1);
559
+ if (id_tasks.find(queue_results[i].id) != id_tasks.end()) {
583
560
  server_task_result res = queue_results[i];
584
561
  queue_results.erase(queue_results.begin() + i);
585
562
  return res;
@@ -590,28 +567,22 @@ struct server_response {
590
567
  // should never reach here
591
568
  }
592
569
 
593
- // Register the function to update multitask
594
- void on_multitask_update(callback_multitask_t callback) {
595
- callback_update_multitask = std::move(callback);
570
+ // single-task version of recv()
571
+ server_task_result recv(int id_task) {
572
+ std::unordered_set<int> id_tasks = {id_task};
573
+ return recv(id_tasks);
596
574
  }
597
575
 
598
576
  // Send a new result to a waiting id_task
599
- void send(server_task_result result) {
600
- LOG_VERBOSE("send new result", {{"id_task", result.id}});
577
+ void send(server_task_result & result) {
578
+ SRV_DBG("sending result for task id = %d\n", result.id);
601
579
 
602
580
  std::unique_lock<std::mutex> lock(mutex_results);
603
581
  for (const auto & id_task : waiting_task_ids) {
604
- // LOG_TEE("waiting task id %i \n", id_task);
605
- // for now, tasks that have associated parent multitasks just get erased once multitask picks up the result
606
- if (result.id_multi == id_task) {
607
- LOG_VERBOSE("callback_update_multitask", {{"id_task", id_task}});
608
- callback_update_multitask(id_task, result.id, result);
609
- continue;
610
- }
611
-
612
582
  if (result.id == id_task) {
613
- LOG_VERBOSE("queue_results.push_back", {{"id_task", id_task}});
614
- queue_results.push_back(result);
583
+ SRV_DBG("task id = %d moved to result queue\n", result.id);
584
+
585
+ queue_results.push_back(std::move(result));
615
586
  condition_results.notify_all();
616
587
  return;
617
588
  }
@@ -622,22 +593,18 @@ struct server_response {
622
593
  struct server_context {
623
594
  llama_model * model = nullptr;
624
595
  llama_context * ctx = nullptr;
596
+ std::vector<common_lora_adapter_container> loras;
625
597
 
626
- gpt_params params;
598
+ common_params params;
627
599
 
628
- llama_batch batch;
600
+ llama_batch batch = {};
629
601
 
630
602
  bool clean_kv_cache = true;
631
603
  bool add_bos_token = true;
604
+ bool has_eos_token = false;
632
605
 
633
606
  int32_t n_ctx; // total context for all clients / slots
634
607
 
635
- // system prompt
636
- bool system_need_update = false;
637
-
638
- std::string system_prompt;
639
- std::vector<llama_token> system_tokens;
640
-
641
608
  // slots / clients
642
609
  std::vector<server_slot> slots;
643
610
  json default_generation_settings_for_props;
@@ -663,47 +630,53 @@ struct server_context {
663
630
 
664
631
  // Clear any sampling context
665
632
  for (server_slot & slot : slots) {
666
- if (slot.ctx_sampling != nullptr) {
667
- llama_sampling_free(slot.ctx_sampling);
633
+ if (slot.smpl != nullptr) {
634
+ common_sampler_free(slot.smpl);
668
635
  }
669
636
  }
670
637
 
671
638
  llama_batch_free(batch);
672
639
  }
673
640
 
674
- bool load_model(const gpt_params & params_) {
641
+ bool load_model(const common_params & params_) {
675
642
  params = params_;
676
643
 
677
- // dedicate one sequence to the system prompt
678
- params.n_parallel += 1;
644
+ common_init_result llama_init = common_init_from_params(params);
645
+
646
+ model = llama_init.model;
647
+ ctx = llama_init.context;
648
+ loras = llama_init.lora_adapters;
679
649
 
680
- std::tie(model, ctx) = llama_init_from_gpt_params(params);
681
- params.n_parallel -= 1; // but be sneaky about it
682
650
  if (model == nullptr) {
683
- LOG_ERROR("unable to load model", {{"model", params.model}});
651
+ SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
684
652
  return false;
685
653
  }
686
654
 
687
655
  n_ctx = llama_n_ctx(ctx);
688
656
 
689
- add_bos_token = llama_should_add_bos_token(model);
690
- GGML_ASSERT(llama_add_eos_token(model) != 1);
657
+ add_bos_token = llama_add_bos_token(model);
658
+ has_eos_token = !llama_add_eos_token(model);
691
659
 
692
660
  return true;
693
661
  }
694
662
 
695
663
  bool validate_model_chat_template() const {
696
- llama_chat_message chat[] = {{"user", "test"}};
697
-
698
- const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
699
-
700
- return res > 0;
664
+ std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
665
+ std::string template_key = "tokenizer.chat_template";
666
+ int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
667
+ if (res >= 0) {
668
+ llama_chat_message chat[] = {{"user", "test"}};
669
+ std::string tmpl = std::string(model_template.data(), model_template.size());
670
+ int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
671
+ return chat_res > 0;
672
+ }
673
+ return false;
701
674
  }
702
675
 
703
676
  void init() {
704
677
  const int32_t n_ctx_slot = n_ctx / params.n_parallel;
705
678
 
706
- LOG_INFO("initializing slots", {{"n_slots", params.n_parallel}});
679
+ SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
707
680
 
708
681
  for (int i = 0; i < params.n_parallel; i++) {
709
682
  server_slot slot;
@@ -712,33 +685,14 @@ struct server_context {
712
685
  slot.n_ctx = n_ctx_slot;
713
686
  slot.n_predict = params.n_predict;
714
687
 
715
- LOG_INFO("new slot", {
716
- {"id_slot", slot.id},
717
- {"n_ctx_slot", slot.n_ctx}
718
- });
719
-
720
- const int ga_n = params.grp_attn_n;
721
- const int ga_w = params.grp_attn_w;
722
-
723
- if (ga_n != 1) {
724
- GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
725
- GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
726
- //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
727
- //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
728
-
729
- LOG_INFO("slot self-extend", {
730
- {"id_slot", slot.id},
731
- {"ga_n", ga_n},
732
- {"ga_w", ga_w}
733
- });
734
- }
735
-
736
- slot.ga_i = 0;
737
- slot.ga_n = ga_n;
738
- slot.ga_w = ga_w;
688
+ SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
739
689
 
740
690
  slot.sparams = params.sparams;
741
691
 
692
+ slot.callback_on_release = [this](int) {
693
+ queue_tasks.pop_deferred_task();
694
+ };
695
+
742
696
  slot.reset();
743
697
 
744
698
  slots.push_back(slot);
@@ -747,59 +701,18 @@ struct server_context {
747
701
  default_generation_settings_for_props = get_formated_generation(slots.front());
748
702
  default_generation_settings_for_props["seed"] = -1;
749
703
 
750
- // the update_slots() logic will always submit a maximum of n_batch tokens
704
+ // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
751
705
  // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
752
706
  {
753
707
  const int32_t n_batch = llama_n_batch(ctx);
754
708
 
755
709
  // only a single seq_id per token is needed
756
- batch = llama_batch_init(n_batch, 0, 1);
710
+ batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
757
711
  }
758
712
 
759
713
  metrics.init();
760
714
  }
761
715
 
762
- std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
763
- // TODO: currently, we tokenize using special tokens by default
764
- // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
765
- // but it's better compared to completely ignoring ChatML and other chat templates
766
- const bool TMP_FORCE_SPECIAL = true;
767
-
768
- // If `add_bos` is true, we only add BOS, when json_prompt is a string,
769
- // or the first element of the json_prompt array is a string.
770
- std::vector<llama_token> prompt_tokens;
771
-
772
- if (json_prompt.is_array()) {
773
- bool first = true;
774
- for (const auto & p : json_prompt) {
775
- if (p.is_string()) {
776
- auto s = p.template get<std::string>();
777
-
778
- std::vector<llama_token> p;
779
- if (first) {
780
- p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
781
- first = false;
782
- } else {
783
- p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
784
- }
785
-
786
- prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
787
- } else {
788
- if (first) {
789
- first = false;
790
- }
791
-
792
- prompt_tokens.push_back(p.template get<llama_token>());
793
- }
794
- }
795
- } else {
796
- auto s = json_prompt.template get<std::string>();
797
- prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
798
- }
799
-
800
- return prompt_tokens;
801
- }
802
-
803
716
  server_slot * get_slot_by_id(int id) {
804
717
  for (server_slot & slot : slots) {
805
718
  if (slot.id == id) {
@@ -810,50 +723,41 @@ struct server_context {
810
723
  return nullptr;
811
724
  }
812
725
 
813
- server_slot * get_available_slot(const std::string & prompt) {
726
+ server_slot * get_available_slot(const server_task & task) {
814
727
  server_slot * ret = nullptr;
815
728
 
816
729
  // find the slot that has at least n% prompt similarity
817
- if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
818
- int max_lcp_len = 0;
730
+ if (ret == nullptr && slot_prompt_similarity != 0.0f) {
731
+ int lcs_len = 0;
819
732
  float similarity = 0;
820
733
 
821
734
  for (server_slot & slot : slots) {
822
735
  // skip the slot if it is not available
823
- if (!slot.available()) {
736
+ if (slot.is_processing()) {
824
737
  continue;
825
738
  }
826
739
 
827
- // skip the slot if it does not contains prompt
828
- if (!slot.prompt.is_string()) {
740
+ // skip the slot if it does not contains cached tokens
741
+ if (slot.cache_tokens.empty()) {
829
742
  continue;
830
743
  }
831
744
 
832
- // current slot's prompt
833
- std::string slot_prompt = slot.prompt.get<std::string>();
834
-
835
- // length of the current slot's prompt
836
- int slot_prompt_len = slot_prompt.size();
745
+ // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
746
+ int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
837
747
 
838
- // length of the Longest Common Prefix between the current slot's prompt and the input prompt
839
- int lcp_len = common_part(slot_prompt, prompt);
840
-
841
- // fraction of the common substring length compared to the current slot's prompt length
842
- similarity = static_cast<float>(lcp_len) / slot_prompt_len;
748
+ // fraction of the common subsequence length compared to the current slot's prompt length
749
+ float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
843
750
 
844
751
  // select the current slot if the criteria match
845
- if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
846
- max_lcp_len = lcp_len;
752
+ if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
753
+ lcs_len = cur_lcs_len;
754
+ similarity = cur_similarity;
847
755
  ret = &slot;
848
756
  }
849
757
  }
850
758
 
851
759
  if (ret != nullptr) {
852
- LOG_VERBOSE("selected slot by lcp similarity", {
853
- {"id_slot", ret->id},
854
- {"max_lcp_len", max_lcp_len},
855
- {"similarity", similarity},
856
- });
760
+ SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity);
857
761
  }
858
762
  }
859
763
 
@@ -862,7 +766,7 @@ struct server_context {
862
766
  int64_t t_last = ggml_time_us();
863
767
  for (server_slot & slot : slots) {
864
768
  // skip the slot if it is not available
865
- if (!slot.available()) {
769
+ if (slot.is_processing()) {
866
770
  continue;
867
771
  }
868
772
 
@@ -874,10 +778,7 @@ struct server_context {
874
778
  }
875
779
 
876
780
  if (ret != nullptr) {
877
- LOG_VERBOSE("selected slot by lru", {
878
- {"id_slot", ret->id},
879
- {"t_last", t_last},
880
- });
781
+ SLT_DBG(*ret, "selected slot by lru, t_last = %" PRId64 "\n", t_last);
881
782
  }
882
783
  }
883
784
 
@@ -887,8 +788,8 @@ struct server_context {
887
788
  bool launch_slot_with_task(server_slot & slot, const server_task & task) {
888
789
  slot_params default_params;
889
790
  // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
890
- llama_sampling_params default_sparams = params.sparams;
891
- auto & data = task.data;
791
+ auto default_sparams = params.sparams;
792
+ const auto & data = task.data;
892
793
 
893
794
  if (data.count("__oaicompat") != 0) {
894
795
  slot.oaicompat = true;
@@ -898,133 +799,86 @@ struct server_context {
898
799
  slot.oaicompat_model = "";
899
800
  }
900
801
 
901
- slot.params.stream = json_value(data, "stream", false);
902
- slot.params.cache_prompt = json_value(data, "cache_prompt", false);
903
- slot.params.n_predict = json_value(data, "n_predict", default_params.n_predict);
904
- slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
905
- slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
906
- slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
907
- slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
908
- slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
909
- slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
910
- slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
911
- slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
912
- slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
913
- slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
914
- slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
915
- slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
916
- slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
917
- slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
918
- slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
919
- slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
920
- slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
921
- slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
922
- slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
923
- slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
924
- slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
802
+ slot.params.stream = json_value(data, "stream", false);
803
+ slot.params.cache_prompt = json_value(data, "cache_prompt", false);
804
+ slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
805
+ slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
806
+ slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
807
+ slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
808
+ slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
809
+ slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
810
+ slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
811
+ slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
812
+ slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
813
+ slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
814
+ slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
815
+ slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
816
+ slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
817
+ slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
818
+ slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
819
+ slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
820
+ slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
821
+ slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
822
+ slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
823
+ slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
824
+ slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
825
+ slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
826
+ slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
827
+ slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
828
+ slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
829
+ slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
830
+ slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
831
+ slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
832
+ //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
833
+ slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
834
+
835
+ if (slot.sparams.dry_base < 1.0f)
836
+ {
837
+ slot.sparams.dry_base = default_sparams.dry_base;
838
+ }
839
+
840
+ // sequence breakers for DRY
841
+ {
842
+ // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
843
+ // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
844
+
845
+ if (data.contains("dry_sequence_breakers")) {
846
+ slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
847
+ if (slot.sparams.dry_sequence_breakers.empty()) {
848
+ send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
849
+ return false;
850
+ }
851
+ }
852
+ }
925
853
 
926
854
  // process "json_schema" and "grammar"
927
855
  if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
928
856
  send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
929
857
  return false;
930
- } else if (data.contains("json_schema") && !data.contains("grammar")) {
858
+ }
859
+ if (data.contains("json_schema") && !data.contains("grammar")) {
931
860
  try {
932
- auto schema = json_value(data, "json_schema", json::object());
933
- slot.sparams.grammar = json_schema_to_grammar(schema);
861
+ auto schema = json_value(data, "json_schema", json::object());
862
+ slot.sparams.grammar = json_schema_to_grammar(schema);
934
863
  } catch (const std::exception & e) {
935
864
  send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
936
865
  return false;
937
866
  }
938
867
  } else {
939
- slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
940
- }
941
-
942
- if (slot.params.cache_prompt && slot.ga_n != 1) {
943
- LOG_WARNING("cache_prompt is not supported with group-attention", {});
944
- slot.params.cache_prompt = false;
868
+ slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
945
869
  }
946
870
 
947
871
  if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
948
872
  // Might be better to reject the request with a 400 ?
949
- LOG_WARNING("Max tokens to predict exceeds server configuration", {
950
- {"params.n_predict", slot.params.n_predict},
951
- {"slot.n_predict", slot.n_predict},
952
- });
953
873
  slot.params.n_predict = slot.n_predict;
954
- }
955
-
956
- // infill
957
- slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
958
- slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
959
-
960
- // get prompt
961
- if (!task.infill) {
962
- const auto & prompt = data.find("prompt");
963
- if (prompt == data.end()) {
964
- send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
965
- return false;
966
- }
967
-
968
- if ((prompt->is_string()) ||
969
- (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
970
- (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
971
- slot.prompt = *prompt;
972
- } else {
973
- send_error(task, "\"prompt\" must be a string or an array of integers", ERROR_TYPE_INVALID_REQUEST);
974
- return false;
975
- }
976
- }
977
-
978
- // penalize user-provided tokens
979
- {
980
- slot.sparams.penalty_prompt_tokens.clear();
981
- slot.sparams.use_penalty_prompt_tokens = false;
982
-
983
- const auto & penalty_prompt = data.find("penalty_prompt");
984
-
985
- if (penalty_prompt != data.end()) {
986
- if (penalty_prompt->is_string()) {
987
- const auto penalty_prompt_string = penalty_prompt->get<std::string>();
988
- slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
989
-
990
- if (slot.params.n_predict > 0) {
991
- slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
992
- }
993
- slot.sparams.use_penalty_prompt_tokens = true;
994
-
995
- LOG_VERBOSE("penalty_prompt_tokens", {
996
- {"id_slot", slot.id},
997
- {"tokens", slot.sparams.penalty_prompt_tokens},
998
- });
999
- }
1000
- else if (penalty_prompt->is_array()) {
1001
- const auto n_tokens = penalty_prompt->size();
1002
- slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
1003
-
1004
- const int n_vocab = llama_n_vocab(model);
1005
- for (const auto & penalty_token : *penalty_prompt) {
1006
- if (penalty_token.is_number_integer()) {
1007
- const auto tok = penalty_token.get<llama_token>();
1008
- if (tok >= 0 && tok < n_vocab) {
1009
- slot.sparams.penalty_prompt_tokens.push_back(tok);
1010
- }
1011
- }
1012
- }
1013
- slot.sparams.use_penalty_prompt_tokens = true;
1014
-
1015
- LOG_VERBOSE("penalty_prompt_tokens", {
1016
- {"id_slot", slot.id},
1017
- {"tokens", slot.sparams.penalty_prompt_tokens},
1018
- });
1019
- }
1020
- }
874
+ SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
1021
875
  }
1022
876
 
1023
877
  {
1024
878
  slot.sparams.logit_bias.clear();
1025
879
 
1026
- if (json_value(data, "ignore_eos", false)) {
1027
- slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
880
+ if (json_value(data, "ignore_eos", false) && has_eos_token) {
881
+ slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
1028
882
  }
1029
883
 
1030
884
  const auto & logit_bias = data.find("logit_bias");
@@ -1045,12 +899,12 @@ struct server_context {
1045
899
  if (el[0].is_number_integer()) {
1046
900
  llama_token tok = el[0].get<llama_token>();
1047
901
  if (tok >= 0 && tok < n_vocab) {
1048
- slot.sparams.logit_bias[tok] = bias;
902
+ slot.sparams.logit_bias.push_back({tok, bias});
1049
903
  }
1050
904
  } else if (el[0].is_string()) {
1051
- auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
905
+ auto toks = common_tokenize(model, el[0].get<std::string>(), false);
1052
906
  for (auto tok : toks) {
1053
- slot.sparams.logit_bias[tok] = bias;
907
+ slot.sparams.logit_bias.push_back({tok, bias});
1054
908
  }
1055
909
  }
1056
910
  }
@@ -1072,128 +926,65 @@ struct server_context {
1072
926
  }
1073
927
 
1074
928
  {
1075
- const auto & samplers_sequence = data.find("samplers");
1076
- if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
1077
- std::vector<std::string> sampler_names;
1078
- for (const auto & sampler_name : *samplers_sequence) {
1079
- if (sampler_name.is_string()) {
1080
- sampler_names.emplace_back(sampler_name);
929
+ const auto & samplers = data.find("samplers");
930
+ if (samplers != data.end()) {
931
+ if (samplers->is_array()) {
932
+ std::vector<std::string> sampler_names;
933
+ for (const auto & name : *samplers) {
934
+ if (name.is_string()) {
935
+ sampler_names.emplace_back(name);
936
+ }
1081
937
  }
938
+ slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false);
939
+ } else if (samplers->is_string()){
940
+ std::string sampler_string;
941
+ for (const auto & name : *samplers) {
942
+ sampler_string += name;
943
+ }
944
+ slot.sparams.samplers = common_sampler_types_from_chars(sampler_string);
1082
945
  }
1083
- slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
1084
946
  } else {
1085
- slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
947
+ slot.sparams.samplers = default_sparams.samplers;
1086
948
  }
1087
949
  }
1088
950
 
1089
951
  {
1090
- if (slot.ctx_sampling != nullptr) {
1091
- llama_sampling_free(slot.ctx_sampling);
952
+ if (slot.smpl != nullptr) {
953
+ common_sampler_free(slot.smpl);
1092
954
  }
1093
- slot.ctx_sampling = llama_sampling_init(slot.sparams);
1094
- if (slot.ctx_sampling == nullptr) {
955
+
956
+ slot.smpl = common_sampler_init(model, slot.sparams);
957
+ if (slot.smpl == nullptr) {
1095
958
  // for now, the only error that may happen here is invalid grammar
1096
959
  send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
1097
960
  return false;
1098
961
  }
1099
962
  }
1100
963
 
1101
- slot.command = SLOT_COMMAND_LOAD_PROMPT;
1102
- slot.prompt_tokens.clear();
964
+ slot.state = SLOT_STATE_STARTED;
1103
965
 
1104
- LOG_INFO("slot is processing task", {
1105
- {"id_slot", slot.id},
1106
- {"id_task", slot.id_task},
1107
- });
966
+ SLT_INF(slot, "%s", "processing task\n");
1108
967
 
1109
968
  return true;
1110
969
  }
1111
970
 
1112
971
  void kv_cache_clear() {
1113
- LOG_VERBOSE("clearing KV cache", {});
972
+ SRV_DBG("%s", "clearing KV cache\n");
1114
973
 
1115
974
  // clear the entire KV cache
1116
975
  llama_kv_cache_clear(ctx);
1117
976
  clean_kv_cache = false;
1118
977
  }
1119
978
 
1120
- void system_prompt_update() {
1121
- LOG_VERBOSE("system prompt update", {
1122
- {"system_prompt", system_prompt},
1123
- });
1124
-
1125
- kv_cache_clear();
1126
- system_tokens.clear();
1127
-
1128
- if (!system_prompt.empty()) {
1129
- system_tokens = ::llama_tokenize(ctx, system_prompt, true);
1130
-
1131
- llama_batch_clear(batch);
1132
-
1133
- for (int i = 0; i < (int)system_tokens.size(); ++i) {
1134
- llama_batch_add(batch, system_tokens[i], i, { 0 }, false);
1135
- }
1136
-
1137
- const int32_t n_batch = llama_n_batch(ctx);
1138
-
1139
- for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
1140
- const int32_t n_tokens = std::min(params.n_batch, batch.n_tokens - i);
1141
- llama_batch batch_view = {
1142
- n_tokens,
1143
- batch.token + i,
1144
- nullptr,
1145
- batch.pos + i,
1146
- batch.n_seq_id + i,
1147
- batch.seq_id + i,
1148
- batch.logits + i,
1149
- 0, 0, 0, // unused
1150
- };
1151
-
1152
- if (llama_decode(ctx, batch_view) != 0) {
1153
- LOG_ERROR("llama_decode() failed", {});
1154
- return;
1155
- }
1156
- }
1157
-
1158
- // assign the system KV cache to all parallel sequences
1159
- for (int32_t i = 1; i <= params.n_parallel; ++i) {
1160
- llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
1161
- }
1162
- }
1163
-
1164
- system_need_update = false;
1165
- }
1166
-
1167
- bool system_prompt_set(const std::string & sys_prompt) {
1168
- system_prompt = sys_prompt;
1169
-
1170
- LOG_VERBOSE("system prompt process", {
1171
- {"system_prompt", system_prompt},
1172
- });
1173
-
1174
- // release all slots
1175
- for (server_slot & slot : slots) {
1176
- slot.release();
1177
- }
1178
-
1179
- system_need_update = true;
1180
- return true;
1181
- }
1182
-
1183
979
  bool process_token(completion_token_output & result, server_slot & slot) {
1184
980
  // remember which tokens were sampled - used for repetition penalties during sampling
1185
- const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
981
+ const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
1186
982
  slot.sampled = result.tok;
1187
983
 
1188
984
  // search stop word and delete it
1189
985
  slot.generated_text += token_str;
1190
986
  slot.has_next_token = true;
1191
987
 
1192
- if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
1193
- // we can change penalty_prompt_tokens because it is always created from scratch each request
1194
- slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
1195
- }
1196
-
1197
988
  // check if there is incomplete UTF-8 character at the end
1198
989
  bool incomplete = false;
1199
990
  for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
@@ -1220,29 +1011,28 @@ struct server_context {
1220
1011
  size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
1221
1012
 
1222
1013
  const std::string str_test = slot.generated_text.substr(pos);
1223
- bool is_stop_full = false;
1014
+ bool send_text = true;
1224
1015
 
1225
1016
  size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
1226
1017
  if (stop_pos != std::string::npos) {
1227
- is_stop_full = true;
1228
1018
  slot.generated_text.erase(
1229
1019
  slot.generated_text.begin() + pos + stop_pos,
1230
1020
  slot.generated_text.end());
1231
1021
  pos = std::min(slot.n_sent_text, slot.generated_text.size());
1232
- } else {
1233
- is_stop_full = false;
1022
+ } else if (slot.has_next_token) {
1234
1023
  stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
1024
+ send_text = stop_pos == std::string::npos;
1235
1025
  }
1236
1026
 
1237
1027
  // check if there is any token to predict
1238
- if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
1028
+ if (send_text) {
1239
1029
  // no send the stop word in the response
1240
1030
  result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
1241
1031
  slot.n_sent_text += result.text_to_send.size();
1242
1032
  // add the token to slot queue and cache
1243
1033
  }
1244
1034
 
1245
- slot.add_token_string(result);
1035
+ slot.add_token(result);
1246
1036
  if (slot.params.stream) {
1247
1037
  send_partial_response(slot, result);
1248
1038
  }
@@ -1257,124 +1047,155 @@ struct server_context {
1257
1047
  slot.stopped_limit = true;
1258
1048
  slot.has_next_token = false;
1259
1049
 
1260
- LOG_VERBOSE("stopped by limit", {
1261
- {"id_slot", slot.id},
1262
- {"id_task", slot.id_task},
1263
- {"n_decoded", slot.n_decoded},
1264
- {"n_predict", slot.params.n_predict},
1265
- });
1050
+ SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
1051
+ }
1052
+
1053
+ if (slot.has_new_line) {
1054
+ // if we have already seen a new line, we stop after a certain time limit
1055
+ if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
1056
+ slot.stopped_limit = true;
1057
+ slot.has_next_token = false;
1058
+
1059
+ SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
1060
+ }
1061
+
1062
+ // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
1063
+ if (slot.params.n_indent > 0) {
1064
+ // check the current indentation
1065
+ // TODO: improve by not doing it more than once for each new line
1066
+ if (slot.last_nl_pos > 0) {
1067
+ size_t pos = slot.last_nl_pos;
1068
+
1069
+ int n_indent = 0;
1070
+ while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
1071
+ n_indent++;
1072
+ pos++;
1073
+ }
1074
+
1075
+ if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
1076
+ slot.stopped_limit = true;
1077
+ slot.has_next_token = false;
1078
+
1079
+ // cut the last line
1080
+ slot.generated_text.erase(pos, std::string::npos);
1081
+
1082
+ SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
1083
+ }
1084
+ }
1085
+
1086
+ // find the next new line
1087
+ {
1088
+ const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
1089
+
1090
+ if (pos != std::string::npos) {
1091
+ slot.last_nl_pos = pos + 1;
1092
+ }
1093
+ }
1094
+ }
1095
+ }
1096
+
1097
+ // check if there is a new line in the generated text
1098
+ if (result.text_to_send.find('\n') != std::string::npos) {
1099
+ slot.has_new_line = true;
1100
+ }
1101
+
1102
+ // if context shift is disabled, we stop when it reaches the context limit
1103
+ if (slot.n_past >= slot.n_ctx) {
1104
+ slot.truncated = true;
1105
+ slot.stopped_limit = true;
1106
+ slot.has_next_token = false;
1107
+
1108
+ SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
1109
+ slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
1266
1110
  }
1267
1111
 
1268
1112
  if (llama_token_is_eog(model, result.tok)) {
1269
1113
  slot.stopped_eos = true;
1270
1114
  slot.has_next_token = false;
1271
1115
 
1272
- LOG_VERBOSE("eos token found", {});
1273
- }
1274
-
1275
- auto n_ctx_train = llama_n_ctx_train(model);
1276
- if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1
1277
- && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1278
- LOG_WARNING("n_predict is not set and self-context extend is disabled."
1279
- " Limiting generated tokens to n_ctx_train to avoid EOS-less generation infinite loop", {
1280
- { "id_slot", slot.id },
1281
- { "params.n_predict", slot.params.n_predict },
1282
- { "slot.n_prompt_tokens", slot.n_prompt_tokens },
1283
- { "slot.n_decoded", slot.n_decoded },
1284
- { "slot.n_predict", slot.n_predict },
1285
- { "n_slots", params.n_parallel },
1286
- { "slot.n_ctx", slot.n_ctx },
1287
- { "n_ctx", n_ctx },
1288
- { "n_ctx_train", n_ctx_train },
1289
- { "ga_n", slot.ga_n },
1290
- });
1116
+ SLT_DBG(slot, "%s", "stopped by EOS\n");
1117
+ }
1118
+
1119
+ const auto n_ctx_train = llama_n_ctx_train(model);
1120
+
1121
+ if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1291
1122
  slot.truncated = true;
1292
1123
  slot.stopped_limit = true;
1293
1124
  slot.has_next_token = false; // stop prediction
1125
+
1126
+ SLT_WRN(slot,
1127
+ "n_predict (%d) is set for infinite generation. "
1128
+ "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
1129
+ slot.params.n_predict, n_ctx_train);
1294
1130
  }
1295
1131
 
1296
- LOG_VERBOSE("next token", {
1297
- {"id_slot", slot.id},
1298
- {"id_task", slot.id_task},
1299
- {"token", result.tok},
1300
- {"token_text", tokens_to_output_formatted_string(ctx, result.tok)},
1301
- {"has_next_token", slot.has_next_token},
1302
- {"n_remain", slot.n_remaining},
1303
- {"n_decoded", slot.n_decoded},
1304
- {"stopped_eos", slot.stopped_eos},
1305
- {"stopped_word", slot.stopped_word},
1306
- {"stopped_limit", slot.stopped_limit},
1307
- {"stopping_word", slot.stopping_word},
1308
- });
1132
+ SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
1309
1133
 
1310
1134
  return slot.has_next_token; // continue
1311
1135
  }
1312
1136
 
1313
1137
  json get_formated_generation(const server_slot & slot) const {
1314
- const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
1315
- const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
1316
-
1317
- std::vector<std::string> samplers_sequence;
1318
- samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
1319
- for (const auto & sampler_type : slot.sparams.samplers_sequence) {
1320
- samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
1138
+ std::vector<std::string> samplers;
1139
+ samplers.reserve(slot.sparams.samplers.size());
1140
+ for (const auto & sampler : slot.sparams.samplers) {
1141
+ samplers.emplace_back(common_sampler_type_to_str(sampler));
1321
1142
  }
1322
1143
 
1323
1144
  return json {
1324
1145
  {"n_ctx", slot.n_ctx},
1325
- {"n_predict", slot.n_predict},
1146
+ {"n_predict", slot.n_predict}, // Server configured n_predict
1326
1147
  {"model", params.model_alias},
1327
1148
  {"seed", slot.sparams.seed},
1149
+ {"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
1328
1150
  {"temperature", slot.sparams.temp},
1329
1151
  {"dynatemp_range", slot.sparams.dynatemp_range},
1330
1152
  {"dynatemp_exponent", slot.sparams.dynatemp_exponent},
1331
1153
  {"top_k", slot.sparams.top_k},
1332
1154
  {"top_p", slot.sparams.top_p},
1333
1155
  {"min_p", slot.sparams.min_p},
1334
- {"tfs_z", slot.sparams.tfs_z},
1335
- {"typical_p", slot.sparams.typical_p},
1156
+ {"xtc_probability", slot.sparams.xtc_probability},
1157
+ {"xtc_threshold", slot.sparams.xtc_threshold},
1158
+ {"typical_p", slot.sparams.typ_p},
1336
1159
  {"repeat_last_n", slot.sparams.penalty_last_n},
1337
1160
  {"repeat_penalty", slot.sparams.penalty_repeat},
1338
1161
  {"presence_penalty", slot.sparams.penalty_present},
1339
1162
  {"frequency_penalty", slot.sparams.penalty_freq},
1340
- {"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
1341
- {"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
1163
+ {"dry_multiplier", slot.sparams.dry_multiplier},
1164
+ {"dry_base", slot.sparams.dry_base},
1165
+ {"dry_allowed_length", slot.sparams.dry_allowed_length},
1166
+ {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
1167
+ {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
1342
1168
  {"mirostat", slot.sparams.mirostat},
1343
1169
  {"mirostat_tau", slot.sparams.mirostat_tau},
1344
1170
  {"mirostat_eta", slot.sparams.mirostat_eta},
1345
1171
  {"penalize_nl", slot.sparams.penalize_nl},
1346
1172
  {"stop", slot.params.antiprompt},
1347
- {"n_predict", slot.params.n_predict}, // TODO: fix duplicate key n_predict
1173
+ {"max_tokens", slot.params.n_predict}, // User configured n_predict
1348
1174
  {"n_keep", slot.params.n_keep},
1349
1175
  {"n_discard", slot.params.n_discard},
1350
- {"ignore_eos", ignore_eos},
1176
+ {"ignore_eos", slot.sparams.ignore_eos},
1351
1177
  {"stream", slot.params.stream},
1352
- {"logit_bias", slot.sparams.logit_bias},
1178
+ //{"logit_bias", slot.sparams.logit_bias},
1353
1179
  {"n_probs", slot.sparams.n_probs},
1354
1180
  {"min_keep", slot.sparams.min_keep},
1355
1181
  {"grammar", slot.sparams.grammar},
1356
- {"samplers", samplers_sequence}
1182
+ {"samplers", samplers},
1357
1183
  };
1358
1184
  }
1359
1185
 
1360
1186
  void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1361
- send_error(task.id, task.id_multi, error, type);
1187
+ send_error(task.id, error, type);
1362
1188
  }
1363
1189
 
1364
1190
  void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1365
- send_error(slot.id_task, slot.id_multi, error, type);
1191
+ send_error(slot.id_task, error, type);
1366
1192
  }
1367
1193
 
1368
- void send_error(const int id_task, const int id_multi, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1369
- LOG_ERROR("task error", {
1370
- {"id_multi", id_multi},
1371
- {"id_task", id_task},
1372
- {"error", error},
1373
- });
1194
+ void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1195
+ SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
1374
1196
 
1375
1197
  server_task_result res;
1376
1198
  res.id = id_task;
1377
- res.id_multi = id_multi;
1378
1199
  res.stop = false;
1379
1200
  res.error = true;
1380
1201
  res.data = format_error_response(error, type);
@@ -1385,18 +1206,18 @@ struct server_context {
1385
1206
  void send_partial_response(server_slot & slot, completion_token_output tkn) {
1386
1207
  server_task_result res;
1387
1208
  res.id = slot.id_task;
1388
- res.id_multi = slot.id_multi;
1389
1209
  res.error = false;
1390
1210
  res.stop = false;
1391
1211
  res.data = json {
1392
1212
  {"content", tkn.text_to_send},
1393
1213
  {"stop", false},
1394
1214
  {"id_slot", slot.id},
1395
- {"multimodal", false}
1215
+ {"multimodal", false},
1216
+ {"index", slot.index},
1396
1217
  };
1397
1218
 
1398
1219
  if (slot.sparams.n_probs > 0) {
1399
- const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
1220
+ const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
1400
1221
  const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
1401
1222
  const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
1402
1223
 
@@ -1422,7 +1243,6 @@ struct server_context {
1422
1243
  void send_final_response(const server_slot & slot) {
1423
1244
  server_task_result res;
1424
1245
  res.id = slot.id_task;
1425
- res.id_multi = slot.id_multi;
1426
1246
  res.error = false;
1427
1247
  res.stop = true;
1428
1248
  res.data = json {
@@ -1433,20 +1253,22 @@ struct server_context {
1433
1253
  {"tokens_predicted", slot.n_decoded},
1434
1254
  {"tokens_evaluated", slot.n_prompt_tokens},
1435
1255
  {"generation_settings", get_formated_generation(slot)},
1436
- {"prompt", slot.prompt},
1256
+ {"prompt", common_detokenize(ctx, slot.prompt_tokens)},
1257
+ {"has_new_line", slot.has_new_line},
1437
1258
  {"truncated", slot.truncated},
1438
1259
  {"stopped_eos", slot.stopped_eos},
1439
1260
  {"stopped_word", slot.stopped_word},
1440
1261
  {"stopped_limit", slot.stopped_limit},
1441
1262
  {"stopping_word", slot.stopping_word},
1442
1263
  {"tokens_cached", slot.n_past},
1443
- {"timings", slot.get_formated_timings()}
1264
+ {"timings", slot.get_formated_timings()},
1265
+ {"index", slot.index},
1444
1266
  };
1445
1267
 
1446
1268
  if (slot.sparams.n_probs > 0) {
1447
1269
  std::vector<completion_token_output> probs;
1448
1270
  if (!slot.params.stream && slot.stopped_word) {
1449
- const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
1271
+ const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
1450
1272
 
1451
1273
  size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
1452
1274
  probs = std::vector<completion_token_output>(
@@ -1471,17 +1293,16 @@ struct server_context {
1471
1293
 
1472
1294
  void send_embedding(const server_slot & slot, const llama_batch & batch) {
1473
1295
  server_task_result res;
1474
- res.id = slot.id_task;
1475
- res.id_multi = slot.id_multi;
1476
- res.error = false;
1477
- res.stop = true;
1296
+ res.id = slot.id_task;
1297
+ res.error = false;
1298
+ res.stop = true;
1478
1299
 
1479
1300
  const int n_embd = llama_n_embd(model);
1480
1301
 
1481
1302
  std::vector<float> embd_res(n_embd, 0.0f);
1482
1303
 
1483
1304
  for (int i = 0; i < batch.n_tokens; ++i) {
1484
- if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1305
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1485
1306
  continue;
1486
1307
  }
1487
1308
 
@@ -1491,150 +1312,239 @@ struct server_context {
1491
1312
  }
1492
1313
 
1493
1314
  if (embd == NULL) {
1494
- LOG_ERROR("failed to get embeddings", {
1495
- {"token", batch.token [i]},
1496
- {"seq_id", batch.seq_id[i][0]}
1497
- });
1315
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1498
1316
 
1499
1317
  res.data = json {
1500
1318
  {"embedding", std::vector<float>(n_embd, 0.0f)},
1319
+ {"index", slot.index},
1501
1320
  };
1502
1321
 
1503
1322
  continue;
1504
1323
  }
1505
1324
 
1506
- llama_embd_normalize(embd, embd_res.data(), n_embd);
1325
+ common_embd_normalize(embd, embd_res.data(), n_embd);
1507
1326
 
1508
1327
  res.data = json {
1509
1328
  {"embedding", embd_res},
1329
+ {"index", slot.index},
1510
1330
  };
1511
1331
  }
1512
1332
 
1333
+ SLT_DBG(slot, "%s", "sending embeddings\n");
1334
+
1513
1335
  queue_results.send(res);
1514
1336
  }
1515
1337
 
1516
- void request_completion(int id_task, int id_multi, json data, bool infill, bool embedding) {
1517
- server_task task;
1518
- task.id = id_task;
1519
- task.id_multi = id_multi;
1520
- task.id_target = 0;
1521
- task.data = std::move(data);
1522
- task.infill = infill;
1523
- task.embedding = embedding;
1524
- task.type = SERVER_TASK_TYPE_COMPLETION;
1525
-
1526
- // when a completion task's prompt array is not a singleton, we split it into multiple requests
1527
- // otherwise, it's a single-prompt task, we actually queue it
1528
- // if there's numbers in the prompt array it will be treated as an array of tokens
1529
- if (task.data.count("prompt") != 0 && task.data.at("prompt").size() > 1) {
1530
- bool numbers = false;
1531
- for (const auto & e : task.data.at("prompt")) {
1532
- if (e.is_number()) {
1533
- numbers = true;
1534
- break;
1535
- }
1338
+ void send_rerank(const server_slot & slot, const llama_batch & batch) {
1339
+ server_task_result res;
1340
+ res.id = slot.id_task;
1341
+ res.error = false;
1342
+ res.stop = true;
1343
+
1344
+ for (int i = 0; i < batch.n_tokens; ++i) {
1345
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1346
+ continue;
1536
1347
  }
1537
1348
 
1538
- // NOTE: split_multiprompt_task() does not handle a mix of strings and numbers,
1539
- // it will completely stall the server. I don't know where the bug for this is.
1540
- //
1541
- // if there are numbers, it needs to be treated like a single prompt,
1542
- // queue_tasks handles a mix of strings and numbers just fine.
1543
- if (numbers) {
1544
- queue_tasks.post(task);
1545
- } else {
1546
- split_multiprompt_task(id_task, task);
1349
+ const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
1350
+ if (embd == NULL) {
1351
+ embd = llama_get_embeddings_ith(ctx, i);
1547
1352
  }
1548
- } else {
1549
- queue_tasks.post(task);
1550
- }
1551
- }
1552
1353
 
1553
- void request_cancel(int id_task) {
1554
- server_task task;
1555
- task.type = SERVER_TASK_TYPE_CANCEL;
1556
- task.id_target = id_task;
1354
+ if (embd == NULL) {
1355
+ SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1557
1356
 
1558
- queue_tasks.post(task);
1559
- }
1357
+ res.data = json {
1358
+ {"index", slot.index},
1359
+ {"score", -1e6},
1360
+ };
1560
1361
 
1561
- void split_multiprompt_task(int id_multi, const server_task & multiprompt_task) {
1562
- const int prompt_count = multiprompt_task.data.at("prompt").size();
1563
- if (prompt_count <= 1) {
1564
- send_error(multiprompt_task, "error while handling multiple prompts");
1565
- return;
1566
- }
1362
+ continue;
1363
+ }
1567
1364
 
1568
- // generate all the ID for subtask
1569
- std::vector<int> subtask_ids(prompt_count);
1570
- for (int i = 0; i < prompt_count; i++) {
1571
- subtask_ids[i] = queue_tasks.get_new_id();
1365
+ res.data = json {
1366
+ {"index", slot.index},
1367
+ {"score", embd[0]},
1368
+ };
1572
1369
  }
1573
1370
 
1574
- // queue up the multitask so we can track its subtask progression
1575
- queue_tasks.add_multitask(id_multi, subtask_ids);
1371
+ SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
1576
1372
 
1577
- // add subtasks
1578
- for (int i = 0; i < prompt_count; i++) {
1579
- json subtask_data = multiprompt_task.data;
1580
- subtask_data["prompt"] = subtask_data.at("prompt")[i];
1581
-
1582
- // subtasks inherit everything else (infill mode, embedding mode, etc.)
1583
- request_completion(subtask_ids[i], id_multi, subtask_data, multiprompt_task.infill, multiprompt_task.embedding);
1584
- }
1373
+ queue_results.send(res);
1585
1374
  }
1586
1375
 
1587
- void process_single_task(const server_task & task) {
1588
- switch (task.type) {
1589
- case SERVER_TASK_TYPE_COMPLETION:
1590
- {
1591
- const int id_slot = json_value(task.data, "id_slot", -1);
1376
+ //
1377
+ // Functions to create new task(s) and receive result(s)
1378
+ //
1592
1379
 
1593
- server_slot * slot;
1380
+ // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
1381
+ std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
1382
+ std::vector<server_task> tasks;
1383
+ auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
1384
+ SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
1385
+ server_task task;
1386
+ task.id = queue_tasks.get_new_id();
1387
+ task.inf_type = inf_type;
1388
+ task.type = SERVER_TASK_TYPE_INFERENCE;
1389
+ task.data = task_data;
1390
+ task.prompt_tokens = std::move(prompt_tokens);
1391
+ tasks.push_back(std::move(task));
1392
+ };
1594
1393
 
1595
- if (id_slot != -1) {
1596
- slot = get_slot_by_id(id_slot);
1597
- } else {
1598
- std::string prompt;
1599
- if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
1600
- prompt = json_value(task.data, "prompt", std::string());
1601
- }
1394
+ static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
1395
+ if (!data.contains("prompt")) {
1396
+ throw std::runtime_error(error_msg);
1397
+ }
1602
1398
 
1603
- slot = get_available_slot(prompt);
1399
+ // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
1400
+ bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
1401
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
1402
+ switch (inf_type) {
1403
+ case SERVER_TASK_INF_TYPE_RERANK:
1404
+ {
1405
+ // prompts[0] is the question
1406
+ // the rest are the answers/documents
1407
+ GGML_ASSERT(tokenized_prompts.size() > 1);
1408
+ SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
1409
+ for (size_t i = 1; i < tokenized_prompts.size(); i++) {
1410
+ data["index"] = i - 1;
1411
+ auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
1412
+ create_task(data, tokens);
1604
1413
  }
1414
+ } break;
1415
+ case SERVER_TASK_INF_TYPE_INFILL:
1416
+ {
1417
+ SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
1418
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
1419
+ data["index"] = i;
1420
+ auto tokens = format_infill(
1421
+ ctx,
1422
+ data.at("input_prefix"),
1423
+ data.at("input_suffix"),
1424
+ data.at("input_extra"),
1425
+ params.n_batch,
1426
+ params.n_predict,
1427
+ slots[0].n_ctx, // TODO: there should be a better way
1428
+ params.spm_infill,
1429
+ tokenized_prompts[i]
1430
+ );
1431
+ create_task(data, tokens);
1432
+ }
1433
+ } break;
1434
+ default:
1435
+ {
1436
+ SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
1437
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
1438
+ data["index"] = i;
1439
+ create_task(data, tokenized_prompts[i]);
1440
+ }
1441
+ }
1442
+ }
1443
+
1444
+ return tasks;
1445
+ }
1446
+
1447
+ void cancel_tasks(const std::unordered_set<int> & id_tasks) {
1448
+ std::vector<server_task> cancel_tasks;
1449
+ cancel_tasks.reserve(id_tasks.size());
1450
+ for (const auto & id_task : id_tasks) {
1451
+ SRV_WRN("cancel task, id_task = %d\n", id_task);
1452
+
1453
+ server_task task;
1454
+ task.type = SERVER_TASK_TYPE_CANCEL;
1455
+ task.id_target = id_task;
1456
+ cancel_tasks.push_back(task);
1457
+ queue_results.remove_waiting_task_id(id_task);
1458
+ }
1459
+ // push to beginning of the queue, so it has highest priority
1460
+ queue_tasks.post(cancel_tasks, true);
1461
+ }
1462
+
1463
+ // receive the results from task(s) created by create_tasks_inference
1464
+ void receive_cmpl_results(
1465
+ const std::unordered_set<int> & id_tasks,
1466
+ const std::function<void(std::vector<server_task_result>&)> & result_handler,
1467
+ const std::function<void(json)> & error_handler) {
1468
+ // TODO: currently, there is no way to detect the client has cancelled the request
1469
+ std::vector<server_task_result> results(id_tasks.size());
1470
+ for (size_t i = 0; i < id_tasks.size(); i++) {
1471
+ server_task_result result = queue_results.recv(id_tasks);
1472
+
1473
+ if (result.error) {
1474
+ error_handler(result.data);
1475
+ cancel_tasks(id_tasks);
1476
+ return;
1477
+ }
1478
+
1479
+ const size_t idx = result.data["index"];
1480
+ GGML_ASSERT(idx < results.size() && "index out of range");
1481
+
1482
+ results[idx] = result;
1483
+ }
1484
+ result_handler(results);
1485
+ }
1486
+
1487
+ // receive the results from task(s) created by create_tasks_inference, in stream mode
1488
+ void receive_cmpl_results_stream(
1489
+ const std::unordered_set<int> & id_tasks, const
1490
+ std::function<bool(server_task_result&)> & result_handler, const
1491
+ std::function<void(json)> & error_handler) {
1492
+ size_t n_finished = 0;
1493
+ while (true) {
1494
+ server_task_result result = queue_results.recv(id_tasks);
1495
+ if (!result_handler(result)) {
1496
+ cancel_tasks(id_tasks);
1497
+ break;
1498
+ }
1499
+
1500
+ if (result.error) {
1501
+ error_handler(result.data);
1502
+ cancel_tasks(id_tasks);
1503
+ break;
1504
+ }
1505
+
1506
+ if (result.stop) {
1507
+ if (++n_finished == id_tasks.size()) {
1508
+ break;
1509
+ }
1510
+ }
1511
+ }
1512
+ }
1513
+
1514
+ //
1515
+ // Functions to process the task
1516
+ //
1517
+
1518
+ void process_single_task(server_task task) {
1519
+ switch (task.type) {
1520
+ case SERVER_TASK_TYPE_INFERENCE:
1521
+ {
1522
+ const int id_slot = json_value(task.data, "id_slot", -1);
1523
+
1524
+ server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
1605
1525
 
1606
1526
  if (slot == nullptr) {
1607
1527
  // if no slot is available, we defer this task for processing later
1608
- LOG_VERBOSE("no slot is available", {{"id_task", task.id}});
1528
+ SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id);
1609
1529
  queue_tasks.defer(task);
1610
1530
  break;
1611
1531
  }
1612
- if (!slot->available()) {
1532
+ if (slot->is_processing()) {
1613
1533
  // if requested slot is unavailable, we defer this task for processing later
1614
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1534
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1615
1535
  queue_tasks.defer(task);
1616
1536
  break;
1617
1537
  }
1618
1538
 
1619
- if (task.data.contains("system_prompt")) {
1620
- std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
1621
- system_prompt_set(sys_prompt);
1622
-
1623
- for (server_slot & slot : slots) {
1624
- slot.n_past = 0;
1625
- slot.n_past_se = 0;
1626
- }
1627
- }
1628
-
1629
1539
  slot->reset();
1630
1540
 
1631
- slot->id_task = task.id;
1632
- slot->id_multi = task.id_multi;
1633
- slot->infill = task.infill;
1634
- slot->embedding = task.embedding;
1541
+ slot->id_task = task.id;
1542
+ slot->inf_type = task.inf_type;
1543
+ slot->index = json_value(task.data, "index", 0);
1544
+ slot->prompt_tokens = std::move(task.prompt_tokens);
1635
1545
 
1636
1546
  if (!launch_slot_with_task(*slot, task)) {
1637
- LOG_ERROR("error while launching slot", task.data);
1547
+ SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
1638
1548
  break;
1639
1549
  }
1640
1550
  } break;
@@ -1661,12 +1571,13 @@ struct server_context {
1661
1571
 
1662
1572
  for (server_slot & slot : slots) {
1663
1573
  json slot_data = get_formated_generation(slot);
1664
- slot_data["id"] = slot.id;
1665
- slot_data["id_task"] = slot.id_task;
1666
- slot_data["state"] = slot.state;
1667
- slot_data["prompt"] = slot.prompt;
1668
- slot_data["next_token"] = {
1574
+ slot_data["id"] = slot.id;
1575
+ slot_data["id_task"] = slot.id_task;
1576
+ slot_data["is_processing"] = slot.is_processing();
1577
+ slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
1578
+ slot_data["next_token"] = {
1669
1579
  {"has_next_token", slot.has_next_token},
1580
+ {"has_new_line", slot.has_new_line},
1670
1581
  {"n_remain", slot.n_remaining},
1671
1582
  {"n_decoded", slot.n_decoded},
1672
1583
  {"stopped_eos", slot.stopped_eos},
@@ -1675,30 +1586,18 @@ struct server_context {
1675
1586
  {"stopping_word", slot.stopping_word},
1676
1587
  };
1677
1588
 
1678
- if (slot_data["state"] == SLOT_STATE_IDLE) {
1679
- n_idle_slots++;
1680
- } else {
1589
+ if (slot.is_processing()) {
1681
1590
  n_processing_slots++;
1591
+ } else {
1592
+ n_idle_slots++;
1682
1593
  }
1683
1594
 
1684
1595
  slots_data.push_back(slot_data);
1685
1596
  }
1686
- LOG_INFO("slot data", {
1687
- {"id_task", task.id},
1688
- {"n_idle_slots", n_idle_slots},
1689
- {"n_processing_slots", n_processing_slots}
1690
- });
1691
-
1692
- LOG_VERBOSE("slot data", {
1693
- {"id_task", task.id},
1694
- {"n_idle_slots", n_idle_slots},
1695
- {"n_processing_slots", n_processing_slots},
1696
- {"slots", slots_data}
1697
- });
1597
+ SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
1698
1598
 
1699
1599
  server_task_result res;
1700
1600
  res.id = task.id;
1701
- res.id_multi = task.id_multi;
1702
1601
  res.stop = true;
1703
1602
  res.error = false;
1704
1603
  res.data = {
@@ -1717,6 +1616,9 @@ struct server_context {
1717
1616
  { "n_tokens_predicted", metrics.n_tokens_predicted},
1718
1617
  { "t_tokens_generation", metrics.t_tokens_generation},
1719
1618
 
1619
+ { "n_decode_total", metrics.n_decode_total},
1620
+ { "n_busy_slots_total", metrics.n_busy_slots_total},
1621
+
1720
1622
  { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
1721
1623
  { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
1722
1624
 
@@ -1736,9 +1638,9 @@ struct server_context {
1736
1638
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1737
1639
  break;
1738
1640
  }
1739
- if (!slot->available()) {
1641
+ if (slot->is_processing()) {
1740
1642
  // if requested slot is unavailable, we defer this task for processing later
1741
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1643
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1742
1644
  queue_tasks.defer(task);
1743
1645
  break;
1744
1646
  }
@@ -1749,7 +1651,7 @@ struct server_context {
1749
1651
  std::string filename = task.data.at("filename");
1750
1652
  std::string filepath = task.data.at("filepath");
1751
1653
 
1752
- const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
1654
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
1753
1655
 
1754
1656
  const int64_t t_end = ggml_time_us();
1755
1657
  const double t_save_ms = (t_end - t_start) / 1000.0;
@@ -1777,9 +1679,9 @@ struct server_context {
1777
1679
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1778
1680
  break;
1779
1681
  }
1780
- if (!slot->available()) {
1682
+ if (slot->is_processing()) {
1781
1683
  // if requested slot is unavailable, we defer this task for processing later
1782
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1684
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1783
1685
  queue_tasks.defer(task);
1784
1686
  break;
1785
1687
  }
@@ -1791,7 +1693,7 @@ struct server_context {
1791
1693
 
1792
1694
  slot->cache_tokens.resize(slot->n_ctx);
1793
1695
  size_t token_count = 0;
1794
- size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
1696
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
1795
1697
  if (nread == 0) {
1796
1698
  slot->cache_tokens.resize(0);
1797
1699
  send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
@@ -1825,16 +1727,16 @@ struct server_context {
1825
1727
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
1826
1728
  break;
1827
1729
  }
1828
- if (!slot->available()) {
1730
+ if (slot->is_processing()) {
1829
1731
  // if requested slot is unavailable, we defer this task for processing later
1830
- LOG_VERBOSE("requested slot is unavailable", {{"id_task", task.id}});
1732
+ SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id);
1831
1733
  queue_tasks.defer(task);
1832
1734
  break;
1833
1735
  }
1834
1736
 
1835
1737
  // Erase token cache
1836
1738
  const size_t n_erased = slot->cache_tokens.size();
1837
- llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
1739
+ llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
1838
1740
  slot->cache_tokens.clear();
1839
1741
 
1840
1742
  server_task_result result;
@@ -1847,69 +1749,34 @@ struct server_context {
1847
1749
  };
1848
1750
  queue_results.send(result);
1849
1751
  } break;
1752
+ case SERVER_TASK_TYPE_SET_LORA:
1753
+ {
1754
+ common_lora_adapters_apply(ctx, loras);
1755
+ server_task_result result;
1756
+ result.id = task.id;
1757
+ result.stop = true;
1758
+ result.error = false;
1759
+ result.data = json{{ "success", true }};
1760
+ queue_results.send(result);
1761
+ } break;
1850
1762
  }
1851
1763
  }
1852
1764
 
1853
- void on_finish_multitask(const server_task_multi & multitask) {
1854
- // all subtasks done == multitask is done
1855
- server_task_result result;
1856
- result.id = multitask.id;
1857
- result.stop = true;
1858
- result.error = false;
1859
-
1860
- // collect json results into one json result
1861
- std::vector<json> result_jsons;
1862
- for (const auto & subres : multitask.results) {
1863
- result_jsons.push_back(subres.data);
1864
- result.error = result.error && subres.error;
1865
- }
1866
- result.data = json {
1867
- { "results", result_jsons }
1868
- };
1869
-
1870
- queue_results.send(result);
1871
- }
1872
-
1873
1765
  void update_slots() {
1874
- if (system_need_update) {
1875
- system_prompt_update();
1876
- }
1877
-
1878
- // release slots
1879
- for (auto & slot : slots) {
1880
- if (slot.command == SLOT_COMMAND_RELEASE) {
1881
- slot.state = SLOT_STATE_IDLE;
1882
- slot.command = SLOT_COMMAND_NONE;
1883
- slot.t_last_used = ggml_time_us();
1884
-
1885
- LOG_INFO("slot released", {
1886
- {"id_slot", slot.id},
1887
- {"id_task", slot.id_task},
1888
- {"n_ctx", n_ctx},
1889
- {"n_past", slot.n_past},
1890
- {"n_system_tokens", system_tokens.size()},
1891
- {"n_cache_tokens", slot.cache_tokens.size()},
1892
- {"truncated", slot.truncated}
1893
- });
1894
-
1895
- queue_tasks.notify_slot_changed();
1896
- }
1897
- }
1898
-
1899
1766
  // check if all slots are idle
1900
1767
  {
1901
1768
  bool all_idle = true;
1902
1769
 
1903
1770
  for (auto & slot : slots) {
1904
- if (slot.state != SLOT_STATE_IDLE || slot.command != SLOT_COMMAND_NONE) {
1771
+ if (slot.is_processing()) {
1905
1772
  all_idle = false;
1906
1773
  break;
1907
1774
  }
1908
1775
  }
1909
1776
 
1910
1777
  if (all_idle) {
1911
- LOG_INFO("all slots are idle", {});
1912
- if (system_prompt.empty() && clean_kv_cache) {
1778
+ SRV_INF("%s", "all slots are idle\n");
1779
+ if (clean_kv_cache) {
1913
1780
  kv_cache_clear();
1914
1781
  }
1915
1782
 
@@ -1918,7 +1785,7 @@ struct server_context {
1918
1785
  }
1919
1786
 
1920
1787
  {
1921
- LOG_VERBOSE("posting NEXT_RESPONSE", {});
1788
+ SRV_DBG("%s", "posting NEXT_RESPONSE\n");
1922
1789
 
1923
1790
  server_task task;
1924
1791
  task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
@@ -1930,59 +1797,51 @@ struct server_context {
1930
1797
  // apply context-shift if needed
1931
1798
  // TODO: simplify and improve
1932
1799
  for (server_slot & slot : slots) {
1933
- if (slot.ga_n == 1) {
1934
- if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
1935
- // Shift context
1936
- const int n_keep = slot.params.n_keep + add_bos_token;
1937
- const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
1938
- const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
1939
-
1940
- LOG_INFO("slot context shift", {
1941
- {"id_slot", slot.id},
1942
- {"id_task", slot.id_task},
1943
- {"n_keep", n_keep},
1944
- {"n_left", n_left},
1945
- {"n_discard", n_discard},
1946
- {"n_ctx", n_ctx},
1947
- {"n_past", slot.n_past},
1948
- {"n_system_tokens", system_tokens.size()},
1949
- {"n_cache_tokens", slot.cache_tokens.size()}
1950
- });
1800
+ if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
1801
+ if (!params.ctx_shift) {
1802
+ // this check is redundant (for good)
1803
+ // we should never get here, because generation should already stopped in process_token()
1804
+ slot.release();
1805
+ send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
1806
+ continue;
1807
+ }
1951
1808
 
1952
- llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1953
- llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
1809
+ // Shift context
1810
+ const int n_keep = slot.params.n_keep + add_bos_token;
1811
+ const int n_left = slot.n_past - n_keep;
1812
+ const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
1954
1813
 
1955
- if (slot.params.cache_prompt) {
1956
- for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
1957
- slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1958
- }
1814
+ SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
1959
1815
 
1960
- slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1961
- }
1816
+ llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
1817
+ llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
1962
1818
 
1963
- slot.n_past -= n_discard;
1819
+ if (slot.params.cache_prompt) {
1820
+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
1821
+ slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1822
+ }
1964
1823
 
1965
- slot.truncated = true;
1824
+ slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1966
1825
  }
1826
+
1827
+ slot.n_past -= n_discard;
1828
+
1829
+ slot.truncated = true;
1967
1830
  }
1968
1831
  }
1969
1832
 
1970
1833
  // start populating the batch for this iteration
1971
- llama_batch_clear(batch);
1834
+ common_batch_clear(batch);
1972
1835
 
1973
1836
  // frist, add sampled tokens from any ongoing sequences
1974
1837
  for (auto & slot : slots) {
1975
- if (slot.state == SLOT_STATE_IDLE) {
1838
+ if (slot.state != SLOT_STATE_GENERATING) {
1976
1839
  continue;
1977
1840
  }
1978
1841
 
1979
1842
  slot.i_batch = batch.n_tokens;
1980
1843
 
1981
- const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
1982
-
1983
- // TODO: we always have to take into account the "system_tokens"
1984
- // this is not great and needs to be improved somehow
1985
- llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
1844
+ common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
1986
1845
 
1987
1846
  slot.n_past += 1;
1988
1847
 
@@ -1990,15 +1849,8 @@ struct server_context {
1990
1849
  slot.cache_tokens.push_back(slot.sampled);
1991
1850
  }
1992
1851
 
1993
- LOG_VERBOSE("slot decode token", {
1994
- {"id_slot", slot.id},
1995
- {"id_task", slot.id_task},
1996
- {"n_ctx", n_ctx},
1997
- {"n_past", slot.n_past},
1998
- {"n_system_tokens", system_tokens.size()},
1999
- {"n_cache_tokens", slot.cache_tokens.size()},
2000
- {"truncated", slot.truncated}
2001
- });
1852
+ SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
1853
+ slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
2002
1854
  }
2003
1855
 
2004
1856
  // process in chunks of params.n_batch
@@ -2008,111 +1860,86 @@ struct server_context {
2008
1860
  // track if this is an embedding or non-embedding batch
2009
1861
  // if we've added sampled tokens above, we are in non-embedding mode
2010
1862
  // -1: none, 0: non-embedding, 1: embedding
1863
+ // TODO: make enum
2011
1864
  int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
2012
1865
 
2013
1866
  // next, batch any pending prompts without exceeding n_batch
2014
1867
  if (params.cont_batching || batch.n_tokens == 0) {
2015
1868
  for (auto & slot : slots) {
2016
1869
  // this slot still has a prompt to be processed
2017
- if (slot.state == SLOT_STATE_IDLE && slot.command == SLOT_COMMAND_LOAD_PROMPT) {
1870
+ if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
2018
1871
  auto & prompt_tokens = slot.prompt_tokens;
2019
1872
 
2020
- // we haven't tokenized the prompt yet - do it now:
2021
- if (prompt_tokens.empty()) {
2022
- LOG_VERBOSE("tokenizing prompt", {
2023
- {"id_slot", slot.id},
2024
- {"id_task", slot.id_task}
2025
- });
2026
-
1873
+ // TODO: maybe move branch to outside of this loop in the future
1874
+ if (slot.state == SLOT_STATE_STARTED) {
2027
1875
  slot.t_start_process_prompt = ggml_time_us();
2028
1876
  slot.t_start_generation = 0;
2029
1877
 
2030
- if (slot.infill) {
2031
- const bool add_bos = llama_should_add_bos_token(model);
2032
- bool suff_rm_leading_spc = true;
2033
- if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
2034
- params.input_suffix.erase(0, 1);
2035
- suff_rm_leading_spc = false;
2036
- }
2037
-
2038
- auto prefix_tokens = tokenize(slot.params.input_prefix, false);
2039
- auto suffix_tokens = tokenize(slot.params.input_suffix, false);
2040
-
2041
- const int space_token = 29871; // TODO: this should not be hardcoded
2042
- if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
2043
- suffix_tokens.erase(suffix_tokens.begin());
2044
- }
2045
-
2046
- prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
2047
- suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
1878
+ slot.n_past = 0;
1879
+ slot.n_prompt_tokens = prompt_tokens.size();
1880
+ slot.state = SLOT_STATE_PROCESSING_PROMPT;
2048
1881
 
2049
- auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
2050
- auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
2051
- if (add_bos) {
2052
- embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
2053
- }
2054
- embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
1882
+ SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
2055
1883
 
2056
- const llama_token middle_token = llama_token_middle(model);
2057
- if (middle_token >= 0) {
2058
- embd_inp.push_back(middle_token);
1884
+ // print prompt tokens (for debugging)
1885
+ if (1) {
1886
+ // first 16 tokens (avoid flooding logs)
1887
+ for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
1888
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2059
1889
  }
2060
-
2061
- prompt_tokens = embd_inp;
2062
1890
  } else {
2063
- prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
1891
+ // all
1892
+ for (int i = 0; i < (int) prompt_tokens.size(); i++) {
1893
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
1894
+ }
2064
1895
  }
2065
1896
 
2066
- slot.n_past = 0;
2067
- slot.n_prompt_tokens = prompt_tokens.size();
2068
-
2069
- LOG_VERBOSE("prompt tokenized", {
2070
- {"id_slot", slot.id},
2071
- {"id_task", slot.id_task},
2072
- {"n_ctx", slot.n_ctx},
2073
- {"n_keep", slot.params.n_keep},
2074
- {"n_prompt_tokens", slot.n_prompt_tokens},
2075
- {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
2076
- });
2077
-
2078
1897
  // empty prompt passed -> release the slot and send empty response
2079
1898
  if (prompt_tokens.empty()) {
2080
- LOG_INFO("empty prompt - releasing slot", {
2081
- {"id_slot", slot.id},
2082
- {"id_task", slot.id_task}
2083
- });
1899
+ SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
2084
1900
 
2085
- slot.state = SLOT_STATE_PROCESSING;
2086
- slot.command = SLOT_COMMAND_NONE;
2087
1901
  slot.release();
2088
1902
  slot.print_timings();
2089
1903
  send_final_response(slot);
2090
1904
  continue;
2091
1905
  }
2092
1906
 
2093
- if (slot.embedding) {
2094
- // this prompt is too large to process - discard it
1907
+ if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2095
1908
  if (slot.n_prompt_tokens > n_ubatch) {
2096
- slot.state = SLOT_STATE_PROCESSING;
2097
- slot.command = SLOT_COMMAND_NONE;
2098
1909
  slot.release();
2099
1910
  send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
2100
1911
  continue;
2101
1912
  }
1913
+
1914
+ if (slot.n_prompt_tokens > slot.n_ctx) {
1915
+ slot.release();
1916
+ send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
1917
+ continue;
1918
+ }
2102
1919
  } else {
1920
+ if (!params.ctx_shift) {
1921
+ // if context shift is disabled, we make sure prompt size is smaller than KV size
1922
+ // TODO: there should be a separate parameter that control prompt truncation
1923
+ // context shift should be applied only during the generation phase
1924
+ if (slot.n_prompt_tokens >= slot.n_ctx) {
1925
+ slot.release();
1926
+ send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
1927
+ continue;
1928
+ }
1929
+ }
2103
1930
  if (slot.params.n_keep < 0) {
2104
1931
  slot.params.n_keep = slot.n_prompt_tokens;
2105
1932
  }
2106
1933
  slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
2107
1934
 
2108
- // if input prompt is too big, truncate it (if group attention self-extend is disabled)
2109
- if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
1935
+ // if input prompt is too big, truncate it
1936
+ if (slot.n_prompt_tokens >= slot.n_ctx) {
2110
1937
  const int n_left = slot.n_ctx - slot.params.n_keep;
2111
1938
 
2112
1939
  const int n_block_size = n_left / 2;
2113
1940
  const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
2114
1941
 
2115
- std::vector<llama_token> new_tokens(
1942
+ llama_tokens new_tokens(
2116
1943
  prompt_tokens.begin(),
2117
1944
  prompt_tokens.begin() + slot.params.n_keep);
2118
1945
 
@@ -2126,54 +1953,73 @@ struct server_context {
2126
1953
  slot.truncated = true;
2127
1954
  slot.n_prompt_tokens = prompt_tokens.size();
2128
1955
 
2129
- LOG_VERBOSE("input truncated", {
2130
- {"id_slot", slot.id},
2131
- {"id_task", slot.id_task},
2132
- {"n_ctx", slot.n_ctx},
2133
- {"n_keep", slot.params.n_keep},
2134
- {"n_left", n_left},
2135
- {"n_prompt_tokens", slot.n_prompt_tokens},
2136
- {"prompt_tokens", tokens_to_str(ctx, prompt_tokens.cbegin(), prompt_tokens.cend())},
2137
- });
1956
+ SLT_WRN(slot, "input truncated, n_ctx = %d, n_keep = %d, n_left = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, n_left, slot.n_prompt_tokens);
2138
1957
 
2139
1958
  GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
2140
1959
  }
2141
1960
 
2142
- llama_sampling_reset(slot.ctx_sampling);
1961
+ if (slot.params.cache_prompt) {
1962
+ // reuse any previously computed tokens that are common with the new prompt
1963
+ slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
2143
1964
 
2144
- if (!slot.params.cache_prompt) {
2145
- slot.n_past_se = 0;
2146
- slot.ga_i = 0;
2147
- } else {
2148
- GGML_ASSERT(slot.ga_n == 1);
1965
+ // reuse chunks from the cached prompt by shifting their KV cache in the new position
1966
+ if (params.n_cache_reuse > 0) {
1967
+ size_t head_c = slot.n_past; // cache
1968
+ size_t head_p = slot.n_past; // current prompt
2149
1969
 
2150
- // reuse any previously computed tokens that are common with the new prompt
2151
- slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
1970
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
1971
+
1972
+ while (head_c < slot.cache_tokens.size() &&
1973
+ head_p < prompt_tokens.size()) {
1974
+
1975
+ size_t n_match = 0;
1976
+ while (head_c + n_match < slot.cache_tokens.size() &&
1977
+ head_p + n_match < prompt_tokens.size() &&
1978
+ slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
1979
+
1980
+ n_match++;
1981
+ }
1982
+
1983
+ if (n_match >= (size_t) params.n_cache_reuse) {
1984
+ SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
1985
+ //for (size_t i = head_p; i < head_p + n_match; i++) {
1986
+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
1987
+ //}
1988
+
1989
+ const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
1990
+
1991
+ llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
1992
+ llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
1993
+
1994
+ for (size_t i = 0; i < n_match; i++) {
1995
+ slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
1996
+ slot.n_past++;
1997
+ }
2152
1998
 
2153
- // push the prompt into the sampling context (do not apply grammar)
2154
- for (int i = 0; i < slot.n_past; ++i) {
2155
- llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
1999
+ head_c += n_match;
2000
+ head_p += n_match;
2001
+ } else {
2002
+ head_c += 1;
2003
+ }
2004
+ }
2005
+
2006
+ SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
2156
2007
  }
2157
2008
  }
2158
2009
  }
2159
2010
 
2160
2011
  if (slot.n_past == slot.n_prompt_tokens && slot.n_past > 0) {
2161
2012
  // we have to evaluate at least 1 token to generate logits.
2162
- LOG_INFO("we have to evaluate at least 1 token to generate logits", {
2163
- { "id_slot", slot.id },
2164
- { "id_task", slot.id_task }
2165
- });
2013
+ SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
2166
2014
 
2167
2015
  slot.n_past--;
2168
- if (slot.ga_i > 0) {
2169
- slot.n_past_se--;
2170
- }
2171
2016
  }
2172
2017
 
2173
2018
  slot.n_prompt_tokens_processed = 0;
2174
2019
  }
2175
2020
 
2176
- if (slot.embedding) {
2021
+ // non-causal tasks require to fit the entire prompt in the physical batch
2022
+ if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2177
2023
  // cannot fit the prompt in the current batch - will try next iter
2178
2024
  if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
2179
2025
  continue;
@@ -2181,7 +2027,10 @@ struct server_context {
2181
2027
  }
2182
2028
 
2183
2029
  // check that we are in the right batch_type, if not defer the slot
2184
- bool slot_type = slot.embedding ? 1 : 0;
2030
+ const bool slot_type =
2031
+ slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
2032
+ slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
2033
+
2185
2034
  if (batch_type == -1) {
2186
2035
  batch_type = slot_type;
2187
2036
  } else if (batch_type != slot_type) {
@@ -2189,88 +2038,53 @@ struct server_context {
2189
2038
  }
2190
2039
 
2191
2040
  // keep only the common part
2192
- int p0 = (int) system_tokens.size() + slot.n_past;
2193
- if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
2041
+ if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
2194
2042
  // could not partially delete (likely using a non-Transformer model)
2195
- llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
2196
-
2197
- p0 = (int) system_tokens.size();
2198
- if (p0 != 0) {
2199
- // copy over the system prompt when there is one
2200
- llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
2201
- }
2043
+ llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
2202
2044
 
2203
- // there is no common part left (except for the system prompt)
2045
+ // there is no common part left
2204
2046
  slot.n_past = 0;
2205
- slot.n_past_se = 0;
2206
- slot.ga_i = 0;
2207
- // TODO: is the system prompt ever in the sampling context?
2208
- llama_sampling_reset(slot.ctx_sampling);
2209
2047
  }
2210
2048
 
2049
+ SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
2050
+
2211
2051
  // remove the non-common part from the cache
2212
2052
  slot.cache_tokens.resize(slot.n_past);
2213
2053
 
2214
- LOG_INFO("kv cache rm [p0, end)", {
2215
- { "id_slot", slot.id },
2216
- { "id_task", slot.id_task },
2217
- { "p0", p0 }
2218
- });
2219
-
2220
- int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
2221
-
2222
- int32_t ga_i = slot.ga_i;
2223
- int32_t ga_n = slot.ga_n;
2224
- int32_t ga_w = slot.ga_w;
2225
-
2226
2054
  // add prompt tokens for processing in the current batch
2227
- // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
2228
- for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
2229
- if (slot.ga_n != 1) {
2230
- while (slot_npast >= ga_i + ga_w) {
2231
- const int bd = (ga_w/ga_n)*(ga_n - 1);
2232
- slot_npast -= bd;
2233
- ga_i += ga_w/ga_n;
2234
- }
2235
- }
2236
-
2237
- llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
2055
+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2056
+ common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
2238
2057
 
2239
2058
  if (slot.params.cache_prompt) {
2240
2059
  slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
2241
2060
  }
2242
2061
 
2243
2062
  slot.n_prompt_tokens_processed++;
2244
- slot_npast++;
2063
+ slot.n_past++;
2245
2064
  }
2246
2065
 
2247
- LOG_VERBOSE("prompt processing progress", {
2248
- {"id_slot", slot.id},
2249
- {"n_past", slot.n_past},
2250
- {"n_ctx", n_ctx},
2251
- {"n_tokens", batch.n_tokens},
2252
- {"progress", (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens},
2253
- });
2066
+ SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
2254
2067
 
2255
- // entire prompt has been processed - start decoding new tokens
2068
+ // entire prompt has been processed
2256
2069
  if (slot.n_past == slot.n_prompt_tokens) {
2257
- slot.state = SLOT_STATE_PROCESSING;
2258
- slot.command = SLOT_COMMAND_NONE;
2070
+ slot.state = SLOT_STATE_DONE_PROMPT;
2259
2071
 
2260
2072
  GGML_ASSERT(batch.n_tokens > 0);
2261
2073
 
2074
+ common_sampler_reset(slot.smpl);
2075
+
2076
+ // Process all prompt tokens through sampler system
2077
+ for (int i = 0; i < slot.n_prompt_tokens; ++i) {
2078
+ common_sampler_accept(slot.smpl, prompt_tokens[i], false);
2079
+ }
2080
+
2262
2081
  // extract the logits only for the last token
2263
2082
  batch.logits[batch.n_tokens - 1] = true;
2264
2083
 
2265
2084
  slot.n_decoded = 0;
2266
2085
  slot.i_batch = batch.n_tokens - 1;
2267
2086
 
2268
- LOG_VERBOSE("prompt done", {
2269
- {"id_slot", slot.id},
2270
- {"n_past", slot.n_past},
2271
- {"n_ctx", n_ctx},
2272
- {"n_tokens", batch.n_tokens},
2273
- });
2087
+ SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens);
2274
2088
  }
2275
2089
  }
2276
2090
 
@@ -2281,13 +2095,11 @@ struct server_context {
2281
2095
  }
2282
2096
 
2283
2097
  if (batch.n_tokens == 0) {
2284
- LOG_VERBOSE("no tokens to decode", {});
2098
+ SRV_WRN("%s", "no tokens to decode\n");
2285
2099
  return;
2286
2100
  }
2287
2101
 
2288
- LOG_VERBOSE("decoding batch", {
2289
- {"n_tokens", batch.n_tokens},
2290
- });
2102
+ SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
2291
2103
 
2292
2104
  // make sure we're in the right embedding mode
2293
2105
  llama_set_embeddings(ctx, batch_type == 1);
@@ -2296,35 +2108,6 @@ struct server_context {
2296
2108
  for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
2297
2109
  const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
2298
2110
 
2299
- for (auto & slot : slots) {
2300
- if (slot.ga_n != 1) {
2301
- // context extension via Self-Extend
2302
- // TODO: simplify and/or abstract this
2303
- while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
2304
- const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
2305
- const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
2306
- const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
2307
-
2308
- LOG_TEE("\n");
2309
- LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2310
- LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
2311
- LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2312
-
2313
- llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
2314
- llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
2315
- llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
2316
-
2317
- slot.n_past_se -= bd;
2318
-
2319
- slot.ga_i += slot.ga_w / slot.ga_n;
2320
-
2321
- LOG_TEE("\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
2322
- }
2323
-
2324
- slot.n_past_se += n_tokens;
2325
- }
2326
- }
2327
-
2328
2111
  llama_batch batch_view = {
2329
2112
  n_tokens,
2330
2113
  batch.token + i,
@@ -2333,22 +2116,16 @@ struct server_context {
2333
2116
  batch.n_seq_id + i,
2334
2117
  batch.seq_id + i,
2335
2118
  batch.logits + i,
2336
- 0, 0, 0, // unused
2337
2119
  };
2338
2120
 
2339
2121
  const int ret = llama_decode(ctx, batch_view);
2122
+ metrics.on_decoded(slots);
2340
2123
 
2341
2124
  if (ret != 0) {
2342
2125
  if (n_batch == 1 || ret < 0) {
2343
2126
  // if you get here, it means the KV cache is full - try increasing it via the context size
2344
- LOG_ERROR("failed to decode the batch: KV cache is full - try increasing it via the context size", {
2345
- {"i", i},
2346
- {"n_batch", ret},
2347
- {"ret", ret},
2348
- });
2127
+ SRV_ERR("failed to decode the batch: KV cache is full - try increasing it via the context size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
2349
2128
  for (auto & slot : slots) {
2350
- slot.state = SLOT_STATE_PROCESSING;
2351
- slot.command = SLOT_COMMAND_NONE;
2352
2129
  slot.release();
2353
2130
  send_error(slot, "Input prompt is too big compared to KV size. Please try increasing KV size.");
2354
2131
  }
@@ -2359,32 +2136,42 @@ struct server_context {
2359
2136
  n_batch /= 2;
2360
2137
  i -= n_batch;
2361
2138
 
2362
- LOG_WARNING("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation", {
2363
- {"i", i},
2364
- {"n_batch", n_batch},
2365
- {"ret", ret},
2366
- });
2139
+ SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size - try increasing it via the context size or enable defragmentation, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret);
2367
2140
 
2368
2141
  continue; // continue loop of n_batch
2369
2142
  }
2370
2143
 
2371
2144
  for (auto & slot : slots) {
2372
- if (slot.state != SLOT_STATE_PROCESSING || slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
2145
+ if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
2373
2146
  continue; // continue loop of slots
2374
2147
  }
2375
2148
 
2376
- // prompt evaluated for embedding
2377
- if (slot.embedding) {
2378
- send_embedding(slot, batch_view);
2379
- slot.release();
2380
- slot.i_batch = -1;
2149
+ if (slot.state == SLOT_STATE_DONE_PROMPT) {
2150
+ if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
2151
+ // prompt evaluated for embedding
2152
+ send_embedding(slot, batch_view);
2153
+ slot.release();
2154
+ slot.i_batch = -1;
2155
+ continue; // continue loop of slots
2156
+ }
2157
+
2158
+ if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2159
+ send_rerank(slot, batch_view);
2160
+ slot.release();
2161
+ slot.i_batch = -1;
2162
+ continue; // continue loop of slots
2163
+ }
2164
+
2165
+ // prompt evaluated for next-token prediction
2166
+ slot.state = SLOT_STATE_GENERATING;
2167
+ } else if (slot.state != SLOT_STATE_GENERATING) {
2381
2168
  continue; // continue loop of slots
2382
2169
  }
2383
2170
 
2384
2171
  completion_token_output result;
2385
- const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
2172
+ const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
2386
2173
 
2387
- llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
2174
+ common_sampler_accept(slot.smpl, id, true);
2388
2175
 
2389
2176
  slot.n_decoded += 1;
2390
2177
  if (slot.n_decoded == 1) {
@@ -2393,37 +2180,19 @@ struct server_context {
2393
2180
  metrics.on_prompt_eval(slot);
2394
2181
  }
2395
2182
 
2396
- llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
2397
2183
  result.tok = id;
2398
2184
 
2399
- const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
2400
- if (n_probs > 0) {
2401
- const size_t n_valid = slot.ctx_sampling->n_valid;
2185
+ const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2402
2186
 
2403
- // Make sure at least n_probs top tokens are at the front of the vector:
2404
- if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
2405
- llama_sample_top_k(ctx, &cur_p, n_probs, 0);
2406
- }
2407
-
2408
- if (slot.sparams.temp == 0.0f) {
2409
- // With greedy sampling the probabilities have possibly not been calculated.
2410
- for (size_t i = 0; i < n_probs; ++i) {
2411
- result.probs.push_back({
2412
- cur_p.data[i].id,
2413
- i == 0 ? 1.0f : 0.0f
2414
- });
2415
- }
2416
- } else {
2417
- for (size_t i = 0; i < n_probs; ++i) {
2418
- result.probs.push_back({
2419
- cur_p.data[i].id,
2420
- i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
2421
- });
2422
- }
2423
- }
2187
+ for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2188
+ result.probs.push_back({
2189
+ cur_p->data[i].id,
2190
+ i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2191
+ });
2424
2192
  }
2425
2193
 
2426
2194
  if (!process_token(result, slot)) {
2195
+ // release slot because of stop condition
2427
2196
  slot.release();
2428
2197
  slot.print_timings();
2429
2198
  send_final_response(slot);
@@ -2434,7 +2203,7 @@ struct server_context {
2434
2203
  }
2435
2204
  }
2436
2205
 
2437
- LOG_VERBOSE("run slots completed", {});
2206
+ SRV_DBG("%s", "run slots completed\n");
2438
2207
  }
2439
2208
 
2440
2209
  json model_meta() const {
@@ -2455,19 +2224,10 @@ static void log_server_request(const httplib::Request & req, const httplib::Resp
2455
2224
  return;
2456
2225
  }
2457
2226
 
2458
- LOG_INFO("request", {
2459
- {"remote_addr", req.remote_addr},
2460
- {"remote_port", req.remote_port},
2461
- {"status", res.status},
2462
- {"method", req.method},
2463
- {"path", req.path},
2464
- {"params", req.params},
2465
- });
2227
+ LOG_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status);
2466
2228
 
2467
- LOG_VERBOSE("request", {
2468
- {"request", req.body},
2469
- {"response", res.body},
2470
- });
2229
+ LOG_DBG("request: %s\n", req.body.c_str());
2230
+ LOG_DBG("response: %s\n", res.body.c_str());
2471
2231
  }
2472
2232
 
2473
2233
  std::function<void(int)> shutdown_handler;
@@ -2485,28 +2245,22 @@ inline void signal_handler(int signal) {
2485
2245
  }
2486
2246
 
2487
2247
  int main(int argc, char ** argv) {
2488
- #if SERVER_VERBOSE != 1
2489
- log_disable();
2490
- #endif
2491
2248
  // own arguments required by this example
2492
- gpt_params params;
2249
+ common_params params;
2493
2250
 
2494
- if (!gpt_params_parse(argc, argv, params)) {
2495
- gpt_params_print_usage(argc, argv, params);
2251
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
2496
2252
  return 1;
2497
2253
  }
2498
2254
 
2499
- // TODO: not great to use extern vars
2500
- server_log_json = params.log_json;
2501
- server_verbose = params.verbosity > 0;
2255
+ common_init();
2256
+
2257
+ // enabling this will output extra debug information in the HTTP responses from the server
2258
+ // see format_final_response_oaicompat()
2259
+ const bool verbose = params.verbosity > 9;
2502
2260
 
2503
2261
  // struct that contains llama context and inference
2504
2262
  server_context ctx_server;
2505
2263
 
2506
- if (!params.system_prompt.empty()) {
2507
- ctx_server.system_prompt_set(params.system_prompt);
2508
- }
2509
-
2510
2264
  if (params.model_alias == "unknown") {
2511
2265
  params.model_alias = params.model;
2512
2266
  }
@@ -2514,58 +2268,60 @@ int main(int argc, char ** argv) {
2514
2268
  llama_backend_init();
2515
2269
  llama_numa_init(params.numa);
2516
2270
 
2517
- LOG_INFO("build info", {
2518
- {"build", LLAMA_BUILD_NUMBER},
2519
- {"commit", LLAMA_COMMIT}
2520
- });
2521
-
2522
- LOG_INFO("system info", {
2523
- {"n_threads", params.n_threads},
2524
- {"n_threads_batch", params.n_threads_batch},
2525
- {"total_threads", std::thread::hardware_concurrency()},
2526
- {"system_info", llama_print_system_info()},
2527
- });
2271
+ LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
2272
+ LOG_INF("\n");
2273
+ LOG_INF("%s\n", common_params_get_system_info(params).c_str());
2274
+ LOG_INF("\n");
2275
+
2276
+ // static files
2277
+ std::map<std::string, server_static_file> static_files = {
2278
+ { "/", { index_html, index_html_len, "text/html; charset=utf-8" }},
2279
+ { "/completion.js", { completion_js, completion_js_len, "text/javascript; charset=utf-8" }},
2280
+ { "/deps_daisyui.min.css", { deps_daisyui_min_css, deps_daisyui_min_css_len, "text/css; charset=utf-8" }},
2281
+ { "/deps_markdown-it.js", { deps_markdown_it_js, deps_markdown_it_js_len, "text/javascript; charset=utf-8" }},
2282
+ { "/deps_tailwindcss.js", { deps_tailwindcss_js, deps_tailwindcss_js_len, "text/javascript; charset=utf-8" }},
2283
+ { "/deps_vue.esm-browser.js", { deps_vue_esm_browser_js, deps_vue_esm_browser_js_len, "text/javascript; charset=utf-8" }},
2284
+ };
2528
2285
 
2529
2286
  std::unique_ptr<httplib::Server> svr;
2530
2287
  #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2531
2288
  if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
2532
- LOG_INFO("Running with SSL", {{"key", params.ssl_file_key}, {"cert", params.ssl_file_cert}});
2289
+ LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str());
2533
2290
  svr.reset(
2534
2291
  new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str())
2535
2292
  );
2536
2293
  } else {
2537
- LOG_INFO("Running without SSL", {});
2294
+ LOG_INF("Running without SSL\n");
2538
2295
  svr.reset(new httplib::Server());
2539
2296
  }
2540
2297
  #else
2298
+ if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
2299
+ LOG_ERR("Server is built without SSL support\n");
2300
+ return 1;
2301
+ }
2541
2302
  svr.reset(new httplib::Server());
2542
2303
  #endif
2543
2304
 
2544
2305
  std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
2545
2306
 
2546
2307
  svr->set_default_headers({{"Server", "llama.cpp"}});
2547
-
2548
- // CORS preflight
2549
- svr->Options(R"(.*)", [](const httplib::Request & req, httplib::Response & res) {
2550
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2551
- res.set_header("Access-Control-Allow-Credentials", "true");
2552
- res.set_header("Access-Control-Allow-Methods", "POST");
2553
- res.set_header("Access-Control-Allow-Headers", "*");
2554
- return res.set_content("", "application/json; charset=utf-8");
2555
- });
2556
-
2557
2308
  svr->set_logger(log_server_request);
2558
2309
 
2559
- auto res_error = [](httplib::Response & res, json error_data) {
2310
+ auto res_error = [](httplib::Response & res, const json & error_data) {
2560
2311
  json final_response {{"error", error_data}};
2561
- res.set_content(final_response.dump(), "application/json; charset=utf-8");
2312
+ res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
2562
2313
  res.status = json_value(error_data, "code", 500);
2563
2314
  };
2564
2315
 
2316
+ auto res_ok = [](httplib::Response & res, const json & data) {
2317
+ res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
2318
+ res.status = 200;
2319
+ };
2320
+
2565
2321
  svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
2566
2322
  std::string message;
2567
2323
  try {
2568
- std::rethrow_exception(std::move(ep));
2324
+ std::rethrow_exception(ep);
2569
2325
  } catch (std::exception & e) {
2570
2326
  message = e.what();
2571
2327
  } catch (...) {
@@ -2573,7 +2329,7 @@ int main(int argc, char ** argv) {
2573
2329
  }
2574
2330
 
2575
2331
  json formatted_error = format_error_response(message, ERROR_TYPE_SERVER);
2576
- LOG_VERBOSE("Got exception", formatted_error);
2332
+ LOG_WRN("got exception: %s\n", formatted_error.dump().c_str());
2577
2333
  res_error(res, formatted_error);
2578
2334
  });
2579
2335
 
@@ -2588,11 +2344,6 @@ int main(int argc, char ** argv) {
2588
2344
  svr->set_read_timeout (params.timeout_read);
2589
2345
  svr->set_write_timeout(params.timeout_write);
2590
2346
 
2591
- if (!svr->bind_to_port(params.hostname, params.port)) {
2592
- fprintf(stderr, "\ncouldn't bind to server socket: hostname=%s port=%d\n\n", params.hostname.c_str(), params.port);
2593
- return 1;
2594
- }
2595
-
2596
2347
  std::unordered_map<std::string, std::string> log_data;
2597
2348
 
2598
2349
  log_data["hostname"] = params.hostname;
@@ -2608,54 +2359,15 @@ int main(int argc, char ** argv) {
2608
2359
  // Necessary similarity of prompt for slot selection
2609
2360
  ctx_server.slot_prompt_similarity = params.slot_prompt_similarity;
2610
2361
 
2611
- // load the model
2612
- if (!ctx_server.load_model(params)) {
2613
- state.store(SERVER_STATE_ERROR);
2614
- return 1;
2615
- } else {
2616
- ctx_server.init();
2617
- state.store(SERVER_STATE_READY);
2618
- }
2619
-
2620
- LOG_INFO("model loaded", {});
2621
-
2622
- const auto model_meta = ctx_server.model_meta();
2623
-
2624
- // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
2625
- if (params.chat_template.empty()) {
2626
- if (!ctx_server.validate_model_chat_template()) {
2627
- LOG_WARNING("The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses", {});
2628
- params.chat_template = "chatml";
2629
- }
2630
- }
2631
-
2632
- // print sample chat example to make it clear which template is used
2633
- {
2634
- LOG_INFO("chat template", {
2635
- {"chat_example", llama_chat_format_example(ctx_server.model, params.chat_template)},
2636
- {"built_in", params.chat_template.empty()},
2637
- });
2638
- }
2639
-
2640
2362
  //
2641
2363
  // Middlewares
2642
2364
  //
2643
2365
 
2644
- auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
2645
- // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
2646
- static const std::set<std::string> protected_endpoints = {
2647
- "/props",
2648
- "/completion",
2649
- "/completions",
2650
- "/v1/completions",
2651
- "/chat/completions",
2652
- "/v1/chat/completions",
2653
- "/infill",
2654
- "/tokenize",
2655
- "/detokenize",
2656
- "/embedding",
2657
- "/embeddings",
2658
- "/v1/embeddings",
2366
+ auto middleware_validate_api_key = [&params, &res_error, &static_files](const httplib::Request & req, httplib::Response & res) {
2367
+ static const std::unordered_set<std::string> public_endpoints = {
2368
+ "/health",
2369
+ "/models",
2370
+ "/v1/models",
2659
2371
  };
2660
2372
 
2661
2373
  // If API key is not set, skip validation
@@ -2663,8 +2375,8 @@ int main(int argc, char ** argv) {
2663
2375
  return true;
2664
2376
  }
2665
2377
 
2666
- // If path is not in protected_endpoints list, skip validation
2667
- if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
2378
+ // If path is public or is static file, skip validation
2379
+ if (public_endpoints.find(req.path) != public_endpoints.end() || static_files.find(req.path) != static_files.end()) {
2668
2380
  return true;
2669
2381
  }
2670
2382
 
@@ -2680,17 +2392,42 @@ int main(int argc, char ** argv) {
2680
2392
  }
2681
2393
 
2682
2394
  // API key is invalid or not provided
2683
- // TODO: make another middleware for CORS related logic
2684
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2685
2395
  res_error(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION));
2686
2396
 
2687
- LOG_WARNING("Unauthorized: Invalid API Key", {});
2397
+ LOG_WRN("Unauthorized: Invalid API Key\n");
2688
2398
 
2689
2399
  return false;
2690
2400
  };
2691
2401
 
2402
+ auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
2403
+ server_state current_state = state.load();
2404
+ if (current_state == SERVER_STATE_LOADING_MODEL) {
2405
+ auto tmp = string_split<std::string>(req.path, '.');
2406
+ if (req.path == "/" || tmp.back() == "html") {
2407
+ res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
2408
+ res.status = 503;
2409
+ } else {
2410
+ res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
2411
+ }
2412
+ return false;
2413
+ }
2414
+ return true;
2415
+ };
2416
+
2692
2417
  // register server middlewares
2693
- svr->set_pre_routing_handler([&middleware_validate_api_key](const httplib::Request & req, httplib::Response & res) {
2418
+ svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
2419
+ res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2420
+ // If this is OPTIONS request, skip validation because browsers don't include Authorization header
2421
+ if (req.method == "OPTIONS") {
2422
+ res.set_header("Access-Control-Allow-Credentials", "true");
2423
+ res.set_header("Access-Control-Allow-Methods", "GET, POST");
2424
+ res.set_header("Access-Control-Allow-Headers", "*");
2425
+ res.set_content("", "text/html"); // blank response, no data
2426
+ return httplib::Server::HandlerResponse::Handled; // skip further processing
2427
+ }
2428
+ if (!middleware_server_state(req, res)) {
2429
+ return httplib::Server::HandlerResponse::Handled;
2430
+ }
2694
2431
  if (!middleware_validate_api_key(req, res)) {
2695
2432
  return httplib::Server::HandlerResponse::Handled;
2696
2433
  }
@@ -2701,99 +2438,57 @@ int main(int argc, char ** argv) {
2701
2438
  // Route handlers (or controllers)
2702
2439
  //
2703
2440
 
2704
- const auto handle_health = [&](const httplib::Request & req, httplib::Response & res) {
2705
- server_state current_state = state.load();
2706
- switch (current_state) {
2707
- case SERVER_STATE_READY:
2708
- {
2709
- // request slots data using task queue
2710
- server_task task;
2711
- task.id = ctx_server.queue_tasks.get_new_id();
2712
- task.type = SERVER_TASK_TYPE_METRICS;
2713
- task.id_target = -1;
2714
-
2715
- ctx_server.queue_results.add_waiting_task_id(task.id);
2716
- ctx_server.queue_tasks.post(task);
2717
-
2718
- // get the result
2719
- server_task_result result = ctx_server.queue_results.recv(task.id);
2720
- ctx_server.queue_results.remove_waiting_task_id(task.id);
2721
-
2722
- const int n_idle_slots = result.data.at("idle");
2723
- const int n_processing_slots = result.data.at("processing");
2724
-
2725
- json health = {
2726
- {"status", "ok"},
2727
- {"slots_idle", n_idle_slots},
2728
- {"slots_processing", n_processing_slots}
2729
- };
2730
-
2731
- res.status = 200; // HTTP OK
2732
- if (params.endpoint_slots && req.has_param("include_slots")) {
2733
- health["slots"] = result.data.at("slots");
2734
- }
2735
-
2736
- if (n_idle_slots == 0) {
2737
- health["status"] = "no slot available";
2738
- if (req.has_param("fail_on_no_slot")) {
2739
- res.status = 503; // HTTP Service Unavailable
2740
- }
2741
- }
2742
-
2743
- res.set_content(health.dump(), "application/json");
2744
- break;
2745
- }
2746
- case SERVER_STATE_LOADING_MODEL:
2747
- {
2748
- res_error(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE));
2749
- } break;
2750
- case SERVER_STATE_ERROR:
2751
- {
2752
- res_error(res, format_error_response("Model failed to load", ERROR_TYPE_SERVER));
2753
- } break;
2754
- }
2441
+ const auto handle_health = [&](const httplib::Request &, httplib::Response & res) {
2442
+ // error and loading states are handled by middleware
2443
+ json health = {{"status", "ok"}};
2444
+ res_ok(res, health);
2755
2445
  };
2756
2446
 
2757
- const auto handle_slots = [&](const httplib::Request &, httplib::Response & res) {
2447
+ const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
2758
2448
  if (!params.endpoint_slots) {
2759
- res_error(res, format_error_response("This server does not support slots endpoint.", ERROR_TYPE_NOT_SUPPORTED));
2449
+ res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
2760
2450
  return;
2761
2451
  }
2762
2452
 
2763
2453
  // request slots data using task queue
2764
2454
  server_task task;
2765
2455
  task.id = ctx_server.queue_tasks.get_new_id();
2766
- task.id_multi = -1;
2767
- task.id_target = -1;
2768
2456
  task.type = SERVER_TASK_TYPE_METRICS;
2769
2457
 
2770
2458
  ctx_server.queue_results.add_waiting_task_id(task.id);
2771
- ctx_server.queue_tasks.post(task);
2459
+ ctx_server.queue_tasks.post(task, true); // high-priority task
2772
2460
 
2773
2461
  // get the result
2774
2462
  server_task_result result = ctx_server.queue_results.recv(task.id);
2775
2463
  ctx_server.queue_results.remove_waiting_task_id(task.id);
2776
2464
 
2777
- res.set_content(result.data.at("slots").dump(), "application/json");
2778
- res.status = 200; // HTTP OK
2465
+ // optionally return "fail_on_no_slot" error
2466
+ const int n_idle_slots = result.data.at("idle");
2467
+ if (req.has_param("fail_on_no_slot")) {
2468
+ if (n_idle_slots == 0) {
2469
+ res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
2470
+ return;
2471
+ }
2472
+ }
2473
+
2474
+ res_ok(res, result.data.at("slots"));
2779
2475
  };
2780
2476
 
2781
2477
  const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
2782
2478
  if (!params.endpoint_metrics) {
2783
- res_error(res, format_error_response("This server does not support metrics endpoint.", ERROR_TYPE_NOT_SUPPORTED));
2479
+ res_error(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED));
2784
2480
  return;
2785
2481
  }
2786
2482
 
2787
2483
  // request slots data using task queue
2788
2484
  server_task task;
2789
2485
  task.id = ctx_server.queue_tasks.get_new_id();
2790
- task.id_multi = -1;
2791
2486
  task.id_target = -1;
2792
2487
  task.type = SERVER_TASK_TYPE_METRICS;
2793
2488
  task.data.push_back({{"reset_bucket", true}});
2794
2489
 
2795
2490
  ctx_server.queue_results.add_waiting_task_id(task.id);
2796
- ctx_server.queue_tasks.post(task);
2491
+ ctx_server.queue_tasks.post(task, true); // high-priority task
2797
2492
 
2798
2493
  // get the result
2799
2494
  server_task_result result = ctx_server.queue_results.recv(task.id);
@@ -2807,6 +2502,9 @@ int main(int argc, char ** argv) {
2807
2502
  const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
2808
2503
  const uint64_t t_tokens_generation = data.at("t_tokens_generation");
2809
2504
 
2505
+ const uint64_t n_decode_total = data.at("n_decode_total");
2506
+ const uint64_t n_busy_slots_total = data.at("n_busy_slots_total");
2507
+
2810
2508
  const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
2811
2509
 
2812
2510
  // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
@@ -2827,6 +2525,14 @@ int main(int argc, char ** argv) {
2827
2525
  {"name", "tokens_predicted_seconds_total"},
2828
2526
  {"help", "Predict process time"},
2829
2527
  {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
2528
+ }, {
2529
+ {"name", "n_decode_total"},
2530
+ {"help", "Total number of llama_decode() calls"},
2531
+ {"value", n_decode_total}
2532
+ }, {
2533
+ {"name", "n_busy_slots_per_decode"},
2534
+ {"help", "Average number of busy slots per llama_decode() call"},
2535
+ {"value", (float) n_busy_slots_total / (float) n_decode_total}
2830
2536
  }}},
2831
2537
  {"gauge", {{
2832
2538
  {"name", "prompt_tokens_seconds"},
@@ -2879,7 +2585,7 @@ int main(int argc, char ** argv) {
2879
2585
  res.status = 200; // HTTP OK
2880
2586
  };
2881
2587
 
2882
- const auto handle_slots_save = [&ctx_server, &res_error, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
2588
+ const auto handle_slots_save = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
2883
2589
  json request_data = json::parse(req.body);
2884
2590
  std::string filename = request_data.at("filename");
2885
2591
  if (!fs_validate_filename(filename)) {
@@ -2893,7 +2599,7 @@ int main(int argc, char ** argv) {
2893
2599
  task.data = {
2894
2600
  { "id_slot", id_slot },
2895
2601
  { "filename", filename },
2896
- { "filepath", filepath }
2602
+ { "filepath", filepath },
2897
2603
  };
2898
2604
 
2899
2605
  const int id_task = ctx_server.queue_tasks.post(task);
@@ -2905,11 +2611,11 @@ int main(int argc, char ** argv) {
2905
2611
  if (result.error) {
2906
2612
  res_error(res, result.data);
2907
2613
  } else {
2908
- res.set_content(result.data.dump(), "application/json");
2614
+ res_ok(res, result.data);
2909
2615
  }
2910
2616
  };
2911
2617
 
2912
- const auto handle_slots_restore = [&ctx_server, &res_error, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
2618
+ const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
2913
2619
  json request_data = json::parse(req.body);
2914
2620
  std::string filename = request_data.at("filename");
2915
2621
  if (!fs_validate_filename(filename)) {
@@ -2923,7 +2629,7 @@ int main(int argc, char ** argv) {
2923
2629
  task.data = {
2924
2630
  { "id_slot", id_slot },
2925
2631
  { "filename", filename },
2926
- { "filepath", filepath }
2632
+ { "filepath", filepath },
2927
2633
  };
2928
2634
 
2929
2635
  const int id_task = ctx_server.queue_tasks.post(task);
@@ -2935,11 +2641,11 @@ int main(int argc, char ** argv) {
2935
2641
  if (result.error) {
2936
2642
  res_error(res, result.data);
2937
2643
  } else {
2938
- res.set_content(result.data.dump(), "application/json");
2644
+ res_ok(res, result.data);
2939
2645
  }
2940
2646
  };
2941
2647
 
2942
- const auto handle_slots_erase = [&ctx_server, &res_error](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
2648
+ const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
2943
2649
  server_task task;
2944
2650
  task.type = SERVER_TASK_TYPE_SLOT_ERASE;
2945
2651
  task.data = {
@@ -2955,12 +2661,15 @@ int main(int argc, char ** argv) {
2955
2661
  if (result.error) {
2956
2662
  res_error(res, result.data);
2957
2663
  } else {
2958
- res.set_content(result.data.dump(), "application/json");
2664
+ res_ok(res, result.data);
2959
2665
  }
2960
2666
  };
2961
2667
 
2962
- const auto handle_slots_action = [&res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
2963
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2668
+ const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
2669
+ if (params.slot_save_path.empty()) {
2670
+ res_error(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED));
2671
+ return;
2672
+ }
2964
2673
 
2965
2674
  std::string id_slot_str = req.path_params.at("id_slot");
2966
2675
  int id_slot;
@@ -2985,298 +2694,262 @@ int main(int argc, char ** argv) {
2985
2694
  }
2986
2695
  };
2987
2696
 
2988
- const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
2989
- std::string template_key = "tokenizer.chat_template", curr_tmpl;
2990
- int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
2991
- if (tlen > 0) {
2992
- std::vector<char> curr_tmpl_buf(tlen + 1, 0);
2993
- if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
2994
- curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
2995
- }
2996
- }
2997
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2697
+ const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
2998
2698
  json data = {
2999
- { "system_prompt", ctx_server.system_prompt.c_str() },
3000
2699
  { "default_generation_settings", ctx_server.default_generation_settings_for_props },
3001
2700
  { "total_slots", ctx_server.params.n_parallel },
3002
- { "chat_template", curr_tmpl.c_str() }
2701
+ { "chat_template", llama_get_chat_template(ctx_server.model) },
3003
2702
  };
3004
2703
 
3005
- res.set_content(data.dump(), "application/json; charset=utf-8");
2704
+ res_ok(res, data);
3006
2705
  };
3007
2706
 
3008
- const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3009
- if (ctx_server.params.embedding) {
3010
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2707
+ const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
2708
+ if (!ctx_server.params.endpoint_props) {
2709
+ res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
3011
2710
  return;
3012
2711
  }
3013
2712
 
3014
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3015
-
3016
2713
  json data = json::parse(req.body);
3017
2714
 
3018
- const int id_task = ctx_server.queue_tasks.get_new_id();
2715
+ // update any props here
3019
2716
 
3020
- ctx_server.queue_results.add_waiting_task_id(id_task);
3021
- ctx_server.request_completion(id_task, -1, data, false, false);
2717
+ res_ok(res, {{ "success", true }});
2718
+ };
3022
2719
 
3023
- if (!json_value(data, "stream", false)) {
3024
- server_task_result result = ctx_server.queue_results.recv(id_task);
3025
- if (!result.error && result.stop) {
3026
- res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
3027
- } else {
3028
- res_error(res, result.data);
3029
- }
2720
+ const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
2721
+ if (ctx_server.params.embedding) {
2722
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2723
+ return;
2724
+ }
3030
2725
 
3031
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3032
- } else {
3033
- const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
3034
- while (true) {
3035
- server_task_result result = ctx_server.queue_results.recv(id_task);
3036
- if (!result.error) {
3037
- const std::string str =
3038
- "data: " +
3039
- result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3040
- "\n\n";
3041
-
3042
- LOG_VERBOSE("data stream", {
3043
- { "to_send", str }
3044
- });
3045
-
3046
- if (!sink.write(str.c_str(), str.size())) {
3047
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3048
- return false;
3049
- }
2726
+ std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
2727
+ ctx_server.queue_results.add_waiting_tasks(tasks);
2728
+ ctx_server.queue_tasks.post(tasks);
3050
2729
 
3051
- if (result.stop) {
3052
- break;
3053
- }
3054
- } else {
3055
- const std::string str =
3056
- "error: " +
3057
- result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3058
- "\n\n";
3059
-
3060
- LOG_VERBOSE("data stream", {
3061
- { "to_send", str }
3062
- });
3063
-
3064
- if (!sink.write(str.c_str(), str.size())) {
3065
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3066
- return false;
3067
- }
2730
+ bool stream = json_value(data, "stream", false);
2731
+ const auto task_ids = server_task::get_list_id(tasks);
3068
2732
 
3069
- break;
2733
+ if (!stream) {
2734
+ ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
2735
+ if (results.size() == 1) {
2736
+ // single result
2737
+ res_ok(res, results[0].data);
2738
+ } else {
2739
+ // multiple results (multitask)
2740
+ json arr = json::array();
2741
+ for (const auto & res : results) {
2742
+ arr.push_back(res.data);
3070
2743
  }
2744
+ res_ok(res, arr);
3071
2745
  }
2746
+ }, [&](const json & error_data) {
2747
+ res_error(res, error_data);
2748
+ });
3072
2749
 
3073
- ctx_server.queue_results.remove_waiting_task_id(id_task);
2750
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2751
+ } else {
2752
+ const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
2753
+ ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
2754
+ return server_sent_event(sink, "data", result.data);
2755
+ }, [&](const json & error_data) {
2756
+ server_sent_event(sink, "error", error_data);
2757
+ });
3074
2758
  sink.done();
3075
-
3076
- return true;
2759
+ return false;
3077
2760
  };
3078
2761
 
3079
- auto on_complete = [id_task, &ctx_server] (bool) {
3080
- // cancel
3081
- ctx_server.request_cancel(id_task);
3082
- ctx_server.queue_results.remove_waiting_task_id(id_task);
2762
+ auto on_complete = [task_ids, &ctx_server] (bool) {
2763
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
3083
2764
  };
3084
2765
 
3085
2766
  res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3086
2767
  }
3087
2768
  };
3088
2769
 
3089
- const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {
3090
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3091
-
3092
- json models = {
3093
- {"object", "list"},
3094
- {"data", {
3095
- {
3096
- {"id", params.model_alias},
3097
- {"object", "model"},
3098
- {"created", std::time(0)},
3099
- {"owned_by", "llamacpp"},
3100
- {"meta", model_meta}
3101
- },
3102
- }}
3103
- };
3104
-
3105
- res.set_content(models.dump(), "application/json; charset=utf-8");
2770
+ const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2771
+ json data = json::parse(req.body);
2772
+ return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
3106
2773
  };
3107
2774
 
3108
- const auto handle_chat_completions = [&ctx_server, &params, &res_error](const httplib::Request & req, httplib::Response & res) {
3109
- if (ctx_server.params.embedding) {
3110
- res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2775
+ const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2776
+ // check model compatibility
2777
+ std::string err;
2778
+ if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
2779
+ err += "prefix token is missing. ";
2780
+ }
2781
+ if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
2782
+ err += "suffix token is missing. ";
2783
+ }
2784
+ if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
2785
+ err += "middle token is missing. ";
2786
+ }
2787
+ if (!err.empty()) {
2788
+ res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
3111
2789
  return;
3112
2790
  }
3113
2791
 
3114
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3115
- json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
3116
-
3117
- const int id_task = ctx_server.queue_tasks.get_new_id();
3118
-
3119
- ctx_server.queue_results.add_waiting_task_id(id_task);
3120
- ctx_server.request_completion(id_task, -1, data, false, false);
2792
+ json data = json::parse(req.body);
3121
2793
 
3122
- const auto completion_id = gen_chatcmplid();
3123
- if (!json_value(data, "stream", false)) {
3124
- server_task_result result = ctx_server.queue_results.recv(id_task);
2794
+ // validate input
2795
+ if (!data.contains("input_prefix")) {
2796
+ res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
2797
+ }
3125
2798
 
3126
- if (!result.error && result.stop) {
3127
- json result_oai = format_final_response_oaicompat(data, result.data, completion_id);
2799
+ if (!data.contains("input_suffix")) {
2800
+ res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
2801
+ }
3128
2802
 
3129
- res.set_content(result_oai.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
3130
- } else {
3131
- res_error(res, result.data);
2803
+ if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
2804
+ res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
2805
+ return;
2806
+ }
2807
+ json input_extra = json_value(data, "input_extra", json::array());
2808
+ for (const auto & chunk : input_extra) {
2809
+ // { "text": string, "filename": string }
2810
+ if (!chunk.contains("text") || !chunk.at("text").is_string()) {
2811
+ res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
2812
+ return;
2813
+ }
2814
+ // filename is optional
2815
+ if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
2816
+ res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
2817
+ return;
3132
2818
  }
3133
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3134
- } else {
3135
- const auto chunked_content_provider = [id_task, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
3136
- while (true) {
3137
- server_task_result result = ctx_server.queue_results.recv(id_task);
3138
- if (!result.error) {
3139
- std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
3140
-
3141
- for (auto it = result_array.begin(); it != result_array.end(); ++it) {
3142
- if (!it->empty()) {
3143
- const std::string str =
3144
- "data: " +
3145
- it->dump(-1, ' ', false, json::error_handler_t::replace) +
3146
- "\n\n";
3147
- LOG_VERBOSE("data stream", {{"to_send", str}});
3148
- if (!sink.write(str.c_str(), str.size())) {
3149
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3150
- return false;
3151
- }
3152
- }
3153
- }
3154
- if (result.stop) {
3155
- break;
3156
- }
3157
- } else {
3158
- const std::string str =
3159
- "error: " +
3160
- result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3161
- "\n\n";
3162
- LOG_VERBOSE("data stream", {{"to_send", str}});
3163
- if (!sink.write(str.c_str(), str.size())) {
3164
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3165
- return false;
3166
- }
3167
- break;
3168
- }
3169
- }
3170
- sink.done();
3171
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3172
- return true;
3173
- };
3174
-
3175
- auto on_complete = [id_task, &ctx_server](bool) {
3176
- // cancel request
3177
- ctx_server.request_cancel(id_task);
3178
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3179
- };
3180
-
3181
- res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3182
2819
  }
2820
+ data["input_extra"] = input_extra; // default to empty array if it's not exist
2821
+
2822
+ return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
3183
2823
  };
3184
2824
 
3185
- const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
2825
+ // TODO: maybe merge this function with "handle_completions_generic"
2826
+ const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
3186
2827
  if (ctx_server.params.embedding) {
3187
- res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2828
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
3188
2829
  return;
3189
2830
  }
3190
2831
 
3191
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3192
-
3193
- json data = json::parse(req.body);
2832
+ json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
3194
2833
 
3195
- const int id_task = ctx_server.queue_tasks.get_new_id();
2834
+ std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
2835
+ ctx_server.queue_results.add_waiting_tasks(tasks);
2836
+ ctx_server.queue_tasks.post(tasks);
3196
2837
 
3197
- ctx_server.queue_results.add_waiting_task_id(id_task);
3198
- ctx_server.request_completion(id_task, -1, data, true, false);
2838
+ bool stream = json_value(data, "stream", false);
2839
+ const auto task_ids = server_task::get_list_id(tasks);
2840
+ const auto completion_id = gen_chatcmplid();
3199
2841
 
3200
- if (!json_value(data, "stream", false)) {
3201
- server_task_result result = ctx_server.queue_results.recv(id_task);
3202
- if (!result.error && result.stop) {
3203
- res.set_content(result.data.dump(-1, ' ', false, json::error_handler_t::replace), "application/json; charset=utf-8");
3204
- } else {
3205
- res_error(res, result.data);
3206
- }
2842
+ if (!stream) {
2843
+ ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
2844
+ // multitask is never support in chat completion, there is only one result
2845
+ json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
2846
+ res_ok(res, result_oai);
2847
+ }, [&](const json & error_data) {
2848
+ res_error(res, error_data);
2849
+ });
3207
2850
 
3208
- ctx_server.queue_results.remove_waiting_task_id(id_task);
2851
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
3209
2852
  } else {
3210
- const auto chunked_content_provider = [id_task, &ctx_server](size_t, httplib::DataSink & sink) {
3211
- while (true) {
3212
- server_task_result result = ctx_server.queue_results.recv(id_task);
3213
- if (!result.error) {
3214
- const std::string str =
3215
- "data: " +
3216
- result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
3217
- "\n\n";
3218
-
3219
- LOG_VERBOSE("data stream", {
3220
- { "to_send", str }
3221
- });
3222
-
3223
- if (!sink.write(str.c_str(), str.size())) {
3224
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3225
- return false;
2853
+ const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
2854
+ ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
2855
+ std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
2856
+ for (auto & event_data : result_array) {
2857
+ if (event_data.empty()) {
2858
+ continue; // skip the stop token
3226
2859
  }
3227
-
3228
- if (result.stop) {
3229
- break;
2860
+ if (!server_sent_event(sink, "data", event_data)) {
2861
+ return false; // connection is closed
3230
2862
  }
3231
- } else {
3232
- break;
3233
2863
  }
3234
- }
3235
-
3236
- ctx_server.queue_results.remove_waiting_task_id(id_task);
2864
+ return true; // ok
2865
+ }, [&](const json & error_data) {
2866
+ server_sent_event(sink, "error", error_data);
2867
+ });
2868
+ static const std::string ev_done = "data: [DONE]\n\n";
2869
+ sink.write(ev_done.data(), ev_done.size());
3237
2870
  sink.done();
3238
-
3239
2871
  return true;
3240
2872
  };
3241
2873
 
3242
- auto on_complete = [id_task, &ctx_server] (bool) {
3243
- ctx_server.request_cancel(id_task);
2874
+ auto on_complete = [task_ids, &ctx_server] (bool) {
2875
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
3244
2876
  };
3245
2877
 
3246
2878
  res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3247
2879
  }
3248
2880
  };
3249
2881
 
3250
- const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3251
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2882
+ const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
2883
+ json models = {
2884
+ {"object", "list"},
2885
+ {"data", {
2886
+ {
2887
+ {"id", params.model_alias},
2888
+ {"object", "model"},
2889
+ {"created", std::time(0)},
2890
+ {"owned_by", "llamacpp"},
2891
+ {"meta", ctx_server.model_meta()}
2892
+ },
2893
+ }}
2894
+ };
2895
+
2896
+ res.set_content(models.dump(), MIMETYPE_JSON);
2897
+ };
2898
+
2899
+ const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
3252
2900
  const json body = json::parse(req.body);
3253
2901
 
3254
- std::vector<llama_token> tokens;
2902
+ json tokens_response = json::array();
3255
2903
  if (body.count("content") != 0) {
3256
2904
  const bool add_special = json_value(body, "add_special", false);
3257
- tokens = ctx_server.tokenize(body.at("content"), add_special);
2905
+ const bool with_pieces = json_value(body, "with_pieces", false);
2906
+
2907
+ llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
2908
+
2909
+ if (with_pieces) {
2910
+ for (const auto& token : tokens) {
2911
+ std::string piece = common_token_to_piece(ctx_server.ctx, token);
2912
+ json piece_json;
2913
+
2914
+ // Check if the piece is valid UTF-8
2915
+ if (is_valid_utf8(piece)) {
2916
+ piece_json = piece;
2917
+ } else {
2918
+ // If not valid UTF-8, store as array of byte values
2919
+ piece_json = json::array();
2920
+ for (unsigned char c : piece) {
2921
+ piece_json.push_back(static_cast<int>(c));
2922
+ }
2923
+ }
2924
+
2925
+ tokens_response.push_back({
2926
+ {"id", token},
2927
+ {"piece", piece_json}
2928
+ });
2929
+ }
2930
+ } else {
2931
+ tokens_response = tokens;
2932
+ }
3258
2933
  }
3259
- const json data = format_tokenizer_response(tokens);
3260
- return res.set_content(data.dump(), "application/json; charset=utf-8");
2934
+
2935
+ const json data = format_tokenizer_response(tokens_response);
2936
+ res_ok(res, data);
3261
2937
  };
3262
2938
 
3263
- const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
3264
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2939
+ const auto handle_detokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
3265
2940
  const json body = json::parse(req.body);
3266
2941
 
3267
2942
  std::string content;
3268
2943
  if (body.count("tokens") != 0) {
3269
- const std::vector<llama_token> tokens = body.at("tokens");
2944
+ const llama_tokens tokens = body.at("tokens");
3270
2945
  content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
3271
2946
  }
3272
2947
 
3273
2948
  const json data = format_detokenized_response(content);
3274
- return res.set_content(data.dump(), "application/json; charset=utf-8");
2949
+ res_ok(res, data);
3275
2950
  };
3276
2951
 
3277
- const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
3278
- res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3279
-
2952
+ const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3280
2953
  const json body = json::parse(req.body);
3281
2954
  bool is_openai = false;
3282
2955
 
@@ -3294,42 +2967,157 @@ int main(int argc, char ** argv) {
3294
2967
  }
3295
2968
 
3296
2969
  // create and queue the task
3297
- json responses;
2970
+ json responses = json::array();
2971
+ bool error = false;
3298
2972
  {
3299
- const int id_task = ctx_server.queue_tasks.get_new_id();
3300
- ctx_server.queue_results.add_waiting_task_id(id_task);
3301
- ctx_server.request_completion(id_task, -1, {{"prompt", prompt}}, false, true);
2973
+ std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
2974
+ ctx_server.queue_results.add_waiting_tasks(tasks);
2975
+ ctx_server.queue_tasks.post(tasks);
3302
2976
 
3303
2977
  // get the result
3304
- server_task_result result = ctx_server.queue_results.recv(id_task);
3305
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3306
- if (!result.error) {
3307
- if (result.data.count("results")) {
3308
- // result for multi-task
3309
- responses = result.data.at("results");
3310
- } else {
3311
- // result for single task
3312
- responses = std::vector<json>{result.data};
2978
+ std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
2979
+
2980
+ ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
2981
+ for (const auto & res : results) {
2982
+ responses.push_back(res.data);
3313
2983
  }
3314
- } else {
3315
- // error received, ignore everything else
3316
- res_error(res, result.data);
3317
- return;
3318
- }
2984
+ }, [&](const json & error_data) {
2985
+ res_error(res, error_data);
2986
+ error = true;
2987
+ });
2988
+
2989
+ ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2990
+ }
2991
+
2992
+ if (error) {
2993
+ return;
3319
2994
  }
3320
2995
 
3321
2996
  // write JSON response
3322
2997
  json root = is_openai
3323
2998
  ? format_embeddings_response_oaicompat(body, responses)
3324
2999
  : responses[0];
3325
- return res.set_content(root.dump(), "application/json; charset=utf-8");
3000
+ res_ok(res, root);
3326
3001
  };
3327
3002
 
3328
- auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
3329
- return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
3330
- res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
3331
- return false;
3332
- };
3003
+ const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3004
+ if (!ctx_server.params.reranking || ctx_server.params.embedding) {
3005
+ res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
3006
+ return;
3007
+ }
3008
+
3009
+ const json body = json::parse(req.body);
3010
+
3011
+ // TODO: implement
3012
+ //int top_n = 1;
3013
+ //if (body.count("top_n") != 1) {
3014
+ // top_n = body.at("top_n");
3015
+ //} else {
3016
+ // res_error(res, format_error_response("\"top_n\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3017
+ // return;
3018
+ //}
3019
+
3020
+ json query;
3021
+ if (body.count("query") == 1) {
3022
+ query = body.at("query");
3023
+ if (!query.is_string()) {
3024
+ res_error(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3025
+ return;
3026
+ }
3027
+ } else {
3028
+ res_error(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3029
+ return;
3030
+ }
3031
+
3032
+ std::vector<std::string> documents = json_value(body, "documents", std::vector<std::string>());
3033
+ if (documents.empty()) {
3034
+ res_error(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST));
3035
+ return;
3036
+ }
3037
+
3038
+ // construct prompt object: array of ["query", "doc0", "doc1", ...]
3039
+ json prompt;
3040
+ prompt.push_back(query);
3041
+ for (const auto & doc : documents) {
3042
+ prompt.push_back(doc);
3043
+ }
3044
+
3045
+ LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
3046
+
3047
+ // create and queue the task
3048
+ json responses = json::array();
3049
+ bool error = false;
3050
+ {
3051
+ std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
3052
+ ctx_server.queue_results.add_waiting_tasks(tasks);
3053
+ ctx_server.queue_tasks.post(tasks);
3054
+
3055
+ // get the result
3056
+ std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
3057
+
3058
+ ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3059
+ for (const auto & res : results) {
3060
+ responses.push_back(res.data);
3061
+ }
3062
+ }, [&](const json & error_data) {
3063
+ res_error(res, error_data);
3064
+ error = true;
3065
+ });
3066
+ }
3067
+
3068
+ if (error) {
3069
+ return;
3070
+ }
3071
+
3072
+ // write JSON response
3073
+ json root = format_response_rerank(body, responses);
3074
+ res_ok(res, root);
3075
+ };
3076
+
3077
+ const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
3078
+ json result = json::array();
3079
+ for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
3080
+ auto & lora = ctx_server.loras[i];
3081
+ result.push_back({
3082
+ {"id", i},
3083
+ {"path", lora.path},
3084
+ {"scale", lora.scale},
3085
+ });
3086
+ }
3087
+ res_ok(res, result);
3088
+ res.status = 200; // HTTP OK
3089
+ };
3090
+
3091
+ const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
3092
+ const std::vector<json> body = json::parse(req.body);
3093
+ int max_idx = ctx_server.loras.size();
3094
+
3095
+ // clear existing value
3096
+ for (auto & lora : ctx_server.loras) {
3097
+ lora.scale = 0.0f;
3098
+ }
3099
+
3100
+ // set value
3101
+ for (auto entry : body) {
3102
+ int id = entry.at("id");
3103
+ float scale = entry.at("scale");
3104
+ if (0 <= id && id < max_idx) {
3105
+ ctx_server.loras[id].scale = scale;
3106
+ } else {
3107
+ throw std::runtime_error("invalid adapter id");
3108
+ }
3109
+ }
3110
+
3111
+ server_task task;
3112
+ task.type = SERVER_TASK_TYPE_SET_LORA;
3113
+ const int id_task = ctx_server.queue_tasks.post(task);
3114
+ ctx_server.queue_results.add_waiting_task_id(id_task);
3115
+
3116
+ server_task_result result = ctx_server.queue_results.recv(id_task);
3117
+ ctx_server.queue_results.remove_waiting_task_id(id_task);
3118
+
3119
+ res_ok(res, result.data);
3120
+ res.status = 200; // HTTP OK
3333
3121
  };
3334
3122
 
3335
3123
  //
@@ -3339,34 +3127,29 @@ int main(int argc, char ** argv) {
3339
3127
  // register static assets routes
3340
3128
  if (!params.public_path.empty()) {
3341
3129
  // Set the base directory for serving static files
3342
- svr->set_base_dir(params.public_path);
3343
- }
3344
-
3345
- // using embedded static files
3346
- svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
3347
- svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
3348
- svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
3349
- svr->Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
3350
-
3351
- // add new-ui files
3352
- svr->Get("/colorthemes.css", handle_static_file(colorthemes_css, colorthemes_css_len, "text/css; charset=utf-8"));
3353
- svr->Get("/style.css", handle_static_file(style_css, style_css_len, "text/css; charset=utf-8"));
3354
- svr->Get("/theme-beeninorder.css", handle_static_file(theme_beeninorder_css, theme_beeninorder_css_len, "text/css; charset=utf-8"));
3355
- svr->Get("/theme-ketivah.css", handle_static_file(theme_ketivah_css, theme_ketivah_css_len, "text/css; charset=utf-8"));
3356
- svr->Get("/theme-mangotango.css", handle_static_file(theme_mangotango_css, theme_mangotango_css_len, "text/css; charset=utf-8"));
3357
- svr->Get("/theme-playground.css", handle_static_file(theme_playground_css, theme_playground_css_len, "text/css; charset=utf-8"));
3358
- svr->Get("/theme-polarnight.css", handle_static_file(theme_polarnight_css, theme_polarnight_css_len, "text/css; charset=utf-8"));
3359
- svr->Get("/theme-snowstorm.css", handle_static_file(theme_snowstorm_css, theme_snowstorm_css_len, "text/css; charset=utf-8"));
3360
- svr->Get("/index-new.html", handle_static_file(index_new_html, index_new_html_len, "text/html; charset=utf-8"));
3361
- svr->Get("/system-prompts.js", handle_static_file(system_prompts_js, system_prompts_js_len, "text/javascript; charset=utf-8"));
3362
- svr->Get("/prompt-formats.js", handle_static_file(prompt_formats_js, prompt_formats_js_len, "text/javascript; charset=utf-8"));
3130
+ bool is_found = svr->set_mount_point("/", params.public_path);
3131
+ if (!is_found) {
3132
+ LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
3133
+ return 1;
3134
+ }
3135
+ } else {
3136
+ // using embedded static files
3137
+ for (const auto & it : static_files) {
3138
+ const server_static_file & static_file = it.second;
3139
+ svr->Get(it.first.c_str(), [&static_file](const httplib::Request &, httplib::Response & res) {
3140
+ res.set_content(reinterpret_cast<const char*>(static_file.data), static_file.size, static_file.mime_type);
3141
+ return false;
3142
+ });
3143
+ }
3144
+ }
3363
3145
 
3364
3146
  // register API routes
3365
- svr->Get ("/health", handle_health);
3366
- svr->Get ("/slots", handle_slots);
3147
+ svr->Get ("/health", handle_health); // public endpoint (no API key check)
3367
3148
  svr->Get ("/metrics", handle_metrics);
3368
3149
  svr->Get ("/props", handle_props);
3369
- svr->Get ("/v1/models", handle_models);
3150
+ svr->Post("/props", handle_props_change);
3151
+ svr->Get ("/models", handle_models); // public endpoint (no API key check)
3152
+ svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
3370
3153
  svr->Post("/completion", handle_completions); // legacy
3371
3154
  svr->Post("/completions", handle_completions);
3372
3155
  svr->Post("/v1/completions", handle_completions);
@@ -3376,12 +3159,18 @@ int main(int argc, char ** argv) {
3376
3159
  svr->Post("/embedding", handle_embeddings); // legacy
3377
3160
  svr->Post("/embeddings", handle_embeddings);
3378
3161
  svr->Post("/v1/embeddings", handle_embeddings);
3162
+ svr->Post("/rerank", handle_rerank);
3163
+ svr->Post("/reranking", handle_rerank);
3164
+ svr->Post("/v1/rerank", handle_rerank);
3165
+ svr->Post("/v1/reranking", handle_rerank);
3379
3166
  svr->Post("/tokenize", handle_tokenize);
3380
3167
  svr->Post("/detokenize", handle_detokenize);
3381
- if (!params.slot_save_path.empty()) {
3382
- // only enable slot endpoints if slot_save_path is set
3383
- svr->Post("/slots/:id_slot", handle_slots_action);
3384
- }
3168
+ // LoRA adapters hotswap
3169
+ svr->Get ("/lora-adapters", handle_lora_adapters_list);
3170
+ svr->Post("/lora-adapters", handle_lora_adapters_apply);
3171
+ // Save & load slots
3172
+ svr->Get ("/slots", handle_slots);
3173
+ svr->Post("/slots/:id_slot", handle_slots_action);
3385
3174
 
3386
3175
  //
3387
3176
  // Start the server
@@ -3393,36 +3182,67 @@ int main(int argc, char ** argv) {
3393
3182
  log_data["n_threads_http"] = std::to_string(params.n_threads_http);
3394
3183
  svr->new_task_queue = [&params] { return new httplib::ThreadPool(params.n_threads_http); };
3395
3184
 
3396
- LOG_INFO("HTTP server listening", log_data);
3185
+ // clean up function, to be called before exit
3186
+ auto clean_up = [&svr]() {
3187
+ svr->stop();
3188
+ llama_backend_free();
3189
+ };
3190
+
3191
+ // bind HTTP listen port, run the HTTP server in a thread
3192
+ if (!svr->bind_to_port(params.hostname, params.port)) {
3193
+ //LOG_ERROR("couldn't bind HTTP server socket", {
3194
+ // {"hostname", params.hostname},
3195
+ // {"port", params.port},
3196
+ //});
3197
+ LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port);
3198
+ clean_up();
3199
+ return 1;
3200
+ }
3201
+ std::thread t([&]() { svr->listen_after_bind(); });
3202
+ svr->wait_until_ready();
3203
+
3204
+ LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http);
3397
3205
 
3398
- // run the HTTP server in a thread - see comment below
3399
- std::thread t([&]() {
3400
- if (!svr->listen_after_bind()) {
3401
- state.store(SERVER_STATE_ERROR);
3402
- return 1;
3206
+ // load the model
3207
+ LOG_INF("%s: loading model\n", __func__);
3208
+
3209
+ if (!ctx_server.load_model(params)) {
3210
+ clean_up();
3211
+ t.join();
3212
+ LOG_ERR("%s: exiting due to model loading error\n", __func__);
3213
+ return 1;
3214
+ }
3215
+
3216
+ ctx_server.init();
3217
+ state.store(SERVER_STATE_READY);
3218
+
3219
+ LOG_INF("%s: model loaded\n", __func__);
3220
+
3221
+ // if a custom chat template is not supplied, we will use the one that comes with the model (if any)
3222
+ if (params.chat_template.empty()) {
3223
+ if (!ctx_server.validate_model_chat_template()) {
3224
+ LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
3225
+ params.chat_template = "chatml";
3403
3226
  }
3227
+ }
3404
3228
 
3405
- return 0;
3406
- });
3229
+ // print sample chat example to make it clear which template is used
3230
+ LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
3407
3231
 
3408
3232
  ctx_server.queue_tasks.on_new_task(std::bind(
3409
- &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3410
- ctx_server.queue_tasks.on_finish_multitask(std::bind(
3411
- &server_context::on_finish_multitask, &ctx_server, std::placeholders::_1));
3233
+ &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3234
+
3412
3235
  ctx_server.queue_tasks.on_update_slots(std::bind(
3413
- &server_context::update_slots, &ctx_server));
3414
- ctx_server.queue_results.on_multitask_update(std::bind(
3415
- &server_queue::update_multitask,
3416
- &ctx_server.queue_tasks,
3417
- std::placeholders::_1,
3418
- std::placeholders::_2,
3419
- std::placeholders::_3
3420
- ));
3236
+ &server_context::update_slots, &ctx_server));
3421
3237
 
3422
3238
  shutdown_handler = [&](int) {
3423
3239
  ctx_server.queue_tasks.terminate();
3424
3240
  };
3425
3241
 
3242
+ LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
3243
+
3244
+ ctx_server.queue_tasks.start_loop();
3245
+
3426
3246
  #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
3427
3247
  struct sigaction sigint_action;
3428
3248
  sigint_action.sa_handler = signal_handler;
@@ -3437,12 +3257,8 @@ int main(int argc, char ** argv) {
3437
3257
  SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
3438
3258
  #endif
3439
3259
 
3440
- ctx_server.queue_tasks.start_loop();
3441
-
3442
- svr->stop();
3260
+ clean_up();
3443
3261
  t.join();
3444
3262
 
3445
- llama_backend_free();
3446
-
3447
3263
  return 0;
3448
3264
  }