@fugood/llama.node 0.3.2 → 0.3.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (286) hide show
  1. package/CMakeLists.txt +7 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/lib/binding.ts +18 -1
  17. package/package.json +1 -1
  18. package/src/DetokenizeWorker.cpp +1 -1
  19. package/src/EmbeddingWorker.cpp +17 -7
  20. package/src/EmbeddingWorker.h +2 -1
  21. package/src/LlamaCompletionWorker.cpp +8 -8
  22. package/src/LlamaCompletionWorker.h +2 -2
  23. package/src/LlamaContext.cpp +89 -27
  24. package/src/LlamaContext.h +2 -0
  25. package/src/TokenizeWorker.cpp +1 -1
  26. package/src/common.hpp +4 -4
  27. package/src/llama.cpp/.github/workflows/build.yml +240 -168
  28. package/src/llama.cpp/.github/workflows/docker.yml +8 -8
  29. package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
  30. package/src/llama.cpp/.github/workflows/server.yml +21 -14
  31. package/src/llama.cpp/CMakeLists.txt +14 -6
  32. package/src/llama.cpp/Sources/llama/llama.h +4 -0
  33. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  34. package/src/llama.cpp/cmake/common.cmake +33 -0
  35. package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
  36. package/src/llama.cpp/common/CMakeLists.txt +6 -4
  37. package/src/llama.cpp/common/arg.cpp +986 -770
  38. package/src/llama.cpp/common/arg.h +22 -22
  39. package/src/llama.cpp/common/common.cpp +212 -351
  40. package/src/llama.cpp/common/common.h +204 -117
  41. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  42. package/src/llama.cpp/common/log.cpp +50 -50
  43. package/src/llama.cpp/common/log.h +18 -18
  44. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  45. package/src/llama.cpp/common/ngram-cache.h +19 -19
  46. package/src/llama.cpp/common/sampling.cpp +163 -121
  47. package/src/llama.cpp/common/sampling.h +41 -20
  48. package/src/llama.cpp/common/speculative.cpp +274 -0
  49. package/src/llama.cpp/common/speculative.h +28 -0
  50. package/src/llama.cpp/docs/build.md +134 -161
  51. package/src/llama.cpp/examples/CMakeLists.txt +33 -14
  52. package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
  53. package/src/llama.cpp/examples/batched/batched.cpp +19 -18
  54. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
  55. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  56. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
  57. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  58. package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
  59. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  60. package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
  61. package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
  62. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  63. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
  64. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  65. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
  66. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  67. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
  68. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
  69. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
  70. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  71. package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
  72. package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
  73. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
  74. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
  75. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
  76. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  77. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
  78. package/src/llama.cpp/examples/imatrix/imatrix.cpp +31 -13
  79. package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
  80. package/src/llama.cpp/examples/infill/infill.cpp +41 -87
  81. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
  82. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +439 -459
  83. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +2 -0
  84. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  85. package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
  86. package/src/llama.cpp/examples/llava/clip.cpp +263 -66
  87. package/src/llama.cpp/examples/llava/clip.h +8 -2
  88. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  89. package/src/llama.cpp/examples/llava/llava.cpp +83 -22
  90. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  91. package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
  92. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
  93. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  94. package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
  95. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  96. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  97. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +16 -15
  98. package/src/llama.cpp/examples/lookup/lookup.cpp +30 -30
  99. package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
  100. package/src/llama.cpp/examples/main/main.cpp +73 -114
  101. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
  102. package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
  103. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  104. package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
  105. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  106. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
  107. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  108. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  109. package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
  110. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
  111. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  112. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
  113. package/src/llama.cpp/examples/retrieval/retrieval.cpp +16 -16
  114. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  115. package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
  116. package/src/llama.cpp/examples/run/run.cpp +911 -0
  117. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
  118. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +38 -21
  119. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -16
  120. package/src/llama.cpp/examples/server/server.cpp +2073 -1339
  121. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
  122. package/src/llama.cpp/examples/server/utils.hpp +354 -277
  123. package/src/llama.cpp/examples/simple/CMakeLists.txt +2 -2
  124. package/src/llama.cpp/examples/simple/simple.cpp +130 -94
  125. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  126. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +200 -0
  127. package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
  128. package/src/llama.cpp/examples/speculative/speculative.cpp +68 -64
  129. package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
  130. package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
  131. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
  132. package/src/llama.cpp/examples/tokenize/tokenize.cpp +3 -3
  133. package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
  134. package/src/llama.cpp/examples/tts/tts.cpp +932 -0
  135. package/src/llama.cpp/ggml/CMakeLists.txt +54 -36
  136. package/src/llama.cpp/ggml/include/ggml-backend.h +63 -34
  137. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  138. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  139. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  140. package/src/llama.cpp/ggml/include/ggml-cpu.h +135 -0
  141. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  142. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  143. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  144. package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
  145. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  146. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  147. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  148. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  149. package/src/llama.cpp/ggml/include/ggml.h +159 -417
  150. package/src/llama.cpp/ggml/src/CMakeLists.txt +121 -1155
  151. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -28
  152. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +57 -36
  153. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +552 -0
  154. package/src/llama.cpp/ggml/src/ggml-backend.cpp +306 -867
  155. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +87 -0
  156. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +216 -65
  157. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +76 -0
  158. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
  159. package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
  160. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +343 -177
  161. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
  162. package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
  163. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
  164. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
  165. package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
  166. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
  167. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
  168. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
  169. package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
  170. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +336 -0
  171. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
  172. package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
  173. package/src/llama.cpp/ggml/src/ggml-cpu/amx/common.h +91 -0
  174. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.cpp +2511 -0
  175. package/src/llama.cpp/ggml/src/ggml-cpu/amx/mmq.h +10 -0
  176. package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
  177. package/src/llama.cpp/ggml/src/{ggml-aarch64.c → ggml-cpu/ggml-cpu-aarch64.cpp} +1299 -246
  178. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +8 -0
  179. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
  180. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
  181. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +14 -242
  182. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10835 -0
  183. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  184. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
  185. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
  186. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +14123 -0
  187. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +628 -0
  188. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +666 -0
  189. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +152 -0
  190. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
  191. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +104 -0
  192. package/src/llama.cpp/ggml/src/ggml-impl.h +393 -22
  193. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +166 -0
  194. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +360 -127
  195. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +105 -0
  196. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +288 -0
  197. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +107 -0
  198. package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
  199. package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
  200. package/src/llama.cpp/ggml/src/ggml-opt.cpp +854 -0
  201. package/src/llama.cpp/ggml/src/ggml-quants.c +188 -10702
  202. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  203. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +9 -0
  204. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +478 -300
  205. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +84 -0
  206. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  207. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +36 -5
  208. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +259 -0
  209. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +3 -2
  210. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
  211. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
  212. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +34 -35
  213. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1030 -0
  214. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  215. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
  216. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3638 -4151
  217. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
  218. package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
  219. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +75 -87
  220. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +7 -6
  221. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  222. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  223. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  224. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
  225. package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
  226. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
  227. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  228. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +141 -0
  229. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  230. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  231. package/src/llama.cpp/ggml/src/ggml-threading.h +14 -0
  232. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +92 -0
  233. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2138 -887
  234. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +3 -1
  235. package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +593 -0
  236. package/src/llama.cpp/ggml/src/ggml.c +4427 -20125
  237. package/src/llama.cpp/include/llama-cpp.h +25 -0
  238. package/src/llama.cpp/include/llama.h +93 -52
  239. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
  240. package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
  241. package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
  242. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
  243. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  244. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  245. package/src/llama.cpp/src/CMakeLists.txt +4 -8
  246. package/src/llama.cpp/src/llama-grammar.cpp +15 -15
  247. package/src/llama.cpp/src/llama-grammar.h +2 -5
  248. package/src/llama.cpp/src/llama-sampling.cpp +779 -194
  249. package/src/llama.cpp/src/llama-sampling.h +21 -2
  250. package/src/llama.cpp/src/llama-vocab.cpp +55 -10
  251. package/src/llama.cpp/src/llama-vocab.h +35 -11
  252. package/src/llama.cpp/src/llama.cpp +4317 -2979
  253. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  254. package/src/llama.cpp/src/unicode.cpp +62 -51
  255. package/src/llama.cpp/src/unicode.h +9 -10
  256. package/src/llama.cpp/tests/CMakeLists.txt +48 -38
  257. package/src/llama.cpp/tests/test-arg-parser.cpp +15 -15
  258. package/src/llama.cpp/tests/test-backend-ops.cpp +324 -80
  259. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  260. package/src/llama.cpp/tests/test-chat-template.cpp +59 -9
  261. package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
  262. package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
  263. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  264. package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
  265. package/src/llama.cpp/tests/test-log.cpp +2 -2
  266. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  267. package/src/llama.cpp/tests/test-quantize-fns.cpp +24 -21
  268. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  269. package/src/llama.cpp/tests/test-rope.cpp +62 -20
  270. package/src/llama.cpp/tests/test-sampling.cpp +163 -138
  271. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  272. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  273. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  274. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
  275. package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
  276. package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
  277. package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
  278. package/src/llama.cpp/common/train.cpp +0 -1515
  279. package/src/llama.cpp/common/train.h +0 -233
  280. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  281. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  282. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -39
  283. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +0 -600
  284. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  285. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  286. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
@@ -2,10 +2,11 @@
2
2
 
3
3
  #include "arg.h"
4
4
  #include "common.h"
5
- #include "log.h"
6
- #include "sampling.h"
7
5
  #include "json-schema-to-grammar.h"
8
6
  #include "llama.h"
7
+ #include "log.h"
8
+ #include "sampling.h"
9
+ #include "speculative.h"
9
10
 
10
11
  // Change JSON_ASSERT from assert() to GGML_ASSERT:
11
12
  #define JSON_ASSERT GGML_ASSERT
@@ -14,21 +15,7 @@
14
15
  #define MIMETYPE_JSON "application/json; charset=utf-8"
15
16
 
16
17
  // auto generated files (update with ./deps.sh)
17
- #include "colorthemes.css.hpp"
18
- #include "style.css.hpp"
19
- #include "theme-beeninorder.css.hpp"
20
- #include "theme-ketivah.css.hpp"
21
- #include "theme-mangotango.css.hpp"
22
- #include "theme-playground.css.hpp"
23
- #include "theme-polarnight.css.hpp"
24
- #include "theme-snowstorm.css.hpp"
25
- #include "index.html.hpp"
26
- #include "index-new.html.hpp"
27
- #include "index.js.hpp"
28
- #include "completion.js.hpp"
29
- #include "system-prompts.js.hpp"
30
- #include "prompt-formats.js.hpp"
31
- #include "json-schema-to-grammar.mjs.hpp"
18
+ #include "index.html.gz.hpp"
32
19
  #include "loading.html.hpp"
33
20
 
34
21
  #include <atomic>
@@ -43,31 +30,19 @@
43
30
  #include <unordered_map>
44
31
  #include <unordered_set>
45
32
 
46
- #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
47
- #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
48
- #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
49
- #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
50
-
51
- #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
52
- #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
53
- #define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
54
- #define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
55
-
56
- #define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
57
- #define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
58
- #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
59
- #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
60
-
61
33
  using json = nlohmann::ordered_json;
62
34
 
63
35
  enum stop_type {
64
- STOP_TYPE_FULL,
65
- STOP_TYPE_PARTIAL,
36
+ STOP_TYPE_NONE,
37
+ STOP_TYPE_EOS,
38
+ STOP_TYPE_WORD,
39
+ STOP_TYPE_LIMIT,
66
40
  };
67
41
 
68
42
  // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
69
43
  enum slot_state {
70
44
  SLOT_STATE_IDLE,
45
+ SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
71
46
  SLOT_STATE_PROCESSING_PROMPT,
72
47
  SLOT_STATE_DONE_PROMPT,
73
48
  SLOT_STATE_GENERATING,
@@ -80,6 +55,9 @@ enum server_state {
80
55
 
81
56
  enum server_task_type {
82
57
  SERVER_TASK_TYPE_COMPLETION,
58
+ SERVER_TASK_TYPE_EMBEDDING,
59
+ SERVER_TASK_TYPE_RERANK,
60
+ SERVER_TASK_TYPE_INFILL,
83
61
  SERVER_TASK_TYPE_CANCEL,
84
62
  SERVER_TASK_TYPE_NEXT_RESPONSE,
85
63
  SERVER_TASK_TYPE_METRICS,
@@ -89,21 +67,309 @@ enum server_task_type {
89
67
  SERVER_TASK_TYPE_SET_LORA,
90
68
  };
91
69
 
92
- enum server_task_cmpl_type {
93
- SERVER_TASK_CMPL_TYPE_NORMAL,
94
- SERVER_TASK_CMPL_TYPE_EMBEDDING,
95
- SERVER_TASK_CMPL_TYPE_RERANK,
96
- SERVER_TASK_CMPL_TYPE_INFILL,
70
+ // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
71
+ enum error_type {
72
+ ERROR_TYPE_INVALID_REQUEST,
73
+ ERROR_TYPE_AUTHENTICATION,
74
+ ERROR_TYPE_SERVER,
75
+ ERROR_TYPE_NOT_FOUND,
76
+ ERROR_TYPE_PERMISSION,
77
+ ERROR_TYPE_UNAVAILABLE, // custom error
78
+ ERROR_TYPE_NOT_SUPPORTED, // custom error
79
+ };
80
+
81
+ struct slot_params {
82
+ bool stream = true;
83
+ bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt
84
+ bool return_tokens = false;
85
+
86
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
87
+ int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
88
+ int32_t n_predict = -1; // new tokens to predict
89
+ int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
90
+
91
+ int64_t t_max_prompt_ms = -1; // TODO: implement
92
+ int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
93
+
94
+ std::vector<std::string> antiprompt;
95
+ bool timings_per_token = false;
96
+ bool post_sampling_probs = false;
97
+ bool ignore_eos = false;
98
+
99
+ struct common_params_sampling sampling;
100
+ struct common_params_speculative speculative;
101
+
102
+ // OAI-compat fields
103
+ bool verbose = false;
104
+ bool oaicompat = false;
105
+ bool oaicompat_chat = true;
106
+ std::string oaicompat_model;
107
+ std::string oaicompat_cmpl_id;
108
+
109
+ json to_json() const {
110
+ std::vector<std::string> samplers;
111
+ samplers.reserve(sampling.samplers.size());
112
+ for (const auto & sampler : sampling.samplers) {
113
+ samplers.emplace_back(common_sampler_type_to_str(sampler));
114
+ }
115
+
116
+ return json {
117
+ {"n_predict", n_predict}, // Server configured n_predict
118
+ {"seed", sampling.seed},
119
+ {"temperature", sampling.temp},
120
+ {"dynatemp_range", sampling.dynatemp_range},
121
+ {"dynatemp_exponent", sampling.dynatemp_exponent},
122
+ {"top_k", sampling.top_k},
123
+ {"top_p", sampling.top_p},
124
+ {"min_p", sampling.min_p},
125
+ {"xtc_probability", sampling.xtc_probability},
126
+ {"xtc_threshold", sampling.xtc_threshold},
127
+ {"typical_p", sampling.typ_p},
128
+ {"repeat_last_n", sampling.penalty_last_n},
129
+ {"repeat_penalty", sampling.penalty_repeat},
130
+ {"presence_penalty", sampling.penalty_present},
131
+ {"frequency_penalty", sampling.penalty_freq},
132
+ {"dry_multiplier", sampling.dry_multiplier},
133
+ {"dry_base", sampling.dry_base},
134
+ {"dry_allowed_length", sampling.dry_allowed_length},
135
+ {"dry_penalty_last_n", sampling.dry_penalty_last_n},
136
+ {"dry_sequence_breakers", sampling.dry_sequence_breakers},
137
+ {"mirostat", sampling.mirostat},
138
+ {"mirostat_tau", sampling.mirostat_tau},
139
+ {"mirostat_eta", sampling.mirostat_eta},
140
+ {"stop", antiprompt},
141
+ {"max_tokens", n_predict}, // User configured n_predict
142
+ {"n_keep", n_keep},
143
+ {"n_discard", n_discard},
144
+ {"ignore_eos", sampling.ignore_eos},
145
+ {"stream", stream},
146
+ {"logit_bias", format_logit_bias(sampling.logit_bias)},
147
+ {"n_probs", sampling.n_probs},
148
+ {"min_keep", sampling.min_keep},
149
+ {"grammar", sampling.grammar},
150
+ {"samplers", samplers},
151
+ {"speculative.n_max", speculative.n_max},
152
+ {"speculative.n_min", speculative.n_min},
153
+ {"speculative.p_min", speculative.p_min},
154
+ {"timings_per_token", timings_per_token},
155
+ {"post_sampling_probs", post_sampling_probs},
156
+ };
157
+ }
97
158
  };
98
159
 
99
160
  struct server_task {
100
- int id = -1; // to be filled by server_queue
101
- int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
161
+ int id = -1; // to be filled by server_queue
162
+ int index = -1; // used when there are multiple prompts (batch request)
102
163
 
103
164
  server_task_type type;
104
- json data;
105
165
 
106
- server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
166
+ // used by SERVER_TASK_TYPE_CANCEL
167
+ int id_target = -1;
168
+
169
+ // used by SERVER_TASK_TYPE_INFERENCE
170
+ slot_params params;
171
+ llama_tokens prompt_tokens;
172
+ int id_selected_slot = -1;
173
+
174
+ // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE
175
+ struct slot_action {
176
+ int slot_id;
177
+ std::string filename;
178
+ std::string filepath;
179
+ };
180
+ slot_action slot_action;
181
+
182
+ // used by SERVER_TASK_TYPE_METRICS
183
+ bool metrics_reset_bucket = false;
184
+
185
+ server_task(server_task_type type) : type(type) {}
186
+
187
+ static slot_params params_from_json_cmpl(
188
+ const llama_model * model,
189
+ const llama_context * ctx,
190
+ const common_params & params_base,
191
+ const json & data) {
192
+ slot_params params;
193
+
194
+ // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
195
+ slot_params defaults;
196
+ defaults.sampling = params_base.sampling;
197
+ defaults.speculative = params_base.speculative;
198
+
199
+ // enabling this will output extra debug information in the HTTP responses from the server
200
+ params.verbose = params_base.verbosity > 9;
201
+ params.timings_per_token = json_value(data, "timings_per_token", false);
202
+
203
+ params.stream = json_value(data, "stream", false);
204
+ params.cache_prompt = json_value(data, "cache_prompt", true);
205
+ params.return_tokens = json_value(data, "return_tokens", false);
206
+ params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict));
207
+ params.n_indent = json_value(data, "n_indent", defaults.n_indent);
208
+ params.n_keep = json_value(data, "n_keep", defaults.n_keep);
209
+ params.n_discard = json_value(data, "n_discard", defaults.n_discard);
210
+ //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
211
+ params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
212
+
213
+ params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
214
+ params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
215
+ params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p);
216
+ params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability);
217
+ params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold);
218
+ params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p);
219
+ params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp);
220
+ params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range);
221
+ params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent);
222
+ params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n);
223
+ params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat);
224
+ params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq);
225
+ params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present);
226
+ params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier);
227
+ params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base);
228
+ params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length);
229
+ params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n);
230
+ params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat);
231
+ params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau);
232
+ params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta);
233
+ params.sampling.seed = json_value(data, "seed", defaults.sampling.seed);
234
+ params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs);
235
+ params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep);
236
+ params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs);
237
+
238
+ params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min);
239
+ params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max);
240
+ params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min);
241
+
242
+ params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min);
243
+ params.speculative.n_min = std::max(params.speculative.n_min, 2);
244
+ params.speculative.n_max = std::max(params.speculative.n_max, 0);
245
+
246
+ // TODO: add more sanity checks for the input parameters
247
+
248
+ if (params.sampling.penalty_last_n < -1) {
249
+ throw std::runtime_error("Error: repeat_last_n must be >= -1");
250
+ }
251
+
252
+ if (params.sampling.dry_penalty_last_n < -1) {
253
+ throw std::runtime_error("Error: dry_penalty_last_n must be >= -1");
254
+ }
255
+
256
+ if (params.sampling.penalty_last_n == -1) {
257
+ // note: should be the slot's context and not the full context, but it's ok
258
+ params.sampling.penalty_last_n = llama_n_ctx(ctx);
259
+ }
260
+
261
+ if (params.sampling.dry_penalty_last_n == -1) {
262
+ params.sampling.dry_penalty_last_n = llama_n_ctx(ctx);
263
+ }
264
+
265
+ if (params.sampling.dry_base < 1.0f) {
266
+ params.sampling.dry_base = defaults.sampling.dry_base;
267
+ }
268
+
269
+ // sequence breakers for DRY
270
+ {
271
+ // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
272
+ // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
273
+
274
+ if (data.contains("dry_sequence_breakers")) {
275
+ params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
276
+ if (params.sampling.dry_sequence_breakers.empty()) {
277
+ throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings");
278
+ }
279
+ }
280
+ }
281
+
282
+ // process "json_schema" and "grammar"
283
+ if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
284
+ throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
285
+ }
286
+ if (data.contains("json_schema") && !data.contains("grammar")) {
287
+ try {
288
+ auto schema = json_value(data, "json_schema", json::object());
289
+ params.sampling.grammar = json_schema_to_grammar(schema);
290
+ } catch (const std::exception & e) {
291
+ throw std::runtime_error(std::string("\"json_schema\": ") + e.what());
292
+ }
293
+ } else {
294
+ params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar);
295
+ }
296
+
297
+ {
298
+ params.sampling.logit_bias.clear();
299
+ params.ignore_eos = json_value(data, "ignore_eos", false);
300
+
301
+ const auto & logit_bias = data.find("logit_bias");
302
+ if (logit_bias != data.end() && logit_bias->is_array()) {
303
+ const int n_vocab = llama_n_vocab(model);
304
+ for (const auto & el : *logit_bias) {
305
+ // TODO: we may want to throw errors here, in case "el" is incorrect
306
+ if (el.is_array() && el.size() == 2) {
307
+ float bias;
308
+ if (el[1].is_number()) {
309
+ bias = el[1].get<float>();
310
+ } else if (el[1].is_boolean() && !el[1].get<bool>()) {
311
+ bias = -INFINITY;
312
+ } else {
313
+ continue;
314
+ }
315
+
316
+ if (el[0].is_number_integer()) {
317
+ llama_token tok = el[0].get<llama_token>();
318
+ if (tok >= 0 && tok < n_vocab) {
319
+ params.sampling.logit_bias.push_back({tok, bias});
320
+ }
321
+ } else if (el[0].is_string()) {
322
+ auto toks = common_tokenize(model, el[0].get<std::string>(), false);
323
+ for (auto tok : toks) {
324
+ params.sampling.logit_bias.push_back({tok, bias});
325
+ }
326
+ }
327
+ }
328
+ }
329
+ }
330
+ }
331
+
332
+ {
333
+ params.antiprompt.clear();
334
+
335
+ const auto & stop = data.find("stop");
336
+ if (stop != data.end() && stop->is_array()) {
337
+ for (const auto & word : *stop) {
338
+ if (!word.empty()) {
339
+ params.antiprompt.push_back(word);
340
+ }
341
+ }
342
+ }
343
+ }
344
+
345
+ {
346
+ const auto & samplers = data.find("samplers");
347
+ if (samplers != data.end()) {
348
+ if (samplers->is_array()) {
349
+ std::vector<std::string> sampler_names;
350
+ for (const auto & name : *samplers) {
351
+ if (name.is_string()) {
352
+ sampler_names.emplace_back(name);
353
+ }
354
+ }
355
+ params.sampling.samplers = common_sampler_types_from_names(sampler_names, false);
356
+ } else if (samplers->is_string()){
357
+ std::string sampler_string;
358
+ for (const auto & name : *samplers) {
359
+ sampler_string += name;
360
+ }
361
+ params.sampling.samplers = common_sampler_types_from_chars(sampler_string);
362
+ }
363
+ } else {
364
+ params.sampling.samplers = defaults.sampling.samplers;
365
+ }
366
+ }
367
+
368
+ std::string model_name = params_base.model_alias.empty() ? DEFAULT_OAICOMPAT_MODEL : params_base.model_alias;
369
+ params.oaicompat_model = json_value(data, "model", model_name);
370
+
371
+ return params;
372
+ }
107
373
 
