@fugood/llama.node 0.0.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. package/CMakeLists.txt +85 -0
  2. package/README.md +56 -0
  3. package/bin/darwin/arm64/llama-node.node +0 -0
  4. package/bin/darwin/x64/llama-node.node +0 -0
  5. package/bin/linux/arm64/llama-node.node +0 -0
  6. package/bin/linux/x64/llama-node.node +0 -0
  7. package/bin/win32/arm64/llama-node.node +0 -0
  8. package/bin/win32/arm64/node.lib +0 -0
  9. package/bin/win32/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/node.lib +0 -0
  11. package/lib/binding.js +13 -0
  12. package/lib/binding.ts +57 -0
  13. package/lib/index.js +24 -0
  14. package/lib/index.ts +13 -0
  15. package/package.json +65 -0
  16. package/src/addons.cpp +506 -0
  17. package/src/llama.cpp/CMakeLists.txt +1320 -0
  18. package/src/llama.cpp/build.zig +172 -0
  19. package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
  20. package/src/llama.cpp/common/CMakeLists.txt +87 -0
  21. package/src/llama.cpp/common/base64.hpp +392 -0
  22. package/src/llama.cpp/common/common.cpp +2949 -0
  23. package/src/llama.cpp/common/common.h +324 -0
  24. package/src/llama.cpp/common/console.cpp +501 -0
  25. package/src/llama.cpp/common/console.h +19 -0
  26. package/src/llama.cpp/common/grammar-parser.cpp +440 -0
  27. package/src/llama.cpp/common/grammar-parser.h +29 -0
  28. package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/json.hpp +24766 -0
  31. package/src/llama.cpp/common/log.h +724 -0
  32. package/src/llama.cpp/common/ngram-cache.cpp +282 -0
  33. package/src/llama.cpp/common/ngram-cache.h +94 -0
  34. package/src/llama.cpp/common/sampling.cpp +353 -0
  35. package/src/llama.cpp/common/sampling.h +147 -0
  36. package/src/llama.cpp/common/stb_image.h +8396 -0
  37. package/src/llama.cpp/common/train.cpp +1513 -0
  38. package/src/llama.cpp/common/train.h +233 -0
  39. package/src/llama.cpp/examples/CMakeLists.txt +52 -0
  40. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
  41. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
  42. package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
  43. package/src/llama.cpp/examples/batched/batched.cpp +262 -0
  44. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
  45. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
  46. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
  47. package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
  48. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
  49. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
  50. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
  51. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
  52. package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
  54. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
  55. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
  56. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
  58. package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
  59. package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
  60. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
  61. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
  62. package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
  64. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
  65. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
  66. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
  67. package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
  68. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
  69. package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
  70. package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
  71. package/src/llama.cpp/examples/infill/infill.cpp +767 -0
  72. package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
  73. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
  74. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
  75. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
  76. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
  77. package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
  78. package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
  79. package/src/llama.cpp/examples/llava/clip.h +85 -0
  80. package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
  81. package/src/llama.cpp/examples/llava/llava.cpp +426 -0
  82. package/src/llama.cpp/examples/llava/llava.h +50 -0
  83. package/src/llama.cpp/examples/llava/requirements.txt +3 -0
  84. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
  85. package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
  86. package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
  87. package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
  88. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
  89. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
  90. package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
  91. package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
  92. package/src/llama.cpp/examples/main/main.cpp +957 -0
  93. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
  94. package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
  95. package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
  96. package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
  97. package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
  98. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
  99. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
  100. package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
  101. package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
  102. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
  103. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
  104. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
  106. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
  107. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
  108. package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
  109. package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
  110. package/src/llama.cpp/examples/server/httplib.h +9465 -0
  111. package/src/llama.cpp/examples/server/server.cpp +3826 -0
  112. package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
  113. package/src/llama.cpp/examples/server/utils.hpp +653 -0
  114. package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
  115. package/src/llama.cpp/examples/simple/simple.cpp +183 -0
  116. package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
  117. package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
  118. package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
  119. package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
  120. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
  121. package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
  122. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
  123. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
  124. package/src/llama.cpp/ggml-alloc.c +985 -0
  125. package/src/llama.cpp/ggml-alloc.h +76 -0
  126. package/src/llama.cpp/ggml-backend-impl.h +141 -0
  127. package/src/llama.cpp/ggml-backend.c +2099 -0
  128. package/src/llama.cpp/ggml-backend.h +233 -0
  129. package/src/llama.cpp/ggml-common.h +1853 -0
  130. package/src/llama.cpp/ggml-cuda.h +43 -0
  131. package/src/llama.cpp/ggml-impl.h +265 -0
  132. package/src/llama.cpp/ggml-kompute.cpp +2006 -0
  133. package/src/llama.cpp/ggml-kompute.h +46 -0
  134. package/src/llama.cpp/ggml-metal.h +66 -0
  135. package/src/llama.cpp/ggml-mpi.c +216 -0
  136. package/src/llama.cpp/ggml-mpi.h +39 -0
  137. package/src/llama.cpp/ggml-opencl.cpp +2301 -0
  138. package/src/llama.cpp/ggml-opencl.h +36 -0
  139. package/src/llama.cpp/ggml-quants.c +12678 -0
  140. package/src/llama.cpp/ggml-quants.h +133 -0
  141. package/src/llama.cpp/ggml-sycl.cpp +17882 -0
  142. package/src/llama.cpp/ggml-sycl.h +49 -0
  143. package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
  144. package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
  145. package/src/llama.cpp/ggml-vulkan.h +29 -0
  146. package/src/llama.cpp/ggml.c +21819 -0
  147. package/src/llama.cpp/ggml.h +2403 -0
  148. package/src/llama.cpp/llama.cpp +17468 -0
  149. package/src/llama.cpp/llama.h +1117 -0
  150. package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
  151. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
  152. package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
  153. package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
  154. package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
  155. package/src/llama.cpp/prompts/alpaca.txt +1 -0
  156. package/src/llama.cpp/prompts/assistant.txt +31 -0
  157. package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
  158. package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
  159. package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
  160. package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
  161. package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
  162. package/src/llama.cpp/prompts/chat.txt +28 -0
  163. package/src/llama.cpp/prompts/dan-modified.txt +1 -0
  164. package/src/llama.cpp/prompts/dan.txt +1 -0
  165. package/src/llama.cpp/prompts/mnemonics.txt +93 -0
  166. package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
  167. package/src/llama.cpp/prompts/reason-act.txt +18 -0
  168. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
  169. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
  170. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
  171. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
  172. package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
  173. package/src/llama.cpp/requirements.txt +12 -0
  174. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
  175. package/src/llama.cpp/scripts/xxd.cmake +16 -0
  176. package/src/llama.cpp/sgemm.cpp +999 -0
  177. package/src/llama.cpp/sgemm.h +12 -0
  178. package/src/llama.cpp/tests/CMakeLists.txt +78 -0
  179. package/src/llama.cpp/tests/get-model.cpp +21 -0
  180. package/src/llama.cpp/tests/get-model.h +2 -0
  181. package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
  182. package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
  183. package/src/llama.cpp/tests/test-c.c +7 -0
  184. package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
  185. package/src/llama.cpp/tests/test-double-float.cpp +57 -0
  186. package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
  187. package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
  188. package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
  189. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
  190. package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
  191. package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
  192. package/src/llama.cpp/tests/test-opt.cpp +181 -0
  193. package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
  194. package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
  195. package/src/llama.cpp/tests/test-rope.cpp +221 -0
  196. package/src/llama.cpp/tests/test-sampling.cpp +301 -0
  197. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
  198. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
  199. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
  200. package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
  201. package/src/llama.cpp/unicode-data.cpp +1651 -0
  202. package/src/llama.cpp/unicode-data.h +16 -0
  203. package/src/llama.cpp/unicode.cpp +277 -0
  204. package/src/llama.cpp/unicode.h +28 -0
