@fugood/llama.node 0.3.0 → 0.3.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. package/CMakeLists.txt +1 -10
  2. package/bin/darwin/arm64/llama-node.node +0 -0
  3. package/bin/darwin/x64/llama-node.node +0 -0
  4. package/bin/linux/arm64/llama-node.node +0 -0
  5. package/bin/linux/x64/llama-node.node +0 -0
  6. package/bin/linux-vulkan/arm64/llama-node.node +0 -0
  7. package/bin/linux-vulkan/x64/llama-node.node +0 -0
  8. package/bin/win32/arm64/llama-node.node +0 -0
  9. package/bin/win32/arm64/node.lib +0 -0
  10. package/bin/win32/x64/llama-node.node +0 -0
  11. package/bin/win32/x64/node.lib +0 -0
  12. package/bin/win32-vulkan/arm64/llama-node.node +0 -0
  13. package/bin/win32-vulkan/arm64/node.lib +0 -0
  14. package/bin/win32-vulkan/x64/llama-node.node +0 -0
  15. package/bin/win32-vulkan/x64/node.lib +0 -0
  16. package/package.json +6 -4
  17. package/src/LlamaCompletionWorker.cpp +6 -6
  18. package/src/LlamaContext.cpp +7 -9
  19. package/src/common.hpp +2 -1
  20. package/src/llama.cpp/.github/workflows/build.yml +98 -24
  21. package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
  22. package/src/llama.cpp/.github/workflows/docker.yml +43 -34
  23. package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
  24. package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
  25. package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
  26. package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
  27. package/src/llama.cpp/.github/workflows/server.yml +7 -0
  28. package/src/llama.cpp/CMakeLists.txt +20 -8
  29. package/src/llama.cpp/common/CMakeLists.txt +12 -10
  30. package/src/llama.cpp/common/arg.cpp +2006 -0
  31. package/src/llama.cpp/common/arg.h +77 -0
  32. package/src/llama.cpp/common/common.cpp +496 -1632
  33. package/src/llama.cpp/common/common.h +161 -63
  34. package/src/llama.cpp/common/console.cpp +3 -0
  35. package/src/llama.cpp/common/log.cpp +401 -0
  36. package/src/llama.cpp/common/log.h +66 -698
  37. package/src/llama.cpp/common/ngram-cache.cpp +3 -0
  38. package/src/llama.cpp/common/sampling.cpp +348 -350
  39. package/src/llama.cpp/common/sampling.h +62 -139
  40. package/src/llama.cpp/common/stb_image.h +5990 -6398
  41. package/src/llama.cpp/common/train.cpp +2 -0
  42. package/src/llama.cpp/docs/build.md +36 -1
  43. package/src/llama.cpp/examples/CMakeLists.txt +0 -1
  44. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1 -2
  45. package/src/llama.cpp/examples/batched/batched.cpp +39 -55
  46. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +34 -44
  47. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
  48. package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +15 -15
  49. package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
  50. package/src/llama.cpp/examples/embedding/embedding.cpp +143 -87
  51. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +33 -33
  52. package/src/llama.cpp/examples/export-lora/export-lora.cpp +36 -35
  53. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
  54. package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +5 -0
  55. package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
  56. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
  57. package/src/llama.cpp/examples/gritlm/gritlm.cpp +34 -27
  58. package/src/llama.cpp/examples/imatrix/imatrix.cpp +59 -62
  59. package/src/llama.cpp/examples/infill/infill.cpp +117 -132
  60. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +265 -58
  61. package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +29 -22
  62. package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
  63. package/src/llama.cpp/examples/llava/clip.cpp +685 -150
  64. package/src/llama.cpp/examples/llava/clip.h +11 -2
  65. package/src/llama.cpp/examples/llava/llava-cli.cpp +47 -58
  66. package/src/llama.cpp/examples/llava/llava.cpp +110 -24
  67. package/src/llama.cpp/examples/llava/llava.h +2 -3
  68. package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
  69. package/src/llama.cpp/examples/llava/requirements.txt +1 -0
  70. package/src/llama.cpp/examples/lookahead/lookahead.cpp +42 -43
  71. package/src/llama.cpp/examples/lookup/lookup-create.cpp +10 -8
  72. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +23 -22
  73. package/src/llama.cpp/examples/lookup/lookup.cpp +40 -43
  74. package/src/llama.cpp/examples/main/main.cpp +210 -262
  75. package/src/llama.cpp/examples/parallel/parallel.cpp +49 -49
  76. package/src/llama.cpp/examples/passkey/passkey.cpp +42 -50
  77. package/src/llama.cpp/examples/perplexity/perplexity.cpp +187 -200
  78. package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
  79. package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
  80. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +2 -3
  81. package/src/llama.cpp/examples/retrieval/retrieval.cpp +49 -44
  82. package/src/llama.cpp/examples/rpc/rpc-server.cpp +24 -1
  83. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +32 -35
  84. package/src/llama.cpp/examples/server/CMakeLists.txt +3 -5
  85. package/src/llama.cpp/examples/server/server.cpp +1027 -1073
  86. package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
  87. package/src/llama.cpp/examples/server/utils.hpp +107 -105
  88. package/src/llama.cpp/examples/simple/simple.cpp +35 -41
  89. package/src/llama.cpp/examples/speculative/speculative.cpp +129 -103
  90. package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
  91. package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
  92. package/src/llama.cpp/examples/tokenize/tokenize.cpp +25 -27
  93. package/src/llama.cpp/ggml/CMakeLists.txt +14 -3
  94. package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
  95. package/src/llama.cpp/ggml/include/ggml-backend.h +145 -60
  96. package/src/llama.cpp/ggml/include/ggml-blas.h +3 -3
  97. package/src/llama.cpp/ggml/include/ggml-cann.h +15 -19
  98. package/src/llama.cpp/ggml/include/ggml-cuda.h +16 -16
  99. package/src/llama.cpp/ggml/include/ggml-metal.h +5 -8
  100. package/src/llama.cpp/ggml/include/ggml-rpc.h +5 -5
  101. package/src/llama.cpp/ggml/include/ggml-sycl.h +8 -8
  102. package/src/llama.cpp/ggml/include/ggml-vulkan.h +7 -7
  103. package/src/llama.cpp/ggml/include/ggml.h +293 -186
  104. package/src/llama.cpp/ggml/src/CMakeLists.txt +86 -44
  105. package/src/llama.cpp/ggml/src/ggml-aarch64.c +2135 -1119
  106. package/src/llama.cpp/ggml/src/ggml-alloc.c +6 -0
  107. package/src/llama.cpp/ggml/src/ggml-backend-impl.h +152 -70
  108. package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +606 -286
  109. package/src/llama.cpp/ggml/src/ggml-blas.cpp +9 -10
  110. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
  111. package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
  112. package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
  113. package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
  114. package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
  115. package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
  116. package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
  117. package/src/llama.cpp/ggml/src/ggml-cann.cpp +215 -216
  118. package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
  119. package/src/llama.cpp/ggml/src/ggml-cpu-impl.h +614 -0
  120. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
  121. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
  122. package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
  123. package/src/llama.cpp/ggml/src/ggml-impl.h +49 -603
  124. package/src/llama.cpp/ggml/src/ggml-kompute.cpp +4 -24
  125. package/src/llama.cpp/ggml/src/ggml-quants.c +972 -92
  126. package/src/llama.cpp/ggml/src/ggml-quants.h +15 -0
  127. package/src/llama.cpp/ggml/src/ggml-rpc.cpp +116 -66
  128. package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
  129. package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +11 -0
  130. package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +52 -0
  131. package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
  132. package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
  133. package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
  134. package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
  135. package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
  136. package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
  137. package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +16 -3
  138. package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
  139. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
  140. package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
  141. package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +1 -1
  142. package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +6 -3
  143. package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +2 -0
  144. package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
  145. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
  146. package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
  147. package/src/llama.cpp/ggml/src/ggml-sycl.cpp +97 -169
  148. package/src/llama.cpp/ggml/src/ggml-vulkan.cpp +1508 -1124
  149. package/src/llama.cpp/ggml/src/ggml.c +3001 -1647
  150. package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +192 -0
  151. package/src/llama.cpp/ggml/src/vulkan-shaders/CMakeLists.txt +2 -0
  152. package/src/llama.cpp/ggml/src/vulkan-shaders/vulkan-shaders-gen.cpp +88 -40
  153. package/src/llama.cpp/include/llama.h +241 -264
  154. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
  155. package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
  156. package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
  157. package/src/llama.cpp/src/llama-grammar.cpp +721 -122
  158. package/src/llama.cpp/src/llama-grammar.h +120 -15
  159. package/src/llama.cpp/src/llama-impl.h +156 -1
  160. package/src/llama.cpp/src/llama-sampling.cpp +1375 -303
  161. package/src/llama.cpp/src/llama-sampling.h +20 -47
  162. package/src/llama.cpp/src/llama-vocab.cpp +343 -120
  163. package/src/llama.cpp/src/llama-vocab.h +33 -17
  164. package/src/llama.cpp/src/llama.cpp +4247 -1525
  165. package/src/llama.cpp/src/unicode-data.cpp +6 -4
  166. package/src/llama.cpp/src/unicode-data.h +4 -4
  167. package/src/llama.cpp/src/unicode.cpp +15 -7
  168. package/src/llama.cpp/tests/CMakeLists.txt +3 -0
  169. package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
  170. package/src/llama.cpp/tests/test-backend-ops.cpp +1592 -289
  171. package/src/llama.cpp/tests/test-barrier.cpp +93 -0
  172. package/src/llama.cpp/tests/test-grad0.cpp +187 -70
  173. package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
  174. package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
  175. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +6 -4
  176. package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
  177. package/src/llama.cpp/tests/test-log.cpp +39 -0
  178. package/src/llama.cpp/tests/test-quantize-fns.cpp +6 -0
  179. package/src/llama.cpp/tests/test-rope.cpp +1 -1
  180. package/src/llama.cpp/tests/test-sampling.cpp +157 -98
  181. package/src/llama.cpp/tests/test-tokenizer-0.cpp +55 -35
  182. package/patches/llama.patch +0 -22
  183. package/src/llama.cpp/.github/workflows/bench.yml +0 -310
  184. package/src/llama.cpp/common/grammar-parser.cpp +0 -536
  185. package/src/llama.cpp/common/grammar-parser.h +0 -29
  186. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
  187. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
