@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -1,6 +1,7 @@
1
1
  aiohttp~=3.9.3
2
2
  behave~=1.2.6
3
- huggingface_hub~=0.20.3
3
+ huggingface_hub~=0.23.2
4
4
  numpy~=1.26.4
5
5
  openai~=1.30.3
6
6
  prometheus-client~=0.20.0
7
+ requests~=2.32.3
@@ -1,16 +1,25 @@
1
1
  #pragma once
2
2
 
3
- #include "llama.h"
4
3
  #include "common.h"
4
+ #include "log.h"
5
+ #include "llama.h"
6
+
7
+ #ifndef NDEBUG
8
+ // crash the server in debug mode, otherwise send an http 500 error
9
+ #define CPPHTTPLIB_NO_EXCEPTIONS 1
10
+ #endif
11
+ // increase max payload length to allow use of larger context size
12
+ #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
13
+ #include "httplib.h"
5
14
 
6
15
  // Change JSON_ASSERT from assert() to GGML_ASSERT:
7
16
  #define JSON_ASSERT GGML_ASSERT
8
17
  #include "json.hpp"
9
18
 
19
+ #include <random>
20
+ #include <sstream>
10
21
  #include <string>
11
22
  #include <vector>
12
- #include <sstream>
13
- #include <random>
14
23
 
15
24
  #define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
16
25
 
@@ -27,32 +36,6 @@ enum error_type {
27
36
  ERROR_TYPE_NOT_SUPPORTED, // custom error
28
37
  };
29
38
 
30
- extern bool server_verbose;
31
- extern bool server_log_json;
32
-
33
- #ifndef SERVER_VERBOSE
34
- #define SERVER_VERBOSE 1
35
- #endif
36
-
37
- #if SERVER_VERBOSE != 1
38
- #define LOG_VERBOSE(MSG, ...)
39
- #else
40
- #define LOG_VERBOSE(MSG, ...) \
41
- do \
42
- { \
43
- if (server_verbose) \
44
- { \
45
- server_log("VERB", __func__, __LINE__, MSG, __VA_ARGS__); \
46
- } \
47
- } while (0)
48
- #endif
49
-
50
- #define LOG_ERROR( MSG, ...) server_log("ERR", __func__, __LINE__, MSG, __VA_ARGS__)
51
- #define LOG_WARNING(MSG, ...) server_log("WARN", __func__, __LINE__, MSG, __VA_ARGS__)
52
- #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
53
-
54
- static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra);
55
-
56
39
  template <typename T>
57
40
  static T json_value(const json & body, const std::string & key, const T & default_value) {
58
41
  // Fallback null to default value
@@ -60,9 +43,7 @@ static T json_value(const json & body, const std::string & key, const T & defaul
60
43
  try {
61
44
  return body.at(key);
62
45
  } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
63
- std::stringstream ss;
64
- ss << "Wrong type supplied for parameter '" << key << "'. Expected '" << json(default_value).type_name() << "', using default value.";
65
- LOG_WARNING(ss.str().c_str(), body);
46
+ LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value\n", key.c_str(), json(default_value).type_name());
66
47
  return default_value;
67
48
  }
68
49
  } else {
@@ -70,48 +51,6 @@ static T json_value(const json & body, const std::string & key, const T & defaul
70
51
  }
71
52
  }
72
53
 