108
374
  // utility function
109
375
  static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@@ -115,33 +381,628 @@ struct server_task {
115
381
  }
116
382
  };
117
383
 
384
+ struct result_timings {
385
+ int32_t prompt_n = -1;
386
+ double prompt_ms;
387
+ double prompt_per_token_ms;
388
+ double prompt_per_second;
389
+
390
+ int32_t predicted_n = -1;
391
+ double predicted_ms;
392
+ double predicted_per_token_ms;
393
+ double predicted_per_second;
394
+
395
+ json to_json() const {
396
+ return {
397
+ {"prompt_n", prompt_n},
398
+ {"prompt_ms", prompt_ms},
399
+ {"prompt_per_token_ms", prompt_per_token_ms},
400
+ {"prompt_per_second", prompt_per_second},
401
+
402
+ {"predicted_n", predicted_n},
403
+ {"predicted_ms", predicted_ms},
404
+ {"predicted_per_token_ms", predicted_per_token_ms},
405
+ {"predicted_per_second", predicted_per_second},
406
+ };
407
+ }
408
+ };
409
+
118
410
  struct server_task_result {
119
- int id = -1;
411
+ int id = -1;
412
+ int id_slot = -1;
413
+ virtual bool is_error() {
414
+ // only used by server_task_result_error
415
+ return false;
416
+ }
417
+ virtual bool is_stop() {
418
+ // only used by server_task_result_cmpl_*
419
+ return false;
420
+ }
421
+ virtual int get_index() {
422
+ return -1;
423
+ }
424
+ virtual json to_json() = 0;
425
+ virtual ~server_task_result() = default;
426
+ };
120
427
 
121
- json data;
428
+ // using shared_ptr for polymorphism of server_task_result
429
+ using server_task_result_ptr = std::unique_ptr<server_task_result>;
122
430
 
123
- bool stop;
124
- bool error;
431
+ inline std::string stop_type_to_str(stop_type type) {
432
+ switch (type) {
433
+ case STOP_TYPE_EOS: return "eos";
434
+ case STOP_TYPE_WORD: return "word";
435
+ case STOP_TYPE_LIMIT: return "limit";
436
+ default: return "none";
437
+ }
438
+ }
439
+
440
+ struct completion_token_output {
441
+ llama_token tok;
442
+ float prob;
443
+ std::string text_to_send;
444
+ struct prob_info {
445
+ llama_token tok;
446
+ std::string txt;
447
+ float prob;
448
+ };
449
+ std::vector<prob_info> probs;
450
+
451
+ json to_json(bool post_sampling_probs) const {
452
+ json probs_for_token = json::array();
453
+ for (const auto & p : probs) {
454
+ std::string txt(p.txt);
455
+ txt.resize(validate_utf8(txt));
456
+ probs_for_token.push_back(json {
457
+ {"id", p.tok},
458
+ {"token", txt},
459
+ {"bytes", str_to_bytes(p.txt)},
460
+ {
461
+ post_sampling_probs ? "prob" : "logprob",
462
+ post_sampling_probs ? p.prob : logarithm(p.prob)
463
+ },
464
+ });
465
+ }
466
+ return probs_for_token;
467
+ }
468
+
469
+ static json probs_vector_to_json(const std::vector<completion_token_output> & probs, bool post_sampling_probs) {
470
+ json out = json::array();
471
+ for (const auto & p : probs) {
472
+ std::string txt(p.text_to_send);
473
+ txt.resize(validate_utf8(txt));
474
+ out.push_back(json {
475
+ {"id", p.tok},
476
+ {"token", txt},
477
+ {"bytes", str_to_bytes(p.text_to_send)},
478
+ {
479
+ post_sampling_probs ? "prob" : "logprob",
480
+ post_sampling_probs ? p.prob : logarithm(p.prob)
481
+ },
482
+ {
483
+ post_sampling_probs ? "top_probs" : "top_logprobs",
484
+ p.to_json(post_sampling_probs)
485
+ },
486
+ });
487
+ }
488
+ return out;
489
+ }
490
+
491
+ static float logarithm(float x) {
492
+ // nlohmann::json converts -inf to null, so we need to prevent that
493
+ return x == 0.0f ? std::numeric_limits<float>::lowest() : std::log(x);
494
+ }
495
+
496
+ static std::vector<unsigned char> str_to_bytes(const std::string & str) {
497
+ std::vector<unsigned char> bytes;
498
+ for (unsigned char c : str) {
499
+ bytes.push_back(c);
500
+ }
501
+ return bytes;
502
+ }
125
503
  };
126
504
 
127
- struct slot_params {
128
- bool stream = true;
129
- bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
505
+ struct server_task_result_cmpl_final : server_task_result {
506
+ int index = 0;
130
507
 
131
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
132
- int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133
- int32_t n_predict = -1; // new tokens to predict
508
+ std::string content;
509
+ llama_tokens tokens;
134
510
 
135
- std::vector<std::string> antiprompt;
511
+ bool stream;
512
+ result_timings timings;
513
+ std::string prompt;
514
+
515
+ bool truncated;
516
+ int32_t n_decoded;
517
+ int32_t n_prompt_tokens;
518
+ int32_t n_tokens_cached;
519
+ bool has_new_line;
520
+ std::string stopping_word;
521
+ stop_type stop = STOP_TYPE_NONE;
136
522
 
137
- json input_prefix;
138
- json input_suffix;
523
+ bool post_sampling_probs;
524
+ std::vector<completion_token_output> probs_output;
525
+
526
+ slot_params generation_params;
527
+
528
+ // OAI-compat fields
529
+ bool verbose = false;
530
+ bool oaicompat = false;
531
+ bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
532
+ std::string oaicompat_model;
533
+ std::string oaicompat_cmpl_id;
534
+
535
+ virtual int get_index() override {
536
+ return index;
537
+ }
538
+
539
+ virtual bool is_stop() override {
540
+ return true; // in stream mode, final responses are considered stop
541
+ }
542
+
543
+ virtual json to_json() override {
544
+ return oaicompat
545
+ ? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
546
+ : to_json_non_oaicompat();
547
+ }
548
+
549
+ json to_json_non_oaicompat() {
550
+ json res = json {
551
+ {"index", index},
552
+ {"content", stream ? "" : content}, // in stream mode, content is already in last partial chunk
553
+ {"tokens", stream ? llama_tokens {} : tokens},
554
+ {"id_slot", id_slot},
555
+ {"stop", true},
556
+ {"model", oaicompat_model},
557
+ {"tokens_predicted", n_decoded},
558
+ {"tokens_evaluated", n_prompt_tokens},
559
+ {"generation_settings", generation_params.to_json()},
560
+ {"prompt", prompt},
561
+ {"has_new_line", has_new_line},
562
+ {"truncated", truncated},
563
+ {"stop_type", stop_type_to_str(stop)},
564
+ {"stopping_word", stopping_word},
565
+ {"tokens_cached", n_tokens_cached},
566
+ {"timings", timings.to_json()},
567
+ };
568
+ if (!stream && !probs_output.empty()) {
569
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs);
570
+ }
571
+ return res;
572
+ }
573
+
574
+ json to_json_oaicompat_chat() {
575
+ std::string finish_reason = "length";
576
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
577
+ finish_reason = "stop";
578
+ }
579
+
580
+ json choice = json{
581
+ {"finish_reason", finish_reason},
582
+ {"index", 0},
583
+ {"message", json {
584
+ {"content", content},
585
+ {"role", "assistant"}
586
+ }
587
+ }};
588
+
589
+ if (!stream && probs_output.size() > 0) {
590
+ choice["logprobs"] = json{
591
+ {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
592
+ };
593
+ }
594
+
595
+ std::time_t t = std::time(0);
596
+
597
+ json res = json {
598
+ {"choices", json::array({choice})},
599
+ {"created", t},
600
+ {"model", oaicompat_model},
601
+ {"object", "chat.completion"},
602
+ {"usage", json {
603
+ {"completion_tokens", n_decoded},
604
+ {"prompt_tokens", n_prompt_tokens},
605
+ {"total_tokens", n_decoded + n_prompt_tokens}
606
+ }},
607
+ {"id", oaicompat_cmpl_id}
608
+ };
609
+
610
+ // extra fields for debugging purposes
611
+ if (verbose) {
612
+ res["__verbose"] = to_json_non_oaicompat();
613
+ }
614
+ if (timings.prompt_n >= 0) {
615
+ res.push_back({"timings", timings.to_json()});
616
+ }
617
+
618
+ return res;
619
+ }
620
+
621
+ json to_json_oaicompat_chat_stream() {
622
+ std::time_t t = std::time(0);
623
+ std::string finish_reason = "length";
624
+ if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
625
+ finish_reason = "stop";
626
+ }
627
+
628
+ json choice = json{
629
+ {"finish_reason", finish_reason},
630
+ {"index", 0},
631
+ {"delta", json::object()}
632
+ };
633
+
634
+ json ret = json {
635
+ {"choices", json::array({choice})},
636
+ {"created", t},
637
+ {"id", oaicompat_cmpl_id},
638
+ {"model", oaicompat_model},
639
+ {"object", "chat.completion.chunk"},
640
+ {"usage", json {
641
+ {"completion_tokens", n_decoded},
642
+ {"prompt_tokens", n_prompt_tokens},
643
+ {"total_tokens", n_decoded + n_prompt_tokens},
644
+ }},
645
+ };
646
+
647
+ if (timings.prompt_n >= 0) {
648
+ ret.push_back({"timings", timings.to_json()});
649
+ }
650
+
651
+ return ret;
652
+ }
653
+ };
654
+
655
+ struct server_task_result_cmpl_partial : server_task_result {
656
+ int index = 0;
657
+
658
+ std::string content;
659
+ llama_tokens tokens;
660
+
661
+ int32_t n_decoded;
662
+ int32_t n_prompt_tokens;
663
+
664
+ bool post_sampling_probs;
665
+ completion_token_output prob_output;
666
+ result_timings timings;
667
+
668
+ // OAI-compat fields
669
+ bool verbose = false;
670
+ bool oaicompat = false;
671
+ bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
672
+ std::string oaicompat_model;
673
+ std::string oaicompat_cmpl_id;
674
+
675
+ virtual int get_index() override {
676
+ return index;
677
+ }
678
+
679
+ virtual bool is_stop() override {
680
+ return false; // in stream mode, partial responses are not considered stop
681
+ }
682
+
683
+ virtual json to_json() override {
684
+ return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
685
+ }
686
+
687
+ json to_json_non_oaicompat() {
688
+ // non-OAI-compat JSON
689
+ json res = json {
690
+ {"index", index},
691
+ {"content", content},
692
+ {"tokens", tokens},
693
+ {"stop", false},
694
+ {"id_slot", id_slot},
695
+ {"tokens_predicted", n_decoded},
696
+ {"tokens_evaluated", n_prompt_tokens},
697
+ };
698
+ // populate the timings object when needed (usually for the last response or with timings_per_token enabled)
699
+ if (timings.prompt_n > 0) {
700
+ res.push_back({"timings", timings.to_json()});
701
+ }
702
+ if (!prob_output.probs.empty()) {
703
+ res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs);
704
+ }
705
+ return res;
706
+ }
707
+
708
+ json to_json_oaicompat() {
709
+ bool first = n_decoded == 0;
710
+ std::time_t t = std::time(0);
711
+ json choices;
712
+
713
+ if (first) {
714
+ if (content.empty()) {
715
+ choices = json::array({json{{"finish_reason", nullptr},
716
+ {"index", 0},
717
+ {"delta", json{{"role", "assistant"}}}}});
718
+ } else {
719
+ // We have to send this as two updates to conform to openai behavior
720
+ json initial_ret = json{{"choices", json::array({json{
721
+ {"finish_reason", nullptr},
722
+ {"index", 0},
723
+ {"delta", json{
724
+ {"role", "assistant"}
725
+ }}}})},
726
+ {"created", t},
727
+ {"id", oaicompat_cmpl_id},
728
+ {"model", oaicompat_model},
729
+ {"object", "chat.completion.chunk"}};
730
+
731
+ json second_ret = json{
732
+ {"choices", json::array({json{{"finish_reason", nullptr},
733
+ {"index", 0},
734
+ {"delta", json {
735
+ {"content", content}}}
736
+ }})},
737
+ {"created", t},
738
+ {"id", oaicompat_cmpl_id},
739
+ {"model", oaicompat_model},
740
+ {"object", "chat.completion.chunk"}};
741
+
742
+ return std::vector<json>({initial_ret, second_ret});
743
+ }
744
+ } else {
745
+ choices = json::array({json{
746
+ {"finish_reason", nullptr},
747
+ {"index", 0},
748
+ {"delta",
749
+ json {
750
+ {"content", content},
751
+ }},
752
+ }});
753
+ }
754
+
755
+ GGML_ASSERT(choices.size() >= 1);
756
+
757
+ if (prob_output.probs.size() > 0) {
758
+ choices[0]["logprobs"] = json{
759
+ {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
760
+ };
761
+ }
762
+
763
+ json ret = json {
764
+ {"choices", choices},
765
+ {"created", t},
766
+ {"id", oaicompat_cmpl_id},
767
+ {"model", oaicompat_model},
768
+ {"object", "chat.completion.chunk"}
769
+ };
770
+
771
+ if (timings.prompt_n >= 0) {
772
+ ret.push_back({"timings", timings.to_json()});
773
+ }
774
+
775
+ return std::vector<json>({ret});
776
+ }
777
+ };
778
+
779
+ struct server_task_result_embd : server_task_result {
780
+ int index = 0;
781
+ std::vector<std::vector<float>> embedding;
782
+
783
+ int32_t n_tokens;
784
+
785
+ // OAI-compat fields
786
+ bool oaicompat = false;
787
+
788
+ virtual int get_index() override {
789
+ return index;
790
+ }
791
+
792
+ virtual json to_json() override {
793
+ return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
794
+ }
795
+
796
+ json to_json_non_oaicompat() {
797
+ return json {
798
+ {"index", index},
799
+ {"embedding", embedding},
800
+ };
801
+ }
802
+
803
+ json to_json_oaicompat() {
804
+ return json {
805
+ {"index", index},
806
+ {"embedding", embedding[0]},
807
+ {"tokens_evaluated", n_tokens},
808
+ };
809
+ }
810
+ };
811
+
812
+ struct server_task_result_rerank : server_task_result {
813
+ int index = 0;
814
+ float score = -1e6;
815
+
816
+ int32_t n_tokens;
817
+
818
+ virtual int get_index() override {
819
+ return index;
820
+ }
821
+
822
+ virtual json to_json() override {
823
+ return json {
824
+ {"index", index},
825
+ {"score", score},
826
+ {"tokens_evaluated", n_tokens},
827
+ };
828
+ }
829
+ };
830
+
831
+ // this function maybe used outside of server_task_result_error
832
+ static json format_error_response(const std::string & message, const enum error_type type) {
833
+ std::string type_str;
834
+ int code = 500;
835
+ switch (type) {
836
+ case ERROR_TYPE_INVALID_REQUEST:
837
+ type_str = "invalid_request_error";
838
+ code = 400;
839
+ break;
840
+ case ERROR_TYPE_AUTHENTICATION:
841
+ type_str = "authentication_error";
842
+ code = 401;
843
+ break;
844
+ case ERROR_TYPE_NOT_FOUND:
845
+ type_str = "not_found_error";
846
+ code = 404;
847
+ break;
848
+ case ERROR_TYPE_SERVER:
849
+ type_str = "server_error";
850
+ code = 500;
851
+ break;
852
+ case ERROR_TYPE_PERMISSION:
853
+ type_str = "permission_error";
854
+ code = 403;
855
+ break;
856
+ case ERROR_TYPE_NOT_SUPPORTED:
857
+ type_str = "not_supported_error";
858
+ code = 501;
859
+ break;
860
+ case ERROR_TYPE_UNAVAILABLE:
861
+ type_str = "unavailable_error";
862
+ code = 503;
863
+ break;
864
+ }
865
+ return json {
866
+ {"code", code},
867
+ {"message", message},
868
+ {"type", type_str},
869
+ };
870
+ }
871
+
872
+ struct server_task_result_error : server_task_result {
873
+ int index = 0;
874
+ error_type err_type = ERROR_TYPE_SERVER;
875
+ std::string err_msg;
876
+
877
+ virtual bool is_error() override {
878
+ return true;
879
+ }
880
+
881
+ virtual json to_json() override {
882
+ return format_error_response(err_msg, err_type);
883
+ }
884
+ };
885
+
886
+ struct server_task_result_metrics : server_task_result {
887
+ int n_idle_slots;
888
+ int n_processing_slots;
889
+ int n_tasks_deferred;
890
+ int64_t t_start;
891
+
892
+ int32_t kv_cache_tokens_count;
893
+ int32_t kv_cache_used_cells;
894
+
895
+ // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields
896
+ uint64_t n_prompt_tokens_processed_total = 0;
897
+ uint64_t t_prompt_processing_total = 0;
898
+ uint64_t n_tokens_predicted_total = 0;
899
+ uint64_t t_tokens_generation_total = 0;
900
+
901
+ uint64_t n_prompt_tokens_processed = 0;
902
+ uint64_t t_prompt_processing = 0;
903
+
904
+ uint64_t n_tokens_predicted = 0;
905
+ uint64_t t_tokens_generation = 0;
906
+
907
+ uint64_t n_decode_total = 0;
908
+ uint64_t n_busy_slots_total = 0;
909
+
910
+ // while we can also use std::vector<server_slot> this requires copying the slot object which can be quite messy
911
+ // therefore, we use json to temporarily store the slot.to_json() result
912
+ json slots_data = json::array();
913
+
914
+ virtual json to_json() override {
915
+ return json {
916
+ { "idle", n_idle_slots },
917
+ { "processing", n_processing_slots },
918
+ { "deferred", n_tasks_deferred },
919
+ { "t_start", t_start },
920
+
921
+ { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total },
922
+ { "t_tokens_generation_total", t_tokens_generation_total },
923
+ { "n_tokens_predicted_total", n_tokens_predicted_total },
924
+ { "t_prompt_processing_total", t_prompt_processing_total },
925
+
926
+ { "n_prompt_tokens_processed", n_prompt_tokens_processed },
927
+ { "t_prompt_processing", t_prompt_processing },
928
+ { "n_tokens_predicted", n_tokens_predicted },
929
+ { "t_tokens_generation", t_tokens_generation },
930
+
931
+ { "n_decode_total", n_decode_total },
932
+ { "n_busy_slots_total", n_busy_slots_total },
933
+
934
+ { "kv_cache_tokens_count", kv_cache_tokens_count },
935
+ { "kv_cache_used_cells", kv_cache_used_cells },
936
+
937
+ { "slots", slots_data },
938
+ };
939
+ }
940
+ };
941
+
942
+ struct server_task_result_slot_save_load : server_task_result {
943
+ std::string filename;
944
+ bool is_save; // true = save, false = load
945
+
946
+ size_t n_tokens;
947
+ size_t n_bytes;
948
+ double t_ms;
949
+
950
+ virtual json to_json() override {
951
+ if (is_save) {
952
+ return json {
953
+ { "id_slot", id_slot },
954
+ { "filename", filename },
955
+ { "n_saved", n_tokens },
956
+ { "n_written", n_bytes },
957
+ { "timings", {
958
+ { "save_ms", t_ms }
959
+ }},
960
+ };
961
+ } else {
962
+ return json {
963
+ { "id_slot", id_slot },
964
+ { "filename", filename },
965
+ { "n_restored", n_tokens },
966
+ { "n_read", n_bytes },
967
+ { "timings", {
968
+ { "restore_ms", t_ms }
969
+ }},
970
+ };
971
+ }
972
+ }
973
+ };
974
+
975
+ struct server_task_result_slot_erase : server_task_result {
976
+ size_t n_erased;
977
+
978
+ virtual json to_json() override {
979
+ return json {
980
+ { "id_slot", id_slot },
981
+ { "n_erased", n_erased },
982
+ };
983
+ }
984
+ };
985
+
986
+ struct server_task_result_apply_lora : server_task_result {
987
+ virtual json to_json() override {
988
+ return json {{ "success", true }};
989
+ }
139
990
  };
140
991
 