@@ -1,7 +1,10 @@
1
1
  // A basic application simulating a server with multiple clients.
2
2
  // The clients submit requests to the server and they are processed in parallel.
3
3
 
4
+ #include "arg.h"
4
5
  #include "common.h"
6
+ #include "sampling.h"
7
+ #include "log.h"
5
8
  #include "llama.h"
6
9
 
7
10
  #include <cmath>
@@ -50,8 +53,8 @@ static std::vector<std::string> k_prompts = {
50
53
 
51
54
  struct client {
52
55
  ~client() {
53
- if (ctx_sampling) {
54
- llama_sampling_free(ctx_sampling);
56
+ if (smpl) {
57
+ gpt_sampler_free(smpl);
55
58
  }
56
59
  }
57
60
 
@@ -72,7 +75,7 @@ struct client {
72
75
  std::string prompt;
73
76
  std::string response;
74
77
 
75
- struct llama_sampling_context * ctx_sampling = nullptr;
78
+ struct gpt_sampler * smpl = nullptr;
76
79
  };
77
80
 
78
81
  static void print_date_time() {
@@ -81,7 +84,9 @@ static void print_date_time() {
81
84
  char buffer[80];
82
85
  strftime(buffer, sizeof(buffer), "%Y-%m-%d %H:%M:%S", local_time);
83
86
 
84
- printf("\n\033[35mrun parameters as at %s\033[0m\n", buffer);
87
+ LOG_INF("\n");
88
+ LOG_INF("\033[35mrun parameters as of %s\033[0m\n", buffer);
89
+ LOG_INF("\n");
85
90
  }
86
91
 
87
92
  // Define a split string function to ...
@@ -100,11 +105,12 @@ int main(int argc, char ** argv) {
100
105
 
101
106
  gpt_params params;
102
107
 
103
- if (!gpt_params_parse(argc, argv, params)) {
104
- gpt_params_print_usage(argc, argv, params);
108
+ if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PARALLEL)) {
105
109
  return 1;
106
110
  }
107
111
 
112
+ gpt_init();
113
+
108
114
  // number of simultaneous "clients" to simulate
109
115
  const int32_t n_clients = params.n_parallel;
110
116
 
@@ -119,41 +125,34 @@ int main(int argc, char ** argv) {
119
125
 
120
126
  const bool dump_kv_cache = params.dump_kv_cache;
121
127
 
122
- #ifndef LOG_DISABLE_LOGS
123
- log_set_target(log_filename_generator("parallel", "log"));
124
- LOG_TEE("Log start\n");
125
- log_dump_cmdline(argc, argv);
126
- #endif // LOG_DISABLE_LOGS
127
-
128
128
  // init llama.cpp
129
129
  llama_backend_init();
130
130
  llama_numa_init(params.numa);
131
131
 
132
- llama_model * model = NULL;
133
- llama_context * ctx = NULL;
134
-
135
132
  // load the target model
136
- std::tie(model, ctx) = llama_init_from_gpt_params(params);
133
+ llama_init_result llama_init = llama_init_from_gpt_params(params);
134
+
135
+ llama_model * model = llama_init.model;
136
+ llama_context * ctx = llama_init.context;
137
137
 
138
138
  // load the prompts from an external file if there are any
139
139
  if (params.prompt.empty()) {
140
- printf("\n\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
140
+ LOG_INF("\033[32mNo new questions so proceed with build-in defaults.\033[0m\n");
141
141
  } else {
142
142
  // Output each line of the input params.prompts vector and copy to k_prompts
143
143
  int index = 0;
144
- printf("\n\033[32mNow printing the external prompt file %s\033[0m\n\n", params.prompt_file.c_str());
144
+ LOG_INF("\033[32mNow printing the external prompt file %s\033[0m\n\n", params.prompt_file.c_str());
145
145
 
146
146
  std::vector<std::string> prompts = split_string(params.prompt, '\n');
147
147
  for (const auto& prompt : prompts) {
148
148
  k_prompts.resize(index + 1);
149
149
  k_prompts[index] = prompt;
150
150
  index++;
151
- printf("%3d prompt: %s\n", index, prompt.c_str());
151
+ LOG_INF("%3d prompt: %s\n", index, prompt.c_str());
152
152
  }
153
153
  }
154
154
 
155
- fprintf(stderr, "\n\n");
156
- fflush(stderr);
155
+ LOG_INF("\n\n");
157
156
 
158
157
  const int n_ctx = llama_n_ctx(ctx);
159
158
 
@@ -161,7 +160,7 @@ int main(int argc, char ** argv) {
161
160
  for (size_t i = 0; i < clients.size(); ++i) {
162
161
  auto & client = clients[i];
163
162
  client.id = i;
164
- client.ctx_sampling = llama_sampling_init(params.sparams);
163
+ client.smpl = gpt_sampler_init(model, params.sparams);
165
164
  }
166
165
 
167
166
  std::vector<llama_token> tokens_system;
@@ -182,19 +181,19 @@ int main(int argc, char ** argv) {
182
181
 
183
182
  const auto t_main_start = ggml_time_us();
184
183
 
185
- LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
186
- LOG_TEE("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
187
- LOG_TEE("\n");
184
+ LOG_INF("%s: Simulating parallel requests from clients:\n", __func__);
185
+ LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
186
+ LOG_INF("\n");
188
187
 
189
188
  {
190
- LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
189
+ LOG_INF("%s: Evaluating the system prompt ...\n", __func__);
191
190
 
192
191
  for (int32_t i = 0; i < n_tokens_system; ++i) {
193
192
  llama_batch_add(batch, tokens_system[i], i, { 0 }, false);
194
193
  }
195
194
 
196
195
  if (llama_decode(ctx, batch) != 0) {
197
- LOG_TEE("%s: llama_decode() failed\n", __func__);
196
+ LOG_ERR("%s: llama_decode() failed\n", __func__);
198
197
  return 1;
199
198
  }
200
199
 
@@ -203,10 +202,10 @@ int main(int argc, char ** argv) {
203
202
  llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
204
203
  }
205
204
 
206
- LOG_TEE("\n");
205
+ LOG_INF("\n");
207
206
  }
208
207
 
209
- LOG_TEE("Processing requests ...\n\n");
208
+ LOG_INF("Processing requests ...\n\n");
210
209
 
211
210
  while (true) {
212
211
  if (dump_kv_cache) {
@@ -237,7 +236,7 @@ int main(int argc, char ** argv) {
237
236
  llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
238
237
  }
239
238
 
240
- LOG_TEE("%s: clearing the KV cache\n", __func__);
239
+ LOG_INF("%s: clearing the KV cache\n", __func__);
241
240
  }
242
241
 
243
242
  // insert new sequences for decoding
@@ -253,7 +252,7 @@ int main(int argc, char ** argv) {
253
252
  client.prompt = client.input + "\nAssistant:";
254
253
  client.response = "";
255
254
 
256
- llama_sampling_reset(client.ctx_sampling);
255
+ gpt_sampler_reset(client.smpl);
257
256
 
258
257
  // do not prepend BOS because we have a system prompt!
259
258
  std::vector<llama_token> tokens_prompt;
@@ -272,7 +271,7 @@ int main(int argc, char ** argv) {
272
271
  client.n_decoded = 0;
273
272
  client.i_batch = batch.n_tokens - 1;
274
273
 
275
- LOG_TEE("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
274
+ LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
276
275
 
277
276
  g_seq_id += 1;
278
277
 
@@ -316,11 +315,11 @@ int main(int argc, char ** argv) {
316
315
  if (ret != 0) {
317
316
  if (n_batch == 1 || ret < 0) {
318
317
  // if you get here, it means the KV cache is full - try increasing it via the context size
319
- LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
318
+ LOG_ERR("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
320
319
  return 1;
321
320
  }
322
321
 
323
- LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
322
+ LOG_ERR("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
324
323
 
325
324
  n_cache_miss += 1;
326
325
 
@@ -331,7 +330,7 @@ int main(int argc, char ** argv) {
331
330
  continue;
332
331
  }
333
332
 
334
- LOG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
333
+ LOG_DBG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
335
334
 
336
335
  for (auto & client : clients) {
337
336
  if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
@@ -341,9 +340,9 @@ int main(int argc, char ** argv) {
341
340
  //printf("client %d, seq %d, token %d, pos %d, batch %d\n",
342
341
  // client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
343
342
 
344
- const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
343
+ const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i);
345
344
 
346
- llama_sampling_accept(client.ctx_sampling, ctx, id, true);
345
+ gpt_sampler_accept(client.smpl, id, true);
347
346
 
348
347
  if (client.n_decoded == 1) {
349
348
  // start measuring generation time after the first token to make sure all concurrent clients
@@ -371,12 +370,12 @@ int main(int argc, char ** argv) {
371
370
  }
372
371
 
373
372
  // delete only the generated part of the sequence, i.e. keep the system prompt in the cache
374
- llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
373
+ llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
375
374
  llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
376
375
 
377
376
  const auto t_main_end = ggml_time_us();
378
377
 
379
- LOG_TEE("\033[31mClient %3d, seq %3d/%3d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \nInput: %s\n\033[35mResponse: %s\033[0m\n\n",
378
+ LOG_INF("\033[31mClient %3d, seq %3d/%3d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \n\nInput: %s\n\033[35mResponse: %s\033[0m\n\n",
380
379
  client.id, client.seq_id, n_seq, client.n_prompt, client.n_decoded,
381
380
  (t_main_end - client.t_start_prompt) / 1e6,
382
381
  (double) (client.n_prompt + client.n_decoded) / (t_main_end - client.t_start_prompt) * 1e6,
@@ -399,21 +398,22 @@ int main(int argc, char ** argv) {
399
398
 
400
399
  print_date_time();
401
400
 
402
- LOG_TEE("\n%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
401
+ LOG_INF("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
403
402
  if (params.prompt_file.empty()) {
404
403
  params.prompt_file = "used built-in defaults";
405
404
  }
406
- LOG_TEE("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str());
407
- LOG_TEE("Model and path used: \033[32m%s\033[0m\n\n", params.model.c_str());
405
+ LOG_INF("External prompt file: \033[32m%s\033[0m\n", params.prompt_file.c_str());
406
+ LOG_INF("Model and path used: \033[32m%s\033[0m\n\n", params.model.c_str());
408
407
 
409
- LOG_TEE("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6);
410
- LOG_TEE("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6);
411
- LOG_TEE("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6);
412
- LOG_TEE("Cache misses: %6d\n", n_cache_miss);
408
+ LOG_INF("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6);
409
+ LOG_INF("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6);
410
+ LOG_INF("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6);
411
+ LOG_INF("Cache misses: %6d\n", n_cache_miss);
413
412
 
414
- LOG_TEE("\n");
413
+ LOG_INF("\n");
415
414
 
416
- llama_print_timings(ctx);
415
+ // TODO: print sampling/grammar timings for all clients
416
+ llama_perf_context_print(ctx);
417
417
 
418
418
  llama_batch_free(batch);
419
419
 
@@ -422,7 +422,7 @@ int main(int argc, char ** argv) {
422
422
 
423
423
  llama_backend_free();
424
424
 
425
- fprintf(stderr, "\n\n");
425
+ LOG("\n\n");
426
426
 
427
427
  return 0;
428
428
  }
@@ -1,4 +1,6 @@
1
+ #include "arg.h"
1
2
  #include "common.h"
3
+ #include "log.h"
2
4
  #include "llama.h"
3
5
 
4
6
  #include <cmath>
@@ -6,12 +8,10 @@
6
8
  #include <string>
7
9
  #include <vector>
8
10
 
9
- static void print_usage(int argc, char ** argv, const gpt_params & params) {
10
- gpt_params_print_usage(argc, argv, params);
11
-
12
- LOG_TEE("\nexample usage:\n");
13
- LOG_TEE("\n %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]);
14
- LOG_TEE("\n");
11
+ static void print_usage(int, char ** argv) {
12
+ LOG("\nexample usage:\n");
13
+ LOG("\n %s -m model.gguf --junk 250 --pos 90 --keep 32 --grp-attn-n 2 [--seed 1234]\n", argv[0]);
14
+ LOG("\n");
15
15
  }
16
16
 
17
17
  int main(int argc, char ** argv) {
@@ -21,12 +21,11 @@ int main(int argc, char ** argv) {
21
21
  params.n_keep = 32;
22
22
  params.i_pos = -1;
23
23
 
24
- if (!gpt_params_parse(argc, argv, params)) {
25
- print_usage(argc, argv, params);
24
+ if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_PASSKEY, print_usage)) {
26
25
  return 1;
27
26
  }
28
27
 
29
- srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed);
28
+ gpt_init();
30
29
 
31
30
  int n_junk = params.n_junk;
32
31
  int n_keep = params.n_keep;
@@ -67,7 +66,7 @@ int main(int argc, char ** argv) {
67
66
  llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
68
67
 
69
68
  if (model == NULL) {
70
- fprintf(stderr , "%s: error: unable to load model\n" , __func__);
69
+ LOG_ERR("%s: unable to load model\n" , __func__);
71
70
  return 1;
72
71
  }
73
72
 
@@ -80,12 +79,17 @@ int main(int argc, char ** argv) {
80
79
  GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
81
80
 
82
81
  llama_context * ctx = llama_new_context_with_model(model, ctx_params);
83
-
84
82
  if (ctx == NULL) {
85
- fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
83
+ LOG_ERR("%s: failed to create the llama_context\n" , __func__);
86
84
  return 1;
87
85
  }
88
86
 
87
+ auto sparams = llama_sampler_chain_default_params();
88
+
89
+ llama_sampler * smpl = llama_sampler_chain_init(sparams);
90
+
91
+ llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
92
+
89
93
  // tokenize the prompt
90
94
  std::vector<llama_token> tokens_list;
91
95
  tokens_list = ::llama_tokenize(ctx, params.prompt, true);
@@ -106,14 +110,14 @@ int main(int argc, char ** argv) {
106
110
  const int n_batch = ctx_params.n_batch;
107
111
  const int n_batch_grp = ctx_params.n_batch/n_grp;
108
112
 
109
- LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos);
113
+ LOG_INF("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d, n_grp = %d, n_batch = %d, n_junk = %d, i_pos = %d\n", __func__, n_len, n_ctx, n_kv_req, n_grp, n_batch, n_junk, i_pos);
110
114
 
111
115
  // print the prompt token-by-token
112
116
 
113
- LOG_TEE("\n");
114
- LOG_TEE("prefix tokens: %d\n", n_tokens_prefix);
115
- LOG_TEE("prompt tokens: %d\n", n_tokens_all);
116
- //LOG_TEE("prompt: %s\n", params.prompt.c_str());
117
+ LOG_INF("\n");
118
+ LOG_INF("prefix tokens: %d\n", n_tokens_prefix);
119
+ LOG_INF("prompt tokens: %d\n", n_tokens_all);
120
+ //LOG_INF("prompt: %s\n", params.prompt.c_str());
117
121
 
118
122
  llama_batch batch = llama_batch_init(params.n_batch, 0, 1);
119
123
 
@@ -144,11 +148,11 @@ int main(int argc, char ** argv) {
144
148
  }
145
149
 
146
150
  if (llama_decode(ctx, batch) != 0) {
147
- LOG_TEE("%s: llama_decode() failed\n", __func__);
151
+ LOG_INF("%s: llama_decode() failed\n", __func__);
148
152
  return 1;
149
153
  }
150
154
 
151
- LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
155
+ LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
152
156
 
153
157
  if (i + n_batch >= n_tokens_all) {
154
158
  break;
@@ -158,7 +162,7 @@ int main(int argc, char ** argv) {
158
162
  for (int i = n_ctx; i < n_tokens_all; i += n_batch) {
159
163
  const int n_discard = n_batch;
160
164
 
161
- LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard);
165
+ LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
162
166
 
163
167
  llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
164
168
  llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
@@ -178,18 +182,18 @@ int main(int argc, char ** argv) {
178
182
  }
179
183
 
180
184
  if (llama_decode(ctx, batch) != 0) {
181
- LOG_TEE("%s: llama_decode() failed\n", __func__);
185
+ LOG_ERR("%s: llama_decode() failed\n", __func__);
182
186
  return 1;
183
187
  }
184
188
 
185
- LOG_TEE("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
189
+ LOG_INF("%s: processed: [%6d, %6d)\n", __func__, i, std::min(i + n_batch, n_tokens_all));
186
190
  }
187
191
 
188
192
  {
189
193
  const int n_discard = n_past - n_ctx + n_predict;
190
194
 
191
195
  if (n_discard > 0) {
192
- LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
196
+ LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
193
197
 
194
198
  llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
195
199
  llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
@@ -200,47 +204,32 @@ int main(int argc, char ** argv) {
200
204
  }
201
205
  }
202
206
 
203
- LOG_TEE("\n");
204
- LOG_TEE("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk);
205
- LOG_TEE("\n");
207
+ LOG_INF("\n");
208
+ LOG_INF("%s: passkey = %d, inserted at position %d / %d (token pos: ~%d)\n", __func__, passkey, i_pos, n_junk, (i_pos * n_tokens_all) / n_junk);
209
+ LOG_INF("\n");
206
210
 
207
211
  // main loop
208
212
 
209
213
  int n_cur = n_tokens_all;
210
214
  int n_decode = 0;
211
215
 
212
- LOG_TEE("%s", prompt_suffix.c_str());
213
- fflush(stdout);
216
+ LOG_INF("%s", prompt_suffix.c_str());
214
217
 
215
218
  const auto t_main_start = ggml_time_us();
216
219
 
217
220
  while (n_cur <= n_len) {
218
221
  // sample the next token
219
222
  {
220
- auto n_vocab = llama_n_vocab(model);
221
- auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
222
-
223
- std::vector<llama_token_data> candidates;
224
- candidates.reserve(n_vocab);
225
-
226
- for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
227
- candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
228
- }
229
-
230
- llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
231
-
232
- // sample the most likely token
233
- const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
223
+ const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
234
224
 
235
225
  // is it an end of generation?
236
226
  if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
237
- LOG_TEE("\n");
227
+ LOG("\n");
238
228
 
239
229
  break;
240
230
  }
241
231
 
242
- LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
243
- fflush(stdout);
232
+ LOG("%s", llama_token_to_piece(ctx, new_token_id).c_str());
244
233
 
245
234
  n_decode += 1;
246
235
 
@@ -255,21 +244,24 @@ int main(int argc, char ** argv) {
255
244
 
256
245
  // evaluate the current batch with the transformer model
257
246
  if (llama_decode(ctx, batch)) {
258
- fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
247
+ LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
259
248
  return 1;
260
249
  }
261
250
  }
262
251
 
263
- LOG_TEE("\n");
252
+ LOG("\n");
264
253
 
265
254
  const auto t_main_end = ggml_time_us();
266
255
 
267
- LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
256
+ LOG_INF("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
268
257
  __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
269
258
 
270
- llama_print_timings(ctx);
259
+ LOG("\n");
260
+ llama_perf_context_print(ctx);
261
+
262
+ LOG("\n");
271
263
 
272
- fprintf(stderr, "\n");
264
+ llama_sampler_free(smpl);
273
265
 
274
266
  llama_batch_free(batch);
275
267