73
- static inline void server_log(const char * level, const char * function, int line, const char * message, const json & extra) {
74
- std::stringstream ss_tid;
75
- ss_tid << std::this_thread::get_id();
76
- json log = json{
77
- {"tid", ss_tid.str()},
78
- {"timestamp", time(nullptr)},
79
- };
80
-
81
- if (server_log_json) {
82
- log.merge_patch({
83
- {"level", level},
84
- {"function", function},
85
- {"line", line},
86
- {"msg", message},
87
- });
88
-
89
- if (!extra.empty()) {
90
- log.merge_patch(extra);
91
- }
92
-
93
- printf("%s\n", log.dump(-1, ' ', false, json::error_handler_t::replace).c_str());
94
- } else {
95
- char buf[1024];
96
- snprintf(buf, 1024, "%4s [%24s] %s", level, function, message);
97
-
98
- if (!extra.empty()) {
99
- log.merge_patch(extra);
100
- }
101
- std::stringstream ss;
102
- ss << buf << " |";
103
- for (const auto & el : log.items())
104
- {
105
- const std::string value = el.value().dump(-1, ' ', false, json::error_handler_t::replace);
106
- ss << " " << el.key() << "=" << value;
107
- }
108
-
109
- const std::string str = ss.str();
110
- printf("%.*s\n", (int)str.size(), str.data());
111
- }
112
- fflush(stdout);
113
- }
114
-
115
54
  //
116
55
  // chat template utils
117
56
  //
@@ -145,8 +84,9 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
145
84
  chat.push_back({role, content});
146
85
  }
147
86
 
148
- auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
149
- LOG_VERBOSE("formatted_chat", {{"text", formatted_chat.c_str()}});
87
+ const auto formatted_chat = llama_chat_apply_template(model, tmpl, chat, true);
88
+ LOG_DBG("formatted_chat: '%s'\n", formatted_chat.c_str());
89
+
150
90
  return formatted_chat;
151
91
  }
152
92
 
@@ -235,10 +175,7 @@ static std::string random_string() {
235
175
  }
236
176
 
237
177
  static std::string gen_chatcmplid() {
238
- std::stringstream chatcmplid;
239
- chatcmplid << "chatcmpl-" << random_string();
240
-
241
- return chatcmplid.str();
178
+ return "chatcmpl-" + random_string();
242
179
  }
243
180
 
244
181
  //
@@ -279,6 +216,18 @@ static size_t find_partial_stop_string(const std::string &stop, const std::strin
279
216
  return std::string::npos;
280
217
  }
281
218
 
219
+ static bool json_is_array_of_numbers(const json & data) {
220
+ if (data.is_array()) {
221
+ for (const auto & e : data) {
222
+ if (!e.is_number()) {
223
+ return false;
224
+ }
225
+ }
226
+ return true;
227
+ }
228
+ return false;
229
+ }
230
+
282
231
  // TODO: reuse llama_detokenize
283
232
  template <class Iter>
284
233
  static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) {
@@ -343,6 +292,17 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
343
292
  return out;
344
293
  }
345
294
 
295
+ static bool server_sent_event(httplib::DataSink & sink, const char * event, const json & data) {
296
+ const std::string str =
297
+ std::string(event) + ": " +
298
+ data.dump(-1, ' ', false, json::error_handler_t::replace) +
299
+ "\n\n"; // note: these newlines are important (not sure why though, if you know, add a comment to explain)
300
+
301
+ LOG_DBG("data stream, to_send: %s", str.c_str());
302
+
303
+ return sink.write(str.c_str(), str.size());
304
+ }
305
+
346
306
  //
347
307
  // OAI utils
348
308
  //
