@fugood/llama.node 0.3.2 → 0.3.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. package/CMakeLists.txt +2 -0
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +1 -1
  17. package/src/DetokenizeWorker.cpp +1 -1
  18. package/src/EmbeddingWorker.cpp +2 -2
  19. package/src/LlamaCompletionWorker.cpp +8 -8
  20. package/src/LlamaCompletionWorker.h +2 -2
  21. package/src/LlamaContext.cpp +8 -9
  22. package/src/TokenizeWorker.cpp +1 -1
  23. package/src/common.hpp +4 -4
  24. package/src/llama.cpp/.github/workflows/build.yml +43 -9
  25. package/src/llama.cpp/.github/workflows/docker.yml +3 -0
  26. package/src/llama.cpp/CMakeLists.txt +7 -4
  27. package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
  28. package/src/llama.cpp/common/CMakeLists.txt +0 -2
  29. package/src/llama.cpp/common/arg.cpp +642 -607
  30. package/src/llama.cpp/common/arg.h +22 -22
  31. package/src/llama.cpp/common/common.cpp +79 -281
  32. package/src/llama.cpp/common/common.h +130 -100
  33. package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
  34. package/src/llama.cpp/common/log.cpp +50 -50
  35. package/src/llama.cpp/common/log.h +18 -18
  36. package/src/llama.cpp/common/ngram-cache.cpp +36 -36
  37. package/src/llama.cpp/common/ngram-cache.h +19 -19
  38. package/src/llama.cpp/common/sampling.cpp +116 -108
  39. package/src/llama.cpp/common/sampling.h +20 -20
  40. package/src/llama.cpp/docs/build.md +37 -17
  41. package/src/llama.cpp/examples/CMakeLists.txt +1 -1
  42. package/src/llama.cpp/examples/batched/batched.cpp +14 -14
  43. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
  44. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
  45. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
  46. package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
  47. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
  48. package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
  49. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
  50. package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
  51. package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
  52. package/src/llama.cpp/examples/infill/infill.cpp +40 -86
  53. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
  54. package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
  55. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
  56. package/src/llama.cpp/examples/llava/clip.cpp +1 -0
  57. package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
  58. package/src/llama.cpp/examples/llava/llava.cpp +37 -3
  59. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
  60. package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
  61. package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
  62. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
  63. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
  64. package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
  65. package/src/llama.cpp/examples/main/main.cpp +64 -109
  66. package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
  67. package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
  68. package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
  69. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
  70. package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
  71. package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
  72. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
  73. package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
  74. package/src/llama.cpp/examples/server/server.cpp +553 -691
  75. package/src/llama.cpp/examples/server/utils.hpp +312 -25
  76. package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
  77. package/src/llama.cpp/examples/simple/simple.cpp +128 -96
  78. package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
  79. package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
  80. package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
  81. package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
  82. package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
  83. package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
  84. package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
  85. package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
  86. package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
  87. package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
  88. package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
  89. package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
  90. package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
  91. package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
  92. package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
  93. package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
  94. package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
  95. package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
  96. package/src/llama.cpp/ggml/include/ggml.h +53 -393
  97. package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
  98. package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
  99. package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
  100. package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
  101. package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
  102. package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
  103. package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
  104. package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
  105. package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
  106. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
  107. package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
  108. package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
  109. package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
  110. package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
  111. package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
  112. package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
  113. package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
  114. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
  115. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
  116. package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
  117. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
  118. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
  120. package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
  121. package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
  122. package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
  123. package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
  124. package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
  125. package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
  126. package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
  127. package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
  128. package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
  129. package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
  130. package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
  131. package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
  132. package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
  133. package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
  134. package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
  135. package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
  136. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  137. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
  138. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
  141. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
  142. package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
  143. package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
  144. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
  145. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
  146. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
  148. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
  149. package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
  150. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
  151. package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
  152. package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
  153. package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
  154. package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
  155. package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
  156. package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
  157. package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
  158. package/src/llama.cpp/include/llama.h +67 -33
  159. package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
  160. package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
  161. package/src/llama.cpp/src/CMakeLists.txt +2 -1
  162. package/src/llama.cpp/src/llama-sampling.cpp +745 -105
  163. package/src/llama.cpp/src/llama-sampling.h +21 -2
  164. package/src/llama.cpp/src/llama-vocab.cpp +49 -9
  165. package/src/llama.cpp/src/llama-vocab.h +35 -11
  166. package/src/llama.cpp/src/llama.cpp +2636 -2406
  167. package/src/llama.cpp/src/unicode-data.cpp +2 -2
  168. package/src/llama.cpp/tests/CMakeLists.txt +1 -2
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
  171. package/src/llama.cpp/tests/test-barrier.cpp +1 -0
  172. package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
  173. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
  174. package/src/llama.cpp/tests/test-log.cpp +2 -2
  175. package/src/llama.cpp/tests/test-opt.cpp +853 -142
  176. package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
  177. package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
  178. package/src/llama.cpp/tests/test-rope.cpp +1 -0
  179. package/src/llama.cpp/tests/test-sampling.cpp +162 -137
  180. package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
  181. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
  182. package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
  183. package/src/llama.cpp/common/train.cpp +0 -1515
  184. package/src/llama.cpp/common/train.h +0 -233
  185. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
  186. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
  187. package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
  188. /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
  189. /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
  190. /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
@@ -14,22 +14,13 @@
14
14
  #define MIMETYPE_JSON "application/json; charset=utf-8"
15
15
 
16
16
  // auto generated files (update with ./deps.sh)
17
- #include "colorthemes.css.hpp"
18
- #include "style.css.hpp"
19
- #include "theme-beeninorder.css.hpp"
20
- #include "theme-ketivah.css.hpp"
21
- #include "theme-mangotango.css.hpp"
22
- #include "theme-playground.css.hpp"
23
- #include "theme-polarnight.css.hpp"
24
- #include "theme-snowstorm.css.hpp"
25
17
  #include "index.html.hpp"
26
- #include "index-new.html.hpp"
27
- #include "index.js.hpp"
28
18
  #include "completion.js.hpp"
29
- #include "system-prompts.js.hpp"
30
- #include "prompt-formats.js.hpp"
31
- #include "json-schema-to-grammar.mjs.hpp"
32
19
  #include "loading.html.hpp"
20
+ #include "deps_daisyui.min.css.hpp"
21
+ #include "deps_markdown-it.js.hpp"
22
+ #include "deps_tailwindcss.js.hpp"
23
+ #include "deps_vue.esm-browser.js.hpp"
33
24
 
34
25
  #include <atomic>
35
26
  #include <condition_variable>
@@ -43,21 +34,6 @@
43
34
  #include <unordered_map>
44
35
  #include <unordered_set>
45
36
 
46
- #define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
47
- #define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
48
- #define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
49
- #define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, (slot).id_task, __VA_ARGS__)
50
-
51
- #define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
52
- #define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
53
- #define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
54
- #define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__)
55
-
56
- #define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
57
- #define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
58
- #define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
59
- #define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__)
60
-
61
37
  using json = nlohmann::ordered_json;
62
38
 
