@fugood/llama.node 0.0.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. package/CMakeLists.txt +85 -0
  2. package/README.md +56 -0
  3. package/bin/darwin/arm64/llama-node.node +0 -0
  4. package/bin/darwin/x64/llama-node.node +0 -0
  5. package/bin/linux/arm64/llama-node.node +0 -0
  6. package/bin/linux/x64/llama-node.node +0 -0
  7. package/bin/win32/arm64/llama-node.node +0 -0
  8. package/bin/win32/arm64/node.lib +0 -0
  9. package/bin/win32/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/node.lib +0 -0
  11. package/lib/binding.js +13 -0
  12. package/lib/binding.ts +57 -0
  13. package/lib/index.js +24 -0
  14. package/lib/index.ts +13 -0
  15. package/package.json +65 -0
  16. package/src/addons.cpp +506 -0
  17. package/src/llama.cpp/CMakeLists.txt +1320 -0
  18. package/src/llama.cpp/build.zig +172 -0
  19. package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
  20. package/src/llama.cpp/common/CMakeLists.txt +87 -0
  21. package/src/llama.cpp/common/base64.hpp +392 -0
  22. package/src/llama.cpp/common/common.cpp +2949 -0
  23. package/src/llama.cpp/common/common.h +324 -0
  24. package/src/llama.cpp/common/console.cpp +501 -0
  25. package/src/llama.cpp/common/console.h +19 -0
  26. package/src/llama.cpp/common/grammar-parser.cpp +440 -0
  27. package/src/llama.cpp/common/grammar-parser.h +29 -0
  28. package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/json.hpp +24766 -0
  31. package/src/llama.cpp/common/log.h +724 -0
  32. package/src/llama.cpp/common/ngram-cache.cpp +282 -0
  33. package/src/llama.cpp/common/ngram-cache.h +94 -0
  34. package/src/llama.cpp/common/sampling.cpp +353 -0
  35. package/src/llama.cpp/common/sampling.h +147 -0
  36. package/src/llama.cpp/common/stb_image.h +8396 -0
  37. package/src/llama.cpp/common/train.cpp +1513 -0
  38. package/src/llama.cpp/common/train.h +233 -0
  39. package/src/llama.cpp/examples/CMakeLists.txt +52 -0
  40. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
  41. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
  42. package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
  43. package/src/llama.cpp/examples/batched/batched.cpp +262 -0
  44. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
  45. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
  46. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
  47. package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
  48. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
  49. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
  50. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
  51. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
  52. package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
  54. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
  55. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
  56. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
  58. package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
  59. package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
  60. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
  61. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
  62. package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
  64. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
  65. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
  66. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
  67. package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
  68. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
  69. package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
  70. package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
  71. package/src/llama.cpp/examples/infill/infill.cpp +767 -0
  72. package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
  73. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
  74. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
  75. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
  76. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
  77. package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
  78. package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
  79. package/src/llama.cpp/examples/llava/clip.h +85 -0
  80. package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
  81. package/src/llama.cpp/examples/llava/llava.cpp +426 -0
  82. package/src/llama.cpp/examples/llava/llava.h +50 -0
  83. package/src/llama.cpp/examples/llava/requirements.txt +3 -0
  84. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
  85. package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
  86. package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
  87. package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
  88. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
  89. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
  90. package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
  91. package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
  92. package/src/llama.cpp/examples/main/main.cpp +957 -0
  93. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
  94. package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
  95. package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
  96. package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
  97. package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
  98. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
  99. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
  100. package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
  101. package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
  102. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
  103. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
  104. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
  106. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
  107. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
  108. package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
  109. package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
  110. package/src/llama.cpp/examples/server/httplib.h +9465 -0
  111. package/src/llama.cpp/examples/server/server.cpp +3826 -0
  112. package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
  113. package/src/llama.cpp/examples/server/utils.hpp +653 -0
  114. package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
  115. package/src/llama.cpp/examples/simple/simple.cpp +183 -0
  116. package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
  117. package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
  118. package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
  119. package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
  120. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
  121. package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
  122. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
  123. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
  124. package/src/llama.cpp/ggml-alloc.c +985 -0
  125. package/src/llama.cpp/ggml-alloc.h +76 -0
  126. package/src/llama.cpp/ggml-backend-impl.h +141 -0
  127. package/src/llama.cpp/ggml-backend.c +2099 -0
  128. package/src/llama.cpp/ggml-backend.h +233 -0
  129. package/src/llama.cpp/ggml-common.h +1853 -0
  130. package/src/llama.cpp/ggml-cuda.h +43 -0
  131. package/src/llama.cpp/ggml-impl.h +265 -0
  132. package/src/llama.cpp/ggml-kompute.cpp +2006 -0
  133. package/src/llama.cpp/ggml-kompute.h +46 -0
  134. package/src/llama.cpp/ggml-metal.h +66 -0
  135. package/src/llama.cpp/ggml-mpi.c +216 -0
  136. package/src/llama.cpp/ggml-mpi.h +39 -0
  137. package/src/llama.cpp/ggml-opencl.cpp +2301 -0
  138. package/src/llama.cpp/ggml-opencl.h +36 -0
  139. package/src/llama.cpp/ggml-quants.c +12678 -0
  140. package/src/llama.cpp/ggml-quants.h +133 -0
  141. package/src/llama.cpp/ggml-sycl.cpp +17882 -0
  142. package/src/llama.cpp/ggml-sycl.h +49 -0
  143. package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
  144. package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
  145. package/src/llama.cpp/ggml-vulkan.h +29 -0
  146. package/src/llama.cpp/ggml.c +21819 -0
  147. package/src/llama.cpp/ggml.h +2403 -0
  148. package/src/llama.cpp/llama.cpp +17468 -0
  149. package/src/llama.cpp/llama.h +1117 -0
  150. package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
  151. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
  152. package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
  153. package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
  154. package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
  155. package/src/llama.cpp/prompts/alpaca.txt +1 -0
  156. package/src/llama.cpp/prompts/assistant.txt +31 -0
  157. package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
  158. package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
  159. package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
  160. package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
  161. package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
  162. package/src/llama.cpp/prompts/chat.txt +28 -0
  163. package/src/llama.cpp/prompts/dan-modified.txt +1 -0
  164. package/src/llama.cpp/prompts/dan.txt +1 -0
  165. package/src/llama.cpp/prompts/mnemonics.txt +93 -0
  166. package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
  167. package/src/llama.cpp/prompts/reason-act.txt +18 -0
  168. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
  169. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
  170. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
  171. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
  172. package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
  173. package/src/llama.cpp/requirements.txt +12 -0
  174. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
  175. package/src/llama.cpp/scripts/xxd.cmake +16 -0
  176. package/src/llama.cpp/sgemm.cpp +999 -0
  177. package/src/llama.cpp/sgemm.h +12 -0
  178. package/src/llama.cpp/tests/CMakeLists.txt +78 -0
  179. package/src/llama.cpp/tests/get-model.cpp +21 -0
  180. package/src/llama.cpp/tests/get-model.h +2 -0
  181. package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
  182. package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
  183. package/src/llama.cpp/tests/test-c.c +7 -0
  184. package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
  185. package/src/llama.cpp/tests/test-double-float.cpp +57 -0
  186. package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
  187. package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
  188. package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
  189. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
  190. package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
  191. package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
  192. package/src/llama.cpp/tests/test-opt.cpp +181 -0
  193. package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
  194. package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
  195. package/src/llama.cpp/tests/test-rope.cpp +221 -0
  196. package/src/llama.cpp/tests/test-sampling.cpp +301 -0
  197. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
  198. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
  199. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
  200. package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
  201. package/src/llama.cpp/unicode-data.cpp +1651 -0
  202. package/src/llama.cpp/unicode-data.h +16 -0
  203. package/src/llama.cpp/unicode.cpp +277 -0
  204. package/src/llama.cpp/unicode.h +28 -0