@@ -355,24 +315,6 @@ static json oaicompat_completion_params_parse(
355
315
 
356
316
  llama_params["__oaicompat"] = true;
357
317
 
358
- // Map OpenAI parameters to llama.cpp parameters
359
- //
360
- // For parameters that are defined by the OpenAI documentation (e.g.
361
- // temperature), we explicitly specify OpenAI's intended default; we
362
- // need to do that because sometimes OpenAI disagrees with llama.cpp
363
- //
364
- // https://platform.openai.com/docs/api-reference/chat/create
365
- llama_sampling_params default_sparams;
366
- llama_params["model"] = json_value(body, "model", std::string("unknown"));
367
- llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
368
- llama_params["logit_bias"] = json_value(body, "logit_bias", json::object());
369
- llama_params["n_predict"] = json_value(body, "max_tokens", -1);
370
- llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
371
- llama_params["seed"] = json_value(body, "seed", LLAMA_DEFAULT_SEED);
372
- llama_params["stream"] = json_value(body, "stream", false);
373
- llama_params["temperature"] = json_value(body, "temperature", 1.0);
374
- llama_params["top_p"] = json_value(body, "top_p", 1.0);
375
-
376
318
  // Apply chat template to the list of messages
377
319
  llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
378
320
 
@@ -389,6 +331,9 @@ static json oaicompat_completion_params_parse(
389
331
  std::string response_type = json_value(response_format, "type", std::string());
390
332
  if (response_type == "json_object") {
391
333
  llama_params["json_schema"] = json_value(response_format, "schema", json::object());
334
+ } else if (response_type == "json_schema") {
335
+ json json_schema = json_value(response_format, "json_schema", json::object());
336
+ llama_params["json_schema"] = json_value(json_schema, "schema", json::object());
392
337
  } else if (!response_type.empty() && response_type != "text") {
393
338
  throw std::runtime_error("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type);
394
339
  }
@@ -410,7 +355,7 @@ static json oaicompat_completion_params_parse(
410
355
 
411
356
  // Params supported by OAI but unsupported by llama.cpp
412
357
  static const std::vector<std::string> unsupported_params { "tools", "tool_choice" };
413
- for (auto & param : unsupported_params) {
358
+ for (const auto & param : unsupported_params) {
414
359
  if (body.contains(param)) {
415
360
  throw std::runtime_error("Unsupported param: " + param);
416
361
  }
@@ -429,7 +374,7 @@ static json oaicompat_completion_params_parse(
429
374
  return llama_params;
430
375
  }
431
376
 
432
- static json format_final_response_oaicompat(const json & request, json result, const std::string & completion_id, bool streaming = false) {
377
+ static json format_final_response_oaicompat(const json & request, const json & result, const std::string & completion_id, bool streaming = false, bool verbose = false) {
433
378
  bool stopped_word = result.count("stopped_word") != 0;
434
379
  bool stopped_eos = json_value(result, "stopped_eos", false);
435
380
  int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
@@ -466,7 +411,8 @@ static json format_final_response_oaicompat(const json & request, json result, c
466
411
  {"id", completion_id}
467
412
  };
468
413
 
469
- if (server_verbose) {
414
+ // extra fields for debugging purposes
415
+ if (verbose) {
470
416
  res["__verbose"] = result;
471
417
  }
472
418
 
@@ -478,7 +424,7 @@ static json format_final_response_oaicompat(const json & request, json result, c
478
424
  }
479
425
 
480
426
  // return value is vector as there is one case where we might need to generate two responses
481
- static std::vector<json> format_partial_response_oaicompat(json result, const std::string & completion_id) {
427
+ static std::vector<json> format_partial_response_oaicompat(const json & result, const std::string & completion_id) {
482
428
  if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
483
429
  return std::vector<json>({result});
484
430
  }
@@ -580,7 +526,7 @@ static std::vector<json> format_partial_response_oaicompat(json result, const st
580
526
  static json format_embeddings_response_oaicompat(const json & request, const json & embeddings) {
581
527
  json data = json::array();
582
528
  int i = 0;
583
- for (auto & elem : embeddings) {
529
+ for (const auto & elem : embeddings) {
584
530
  data.push_back(json{
585
531
  {"embedding", json_value(elem, "embedding", json::array())},
586
532
  {"index", i++},
@@ -591,7 +537,7 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
591
537
  json res = json {
592
538
  {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
593
539
  {"object", "list"},
594
- {"usage", json {
540
+ {"usage", json { // TODO: fill
595
541
  {"prompt_tokens", 0},
596
542
  {"total_tokens", 0}
597
543
  }},
@@ -601,7 +547,63 @@ static json format_embeddings_response_oaicompat(const json & request, const jso
601
547
  return res;
602
548
  }
603
549
 
604
- static json format_tokenizer_response(const std::vector<llama_token> & tokens) {
550
+ static json format_response_rerank(const json & request, const json & ranks) {
551
+ json data = json::array();
552
+ int i = 0;
553
+ for (const auto & rank : ranks) {
554
+ data.push_back(json{
555
+ {"index", i++},
556
+ {"relevance_score", json_value(rank, "score", 0.0)},
557
+ });
558
+ }
559
+
560
+ json res = json {
561
+ {"model", json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
562
+ {"object", "list"},
563
+ {"usage", json { // TODO: fill
564
+ {"prompt_tokens", 0},
565
+ {"total_tokens", 0}
566
+ }},
567
+ {"results", data}
568
+ };
569
+
570
+ return res;
571
+ }
572
+
573
+ static bool is_valid_utf8(const std::string & str) {
574
+ const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
575
+ const unsigned char* end = bytes + str.length();
576
+
577
+ while (bytes < end) {
578
+ if (*bytes <= 0x7F) {
579
+ // 1-byte sequence (0xxxxxxx)
580
+ bytes++;
581
+ } else if ((*bytes & 0xE0) == 0xC0) {
582
+ // 2-byte sequence (110xxxxx 10xxxxxx)
583
+ if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80)
584
+ return false;
585
+ bytes += 2;
586
+ } else if ((*bytes & 0xF0) == 0xE0) {
587
+ // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx)
588
+ if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80)
589
+ return false;
590
+ bytes += 3;
591
+ } else if ((*bytes & 0xF8) == 0xF0) {
592
+ // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx)
593
+ if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 ||
594
+ (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80)
595
+ return false;
596
+ bytes += 4;
597
+ } else {
598
+ // Invalid UTF-8 lead byte
599
+ return false;
600
+ }
601
+ }
602
+
603
+ return true;
604
+ }
605
+
606
+ static json format_tokenizer_response(const json & tokens) {
605
607
  return json {
606
608
  {"tokens", tokens}
607
609
  };
@@ -1,17 +1,14 @@
1
+ #include "arg.h"
1
2
  #include "common.h"
3
+ #include "log.h"
2
4
  #include "llama.h"
3
5
 
4
- #include <cmath>
5
- #include <cstdio>
6
- #include <string>
7
6
  #include <vector>
8
7
 
9
- static void print_usage(int argc, char ** argv, const gpt_params & params) {
10
- gpt_params_print_usage(argc, argv, params);
11
-
12
- LOG_TEE("\nexample usage:\n");
13
- LOG_TEE("\n %s -m model.gguf -p \"Hello my name is\" -n 32\n", argv[0]);
14
- LOG_TEE("\n");
8
+ static void print_usage(int, char ** argv) {
9
+ LOG("\nexample usage:\n");
10
+ LOG("\n %s -m model.gguf -p \"Hello my name is\" -n 32\n", argv[0]);
11
+ LOG("\n");
15
12
  }
16
13
 
17
14
  int main(int argc, char ** argv) {
@@ -20,11 +17,12 @@ int main(int argc, char ** argv) {
20
17
  params.prompt = "Hello my name is";
21
18
  params.n_predict = 32;
22
19
 
23
- if (!gpt_params_parse(argc, argv, params)) {
24
- print_usage(argc, argv, params);
20
+ if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON, print_usage)) {
25
21
  return 1;
26
22
  }
27
23
 
24
+ gpt_init();
25
+
28
26
  // total length of the sequence including the prompt
29
27
  const int n_predict = params.n_predict;
30
28
 
@@ -55,6 +53,14 @@ int main(int argc, char ** argv) {
55
53
  return 1;
56
54
  }
57
55
 
56
+ auto sparams = llama_sampler_chain_default_params();
57
+
58
+ sparams.no_perf = false;
59
+
60
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
61
+
62
+ llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
63
+
58
64
  // tokenize the prompt
59
65
 
60
66
  std::vector<llama_token> tokens_list;
@@ -63,25 +69,24 @@ int main(int argc, char ** argv) {
63
69
  const int n_ctx = llama_n_ctx(ctx);
64
70
  const int n_kv_req = tokens_list.size() + (n_predict - tokens_list.size());
65
71
 
66
- LOG_TEE("\n%s: n_predict = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, n_kv_req);
72
+ LOG("\n");
73
+ LOG_INF("%s: n_predict = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_predict, n_ctx, n_kv_req);
67
74
 
68
75
  // make sure the KV cache is big enough to hold all the prompt and generated tokens
69
76
  if (n_kv_req > n_ctx) {
70
- LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
71
- LOG_TEE("%s: either reduce n_predict or increase n_ctx\n", __func__);
77
+ LOG_ERR("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
78
+ LOG_ERR("%s: either reduce n_predict or increase n_ctx\n", __func__);
72
79
  return 1;
73
80
  }
74
81
 
75
82
  // print the prompt token-by-token
76
83
 
77
- fprintf(stderr, "\n");
84
+ LOG("\n");
78
85
 
79
86
  for (auto id : tokens_list) {
80
- fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
87
+ LOG("%s", llama_token_to_piece(ctx, id).c_str());
81
88
  }
82
89
 
83
- fflush(stderr);
84
-
85
90
  // create a llama_batch with size 512
86
91
  // we use this object to submit token data for decoding
87
92
 
@@ -96,7 +101,7 @@ int main(int argc, char ** argv) {
96
101
  batch.logits[batch.n_tokens - 1] = true;
97
102
 
98
103
  if (llama_decode(ctx, batch) != 0) {
99
- LOG_TEE("%s: llama_decode() failed\n", __func__);
104
+ LOG("%s: llama_decode() failed\n", __func__);
100
105
  return 1;
101
106
  }
102
107
 
@@ -110,29 +115,16 @@ int main(int argc, char ** argv) {
110
115
  while (n_cur <= n_predict) {
111
116
  // sample the next token
112
117
  {
113
- auto n_vocab = llama_n_vocab(model);
114
- auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
115
-
116
- std::vector<llama_token_data> candidates;
117
- candidates.reserve(n_vocab);
118
-
119
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
120
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
121
- }
122
-
123
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
124
-
125
- // sample the most likely token
126
- const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
118
+ const llama_token new_token_id = llama_sampler_sample(smpl, ctx, -1);
127
119
 
128
120
  // is it an end of generation?
129
121
  if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
130
- LOG_TEE("\n");
122
+ LOG("\n");
131
123
 
132
124
  break;
133
125
  }
134
126
 
135
- LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
127
+ LOG("%s", llama_token_to_piece(ctx, new_token_id).c_str());
136
128
  fflush(stdout);
137
129
 
138
130
  // prepare the next batch
@@ -148,24 +140,26 @@ int main(int argc, char ** argv) {
148
140
 
149
141
  // evaluate the current batch with the transformer model
150
142
  if (llama_decode(ctx, batch)) {
151
- fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
143
+ LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
152
144
  return 1;
153
145
  }
154
146
  }
155
147
 
156
- LOG_TEE("\n");
148
+ LOG("\n");
157
149
 
158
150
  const auto t_main_end = ggml_time_us();
159
151
 
160
- LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
152
+ LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
161
153
  __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
162
154
 
163
- llama_print_timings(ctx);
155
+ LOG("\n");
156
+ llama_perf_sampler_print(smpl);
157
+ llama_perf_context_print(ctx);
164
158
 
165
- fprintf(stderr, "\n");
159
+ LOG("\n");
166
160
 
167
161
  llama_batch_free(batch);
168
-
162
+ llama_sampler_free(smpl);
169
163
  llama_free(ctx);
170
164
  llama_free_model(model);
171
165