63
39
  enum stop_type {
@@ -68,6 +44,7 @@ enum stop_type {
68
44
  // state diagram: https://github.com/ggerganov/llama.cpp/pull/9283
69
45
  enum slot_state {
70
46
  SLOT_STATE_IDLE,
47
+ SLOT_STATE_STARTED, // TODO: this state is only used for setting up the initial prompt processing; maybe merge it with launch_slot_with_task in the future
71
48
  SLOT_STATE_PROCESSING_PROMPT,
72
49
  SLOT_STATE_DONE_PROMPT,
73
50
  SLOT_STATE_GENERATING,
@@ -79,7 +56,7 @@ enum server_state {
79
56
  };
80
57
 
81
58
  enum server_task_type {
82
- SERVER_TASK_TYPE_COMPLETION,
59
+ SERVER_TASK_TYPE_INFERENCE,
83
60
  SERVER_TASK_TYPE_CANCEL,
84
61
  SERVER_TASK_TYPE_NEXT_RESPONSE,
85
62
  SERVER_TASK_TYPE_METRICS,
@@ -89,21 +66,22 @@ enum server_task_type {
89
66
  SERVER_TASK_TYPE_SET_LORA,
90
67
  };
91
68
 
92
- enum server_task_cmpl_type {
93
- SERVER_TASK_CMPL_TYPE_NORMAL,
94
- SERVER_TASK_CMPL_TYPE_EMBEDDING,
95
- SERVER_TASK_CMPL_TYPE_RERANK,
96
- SERVER_TASK_CMPL_TYPE_INFILL,
69
+ enum server_task_inf_type {
70
+ SERVER_TASK_INF_TYPE_COMPLETION,
71
+ SERVER_TASK_INF_TYPE_EMBEDDING,
72
+ SERVER_TASK_INF_TYPE_RERANK,
73
+ SERVER_TASK_INF_TYPE_INFILL,
97
74
  };
98
75
 
99
76
  struct server_task {
100
77
  int id = -1; // to be filled by server_queue
101
78
  int id_target = -1; // used by SERVER_TASK_TYPE_CANCEL
102
79
 
80
+ llama_tokens prompt_tokens;
103
81
  server_task_type type;
104
82
  json data;
105
83
 
106
- server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
84
+ server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
107
85
 
108
86
  // utility function
109
87
  static std::unordered_set<int> get_list_id(const std::vector<server_task> & tasks) {
@@ -124,18 +102,25 @@ struct server_task_result {
124
102
  bool error;
125
103
  };
126
104
 
105
+ struct server_static_file {
106
+ const unsigned char * data;
107
+ unsigned int size;
108
+ const char * mime_type;
109
+ };
110
+
127
111
  struct slot_params {
128
112
  bool stream = true;
129
113
  bool cache_prompt = false; // remember the prompt to avoid reprocessing all prompt
130
114
 
131
- int32_t n_keep = 0; // number of tokens to keep from initial prompt
132
- int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
133
- int32_t n_predict = -1; // new tokens to predict
115
+ int32_t n_keep = 0; // number of tokens to keep from initial prompt
116
+ int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
117
+ int32_t n_predict = -1; // new tokens to predict
118
+ int32_t n_indent = 0; // mininum line indentation for the generated text in number of whitespace characters
134
119
 
135
- std::vector<std::string> antiprompt;
120
+ int64_t t_max_prompt_ms = -1; // TODO: implement
121
+ int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
136
122
 
137
- json input_prefix;
138
- json input_suffix;
123
+ std::vector<std::string> antiprompt;
139
124
  };
140
125
 
141
126
  struct server_slot {
@@ -160,21 +145,23 @@ struct server_slot {
160
145
  int32_t i_batch = -1;
161
146
  int32_t n_predict = -1; // TODO: disambiguate from params.n_predict
162
147
 
148
+ // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated
163
149
  int32_t n_prompt_tokens = 0;
164
150
  int32_t n_prompt_tokens_processed = 0;
165
151
 
166
- json prompt; // can be either a string, array of strings or array of token ids
152
+ // input prompt tokens
153
+ llama_tokens prompt_tokens;
167
154
 
168
- // when a task is submitted, we first tokenize the prompt and store it here
169
- std::vector<llama_token> prompt_tokens;
155
+ size_t last_nl_pos = 0;
170
156
 
171
157
  std::string generated_text;
172
- std::vector<llama_token> cache_tokens;
158
+ llama_tokens cache_tokens;
173
159
  std::vector<completion_token_output> generated_token_probs;
174
160
 
175
- server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
161
+ server_task_inf_type inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
176
162
 
177
163
  bool has_next_token = true;
164
+ bool has_new_line = false;
178
165
  bool truncated = false;
179
166
  bool stopped_eos = false;
180
167
  bool stopped_word = false;
@@ -188,26 +175,20 @@ struct server_slot {
188
175
  // sampling
189
176
  json json_schema;
190
177
 
191
- struct gpt_sampler_params sparams;
192
- struct gpt_sampler * smpl = nullptr;
178
+ struct common_sampler_params sparams;
179
+ struct common_sampler * smpl = nullptr;
193
180
 
194
181
  llama_token sampled;
195
182
 
196
- int32_t ga_i = 0; // group-attention state
197
- int32_t ga_n = 1; // group-attention factor
198
- int32_t ga_w = 512; // group-attention width
199
-
200
- int32_t n_past_se = 0; // self-extend
201
-
202
183
  // stats
203
- size_t n_sent_text = 0; // number of sent text character
184
+ size_t n_sent_text = 0; // number of sent text character
204
185
  size_t n_sent_token_probs = 0;
205
186
 
206
187
  int64_t t_start_process_prompt;
207
188
  int64_t t_start_generation;
208
189
 
209
190
  double t_prompt_processing; // ms
210
- double t_token_generation; // ms
191
+ double t_token_generation; // ms
211
192
 
212
193
  std::function<void(int)> callback_on_release;
213
194
 
@@ -215,7 +196,9 @@ struct server_slot {
215
196
  SLT_DBG(*this, "%s", "\n");
216
197
 
217
198
  n_prompt_tokens = 0;
199
+ last_nl_pos = 0;
218
200
  generated_text = "";
201
+ has_new_line = false;
219
202
  truncated = false;
220
203
  stopped_eos = false;
221
204
  stopped_word = false;
@@ -224,14 +207,12 @@ struct server_slot {
224
207
  n_past = 0;
225
208
  n_sent_text = 0;
226
209
  n_sent_token_probs = 0;
227
- cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL;
228
- ga_i = 0;
229
- n_past_se = 0;
210
+ inf_type = SERVER_TASK_INF_TYPE_COMPLETION;
230
211
 
231
212
  generated_token_probs.clear();
232
213
  }
233
214
 
234
- bool has_budget(gpt_params &global_params) {
215
+ bool has_budget(common_params &global_params) {
235
216
  if (params.n_predict == -1 && global_params.n_predict == -1) {
236
217
  return true; // limitless
237
218
  }
@@ -263,6 +244,7 @@ struct server_slot {
263
244
  if (is_processing()) {
264
245
  SLT_INF(*this, "stop processing: n_past = %d, truncated = %d\n", n_past, truncated);
265
246
 
247
+ t_last_used = ggml_time_us();
266
248
  t_token_generation = (ggml_time_us() - t_start_generation) / 1e3;
267
249
  state = SLOT_STATE_IDLE;
268
250
  callback_on_release(id);
@@ -393,8 +375,8 @@ struct server_queue {
393
375
  std::condition_variable condition_tasks;
394
376
 
395
377
  // callback functions
396
- std::function<void(server_task&)> callback_new_task;
397
- std::function<void(void)> callback_update_slots;
378
+ std::function<void(server_task)> callback_new_task;
379
+ std::function<void(void)> callback_update_slots;
398
380
 
399
381
  // Add a new task to the end of the queue
400
382
  int post(server_task task, bool front = false) {
@@ -446,7 +428,7 @@ struct server_queue {
446
428
  }
447
429
 
448
430
  // Register function to process a new task
449
- void on_new_task(std::function<void(server_task &)> callback) {
431
+ void on_new_task(std::function<void(server_task)> callback) {
450
432
  callback_new_task = std::move(callback);
451
433
  }
452
434
 
@@ -496,7 +478,7 @@ struct server_queue {
496
478
  lock.unlock();
497
479
 
498
480
  QUE_DBG("processing task, id = %d\n", task.id);
499
- callback_new_task(task);
481
+ callback_new_task(std::move(task));
500
482
  }
501
483
 
502
484
  // all tasks in the current loop is processed, slots data is now ready
@@ -611,9 +593,9 @@ struct server_response {
611
593
  struct server_context {
612
594
  llama_model * model = nullptr;
613
595
  llama_context * ctx = nullptr;
614
- std::vector<llama_lora_adapter_container> loras;
596
+ std::vector<common_lora_adapter_container> loras;
615
597
 
616
- gpt_params params;
598
+ common_params params;
617
599
 
618
600
  llama_batch batch = {};
619
601
 
@@ -623,12 +605,6 @@ struct server_context {
623
605
 
624
606
  int32_t n_ctx; // total context for all clients / slots
625
607
 
626
- // system prompt
627
- bool system_need_update = false;
628
-
629
- std::string system_prompt;
630
- std::vector<llama_token> system_tokens;
631
-
632
608
  // slots / clients
633
609
  std::vector<server_slot> slots;
634
610
  json default_generation_settings_for_props;
@@ -655,27 +631,22 @@ struct server_context {
655
631
  // Clear any sampling context
656
632
  for (server_slot & slot : slots) {
657
633
  if (slot.smpl != nullptr) {
658
- gpt_sampler_free(slot.smpl);
634
+ common_sampler_free(slot.smpl);
659
635
  }
660
636
  }
661
637
 
662
638
  llama_batch_free(batch);
663
639
  }
664
640
 
665
- bool load_model(const gpt_params & params_) {
641
+ bool load_model(const common_params & params_) {
666
642
  params = params_;
667
643
 
668
- // dedicate one sequence to the system prompt
669
- params.n_parallel += 1;
670
-
671
- llama_init_result llama_init = llama_init_from_gpt_params(params);
644
+ common_init_result llama_init = common_init_from_params(params);
672
645
 
673
646
  model = llama_init.model;
674
647
  ctx = llama_init.context;
675
648
  loras = llama_init.lora_adapters;
676
649
 
677
- params.n_parallel -= 1; // but be sneaky about it
678
-
679
650
  if (model == nullptr) {
680
651
  SRV_ERR("failed to load model, '%s'\n", params.model.c_str());
681
652
  return false;
@@ -690,11 +661,16 @@ struct server_context {
690
661
  }
691
662
 
692
663
  bool validate_model_chat_template() const {
693
- llama_chat_message chat[] = {{"user", "test"}};
694
-
695
- const int res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
696
-
697
- return res > 0;
664
+ std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
665
+ std::string template_key = "tokenizer.chat_template";
666
+ int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
667
+ if (res >= 0) {
668
+ llama_chat_message chat[] = {{"user", "test"}};
669
+ std::string tmpl = std::string(model_template.data(), model_template.size());
670
+ int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
671
+ return chat_res > 0;
672
+ }
673
+ return false;
698
674
  }
699
675
 
700
676
  void init() {
@@ -711,22 +687,6 @@ struct server_context {
711
687
 
712
688
  SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
713
689
 
714
- const int ga_n = params.grp_attn_n;
715
- const int ga_w = params.grp_attn_w;
716
-
717
- if (ga_n != 1) {
718
- GGML_ASSERT(ga_n > 0 && "ga_n must be positive"); // NOLINT
719
- GGML_ASSERT(ga_w % ga_n == 0 && "ga_w must be a multiple of ga_n"); // NOLINT
720
- //GGML_ASSERT(n_ctx_train % ga_w == 0 && "n_ctx_train must be a multiple of ga_w"); // NOLINT
721
- //GGML_ASSERT(n_ctx >= n_ctx_train * ga_n && "n_ctx must be at least n_ctx_train * ga_n"); // NOLINT
722
-
723
- SLT_INF(slot, "slot self-extend: ga_n = %d, ga_w = %d\n", ga_n, ga_w);
724
- }
725
-
726
- slot.ga_i = 0;
727
- slot.ga_n = ga_n;
728
- slot.ga_w = ga_w;
729
-
730
690
  slot.sparams = params.sparams;
731
691
 
732
692
  slot.callback_on_release = [this](int) {
@@ -753,47 +713,6 @@ struct server_context {
753
713
  metrics.init();
754
714
  }
755
715
 
756
- std::vector<llama_token> tokenize(const json & json_prompt, bool add_special) const {
757
- // TODO: currently, we tokenize using special tokens by default
758
- // this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
759
- // but it's better compared to completely ignoring ChatML and other chat templates
760
- const bool TMP_FORCE_SPECIAL = true;
761
-
762
- // If `add_bos` is true, we only add BOS, when json_prompt is a string,
763
- // or the first element of the json_prompt array is a string.
764
- std::vector<llama_token> prompt_tokens;
765
-
766
- if (json_prompt.is_array()) {
767
- bool first = true;
768
- for (const auto & p : json_prompt) {
769
- if (p.is_string()) {
770
- auto s = p.template get<std::string>();
771
-
772
- std::vector<llama_token> p;
773
- if (first) {
774
- p = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
775
- first = false;
776
- } else {
777
- p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
778
- }
779
-
780
- prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
781
- } else {
782
- if (first) {
783
- first = false;
784
- }
785
-
786
- prompt_tokens.push_back(p.template get<llama_token>());
787
- }
788
- }
789
- } else {
790
- auto s = json_prompt.template get<std::string>();
791
- prompt_tokens = ::llama_tokenize(ctx, s, add_special, TMP_FORCE_SPECIAL);
792
- }
793
-
794
- return prompt_tokens;
795
- }
796
-
797
716
  server_slot * get_slot_by_id(int id) {
798
717
  for (server_slot & slot : slots) {
799
718
  if (slot.id == id) {
@@ -804,12 +723,12 @@ struct server_context {
804
723
  return nullptr;
805
724
  }
806
725
 
807
- server_slot * get_available_slot(const std::string & prompt) {
726
+ server_slot * get_available_slot(const server_task & task) {
808
727
  server_slot * ret = nullptr;
809
728
 
810
729
  // find the slot that has at least n% prompt similarity
811
- if (ret == nullptr && slot_prompt_similarity != 0.0f && !prompt.empty()) {
812
- int max_lcp_len = 0;
730
+ if (ret == nullptr && slot_prompt_similarity != 0.0f) {
731
+ int lcs_len = 0;
813
732
  float similarity = 0;
814
733
 
815
734
  for (server_slot & slot : slots) {
@@ -818,32 +737,27 @@ struct server_context {
818
737
  continue;
819
738
  }
820
739
 
821
- // skip the slot if it does not contains prompt
822
- if (!slot.prompt.is_string()) {
740
+ // skip the slot if it does not contains cached tokens
741
+ if (slot.cache_tokens.empty()) {
823
742
  continue;
824
743
  }
825
744
 
826
- // current slot's prompt
827
- std::string slot_prompt = slot.prompt.get<std::string>();
828
-
829
- // length of the current slot's prompt
830
- int slot_prompt_len = slot_prompt.size();
745
+ // length of the Longest Common Subsequence between the current slot's prompt and the input prompt
746
+ int cur_lcs_len = longest_common_subsequence(slot.cache_tokens, task.prompt_tokens);
831
747
 
832
- // length of the Longest Common Prefix between the current slot's prompt and the input prompt
833
- int lcp_len = common_part(slot_prompt, prompt);
834
-
835
- // fraction of the common substring length compared to the current slot's prompt length
836
- similarity = static_cast<float>(lcp_len) / slot_prompt_len;
748
+ // fraction of the common subsequence length compared to the current slot's prompt length
749
+ float cur_similarity = static_cast<float>(cur_lcs_len) / static_cast<int>(slot.cache_tokens.size());
837
750
 
838
751
  // select the current slot if the criteria match
839
- if (lcp_len > max_lcp_len && similarity > slot_prompt_similarity) {
840
- max_lcp_len = lcp_len;
752
+ if (cur_lcs_len > lcs_len && cur_similarity > slot_prompt_similarity) {
753
+ lcs_len = cur_lcs_len;
754
+ similarity = cur_similarity;
841
755
  ret = &slot;
842
756
  }
843
757
  }
844
758
 
845
759
  if (ret != nullptr) {
846
- SLT_DBG(*ret, "selected slot by lcp similarity, max_lcp_len = %d, similarity = %f\n", max_lcp_len, similarity);
760
+ SLT_DBG(*ret, "selected slot by lcs similarity, lcs_len = %d, similarity = %f\n", lcs_len, similarity);
847
761
  }
848
762
  }
849
763
 
@@ -885,30 +799,57 @@ struct server_context {
885
799
  slot.oaicompat_model = "";
886
800
  }
887
801
 
888
- slot.params.stream = json_value(data, "stream", false);
889
- slot.params.cache_prompt = json_value(data, "cache_prompt", false);
890
- slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
891
- slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
892
- slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
893
- slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
894
- slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
895
- slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
896
- slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
897
- slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
898
- slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
899
- slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
900
- slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
901
- slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
902
- slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
903
- slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
904
- slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
905
- slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
906
- slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
907
- slot.params.n_keep = json_value(data, "n_keep", slot.params.n_keep);
908
- slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
909
- slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
910
- slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
911
- slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
802
+ slot.params.stream = json_value(data, "stream", false);
803
+ slot.params.cache_prompt = json_value(data, "cache_prompt", false);
804
+ slot.params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", default_params.n_predict));
805
+ slot.params.n_indent = json_value(data, "n_indent", default_params.n_indent);
806
+ slot.sparams.top_k = json_value(data, "top_k", default_sparams.top_k);
807
+ slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
808
+ slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
809
+ slot.sparams.xtc_probability = json_value(data, "xtc_probability", default_sparams.xtc_probability);
810
+ slot.sparams.xtc_threshold = json_value(data, "xtc_threshold", default_sparams.xtc_threshold);
811
+ slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
812
+ slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
813
+ slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
814
+ slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
815
+ slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
816
+ slot.sparams.penalty_repeat = json_value(data, "repeat_penalty", default_sparams.penalty_repeat);
817
+ slot.sparams.penalty_freq = json_value(data, "frequency_penalty", default_sparams.penalty_freq);
818
+ slot.sparams.penalty_present = json_value(data, "presence_penalty", default_sparams.penalty_present);
819
+ slot.sparams.dry_multiplier = json_value(data, "dry_multiplier", default_sparams.dry_multiplier);
820
+ slot.sparams.dry_base = json_value(data, "dry_base", default_sparams.dry_base);
821
+ slot.sparams.dry_allowed_length = json_value(data, "dry_allowed_length", default_sparams.dry_allowed_length);
822
+ slot.sparams.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", default_sparams.dry_penalty_last_n);
823
+ slot.sparams.mirostat = json_value(data, "mirostat", default_sparams.mirostat);
824
+ slot.sparams.mirostat_tau = json_value(data, "mirostat_tau", default_sparams.mirostat_tau);
825
+ slot.sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
826
+ slot.sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
827
+ slot.params.n_keep = json_value(data, "n_keep", default_params.n_keep);
828
+ slot.params.n_discard = json_value(data, "n_discard", default_params.n_discard);
829
+ slot.sparams.seed = json_value(data, "seed", default_sparams.seed);
830
+ slot.sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
831
+ slot.sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
832
+ //slot.params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", default_params.t_max_prompt_ms); // TODO: implement
833
+ slot.params.t_max_predict_ms = json_value(data, "t_max_predict_ms", default_params.t_max_predict_ms);
834
+
835
+ if (slot.sparams.dry_base < 1.0f)
836
+ {
837
+ slot.sparams.dry_base = default_sparams.dry_base;
838
+ }
839
+
840
+ // sequence breakers for DRY
841
+ {
842
+ // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format
843
+ // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39
844
+
845
+ if (data.contains("dry_sequence_breakers")) {
846
+ slot.sparams.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector<std::string>());
847
+ if (slot.sparams.dry_sequence_breakers.empty()) {
848
+ send_error(task, "Error: dry_sequence_breakers must be a non-empty array of strings", ERROR_TYPE_INVALID_REQUEST);
849
+ return false;
850
+ }
851
+ }
852
+ }
912
853
 
913
854
  // process "json_schema" and "grammar"
914
855
  if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
@@ -917,19 +858,14 @@ struct server_context {
917
858
  }
918
859
  if (data.contains("json_schema") && !data.contains("grammar")) {
919
860
  try {
920
- auto schema = json_value(data, "json_schema", json::object());
921
- slot.sparams.grammar = json_schema_to_grammar(schema);
861
+ auto schema = json_value(data, "json_schema", json::object());
862
+ slot.sparams.grammar = json_schema_to_grammar(schema);
922
863
  } catch (const std::exception & e) {
923
864
  send_error(task, std::string("\"json_schema\": ") + e.what(), ERROR_TYPE_INVALID_REQUEST);
924
865
  return false;
925
866
  }
926
867
  } else {
927
- slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
928
- }
929
-
930
- if (slot.params.cache_prompt && slot.ga_n != 1) {
931
- slot.params.cache_prompt = false;
932
- SLT_WRN(slot, "%s", "group-attention is not supported with prompt caching. disabling cache\n");
868
+ slot.sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
933
869
  }
934
870
 
935
871
  if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@@ -938,39 +874,6 @@ struct server_context {
938
874
  SLT_WRN(slot, "n_predict = %d exceeds server configuration, setting to %d", slot.n_predict, slot.n_predict);
939
875
  }
940
876
 
941
- // infill
942
- slot.params.input_prefix = json_value(data, "input_prefix", default_params.input_prefix);
943
- slot.params.input_suffix = json_value(data, "input_suffix", default_params.input_suffix);
944
-
945
- // get prompt
946
- if (task.cmpl_type != SERVER_TASK_CMPL_TYPE_INFILL) {
947
- const auto & prompt = data.find("prompt");
948
- if (prompt == data.end()) {
949
- send_error(task, "\"prompt\" must be provided", ERROR_TYPE_INVALID_REQUEST);
950
- return false;
951
- }
952
-
953
- if ((prompt->is_string()) ||
954
- (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_string()) ||
955
- (prompt->is_array() && !prompt->empty() && prompt->at(0).is_number_integer())) {
956
- slot.prompt = *prompt;
957
- } else if (prompt->is_array() && prompt->size() == 1 && prompt->at(0).is_array()) {
958
- slot.prompt = prompt->at(0);
959
- } else if (prompt->is_array() && prompt->size() > 1) {
960
- // array of strings
961
- for (const auto & el : *prompt) {
962
- if (!el.is_string()) {
963
- send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
964
- return false;
965
- }
966
- }
967
- slot.prompt = *prompt;
968
- } else {
969
- send_error(task, "\"prompt\" must be a string, an array of strings or an array of integers", ERROR_TYPE_INVALID_REQUEST);
970
- return false;
971
- }
972
- }
973
-
974
877
  {
975
878
  slot.sparams.logit_bias.clear();
976
879
 
@@ -999,7 +902,7 @@ struct server_context {
999
902
  slot.sparams.logit_bias.push_back({tok, bias});
1000
903
  }
1001
904
  } else if (el[0].is_string()) {
1002
- auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
905
+ auto toks = common_tokenize(model, el[0].get<std::string>(), false);
1003
906
  for (auto tok : toks) {
1004
907
  slot.sparams.logit_bias.push_back({tok, bias});
1005
908
  }
@@ -1024,14 +927,22 @@ struct server_context {
1024
927
 
1025
928
  {
1026
929
  const auto & samplers = data.find("samplers");
1027
- if (samplers != data.end() && samplers->is_array()) {
1028
- std::vector<std::string> sampler_names;
1029
- for (const auto & name : *samplers) {
1030
- if (name.is_string()) {
1031
- sampler_names.emplace_back(name);
930
+ if (samplers != data.end()) {
931
+ if (samplers->is_array()) {
932
+ std::vector<std::string> sampler_names;
933
+ for (const auto & name : *samplers) {
934
+ if (name.is_string()) {
935
+ sampler_names.emplace_back(name);
936
+ }
1032
937
  }
938
+ slot.sparams.samplers = common_sampler_types_from_names(sampler_names, false);
939
+ } else if (samplers->is_string()){
940
+ std::string sampler_string;
941
+ for (const auto & name : *samplers) {
942
+ sampler_string += name;
943
+ }
944
+ slot.sparams.samplers = common_sampler_types_from_chars(sampler_string);
1033
945
  }
1034
- slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
1035
946
  } else {
1036
947
  slot.sparams.samplers = default_sparams.samplers;
1037
948
  }
@@ -1039,10 +950,10 @@ struct server_context {
1039
950
 
1040
951
  {
1041
952
  if (slot.smpl != nullptr) {
1042
- gpt_sampler_free(slot.smpl);
953
+ common_sampler_free(slot.smpl);
1043
954
  }
1044
955
 
1045
- slot.smpl = gpt_sampler_init(model, slot.sparams);
956
+ slot.smpl = common_sampler_init(model, slot.sparams);
1046
957
  if (slot.smpl == nullptr) {
1047
958
  // for now, the only error that may happen here is invalid grammar
1048
959
  send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
@@ -1050,8 +961,7 @@ struct server_context {
1050
961
  }
1051
962
  }
1052
963
 
1053
- slot.state = SLOT_STATE_PROCESSING_PROMPT;
1054
- slot.prompt_tokens.clear();
964
+ slot.state = SLOT_STATE_STARTED;
1055
965
 
1056
966
  SLT_INF(slot, "%s", "processing task\n");
1057
967
 
@@ -1066,59 +976,9 @@ struct server_context {
1066
976
  clean_kv_cache = false;
1067
977
  }
1068
978
 
1069
- void system_prompt_update() {
1070
- SRV_DBG("updating system prompt: '%s'\n", system_prompt.c_str());
1071
-
1072
- kv_cache_clear();
1073
- system_tokens.clear();
1074
-
1075
- if (!system_prompt.empty()) {
1076
- system_tokens = ::llama_tokenize(ctx, system_prompt, true);
1077
-
1078
- const int32_t n_batch = llama_n_batch(ctx);
1079
- const int32_t n_tokens_prompt = system_tokens.size();
1080
-
1081
- for (int32_t i = 0; i < n_tokens_prompt; i += n_batch) {
1082
- const int32_t n_tokens = std::min(n_batch, n_tokens_prompt - i);
1083
-
1084
- llama_batch_clear(batch);
1085
-
1086
- for (int32_t j = 0; j < n_tokens; ++j) {
1087
- llama_batch_add(batch, system_tokens[i + j], i + j, { 0 }, false);
1088
- }
1089
-
1090
- if (llama_decode(ctx, batch) != 0) {
1091
- SRV_ERR("%s", "llama_decode() failed\n");
1092
- return;
1093
- }
1094
- }
1095
-
1096
- // assign the system KV cache to all parallel sequences
1097
- for (int32_t i = 1; i <= params.n_parallel; ++i) {
1098
- llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
1099
- }
1100
- }
1101
-
1102
- system_need_update = false;
1103
- }
1104
-
1105
- bool system_prompt_set(const std::string & sys_prompt) {
1106
- SRV_DBG("system prompt set: '%s'\n", system_prompt.c_str());
1107
-
1108
- system_prompt = sys_prompt;
1109
-
1110
- // release all slots
1111
- for (server_slot & slot : slots) {
1112
- slot.release();
1113
- }
1114
-
1115
- system_need_update = true;
1116
- return true;
1117
- }
1118
-
1119
979
  bool process_token(completion_token_output & result, server_slot & slot) {
1120
980
  // remember which tokens were sampled - used for repetition penalties during sampling
1121
- const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special);
981
+ const std::string token_str = common_token_to_piece(ctx, result.tok, params.special);
1122
982
  slot.sampled = result.tok;
1123
983
 
1124
984
  // search stop word and delete it
@@ -1151,22 +1011,21 @@ struct server_context {
1151
1011
  size_t pos = std::min(slot.n_sent_text, slot.generated_text.size());
1152
1012
 
1153
1013
  const std::string str_test = slot.generated_text.substr(pos);
1154
- bool is_stop_full = false;
1014
+ bool send_text = true;
1155
1015
 
1156
1016
  size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
1157
1017
  if (stop_pos != std::string::npos) {
1158
- is_stop_full = true;
1159
1018
  slot.generated_text.erase(
1160
1019
  slot.generated_text.begin() + pos + stop_pos,
1161
1020
  slot.generated_text.end());
1162
1021
  pos = std::min(slot.n_sent_text, slot.generated_text.size());
1163
- } else {
1164
- is_stop_full = false;
1022
+ } else if (slot.has_next_token) {
1165
1023
  stop_pos = slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_PARTIAL);
1024
+ send_text = stop_pos == std::string::npos;
1166
1025
  }
1167
1026
 
1168
1027
  // check if there is any token to predict
1169
- if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) {
1028
+ if (send_text) {
1170
1029
  // no send the stop word in the response
1171
1030
  result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
1172
1031
  slot.n_sent_text += result.text_to_send.size();
@@ -1191,13 +1050,63 @@ struct server_context {
1191
1050
  SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict);
1192
1051
  }
1193
1052
 
1053
+ if (slot.has_new_line) {
1054
+ // if we have already seen a new line, we stop after a certain time limit
1055
+ if (slot.params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.params.t_max_predict_ms)) {
1056
+ slot.stopped_limit = true;
1057
+ slot.has_next_token = false;
1058
+
1059
+ SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.params.t_max_predict_ms);
1060
+ }
1061
+
1062
+ // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent
1063
+ if (slot.params.n_indent > 0) {
1064
+ // check the current indentation
1065
+ // TODO: improve by not doing it more than once for each new line
1066
+ if (slot.last_nl_pos > 0) {
1067
+ size_t pos = slot.last_nl_pos;
1068
+
1069
+ int n_indent = 0;
1070
+ while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) {
1071
+ n_indent++;
1072
+ pos++;
1073
+ }
1074
+
1075
+ if (pos < slot.generated_text.size() && n_indent < slot.params.n_indent) {
1076
+ slot.stopped_limit = true;
1077
+ slot.has_next_token = false;
1078
+
1079
+ // cut the last line
1080
+ slot.generated_text.erase(pos, std::string::npos);
1081
+
1082
+ SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent);
1083
+ }
1084
+ }
1085
+
1086
+ // find the next new line
1087
+ {
1088
+ const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos);
1089
+
1090
+ if (pos != std::string::npos) {
1091
+ slot.last_nl_pos = pos + 1;
1092
+ }
1093
+ }
1094
+ }
1095
+ }
1096
+
1097
+ // check if there is a new line in the generated text
1098
+ if (result.text_to_send.find('\n') != std::string::npos) {
1099
+ slot.has_new_line = true;
1100
+ }
1101
+
1194
1102
  // if context shift is disabled, we stop when it reaches the context limit
1195
- if (slot.n_decoded >= slot.n_ctx) {
1103
+ if (slot.n_past >= slot.n_ctx) {
1196
1104
  slot.truncated = true;
1197
1105
  slot.stopped_limit = true;
1198
1106
  slot.has_next_token = false;
1199
1107
 
1200
- SLT_DBG(slot, "stopped due to running out of context capacity, n_decoded = %d, n_ctx = %d\n", slot.n_decoded, slot.n_ctx);
1108
+ SLT_DBG(slot, "stopped due to running out of context capacity, n_past = %d, n_prompt_tokens = %d, n_decoded = %d, n_ctx = %d\n",
1109
+ slot.n_decoded, slot.n_prompt_tokens, slot.n_past, slot.n_ctx);
1201
1110
  }
1202
1111
 
1203
1112
  if (llama_token_is_eog(model, result.tok)) {
@@ -1209,18 +1118,18 @@ struct server_context {
1209
1118
 
1210
1119
  const auto n_ctx_train = llama_n_ctx_train(model);
1211
1120
 
1212
- if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.ga_n == 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1121
+ if (slot.params.n_predict < 1 && slot.n_predict < 1 && slot.n_prompt_tokens + slot.n_decoded >= n_ctx_train) {
1213
1122
  slot.truncated = true;
1214
1123
  slot.stopped_limit = true;
1215
1124
  slot.has_next_token = false; // stop prediction
1216
1125
 
1217
1126
  SLT_WRN(slot,
1218
- "n_predict (%d) is not set and self-context extend is disabled. "
1127
+ "n_predict (%d) is set for infinite generation. "
1219
1128
  "Limiting generated tokens to n_ctx_train (%d) to avoid EOS-less generation infinite loop\n",
1220
1129
  slot.params.n_predict, n_ctx_train);
1221
1130
  }
1222
1131
 
1223
- SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: '%s'\n", slot.n_decoded, slot.n_remaining, token_str.c_str());
1132
+ SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str());
1224
1133
 
1225
1134
  return slot.has_next_token; // continue
1226
1135
  }
@@ -1229,7 +1138,7 @@ struct server_context {
1229
1138
  std::vector<std::string> samplers;
1230
1139
  samplers.reserve(slot.sparams.samplers.size());
1231
1140
  for (const auto & sampler : slot.sparams.samplers) {
1232
- samplers.emplace_back(gpt_sampler_type_to_str(sampler));
1141
+ samplers.emplace_back(common_sampler_type_to_str(sampler));
1233
1142
  }
1234
1143
 
1235
1144
  return json {
@@ -1237,19 +1146,25 @@ struct server_context {
1237
1146
  {"n_predict", slot.n_predict}, // Server configured n_predict
1238
1147
  {"model", params.model_alias},
1239
1148
  {"seed", slot.sparams.seed},
1240
- {"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0},
1149
+ {"seed_cur", slot.smpl ? common_sampler_get_seed(slot.smpl) : 0},
1241
1150
  {"temperature", slot.sparams.temp},
1242
1151
  {"dynatemp_range", slot.sparams.dynatemp_range},
1243
1152
  {"dynatemp_exponent", slot.sparams.dynatemp_exponent},
1244
1153
  {"top_k", slot.sparams.top_k},
1245
1154
  {"top_p", slot.sparams.top_p},
1246
1155
  {"min_p", slot.sparams.min_p},
1247
- {"tfs_z", slot.sparams.tfs_z},
1156
+ {"xtc_probability", slot.sparams.xtc_probability},
1157
+ {"xtc_threshold", slot.sparams.xtc_threshold},
1248
1158
  {"typical_p", slot.sparams.typ_p},
1249
1159
  {"repeat_last_n", slot.sparams.penalty_last_n},
1250
1160
  {"repeat_penalty", slot.sparams.penalty_repeat},
1251
1161
  {"presence_penalty", slot.sparams.penalty_present},
1252
1162
  {"frequency_penalty", slot.sparams.penalty_freq},
1163
+ {"dry_multiplier", slot.sparams.dry_multiplier},
1164
+ {"dry_base", slot.sparams.dry_base},
1165
+ {"dry_allowed_length", slot.sparams.dry_allowed_length},
1166
+ {"dry_penalty_last_n", slot.sparams.dry_penalty_last_n},
1167
+ {"dry_sequence_breakers", slot.sparams.dry_sequence_breakers},
1253
1168
  {"mirostat", slot.sparams.mirostat},
1254
1169
  {"mirostat_tau", slot.sparams.mirostat_tau},
1255
1170
  {"mirostat_eta", slot.sparams.mirostat_eta},
@@ -1302,7 +1217,7 @@ struct server_context {
1302
1217
  };
1303
1218
 
1304
1219
  if (slot.sparams.n_probs > 0) {
1305
- const std::vector<llama_token> to_send_toks = llama_tokenize(ctx, tkn.text_to_send, false);
1220
+ const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
1306
1221
  const size_t probs_pos = std::min(slot.n_sent_token_probs, slot.generated_token_probs.size());
1307
1222
  const size_t probs_stop_pos = std::min(slot.n_sent_token_probs + to_send_toks.size(), slot.generated_token_probs.size());
1308
1223
 
@@ -1338,7 +1253,8 @@ struct server_context {
1338
1253
  {"tokens_predicted", slot.n_decoded},
1339
1254
  {"tokens_evaluated", slot.n_prompt_tokens},
1340
1255
  {"generation_settings", get_formated_generation(slot)},
1341
- {"prompt", slot.prompt},
1256
+ {"prompt", common_detokenize(ctx, slot.prompt_tokens)},
1257
+ {"has_new_line", slot.has_new_line},
1342
1258
  {"truncated", slot.truncated},
1343
1259
  {"stopped_eos", slot.stopped_eos},
1344
1260
  {"stopped_word", slot.stopped_word},
@@ -1352,7 +1268,7 @@ struct server_context {
1352
1268
  if (slot.sparams.n_probs > 0) {
1353
1269
  std::vector<completion_token_output> probs;
1354
1270
  if (!slot.params.stream && slot.stopped_word) {
1355
- const std::vector<llama_token> stop_word_toks = llama_tokenize(ctx, slot.stopping_word, false);
1271
+ const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false);
1356
1272
 
1357
1273
  size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size());
1358
1274
  probs = std::vector<completion_token_output>(
@@ -1377,16 +1293,16 @@ struct server_context {
1377
1293
 
1378
1294
  void send_embedding(const server_slot & slot, const llama_batch & batch) {
1379
1295
  server_task_result res;
1380
- res.id = slot.id_task;
1381
- res.error = false;
1382
- res.stop = true;
1296
+ res.id = slot.id_task;
1297
+ res.error = false;
1298
+ res.stop = true;
1383
1299
 
1384
1300
  const int n_embd = llama_n_embd(model);
1385
1301
 
1386
1302
  std::vector<float> embd_res(n_embd, 0.0f);
1387
1303
 
1388
1304
  for (int i = 0; i < batch.n_tokens; ++i) {
1389
- if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1305
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1390
1306
  continue;
1391
1307
  }
1392
1308
 
@@ -1406,7 +1322,7 @@ struct server_context {
1406
1322
  continue;
1407
1323
  }
1408
1324
 
1409
- llama_embd_normalize(embd, embd_res.data(), n_embd);
1325
+ common_embd_normalize(embd, embd_res.data(), n_embd);
1410
1326
 
1411
1327
  res.data = json {
1412
1328
  {"embedding", embd_res},
@@ -1421,12 +1337,12 @@ struct server_context {
1421
1337
 
1422
1338
  void send_rerank(const server_slot & slot, const llama_batch & batch) {
1423
1339
  server_task_result res;
1424
- res.id = slot.id_task;
1425
- res.error = false;
1426
- res.stop = true;
1340
+ res.id = slot.id_task;
1341
+ res.error = false;
1342
+ res.stop = true;
1427
1343
 
1428
1344
  for (int i = 0; i < batch.n_tokens; ++i) {
1429
- if (!batch.logits[i] || batch.seq_id[i][0] != slot.id + 1) {
1345
+ if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) {
1430
1346
  continue;
1431
1347
  }
1432
1348
 
@@ -1461,19 +1377,17 @@ struct server_context {
1461
1377
  // Functions to create new task(s) and receive result(s)
1462
1378
  //
1463
1379
 
1464
- std::vector<server_task> create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) {
1380
+ // break the input "prompt" into multiple tasks if needed, then format and tokenize the input prompt(s)
1381
+ std::vector<server_task> create_tasks_inference(json data, server_task_inf_type inf_type) {
1465
1382
  std::vector<server_task> tasks;
1466
- auto create_task = [&](json & task_data, bool replace_prompt, json prompt) {
1383
+ auto create_task = [&](json & task_data, llama_tokens & prompt_tokens) {
1384
+ SRV_DBG("create task, n_tokens = %d\n", (int) prompt_tokens.size());
1467
1385
  server_task task;
1468
- task.id = queue_tasks.get_new_id();
1469
- task.cmpl_type = cmpl_type;
1470
- task.type = SERVER_TASK_TYPE_COMPLETION;
1471
- if (replace_prompt) {
1472
- task.data = task_data;
1473
- task.data["prompt"] = std::move(prompt);
1474
- } else {
1475
- task.data = std::move(task_data);
1476
- }
1386
+ task.id = queue_tasks.get_new_id();
1387
+ task.inf_type = inf_type;
1388
+ task.type = SERVER_TASK_TYPE_INFERENCE;
1389
+ task.data = task_data;
1390
+ task.prompt_tokens = std::move(prompt_tokens);
1477
1391
  tasks.push_back(std::move(task));
1478
1392
  };
1479
1393
 
@@ -1482,43 +1396,49 @@ struct server_context {
1482
1396
  throw std::runtime_error(error_msg);
1483
1397
  }
1484
1398
 
1485
- json prompt = data.at("prompt");
1486
-
1487
- // if the prompt is a singleton (i.e. a string or a list of tokens), we only need to create single task
1488
- if (prompt.is_string() || json_is_array_of_numbers(prompt)) {
1489
- data["index"] = 0;
1490
- create_task(data, false, nullptr);
1491
- }
1492
- // otherwise, it's a multiple-prompt task, we break it into smaller tasks
1493
- else if (prompt.is_array()) {
1494
- std::vector<json> prompts = prompt;
1495
- if (cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
1496
- // prompts[0] is the question
1497
- // the rest are the answers/documents
1498
- SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) prompts.size() - 1);
1499
- for (size_t i = 1; i < prompts.size(); i++) {
1500
- json qd;
1501
- qd.push_back(prompts[0]);
1502
- qd.push_back(prompts[i]);
1503
- data["index"] = i - 1;
1504
- create_task(data, true, qd);
1505
- }
1506
- } else {
1507
- SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) prompts.size());
1508
- for (size_t i = 0; i < prompts.size(); i++) {
1509
- const auto & e = prompts[i];
1510
- if (e.is_string() || json_is_array_of_numbers(e)) {
1399
+ // because llama_tokenize api is thread-safe, we can tokenize the prompt from HTTP thread
1400
+ bool add_special = inf_type != SERVER_TASK_INF_TYPE_RERANK && inf_type != SERVER_TASK_INF_TYPE_INFILL;
1401
+ std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx, data.at("prompt"), add_special, true);
1402
+ switch (inf_type) {
1403
+ case SERVER_TASK_INF_TYPE_RERANK:
1404
+ {
1405
+ // prompts[0] is the question
1406
+ // the rest are the answers/documents
1407
+ GGML_ASSERT(tokenized_prompts.size() > 1);
1408
+ SRV_DBG("creating rerank tasks, n_prompts = %d\n", (int) tokenized_prompts.size() - 1);
1409
+ for (size_t i = 1; i < tokenized_prompts.size(); i++) {
1410
+ data["index"] = i - 1;
1411
+ auto tokens = format_rerank(model, tokenized_prompts[0], tokenized_prompts[i]);
1412
+ create_task(data, tokens);
1413
+ }
1414
+ } break;
1415
+ case SERVER_TASK_INF_TYPE_INFILL:
1416
+ {
1417
+ SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
1418
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
1511
1419
  data["index"] = i;
1512
- create_task(data, true, e);
1513
- } else {
1514
- throw std::runtime_error(error_msg);
1420
+ auto tokens = format_infill(
1421
+ ctx,
1422
+ data.at("input_prefix"),
1423
+ data.at("input_suffix"),
1424
+ data.at("input_extra"),
1425
+ params.n_batch,
1426
+ params.n_predict,
1427
+ slots[0].n_ctx, // TODO: there should be a better way
1428
+ params.spm_infill,
1429
+ tokenized_prompts[i]
1430
+ );
1431
+ create_task(data, tokens);
1432
+ }
1433
+ } break;
1434
+ default:
1435
+ {
1436
+ SRV_DBG("creating multi-prompt tasks, n_prompts = %d\n", (int) tokenized_prompts.size());
1437
+ for (size_t i = 0; i < tokenized_prompts.size(); i++) {
1438
+ data["index"] = i;
1439
+ create_task(data, tokenized_prompts[i]);
1515
1440
  }
1516
1441
  }
1517
- }
1518
- }
1519
- // invalid case
1520
- else {
1521
- throw std::runtime_error(error_msg);
1522
1442
  }
1523
1443
 
1524
1444
  return tasks;
@@ -1540,7 +1460,7 @@ struct server_context {
1540
1460
  queue_tasks.post(cancel_tasks, true);
1541
1461
  }
1542
1462
 
1543
- // receive the results from task(s) created by create_tasks_cmpl
1463
+ // receive the results from task(s) created by create_tasks_inference
1544
1464
  void receive_cmpl_results(
1545
1465
  const std::unordered_set<int> & id_tasks,
1546
1466
  const std::function<void(std::vector<server_task_result>&)> & result_handler,
@@ -1564,7 +1484,7 @@ struct server_context {
1564
1484
  result_handler(results);
1565
1485
  }
1566
1486
 
1567
- // receive the results from task(s) created by create_tasks_cmpl, in stream mode
1487
+ // receive the results from task(s) created by create_tasks_inference, in stream mode
1568
1488
  void receive_cmpl_results_stream(
1569
1489
  const std::unordered_set<int> & id_tasks, const
1570
1490
  std::function<bool(server_task_result&)> & result_handler, const
@@ -1595,24 +1515,13 @@ struct server_context {
1595
1515
  // Functions to process the task
1596
1516
  //
1597
1517
 
1598
- void process_single_task(const server_task & task) {
1518
+ void process_single_task(server_task task) {
1599
1519
  switch (task.type) {
1600
- case SERVER_TASK_TYPE_COMPLETION:
1520
+ case SERVER_TASK_TYPE_INFERENCE:
1601
1521
  {
1602
1522
  const int id_slot = json_value(task.data, "id_slot", -1);
1603
1523
 
1604
- server_slot * slot;
1605
-
1606
- if (id_slot != -1) {
1607
- slot = get_slot_by_id(id_slot);
1608
- } else {
1609
- std::string prompt;
1610
- if (task.data.contains("prompt") && task.data.at("prompt").is_string()) {
1611
- prompt = json_value(task.data, "prompt", std::string());
1612
- }
1613
-
1614
- slot = get_available_slot(prompt);
1615
- }
1524
+ server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task);
1616
1525
 
1617
1526
  if (slot == nullptr) {
1618
1527
  // if no slot is available, we defer this task for processing later
@@ -1627,21 +1536,12 @@ struct server_context {
1627
1536
  break;
1628
1537
  }
1629
1538
 
1630
- if (task.data.contains("system_prompt")) {
1631
- std::string sys_prompt = json_value(task.data, "system_prompt", std::string());
1632
- system_prompt_set(sys_prompt);
1633
-
1634
- for (server_slot & slot : slots) {
1635
- slot.n_past = 0;
1636
- slot.n_past_se = 0;
1637
- }
1638
- }
1639
-
1640
1539
  slot->reset();
1641
1540
 
1642
- slot->id_task = task.id;
1643
- slot->cmpl_type = task.cmpl_type;
1644
- slot->index = json_value(task.data, "index", 0);
1541
+ slot->id_task = task.id;
1542
+ slot->inf_type = task.inf_type;
1543
+ slot->index = json_value(task.data, "index", 0);
1544
+ slot->prompt_tokens = std::move(task.prompt_tokens);
1645
1545
 
1646
1546
  if (!launch_slot_with_task(*slot, task)) {
1647
1547
  SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id);
@@ -1671,12 +1571,13 @@ struct server_context {
1671
1571
 
1672
1572
  for (server_slot & slot : slots) {
1673
1573
  json slot_data = get_formated_generation(slot);
1674
- slot_data["id"] = slot.id;
1675
- slot_data["id_task"] = slot.id_task;
1676
- slot_data["state"] = slot.state;
1677
- slot_data["prompt"] = slot.prompt;
1678
- slot_data["next_token"] = {
1574
+ slot_data["id"] = slot.id;
1575
+ slot_data["id_task"] = slot.id_task;
1576
+ slot_data["is_processing"] = slot.is_processing();
1577
+ slot_data["prompt"] = common_detokenize(ctx, slot.prompt_tokens);
1578
+ slot_data["next_token"] = {
1679
1579
  {"has_next_token", slot.has_next_token},
1580
+ {"has_new_line", slot.has_new_line},
1680
1581
  {"n_remain", slot.n_remaining},
1681
1582
  {"n_decoded", slot.n_decoded},
1682
1583
  {"stopped_eos", slot.stopped_eos},
@@ -1685,10 +1586,10 @@ struct server_context {
1685
1586
  {"stopping_word", slot.stopping_word},
1686
1587
  };
1687
1588
 
1688
- if (slot_data["state"] == SLOT_STATE_IDLE) {
1689
- n_idle_slots++;
1690
- } else {
1589
+ if (slot.is_processing()) {
1691
1590
  n_processing_slots++;
1591
+ } else {
1592
+ n_idle_slots++;
1692
1593
  }
1693
1594
 
1694
1595
  slots_data.push_back(slot_data);
@@ -1750,7 +1651,7 @@ struct server_context {
1750
1651
  std::string filename = task.data.at("filename");
1751
1652
  std::string filepath = task.data.at("filepath");
1752
1653
 
1753
- const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
1654
+ const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), token_count);
1754
1655
 
1755
1656
  const int64_t t_end = ggml_time_us();
1756
1657
  const double t_save_ms = (t_end - t_start) / 1000.0;
@@ -1792,7 +1693,7 @@ struct server_context {
1792
1693
 
1793
1694
  slot->cache_tokens.resize(slot->n_ctx);
1794
1695
  size_t token_count = 0;
1795
- size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
1696
+ size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
1796
1697
  if (nread == 0) {
1797
1698
  slot->cache_tokens.resize(0);
1798
1699
  send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
@@ -1835,7 +1736,7 @@ struct server_context {
1835
1736
 
1836
1737
  // Erase token cache
1837
1738
  const size_t n_erased = slot->cache_tokens.size();
1838
- llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
1739
+ llama_kv_cache_seq_rm(ctx, slot->id, -1, -1);
1839
1740
  slot->cache_tokens.clear();
1840
1741
 
1841
1742
  server_task_result result;
@@ -1850,7 +1751,7 @@ struct server_context {
1850
1751
  } break;
1851
1752
  case SERVER_TASK_TYPE_SET_LORA:
1852
1753
  {
1853
- llama_lora_adapters_apply(ctx, loras);
1754
+ common_lora_adapters_apply(ctx, loras);
1854
1755
  server_task_result result;
1855
1756
  result.id = task.id;
1856
1757
  result.stop = true;
@@ -1862,10 +1763,6 @@ struct server_context {
1862
1763
  }
1863
1764
 
1864
1765
  void update_slots() {
1865
- if (system_need_update) {
1866
- system_prompt_update();
1867
- }
1868
-
1869
1766
  // check if all slots are idle
1870
1767
  {
1871
1768
  bool all_idle = true;
@@ -1879,7 +1776,7 @@ struct server_context {
1879
1776
 
1880
1777
  if (all_idle) {
1881
1778
  SRV_INF("%s", "all slots are idle\n");
1882
- if (system_prompt.empty() && clean_kv_cache) {
1779
+ if (clean_kv_cache) {
1883
1780
  kv_cache_clear();
1884
1781
  }
1885
1782
 
@@ -1900,43 +1797,41 @@ struct server_context {
1900
1797
  // apply context-shift if needed
1901
1798
  // TODO: simplify and improve
1902
1799
  for (server_slot & slot : slots) {
1903
- if (slot.ga_n == 1) {
1904
- if (slot.is_processing() && (int) system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
1905
- if (!params.ctx_shift) {
1906
- // this check is redundant (for good)
1907
- // we should never get here, because generation should already stopped in process_token()
1908
- slot.release();
1909
- send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
1910
- continue;
1911
- }
1912
-
1913
- // Shift context
1914
- const int n_keep = slot.params.n_keep + add_bos_token;
1915
- const int n_left = (int) system_tokens.size() + slot.n_past - n_keep;
1916
- const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
1800
+ if (slot.is_processing() && slot.n_past + 1 >= slot.n_ctx) {
1801
+ if (!params.ctx_shift) {
1802
+ // this check is redundant (for good)
1803
+ // we should never get here, because generation should already stopped in process_token()
1804
+ slot.release();
1805
+ send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER);
1806
+ continue;
1807
+ }
1917
1808
 
1918
- SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
1809
+ // Shift context
1810
+ const int n_keep = slot.params.n_keep + add_bos_token;
1811
+ const int n_left = slot.n_past - n_keep;
1812
+ const int n_discard = slot.params.n_discard ? slot.params.n_discard : (n_left / 2);
1919
1813
 
1920
- llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
1921
- llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
1814
+ SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard);
1922
1815
 
1923
- if (slot.params.cache_prompt) {
1924
- for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
1925
- slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1926
- }
1816
+ llama_kv_cache_seq_rm (ctx, slot.id, n_keep , n_keep + n_discard);
1817
+ llama_kv_cache_seq_add(ctx, slot.id, n_keep + n_discard, slot.n_past, -n_discard);
1927
1818
 
1928
- slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1819
+ if (slot.params.cache_prompt) {
1820
+ for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
1821
+ slot.cache_tokens[i - n_discard] = slot.cache_tokens[i];
1929
1822
  }
1930
1823
 
1931
- slot.n_past -= n_discard;
1932
-
1933
- slot.truncated = true;
1824
+ slot.cache_tokens.resize(slot.cache_tokens.size() - n_discard);
1934
1825
  }
1826
+
1827
+ slot.n_past -= n_discard;
1828
+
1829
+ slot.truncated = true;
1935
1830
  }
1936
1831
  }
1937
1832
 
1938
1833
  // start populating the batch for this iteration
1939
- llama_batch_clear(batch);
1834
+ common_batch_clear(batch);
1940
1835
 
1941
1836
  // frist, add sampled tokens from any ongoing sequences
1942
1837
  for (auto & slot : slots) {
@@ -1946,11 +1841,7 @@ struct server_context {
1946
1841
 
1947
1842
  slot.i_batch = batch.n_tokens;
1948
1843
 
1949
- const int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
1950
-
1951
- // TODO: we always have to take into account the "system_tokens"
1952
- // this is not great and needs to be improved somehow
1953
- llama_batch_add(batch, slot.sampled, system_tokens.size() + slot_npast, { slot.id + 1 }, true);
1844
+ common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
1954
1845
 
1955
1846
  slot.n_past += 1;
1956
1847
 
@@ -1958,8 +1849,8 @@ struct server_context {
1958
1849
  slot.cache_tokens.push_back(slot.sampled);
1959
1850
  }
1960
1851
 
1961
- SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_system_tokens = %d, n_cache_tokens = %d, truncated = %d\n",
1962
- slot.n_ctx, slot.n_past, (int) system_tokens.size(), (int) slot.cache_tokens.size(), slot.truncated);
1852
+ SLT_DBG(slot, "slot decode token, n_ctx = %d, n_past = %d, n_cache_tokens = %d, truncated = %d\n",
1853
+ slot.n_ctx, slot.n_past, (int) slot.cache_tokens.size(), slot.truncated);
1963
1854
  }
1964
1855
 
1965
1856
  // process in chunks of params.n_batch
@@ -1976,80 +1867,33 @@ struct server_context {
1976
1867
  if (params.cont_batching || batch.n_tokens == 0) {
1977
1868
  for (auto & slot : slots) {
1978
1869
  // this slot still has a prompt to be processed
1979
- if (slot.state == SLOT_STATE_PROCESSING_PROMPT) {
1870
+ if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
1980
1871
  auto & prompt_tokens = slot.prompt_tokens;
1981
1872
 
1982
- // we haven't tokenized the prompt yet - do it now:
1983
- if (prompt_tokens.empty()) {
1984
- SLT_INF(slot, "tokenizing prompt, len = %d\n", (int) slot.prompt.size());
1985
-
1873
+ // TODO: maybe move branch to outside of this loop in the future
1874
+ if (slot.state == SLOT_STATE_STARTED) {
1986
1875
  slot.t_start_process_prompt = ggml_time_us();
1987
1876
  slot.t_start_generation = 0;
1988
1877
 
1989
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
1990
- const bool add_bos = llama_add_bos_token(model);
1991
- bool suff_rm_leading_spc = true;
1992
- if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
1993
- params.input_suffix.erase(0, 1);
1994
- suff_rm_leading_spc = false;
1995
- }
1996
-
1997
- auto prefix_tokens = tokenize(slot.params.input_prefix, false);
1998
- auto suffix_tokens = tokenize(slot.params.input_suffix, false);
1999
-
2000
- const int space_token = 29871; // TODO: this should not be hardcoded
2001
- if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
2002
- suffix_tokens.erase(suffix_tokens.begin());
2003
- }
2004
-
2005
- prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
2006
- suffix_tokens.insert(suffix_tokens.begin(), llama_token_suffix(model));
2007
-
2008
- auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
2009
- auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;
2010
- if (add_bos) {
2011
- embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
2012
- }
2013
- embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
2014
-
2015
- const llama_token middle_token = llama_token_middle(model);
2016
- if (middle_token >= 0) {
2017
- embd_inp.push_back(middle_token);
2018
- }
1878
+ slot.n_past = 0;
1879
+ slot.n_prompt_tokens = prompt_tokens.size();
1880
+ slot.state = SLOT_STATE_PROCESSING_PROMPT;
2019
1881
 
2020
- prompt_tokens = embd_inp;
2021
- } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2022
- // require slot.prompt to be array of 2 strings
2023
- if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
2024
- SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
2025
- slot.release();
2026
- send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
2027
- continue;
2028
- }
1882
+ SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
2029
1883
 
2030
- // prompt: [BOS]query[EOS][SEP]doc[EOS]
2031
- prompt_tokens.clear();
2032
- prompt_tokens.push_back(llama_token_bos(model));
2033
- {
2034
- const auto part = tokenize(slot.prompt[0], false);
2035
- prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
2036
- }
2037
- prompt_tokens.push_back(llama_token_eos(model));
2038
- prompt_tokens.push_back(llama_token_sep(model));
2039
- {
2040
- const auto part = tokenize(slot.prompt[1], false);
2041
- prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
1884
+ // print prompt tokens (for debugging)
1885
+ if (1) {
1886
+ // first 16 tokens (avoid flooding logs)
1887
+ for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
1888
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
2042
1889
  }
2043
- prompt_tokens.push_back(llama_token_eos(model));
2044
1890
  } else {
2045
- prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
1891
+ // all
1892
+ for (int i = 0; i < (int) prompt_tokens.size(); i++) {
1893
+ SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
1894
+ }
2046
1895
  }
2047
1896
 
2048
- slot.n_past = 0;
2049
- slot.n_prompt_tokens = prompt_tokens.size();
2050
-
2051
- SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);
2052
-
2053
1897
  // empty prompt passed -> release the slot and send empty response
2054
1898
  if (prompt_tokens.empty()) {
2055
1899
  SLT_WRN(slot, "%s", "empty prompt - releasing slot\n");
@@ -2060,17 +1904,24 @@ struct server_context {
2060
1904
  continue;
2061
1905
  }
2062
1906
 
2063
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2064
- // this prompt is too large to process - discard it
1907
+ if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2065
1908
  if (slot.n_prompt_tokens > n_ubatch) {
2066
1909
  slot.release();
2067
1910
  send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER);
2068
1911
  continue;
2069
1912
  }
1913
+
1914
+ if (slot.n_prompt_tokens > slot.n_ctx) {
1915
+ slot.release();
1916
+ send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_SERVER);
1917
+ continue;
1918
+ }
2070
1919
  } else {
2071
1920
  if (!params.ctx_shift) {
2072
1921
  // if context shift is disabled, we make sure prompt size is smaller than KV size
2073
- if ((int) system_tokens.size() + slot.n_prompt_tokens >= slot.n_ctx) {
1922
+ // TODO: there should be a separate parameter that control prompt truncation
1923
+ // context shift should be applied only during the generation phase
1924
+ if (slot.n_prompt_tokens >= slot.n_ctx) {
2074
1925
  slot.release();
2075
1926
  send_error(slot, "the request exceeds the available context size. try increasing the context size or enable context shift", ERROR_TYPE_INVALID_REQUEST);
2076
1927
  continue;
@@ -2081,14 +1932,14 @@ struct server_context {
2081
1932
  }
2082
1933
  slot.params.n_keep = std::min(slot.n_ctx - 4, slot.params.n_keep);
2083
1934
 
2084
- // if input prompt is too big, truncate it (if group attention self-extend is disabled)
2085
- if (slot.ga_n == 1 && slot.n_prompt_tokens >= slot.n_ctx) {
1935
+ // if input prompt is too big, truncate it
1936
+ if (slot.n_prompt_tokens >= slot.n_ctx) {
2086
1937
  const int n_left = slot.n_ctx - slot.params.n_keep;
2087
1938
 
2088
1939
  const int n_block_size = n_left / 2;
2089
1940
  const int erased_blocks = (slot.n_prompt_tokens - slot.params.n_keep - n_block_size) / n_block_size;
2090
1941
 
2091
- std::vector<llama_token> new_tokens(
1942
+ llama_tokens new_tokens(
2092
1943
  prompt_tokens.begin(),
2093
1944
  prompt_tokens.begin() + slot.params.n_keep);
2094
1945
 
@@ -2107,20 +1958,52 @@ struct server_context {
2107
1958
  GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
2108
1959
  }
2109
1960
 
2110
- gpt_sampler_reset(slot.smpl);
1961
+ if (slot.params.cache_prompt) {
1962
+ // reuse any previously computed tokens that are common with the new prompt
1963
+ slot.n_past = longest_common_prefix(slot.cache_tokens, prompt_tokens);
1964
+
1965
+ // reuse chunks from the cached prompt by shifting their KV cache in the new position
1966
+ if (params.n_cache_reuse > 0) {
1967
+ size_t head_c = slot.n_past; // cache
1968
+ size_t head_p = slot.n_past; // current prompt
2111
1969
 
2112
- if (!slot.params.cache_prompt) {
2113
- slot.n_past_se = 0;
2114
- slot.ga_i = 0;
2115
- } else {
2116
- GGML_ASSERT(slot.ga_n == 1);
1970
+ SLT_DBG(slot, "trying to reuse chunks with size > %d, slot.n_past = %d\n", params.n_cache_reuse, slot.n_past);
2117
1971
 
2118
- // reuse any previously computed tokens that are common with the new prompt
2119
- slot.n_past = common_part(slot.cache_tokens, prompt_tokens);
1972
+ while (head_c < slot.cache_tokens.size() &&
1973
+ head_p < prompt_tokens.size()) {
2120
1974
 
2121
- // push the prompt into the sampling context (do not apply grammar)
2122
- for (int i = 0; i < slot.n_past; ++i) {
2123
- gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
1975
+ size_t n_match = 0;
1976
+ while (head_c + n_match < slot.cache_tokens.size() &&
1977
+ head_p + n_match < prompt_tokens.size() &&
1978
+ slot.cache_tokens[head_c + n_match] == prompt_tokens[head_p + n_match]) {
1979
+
1980
+ n_match++;
1981
+ }
1982
+
1983
+ if (n_match >= (size_t) params.n_cache_reuse) {
1984
+ SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match);
1985
+ //for (size_t i = head_p; i < head_p + n_match; i++) {
1986
+ // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
1987
+ //}
1988
+
1989
+ const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c;
1990
+
1991
+ llama_kv_cache_seq_rm (ctx, slot.id, head_p, head_c);
1992
+ llama_kv_cache_seq_add(ctx, slot.id, head_c, -1, kv_shift);
1993
+
1994
+ for (size_t i = 0; i < n_match; i++) {
1995
+ slot.cache_tokens[head_p + i] = slot.cache_tokens[head_c + i];
1996
+ slot.n_past++;
1997
+ }
1998
+
1999
+ head_c += n_match;
2000
+ head_p += n_match;
2001
+ } else {
2002
+ head_c += 1;
2003
+ }
2004
+ }
2005
+
2006
+ SLT_DBG(slot, "after context reuse, new slot.n_past = %d\n", slot.n_past);
2124
2007
  }
2125
2008
  }
2126
2009
  }
@@ -2130,16 +2013,13 @@ struct server_context {
2130
2013
  SLT_WRN(slot, "need to evaluate at least 1 token to generate logits, n_past = %d, n_prompt_tokens = %d\n", slot.n_past, slot.n_prompt_tokens);
2131
2014
 
2132
2015
  slot.n_past--;
2133
- if (slot.ga_i > 0) {
2134
- slot.n_past_se--;
2135
- }
2136
2016
  }
2137
2017
 
2138
2018
  slot.n_prompt_tokens_processed = 0;
2139
2019
  }
2140
2020
 
2141
2021
  // non-causal tasks require to fit the entire prompt in the physical batch
2142
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING || slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2022
+ if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING || slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2143
2023
  // cannot fit the prompt in the current batch - will try next iter
2144
2024
  if (batch.n_tokens + slot.n_prompt_tokens > n_batch) {
2145
2025
  continue;
@@ -2148,8 +2028,8 @@ struct server_context {
2148
2028
 
2149
2029
  // check that we are in the right batch_type, if not defer the slot
2150
2030
  const bool slot_type =
2151
- slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING ||
2152
- slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK ? 1 : 0;
2031
+ slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING ||
2032
+ slot.inf_type == SERVER_TASK_INF_TYPE_RERANK ? 1 : 0;
2153
2033
 
2154
2034
  if (batch_type == -1) {
2155
2035
  batch_type = slot_type;
@@ -2158,55 +2038,29 @@ struct server_context {
2158
2038
  }
2159
2039
 
2160
2040
  // keep only the common part
2161
- int p0 = (int) system_tokens.size() + slot.n_past;
2162
- if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
2041
+ if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
2163
2042
  // could not partially delete (likely using a non-Transformer model)
2164
- llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
2043
+ llama_kv_cache_seq_rm(ctx, slot.id, -1, -1);
2165
2044
 
2166
- p0 = (int) system_tokens.size();
2167
- if (p0 != 0) {
2168
- // copy over the system prompt when there is one
2169
- llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
2170
- }
2171
-
2172
- // there is no common part left (except for the system prompt)
2045
+ // there is no common part left
2173
2046
  slot.n_past = 0;
2174
- slot.n_past_se = 0;
2175
- slot.ga_i = 0;
2176
- // TODO: is the system prompt ever in the sampling context?
2177
- gpt_sampler_reset(slot.smpl);
2178
2047
  }
2179
2048
 
2049
+ SLT_INF(slot, "kv cache rm [%d, end)\n", slot.n_past);
2050
+
2180
2051
  // remove the non-common part from the cache
2181
2052
  slot.cache_tokens.resize(slot.n_past);
2182
2053
 
2183
- SLT_INF(slot, "kv cache rm [%d, end)\n", p0);
2184
-
2185
- int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;
2186
-
2187
- int32_t ga_i = slot.ga_i;
2188
- int32_t ga_n = slot.ga_n;
2189
- int32_t ga_w = slot.ga_w;
2190
-
2191
2054
  // add prompt tokens for processing in the current batch
2192
- // TODO: the self-extend stuff here is a mess - simplify and/or abstract it somehow
2193
- for (; slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch; ++slot.n_past) {
2194
- if (slot.ga_n != 1) {
2195
- while (slot_npast >= ga_i + ga_w) {
2196
- const int bd = (ga_w/ga_n)*(ga_n - 1);
2197
- slot_npast -= bd;
2198
- ga_i += ga_w/ga_n;
2199
- }
2200
- }
2201
-
2202
- llama_batch_add(batch, prompt_tokens[slot.n_past], system_tokens.size() + slot_npast, { slot.id + 1 }, false);
2055
+ while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) {
2056
+ common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, false);
2203
2057
 
2204
2058
  if (slot.params.cache_prompt) {
2205
2059
  slot.cache_tokens.push_back(prompt_tokens[slot.n_past]);
2206
2060
  }
2207
2061
 
2208
2062
  slot.n_prompt_tokens_processed++;
2209
- slot_npast++;
2063
+ slot.n_past++;
2210
2064
  }
2211
2065
 
2212
2066
  SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens);
@@ -2217,6 +2071,13 @@ struct server_context {
2217
2071
 
2218
2072
  GGML_ASSERT(batch.n_tokens > 0);
2219
2073
 
2074
+ common_sampler_reset(slot.smpl);
2075
+
2076
+ // Process all prompt tokens through sampler system
2077
+ for (int i = 0; i < slot.n_prompt_tokens; ++i) {
2078
+ common_sampler_accept(slot.smpl, prompt_tokens[i], false);
2079
+ }
2080
+
2220
2081
  // extract the logits only for the last token
2221
2082
  batch.logits[batch.n_tokens - 1] = true;
2222
2083
 
@@ -2247,34 +2108,6 @@ struct server_context {
2247
2108
  for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
2248
2109
  const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
2249
2110
 
2250
- for (auto & slot : slots) {
2251
- if (slot.ga_n != 1) {
2252
- // context extension via Self-Extend
2253
- // TODO: simplify and/or abstract this
2254
- while (slot.n_past_se >= slot.ga_i + slot.ga_w) {
2255
- const int ib = (slot.ga_n * slot.ga_i) / slot.ga_w;
2256
- const int bd = (slot.ga_w / slot.ga_n) * (slot.ga_n - 1);
2257
- const int dd = (slot.ga_w / slot.ga_n) - ib * bd - slot.ga_w;
2258
-
2259
- SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i, slot.n_past_se, ib * bd, slot.ga_i + ib * bd, slot.n_past_se + ib * bd);
2260
- SLT_DBG(slot, "div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
2261
- SLT_DBG(slot, "shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
2262
-
2263
- llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
2264
- llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
2265
- llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
2266
-
2267
- slot.n_past_se -= bd;
2268
-
2269
- slot.ga_i += slot.ga_w / slot.ga_n;
2270
-
2271
- SLT_DBG(slot, "\nn_past_old = %d, n_past = %d, ga_i = %d\n\n", slot.n_past_se + bd, slot.n_past_se, slot.ga_i);
2272
- }
2273
-
2274
- slot.n_past_se += n_tokens;
2275
- }
2276
- }
2277
-
2278
2111
  llama_batch batch_view = {
2279
2112
  n_tokens,
2280
2113
  batch.token + i,
@@ -2283,7 +2116,6 @@ struct server_context {
2283
2116
  batch.n_seq_id + i,
2284
2117
  batch.seq_id + i,
2285
2118
  batch.logits + i,
2286
- 0, 0, 0, // unused
2287
2119
  };
2288
2120
 
2289
2121
  const int ret = llama_decode(ctx, batch_view);
@@ -2315,7 +2147,7 @@ struct server_context {
2315
2147
  }
2316
2148
 
2317
2149
  if (slot.state == SLOT_STATE_DONE_PROMPT) {
2318
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_EMBEDDING) {
2150
+ if (slot.inf_type == SERVER_TASK_INF_TYPE_EMBEDDING) {
2319
2151
  // prompt evaluated for embedding
2320
2152
  send_embedding(slot, batch_view);
2321
2153
  slot.release();
@@ -2323,7 +2155,7 @@ struct server_context {
2323
2155
  continue; // continue loop of slots
2324
2156
  }
2325
2157
 
2326
- if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
2158
+ if (slot.inf_type == SERVER_TASK_INF_TYPE_RERANK) {
2327
2159
  send_rerank(slot, batch_view);
2328
2160
  slot.release();
2329
2161
  slot.i_batch = -1;
@@ -2337,9 +2169,9 @@ struct server_context {
2337
2169
  }
2338
2170
 
2339
2171
  completion_token_output result;
2340
- const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
2172
+ const llama_token id = common_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
2341
2173
 
2342
- gpt_sampler_accept(slot.smpl, id, true);
2174
+ common_sampler_accept(slot.smpl, id, true);
2343
2175
 
2344
2176
  slot.n_decoded += 1;
2345
2177
  if (slot.n_decoded == 1) {
@@ -2350,7 +2182,7 @@ struct server_context {
2350
2182
 
2351
2183
  result.tok = id;
2352
2184
 
2353
- const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
2185
+ const auto * cur_p = common_sampler_get_candidates(slot.smpl);
2354
2186
 
2355
2187
  for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
2356
2188
  result.probs.push_back({
@@ -2414,13 +2246,13 @@ inline void signal_handler(int signal) {
2414
2246
 
2415
2247
  int main(int argc, char ** argv) {
2416
2248
  // own arguments required by this example
2417
- gpt_params params;
2249
+ common_params params;
2418
2250
 
2419
- if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
2251
+ if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) {
2420
2252
  return 1;
2421
2253
  }
2422
2254
 
2423
- gpt_init();
2255
+ common_init();
2424
2256
 
2425
2257
  // enabling this will output extra debug information in the HTTP responses from the server
2426
2258
  // see format_final_response_oaicompat()
@@ -2429,10 +2261,6 @@ int main(int argc, char ** argv) {
2429
2261
  // struct that contains llama context and inference
2430
2262
  server_context ctx_server;
2431
2263
 
2432
- if (!params.system_prompt.empty()) {
2433
- ctx_server.system_prompt_set(params.system_prompt);
2434
- }
2435
-
2436
2264
  if (params.model_alias == "unknown") {
2437
2265
  params.model_alias = params.model;
2438
2266
  }
@@ -2442,9 +2270,19 @@ int main(int argc, char ** argv) {
2442
2270
 
2443
2271
  LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency());
2444
2272
  LOG_INF("\n");
2445
- LOG_INF("%s\n", gpt_params_get_system_info(params).c_str());
2273
+ LOG_INF("%s\n", common_params_get_system_info(params).c_str());
2446
2274
  LOG_INF("\n");
2447
2275
 
2276
+ // static files
2277
+ std::map<std::string, server_static_file> static_files = {
2278
+ { "/", { index_html, index_html_len, "text/html; charset=utf-8" }},
2279
+ { "/completion.js", { completion_js, completion_js_len, "text/javascript; charset=utf-8" }},
2280
+ { "/deps_daisyui.min.css", { deps_daisyui_min_css, deps_daisyui_min_css_len, "text/css; charset=utf-8" }},
2281
+ { "/deps_markdown-it.js", { deps_markdown_it_js, deps_markdown_it_js_len, "text/javascript; charset=utf-8" }},
2282
+ { "/deps_tailwindcss.js", { deps_tailwindcss_js, deps_tailwindcss_js_len, "text/javascript; charset=utf-8" }},
2283
+ { "/deps_vue.esm-browser.js", { deps_vue_esm_browser_js, deps_vue_esm_browser_js_len, "text/javascript; charset=utf-8" }},
2284
+ };
2285
+
2448
2286
  std::unique_ptr<httplib::Server> svr;
2449
2287
  #ifdef CPPHTTPLIB_OPENSSL_SUPPORT
2450
2288
  if (params.ssl_file_key != "" && params.ssl_file_cert != "") {
@@ -2467,16 +2305,6 @@ int main(int argc, char ** argv) {
2467
2305
  std::atomic<server_state> state{SERVER_STATE_LOADING_MODEL};
2468
2306
 
2469
2307
  svr->set_default_headers({{"Server", "llama.cpp"}});
2470
-
2471
- // CORS preflight
2472
- svr->Options(R"(.*)", [](const httplib::Request &, httplib::Response & res) {
2473
- // Access-Control-Allow-Origin is already set by middleware
2474
- res.set_header("Access-Control-Allow-Credentials", "true");
2475
- res.set_header("Access-Control-Allow-Methods", "POST");
2476
- res.set_header("Access-Control-Allow-Headers", "*");
2477
- return res.set_content("", "text/html"); // blank response, no data
2478
- });
2479
-
2480
2308
  svr->set_logger(log_server_request);
2481
2309
 
2482
2310
  auto res_error = [](httplib::Response & res, const json & error_data) {
@@ -2535,21 +2363,11 @@ int main(int argc, char ** argv) {
2535
2363
  // Middlewares
2536
2364
  //
2537
2365
 
2538
- auto middleware_validate_api_key = [&params, &res_error](const httplib::Request & req, httplib::Response & res) {
2539
- // TODO: should we apply API key to all endpoints, including "/health" and "/models"?
2540
- static const std::unordered_set<std::string> protected_endpoints = {
2541
- "/props",
2542
- "/completion",
2543
- "/completions",
2544
- "/v1/completions",
2545
- "/chat/completions",
2546
- "/v1/chat/completions",
2547
- "/infill",
2548
- "/tokenize",
2549
- "/detokenize",
2550
- "/embedding",
2551
- "/embeddings",
2552
- "/v1/embeddings",
2366
+ auto middleware_validate_api_key = [&params, &res_error, &static_files](const httplib::Request & req, httplib::Response & res) {
2367
+ static const std::unordered_set<std::string> public_endpoints = {
2368
+ "/health",
2369
+ "/models",
2370
+ "/v1/models",
2553
2371
  };
2554
2372
 
2555
2373
  // If API key is not set, skip validation
@@ -2557,8 +2375,8 @@ int main(int argc, char ** argv) {
2557
2375
  return true;
2558
2376
  }
2559
2377
 
2560
- // If path is not in protected_endpoints list, skip validation
2561
- if (protected_endpoints.find(req.path) == protected_endpoints.end()) {
2378
+ // If path is public or is static file, skip validation
2379
+ if (public_endpoints.find(req.path) != public_endpoints.end() || static_files.find(req.path) != static_files.end()) {
2562
2380
  return true;
2563
2381
  }
2564
2382
 
@@ -2584,7 +2402,7 @@ int main(int argc, char ** argv) {
2584
2402
  auto middleware_server_state = [&res_error, &state](const httplib::Request & req, httplib::Response & res) {
2585
2403
  server_state current_state = state.load();
2586
2404
  if (current_state == SERVER_STATE_LOADING_MODEL) {
2587
- auto tmp = string_split(req.path, '.');
2405
+ auto tmp = string_split<std::string>(req.path, '.');
2588
2406
  if (req.path == "/" || tmp.back() == "html") {
2589
2407
  res.set_content(reinterpret_cast<const char*>(loading_html), loading_html_len, "text/html; charset=utf-8");
2590
2408
  res.status = 503;
@@ -2599,6 +2417,14 @@ int main(int argc, char ** argv) {
2599
2417
  // register server middlewares
2600
2418
  svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) {
2601
2419
  res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
2420
+ // If this is OPTIONS request, skip validation because browsers don't include Authorization header
2421
+ if (req.method == "OPTIONS") {
2422
+ res.set_header("Access-Control-Allow-Credentials", "true");
2423
+ res.set_header("Access-Control-Allow-Methods", "GET, POST");
2424
+ res.set_header("Access-Control-Allow-Headers", "*");
2425
+ res.set_content("", "text/html"); // blank response, no data
2426
+ return httplib::Server::HandlerResponse::Handled; // skip further processing
2427
+ }
2602
2428
  if (!middleware_server_state(req, res)) {
2603
2429
  return httplib::Server::HandlerResponse::Handled;
2604
2430
  }
@@ -2620,7 +2446,7 @@ int main(int argc, char ** argv) {
2620
2446
 
2621
2447
  const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) {
2622
2448
  if (!params.endpoint_slots) {
2623
- res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED));
2449
+ res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED));
2624
2450
  return;
2625
2451
  }
2626
2452
 
@@ -2869,31 +2695,35 @@ int main(int argc, char ** argv) {
2869
2695
  };
2870
2696
 
2871
2697
  const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
2872
- std::string template_key = "tokenizer.chat_template", curr_tmpl;
2873
- int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0);
2874
- if (tlen > 0) {
2875
- std::vector<char> curr_tmpl_buf(tlen + 1, 0);
2876
- if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) {
2877
- curr_tmpl = std::string(curr_tmpl_buf.data(), tlen);
2878
- }
2879
- }
2880
2698
  json data = {
2881
- { "system_prompt", ctx_server.system_prompt.c_str() },
2882
2699
  { "default_generation_settings", ctx_server.default_generation_settings_for_props },
2883
2700
  { "total_slots", ctx_server.params.n_parallel },
2884
- { "chat_template", curr_tmpl.c_str() },
2701
+ { "chat_template", llama_get_chat_template(ctx_server.model) },
2885
2702
  };
2886
2703
 
2887
2704
  res_ok(res, data);
2888
2705
  };
2889
2706
 
2890
- const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) {
2891
- if (ctx_server.params.embedding || ctx_server.params.reranking) {
2892
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
2707
+ const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
2708
+ if (!ctx_server.params.endpoint_props) {
2709
+ res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED));
2893
2710
  return;
2894
2711
  }
2895
2712
 
2896
- std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, cmpl_type);
2713
+ json data = json::parse(req.body);
2714
+
2715
+ // update any props here
2716
+
2717
+ res_ok(res, {{ "success", true }});
2718
+ };
2719
+
2720
+ const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_inf_type inf_type, json & data, httplib::Response & res) {
2721
+ if (ctx_server.params.embedding) {
2722
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2723
+ return;
2724
+ }
2725
+
2726
+ std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, inf_type);
2897
2727
  ctx_server.queue_results.add_waiting_tasks(tasks);
2898
2728
  ctx_server.queue_tasks.post(tasks);
2899
2729
 
@@ -2939,24 +2769,69 @@ int main(int argc, char ** argv) {
2939
2769
 
2940
2770
  const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2941
2771
  json data = json::parse(req.body);
2942
- return handle_completions_generic(SERVER_TASK_CMPL_TYPE_NORMAL, data, res);
2772
+ return handle_completions_generic(SERVER_TASK_INF_TYPE_COMPLETION, data, res);
2943
2773
  };
2944
2774
 
2945
- const auto handle_infill = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2775
+ const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
2776
+ // check model compatibility
2777
+ std::string err;
2778
+ if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
2779
+ err += "prefix token is missing. ";
2780
+ }
2781
+ if (llama_token_fim_suf(ctx_server.model) == LLAMA_TOKEN_NULL) {
2782
+ err += "suffix token is missing. ";
2783
+ }
2784
+ if (llama_token_fim_mid(ctx_server.model) == LLAMA_TOKEN_NULL) {
2785
+ err += "middle token is missing. ";
2786
+ }
2787
+ if (!err.empty()) {
2788
+ res_error(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED));
2789
+ return;
2790
+ }
2791
+
2946
2792
  json data = json::parse(req.body);
2947
- return handle_completions_generic(SERVER_TASK_CMPL_TYPE_INFILL, data, res);
2793
+
2794
+ // validate input
2795
+ if (!data.contains("input_prefix")) {
2796
+ res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST));
2797
+ }
2798
+
2799
+ if (!data.contains("input_suffix")) {
2800
+ res_error(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST));
2801
+ }
2802
+
2803
+ if (data.contains("input_extra") && !data.at("input_extra").is_array()) {
2804
+ res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST));
2805
+ return;
2806
+ }
2807
+ json input_extra = json_value(data, "input_extra", json::array());
2808
+ for (const auto & chunk : input_extra) {
2809
+ // { "text": string, "filename": string }
2810
+ if (!chunk.contains("text") || !chunk.at("text").is_string()) {
2811
+ res_error(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST));
2812
+ return;
2813
+ }
2814
+ // filename is optional
2815
+ if (chunk.contains("filename") && !chunk.at("filename").is_string()) {
2816
+ res_error(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST));
2817
+ return;
2818
+ }
2819
+ }
2820
+ data["input_extra"] = input_extra; // default to empty array if it's not exist
2821
+
2822
+ return handle_completions_generic(SERVER_TASK_INF_TYPE_INFILL, data, res);
2948
2823
  };
2949
2824
 
2950
2825
  // TODO: maybe merge this function with "handle_completions_generic"
2951
2826
  const auto handle_chat_completions = [&ctx_server, &params, &res_error, &res_ok, verbose](const httplib::Request & req, httplib::Response & res) {
2952
- if (ctx_server.params.embedding || ctx_server.params.reranking) {
2953
- res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
2827
+ if (ctx_server.params.embedding) {
2828
+ res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
2954
2829
  return;
2955
2830
  }
2956
2831
 
2957
2832
  json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
2958
2833
 
2959
- std::vector<server_task> tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL);
2834
+ std::vector<server_task> tasks = ctx_server.create_tasks_inference(data, SERVER_TASK_INF_TYPE_COMPLETION);
2960
2835
  ctx_server.queue_results.add_waiting_tasks(tasks);
2961
2836
  ctx_server.queue_tasks.post(tasks);
2962
2837
 
@@ -3028,11 +2903,12 @@ int main(int argc, char ** argv) {
3028
2903
  if (body.count("content") != 0) {
3029
2904
  const bool add_special = json_value(body, "add_special", false);
3030
2905
  const bool with_pieces = json_value(body, "with_pieces", false);
3031
- std::vector<llama_token> tokens = ctx_server.tokenize(body.at("content"), add_special);
2906
+
2907
+ llama_tokens tokens = tokenize_mixed(ctx_server.ctx, body.at("content"), add_special, true);
3032
2908
 
3033
2909
  if (with_pieces) {
3034
2910
  for (const auto& token : tokens) {
3035
- std::string piece = llama_token_to_piece(ctx_server.ctx, token);
2911
+ std::string piece = common_token_to_piece(ctx_server.ctx, token);
3036
2912
  json piece_json;
3037
2913
 
3038
2914
  // Check if the piece is valid UTF-8
@@ -3065,7 +2941,7 @@ int main(int argc, char ** argv) {
3065
2941
 
3066
2942
  std::string content;
3067
2943
  if (body.count("tokens") != 0) {
3068
- const std::vector<llama_token> tokens = body.at("tokens");
2944
+ const llama_tokens tokens = body.at("tokens");
3069
2945
  content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend());
3070
2946
  }
3071
2947
 
@@ -3074,11 +2950,6 @@ int main(int argc, char ** argv) {
3074
2950
  };
3075
2951
 
3076
2952
  const auto handle_embeddings = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3077
- // TODO: somehow clean up this checks in the future
3078
- if (!ctx_server.params.embedding || ctx_server.params.reranking) {
3079
- res_error(res, format_error_response("This server does not support embeddings. Start it with `--embeddings` and without `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3080
- return;
3081
- }
3082
2953
  const json body = json::parse(req.body);
3083
2954
  bool is_openai = false;
3084
2955
 
@@ -3099,7 +2970,7 @@ int main(int argc, char ** argv) {
3099
2970
  json responses = json::array();
3100
2971
  bool error = false;
3101
2972
  {
3102
- std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING);
2973
+ std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_EMBEDDING);
3103
2974
  ctx_server.queue_results.add_waiting_tasks(tasks);
3104
2975
  ctx_server.queue_tasks.post(tasks);
3105
2976
 
@@ -3130,10 +3001,11 @@ int main(int argc, char ** argv) {
3130
3001
  };
3131
3002
 
3132
3003
  const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
3133
- if (!ctx_server.params.reranking) {
3134
- res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED));
3004
+ if (!ctx_server.params.reranking || ctx_server.params.embedding) {
3005
+ res_error(res, format_error_response("This server does not support reranking. Start it with `--reranking` and without `--embedding`", ERROR_TYPE_NOT_SUPPORTED));
3135
3006
  return;
3136
3007
  }
3008
+
3137
3009
  const json body = json::parse(req.body);
3138
3010
 
3139
3011
  // TODO: implement
@@ -3176,7 +3048,7 @@ int main(int argc, char ** argv) {
3176
3048
  json responses = json::array();
3177
3049
  bool error = false;
3178
3050
  {
3179
- std::vector<server_task> tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK);
3051
+ std::vector<server_task> tasks = ctx_server.create_tasks_inference({{"prompt", prompt}}, SERVER_TASK_INF_TYPE_RERANK);
3180
3052
  ctx_server.queue_results.add_waiting_tasks(tasks);
3181
3053
  ctx_server.queue_tasks.post(tasks);
3182
3054
 
@@ -3248,13 +3120,6 @@ int main(int argc, char ** argv) {
3248
3120
  res.status = 200; // HTTP OK
3249
3121
  };
3250
3122
 
3251
- auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
3252
- return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
3253
- res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
3254
- return false;
3255
- };
3256
- };
3257
-
3258
3123
  //
3259
3124
  // Router
3260
3125
  //
@@ -3262,33 +3127,29 @@ int main(int argc, char ** argv) {
3262
3127
  // register static assets routes
3263
3128
  if (!params.public_path.empty()) {
3264
3129
  // Set the base directory for serving static files
3265
- svr->set_base_dir(params.public_path);
3266
- }
3267
-
3268
- // using embedded static files
3269
- svr->Get("/", handle_static_file(index_html, index_html_len, "text/html; charset=utf-8"));
3270
- svr->Get("/index.js", handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8"));
3271
- svr->Get("/completion.js", handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8"));
3272
- svr->Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8"));
3273
-
3274
- // add new-ui files
3275
- svr->Get("/colorthemes.css", handle_static_file(colorthemes_css, colorthemes_css_len, "text/css; charset=utf-8"));
3276
- svr->Get("/style.css", handle_static_file(style_css, style_css_len, "text/css; charset=utf-8"));
3277
- svr->Get("/theme-beeninorder.css", handle_static_file(theme_beeninorder_css, theme_beeninorder_css_len, "text/css; charset=utf-8"));
3278
- svr->Get("/theme-ketivah.css", handle_static_file(theme_ketivah_css, theme_ketivah_css_len, "text/css; charset=utf-8"));
3279
- svr->Get("/theme-mangotango.css", handle_static_file(theme_mangotango_css, theme_mangotango_css_len, "text/css; charset=utf-8"));
3280
- svr->Get("/theme-playground.css", handle_static_file(theme_playground_css, theme_playground_css_len, "text/css; charset=utf-8"));
3281
- svr->Get("/theme-polarnight.css", handle_static_file(theme_polarnight_css, theme_polarnight_css_len, "text/css; charset=utf-8"));
3282
- svr->Get("/theme-snowstorm.css", handle_static_file(theme_snowstorm_css, theme_snowstorm_css_len, "text/css; charset=utf-8"));
3283
- svr->Get("/index-new.html", handle_static_file(index_new_html, index_new_html_len, "text/html; charset=utf-8"));
3284
- svr->Get("/system-prompts.js", handle_static_file(system_prompts_js, system_prompts_js_len, "text/javascript; charset=utf-8"));
3285
- svr->Get("/prompt-formats.js", handle_static_file(prompt_formats_js, prompt_formats_js_len, "text/javascript; charset=utf-8"));
3130
+ bool is_found = svr->set_mount_point("/", params.public_path);
3131
+ if (!is_found) {
3132
+ LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str());
3133
+ return 1;
3134
+ }
3135
+ } else {
3136
+ // using embedded static files
3137
+ for (const auto & it : static_files) {
3138
+ const server_static_file & static_file = it.second;
3139
+ svr->Get(it.first.c_str(), [&static_file](const httplib::Request &, httplib::Response & res) {
3140
+ res.set_content(reinterpret_cast<const char*>(static_file.data), static_file.size, static_file.mime_type);
3141
+ return false;
3142
+ });
3143
+ }
3144
+ }
3286
3145
 
3287
3146
  // register API routes
3288
- svr->Get ("/health", handle_health);
3147
+ svr->Get ("/health", handle_health); // public endpoint (no API key check)
3289
3148
  svr->Get ("/metrics", handle_metrics);
3290
3149
  svr->Get ("/props", handle_props);
3291
- svr->Get ("/v1/models", handle_models);
3150
+ svr->Post("/props", handle_props_change);
3151
+ svr->Get ("/models", handle_models); // public endpoint (no API key check)
3152
+ svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
3292
3153
  svr->Post("/completion", handle_completions); // legacy
3293
3154
  svr->Post("/completions", handle_completions);
3294
3155
  svr->Post("/v1/completions", handle_completions);
@@ -3366,10 +3227,11 @@ int main(int argc, char ** argv) {
3366
3227
  }
3367
3228
 
3368
3229
  // print sample chat example to make it clear which template is used
3369
- LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), llama_chat_format_example(ctx_server.model, params.chat_template).c_str());
3230
+ LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
3370
3231
 
3371
3232
  ctx_server.queue_tasks.on_new_task(std::bind(
3372
3233
  &server_context::process_single_task, &ctx_server, std::placeholders::_1));
3234
+
3373
3235
  ctx_server.queue_tasks.on_update_slots(std::bind(
3374
3236
  &server_context::update_slots, &ctx_server));
3375
3237
 
@@ -3377,7 +3239,7 @@ int main(int argc, char ** argv) {
3377
3239
  ctx_server.queue_tasks.terminate();
3378
3240
  };
3379
3241
 
3380
- LOG_INF("%s: server is listening on %s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
3242
+ LOG_INF("%s: server is listening on http://%s:%d - starting the main loop\n", __func__, params.hostname.c_str(), params.port);
3381
3243
 
3382
3244
  ctx_server.queue_tasks.start_loop();
3383
3245