@@ -0,0 +1,5 @@
1
+ set(TARGET batched)
2
+ add_executable(${TARGET} batched.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)
@@ -0,0 +1,262 @@
1
+ #include "common.h"
2
+ #include "llama.h"
3
+
4
+ #include <algorithm>
5
+ #include <cmath>
6
+ #include <cstdio>
7
+ #include <string>
8
+ #include <vector>
9
+
10
+ int main(int argc, char ** argv) {
11
+ gpt_params params;
12
+
13
+ if (argc == 1 || argv[1][0] == '-') {
14
+ printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]);
15
+ return 1 ;
16
+ }
17
+
18
+ // number of parallel batches
19
+ int n_parallel = 1;
20
+
21
+ // total length of the sequences including the prompt
22
+ int n_len = 32;
23
+
24
+ // number of layers to offload to the GPU
25
+ int n_gpu_layers = 0;
26
+
27
+ if (argc >= 2) {
28
+ params.model = argv[1];
29
+ }
30
+
31
+ if (argc >= 3) {
32
+ params.prompt = argv[2];
33
+ }
34
+
35
+ if (argc >= 4) {
36
+ n_parallel = std::atoi(argv[3]);
37
+ }
38
+
39
+ if (argc >= 5) {
40
+ n_len = std::atoi(argv[4]);
41
+ }
42
+
43
+ if (argc >= 6) {
44
+ n_gpu_layers = std::atoi(argv[5]);
45
+ }
46
+
47
+ if (params.prompt.empty()) {
48
+ params.prompt = "Hello my name is";
49
+ }
50
+
51
+ process_escapes(params.prompt);
52
+
53
+ // init LLM
54
+
55
+ llama_backend_init();
56
+ llama_numa_init(params.numa);
57
+
58
+ // initialize the model
59
+
60
+ llama_model_params model_params = llama_model_default_params();
61
+
62
+ model_params.n_gpu_layers = n_gpu_layers;
63
+
64
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
65
+
66
+ if (model == NULL) {
67
+ fprintf(stderr , "%s: error: unable to load model\n" , __func__);
68
+ return 1;
69
+ }
70
+
71
+ // tokenize the prompt
72
+
73
+ std::vector<llama_token> tokens_list;
74
+ tokens_list = ::llama_tokenize(model, params.prompt, true);
75
+
76
+ const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
77
+
78
+ // initialize the context
79
+
80
+ llama_context_params ctx_params = llama_context_default_params();
81
+
82
+ ctx_params.seed = 1234;
83
+ ctx_params.n_ctx = n_kv_req;
84
+ ctx_params.n_batch = std::max(n_len, n_parallel);
85
+ ctx_params.n_seq_max = n_parallel;
86
+ ctx_params.n_threads = params.n_threads;
87
+ ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
88
+
89
+ llama_context * ctx = llama_new_context_with_model(model, ctx_params);
90
+
91
+ if (ctx == NULL) {
92
+ fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
93
+ return 1;
94
+ }
95
+
96
+ const int n_ctx = llama_n_ctx(ctx);
97
+
98
+ LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %u, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
99
+
100
+ // make sure the KV cache is big enough to hold all the prompt and generated tokens
101
+ if (n_kv_req > n_ctx) {
102
+ LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
103
+ LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
104
+ return 1;
105
+ }
106
+
107
+ // print the prompt token-by-token
108
+
109
+ fprintf(stderr, "\n");
110
+
111
+ for (auto id : tokens_list) {
112
+ fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
113
+ }
114
+
115
+ fflush(stderr);
116
+
117
+ // create a llama_batch
118
+ // we use this object to submit token data for decoding
119
+ llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0, 1);
120
+
121
+ // evaluate the initial prompt
122
+ for (size_t i = 0; i < tokens_list.size(); ++i) {
123
+ llama_batch_add(batch, tokens_list[i], i, { 0 }, false);
124
+ }
125
+ GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
126
+
127
+ // llama_decode will output logits only for the last token of the prompt
128
+ batch.logits[batch.n_tokens - 1] = true;
129
+
130
+ if (llama_decode(ctx, batch) != 0) {
131
+ LOG_TEE("%s: llama_decode() failed\n", __func__);
132
+ return 1;
133
+ }
134
+
135
+ // assign the system KV cache to all parallel sequences
136
+ // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
137
+ for (int32_t i = 1; i < n_parallel; ++i) {
138
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
139
+ }
140
+
141
+ if (n_parallel > 1) {
142
+ LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
143
+ }
144
+
145
+ // main loop
146
+
147
+ // we will store the parallel decoded sequences in this vector
148
+ std::vector<std::string> streams(n_parallel);
149
+
150
+ // remember the batch index of the last token for each parallel sequence
151
+ // we need this to determine which logits to sample from
152
+ std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
153
+
154
+ int n_cur = batch.n_tokens;
155
+ int n_decode = 0;
156
+
157
+ const auto t_main_start = ggml_time_us();
158
+
159
+ while (n_cur <= n_len) {
160
+ // prepare the next batch
161
+ llama_batch_clear(batch);
162
+
163
+ // sample the next token for each parallel sequence / stream
164
+ for (int32_t i = 0; i < n_parallel; ++i) {
165
+ if (i_batch[i] < 0) {
166
+ // the stream has already finished
167
+ continue;
168
+ }
169
+
170
+ auto n_vocab = llama_n_vocab(model);
171
+ auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
172
+
173
+ std::vector<llama_token_data> candidates;
174
+ candidates.reserve(n_vocab);
175
+
176
+ for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
177
+ candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
178
+ }
179
+
180
+ llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
181
+
182
+ const int top_k = 40;
183
+ const float top_p = 0.9f;
184
+ const float temp = 0.4f;
185
+
186
+ llama_sample_top_k(ctx, &candidates_p, top_k, 1);
187
+ llama_sample_top_p(ctx, &candidates_p, top_p, 1);
188
+ llama_sample_temp (ctx, &candidates_p, temp);
189
+
190
+ const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
191
+
192
+ //const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
193
+
194
+ // is it an end of generation? -> mark the stream as finished
195
+ if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
196
+ i_batch[i] = -1;
197
+ LOG_TEE("\n");
198
+ if (n_parallel > 1) {
199
+ LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
200
+ }
201
+
202
+ continue;
203
+ }
204
+
205
+ // if there is only one stream, we print immediately to stdout
206
+ if (n_parallel == 1) {
207
+ LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
208
+ fflush(stdout);
209
+ }
210
+
211
+ streams[i] += llama_token_to_piece(ctx, new_token_id);
212
+
213
+ i_batch[i] = batch.n_tokens;
214
+
215
+ // push this new token for next evaluation
216
+ llama_batch_add(batch, new_token_id, n_cur, { i }, true);
217
+
218
+ n_decode += 1;
219
+ }
220
+
221
+ // all streams are finished
222
+ if (batch.n_tokens == 0) {
223
+ break;
224
+ }
225
+
226
+ n_cur += 1;
227
+
228
+ // evaluate the current batch with the transformer model
229
+ if (llama_decode(ctx, batch)) {
230
+ fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
231
+ return 1;
232
+ }
233
+ }
234
+
235
+ LOG_TEE("\n");
236
+
237
+ if (n_parallel > 1) {
238
+ LOG_TEE("\n");
239
+
240
+ for (int32_t i = 0; i < n_parallel; ++i) {
241
+ LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
242
+ }
243
+ }
244
+
245
+ const auto t_main_end = ggml_time_us();
246
+
247
+ LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
248
+ __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
249
+
250
+ llama_print_timings(ctx);
251
+
252
+ fprintf(stderr, "\n");
253
+
254
+ llama_batch_free(batch);
255
+
256
+ llama_free(ctx);
257
+ llama_free_model(model);
258
+
259
+ llama_backend_free();
260
+
261
+ return 0;
262
+ }
@@ -0,0 +1,5 @@
1
+ set(TARGET batched-bench)
2
+ add_executable(${TARGET} batched-bench.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)
@@ -0,0 +1,261 @@
1
+ #include "common.h"
2
+ #include "llama.h"
3
+
4
+ #include <algorithm>
5
+ #include <cmath>
6
+ #include <cstdio>
7
+ #include <string>
8
+ #include <vector>
9
+
10
+ // mutates the input string
11
+ static std::vector<int> parse_list(char * p) {
12
+ std::vector<int> ret;
13
+
14
+ char * q = p;
15
+
16
+ while (*p) {
17
+ if (*p == ',') {
18
+ *p = '\0';
19
+ ret.push_back(std::atoi(q));
20
+ q = p + 1;
21
+ }
22
+
23
+ ++p;
24
+ }
25
+
26
+ ret.push_back(std::atoi(q));
27
+
28
+ return ret;
29
+ }
30
+
31
+ int main(int argc, char ** argv) {
32
+ gpt_params params;
33
+
34
+ if (argc == 1 || argv[1][0] == '-') {
35
+ printf("usage: %s MODEL_PATH [N_KV_MAX] [N_BATCH] [N_UBATCH] [IS_PP_SHARED] [NGL] <PP> <TG> <PL>\n" , argv[0]);
36
+ printf(" <PP>, <TG> and PL are comma-separated lists of numbers without spaces\n\n");
37
+ printf(" example: %s ggml-model-f16.gguf 2048 2048 512 0 999 128,256,512 128,256 1,2,4,8,16,32\n\n", argv[0]);
38
+ return 1 ;
39
+ }
40
+
41
+ int n_kv_max = 2048;
42
+ int n_batch = 2048;
43
+ int n_ubatch = 512;
44
+ int is_pp_shared = 0;
45
+ int n_gpu_layers = 0;
46
+
47
+ std::vector<int> n_pp = { 128, 256, 512, 1024, 2048, 3584, 7680, };
48
+ std::vector<int> n_tg = { 128, 256, };
49
+ std::vector<int> n_pl = { 1, 2, 4, 8, 16, 32, };
50
+ //std::vector<int> n_pl = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 32, };
51
+
52
+ if (argc >= 2) {
53
+ params.model = argv[1];
54
+ }
55
+
56
+ if (argc >= 3) {
57
+ n_kv_max = std::atoi(argv[2]);
58
+ }
59
+
60
+ if (argc >= 4) {
61
+ n_batch = std::atoi(argv[3]);
62
+ }
63
+
64
+ if (argc >= 5) {
65
+ n_ubatch = std::atoi(argv[4]);
66
+ }
67
+
68
+ if (argc >= 6) {
69
+ is_pp_shared = std::atoi(argv[5]);
70
+ }
71
+
72
+ if (argc >= 7) {
73
+ n_gpu_layers = std::atoi(argv[6]);
74
+ }
75
+
76
+ if (argc >= 8) {
77
+ n_pp = parse_list(argv[7]);
78
+ }
79
+
80
+ if (argc >= 9) {
81
+ n_tg = parse_list(argv[8]);
82
+ }
83
+
84
+ if (argc >= 10) {
85
+ n_pl = parse_list(argv[9]);
86
+ }
87
+
88
+ // init LLM
89
+
90
+ llama_backend_init();
91
+ llama_numa_init(params.numa);
92
+
93
+ // initialize the model
94
+
95
+ llama_model_params model_params = llama_model_default_params();
96
+
97
+ const std::vector<float> t_split(llama_max_devices(), 0.0f);
98
+
99
+ model_params.n_gpu_layers = n_gpu_layers;
100
+ model_params.tensor_split = t_split.data();
101
+
102
+ llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params);
103
+
104
+ if (model == NULL) {
105
+ fprintf(stderr , "%s: error: unable to load model\n" , __func__);
106
+ return 1;
107
+ }
108
+
109
+ llama_context_params ctx_params = llama_context_default_params();
110
+
111
+ ctx_params.seed = 1234;
112
+ ctx_params.n_ctx = n_kv_max;
113
+ ctx_params.n_batch = n_batch;
114
+ ctx_params.n_ubatch = n_ubatch;
115
+
116
+ ctx_params.n_threads = params.n_threads;
117
+ ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
118
+
119
+ // ensure enough sequences are available
120
+ ctx_params.n_seq_max = *std::max_element(n_pl.begin(), n_pl.end());
121
+
122
+ llama_context * ctx = llama_new_context_with_model(model, ctx_params);
123
+
124
+ if (ctx == NULL) {
125
+ fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
126
+ return 1;
127
+ }
128
+
129
+ llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
130
+
131
+ // decode in batches of ctx_params.n_batch tokens
132
+ auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
133
+ for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
134
+ const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
135
+
136
+ llama_batch batch_view = {
137
+ n_tokens,
138
+ batch.token + i,
139
+ nullptr,
140
+ batch.pos + i,
141
+ batch.n_seq_id + i,
142
+ batch.seq_id + i,
143
+ batch.logits + i,
144
+ 0, 0, 0, // unused
145
+ };
146
+
147
+ const int ret = llama_decode(ctx, batch_view);
148
+ if (ret != 0) {
149
+ LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
150
+ return false;
151
+ }
152
+
153
+ llama_synchronize(ctx);
154
+ }
155
+
156
+ return true;
157
+ };
158
+
159
+ // warm up
160
+ {
161
+ for (int i = 0; i < 16; ++i) {
162
+ llama_batch_add(batch, 0, i, { 0 }, false);
163
+ }
164
+
165
+ if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
166
+ LOG_TEE("%s: llama_decode() failed\n", __func__);
167
+ return 1;
168
+ }
169
+ }
170
+
171
+ LOG_TEE("\n");
172
+ LOG_TEE("%s: n_kv_max = %d, n_batch = %d, n_ubatch = %d, is_pp_shared = %d, n_gpu_layers = %d, n_threads = %u, n_threads_batch = %u\n", __func__, n_kv_max, n_batch, n_ubatch, is_pp_shared, n_gpu_layers, ctx_params.n_threads, ctx_params.n_threads_batch);
173
+ LOG_TEE("\n");
174
+
175
+ LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP", "TG", "B", "N_KV", "T_PP s", "S_PP t/s", "T_TG s", "S_TG t/s", "T s", "S t/s");
176
+ LOG_TEE("|%6s-|-%6s-|-%4s-|-%6s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|-%8s-|\n", "------", "------", "----", "------", "--------", "--------", "--------", "--------", "--------", "--------");
177
+
178
+ for ( int i_pp = 0; i_pp < (int) n_pp.size(); ++i_pp) {
179
+ for ( int i_tg = 0; i_tg < (int) n_tg.size(); ++i_tg) {
180
+ for (int i_pl = 0; i_pl < (int) n_pl.size(); ++i_pl) {
181
+ const int pp = n_pp[i_pp];
182
+ const int tg = n_tg[i_tg];
183
+ const int pl = n_pl[i_pl];
184
+
185
+ const int n_ctx_req = is_pp_shared ? pp + pl*tg : pl*(pp + tg);
186
+
187
+ if (n_ctx_req > n_kv_max) {
188
+ continue;
189
+ }
190
+
191
+ llama_batch_clear(batch);
192
+
193
+ for (int i = 0; i < pp; ++i) {
194
+ for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
195
+ llama_batch_add(batch, 0, i, { j }, false);
196
+ }
197
+ }
198
+ batch.logits[batch.n_tokens - 1] = true;
199
+
200
+ const auto t_pp_start = ggml_time_us();
201
+
202
+ llama_kv_cache_clear(ctx);
203
+
204
+ if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
205
+ LOG_TEE("%s: llama_decode() failed\n", __func__);
206
+ return 1;
207
+ }
208
+
209
+ if (is_pp_shared) {
210
+ for (int32_t i = 1; i < pl; ++i) {
211
+ llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
212
+ }
213
+ }
214
+
215
+ const auto t_pp_end = ggml_time_us();
216
+
217
+ const auto t_tg_start = ggml_time_us();
218
+
219
+ for (int i = 0; i < tg; ++i) {
220
+ llama_batch_clear(batch);
221
+
222
+ for (int j = 0; j < pl; ++j) {
223
+ llama_batch_add(batch, 0, pp + i, { j }, true);
224
+ }
225
+
226
+ if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
227
+ LOG_TEE("%s: llama_decode() failed\n", __func__);
228
+ return 1;
229
+ }
230
+ }
231
+
232
+ const auto t_tg_end = ggml_time_us();
233
+
234
+ const int32_t n_kv = n_ctx_req;
235
+
236
+ const float t_pp = (t_pp_end - t_pp_start) / 1000000.0f;
237
+ const float t_tg = (t_tg_end - t_tg_start) / 1000000.0f;
238
+ const float t = t_pp + t_tg;
239
+
240
+ const float speed_pp = is_pp_shared ? pp / t_pp : pl*pp / t_pp;
241
+ const float speed_tg = pl*tg / t_tg;
242
+ const float speed = n_kv / t;
243
+
244
+ LOG_TEE("|%6d | %6d | %4d | %6d | %8.3f | %8.2f | %8.3f | %8.2f | %8.3f | %8.2f |\n", pp, tg, pl, n_kv, t_pp, speed_pp, t_tg, speed_tg, t, speed);
245
+ }
246
+ }
247
+ }
248
+
249
+ llama_print_timings(ctx);
250
+
251
+ llama_batch_free(batch);
252
+
253
+ llama_free(ctx);
254
+ llama_free_model(model);
255
+
256
+ llama_backend_free();
257
+
258
+ fprintf(stderr, "\n\n");
259
+
260
+ return 0;
261
+ }
@@ -0,0 +1,5 @@
1
+ set(TARGET beam-search)
2
+ add_executable(${TARGET} beam-search.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)