@@ -0,0 +1,485 @@
1
+ #include "common.h"
2
+ #include "llama.h"
3
+
4
+ #include <cmath>
5
+ #include <cstdio>
6
+ #include <string>
7
+ #include <vector>
8
+
9
+ struct ngram_data {
10
+ bool active = false;
11
+
12
+ llama_seq_id seq_id = -1;
13
+
14
+ std::vector<int> i_batch;
15
+
16
+ std::vector<llama_token> tokens;
17
+ };
18
+
19
+ // n-gram container
20
+ struct ngram_container {
21
+ ngram_container(int n_vocab, int N, int G) {
22
+ cnt.resize(n_vocab);
23
+ head.resize(n_vocab);
24
+ tokens.resize(n_vocab * G * (N - 1));
25
+ }
26
+
27
+ int n_total = 0;
28
+
29
+ std::vector<int> cnt;
30
+ std::vector<int> head;
31
+
32
+ // [n_vocab][G][N - 1]
33
+ // for each token of the vocab, keep a ring-buffer of capacity G of n-grams of size N - 1
34
+ std::vector<llama_token> tokens;
35
+ };
36
+
37
+ int main(int argc, char ** argv) {
38
+ gpt_params params;
39
+
40
+ if (gpt_params_parse(argc, argv, params) == false) {
41
+ return 1;
42
+ }
43
+
44
+ const int W = 15; // lookahead window
45
+ const int N = 5; // n-gram size
46
+ const int G = 15; // max verification n-grams
47
+
48
+ const bool dump_kv_cache = params.dump_kv_cache;
49
+
50
+ #ifndef LOG_DISABLE_LOGS
51
+ log_set_target(log_filename_generator("lookahead", "log"));
52
+ LOG_TEE("Log start\n");
53
+ log_dump_cmdline(argc, argv);
54
+ #endif // LOG_DISABLE_LOGS
55
+
56
+ // init llama.cpp
57
+ llama_backend_init();
58
+ llama_numa_init(params.numa);
59
+
60
+ llama_model * model = NULL;
61
+ llama_context * ctx = NULL;
62
+
63
+ // load the target model
64
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
65
+
66
+ // Tokenize the prompt
67
+ std::vector<llama_token> inp;
68
+ std::vector<llama_token> all;
69
+
70
+ inp = ::llama_tokenize(ctx, params.prompt, true, true);
71
+ all = inp;
72
+
73
+ const int max_context_size = llama_n_ctx(ctx);
74
+ const int max_tokens_list_size = max_context_size - 4;
75
+
76
+ if ((int) inp.size() > max_tokens_list_size) {
77
+ fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
78
+ return 1;
79
+ }
80
+
81
+ fprintf(stderr, "\n\n");
82
+
83
+ for (auto id : inp) {
84
+ fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
85
+ }
86
+
87
+ fflush(stderr);
88
+
89
+ const int n_input = inp.size();
90
+
91
+ const auto t_enc_start = ggml_time_us();
92
+
93
+ // eval the prompt
94
+ llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
95
+ llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
96
+
97
+ for (int s = 1; s < W + G + 1; ++s) {
98
+ llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
99
+ }
100
+
101
+ const auto t_enc_end = ggml_time_us();
102
+
103
+ int n_predict = 0;
104
+ int n_accept = 0;
105
+
106
+ int n_past = inp.size();
107
+
108
+ llama_token id = 0;
109
+
110
+ // used to determine end of generation
111
+ bool has_eos = false;
112
+
113
+ // for each decoded batch, we have at most W + G + 1 distinct sequences:
114
+ // seq_id == 0 : the current input token
115
+ // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
116
+ // seq_id [W + 1, W + G] : verification n-grams
117
+ llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
118
+
119
+ // target model sampling context
120
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
121
+
122
+ // verification n-grams
123
+ std::vector<ngram_data> ngrams_cur(G);
124
+
125
+ // tokens for the past N - 1 Jacobi iterations
126
+ std::vector<llama_token> tokens_j_prev(W);
127
+ std::vector<std::vector<llama_token>> tokens_j(N - 1);
128
+ for (int j = 0; j < N - 1; j++) {
129
+ tokens_j[j].resize(W);
130
+
131
+ for (int i = 0; i < W; i++) {
132
+ // there are different ways to init these tokens
133
+ if (0) {
134
+ // initialize randomly from the prompt tokens
135
+ tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
136
+ } else {
137
+ // initialize with a sequence of increasing numbers
138
+ tokens_j[j][i] = 100 + i;
139
+ }
140
+ }
141
+ }
142
+
143
+ std::vector<llama_seq_id> seq_id_look;
144
+
145
+ // the input token belongs both to all sequences
146
+ std::vector<llama_seq_id> seq_id_all(W + G + 1);
147
+ for (int i = 0; i < W + G + 1; i++) {
148
+ seq_id_all[i] = i;
149
+ }
150
+
151
+ // here we keep adding new n-grams as we go
152
+ ngram_container ngrams_observed(llama_n_vocab(model), N, G);
153
+
154
+ // debug
155
+ struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
156
+
157
+ const auto t_dec_start = ggml_time_us();
158
+
159
+ // sample first token
160
+ {
161
+ id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
162
+
163
+ llama_sampling_accept(ctx_sampling, ctx, id, true);
164
+
165
+ {
166
+ const std::string token_str = llama_token_to_piece(ctx, id);
167
+
168
+ printf("%s", token_str.c_str());
169
+ fflush(stdout);
170
+ }
171
+ }
172
+
173
+ while (true) {
174
+ // debug
175
+ if (dump_kv_cache) {
176
+ llama_kv_cache_view_update(ctx, &kvc_view);
177
+ dump_kv_cache_view_seqs(kvc_view, 40);
178
+ }
179
+
180
+ // build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/
181
+ //
182
+ // Example for W = 5, N = 4, G = 2:
183
+ // (I = input, L = lookahead, V = verification)
184
+ //
185
+ // Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
186
+ // T: -2 -2 -2 -2 -1 -1 -1 -1 -1 0 0 0 0 0 0
187
+ // Info: I L L L L L L L L L L L L L L V V V V V V
188
+ // Pos: 0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 1 2 3 1 2 3 (+ n_past)
189
+ // Logits: 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
190
+ // ---------------------------------------------------------------------
191
+ // Seq: 0
192
+ // 1 1 1
193
+ // 2 2 2 2
194
+ // 3 3 3 3 3
195
+ // 4 4 4 4 4 4
196
+ // 5 5 5 5 5 5 5
197
+ // 6 6 6 6
198
+ // 7 7 7 7
199
+ // ---------------------------------------------------------------------
200
+ // | | | | | | | | | | |
201
+ // V V V V V | | | | | |
202
+ // j_tokens | | | | | |
203
+ // V V V V V V
204
+ // id
205
+ {
206
+ llama_batch_clear(batch);
207
+
208
+ // current token - first token of the first level
209
+ llama_batch_add(batch, id, n_past, seq_id_all, true);
210
+
211
+ // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
212
+ {
213
+ const int g_cur = ngrams_observed.cnt[id];
214
+
215
+ ngrams_cur.resize(g_cur);
216
+ for (int g = 0; g < g_cur; g++) {
217
+ ngrams_cur[g].active = true;
218
+ ngrams_cur[g].tokens.resize(N);
219
+ ngrams_cur[g].i_batch.resize(N);
220
+ ngrams_cur[g].seq_id = W + 1 + g;
221
+ ngrams_cur[g].i_batch[0] = 0;
222
+ ngrams_cur[g].tokens [0] = id;
223
+ }
224
+
225
+ for (int j = 0; j < N - 1; j++) {
226
+ for (int g = 0; g < g_cur; g++) {
227
+ const int idx = id*(N - 1)*G + g*(N - 1);
228
+
229
+ const llama_token t = ngrams_observed.tokens[idx + j];
230
+
231
+ ngrams_cur[g].tokens [j + 1] = t;
232
+ ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
233
+
234
+ llama_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
235
+ }
236
+ }
237
+ }
238
+
239
+ // fill the remaining W - 1 tokens for the first level
240
+ for (int i = 1; i < W; i++) {
241
+ seq_id_look.resize(W - i);
242
+ for (int j = 0; j < W - i; j++) {
243
+ seq_id_look[j] = i + j + 1;
244
+ }
245
+
246
+ llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
247
+ }
248
+
249
+ // fill the rest of the levels
250
+ for (int j = 1; j < N - 1; j++) {
251
+ for (int i = 0; i < W; i++) {
252
+ llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
253
+ }
254
+ }
255
+ }
256
+
257
+ if (llama_decode(ctx, batch) != 0) {
258
+ fprintf(stderr, "\n\n%s: error: llama_decode failed - increase KV cache size\n", __func__);
259
+ return 1;
260
+ }
261
+
262
+ int seq_id_best = 0;
263
+
264
+ for (int v = 0; v < N; ++v) {
265
+ int i_batch = 0;
266
+
267
+ // if no active ngrams are left, it means the sampled token does not pass the verification
268
+ if (v > 0) {
269
+ for (int g = 0; g < (int) ngrams_cur.size(); g++) {
270
+ if (ngrams_cur[g].active) {
271
+ i_batch = ngrams_cur[g].i_batch[v];
272
+ seq_id_best = ngrams_cur[g].seq_id;
273
+
274
+ ++n_accept;
275
+ break;
276
+ }
277
+ }
278
+
279
+ // no more matches -> create a new batch
280
+ if (i_batch == 0) {
281
+ break;
282
+ }
283
+ }
284
+
285
+ // sample the next token
286
+ id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
287
+
288
+ llama_sampling_accept(ctx_sampling, ctx, id, true);
289
+
290
+ // print
291
+ {
292
+ const std::string token_str = llama_token_to_piece(ctx, id);
293
+
294
+ if (v == 0) {
295
+ printf("%s", token_str.c_str());
296
+ } else {
297
+ // print light cyan
298
+ printf("\033[0;96m%s\033[0m", token_str.c_str());
299
+ }
300
+ fflush(stdout);
301
+
302
+ if (llama_token_is_eog(model, id)) {
303
+ has_eos = true;
304
+ }
305
+
306
+ all.push_back(id);
307
+ }
308
+
309
+ ++n_predict;
310
+ ++n_past;
311
+
312
+ if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
313
+ break;
314
+ }
315
+
316
+ // verify across active n-grams
317
+ for (int g = 0; g < (int) ngrams_cur.size(); g++) {
318
+ if (ngrams_cur[g].active) {
319
+ if (v == N - 1) {
320
+ ngrams_cur[g].active = false;
321
+ } else {
322
+ if (id != ngrams_cur[g].tokens[v + 1]) {
323
+ ngrams_cur[g].active = false;
324
+ }
325
+ }
326
+ }
327
+ }
328
+
329
+ // print known n-grams starting with token id (debug)
330
+ if (0 && v == 0) {
331
+ if (ngrams_observed.cnt[id] > 0) {
332
+ printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
333
+ }
334
+
335
+ for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
336
+ printf(" - ngram %2d: ", i);
337
+
338
+ const int idx = id*(N - 1)*G + i*(N - 1);
339
+
340
+ for (int j = 0; j < N - 1; j++) {
341
+ const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
342
+
343
+ printf("%s", token_str.c_str());
344
+ }
345
+
346
+ printf("\n");
347
+ }
348
+ }
349
+
350
+ // update lookahead tokens
351
+ {
352
+ for (int i = 0; i < W; i++) {
353
+ tokens_j_prev[i] = tokens_j[0][i];
354
+ }
355
+
356
+ for (int j = 0; j < N - 2; j++) {
357
+ tokens_j[j] = tokens_j[j + 1];
358
+ }
359
+
360
+ if (v == 0) {
361
+ // sample from the last level
362
+ for (int i = 0; i < W; i++) {
363
+ tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
364
+ }
365
+ } else {
366
+ for (int i = 0; i < W; i++) {
367
+ // there are different ways to init these tokens
368
+ if (0) {
369
+ // random init
370
+ tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
371
+ } else {
372
+ // init from the previous level
373
+ tokens_j[N - 2][i] = tokens_j[0][i];
374
+ }
375
+ }
376
+ }
377
+ }
378
+
379
+ // update observed ngrams
380
+ if (v == 0) {
381
+ // the first token of the n-gram is determined by the index in the container so it is not stored
382
+ std::vector<llama_token> ngram(N - 1);
383
+
384
+ // n-gram generation
385
+ // ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518
386
+ for (int f = 0; f < W; ++f) {
387
+ const int ft = tokens_j_prev[f]; // first token of the n-gram
388
+
389
+ for (int j = 0; j < N - 1; ++j) {
390
+ ngram[j] = tokens_j[j][f];
391
+ }
392
+
393
+ // filter-out repeating n-grams
394
+ {
395
+ bool is_unique = true;
396
+
397
+ for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) {
398
+ const int idx = ft*(N - 1)*G + k*(N - 1);
399
+
400
+ bool is_match = true;
401
+ for (int j = 0; j < N - 1; ++j) {
402
+ if (ngrams_observed.tokens[idx + j] != ngram[j]) {
403
+ is_match = false;
404
+ break;
405
+ }
406
+ }
407
+
408
+ if (is_match) {
409
+ is_unique = false;
410
+ break;
411
+ }
412
+ }
413
+
414
+ if (!is_unique) {
415
+ continue;
416
+ }
417
+ }
418
+
419
+ const int head = ngrams_observed.head[ft];
420
+ const int idx = ft*(N - 1)*G + head*(N - 1);
421
+
422
+ for (int i = 0; i < N - 1; i++) {
423
+ ngrams_observed.tokens[idx + i] = ngram[i];
424
+ }
425
+
426
+ ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
427
+ ngrams_observed.head[ft] = (head + 1) % G;
428
+
429
+ ngrams_observed.n_total++;
430
+ }
431
+ }
432
+ }
433
+
434
+ if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
435
+ break;
436
+ }
437
+
438
+ // KV cache management
439
+ // if no verification token matched, we simply remove all cells from this batch -> no fragmentation
440
+ llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
441
+
442
+ if (seq_id_best != 0) {
443
+ // if a verification token matched, we keep the best sequence and remove the rest
444
+ // this leads to some KV cache fragmentation
445
+ llama_kv_cache_seq_keep(ctx, seq_id_best);
446
+ llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
447
+ llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
448
+
449
+ for (int s = 1; s < W + G + 1; ++s) {
450
+ llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
451
+ }
452
+ }
453
+ }
454
+
455
+ auto t_dec_end = ggml_time_us();
456
+
457
+ LOG_TEE("\n\n");
458
+
459
+ LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
460
+ LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
461
+
462
+ LOG_TEE("\n");
463
+ LOG_TEE("W = %2d\n", W);
464
+ LOG_TEE("N = %2d\n", N);
465
+ LOG_TEE("G = %2d\n", G);
466
+ LOG_TEE("\n");
467
+ LOG_TEE("n_predict = %d\n", n_predict);
468
+ LOG_TEE("n_accept = %d\n", n_accept);
469
+
470
+ llama_print_timings(ctx);
471
+
472
+ llama_kv_cache_view_free(&kvc_view);
473
+ llama_sampling_free(ctx_sampling);
474
+
475
+ llama_batch_free(batch);
476
+
477
+ llama_free(ctx);
478
+ llama_free_model(model);
479
+
480
+ llama_backend_free();
481
+
482
+ fprintf(stderr, "\n\n");
483
+
484
+ return 0;
485
+ }
@@ -0,0 +1,23 @@
1
+ set(TARGET lookup)
2
+ add_executable(${TARGET} lookup.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)
6
+
7
+ set(TARGET lookup-create)
8
+ add_executable(${TARGET} lookup-create.cpp)
9
+ install(TARGETS ${TARGET} RUNTIME)
10
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
11
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)
12
+
13
+ set(TARGET lookup-merge)
14
+ add_executable(${TARGET} lookup-merge.cpp)
15
+ install(TARGETS ${TARGET} RUNTIME)
16
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
17
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)
18
+
19
+ set(TARGET lookup-stats)
20
+ add_executable(${TARGET} lookup-stats.cpp)
21
+ install(TARGETS ${TARGET} RUNTIME)
22
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
23
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)
@@ -0,0 +1,41 @@
1
+ #include "ggml.h"
2
+ #include "llama.h"
3
+ #include "common.h"
4
+ #include "ngram-cache.h"
5
+
6
+ #include <cstdint>
7
+ #include <fstream>
8
+ #include <iostream>
9
+ #include <string>
10
+ #include <unordered_map>
11
+ #include <vector>
12
+
13
+ int main(int argc, char ** argv){
14
+ gpt_params params;
15
+
16
+ if (!gpt_params_parse(argc, argv, params)) {
17
+ return 1;
18
+ }
19
+ // init llama.cpp
20
+ llama_backend_init();
21
+ llama_numa_init(params.numa);
22
+
23
+ llama_model * model = NULL;
24
+ llama_context * ctx = NULL;
25
+
26
+ // load the model
27
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
28
+ GGML_ASSERT(model != nullptr);
29
+
30
+ // tokenize the prompt
31
+ std::vector<llama_token> inp;
32
+ inp = ::llama_tokenize(ctx, params.prompt, true, true);
33
+ fprintf(stderr, "%s: tokenization done\n", __func__);
34
+
35
+
36
+ llama_ngram_cache ngram_cache;
37
+ llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
38
+ fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
39
+
40
+ llama_ngram_cache_save(ngram_cache, params.lookup_cache_static);
41
+ }
@@ -0,0 +1,47 @@
1
+ #include "ggml.h"
2
+ #include "llama.h"
3
+ #include "common.h"
4
+ #include "ngram-cache.h"
5
+
6
+ #include <cstdint>
7
+ #include <cstdio>
8
+ #include <fstream>
9
+ #include <iostream>
10
+ #include <string>
11
+ #include <unordered_map>
12
+ #include <vector>
13
+
14
+ static void print_usage() {
15
+ fprintf(stderr, "Merges multiple lookup cache files into a single one.\n");
16
+ fprintf(stderr, "Usage: lookup-merge [--help] lookup_part_1.bin lookup_part_2.bin ... lookup_merged.bin\n");
17
+ }
18
+
19
+ int main(int argc, char ** argv){
20
+ if (argc < 3) {
21
+ print_usage();
22
+ exit(1);
23
+ }
24
+
25
+ std::vector<std::string> args;
26
+ args.resize(argc-1);
27
+ for (int i = 0; i < argc-1; ++i) {
28
+ args[i] = argv[i+1];
29
+ if (args[i] == "-h" || args[i] == "--help") {
30
+ print_usage();
31
+ exit(0);
32
+ }
33
+ }
34
+
35
+ fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
36
+ llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]);
37
+
38
+ for (size_t i = 1; i < args.size()-1; ++i) {
39
+ fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
40
+ llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]);
41
+
42
+ llama_ngram_cache_merge(ngram_cache_merged, ngram_cache);
43
+ }
44
+
45
+ fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str());
46
+ llama_ngram_cache_save(ngram_cache_merged, args.back());
47
+ }