141
992
  struct server_slot {
142
993
  int id;
143
994
  int id_task = -1;
144
995
 
996
+ // only used for completion/embedding/infill/rerank
997
+ server_task_type task_type = SERVER_TASK_TYPE_COMPLETION;
998
+
999
+ llama_batch batch_spec = {};
1000
+
1001
+ llama_context * ctx = nullptr;
1002
+ llama_context * ctx_dft = nullptr;
1003
+
1004
+ common_speculative * spec = nullptr;
1005
+
145
1006
  // the index relative to completion multi-task request
146
1007
  size_t index = 0;
147
1008
 
@@ -160,54 +1021,44 @@ struct server_slot {
160
1021
  int32_t i_batch = -1;
161
1022
  int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
162
1023
 
1024
+ // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
163
1025
  int32_t n_prompt_tokens = 0;
164
1026
  int32_t n_prompt_tokens_processed = 0;
165
1027
 
166
- json prompt; // can be either a string, array of strings or array of token ids
1028
+ // input prompt tokens
1029
+ llama_tokens prompt_tokens;
167
1030
 
168
- // when a task is submitted, we first tokenize the prompt and store it here
169
- std::vector<llama_token> prompt_tokens;
1031
+ size_t last_nl_pos = 0;
170
1032
 
171
- std::string generated_text;
172
- std::vector<llama_token> cache_tokens;
173
- std::vector<completion_token_output> generated_token_probs;
1033
+ std::string generated_text;
1034
+ llama_tokens generated_tokens;
174
1035
 
175
- server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
1036
+ llama_tokens cache_tokens;
1037
+
1038
+ std::vector<completion_token_output> generated_token_probs;
176
1039
 
177
1040
  bool has_next_token = true;
1041
+ bool has_new_line = false;
178
1042
  bool truncated = false;
179
- bool stopped_eos = false;
180
- bool stopped_word = false;
181
- bool stopped_limit = false;
182
-
183
- bool oaicompat = false;
1043
+ stop_type stop;
184
1044
 
185
- std::string oaicompat_model;
186
1045
  std::string stopping_word;
187
1046
 
188
1047
  // sampling
189
1048
  json json_schema;
190
1049
 
191
- struct gpt_sampler_params sparams;
192
- struct gpt_sampler * smpl = nullptr;
1050
+ struct common_sampler * smpl = nullptr;
193
1051
 
194
1052
  llama_token sampled;
195
1053
 
196
- int32_t ga_i = 0; // group-attention state
197
- int32_t ga_n = 1; // group-attention factor
198
- int32_t ga_w = 512; // group-attention width
199
-
200
- int32_t n_past_se = 0; // self-extend
201
-
202
1054
  // stats
203
- size_t n_sent_text = 0; // number of sent text character
204
- size_t n_sent_token_probs = 0;
1055
+ size_t n_sent_text = 0; // number of sent text character
205
1056
 
206
1057
  int64_t t_start_process_prompt;
207
1058
  int64_t t_start_generation;
208
1059
 
209
1060
  double t_prompt_processing; // ms
210
- double t_token_generation; // ms
1061
+ double t_token_generation; // ms
211
1062
 
212
1063
  std::function<void(int)> callback_on_release;
213
1064
 
@@ -215,23 +1066,25 @@ struct server_slot {
215
1066
  SLT_DBG(*this, "%s", "\n");
216
1067
 
217
1068
  n_prompt_tokens = 0;
1069
+ last_nl_pos = 0;
218
1070
  generated_text = "";
1071
+ has_new_line = false;
219
1072
  truncated = false;
220
- stopped_eos = false;
221
- stopped_word = false;
222
- stopped_limit = false;
1073
+ stop = STOP_TYPE_NONE;
223
1074
  stopping_word = "";
224
1075
  n_past = 0;
225
1076
  n_sent_text = 0;
226
- n_sent_token_probs = 0;
227
- cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
228
- ga_i = 0;
229
- n_past_se = 0;
1077
+ task_type = SERVER_TASK_TYPE_COMPLETION;
230
1078
 
1079
+ generated_tokens.clear();
231
1080
  generated_token_probs.clear();
232
1081
  }
233
1082
 
234
- bool has_budget(gpt_params &global_params) {
1083
+ bool is_non_causal() const {
1084
+ return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
1085
+ }
1086
+
1087
+ bool has_budget(const common_params & global_params) {
235
1088
  if (params.n_predict == -1 && global_params.n_predict == -1) {
236
1089
  return true; // limitless
237
1090
  }
@@ -251,6 +1104,10 @@ struct server_slot {
251
1104
  return state != SLOT_STATE_IDLE;
252
1105
  }
253
1106
 
1107
+ bool can_speculate() const {
1108
+ return ctx_dft && params.speculative.n_max > 0 && params.cache_prompt;
1109
+ }
1110
+
254
1111
  void add_token(const completion_token_output & token) {
255
1112
  if (!is_processing()) {
256
1113
  SLT_WRN(*this, "%s", "slot is not processing\n");
@@ -263,44 +1120,47 @@ struct server_slot {
263
1120
  if (is_processing()) {
264
1121
  SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
265
1122
 
1123
+ t_last_used = ggml_time_us();
266
1124
  t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
267
1125
  state = SLOT_STATE_IDLE;
268
1126
  callback_on_release(id);
269
1127
  }
270
1128
  }
271
1129
 
272
- json get_formated_timings() const {
273
- return json {
274
- {"prompt_n", n_prompt_tokens_processed},
275
- {"prompt_ms", t_prompt_processing},
276
- {"prompt_per_token_ms", t_prompt_processing / n_prompt_tokens_processed},
277
- {"prompt_per_second", 1e3 / t_prompt_processing * n_prompt_tokens_processed},
278
-
279
- {"predicted_n", n_decoded},
280
- {"predicted_ms", t_token_generation},
281
- {"predicted_per_token_ms", t_token_generation / n_decoded},
282
- {"predicted_per_second", 1e3 / t_token_generation * n_decoded},
283
- };
1130
+ result_timings get_timings() const {
1131
+ result_timings timings;
1132
+ timings.prompt_n = n_prompt_tokens_processed;
1133
+ timings.prompt_ms = t_prompt_processing;
1134
+ timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed;
1135
+ timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed;
1136
+
1137
+ timings.predicted_n = n_decoded;
1138
+ timings.predicted_ms = t_token_generation;
1139
+ timings.predicted_per_token_ms = t_token_generation / n_decoded;
1140
+ timings.predicted_per_second = 1e3 / t_token_generation * n_decoded;
1141
+
1142
+ return timings;
284
1143
  }
285
1144
 
286
- size_t find_stopping_strings(const std::string & text, const size_t last_token_size, const stop_type type) {
1145
+ size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) {
287
1146
  size_t stop_pos = std::string::npos;
288
1147
 
289
1148
  for (const std::string & word : params.antiprompt) {
290
1149
  size_t pos;
291
1150
 
292
- if (type == STOP_TYPE_FULL) {
1151
+ if (is_full_stop) {
293
1152
  const size_t tmp = word.size() + last_token_size;
294
1153
  const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0;
295
1154
 
296
1155
  pos = text.find(word, from_pos);
297
1156
  } else {
1157
+ // otherwise, partial stop
298
1158
  pos = find_partial_stop_string(word, text);
299
1159
  }
300
1160
 
301
1161
  if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) {
302
- if (type == STOP_TYPE_FULL) {
303
- stopped_word = true;
1162
+ if (is_full_stop) {
1163
+ stop = STOP_TYPE_WORD;
304
1164
  stopping_word = word;
305
1165
  has_next_token = false;
306
1166
  }
@@ -320,13 +1180,35 @@ struct server_slot {
320
1180
 
321
1181
  SLT_INF(*this,
322
1182
  "\n"
323
- "\rprompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
324
- "\r eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
325
- "\r total time = %10.2f ms / %5d tokens\n",
1183
+ "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
1184
+ " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n"
1185
+ " total time = %10.2f ms / %5d tokens\n",
326
1186
  t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second,
327
1187
  t_token_generation, n_decoded, t_gen, n_gen_second,
328
1188
  t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded);
329
1189
  }
1190
+
1191
+ json to_json() const {
1192
+ return json {
1193
+ {"id", id},
1194
+ {"id_task", id_task},
1195
+ {"n_ctx", n_ctx},
1196
+ {"speculative", can_speculate()},
1197
+ {"is_processing", is_processing()},
1198
+ {"non_causal", is_non_causal()},
1199
+ {"params", params.to_json()},
1200
+ {"prompt", common_detokenize(ctx, prompt_tokens)},
1201
+ {"next_token",
1202
+ {
1203
+ {"has_next_token", has_next_token},
1204
+ {"has_new_line", has_new_line},
1205
+ {"n_remain", n_remaining},
1206
+ {"n_decoded", n_decoded},
1207
+ {"stopping_word", stopping_word},
1208
+ }
1209
+ },
1210
+ };
1211
+ }
330
1212
  };
331
1213
 
332
1214
  struct server_metrics {
@@ -393,15 +1275,13 @@ struct server_queue {
393
1275
  std::condition_variable condition_tasks;
394
1276
 
395
1277
  // callback functions
396
- std::function<void(server_task&)> callback_new_task;
397
- std::function<void(void)> callback_update_slots;
1278
+ std::function<void(server_task)> callback_new_task;
1279
+ std::function<void(void)> callback_update_slots;
398
1280
 
399
1281
  // Add a new task to the end of the queue
400
1282
  int post(server_task task, bool front = false) {
401
1283
  std::unique_lock<std::mutex> lock(mutex_tasks);
402
- if (task.id == -1) {
403
- task.id = id++;
404
- }
1284
+ GGML_ASSERT(task.id != -1);
405
1285
  QUE_DBG("new task, id = %d, front = %d\n", task.id, front);
406
1286
  if (front) {
407
1287
  queue_tasks.push_front(std::move(task));
@@ -446,7 +1326,7 @@ struct server_queue {
446
1326
  }
447
1327
 
448
1328
  // Register function to process a new task
449
- void on_new_task(std::function<void(server_task &)> callback) {
1329
+ void on_new_task(std::function<void(server_task)> callback) {
450
1330
  callback_new_task = std::move(callback);
451
1331
  }
452
1332
 
@@ -496,7 +1376,7 @@ struct server_queue {
496
1376
  lock.unlock();
497
1377
 
498
1378
  QUE_DBG("processing task, id = %d\n", task.id);
499
- callback_new_task(task);
1379
+ callback_new_task(std::move(task));
500
1380
  }
501
1381
 
502
1382
  // all tasks in the current loop is processed, slots data is now ready
@@ -525,8 +1405,8 @@ struct server_response {
525
1405
  // for keeping track of all tasks waiting for the result
526
1406
  std::unordered_set<int> waiting_task_ids;
527
1407
 
528
- // the main result queue
529
- std::vector<server_task_result> queue_results;
1408
+ // the main result queue (using ptr for polymorphism)
1409
+ std::vector<server_task_result_ptr> queue_results;
530
1410
 
531
1411
  std::mutex mutex_results;
532
1412
  std::condition_variable condition_results;
@@ -566,7 +1446,7 @@ struct server_response {
566
1446
  }
567
1447
 
568
1448
  // This function blocks the thread until there is a response for one of the id_tasks
569
- server_task_result recv(const std::unordered_set<int> & id_tasks) {
1449
+ server_task_result_ptr recv(const std::unordered_set<int> & id_tasks) {
570
1450
  while (true) {
571
1451
  std::unique_lock<std::mutex> lock(mutex_results);
572
1452
  condition_results.wait(lock, [&]{
@@ -574,8 +1454,8 @@ struct server_response {
574
1454
  });
575
1455
 
576
1456
  for (int i = 0; i < (int) queue_results.size(); i++) {
577
- if (id_tasks.find(queue_results[i].id) != id_tasks.end()) {
578
- server_task_result res = queue_results[i];
1457
+ if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
1458
+ server_task_result_ptr res = std::move(queue_results[i]);
579
1459
  queue_results.erase(queue_results.begin() + i);
580
1460
  return res;
581
1461
  }
@@ -586,21 +1466,21 @@ struct server_response {
586
1466
  }
587
1467
 
588
1468
  // single-task version of recv()
589
- server_task_result recv(int id_task) {
1469
+ server_task_result_ptr recv(int id_task) {
590
1470
  std::unordered_set<int> id_tasks = {id_task};
591
1471
  return recv(id_tasks);
592
1472
  }
593
1473
 
594
1474
  // Send a new result to a waiting id_task
595
- void send(server_task_result & result) {
596
- SRV_DBG("sending result for task id = %d\n", result.id);
1475
+ void send(server_task_result_ptr && result) {
1476
+ SRV_DBG("sending result for task id = %d\n", result->id);
597
1477
 
598
1478
  std::unique_lock<std::mutex> lock(mutex_results);
599
1479
  for (const auto & id_task : waiting_task_ids) {
600
- if (result.id == id_task) {
601
- SRV_DBG("task id = %d moved to result queue\n", result.id);
1480
+ if (result->id == id_task) {
1481
+ SRV_DBG("task id = %d pushed to result queue\n", result->id);
602
1482
 
603
- queue_results.push_back(std::move(result));
1483
+ queue_results.emplace_back(std::move(result));
604
1484
  condition_results.notify_all();
605
1485
  return;
606
1486
  }
@@ -609,11 +1489,14 @@ struct server_response {
609
1489
  };
610
1490
 
611
1491
  struct server_context {
1492
+ common_params params_base;
1493
+
612
1494
  llama_model * model = nullptr;
613
1495
  llama_context * ctx = nullptr;
614
- std::vector<llama_lora_adapter_container> loras;
1496
+ std::vector<common_lora_adapter_container> loras;
615
1497
 
616
- gpt_params params;
1498
+ llama_model * model_dft = nullptr;
1499
+ llama_context_params cparams_dft;
617
1500
 
618
1501
  llama_batch batch = {};
619
1502
 
@@ -623,12 +1506,6 @@ struct server_context {
623
1506
 
624
1507
  int32_t n_ctx; // total context for all clients / slots
625
1508
 
626
- // system prompt
627
- bool system_need_update = false;
628
-
629
- std::string system_prompt;
630
- std::vector<llama_token> system_tokens;
631
-
632
1509
  // slots / clients
633
1510
  std::vector<server_slot> slots;
634
1511
  json default_generation_settings_for_props;
@@ -652,82 +1529,139 @@ struct server_context {
652
1529
  model = nullptr;
653
1530
  }
654
1531
 
1532
+ if (model_dft) {
1533
+ llama_free_model(model_dft);
1534
+ model_dft = nullptr;
1535
+ }
1536
+
655
1537
  // Clear any sampling context
656
1538
  for (server_slot & slot : slots) {
657
- if (slot.smpl != nullptr) {
658
- gpt_sampler_free(slot.smpl);
659
- }
1539
+ common_sampler_free(slot.smpl);
1540
+ slot.smpl = nullptr;
1541
+
1542
+ llama_free(slot.ctx_dft);
1543
+ slot.ctx_dft = nullptr;
1544
+
1545
+ common_speculative_free(slot.spec);
1546
+ slot.spec = nullptr;
1547
+
1548
+ llama_batch_free(slot.batch_spec);
660
1549
  }
661
1550
 
662
1551
  llama_batch_free(batch);
663
1552
  }
664
1553
 
665
- bool load_model(const gpt_params & params_) {
666
- params = params_;
1554
+ bool load_model(const common_params & params) {
1555
+ SRV_INF("loading model '%s'\n", params.model.c_str());
667
1556
 
668
- // dedicate one sequence to the system prompt
669
- params.n_parallel += 1;
1557
+ params_base = params;
670
1558
 
671
- llama_init_result llama_init = llama_init_from_gpt_params(params);
1559
+ common_init_result llama_init = common_init_from_params(params_base);
672
1560
 
673
1561
  model = llama_init.model;
674
1562
  ctx = llama_init.context;
675
1563
  loras = llama_init.lora_adapters;
676
1564
 
677
- params.n_parallel -= 1; // but be sneaky about it
678
-
679
1565
  if (model == nullptr) {
680
- SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
1566
+ SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
681
1567
  return false;
682
1568
  }
683
1569
 
684
1570
  n_ctx = llama_n_ctx(ctx);
685
1571
 
686
1572
  add_bos_token = llama_add_bos_token(model);
687
- has_eos_token = !llama_add_eos_token(model);
1573
+ has_eos_token = llama_token_eos(model) != LLAMA_TOKEN_NULL;
1574
+
1575
+ if (!params_base.speculative.model.empty()) {
1576
+ SRV_INF("loading draft model '%s'\n", params_base.speculative.model.c_str());
1577
+
1578
+ auto params_dft = params_base;
1579
+
1580
+ params_dft.devices = params_base.speculative.devices;
1581
+ params_dft.model = params_base.speculative.model;
1582
+ params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? params_base.n_ctx / params_base.n_parallel : params_base.speculative.n_ctx;
1583
+ params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers;
1584
+ params_dft.n_parallel = 1;
1585
+
1586
+ common_init_result llama_init_dft = common_init_from_params(params_dft);
1587
+
1588
+ model_dft = llama_init_dft.model;
1589
+
1590
+ if (model_dft == nullptr) {
1591
+ SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.c_str());
1592
+ return false;
1593
+ }
1594
+
1595
+ if (!common_speculative_are_compatible(ctx, llama_init_dft.context)) {
1596
+ SRV_ERR("the draft model '%s' is not compatible with the target model '%s'\n", params_base.speculative.model.c_str(), params_base.model.c_str());
1597
+
1598
+ llama_free (llama_init_dft.context);
1599
+ llama_free_model(llama_init_dft.model);
1600
+
1601
+ return false;
1602
+ }
1603
+
1604
+ const int n_ctx_dft = llama_n_ctx(llama_init_dft.context);
1605
+
1606
+ cparams_dft = common_context_params_to_llama(params_dft);
1607
+ cparams_dft.n_batch = n_ctx_dft;
1608
+
1609
+ // force F16 KV cache for the draft model for extra performance
1610
+ cparams_dft.type_k = GGML_TYPE_F16;
1611
+ cparams_dft.type_v = GGML_TYPE_F16;
1612
+
1613
+ // the context is not needed - we will create one for each slot
1614
+ llama_free(llama_init_dft.context);
1615
+ }
688
1616
 
689
1617
  return true;
690
1618
  }
691
1619
 
692
1620
  bool validate_model_chat_template() const {
693
- llama_chat_message chat[] = {{"user", "test"}};
694
-
695
- const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
696
-
697
- return res > 0;
1621
+ std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
1622
+ std::string template_key = "tokenizer.chat_template";
1623
+ int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
1624
+ if (res >= 0) {
1625
+ llama_chat_message chat[] = {{"user", "test"}};
1626
+ std::string tmpl = std::string(model_template.data(), model_template.size());
1627
+ int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
1628
+ return chat_res > 0;
1629
+ }
1630
+ return false;
698
1631
  }
699
1632
 
700
1633
  void init() {
701
- const int32_t n_ctx_slot = n_ctx / params.n_parallel;
1634
+ const int32_t n_ctx_slot = n_ctx / params_base.n_parallel;
702
1635
 
703
- SRV_INF("initializing slots, n_slots = %d\n", params.n_parallel);
1636
+ SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel);
704
1637
 
705
- for (int i = 0; i < params.n_parallel; i++) {
1638
+ for (int i = 0; i < params_base.n_parallel; i++) {
706
1639
  server_slot slot;
707
1640
 
708
1641
  slot.id = i;
1642
+ slot.ctx = ctx;
709
1643
  slot.n_ctx = n_ctx_slot;
710
- slot.n_predict = params.n_predict;
1644
+ slot.n_predict = params_base.n_predict;
711
1645
 
712
- SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
713
-
714
- const int ga_n = params.grp_attn_n;
715
- const int ga_w = params.grp_attn_w;
1646
+ if (model_dft) {
1647
+ slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1);
716
1648
 
717
- if (ga_n != 1) {
718
- GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
719
- GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
720
- //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
721
- //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
1649
+ slot.ctx_dft = llama_new_context_with_model(model_dft, cparams_dft);
1650
+ if (slot.ctx_dft == nullptr) {
1651
+ SRV_ERR("%s", "failed to create draft context\n");
1652
+ return;
1653
+ }
722
1654
 
723
- SLT_INF(slot, "slot self-extend: ga_n = %d, ga_w = %d\n", ga_n, ga_w);
1655
+ slot.spec = common_speculative_init(slot.ctx_dft);
1656
+ if (slot.spec == nullptr) {
1657
+ SRV_ERR("%s", "failed to create speculator\n");
1658
+ return;
1659
+ }
724
1660
  }
725
1661
 
726
- slot.ga_i = 0;
727
- slot.ga_n = ga_n;
728
- slot.ga_w = ga_w;
1662
+ SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
729
1663
 
730
- slot.sparams = params.sparams;
1664
+ slot.params.sampling = params_base.sampling;
731
1665
 
732
1666
  slot.callback_on_release = [this](int) {
733
1667
  queue_tasks.pop_deferred_task();
@@ -738,60 +1672,18 @@ struct server_context {
738
1672
  slots.push_back(slot);
739
1673
  }
740
1674
 
741
- default_generation_settings_for_props = get_formated_generation(slots.front());
742
- default_generation_settings_for_props["seed"] = -1;
1675
+ default_generation_settings_for_props = slots[0].to_json();
743
1676
 
744
1677
  // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens
745
1678
  // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used)
746
- {
747
- const int32_t n_batch = llama_n_batch(ctx);
748
-
749
- // only a single seq_id per token is needed
750
- batch = llama_batch_init(std::max(n_batch, params.n_parallel), 0, 1);
751
- }
752
-
753
- metrics.init();
754
- }
755
-
756
- std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
757
- // TODO: currently, we tokenize using special tokens by default
758
- // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
759
- // but it's better compared to completely ignoring ChatML and other chat templates
760
- const bool TMP_FORCE_SPECIAL = true;
761
-
762
- // If `add_bos` is true, we only add BOS, when json_prompt is a string,
763
- // or the first element of the json_prompt array is a string.
764
- std::vector<llama_token> prompt_tokens;
765
-
766
- if (json_prompt.is_array()) {
767
- bool first = true;
768
- for (const auto & p : json_prompt) {
769
- if (p.is_string()) {
770
- auto s = p.template get<std::string>();
771
-
772
- std::vector<llama_token> p;
773
- if (first) {
774
- p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
775
- first = false;
776
- } else {
777
- p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
778
- }
779
-
780
- prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
781
- } else {
782
- if (first) {
783
- first = false;
784
- }
1679
+ {
1680
+ const int32_t n_batch = llama_n_batch(ctx);
785
1681
 
786
- prompt_tokens.push_back(p.template get<llama_token>());
787
- }
788
- }
789
- } else {
790
- auto s = json_prompt.template get<std::string>();
791
- prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
1682
+ // only a single seq_id per token is needed
1683
+ batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1);
792
1684
  }
793
1685
 
794
- return prompt_tokens;
1686
+ metrics.init();
795
1687
  }
796
1688
 
797
1689
  server_slot * get_slot_by_id(int id) {
@@ -804,12 +1696,12 @@ struct server_context {
804
1696
  return nullptr;
805
1697
  }
806
1698
 
807
- server_slot * get_available_slot(const std::string & prompt) {
1699
+ server_slot * get_available_slot(const server_task & task) {
808
1700
  server_slot * ret = nullptr;
809
1701
 
810
1702
  // find the slot that has at least n% prompt similarity
811
- if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
812
- int max_lcp_len = 0;
1703
+ if (ret == nullptr && slot_prompt_similarity != 0.0f) {
1704
+ int lcs_len = 0;
813
1705
  float similarity = 0;
814
1706
 
815
1707
  for (server_slot & slot : slots) {
@@ -818,32 +1710,27 @@ struct server_context {
818
1710
  continue;
819
1711
  }
820
1712
 
821
- // skip the slot if it does not contains prompt
822
- if (!slot.prompt.is_string()) {
1713
+ // skip the slot if it does not contains cached tokens
1714
+ if (slot.cache_tokens.empty()) {
823
1715
  continue;
824
1716
  }
825
1717
 
826
- // current slot's prompt
827
- std::string slot_prompt = slot.prompt.get<std::string>();
828
-
829
- // length of the current slot's prompt
830
- int slot_prompt_len = slot_prompt.size();
1718
+ // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
1719
+ int cur_lcs_len = common_lcs(slot.cache_tokens, task.prompt_tokens);
831
1720
 
832
- // length of the Longest Common Prefix between the current slot's prompt and the input prompt
833
- int lcp_len = common_part(slot_prompt, prompt);
834
-
835
- // fraction of the common substring length compared to the current slot's prompt length
836
- similarity = static_cast<float>(lcp_len) / slot_prompt_len;
1721
+ // fraction of the common subsequence length compared to the current slot's prompt length
1722
+ float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
837
1723
 
838
1724
  // select the current slot if the criteria match
839
- if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
840
- max_lcp_len = lcp_len;
1725
+ if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
1726
+ lcs_len = cur_lcs_len;
1727
+ similarity = cur_similarity;
841
1728
  ret = &slot;
842
1729
  }
843
1730
  }
844
1731
 
845
1732
  if (ret != nullptr) {
846
- SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
1733
+ SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity);
847
1734
  }
848
1735
  }
849
1736
 
@@ -872,65 +1759,14 @@ struct server_context {
872
1759
  }
873
1760
 
874
1761
  bool launch_slot_with_task(server_slot & slot, const server_task & task) {
875
- slot_params default_params;
876
- // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
877
- auto default_sparams = params.sparams;
878
- const auto & data = task.data;
879
-
880
- if (data.count("__oaicompat") != 0) {
881
- slot.oaicompat = true;
882
- slot.oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
883
- } else {
884
- slot.oaicompat = false;
885
- slot.oaicompat_model = "";
886
- }
887
-
888
- slot.params.stream = json_value(data, "stream", false);
889
- slot.params.cache_prompt = json_value(data, "cache_prompt", false);
890
- slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
891
- slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
892
- slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
893
- slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
894
- slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
895
- slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
896
- slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
897
- slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
898
- slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
899
- slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
900
- slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
901
- slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
902
- slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
903
- slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
904
- slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
905
- slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
906
- slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
907
- slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
908
- slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
909
- slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
910
- slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
911
- slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
912
-
913
- // process "json_schema" and "grammar"
914
- if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
915
- send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
916
- return false;
917
- }
918
- if (data.contains("json_schema") && !data.contains("grammar")) {
919
- try {
920
- auto schema = json_value(data, "json_schema", json::object());
921
- slot.sparams.grammar = json_schema_to_grammar(schema);
922
- } catch (const std::exception & e) {
923
- send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
924
- return false;
925
- }
926
- } else {
927
- slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
928
- }
1762
+ slot.reset();
1763
+ slot.id_task = task.id;
1764
+ slot.index = task.index;
1765
+ slot.task_type = task.type;
1766
+ slot.params = std::move(task.params);
1767
+ slot.prompt_tokens = std::move(task.prompt_tokens);
929
1768
 
930
- if (slot.params.cache_prompt && slot.ga_n != 1) {
931
- slot.params.cache_prompt = false;
932
- SLT_WRN(slot, "%s", "group-attention is not supported with prompt caching. disabling cache\n");
933
- }
1769
+ SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
934
1770
 
935
1771
  if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
936
1772
  // Might be better to reject the request with a 400 ?
@@ -938,111 +1774,16 @@ struct server_context {
938
1774
  SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
939
1775
  }
940
1776
 
941
- // infill
942
- slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
943
- slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
944
-
945
- // get prompt
946
- if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
947
- const auto & prompt = data.find("prompt");
948
- if (prompt == data.end()) {
949
- send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
950
- return false;
951
- }
952
-
953
- if ((prompt->is_string()) ||
954
- (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
955
- (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
956
- slot.prompt = *prompt;
957
- } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
958
- slot.prompt = prompt->at(0);
959
- } else if (prompt->is_array() && prompt->size() > 1) {
960
- // array of strings
961
- for (const auto & el : *prompt) {
962
- if (!el.is_string()) {
963
- send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
964
- return false;
965
- }
966
- }
967
- slot.prompt = *prompt;
968
- } else {
969
- send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
970
- return false;
971
- }
972
- }
973
-
974
- {
975
- slot.sparams.logit_bias.clear();
976
-
977
- if (json_value(data, "ignore_eos", false) && has_eos_token) {
978
- slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
979
- }
980
-
981
- const auto & logit_bias = data.find("logit_bias");
982
- if (logit_bias != data.end() && logit_bias->is_array()) {
983
- const int n_vocab = llama_n_vocab(model);
984
- for (const auto & el : *logit_bias) {
985
- // TODO: we may want to throw errors here, in case "el" is incorrect
986
- if (el.is_array() && el.size() == 2) {
987
- float bias;
988
- if (el[1].is_number()) {
989
- bias = el[1].get<float>();
990
- } else if (el[1].is_boolean() && !el[1].get<bool>()) {
991
- bias = -INFINITY;
992
- } else {
993
- continue;
994
- }
995
-
996
- if (el[0].is_number_integer()) {
997
- llama_token tok = el[0].get<llama_token>();
998
- if (tok >= 0 && tok < n_vocab) {
999
- slot.sparams.logit_bias.push_back({tok, bias});
1000
- }
1001
- } else if (el[0].is_string()) {
1002
- auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
1003
- for (auto tok : toks) {
1004
- slot.sparams.logit_bias.push_back({tok, bias});
1005
- }
1006
- }
1007
- }
1008
- }
1009
- }
1010
- }
1011
-
1012
- {
1013
- slot.params.antiprompt.clear();
1014
-
1015
- const auto & stop = data.find("stop");
1016
- if (stop != data.end() && stop->is_array()) {
1017
- for (const auto & word : *stop) {
1018
- if (!word.empty()) {
1019
- slot.params.antiprompt.push_back(word);
1020
- }
1021
- }
1022
- }
1023
- }
1024
-
1025
- {
1026
- const auto & samplers = data.find("samplers");
1027
- if (samplers != data.end() && samplers->is_array()) {
1028
- std::vector<std::string> sampler_names;
1029
- for (const auto & name : *samplers) {
1030
- if (name.is_string()) {
1031
- sampler_names.emplace_back(name);
1032
- }
1033
- }
1034
- slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
1035
- } else {
1036
- slot.sparams.samplers = default_sparams.samplers;
1037
- }
1777
+ if (slot.params.ignore_eos && has_eos_token) {
1778
+ slot.params.sampling.logit_bias.push_back({llama_token_eos(model), -INFINITY});
1038
1779
  }
1039
1780
 
1040
1781
  {
1041
1782
  if (slot.smpl != nullptr) {
1042
- gpt_sampler_free(slot.smpl);
1783
+ common_sampler_free(slot.smpl);
1043
1784
  }
1044
1785
 
1045
- slot.smpl = gpt_sampler_init(model, slot.sparams);
1786
+ slot.smpl = common_sampler_init(model, slot.params.sampling);
1046
1787
  if (slot.smpl == nullptr) {
1047
1788
  // for now, the only error that may happen here is invalid grammar
1048
1789
  send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@@ -1050,8 +1791,13 @@ struct server_context {
1050
1791
  }
1051
1792
  }
1052
1793
 
1053
- slot.state = SLOT_STATE_PROCESSING_PROMPT;
1054
- slot.prompt_tokens.clear();
1794
+ if (slot.ctx_dft) {
1795
+ llama_batch_free(slot.batch_spec);
1796
+
1797
+ slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1);
1798
+ }
1799
+
1800
+ slot.state = SLOT_STATE_STARTED;
1055
1801
 
1056
1802
  SLT_INF(slot, "%s", "processing task\n");
1057
1803
 
@@ -1066,107 +1812,40 @@ struct server_context {
1066
1812
  clean_kv_cache = false;
1067
1813
  }
1068
1814
 
1069
- void system_prompt_update() {
1070
- SRV_DBG("updating system prompt: '%s'\n", system_prompt.c_str());
1071
-
1072
- kv_cache_clear();
1073
- system_tokens.clear();
1074
-
1075
- if (!system_prompt.empty()) {
1076
- system_tokens = ::llama_tokenize(ctx, system_prompt, true);
1077
-
1078
- const int32_t n_batch = llama_n_batch(ctx);
1079
- const int32_t n_tokens_prompt = system_tokens.size();
1080
-
1081
- for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
1082
- const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
1083
-
1084
- llama_batch_clear(batch);
1085
-
1086
- for (int32_t j = 0; j < n_tokens; ++j) {
1087
- llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
1088
- }
1089
-
1090
- if (llama_decode(ctx, batch) != 0) {
1091
- SRV_ERR("%s", "llama_decode() failed\n");
1092
- return;
1093
- }
1094
- }
1095
-
1096
- // assign the system KV cache to all parallel sequences
1097
- for (int32_t i = 1; i <= params.n_parallel; ++i) {
1098
- llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
1099
- }
1100
- }
1101
-
1102
- system_need_update = false;
1103
- }
1104
-
1105
- bool system_prompt_set(const std::string & sys_prompt) {
1106
- SRV_DBG("system prompt set: '%s'\n", system_prompt.c_str());
1107
-
1108
- system_prompt = sys_prompt;
1109
-
1110
- // release all slots
1111
- for (server_slot & slot : slots) {
1112
- slot.release();
1113
- }
1114
-
1115
- system_need_update = true;
1116
- return true;
1117
- }
1118
-
1119
1815
  bool process_token(completion_token_output & result, server_slot & slot) {
1120
1816
  // remember which tokens were sampled - used for repetition penalties during sampling
1121
- const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
1817
+ const std::string token_str = result.text_to_send;
1122
1818
  slot.sampled = result.tok;
1123
1819
 
1124
- // search stop word and delete it
1125
1820
  slot.generated_text += token_str;
1821
+ if (slot.params.return_tokens) {
1822
+ slot.generated_tokens.push_back(result.tok);
1823
+ }
1126
1824
  slot.has_next_token = true;
1127
1825
 
1128
1826
  // check if there is incomplete UTF-8 character at the end
1129
- bool incomplete = false;
1130
- for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
1131
- unsigned char c = slot.generated_text[slot.generated_text.size() - i];
1132
- if ((c & 0xC0) == 0x80) {
1133
- // continuation byte: 10xxxxxx
1134
- continue;
1135
- }
1136
- if ((c & 0xE0) == 0xC0) {
1137
- // 2-byte character: 110xxxxx ...
1138
- incomplete = i < 2;
1139
- } else if ((c & 0xF0) == 0xE0) {
1140
- // 3-byte character: 1110xxxx ...
1141
- incomplete = i < 3;
1142
- } else if ((c & 0xF8) == 0xF0) {
1143
- // 4-byte character: 11110xxx ...
1144
- incomplete = i < 4;
1145
- }
1146
- // else 1-byte character or invalid byte
1147
- break;
1148
- }
1827
+ bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size();
1149
1828
 
1829
+ // search stop word and delete it
1150
1830
  if (!incomplete) {
1151
1831
  size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
1152
1832
 
1153
1833
  const std::string str_test = slot.generated_text.substr(pos);
1154
- bool is_stop_full = false;
1834
+ bool send_text = true;
1155
1835
 
1156
- size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
1836
+ size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
1157
1837
  if (stop_pos != std::string::npos) {
1158
- is_stop_full = true;
1159
1838
  slot.generated_text.erase(
1160
1839
  slot.generated_text.begin() + pos + stop_pos,
1161
1840
  slot.generated_text.end());
1162
1841
  pos = std::min(slot.n_sent_text, slot.generated_text.size());
1163
- } else {
1164
- is_stop_full = false;
1165
- stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
1842
+ } else if (slot.has_next_token) {
1843
+ stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
1844
+ send_text = stop_pos == std::string::npos;
1166
1845
  }
1167
1846
 
1168
1847
  // check if there is any token to predict
1169
- if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
1848
+ if (send_text) {
1170
1849
  // no send the stop word in the response
1171
1850
  result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
1172
1851
  slot.n_sent_text += result.text_to_send.size();
@@ -1184,24 +1863,74 @@ struct server_context {
1184
1863
  }
1185
1864
 
1186
1865
  // check the limits
1187
- if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params)) {
1188
- slot.stopped_limit = true;
1866
+ if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) {
1867
+ slot.stop = STOP_TYPE_LIMIT;
1189
1868
  slot.has_next_token = false;
1190
1869
 
1191
1870
  SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
1192
1871
  }
1193
1872
 
1873
+ if (slot.has_new_line) {
1874
+ // if we have already seen a new line, we stop after a certain time limit
1875
+ if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
1876
+ slot.stop = STOP_TYPE_LIMIT;
1877
+ slot.has_next_token = false;
1878
+
1879
+ SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
1880
+ }
1881
+
1882
+ // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
1883
+ if (slot.params.n_indent > 0) {
1884
+ // check the current indentation
1885
+ // TODO: improve by not doing it more than once for each new line
1886
+ if (slot.last_nl_pos > 0) {
1887
+ size_t pos = slot.last_nl_pos;
1888
+
1889
+ int n_indent = 0;
1890
+ while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
1891
+ n_indent++;
1892
+ pos++;
1893
+ }
1894
+
1895
+ if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
1896
+ slot.stop = STOP_TYPE_LIMIT;
1897
+ slot.has_next_token = false;
1898
+
1899
+ // cut the last line
1900
+ slot.generated_text.erase(pos, std::string::npos);
1901
+
1902
+ SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
1903
+ }
1904
+ }
1905
+
1906
+ // find the next new line
1907
+ {
1908
+ const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
1909
+
1910
+ if (pos != std::string::npos) {
1911
+ slot.last_nl_pos = pos + 1;
1912
+ }
1913
+ }
1914
+ }
1915
+ }
1916
+
1917
+ // check if there is a new line in the generated text
1918
+ if (result.text_to_send.find('\n') != std::string::npos) {
1919
+ slot.has_new_line = true;
1920
+ }
1921
+
1194
1922
  // if context shift is disabled, we stop when it reaches the context limit
1195
- if (slot.n_decoded >= slot.n_ctx) {
1923
+ if (slot.n_past >= slot.n_ctx) {
1196
1924
  slot.truncated = true;
1197
- slot.stopped_limit = true;
1925
+ slot.stop = STOP_TYPE_LIMIT;
1198
1926
  slot.has_next_token = false;
1199
1927
 
1200
- SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
1928
+ SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
1929
+ slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
1201
1930
  }
1202
1931
 
1203
1932
  if (llama_token_is_eog(model, result.tok)) {
1204
- slot.stopped_eos = true;
1933
+ slot.stop = STOP_TYPE_EOS;
1205
1934
  slot.has_next_token = false;
1206
1935
 
1207
1936
  SLT_DBG(slot, "%s", "stopped by EOS\n");
@@ -1209,63 +1938,69 @@ struct server_context {
1209
1938
 
1210
1939
  const auto n_ctx_train = llama_n_ctx_train(model);
1211
1940
 
1212
- if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1941
+ if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1213
1942
  slot.truncated = true;
1214
- slot.stopped_limit = true;
1943
+ slot.stop = STOP_TYPE_LIMIT;
1215
1944
  slot.has_next_token = false; // stop prediction
1216
1945
 
1217
1946
  SLT_WRN(slot,
1218
- "n_predict (%d) is not set and self-context extend is disabled. "
1947
+ "n_predict (%d) is set for infinite generation. "
1219
1948
  "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
1220
1949
  slot.params.n_predict, n_ctx_train);
1221
1950
  }
1222
1951
 
1223
- SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
1952
+ SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
1224
1953
 
1225
1954
  return slot.has_next_token; // continue
1226
1955
  }
1227
1956
 
1228
- json get_formated_generation(const server_slot & slot) const {
1229
- std::vector<std::string> samplers;
1230
- samplers.reserve(slot.sparams.samplers.size());
1231
- for (const auto & sampler : slot.sparams.samplers) {
1232
- samplers.emplace_back(gpt_sampler_type_to_str(sampler));
1233
- }
1957
+ void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) {
1958
+ size_t n_probs = slot.params.sampling.n_probs;
1959
+ size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
1960
+ if (post_sampling) {
1961
+ const auto * cur_p = common_sampler_get_candidates(slot.smpl);
1962
+ const size_t max_probs = cur_p->size;
1963
+
1964
+ // set probability for sampled token
1965
+ for (size_t i = 0; i < max_probs; i++) {
1966
+ if (cur_p->data[i].id == result.tok) {
1967
+ result.prob = cur_p->data[i].p;
1968
+ break;
1969
+ }
1970
+ }
1234
1971
 
1235
- return json {
1236
- {"n_ctx", slot.n_ctx},
1237
- {"n_predict", slot.n_predict}, // Server configured n_predict
1238
- {"model", params.model_alias},
1239
- {"seed", slot.sparams.seed},
1240
- {"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0},
1241
- {"temperature", slot.sparams.temp},
1242
- {"dynatemp_range", slot.sparams.dynatemp_range},
1243
- {"dynatemp_exponent", slot.sparams.dynatemp_exponent},
1244
- {"top_k", slot.sparams.top_k},
1245
- {"top_p", slot.sparams.top_p},
1246
- {"min_p", slot.sparams.min_p},
1247
- {"tfs_z", slot.sparams.tfs_z},
1248
- {"typical_p", slot.sparams.typ_p},
1249
- {"repeat_last_n", slot.sparams.penalty_last_n},
1250
- {"repeat_penalty", slot.sparams.penalty_repeat},
1251
- {"presence_penalty", slot.sparams.penalty_present},
1252
- {"frequency_penalty", slot.sparams.penalty_freq},
1253
- {"mirostat", slot.sparams.mirostat},
1254
- {"mirostat_tau", slot.sparams.mirostat_tau},
1255
- {"mirostat_eta", slot.sparams.mirostat_eta},
1256
- {"penalize_nl", slot.sparams.penalize_nl},
1257
- {"stop", slot.params.antiprompt},
1258
- {"max_tokens", slot.params.n_predict}, // User configured n_predict
1259
- {"n_keep", slot.params.n_keep},
1260
- {"n_discard", slot.params.n_discard},
1261
- {"ignore_eos", slot.sparams.ignore_eos},
1262
- {"stream", slot.params.stream},
1263
- //{"logit_bias", slot.sparams.logit_bias},
1264
- {"n_probs", slot.sparams.n_probs},
1265
- {"min_keep", slot.sparams.min_keep},
1266
- {"grammar", slot.sparams.grammar},
1267
- {"samplers", samplers},
1268
- };
1972
+ // set probability for top n_probs tokens
1973
+ result.probs.reserve(max_probs);
1974
+ for (size_t i = 0; i < std::min(max_probs, n_probs); i++) {
1975
+ result.probs.push_back({
1976
+ cur_p->data[i].id,
1977
+ common_detokenize(ctx, {cur_p->data[i].id}, special),
1978
+ cur_p->data[i].p
1979
+ });
1980
+ }
1981
+ } else {
1982
+ // TODO: optimize this with min-p optimization
1983
+ std::vector<llama_token_data> cur = get_token_probabilities(ctx, idx);
1984
+
1985
+ // set probability for sampled token
1986
+ for (size_t i = 0; i < n_vocab; i++) {
1987
+ // set probability for sampled token
1988
+ if (cur[i].id == result.tok) {
1989
+ result.prob = cur[i].p;
1990
+ break;
1991
+ }
1992
+ }
1993
+
1994
+ // set probability for top n_probs tokens
1995
+ result.probs.reserve(n_probs);
1996
+ for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) {
1997
+ result.probs.push_back({
1998
+ cur[i].id,
1999
+ common_detokenize(ctx, {cur[i].id}, special),
2000
+ cur[i].p
2001
+ });
2002
+ }
2003
+ }
1269
2004
  }
1270
2005
 
1271
2006
  void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
@@ -1279,114 +2014,106 @@ struct server_context {
1279
2014
  void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) {
1280
2015
  SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str());
1281
2016
 
1282
- server_task_result res;
1283
- res.id = id_task;
1284
- res.stop = false;
1285
- res.error = true;
1286
- res.data = format_error_response(error, type);
1287
-
1288
- queue_results.send(res);
1289
- }
1290
-
1291
- void send_partial_response(server_slot & slot, completion_token_output tkn) {
1292
- server_task_result res;
1293
- res.id = slot.id_task;
1294
- res.error = false;
1295
- res.stop = false;
1296
- res.data = json {
1297
- {"content", tkn.text_to_send},
1298
- {"stop", false},
1299
- {"id_slot", slot.id},
1300
- {"multimodal", false},
1301
- {"index", slot.index},
1302
- };
2017
+ auto res = std::make_unique<server_task_result_error>();
2018
+ res->id = id_task;
2019
+ res->err_type = type;
2020
+ res->err_msg = error;
1303
2021
 
1304
- if (slot.sparams.n_probs > 0) {
1305
- const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
1306
- const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
1307
- const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
2022
+ queue_results.send(std::move(res));
2023
+ }
1308
2024
 
1309
- std::vector<completion_token_output> probs_output;
1310
- if (probs_pos < probs_stop_pos) {
1311
- probs_output = std::vector<completion_token_output>(
1312
- slot.generated_token_probs.begin() + probs_pos,
1313
- slot.generated_token_probs.begin() + probs_stop_pos);
1314
- }
1315
- slot.n_sent_token_probs = probs_stop_pos;
2025
+ void send_partial_response(server_slot & slot, const completion_token_output & tkn) {
2026
+ auto res = std::make_unique<server_task_result_cmpl_partial>();
2027
+
2028
+ res->id = slot.id_task;
2029
+ res->index = slot.index;
2030
+ res->content = tkn.text_to_send;
2031
+ res->tokens = { tkn.tok };
2032
+
2033
+ res->n_decoded = slot.n_decoded;
2034
+ res->n_prompt_tokens = slot.n_prompt_tokens;
2035
+ res->post_sampling_probs = slot.params.post_sampling_probs;
1316
2036
 
1317
- res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
2037
+ res->verbose = slot.params.verbose;
2038
+ res->oaicompat = slot.params.oaicompat;
2039
+ res->oaicompat_chat = slot.params.oaicompat_chat;
2040
+ res->oaicompat_model = slot.params.oaicompat_model;
2041
+ res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2042
+
2043
+ // populate res.probs_output
2044
+ if (slot.params.sampling.n_probs > 0) {
2045
+ res->prob_output = tkn; // copy the token probs
1318
2046
  }
1319
2047
 
1320
- if (slot.oaicompat) {
1321
- res.data["oaicompat_token_ctr"] = slot.n_decoded;
1322
- res.data["model"] = slot.oaicompat_model;
2048
+ // populate timings if this is final response or timings_per_token is enabled
2049
+ if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
2050
+ res->timings = slot.get_timings();
1323
2051
  }
1324
2052
 
1325
- queue_results.send(res);
2053
+ queue_results.send(std::move(res));
1326
2054
  }
1327
2055
 
1328
- void send_final_response(const server_slot & slot) {
1329
- server_task_result res;
1330
- res.id = slot.id_task;
1331
- res.error = false;
1332
- res.stop = true;
1333
- res.data = json {
1334
- {"content", !slot.params.stream ? slot.generated_text : ""},
1335
- {"id_slot", slot.id},
1336
- {"stop", true},
1337
- {"model", params.model_alias},
1338
- {"tokens_predicted", slot.n_decoded},
1339
- {"tokens_evaluated", slot.n_prompt_tokens},
1340
- {"generation_settings", get_formated_generation(slot)},
1341
- {"prompt", slot.prompt},
1342
- {"truncated", slot.truncated},
1343
- {"stopped_eos", slot.stopped_eos},
1344
- {"stopped_word", slot.stopped_word},
1345
- {"stopped_limit", slot.stopped_limit},
1346
- {"stopping_word", slot.stopping_word},
1347
- {"tokens_cached", slot.n_past},
1348
- {"timings", slot.get_formated_timings()},
1349
- {"index", slot.index},
1350
- };
1351
-
1352
- if (slot.sparams.n_probs > 0) {
1353
- std::vector<completion_token_output> probs;
1354
- if (!slot.params.stream && slot.stopped_word) {
1355
- const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
2056
+ void send_final_response(server_slot & slot) {
2057
+ auto res = std::make_unique<server_task_result_cmpl_final>();
2058
+ res->id = slot.id_task;
2059
+ res->id_slot = slot.id;
2060
+
2061
+ res->index = slot.index;
2062
+ res->content = slot.generated_text;
2063
+ res->tokens = slot.generated_tokens;
2064
+ res->timings = slot.get_timings();
2065
+ res->prompt = common_detokenize(ctx, slot.prompt_tokens, true);
2066
+
2067
+ res->truncated = slot.truncated;
2068
+ res->n_decoded = slot.n_decoded;
2069
+ res->n_prompt_tokens = slot.n_prompt_tokens;
2070
+ res->n_tokens_cached = slot.n_past;
2071
+ res->has_new_line = slot.has_new_line;
2072
+ res->stopping_word = slot.stopping_word;
2073
+ res->stop = slot.stop;
2074
+ res->post_sampling_probs = slot.params.post_sampling_probs;
2075
+
2076
+ res->verbose = slot.params.verbose;
2077
+ res->stream = slot.params.stream;
2078
+ res->oaicompat = slot.params.oaicompat;
2079
+ res->oaicompat_chat = slot.params.oaicompat_chat;
2080
+ res->oaicompat_model = slot.params.oaicompat_model;
2081
+ res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
2082
+
2083
+ // populate res.probs_output
2084
+ if (slot.params.sampling.n_probs > 0) {
2085
+ if (!slot.params.stream && slot.stop == STOP_TYPE_WORD) {
2086
+ const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
1356
2087
 
1357
2088
  size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
1358
- probs = std::vector<completion_token_output>(
2089
+ res->probs_output = std::vector<completion_token_output>(
1359
2090
  slot.generated_token_probs.begin(),
1360
2091
  slot.generated_token_probs.end() - safe_offset);
1361
2092
  } else {
1362
- probs = std::vector<completion_token_output>(
2093
+ res->probs_output = std::vector<completion_token_output>(
1363
2094
  slot.generated_token_probs.begin(),
1364
2095
  slot.generated_token_probs.end());
1365
2096
  }
1366
-
1367
- res.data["completion_probabilities"] = probs_vector_to_json(ctx, probs);
1368
2097
  }
1369
2098
 
1370
- if (slot.oaicompat) {
1371
- res.data["oaicompat_token_ctr"] = slot.n_decoded;
1372
- res.data["model"] = slot.oaicompat_model;
1373
- }
2099
+ res->generation_params = slot.params; // copy the parameters
1374
2100
 
1375
- queue_results.send(res);
2101
+ queue_results.send(std::move(res));
1376
2102
  }
1377
2103
 
1378
2104
  void send_embedding(const server_slot & slot, const llama_batch & batch) {
1379
- server_task_result res;
1380
- res.id = slot.id_task;
1381
- res.error = false;
1382
- res.stop = true;
2105
+ auto res = std::make_unique<server_task_result_embd>();
2106
+ res->id = slot.id_task;
2107
+ res->index = slot.index;
2108
+ res->n_tokens = slot.n_prompt_tokens;
2109
+ res->oaicompat = slot.params.oaicompat;
1383
2110
 
1384
2111
  const int n_embd = llama_n_embd(model);
1385
2112
 
1386
2113
  std::vector<float> embd_res(n_embd, 0.0f);
1387
2114
 
1388
2115
  for (int i = 0; i < batch.n_tokens; ++i) {
1389
- if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
2116
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1390
2117
  continue;
1391
2118
  }
1392
2119
 
@@ -1398,35 +2125,33 @@ struct server_context {
1398
2125
  if (embd == NULL) {
1399
2126
  SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1400
2127
 
1401
- res.data = json {
1402
- {"embedding", std::vector<float>(n_embd, 0.0f)},
1403
- {"index", slot.index},
1404
- };
1405
-
2128
+ res->embedding.push_back(std::vector<float>(n_embd, 0.0f));
1406
2129
  continue;
1407
2130
  }
1408
2131
 
1409
- llama_embd_normalize(embd, embd_res.data(), n_embd);
1410
-
1411
- res.data = json {
1412
- {"embedding", embd_res},
1413
- {"index", slot.index},
1414
- };
2132
+ // normalize only when there is pooling
2133
+ // TODO: configurable
2134
+ if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) {
2135
+ common_embd_normalize(embd, embd_res.data(), n_embd, 2);
2136
+ res->embedding.push_back(embd_res);
2137
+ } else {
2138
+ res->embedding.push_back({ embd, embd + n_embd });
2139
+ }
1415
2140
  }
1416
2141
 
1417
2142
  SLT_DBG(slot, "%s", "sending embeddings\n");
1418
2143
 
1419
- queue_results.send(res);
2144
+ queue_results.send(std::move(res));
1420
2145
  }
1421
2146
 
1422
2147
  void send_rerank(const server_slot & slot, const llama_batch & batch) {
1423
- server_task_result res;
1424
- res.id = slot.id_task;
1425
- res.error = false;
1426
- res.stop = true;
2148
+ auto res = std::make_unique<server_task_result_rerank>();
2149
+ res->id = slot.id_task;
2150
+ res->index = slot.index;
2151
+ res->n_tokens = slot.n_prompt_tokens;
1427
2152
 
1428
2153
  for (int i = 0; i < batch.n_tokens; ++i) {
1429
- if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
2154
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1430
2155
  continue;
1431
2156
  }
1432
2157
 
@@ -1438,100 +2163,29 @@ struct server_context {
1438
2163
  if (embd == NULL) {
1439
2164
  SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]);
1440
2165
 
1441
- res.data = json {
1442
- {"index", slot.index},
1443
- {"score", -1e6},
1444
- };
1445
-
2166
+ res->score = -1e6;
1446
2167
  continue;
1447
2168
  }
1448
2169
 
1449
- res.data = json {
1450
- {"index", slot.index},
1451
- {"score", embd[0]},
1452
- };
2170
+ res->score = embd[0];
1453
2171
  }
1454
2172
 
1455
- SLT_DBG(slot, "sending rerank result, res = '%s'\n", res.data.dump().c_str());
2173
+ SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score);
1456
2174
 
1457
- queue_results.send(res);
2175
+ queue_results.send(std::move(res));
1458
2176
  }
1459
2177
 
1460
2178
  //
1461
2179
  // Functions to create new task(s) and receive result(s)
1462
2180
  //
1463
2181
 
1464
- std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
1465
- std::vector<server_task> tasks;
1466
- auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
1467
- server_task task;
1468
- task.id = queue_tasks.get_new_id();
1469
- task.cmpl_type = cmpl_type;
1470
- task.type = SERVER_TASK_TYPE_COMPLETION;
1471
- if (replace_prompt) {
1472
- task.data = task_data;
1473
- task.data["prompt"] = std::move(prompt);
1474
- } else {
1475
- task.data = std::move(task_data);
1476
- }
1477
- tasks.push_back(std::move(task));
1478
- };
1479
-
1480
- static constexpr const char * error_msg = "\"prompt\" must be a string, an array of token ids or an array of prompts";
1481
- if (!data.contains("prompt")) {
1482
- throw std::runtime_error(error_msg);
1483
- }
1484
-
1485
- json prompt = data.at("prompt");
1486
-
1487
- // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
1488
- if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
1489
- data["index"] = 0;
1490
- create_task(data, false, nullptr);
1491
- }
1492
- // otherwise, it's a multiple-prompt task, we break it into smaller tasks
1493
- else if (prompt.is_array()) {
1494
- std::vector<json> prompts = prompt;
1495
- if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1496
- // prompts[0] is the question
1497
- // the rest are the answers/documents
1498
- SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
1499
- for (size_t i = 1; i < prompts.size(); i++) {
1500
- json qd;
1501
- qd.push_back(prompts[0]);
1502
- qd.push_back(prompts[i]);
1503
- data["index"] = i - 1;
1504
- create_task(data, true, qd);
1505
- }
1506
- } else {
1507
- SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
1508
- for (size_t i = 0; i < prompts.size(); i++) {
1509
- const auto & e = prompts[i];
1510
- if (e.is_string() || json_is_array_of_numbers(e)) {
1511
- data["index"] = i;
1512
- create_task(data, true, e);
1513
- } else {
1514
- throw std::runtime_error(error_msg);
1515
- }
1516
- }
1517
- }
1518
- }
1519
- // invalid case
1520
- else {
1521
- throw std::runtime_error(error_msg);
1522
- }
1523
-
1524
- return tasks;
1525
- }
1526
-
1527
2182
  void cancel_tasks(const std::unordered_set<int> & id_tasks) {
1528
2183
  std::vector<server_task> cancel_tasks;
1529
2184
  cancel_tasks.reserve(id_tasks.size());
1530
2185
  for (const auto & id_task : id_tasks) {
1531
2186
  SRV_WRN("cancel task, id_task = %d\n", id_task);
1532
2187
 
1533
- server_task task;
1534
- task.type = SERVER_TASK_TYPE_CANCEL;
2188
+ server_task task(SERVER_TASK_TYPE_CANCEL);
1535
2189
  task.id_target = id_task;
1536
2190
  cancel_tasks.push_back(task);
1537
2191
  queue_results.remove_waiting_task_id(id_task);
@@ -1540,50 +2194,58 @@ struct server_context {
1540
2194
  queue_tasks.post(cancel_tasks, true);
1541
2195
  }
1542
2196
 
1543
- // receive the results from task(s) created by create_tasks_cmpl
1544
- void receive_cmpl_results(
2197
+ // receive the results from task(s)
2198
+ void receive_multi_results(
1545
2199
  const std::unordered_set<int> & id_tasks,
1546
- const std::function<void(std::vector<server_task_result>&)> & result_handler,
2200
+ const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
1547
2201
  const std::function<void(json)> & error_handler) {
1548
- // TODO: currently, there is no way to detect the client has cancelled the request
1549
- std::vector<server_task_result> results(id_tasks.size());
2202
+ std::vector<server_task_result_ptr> results(id_tasks.size());
1550
2203
  for (size_t i = 0; i < id_tasks.size(); i++) {
1551
- server_task_result result = queue_results.recv(id_tasks);
2204
+ server_task_result_ptr result = queue_results.recv(id_tasks);
1552
2205
 
1553
- if (result.error) {
1554
- error_handler(result.data);
2206
+ if (result->is_error()) {
2207
+ error_handler(result->to_json());
1555
2208
  cancel_tasks(id_tasks);
1556
2209
  return;
1557
2210
  }
1558
2211
 
1559
- const size_t idx = result.data["index"];
2212
+ GGML_ASSERT(
2213
+ dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
2214
+ || dynamic_cast<server_task_result_embd*>(result.get()) != nullptr
2215
+ || dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr
2216
+ );
2217
+ const size_t idx = result->get_index();
1560
2218
  GGML_ASSERT(idx < results.size() && "index out of range");
1561
-
1562
- results[idx] = result;
2219
+ results[idx] = std::move(result);
1563
2220
  }
1564
2221
  result_handler(results);
1565
2222
  }
1566
2223
 
1567
- // receive the results from task(s) created by create_tasks_cmpl, in stream mode
2224
+ // receive the results from task(s), in stream mode
1568
2225
  void receive_cmpl_results_stream(
1569
- const std::unordered_set<int> & id_tasks, const
1570
- std::function<bool(server_task_result&)> & result_handler, const
1571
- std::function<void(json)> & error_handler) {
2226
+ const std::unordered_set<int> & id_tasks,
2227
+ const std::function<bool(server_task_result_ptr&)> & result_handler,
2228
+ const std::function<void(json)> & error_handler) {
1572
2229
  size_t n_finished = 0;
1573
2230
  while (true) {
1574
- server_task_result result = queue_results.recv(id_tasks);
1575
- if (!result_handler(result)) {
2231
+ server_task_result_ptr result = queue_results.recv(id_tasks);
2232
+
2233
+ if (result->is_error()) {
2234
+ error_handler(result->to_json());
1576
2235
  cancel_tasks(id_tasks);
1577
- break;
2236
+ return;
1578
2237
  }
1579
2238
 
1580
- if (result.error) {
1581
- error_handler(result.data);
2239
+ GGML_ASSERT(
2240
+ dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
2241
+ || dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
2242
+ );
2243
+ if (!result_handler(result)) {
1582
2244
  cancel_tasks(id_tasks);
1583
2245
  break;
1584
2246
  }
1585
2247
 
1586
- if (result.stop) {
2248
+ if (result->is_stop()) {
1587
2249
  if (++n_finished == id_tasks.size()) {
1588
2250
  break;
1589
2251
  }
@@ -1595,24 +2257,16 @@ struct server_context {
1595
2257
  // Functions to process the task
1596
2258
  //
1597
2259
 
1598
- void process_single_task(const server_task & task) {
2260
+ void process_single_task(server_task task) {
1599
2261
  switch (task.type) {
1600
2262
  case SERVER_TASK_TYPE_COMPLETION:
2263
+ case SERVER_TASK_TYPE_INFILL:
2264
+ case SERVER_TASK_TYPE_EMBEDDING:
2265
+ case SERVER_TASK_TYPE_RERANK:
1601
2266
  {
1602
- const int id_slot = json_value(task.data, "id_slot", -1);
1603
-
1604
- server_slot * slot;
1605
-
1606
- if (id_slot != -1) {
1607
- slot = get_slot_by_id(id_slot);
1608
- } else {
1609
- std::string prompt;
1610
- if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
1611
- prompt = json_value(task.data, "prompt", std::string());
1612
- }
2267
+ const int id_slot = task.id_selected_slot;
1613
2268
 
1614
- slot = get_available_slot(prompt);
1615
- }
2269
+ server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
1616
2270
 
1617
2271
  if (slot == nullptr) {
1618
2272
  // if no slot is available, we defer this task for processing later
@@ -1627,22 +2281,6 @@ struct server_context {
1627
2281
  break;
1628
2282
  }
1629
2283
 
1630
- if (task.data.contains("system_prompt")) {
1631
- std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
1632
- system_prompt_set(sys_prompt);
1633
-
1634
- for (server_slot & slot : slots) {
1635
- slot.n_past = 0;
1636
- slot.n_past_se = 0;
1637
- }
1638
- }
1639
-
1640
- slot->reset();
1641
-
1642
- slot->id_task = task.id;
1643
- slot->cmpl_type = task.cmpl_type;
1644
- slot->index = json_value(task.data, "index", 0);
1645
-
1646
2284
  if (!launch_slot_with_task(*slot, task)) {
1647
2285
  SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
1648
2286
  break;
@@ -1670,68 +2308,50 @@ struct server_context {
1670
2308
  int n_processing_slots = 0;
1671
2309
 
1672
2310
  for (server_slot & slot : slots) {
1673
- json slot_data = get_formated_generation(slot);
1674
- slot_data["id"] = slot.id;
1675
- slot_data["id_task"] = slot.id_task;
1676
- slot_data["state"] = slot.state;
1677
- slot_data["prompt"] = slot.prompt;
1678
- slot_data["next_token"] = {
1679
- {"has_next_token", slot.has_next_token},
1680
- {"n_remain", slot.n_remaining},
1681
- {"n_decoded", slot.n_decoded},
1682
- {"stopped_eos", slot.stopped_eos},
1683
- {"stopped_word", slot.stopped_word},
1684
- {"stopped_limit", slot.stopped_limit},
1685
- {"stopping_word", slot.stopping_word},
1686
- };
1687
-
1688
- if (slot_data["state"] == SLOT_STATE_IDLE) {
1689
- n_idle_slots++;
1690
- } else {
2311
+ json slot_data = slot.to_json();
2312
+
2313
+ if (slot.is_processing()) {
1691
2314
  n_processing_slots++;
2315
+ } else {
2316
+ n_idle_slots++;
1692
2317
  }
1693
2318
 
1694
2319
  slots_data.push_back(slot_data);
1695
2320
  }
1696
2321
  SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots);
1697
2322
 
1698
- server_task_result res;
1699
- res.id = task.id;
1700
- res.stop = true;
1701
- res.error = false;
1702
- res.data = {
1703
- { "idle", n_idle_slots },
1704
- { "processing", n_processing_slots },
1705
- { "deferred", queue_tasks.queue_tasks_deferred.size() },
1706
- { "t_start", metrics.t_start},
1707
-
1708
- { "n_prompt_tokens_processed_total", metrics.n_prompt_tokens_processed_total},
1709
- { "t_tokens_generation_total", metrics.t_tokens_generation_total},
1710
- { "n_tokens_predicted_total", metrics.n_tokens_predicted_total},
1711
- { "t_prompt_processing_total", metrics.t_prompt_processing_total},
2323
+ auto res = std::make_unique<server_task_result_metrics>();
2324
+ res->id = task.id;
2325
+ res->slots_data = std::move(slots_data);
2326
+ res->n_idle_slots = n_idle_slots;
2327
+ res->n_processing_slots = n_processing_slots;
2328
+ res->n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
2329
+ res->t_start = metrics.t_start;
1712
2330
 
1713
- { "n_prompt_tokens_processed", metrics.n_prompt_tokens_processed},
1714
- { "t_prompt_processing", metrics.t_prompt_processing},
1715
- { "n_tokens_predicted", metrics.n_tokens_predicted},
1716
- { "t_tokens_generation", metrics.t_tokens_generation},
2331
+ res->kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
2332
+ res->kv_cache_used_cells = llama_get_kv_cache_used_cells(ctx);
1717
2333
 
1718
- { "n_decode_total", metrics.n_decode_total},
1719
- { "n_busy_slots_total", metrics.n_busy_slots_total},
2334
+ res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total;
2335
+ res->t_prompt_processing_total = metrics.t_prompt_processing_total;
2336
+ res->n_tokens_predicted_total = metrics.n_tokens_predicted_total;
2337
+ res->t_tokens_generation_total = metrics.t_tokens_generation_total;
1720
2338
 
1721
- { "kv_cache_tokens_count", llama_get_kv_cache_token_count(ctx)},
1722
- { "kv_cache_used_cells", llama_get_kv_cache_used_cells(ctx)},
2339
+ res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed;
2340
+ res->t_prompt_processing = metrics.t_prompt_processing;
2341
+ res->n_tokens_predicted = metrics.n_tokens_predicted;
2342
+ res->t_tokens_generation = metrics.t_tokens_generation;
1723
2343
 
1724
- { "slots", slots_data },
1725
- };
2344
+ res->n_decode_total = metrics.n_decode_total;
2345
+ res->n_busy_slots_total = metrics.n_busy_slots_total;
1726
2346
 
1727
- if (json_value(task.data, "reset_bucket", false)) {
2347
+ if (task.metrics_reset_bucket) {
1728
2348
  metrics.reset_bucket();
1729
2349
  }
1730
- queue_results.send(res);
2350
+ queue_results.send(std::move(res));
1731
2351
  } break;
1732
2352
  case SERVER_TASK_TYPE_SLOT_SAVE:
1733
2353
  {
1734
- int id_slot = task.data.at("id_slot");
2354
+ int id_slot = task.slot_action.slot_id;
1735
2355
  server_slot * slot = get_slot_by_id(id_slot);
1736
2356
  if (slot == nullptr) {
1737
2357
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -1747,32 +2367,27 @@ struct server_context {
1747
2367
  const size_t token_count = slot->cache_tokens.size();
1748
2368
  const int64_t t_start = ggml_time_us();
1749
2369
 
1750
- std::string filename = task.data.at("filename");
1751
- std::string filepath = task.data.at("filepath");
2370
+ std::string filename = task.slot_action.filename;
2371
+ std::string filepath = task.slot_action.filepath;
1752
2372
 
1753
- const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
2373
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
1754
2374
 
1755
2375
  const int64_t t_end = ggml_time_us();
1756
2376
  const double t_save_ms = (t_end - t_start) / 1000.0;
1757
2377
 
1758
- server_task_result result;
1759
- result.id = task.id;
1760
- result.stop = true;
1761
- result.error = false;
1762
- result.data = json {
1763
- { "id_slot", id_slot },
1764
- { "filename", filename },
1765
- { "n_saved", token_count }, // tokens saved
1766
- { "n_written", nwrite }, // bytes written
1767
- { "timings", {
1768
- { "save_ms", t_save_ms }
1769
- } }
1770
- };
1771
- queue_results.send(result);
2378
+ auto res = std::make_unique<server_task_result_slot_save_load>();
2379
+ res->id = task.id;
2380
+ res->id_slot = id_slot;
2381
+ res->filename = filename;
2382
+ res->is_save = true;
2383
+ res->n_tokens = token_count;
2384
+ res->n_bytes = nwrite;
2385
+ res->t_ms = t_save_ms;
2386
+ queue_results.send(std::move(res));
1772
2387
  } break;
1773
2388
  case SERVER_TASK_TYPE_SLOT_RESTORE:
1774
2389
  {
1775
- int id_slot = task.data.at("id_slot");
2390
+ int id_slot = task.slot_action.slot_id;
1776
2391
  server_slot * slot = get_slot_by_id(id_slot);
1777
2392
  if (slot == nullptr) {
1778
2393
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -1787,12 +2402,12 @@ struct server_context {
1787
2402
 
1788
2403
  const int64_t t_start = ggml_time_us();
1789
2404
 
1790
- std::string filename = task.data.at("filename");
1791
- std::string filepath = task.data.at("filepath");
2405
+ std::string filename = task.slot_action.filename;
2406
+ std::string filepath = task.slot_action.filepath;
1792
2407
 
1793
2408
  slot->cache_tokens.resize(slot->n_ctx);
1794
2409
  size_t token_count = 0;
1795
- size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
2410
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
1796
2411
  if (nread == 0) {
1797
2412
  slot->cache_tokens.resize(0);
1798
2413
  send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
@@ -1803,24 +2418,19 @@ struct server_context {
1803
2418
  const int64_t t_end = ggml_time_us();
1804
2419
  const double t_restore_ms = (t_end - t_start) / 1000.0;
1805
2420
 
1806
- server_task_result result;
1807
- result.id = task.id;
1808
- result.stop = true;
1809
- result.error = false;
1810
- result.data = json {
1811
- { "id_slot", id_slot },
1812
- { "filename", filename },
1813
- { "n_restored", token_count }, // tokens restored
1814
- { "n_read", nread }, // bytes read
1815
- { "timings", {
1816
- { "restore_ms", t_restore_ms }
1817
- } }
1818
- };
1819
- queue_results.send(result);
2421
+ auto res = std::make_unique<server_task_result_slot_save_load>();
2422
+ res->id = task.id;
2423
+ res->id_slot = id_slot;
2424
+ res->filename = filename;
2425
+ res->is_save = false;
2426
+ res->n_tokens = token_count;
2427
+ res->n_bytes = nread;
2428
+ res->t_ms = t_restore_ms;
2429
+ queue_results.send(std::move(res));
1820
2430
  } break;
1821
2431
  case SERVER_TASK_TYPE_SLOT_ERASE:
1822
2432
  {
1823
- int id_slot = task.data.at("id_slot");
2433
+ int id_slot = task.slot_action.slot_id;
1824
2434
  server_slot * slot = get_slot_by_id(id_slot);
1825
2435
  if (slot == nullptr) {
1826
2436
  send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
@@ -1835,37 +2445,26 @@ struct server_context {
1835
2445
 
1836
2446
  // Erase token cache
1837
2447
  const size_t n_erased = slot->cache_tokens.size();
1838
- llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
2448
+ llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
1839
2449
  slot->cache_tokens.clear();
1840
2450
 
1841
- server_task_result result;
1842
- result.id = task.id;
1843
- result.stop = true;
1844
- result.error = false;
1845
- result.data = json {
1846
- { "id_slot", id_slot },
1847
- { "n_erased", n_erased }
1848
- };
1849
- queue_results.send(result);
2451
+ auto res = std::make_unique<server_task_result_slot_erase>();
2452
+ res->id = task.id;
2453
+ res->id_slot = id_slot;
2454
+ res->n_erased = n_erased;
2455
+ queue_results.send(std::move(res));
1850
2456
  } break;
1851
2457
  case SERVER_TASK_TYPE_SET_LORA:
1852
2458
  {
1853
- llama_lora_adapters_apply(ctx, loras);
1854
- server_task_result result;
1855
- result.id = task.id;
1856
- result.stop = true;
1857
- result.error = false;
1858
- result.data = json{{ "success", true }};
1859
- queue_results.send(result);
2459
+ common_lora_adapters_apply(ctx, loras);
2460
+ auto res = std::make_unique<server_task_result_apply_lora>();
2461
+ res->id = task.id;
2462
+ queue_results.send(std::move(res));
1860
2463
  } break;
1861
2464
  }
1862
2465
  }
1863
2466
 
1864
2467
  void update_slots() {
1865
- if (system_need_update) {
1866
- system_prompt_update();
1867
- }
1868
-
1869
2468
  // check if all slots are idle
1870
2469
  {
1871
2470
  bool all_idle = true;
@@ -1879,7 +2478,7 @@ struct server_context {
1879
2478
 
1880
2479
  if (all_idle) {
1881
2480
  SRV_INF("%s", "all slots are idle\n");
1882
- if (system_prompt.empty() && clean_kv_cache) {
2481
+ if (clean_kv_cache) {
1883
2482
  kv_cache_clear();
1884
2483
  }
1885
2484
 
@@ -1890,53 +2489,49 @@ struct server_context {
1890
2489
  {
1891
2490
  SRV_DBG("%s", "posting NEXT_RESPONSE\n");
1892
2491
 
1893
- server_task task;
1894
- task.type = SERVER_TASK_TYPE_NEXT_RESPONSE;
1895
- task.id_target = -1;
1896
-
2492
+ server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE);
2493
+ task.id = queue_tasks.get_new_id();
1897
2494
  queue_tasks.post(task);
1898
2495
  }
1899
2496
 
1900
2497
  // apply context-shift if needed
1901
2498
  // TODO: simplify and improve
1902
2499
  for (server_slot & slot : slots) {
1903
- if (slot.ga_n == 1) {
1904
- if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
1905
- if (!params.ctx_shift) {
1906
- // this check is redundant (for good)
1907
- // we should never get here, because generation should already stopped in process_token()
1908
- slot.release();
1909
- send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
1910
- continue;
1911
- }
1912
-
1913
- // Shift context
1914
- const int n_keep = slot.params.n_keep + add_bos_token;
1915
- const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
1916
- const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
2500
+ if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
2501
+ if (!params_base.ctx_shift) {
2502
+ // this check is redundant (for good)
2503
+ // we should never get here, because generation should already stopped in process_token()
2504
+ slot.release();
2505
+ send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
2506
+ continue;
2507
+ }
1917
2508
 
1918
- SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
2509
+ // Shift context
2510
+ const int n_keep = slot.params.n_keep + add_bos_token;
2511
+ const int n_left = slot.n_past - n_keep;
2512
+ const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
1919
2513
 
1920
- llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1921
- llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
2514
+ SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
1922
2515
 
1923
- if (slot.params.cache_prompt) {
1924
- for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
1925
- slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1926
- }
2516
+ llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
2517
+ llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
1927
2518
 
1928
- slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
2519
+ if (slot.params.cache_prompt) {
2520
+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
2521
+ slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1929
2522
  }
1930
2523
 
1931
- slot.n_past -= n_discard;
1932
-
1933
- slot.truncated = true;
2524
+ slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1934
2525
  }
2526
+
2527
+ slot.n_past -= n_discard;
2528
+
2529
+ slot.truncated = true;
1935
2530
  }
1936
2531
  }
1937
2532
 
1938
2533
  // start populating the batch for this iteration
1939
- llama_batch_clear(batch);
2534
+ common_batch_clear(batch);
1940
2535
 
1941
2536
  // frist, add sampled tokens from any ongoing sequences
1942
2537
  for (auto & slot : slots) {
@@ -1946,11 +2541,7 @@ struct server_context {
1946
2541
 
1947
2542
  slot.i_batch = batch.n_tokens;
1948
2543
 
1949
- const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
1950
-
1951
- // TODO: we always have to take into account the "system_tokens"
1952
- // this is not great and needs to be improved somehow
1953
- llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
2544
+ common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
1954
2545
 
1955
2546
  slot.n_past += 1;
1956
2547
 
@@ -1958,8 +2549,8 @@ struct server_context {
1958
2549
  slot.cache_tokens.push_back(slot.sampled);
1959
2550
  }
1960
2551
 
1961
- SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_system_tokens = %d, n_cache_tokens = %d, truncated = %d\n",
1962
- slot.n_ctx, slot.n_past, (int) system_tokens.size(), (int) slot.cache_tokens.size(), slot.truncated);
2552
+ SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
2553
+ slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
1963
2554
  }
1964
2555
 
1965
2556
  // process in chunks of params.n_batch
@@ -1973,82 +2564,35 @@ struct server_context {
1973
2564
  int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
1974
2565
 
1975
2566
  // next, batch any pending prompts without exceeding n_batch
1976
- if (params.cont_batching || batch.n_tokens == 0) {
2567
+ if (params_base.cont_batching || batch.n_tokens == 0) {
1977
2568
  for (auto & slot : slots) {
1978
2569
  // this slot still has a prompt to be processed
1979
- if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
2570
+ if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
1980
2571
  auto & prompt_tokens = slot.prompt_tokens;
1981
2572
 
1982
- // we haven't tokenized the prompt yet - do it now:
1983
- if (prompt_tokens.empty()) {
1984
- SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
1985
-
1986
- slot.t_start_process_prompt = ggml_time_us();
1987
- slot.t_start_generation = 0;
1988
-
1989
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
1990
- const bool add_bos = llama_add_bos_token(model);
1991
- bool suff_rm_leading_spc = true;
1992
- if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
1993
- params.input_suffix.erase(0, 1);
1994
- suff_rm_leading_spc = false;
1995
- }
1996
-
1997
- auto prefix_tokens = tokenize(slot.params.input_prefix, false);
1998
- auto suffix_tokens = tokenize(slot.params.input_suffix, false);
1999
-
2000
- const int space_token = 29871; // TODO: this should not be hardcoded
2001
- if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
2002
- suffix_tokens.erase(suffix_tokens.begin());
2003
- }
2004
-
2005
- prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
2006
- suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
2007
-
2008
- auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
2009
- auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
2010
- if (add_bos) {
2011
- embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
2012
- }
2013
- embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
2014
-
2015
- const llama_token middle_token = llama_token_middle(model);
2016
- if (middle_token >= 0) {
2017
- embd_inp.push_back(middle_token);
2018
- }
2019
-
2020
- prompt_tokens = embd_inp;
2021
- } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2022
- // require slot.prompt to be array of 2 strings
2023
- if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
2024
- SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
2025
- slot.release();
2026
- send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
2027
- continue;
2028
- }
2029
-
2030
- // prompt: [BOS]query[EOS][SEP]doc[EOS]
2031
- prompt_tokens.clear();
2032
- prompt_tokens.push_back(llama_token_bos(model));
2033
- {
2034
- const auto part = tokenize(slot.prompt[0], false);
2035
- prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2036
- }
2037
- prompt_tokens.push_back(llama_token_eos(model));
2038
- prompt_tokens.push_back(llama_token_sep(model));
2039
- {
2040
- const auto part = tokenize(slot.prompt[1], false);
2041
- prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2042
- }
2043
- prompt_tokens.push_back(llama_token_eos(model));
2044
- } else {
2045
- prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
2046
- }
2047
-
2573
+ // TODO: maybe move branch to outside of this loop in the future
2574
+ if (slot.state == SLOT_STATE_STARTED) {
2575
+ slot.t_start_process_prompt = ggml_time_us();
2576
+ slot.t_start_generation = 0;
2577
+
2048
2578
  slot.n_past = 0;
2049
2579
  slot.n_prompt_tokens = prompt_tokens.size();
2580
+ slot.state = SLOT_STATE_PROCESSING_PROMPT;
2581
+
2582
+ SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
2050
2583
 
2051
- SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
2584
+ // print prompt tokens (for debugging)
2585
+ if (1) {
2586
+ // first 16 tokens (avoid flooding logs)
2587
+ for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
2588
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2589
+ }
2590
+ } else {
2591
+ // all
2592
+ for (int i = 0; i < (int) prompt_tokens.size(); i++) {
2593
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2594
+ }
2595
+ }
2052
2596
 
2053
2597
  // empty prompt passed -> release the slot and send empty response
2054
2598
  if (prompt_tokens.empty()) {
@@ -2060,17 +2604,24 @@ struct server_context {
2060
2604
  continue;
2061
2605
  }
2062
2606
 
2063
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2064
- // this prompt is too large to process - discard it
2607
+ if (slot.is_non_causal()) {
2065
2608
  if (slot.n_prompt_tokens > n_ubatch) {
2066
2609
  slot.release();
2067
2610
  send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
2068
2611
  continue;
2069
2612
  }
2613
+
2614
+ if (slot.n_prompt_tokens > slot.n_ctx) {
2615
+ slot.release();
2616
+ send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
2617
+ continue;
2618
+ }
2070
2619
  } else {
2071
- if (!params.ctx_shift) {
2620
+ if (!params_base.ctx_shift) {
2072
2621
  // if context shift is disabled, we make sure prompt size is smaller than KV size
2073
- if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
2622
+ // TODO: there should be a separate parameter that control prompt truncation
2623
+ // context shift should be applied only during the generation phase
2624
+ if (slot.n_prompt_tokens >= slot.n_ctx) {
2074
2625
  slot.release();
2075
2626
  send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
2076
2627
  continue;
@@ -2081,14 +2632,14 @@ struct server_context {
2081
2632
  }
2082
2633
  slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
2083
2634
 
2084
- // if input prompt is too big, truncate it (if group attention self-extend is disabled)
2085
- if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
2635
+ // if input prompt is too big, truncate it
2636
+ if (slot.n_prompt_tokens >= slot.n_ctx) {
2086
2637
  const int n_left = slot.n_ctx - slot.params.n_keep;
2087
2638
 
2088
2639
  const int n_block_size = n_left / 2;
2089
2640
  const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
2090
2641
 
2091
- std::vector<llama_token> new_tokens(
2642
+ llama_tokens new_tokens(
2092
2643
  prompt_tokens.begin(),
2093
2644
  prompt_tokens.begin() + slot.params.n_keep);
2094
2645
 
@@ -2107,20 +2658,52 @@ struct server_context {
2107
2658
  GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
2108
2659
  }
2109
2660
 
2110
- gpt_sampler_reset(slot.smpl);
2661
+ if (slot.params.cache_prompt) {
2662
+ // reuse any previously computed tokens that are common with the new prompt
2663
+ slot.n_past = common_lcp(slot.cache_tokens, prompt_tokens);
2111
2664
 
2112
- if (!slot.params.cache_prompt) {
2113
- slot.n_past_se = 0;
2114
- slot.ga_i = 0;
2115
- } else {
2116
- GGML_ASSERT(slot.ga_n == 1);
2665
+ // reuse chunks from the cached prompt by shifting their KV cache in the new position
2666
+ if (params_base.n_cache_reuse > 0) {
2667
+ size_t head_c = slot.n_past; // cache
2668
+ size_t head_p = slot.n_past; // current prompt
2117
2669
 
2118
- // reuse any previously computed tokens that are common with the new prompt
2119
- slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
2670
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params_base.n_cache_reuse, slot.n_past);
2671
+
2672
+ while (head_c < slot.cache_tokens.size() &&
2673
+ head_p < prompt_tokens.size()) {
2674
+
2675
+ size_t n_match = 0;
2676
+ while (head_c + n_match < slot.cache_tokens.size() &&
2677
+ head_p + n_match < prompt_tokens.size() &&
2678
+ slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
2120
2679
 
2121
- // push the prompt into the sampling context (do not apply grammar)
2122
- for (int i = 0; i < slot.n_past; ++i) {
2123
- gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
2680
+ n_match++;
2681
+ }
2682
+
2683
+ if (n_match >= (size_t) params_base.n_cache_reuse) {
2684
+ SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
2685
+ //for (size_t i = head_p; i < head_p + n_match; i++) {
2686
+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2687
+ //}
2688
+
2689
+ const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
2690
+
2691
+ llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
2692
+ llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
2693
+
2694
+ for (size_t i = 0; i < n_match; i++) {
2695
+ slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
2696
+ slot.n_past++;
2697
+ }
2698
+
2699
+ head_c += n_match;
2700
+ head_p += n_match;
2701
+ } else {
2702
+ head_c += 1;
2703
+ }
2704
+ }
2705
+
2706
+ SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
2124
2707
  }
2125
2708
  }
2126
2709
  }
@@ -2130,16 +2713,13 @@ struct server_context {
2130
2713
  SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
2131
2714
 
2132
2715
  slot.n_past--;
2133
- if (slot.ga_i > 0) {
2134
- slot.n_past_se--;
2135
- }
2136
2716
  }
2137
2717
 
2138
2718
  slot.n_prompt_tokens_processed = 0;
2139
2719
  }
2140
2720
 
2141
2721
  // non-causal tasks require to fit the entire prompt in the physical batch
2142
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2722
+ if (slot.is_non_causal()) {
2143
2723
  // cannot fit the prompt in the current batch - will try next iter
2144
2724
  if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
2145
2725
  continue;
@@ -2147,10 +2727,7 @@ struct server_context {
2147
2727
  }
2148
2728
 
2149
2729
  // check that we are in the right batch_type, if not defer the slot
2150
- const bool slot_type =
2151
- slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
2152
- slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
2153
-
2730
+ int slot_type = slot.is_non_causal();
2154
2731
  if (batch_type == -1) {
2155
2732
  batch_type = slot_type;
2156
2733
  } else if (batch_type != slot_type) {
@@ -2158,55 +2735,32 @@ struct server_context {
2158
2735
  }
2159
2736
 
2160
2737
  // keep only the common part
2161
- int p0 = (int) system_tokens.size() + slot.n_past;
2162
- if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
2738
+ if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
2163
2739
  // could not partially delete (likely using a non-Transformer model)
2164
- llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
2740
+ llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
2165
2741
 
2166
- p0 = (int) system_tokens.size();
2167
- if (p0 != 0) {
2168
- // copy over the system prompt when there is one
2169
- llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
2170
- }
2171
-
2172
- // there is no common part left (except for the system prompt)
2742
+ // there is no common part left
2173
2743
  slot.n_past = 0;
2174
- slot.n_past_se = 0;
2175
- slot.ga_i = 0;
2176
- // TODO: is the system prompt ever in the sampling context?
2177
- gpt_sampler_reset(slot.smpl);
2178
2744
  }
2179
2745
 
2746
+ SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
2747
+
2180
2748
  // remove the non-common part from the cache
2181
2749
  slot.cache_tokens.resize(slot.n_past);
2182
2750
 
2183
- SLT_INF(slot, "kv cache rm [%d, end)\n", p0);
2184
-
2185
- int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
2186
-
2187
- int32_t ga_i = slot.ga_i;
2188
- int32_t ga_n = slot.ga_n;
2189
- int32_t ga_w = slot.ga_w;
2190
-
2191
2751
  // add prompt tokens for processing in the current batch
2192
- // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
2193
- for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
2194
- if (slot.ga_n != 1) {
2195
- while (slot_npast >= ga_i + ga_w) {
2196
- const int bd = (ga_w/ga_n)*(ga_n - 1);
2197
- slot_npast -= bd;
2198
- ga_i += ga_w/ga_n;
2199
- }
2200
- }
2752
+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2753
+ // without pooling, we want to output the embeddings for all the tokens in the batch
2754
+ const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE;
2201
2755
 
2202
- llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
2756
+ common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd);
2203
2757
 
2204
2758
  if (slot.params.cache_prompt) {
2205
2759
  slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
2206
2760
  }
2207
2761
 
2208
2762
  slot.n_prompt_tokens_processed++;
2209
- slot_npast++;
2763
+ slot.n_past++;
2210
2764
  }
2211
2765
 
2212
2766
  SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
@@ -2217,6 +2771,13 @@ struct server_context {
2217
2771
 
2218
2772
  GGML_ASSERT(batch.n_tokens > 0);
2219
2773
 
2774
+ common_sampler_reset(slot.smpl);
2775
+
2776
+ // Process all prompt tokens through sampler system
2777
+ for (int i = 0; i < slot.n_prompt_tokens; ++i) {
2778
+ common_sampler_accept(slot.smpl, prompt_tokens[i], false);
2779
+ }
2780
+
2220
2781
  // extract the logits only for the last token
2221
2782
  batch.logits[batch.n_tokens - 1] = true;
2222
2783
 
@@ -2247,34 +2808,6 @@ struct server_context {
2247
2808
  for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
2248
2809
  const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
2249
2810
 
2250
- for (auto & slot : slots) {
2251
- if (slot.ga_n != 1) {
2252
- // context extension via Self-Extend
2253
- // TODO: simplify and/or abstract this
2254
- while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
2255
- const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
2256
- const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
2257
- const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
2258
-
2259
- SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2260
- SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
2261
- SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2262
-
2263
- llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
2264
- llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
2265
- llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
2266
-
2267
- slot.n_past_se -= bd;
2268
-
2269
- slot.ga_i += slot.ga_w / slot.ga_n;
2270
-
2271
- SLT_DBG(slot, "\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
2272
- }
2273
-
2274
- slot.n_past_se += n_tokens;
2275
- }
2276
- }
2277
-
2278
2811
  llama_batch batch_view = {
2279
2812
  n_tokens,
2280
2813
  batch.token + i,
@@ -2283,7 +2816,6 @@ struct server_context {
2283
2816
  batch.n_seq_id + i,
2284
2817
  batch.seq_id + i,
2285
2818
  batch.logits + i,
2286
- 0, 0, 0, // unused
2287
2819
  };
2288
2820
 
2289
2821
  const int ret = llama_decode(ctx, batch_view);
@@ -2315,7 +2847,7 @@ struct server_context {
2315
2847
  }
2316
2848
 
2317
2849
  if (slot.state == SLOT_STATE_DONE_PROMPT) {
2318
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2850
+ if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) {
2319
2851
  // prompt evaluated for embedding
2320
2852
  send_embedding(slot, batch_view);
2321
2853
  slot.release();
@@ -2323,7 +2855,7 @@ struct server_context {
2323
2855
  continue; // continue loop of slots
2324
2856
  }
2325
2857
 
2326
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2858
+ if (slot.task_type == SERVER_TASK_TYPE_RERANK) {
2327
2859
  send_rerank(slot, batch_view);
2328
2860
  slot.release();
2329
2861
  slot.i_batch = -1;
@@ -2336,27 +2868,33 @@ struct server_context {
2336
2868
  continue; // continue loop of slots
2337
2869
  }
2338
2870
 
2339
- completion_token_output result;
2340
- const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
2871
+ const int tok_idx = slot.i_batch - i;
2872
+
2873
+ llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx);
2341
2874
 
2342
- gpt_sampler_accept(slot.smpl, id, true);
2875
+ slot.i_batch = -1;
2876
+
2877
+ common_sampler_accept(slot.smpl, id, true);
2343
2878
 
2344
2879
  slot.n_decoded += 1;
2880
+
2881
+ const int64_t t_current = ggml_time_us();
2882
+
2345
2883
  if (slot.n_decoded == 1) {
2346
- slot.t_start_generation = ggml_time_us();
2884
+ slot.t_start_generation = t_current;
2347
2885
  slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3;
2348
2886
  metrics.on_prompt_eval(slot);
2349
2887
  }
2350
2888
 
2351
- result.tok = id;
2889
+ slot.t_token_generation = (t_current - slot.t_start_generation) / 1e3;
2352
2890
 
2353
- const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
2891
+ completion_token_output result;
2892
+ result.tok = id;
2893
+ result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
2894
+ result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs
2354
2895
 
2355
- for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2356
- result.probs.push_back({
2357
- cur_p->data[i].id,
2358
- i >= cur_p->size ? 0.0f : cur_p->data[i].p,
2359
- });
2896
+ if (slot.params.sampling.n_probs > 0) {
2897
+ populate_token_probs(slot, result, slot.params.post_sampling_probs, params_base.special, tok_idx);
2360
2898
  }
2361
2899
 
2362
2900
  if (!process_token(result, slot)) {
@@ -2365,9 +2903,98 @@ struct server_context {
2365
2903
  slot.print_timings();
2366
2904
  send_final_response(slot);
2367
2905
  metrics.on_prediction(slot);
2906
+ continue;
2907
+ }
2908
+ }
2909
+
2910
+ // do speculative decoding
2911
+ for (auto & slot : slots) {
2912
+ if (!slot.is_processing() || !slot.can_speculate()) {
2913
+ continue;
2368
2914
  }
2369
2915
 
2370
- slot.i_batch = -1;
2916
+ if (slot.state != SLOT_STATE_GENERATING) {
2917
+ continue;
2918
+ }
2919
+
2920
+ // determine the max draft that fits the current slot state
2921
+ int n_draft_max = slot.params.speculative.n_max;
2922
+
2923
+ // note: n_past is not yet increased for the `id` token sampled above
2924
+ // also, need to leave space for 1 extra token to allow context shifts
2925
+ n_draft_max = std::min(n_draft_max, slot.n_ctx - slot.n_past - 2);
2926
+
2927
+ if (slot.n_remaining > 0) {
2928
+ n_draft_max = std::min(n_draft_max, slot.n_remaining - 1);
2929
+ }
2930
+
2931
+ SLT_DBG(slot, "max possible draft: %d\n", n_draft_max);
2932
+
2933
+ if (n_draft_max < slot.params.speculative.n_min) {
2934
+ SLT_DBG(slot, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, slot.params.speculative.n_min);
2935
+
2936
+ continue;
2937
+ }
2938
+
2939
+ llama_token id = slot.sampled;
2940
+
2941
+ struct common_speculative_params params_spec;
2942
+ params_spec.n_draft = n_draft_max;
2943
+ params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.params.speculative.n_max;
2944
+ params_spec.p_min = slot.params.speculative.p_min;
2945
+
2946
+ llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, slot.cache_tokens, id);
2947
+
2948
+ // ignore small drafts
2949
+ if (slot.params.speculative.n_min > (int) draft.size()) {
2950
+ SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.params.speculative.n_min);
2951
+
2952
+ continue;
2953
+ }
2954
+
2955
+ // construct the speculation batch
2956
+ common_batch_clear(slot.batch_spec);
2957
+ common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true);
2958
+
2959
+ for (size_t i = 0; i < draft.size(); ++i) {
2960
+ common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true);
2961
+ }
2962
+
2963
+ SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens);
2964
+
2965
+ llama_decode(ctx, slot.batch_spec);
2966
+
2967
+ // the accepted tokens from the speculation
2968
+ const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft);
2969
+
2970
+ slot.n_past += ids.size();
2971
+ slot.n_decoded += ids.size();
2972
+
2973
+ slot.cache_tokens.push_back(id);
2974
+ slot.cache_tokens.insert(slot.cache_tokens.end(), ids.begin(), ids.end() - 1);
2975
+
2976
+ llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1);
2977
+
2978
+ for (size_t i = 0; i < ids.size(); ++i) {
2979
+ completion_token_output result;
2980
+
2981
+ result.tok = ids[i];
2982
+ result.text_to_send = common_token_to_piece(ctx, result.tok, params_base.special);
2983
+ result.prob = 1.0f; // set later
2984
+
2985
+ // TODO: set result.probs
2986
+
2987
+ if (!process_token(result, slot)) {
2988
+ // release slot because of stop condition
2989
+ slot.release();
2990
+ slot.print_timings();
2991
+ send_final_response(slot);
2992
+ metrics.on_prediction(slot);
2993
+ break;
2994
+ }
2995
+ }
2996
+
2997
+ SLT_DBG(slot, "accepted %d/%d draft tokens, new n_past = %d\n", (int) ids.size() - 1, (int) draft.size(), slot.n_past);
2371
2998
  }
2372
2999
  }
2373
3000
 
@@ -2414,35 +3041,23 @@ inline void signal_handler(int signal) {
2414
3041
 
2415
3042
  int main(int argc, char ** argv) {
2416
3043
  // own arguments required by this example
2417
- gpt_params params;
3044
+ common_params params;
2418
3045
 
2419
- if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
3046
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
2420
3047
  return 1;
2421
3048
  }
2422
3049
 
2423
- gpt_init();
2424
-
2425
- // enabling this will output extra debug information in the HTTP responses from the server
2426
- // see format_final_response_oaicompat()
2427
- const bool verbose = params.verbosity > 9;
3050
+ common_init();
2428
3051
 
2429
3052
  // struct that contains llama context and inference
2430
3053
  server_context ctx_server;
2431
3054
 
2432
- if (!params.system_prompt.empty()) {
2433
- ctx_server.system_prompt_set(params.system_prompt);
2434
- }
2435
-
2436
- if (params.model_alias == "unknown") {
2437
- params.model_alias = params.model;
2438
- }
2439
-
2440
3055
  llama_backend_init();
2441
3056
  llama_numa_init(params.numa);
2442
3057
 
2443
3058
  LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
2444
3059
  LOG_INF("\n");
2445
- LOG_INF("%s\n", gpt_params_get_system_info(params).c_str());
3060
+ LOG_INF("%s\n", common_params_get_system_info(params).c_str());
2446
3061
  LOG_INF("\n");
2447
3062
 
2448
3063
  std::unique_ptr<httplib::Server> svr;
@@ -2467,34 +3082,24 @@ int main(int argc, char ** argv) {
2467
3082
  std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
2468
3083
 
2469
3084
  svr->set_default_headers({{"Server", "llama.cpp"}});
2470
-
2471
- // CORS preflight
2472
- svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
2473
- // Access-Control-Allow-Origin is already set by middleware
2474
- res.set_header("Access-Control-Allow-Credentials", "true");
2475
- res.set_header("Access-Control-Allow-Methods", "POST");
2476
- res.set_header("Access-Control-Allow-Headers", "*");
2477
- return res.set_content("", "text/html"); // blank response, no data
2478
- });
2479
-
2480
3085
  svr->set_logger(log_server_request);
2481
3086
 
2482
3087
  auto res_error = [](httplib::Response & res, const json & error_data) {
2483
3088
  json final_response {{"error", error_data}};
2484
- res.set_content(final_response.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
3089
+ res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON);
2485
3090
  res.status = json_value(error_data, "code", 500);
2486
3091
  };
2487
3092
 
2488
3093
  auto res_ok = [](httplib::Response & res, const json & data) {
2489
- res.set_content(data.dump(-1, ' ', false, json::error_handler_t::replace), MIMETYPE_JSON);
3094
+ res.set_content(safe_json_to_str(data), MIMETYPE_JSON);
2490
3095
  res.status = 200;
2491
3096
  };
2492
3097
 
2493
- svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, std::exception_ptr ep) {
3098
+ svr->set_exception_handler([&res_error](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) {
2494
3099
  std::string message;
2495
3100
  try {
2496
3101
  std::rethrow_exception(ep);
2497
- } catch (std::exception & e) {
3102
+ } catch (const std::exception & e) {
2498
3103
  message = e.what();
2499
3104
  } catch (...) {
2500
3105
  message = "Unknown Exception";
@@ -2536,20 +3141,10 @@ int main(int argc, char ** argv) {
2536
3141
  //
2537
3142
 
2538
3143
  auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
2539
- // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
2540
- static const std::unordered_set<std::string> protected_endpoints = {
2541
- "/props",
2542
- "/completion",
2543
- "/completions",
2544
- "/v1/completions",
2545
- "/chat/completions",
2546
- "/v1/chat/completions",
2547
- "/infill",
2548
- "/tokenize",
2549
- "/detokenize",
2550
- "/embedding",
2551
- "/embeddings",
2552
- "/v1/embeddings",
3144
+ static const std::unordered_set<std::string> public_endpoints = {
3145
+ "/health",
3146
+ "/models",
3147
+ "/v1/models",
2553
3148
  };
2554
3149
 
2555
3150
  // If API key is not set, skip validation
@@ -2557,8 +3152,8 @@ int main(int argc, char ** argv) {
2557
3152
  return true;
2558
3153
  }
2559
3154
 
2560
- // If path is not in protected_endpoints list, skip validation
2561
- if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
3155
+ // If path is public or is static file, skip validation
3156
+ if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") {
2562
3157
  return true;
2563
3158
  }
2564
3159
 
@@ -2584,7 +3179,7 @@ int main(int argc, char ** argv) {
2584
3179
  auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
2585
3180
  server_state current_state = state.load();
2586
3181
  if (current_state == SERVER_STATE_LOADING_MODEL) {
2587
- auto tmp = string_split(req.path, '.');
3182
+ auto tmp = string_split<std::string>(req.path, '.');
2588
3183
  if (req.path == "/" || tmp.back() == "html") {
2589
3184
  res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
2590
3185
  res.status = 503;
@@ -2599,6 +3194,14 @@ int main(int argc, char ** argv) {
2599
3194
  // register server middlewares
2600
3195
  svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
2601
3196
  res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3197
+ // If this is OPTIONS request, skip validation because browsers don't include Authorization header
3198
+ if (req.method == "OPTIONS") {
3199
+ res.set_header("Access-Control-Allow-Credentials", "true");
3200
+ res.set_header("Access-Control-Allow-Methods", "GET, POST");
3201
+ res.set_header("Access-Control-Allow-Headers", "*");
3202
+ res.set_content("", "text/html"); // blank response, no data
3203
+ return httplib::Server::HandlerResponse::Handled; // skip further processing
3204
+ }
2602
3205
  if (!middleware_server_state(req, res)) {
2603
3206
  return httplib::Server::HandlerResponse::Handled;
2604
3207
  }
@@ -2620,32 +3223,38 @@ int main(int argc, char ** argv) {
2620
3223
 
2621
3224
  const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
2622
3225
  if (!params.endpoint_slots) {
2623
- res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
3226
+ res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
2624
3227
  return;
2625
3228
  }
2626
3229
 
2627
3230
  // request slots data using task queue
2628
- server_task task;
3231
+ server_task task(SERVER_TASK_TYPE_METRICS);
2629
3232
  task.id = ctx_server.queue_tasks.get_new_id();
2630
- task.type = SERVER_TASK_TYPE_METRICS;
2631
-
2632
3233
  ctx_server.queue_results.add_waiting_task_id(task.id);
2633
3234
  ctx_server.queue_tasks.post(task, true); // high-priority task
2634
3235
 
2635
3236
  // get the result
2636
- server_task_result result = ctx_server.queue_results.recv(task.id);
3237
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
2637
3238
  ctx_server.queue_results.remove_waiting_task_id(task.id);
2638
3239
 
3240
+ if (result->is_error()) {
3241
+ res_error(res, result->to_json());
3242
+ return;
3243
+ }
3244
+
3245
+ // TODO: get rid of this dynamic_cast
3246
+ auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
3247
+ GGML_ASSERT(res_metrics != nullptr);
3248
+
2639
3249
  // optionally return "fail_on_no_slot" error
2640
- const int n_idle_slots = result.data.at("idle");
2641
3250
  if (req.has_param("fail_on_no_slot")) {
2642
- if (n_idle_slots == 0) {
3251
+ if (res_metrics->n_idle_slots == 0) {
2643
3252
  res_error(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE));
2644
3253
  return;
2645
3254
  }
2646
3255
  }
2647
3256
 
2648
- res_ok(res, result.data.at("slots"));
3257
+ res_ok(res, res_metrics->slots_data);
2649
3258
  };
2650
3259
 
2651
3260
  const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) {
@@ -2655,83 +3264,77 @@ int main(int argc, char ** argv) {
2655
3264
  }
2656
3265
 
2657
3266
  // request slots data using task queue
2658
- server_task task;
3267
+ server_task task(SERVER_TASK_TYPE_METRICS);
2659
3268
  task.id = ctx_server.queue_tasks.get_new_id();
2660
- task.id_target = -1;
2661
- task.type = SERVER_TASK_TYPE_METRICS;
2662
- task.data.push_back({{"reset_bucket", true}});
3269
+ task.metrics_reset_bucket = true;
2663
3270
 
2664
3271
  ctx_server.queue_results.add_waiting_task_id(task.id);
2665
3272
  ctx_server.queue_tasks.post(task, true); // high-priority task
2666
3273
 
2667
3274
  // get the result
2668
- server_task_result result = ctx_server.queue_results.recv(task.id);
3275
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
2669
3276
  ctx_server.queue_results.remove_waiting_task_id(task.id);
2670
3277
 
2671
- json data = result.data;
2672
-
2673
- const uint64_t n_prompt_tokens_processed = data.at("n_prompt_tokens_processed");
2674
- const uint64_t t_prompt_processing = data.at("t_prompt_processing");
2675
-
2676
- const uint64_t n_tokens_predicted = data.at("n_tokens_predicted");
2677
- const uint64_t t_tokens_generation = data.at("t_tokens_generation");
2678
-
2679
- const uint64_t n_decode_total = data.at("n_decode_total");
2680
- const uint64_t n_busy_slots_total = data.at("n_busy_slots_total");
3278
+ if (result->is_error()) {
3279
+ res_error(res, result->to_json());
3280
+ return;
3281
+ }
2681
3282
 
2682
- const int32_t kv_cache_used_cells = data.at("kv_cache_used_cells");
3283
+ // TODO: get rid of this dynamic_cast
3284
+ auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get());
3285
+ GGML_ASSERT(res_metrics != nullptr);
2683
3286
 
2684
3287
  // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names
2685
3288
  json all_metrics_def = json {
2686
3289
  {"counter", {{
2687
3290
  {"name", "prompt_tokens_total"},
2688
3291
  {"help", "Number of prompt tokens processed."},
2689
- {"value", (uint64_t) data.at("n_prompt_tokens_processed_total")}
3292
+ {"value", (uint64_t) res_metrics->n_prompt_tokens_processed_total}
2690
3293
  }, {
2691
3294
  {"name", "prompt_seconds_total"},
2692
3295
  {"help", "Prompt process time"},
2693
- {"value", (uint64_t) data.at("t_prompt_processing_total") / 1.e3}
3296
+ {"value", (uint64_t) res_metrics->t_prompt_processing_total / 1.e3}
2694
3297
  }, {
2695
3298
  {"name", "tokens_predicted_total"},
2696
3299
  {"help", "Number of generation tokens processed."},
2697
- {"value", (uint64_t) data.at("n_tokens_predicted_total")}
3300
+ {"value", (uint64_t) res_metrics->n_tokens_predicted_total}
2698
3301
  }, {
2699
3302
  {"name", "tokens_predicted_seconds_total"},
2700
3303
  {"help", "Predict process time"},
2701
- {"value", (uint64_t) data.at("t_tokens_generation_total") / 1.e3}
3304
+ {"value", (uint64_t) res_metrics->t_tokens_generation_total / 1.e3}
2702
3305
  }, {
2703
3306
  {"name", "n_decode_total"},
2704
3307
  {"help", "Total number of llama_decode() calls"},
2705
- {"value", n_decode_total}
3308
+ {"value", res_metrics->n_decode_total}
2706
3309
  }, {
2707
3310
  {"name", "n_busy_slots_per_decode"},
2708
3311
  {"help", "Average number of busy slots per llama_decode() call"},
2709
- {"value", (float) n_busy_slots_total / (float) n_decode_total}
3312
+ {"value", (float) res_metrics->n_busy_slots_total / (float) res_metrics->n_decode_total}
2710
3313
  }}},
2711
3314
  {"gauge", {{
2712
3315
  {"name", "prompt_tokens_seconds"},
2713
3316
  {"help", "Average prompt throughput in tokens/s."},
2714
- {"value", n_prompt_tokens_processed ? 1.e3 / t_prompt_processing * n_prompt_tokens_processed : 0.}
3317
+ {"value", res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.}
2715
3318
  },{
2716
3319
  {"name", "predicted_tokens_seconds"},
2717
3320
  {"help", "Average generation throughput in tokens/s."},
2718
- {"value", n_tokens_predicted ? 1.e3 / t_tokens_generation * n_tokens_predicted : 0.}
3321
+ {"value", res_metrics->n_tokens_predicted ? 1.e3 / res_metrics->t_tokens_generation * res_metrics->n_tokens_predicted : 0.}
2719
3322
  },{
2720
3323
  {"name", "kv_cache_usage_ratio"},
2721
3324
  {"help", "KV-cache usage. 1 means 100 percent usage."},
2722
- {"value", 1. * kv_cache_used_cells / params.n_ctx}
3325
+ {"value", 1. * res_metrics->kv_cache_used_cells / params.n_ctx}
2723
3326
  },{
2724
3327
  {"name", "kv_cache_tokens"},
2725
3328
  {"help", "KV-cache tokens."},
2726
- {"value", (uint64_t) data.at("kv_cache_tokens_count")}
3329
+ {"value", (uint64_t) res_metrics->kv_cache_tokens_count}
2727
3330
  },{
2728
3331
  {"name", "requests_processing"},
2729
3332
  {"help", "Number of request processing."},
2730
- {"value", (uint64_t) data.at("processing")}
3333
+ {"value", (uint64_t) res_metrics->n_processing_slots}
2731
3334
  },{
2732
3335
  {"name", "requests_deferred"},
2733
3336
  {"help", "Number of request deferred."},
2734
- {"value", (uint64_t) data.at("deferred")}
3337
+ {"value", (uint64_t) res_metrics->n_tasks_deferred}
2735
3338
  }}}
2736
3339
  };
2737
3340
 
@@ -2752,8 +3355,7 @@ int main(int argc, char ** argv) {
2752
3355
  }
2753
3356
  }
2754
3357
 
2755
- const int64_t t_start = data.at("t_start");
2756
- res.set_header("Process-Start-Time-Unix", std::to_string(t_start));
3358
+ res.set_header("Process-Start-Time-Unix", std::to_string(res_metrics->t_start));
2757
3359
 
2758
3360
  res.set_content(prometheus.str(), "text/plain; version=0.0.4");
2759
3361
  res.status = 200; // HTTP OK
@@ -2768,25 +3370,24 @@ int main(int argc, char ** argv) {
2768
3370
  }
2769
3371
  std::string filepath = params.slot_save_path + filename;
2770
3372
 
2771
- server_task task;
2772
- task.type = SERVER_TASK_TYPE_SLOT_SAVE;
2773
- task.data = {
2774
- { "id_slot", id_slot },
2775
- { "filename", filename },
2776
- { "filepath", filepath },
2777
- };
3373
+ server_task task(SERVER_TASK_TYPE_SLOT_SAVE);
3374
+ task.id = ctx_server.queue_tasks.get_new_id();
3375
+ task.slot_action.slot_id = id_slot;
3376
+ task.slot_action.filename = filename;
3377
+ task.slot_action.filepath = filepath;
2778
3378
 
2779
- const int id_task = ctx_server.queue_tasks.post(task);
2780
- ctx_server.queue_results.add_waiting_task_id(id_task);
3379
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3380
+ ctx_server.queue_tasks.post(task);
2781
3381
 
2782
- server_task_result result = ctx_server.queue_results.recv(id_task);
2783
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3382
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3383
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
2784
3384
 
2785
- if (result.error) {
2786
- res_error(res, result.data);
2787
- } else {
2788
- res_ok(res, result.data);
3385
+ if (result->is_error()) {
3386
+ res_error(res, result->to_json());
3387
+ return;
2789
3388
  }
3389
+
3390
+ res_ok(res, result->to_json());
2790
3391
  };
2791
3392
 
2792
3393
  const auto handle_slots_restore = [&ctx_server, &res_error, &res_ok, &params](const httplib::Request & req, httplib::Response & res, int id_slot) {
@@ -2798,45 +3399,45 @@ int main(int argc, char ** argv) {
2798
3399
  }
2799
3400
  std::string filepath = params.slot_save_path + filename;
2800
3401
 
2801
- server_task task;
2802
- task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
2803
- task.data = {
2804
- { "id_slot", id_slot },
2805
- { "filename", filename },
2806
- { "filepath", filepath },
2807
- };
3402
+ server_task task(SERVER_TASK_TYPE_SLOT_RESTORE);
3403
+ task.id = ctx_server.queue_tasks.get_new_id();
3404
+ task.slot_action.slot_id = id_slot;
3405
+ task.slot_action.filename = filename;
3406
+ task.slot_action.filepath = filepath;
2808
3407
 
2809
- const int id_task = ctx_server.queue_tasks.post(task);
2810
- ctx_server.queue_results.add_waiting_task_id(id_task);
3408
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3409
+ ctx_server.queue_tasks.post(task);
2811
3410
 
2812
- server_task_result result = ctx_server.queue_results.recv(id_task);
2813
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3411
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3412
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
2814
3413
 
2815
- if (result.error) {
2816
- res_error(res, result.data);
2817
- } else {
2818
- res_ok(res, result.data);
3414
+ if (result->is_error()) {
3415
+ res_error(res, result->to_json());
3416
+ return;
2819
3417
  }
3418
+
3419
+ GGML_ASSERT(dynamic_cast<server_task_result_slot_save_load*>(result.get()) != nullptr);
3420
+ res_ok(res, result->to_json());
2820
3421
  };
2821
3422
 
2822
3423
  const auto handle_slots_erase = [&ctx_server, &res_error, &res_ok](const httplib::Request & /* req */, httplib::Response & res, int id_slot) {
2823
- server_task task;
2824
- task.type = SERVER_TASK_TYPE_SLOT_ERASE;
2825
- task.data = {
2826
- { "id_slot", id_slot },
2827
- };
3424
+ server_task task(SERVER_TASK_TYPE_SLOT_ERASE);
3425
+ task.id = ctx_server.queue_tasks.get_new_id();
3426
+ task.slot_action.slot_id = id_slot;
2828
3427
 
2829
- const int id_task = ctx_server.queue_tasks.post(task);
2830
- ctx_server.queue_results.add_waiting_task_id(id_task);
3428
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3429
+ ctx_server.queue_tasks.post(task);
2831
3430
 
2832
- server_task_result result = ctx_server.queue_results.recv(id_task);
2833
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3431
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3432
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
2834
3433
 
2835
- if (result.error) {
2836
- res_error(res, result.data);
2837
- } else {
2838
- res_ok(res, result.data);
3434
+ if (result->is_error()) {
3435
+ res_error(res, result->to_json());
3436
+ return;
2839
3437
  }
3438
+
3439
+ GGML_ASSERT(dynamic_cast<server_task_result_slot_erase*>(result.get()) != nullptr);
3440
+ res_ok(res, result->to_json());
2840
3441
  };
2841
3442
 
2842
3443
  const auto handle_slots_action = [&params, &res_error, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) {
@@ -2869,31 +3470,74 @@ int main(int argc, char ** argv) {
2869
3470
  };
2870
3471
 
2871
3472
  const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
2872
- std::string template_key = "tokenizer.chat_template", curr_tmpl;
2873
- int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
2874
- if (tlen > 0) {
2875
- std::vector<char> curr_tmpl_buf(tlen + 1, 0);
2876
- if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
2877
- curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
2878
- }
2879
- }
3473
+ // this endpoint is publicly available, please only return what is safe to be exposed
2880
3474
  json data = {
2881
- { "system_prompt", ctx_server.system_prompt.c_str() },
2882
3475
  { "default_generation_settings", ctx_server.default_generation_settings_for_props },
2883
- { "total_slots", ctx_server.params.n_parallel },
2884
- { "chat_template", curr_tmpl.c_str() },
3476
+ { "total_slots", ctx_server.params_base.n_parallel },
3477
+ { "model_path", ctx_server.params_base.model },
3478
+ { "chat_template", llama_get_chat_template(ctx_server.model) },
2885
3479
  };
2886
3480
 
2887
3481
  res_ok(res, data);
2888
3482
  };
2889
3483
 
2890
- const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
2891
- if (ctx_server.params.embedding || ctx_server.params.reranking) {
2892
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3484
+ const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3485
+ if (!ctx_server.params_base.endpoint_props) {
3486
+ res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
3487
+ return;
3488
+ }
3489
+
3490
+ json data = json::parse(req.body);
3491
+
3492
+ // update any props here
3493
+
3494
+ res_ok(res, {{ "success", true }});
3495
+ };
3496
+
3497
+ // handle completion-like requests (completion, chat, infill)
3498
+ // we can optionally provide a custom format for partial results and final results
3499
+ const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
3500
+ server_task_type type,
3501
+ json & data,
3502
+ httplib::Response & res,
3503
+ bool oaicompat = false,
3504
+ bool oaicompat_chat = false) {
3505
+ GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
3506
+
3507
+ if (ctx_server.params_base.embedding) {
3508
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
3509
+ return;
3510
+ }
3511
+
3512
+ auto completion_id = gen_chatcmplid();
3513
+ std::vector<server_task> tasks;
3514
+
3515
+ try {
3516
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, data.at("prompt"), true, true);
3517
+ tasks.reserve(tokenized_prompts.size());
3518
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3519
+ server_task task = server_task(type);
3520
+
3521
+ task.id = ctx_server.queue_tasks.get_new_id();
3522
+ task.index = i;
3523
+
3524
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
3525
+ task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
3526
+ task.id_selected_slot = json_value(data, "id_slot", -1);
3527
+
3528
+ // OAI-compat
3529
+ task.params.oaicompat = oaicompat;
3530
+ task.params.oaicompat_chat = oaicompat_chat;
3531
+ task.params.oaicompat_cmpl_id = completion_id;
3532
+ // oaicompat_model is already populated by params_from_json_cmpl
3533
+
3534
+ tasks.push_back(task);
3535
+ }
3536
+ } catch (const std::exception & e) {
3537
+ res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
2893
3538
  return;
2894
3539
  }
2895
3540
 
2896
- std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
2897
3541
  ctx_server.queue_results.add_waiting_tasks(tasks);
2898
3542
  ctx_server.queue_tasks.post(tasks);
2899
3543
 
@@ -2901,15 +3545,15 @@ int main(int argc, char ** argv) {
2901
3545
  const auto task_ids = server_task::get_list_id(tasks);
2902
3546
 
2903
3547
  if (!stream) {
2904
- ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3548
+ ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
2905
3549
  if (results.size() == 1) {
2906
3550
  // single result
2907
- res_ok(res, results[0].data);
3551
+ res_ok(res, results[0]->to_json());
2908
3552
  } else {
2909
3553
  // multiple results (multitask)
2910
3554
  json arr = json::array();
2911
- for (const auto & res : results) {
2912
- arr.push_back(res.data);
3555
+ for (auto & res : results) {
3556
+ arr.push_back(res->to_json());
2913
3557
  }
2914
3558
  res_ok(res, arr);
2915
3559
  }
@@ -2919,12 +3563,26 @@ int main(int argc, char ** argv) {
2919
3563
 
2920
3564
  ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2921
3565
  } else {
2922
- const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
2923
- ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
2924
- return server_sent_event(sink, "data", result.data);
3566
+ const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
3567
+ ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
3568
+ json res_json = result->to_json();
3569
+ if (res_json.is_array()) {
3570
+ for (const auto & res : res_json) {
3571
+ if (!server_sent_event(sink, "data", res)) {
3572
+ return false;
3573
+ }
3574
+ }
3575
+ return true;
3576
+ } else {
3577
+ return server_sent_event(sink, "data", res_json);
3578
+ }
2925
3579
  }, [&](const json & error_data) {
2926
3580
  server_sent_event(sink, "error", error_data);
2927
3581
  });
3582
+ if (oaicompat) {
3583
+ static const std::string ev_done = "data: [DONE]\n\n";
3584
+ sink.write(ev_done.data(), ev_done.size());
3585
+ }
2928
3586
  sink.done();
2929
3587
  return false;
2930
3588
  };
@@ -2939,72 +3597,102 @@ int main(int argc, char ** argv) {
2939
3597
 
2940
3598
  const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2941
3599
  json data = json::parse(req.body);
2942
- return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
2943
- };
2944
-
2945
- const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2946
- json data = json::parse(req.body);
2947
- return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
3600
+ return handle_completions_generic(
3601
+ SERVER_TASK_TYPE_COMPLETION,
3602
+ data,
3603
+ res,
3604
+ /* oaicompat */ false,
3605
+ /* oaicompat_chat */ false);
2948
3606
  };
2949
3607
 
2950
- // TODO: maybe merge this function with "handle_completions_generic"
2951
- const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
2952
- if (ctx_server.params.embedding || ctx_server.params.reranking) {
2953
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3608
+ const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
3609
+ // check model compatibility
3610
+ std::string err;
3611
+ if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
3612
+ err += "prefix token is missing. ";
3613
+ }
3614
+ if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
3615
+ err += "suffix token is missing. ";
3616
+ }
3617
+ if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
3618
+ err += "middle token is missing. ";
3619
+ }
3620
+ if (!err.empty()) {
3621
+ res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
2954
3622
  return;
2955
3623
  }
2956
3624
 
2957
- json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
3625
+ json data = json::parse(req.body);
2958
3626
 
2959
- std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
2960
- ctx_server.queue_results.add_waiting_tasks(tasks);
2961
- ctx_server.queue_tasks.post(tasks);
3627
+ // validate input
3628
+ if (data.contains("prompt") && !data.at("prompt").is_string()) {
3629
+ // prompt is optional
3630
+ res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST));
3631
+ }
2962
3632
 
2963
- bool stream = json_value(data, "stream", false);
2964
- const auto task_ids = server_task::get_list_id(tasks);
2965
- const auto completion_id = gen_chatcmplid();
3633
+ if (!data.contains("input_prefix")) {
3634
+ res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
3635
+ }
2966
3636
 
2967
- if (!stream) {
2968
- ctx_server.receive_cmpl_results(task_ids, [&](const std::vector<server_task_result> & results) {
2969
- // multitask is never support in chat completion, there is only one result
2970
- json result_oai = format_final_response_oaicompat(data, results[0].data, completion_id, /*.streaming =*/ false, verbose);
2971
- res_ok(res, result_oai);
2972
- }, [&](const json & error_data) {
2973
- res_error(res, error_data);
2974
- });
3637
+ if (!data.contains("input_suffix")) {
3638
+ res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
3639
+ }
2975
3640
 
2976
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2977
- } else {
2978
- const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
2979
- ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
2980
- std::vector<json> result_array = format_partial_response_oaicompat(result.data, completion_id);
2981
- for (auto & event_data : result_array) {
2982
- if (event_data.empty()) {
2983
- continue; // skip the stop token
2984
- }
2985
- if (!server_sent_event(sink, "data", event_data)) {
2986
- return false; // connection is closed
2987
- }
2988
- }
2989
- return true; // ok
2990
- }, [&](const json & error_data) {
2991
- server_sent_event(sink, "error", error_data);
2992
- });
2993
- static const std::string ev_done = "data: [DONE]\n\n";
2994
- sink.write(ev_done.data(), ev_done.size());
2995
- sink.done();
2996
- return true;
2997
- };
3641
+ if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
3642
+ // input_extra is optional
3643
+ res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
3644
+ return;
3645
+ }
2998
3646
 
2999
- auto on_complete = [task_ids, &ctx_server] (bool) {
3000
- ctx_server.queue_results.remove_waiting_task_ids(task_ids);
3001
- };
3647
+ json input_extra = json_value(data, "input_extra", json::array());
3648
+ for (const auto & chunk : input_extra) {
3649
+ // { "text": string, "filename": string }
3650
+ if (!chunk.contains("text") || !chunk.at("text").is_string()) {
3651
+ res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
3652
+ return;
3653
+ }
3654
+ // filename is optional
3655
+ if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
3656
+ res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
3657
+ return;
3658
+ }
3659
+ }
3660
+ data["input_extra"] = input_extra; // default to empty array if it's not exist
3661
+
3662
+ std::string prompt = json_value(data, "prompt", std::string());
3663
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
3664
+ SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
3665
+ data["prompt"] = format_infill(
3666
+ ctx_server.ctx,
3667
+ data.at("input_prefix"),
3668
+ data.at("input_suffix"),
3669
+ data.at("input_extra"),
3670
+ ctx_server.params_base.n_batch,
3671
+ ctx_server.params_base.n_predict,
3672
+ ctx_server.slots[0].n_ctx, // TODO: there should be a better way
3673
+ ctx_server.params_base.spm_infill,
3674
+ tokenized_prompts[0]
3675
+ );
3002
3676
 
3003
- res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
3677
+ return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
3678
+ };
3679
+
3680
+ const auto handle_chat_completions = [&ctx_server, &params, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
3681
+ if (ctx_server.params_base.embedding) {
3682
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
3683
+ return;
3004
3684
  }
3685
+
3686
+ json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
3687
+ return handle_completions_generic(
3688
+ SERVER_TASK_TYPE_COMPLETION,
3689
+ data,
3690
+ res,
3691
+ /* oaicompat */ true,
3692
+ /* oaicompat_chat */ true);
3005
3693
  };
3006
3694
 
3007
- const auto handle_models = [&params, &ctx_server](const httplib::Request &, httplib::Response & res) {
3695
+ const auto handle_models = [&params, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
3008
3696
  json models = {
3009
3697
  {"object", "list"},
3010
3698
  {"data", {
@@ -3018,7 +3706,7 @@ int main(int argc, char ** argv) {
3018
3706
  }}
3019
3707
  };
3020
3708
 
3021
- res.set_content(models.dump(), MIMETYPE_JSON);
3709
+ res_ok(res, models);
3022
3710
  };
3023
3711
 
3024
3712
  const auto handle_tokenize = [&ctx_server, &res_ok](const httplib::Request & req, httplib::Response & res) {
@@ -3028,11 +3716,12 @@ int main(int argc, char ** argv) {
3028
3716
  if (body.count("content") != 0) {
3029
3717
  const bool add_special = json_value(body, "add_special", false);
3030
3718
  const bool with_pieces = json_value(body, "with_pieces", false);
3031
- std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special);
3719
+
3720
+ llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
3032
3721
 
3033
3722
  if (with_pieces) {
3034
3723
  for (const auto& token : tokens) {
3035
- std::string piece = llama_token_to_piece(ctx_server.ctx, token);
3724
+ std::string piece = common_token_to_piece(ctx_server.ctx, token);
3036
3725
  json piece_json;
3037
3726
 
3038
3727
  // Check if the piece is valid UTF-8
@@ -3065,7 +3754,7 @@ int main(int argc, char ** argv) {
3065
3754
 
3066
3755
  std::string content;
3067
3756
  if (body.count("tokens") != 0) {
3068
- const std::vector<llama_token> tokens = body.at("tokens");
3757
+ const llama_tokens tokens = body.at("tokens");
3069
3758
  content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
3070
3759
  }
3071
3760
 
@@ -3073,42 +3762,63 @@ int main(int argc, char ** argv) {
3073
3762
  res_ok(res, data);
3074
3763
  };
3075
3764
 
3076
- const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3077
- // TODO: somehow clean up this checks in the future
3078
- if (!ctx_server.params.embedding || ctx_server.params.reranking) {
3079
- res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3765
+ const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
3766
+ const json body = json::parse(req.body);
3767
+
3768
+ if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
3769
+ res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
3080
3770
  return;
3081
3771
  }
3082
- const json body = json::parse(req.body);
3083
- bool is_openai = false;
3084
3772
 
3085
- // an input prompt can be a string or a list of tokens (integer)
3773
+ // for the shape of input/content, see tokenize_input_prompts()
3086
3774
  json prompt;
3087
3775
  if (body.count("input") != 0) {
3088
- is_openai = true;
3089
3776
  prompt = body.at("input");
3090
- } else if (body.count("content") != 0) {
3091
- // with "content", we only support single prompt
3092
- prompt = std::vector<std::string>{body.at("content")};
3777
+ } else if (body.contains("content")) {
3778
+ oaicompat = false;
3779
+ prompt = body.at("content");
3093
3780
  } else {
3094
3781
  res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
3095
3782
  return;
3096
3783
  }
3097
3784
 
3785
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
3786
+ for (const auto & tokens : tokenized_prompts) {
3787
+ // this check is necessary for models that do not add BOS token to the input
3788
+ if (tokens.empty()) {
3789
+ res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
3790
+ return;
3791
+ }
3792
+ }
3793
+
3098
3794
  // create and queue the task
3099
3795
  json responses = json::array();
3100
3796
  bool error = false;
3101
3797
  {
3102
- std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
3798
+ std::vector<server_task> tasks;
3799
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
3800
+ server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
3801
+
3802
+ task.id = ctx_server.queue_tasks.get_new_id();
3803
+ task.index = i;
3804
+ task.prompt_tokens = std::move(tokenized_prompts[i]);
3805
+
3806
+ // OAI-compat
3807
+ task.params.oaicompat = oaicompat;
3808
+
3809
+ tasks.push_back(task);
3810
+ }
3811
+
3103
3812
  ctx_server.queue_results.add_waiting_tasks(tasks);
3104
3813
  ctx_server.queue_tasks.post(tasks);
3105
3814
 
3106
3815
  // get the result
3107
3816
  std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
3108
3817
 
3109
- ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3110
- for (const auto & res : results) {
3111
- responses.push_back(res.data);
3818
+ ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
3819
+ for (auto & res : results) {
3820
+ GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
3821
+ responses.push_back(res->to_json());
3112
3822
  }
3113
3823
  }, [&](const json & error_data) {
3114
3824
  res_error(res, error_data);
@@ -3123,17 +3833,24 @@ int main(int argc, char ** argv) {
3123
3833
  }
3124
3834
 
3125
3835
  // write JSON response
3126
- json root = is_openai
3127
- ? format_embeddings_response_oaicompat(body, responses)
3128
- : responses[0];
3836
+ json root = oaicompat ? format_embeddings_response_oaicompat(body, responses) : json(responses);
3129
3837
  res_ok(res, root);
3130
3838
  };
3131
3839
 
3840
+ const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3841
+ handle_embeddings_impl(req, res, false);
3842
+ };
3843
+
3844
+ const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
3845
+ handle_embeddings_impl(req, res, true);
3846
+ };
3847
+
3132
3848
  const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3133
- if (!ctx_server.params.reranking) {
3134
- res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3849
+ if (!ctx_server.params_base.reranking || ctx_server.params_base.embedding) {
3850
+ res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
3135
3851
  return;
3136
3852
  }
3853
+
3137
3854
  const json body = json::parse(req.body);
3138
3855
 
3139
3856
  // TODO: implement
@@ -3163,29 +3880,33 @@ int main(int argc, char ** argv) {
3163
3880
  return;
3164
3881
  }
3165
3882
 
3166
- // construct prompt object: array of ["query", "doc0", "doc1", ...]
3167
- json prompt;
3168
- prompt.push_back(query);
3169
- for (const auto & doc : documents) {
3170
- prompt.push_back(doc);
3171
- }
3172
-
3173
- LOG_DBG("rerank prompt: %s\n", prompt.dump().c_str());
3883
+ llama_tokens tokenized_query = tokenize_input_prompts(ctx_server.ctx, query, /* add_special */ false, true)[0];
3174
3884
 
3175
3885
  // create and queue the task
3176
3886
  json responses = json::array();
3177
3887
  bool error = false;
3178
3888
  {
3179
- std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
3889
+ std::vector<server_task> tasks;
3890
+ std::vector<llama_tokens> tokenized_docs = tokenize_input_prompts(ctx_server.ctx, documents, /* add_special */ false, true);
3891
+ tasks.reserve(tokenized_docs.size());
3892
+ for (size_t i = 0; i < tokenized_docs.size(); i++) {
3893
+ server_task task = server_task(SERVER_TASK_TYPE_RERANK);
3894
+ task.id = ctx_server.queue_tasks.get_new_id();
3895
+ task.index = i;
3896
+ task.prompt_tokens = format_rerank(ctx_server.model, tokenized_query, tokenized_docs[i]);
3897
+ tasks.push_back(task);
3898
+ }
3899
+
3180
3900
  ctx_server.queue_results.add_waiting_tasks(tasks);
3181
3901
  ctx_server.queue_tasks.post(tasks);
3182
3902
 
3183
3903
  // get the result
3184
3904
  std::unordered_set<int> task_ids = server_task::get_list_id(tasks);
3185
3905
 
3186
- ctx_server.receive_cmpl_results(task_ids, [&](std::vector<server_task_result> & results) {
3187
- for (const auto & res : results) {
3188
- responses.push_back(res.data);
3906
+ ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
3907
+ for (auto & res : results) {
3908
+ GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
3909
+ responses.push_back(res->to_json());
3189
3910
  }
3190
3911
  }, [&](const json & error_data) {
3191
3912
  res_error(res, error_data);
@@ -3236,59 +3957,59 @@ int main(int argc, char ** argv) {
3236
3957
  }
3237
3958
  }
3238
3959
 
3239
- server_task task;
3240
- task.type = SERVER_TASK_TYPE_SET_LORA;
3241
- const int id_task = ctx_server.queue_tasks.post(task);
3242
- ctx_server.queue_results.add_waiting_task_id(id_task);
3960
+ server_task task(SERVER_TASK_TYPE_SET_LORA);
3961
+ task.id = ctx_server.queue_tasks.get_new_id();
3962
+ ctx_server.queue_results.add_waiting_task_id(task.id);
3963
+ ctx_server.queue_tasks.post(task);
3243
3964
 
3244
- server_task_result result = ctx_server.queue_results.recv(id_task);
3245
- ctx_server.queue_results.remove_waiting_task_id(id_task);
3965
+ server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3966
+ ctx_server.queue_results.remove_waiting_task_id(task.id);
3246
3967
 
3247
- res_ok(res, result.data);
3248
- res.status = 200; // HTTP OK
3249
- };
3968
+ if (result->is_error()) {
3969
+ res_error(res, result->to_json());
3970
+ return;
3971
+ }
3250
3972
 
3251
- auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
3252
- return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
3253
- res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
3254
- return false;
3255
- };
3973
+ GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
3974
+ res_ok(res, result->to_json());
3256
3975
  };
3257
3976
 
3258
3977
  //
3259
3978
  // Router
3260
3979
  //
3261
3980
 
3262
- // register static assets routes
3263
- if (!params.public_path.empty()) {
3264
- // Set the base directory for serving static files
3265
- svr->set_base_dir(params.public_path);
3266
- }
3267
-
3268
- // using embedded static files
3269
- svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
3270
- svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
3271
- svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
3272
- svr->Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
3273
-
3274
- // add new-ui files
3275
- svr->Get("/colorthemes.css", handle_static_file(colorthemes_css, colorthemes_css_len, "text/css; charset=utf-8"));
3276
- svr->Get("/style.css", handle_static_file(style_css, style_css_len, "text/css; charset=utf-8"));
3277
- svr->Get("/theme-beeninorder.css", handle_static_file(theme_beeninorder_css, theme_beeninorder_css_len, "text/css; charset=utf-8"));
3278
- svr->Get("/theme-ketivah.css", handle_static_file(theme_ketivah_css, theme_ketivah_css_len, "text/css; charset=utf-8"));
3279
- svr->Get("/theme-mangotango.css", handle_static_file(theme_mangotango_css, theme_mangotango_css_len, "text/css; charset=utf-8"));
3280
- svr->Get("/theme-playground.css", handle_static_file(theme_playground_css, theme_playground_css_len, "text/css; charset=utf-8"));
3281
- svr->Get("/theme-polarnight.css", handle_static_file(theme_polarnight_css, theme_polarnight_css_len, "text/css; charset=utf-8"));
3282
- svr->Get("/theme-snowstorm.css", handle_static_file(theme_snowstorm_css, theme_snowstorm_css_len, "text/css; charset=utf-8"));
3283
- svr->Get("/index-new.html", handle_static_file(index_new_html, index_new_html_len, "text/html; charset=utf-8"));
3284
- svr->Get("/system-prompts.js", handle_static_file(system_prompts_js, system_prompts_js_len, "text/javascript; charset=utf-8"));
3285
- svr->Get("/prompt-formats.js", handle_static_file(prompt_formats_js, prompt_formats_js_len, "text/javascript; charset=utf-8"));
3981
+ if (!params.webui) {
3982
+ LOG_INF("Web UI is disabled\n");
3983
+ } else {
3984
+ // register static assets routes
3985
+ if (!params.public_path.empty()) {
3986
+ // Set the base directory for serving static files
3987
+ bool is_found = svr->set_mount_point("/", params.public_path);
3988
+ if (!is_found) {
3989
+ LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
3990
+ return 1;
3991
+ }
3992
+ } else {
3993
+ // using embedded static index.html
3994
+ svr->Get("/", [](const httplib::Request & req, httplib::Response & res) {
3995
+ if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) {
3996
+ res.set_content("Error: gzip is not supported by this browser", "text/plain");
3997
+ } else {
3998
+ res.set_header("Content-Encoding", "gzip");
3999
+ res.set_content(reinterpret_cast<const char*>(index_html_gz), index_html_gz_len, "text/html; charset=utf-8");
4000
+ }
4001
+ return false;
4002
+ });
4003
+ }
4004
+ }
3286
4005
 
3287
4006
  // register API routes
3288
- svr->Get ("/health", handle_health);
4007
+ svr->Get ("/health", handle_health); // public endpoint (no API key check)
3289
4008
  svr->Get ("/metrics", handle_metrics);
3290
4009
  svr->Get ("/props", handle_props);
3291
- svr->Get ("/v1/models", handle_models);
4010
+ svr->Post("/props", handle_props_change);
4011
+ svr->Get ("/models", handle_models); // public endpoint (no API key check)
4012
+ svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
3292
4013
  svr->Post("/completion", handle_completions); // legacy
3293
4014
  svr->Post("/completions", handle_completions);
3294
4015
  svr->Post("/v1/completions", handle_completions);
@@ -3297,7 +4018,7 @@ int main(int argc, char ** argv) {
3297
4018
  svr->Post("/infill", handle_infill);
3298
4019
  svr->Post("/embedding", handle_embeddings); // legacy
3299
4020
  svr->Post("/embeddings", handle_embeddings);
3300
- svr->Post("/v1/embeddings", handle_embeddings);
4021
+ svr->Post("/v1/embeddings", handle_embeddings_oai);
3301
4022
  svr->Post("/rerank", handle_rerank);
3302
4023
  svr->Post("/reranking", handle_rerank);
3303
4024
  svr->Post("/v1/rerank", handle_rerank);
@@ -3327,8 +4048,18 @@ int main(int argc, char ** argv) {
3327
4048
  llama_backend_free();
3328
4049
  };
3329
4050
 
3330
- // bind HTTP listen port, run the HTTP server in a thread
3331
- if (!svr->bind_to_port(params.hostname, params.port)) {
4051
+ // bind HTTP listen port
4052
+ bool was_bound = false;
4053
+ if (params.port == 0) {
4054
+ int bound_port = svr->bind_to_any_port(params.hostname);
4055
+ if ((was_bound = (bound_port >= 0))) {
4056
+ params.port = bound_port;
4057
+ }
4058
+ } else {
4059
+ was_bound = svr->bind_to_port(params.hostname, params.port);
4060
+ }
4061
+
4062
+ if (!was_bound) {
3332
4063
  //LOG_ERROR("couldn't bind HTTP server socket", {
3333
4064
  // {"hostname", params.hostname},
3334
4065
  // {"port", params.port},
@@ -3337,6 +4068,8 @@ int main(int argc, char ** argv) {
3337
4068
  clean_up();
3338
4069
  return 1;
3339
4070
  }
4071
+
4072
+ // run the HTTP server in a thread
3340
4073
  std::thread t([&]() { svr->listen_after_bind(); });
3341
4074
  svr->wait_until_ready();
3342
4075
 
@@ -3366,10 +4099,11 @@ int main(int argc, char ** argv) {
3366
4099
  }
3367
4100
 
3368
4101
  // print sample chat example to make it clear which template is used
3369
- LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
4102
+ LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
3370
4103
 
3371
4104
  ctx_server.queue_tasks.on_new_task(std::bind(
3372
4105
  &server_context::process_single_task, &ctx_server, std::placeholders::_1));
4106
+
3373
4107
  ctx_server.queue_tasks.on_update_slots(std::bind(
3374
4108
  &server_context::update_slots, &ctx_server));
3375
4109
 
@@ -3377,7 +4111,7 @@ int main(int argc, char ** argv) {
3377
4111
  ctx_server.queue_tasks.terminate();
3378
4112
  };
3379
4113
 
3380
- LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
4114
+ LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
3381
4115
 
3382
4116
  ctx_server.queue_tasks.start_loop();
3383
4117