@fugood/llama.node 0.0.1-alpha.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. package/CMakeLists.txt +85 -0
  2. package/README.md +56 -0
  3. package/bin/darwin/arm64/llama-node.node +0 -0
  4. package/bin/darwin/x64/llama-node.node +0 -0
  5. package/bin/linux/arm64/llama-node.node +0 -0
  6. package/bin/linux/x64/llama-node.node +0 -0
  7. package/bin/win32/arm64/llama-node.node +0 -0
  8. package/bin/win32/arm64/node.lib +0 -0
  9. package/bin/win32/x64/llama-node.node +0 -0
  10. package/bin/win32/x64/node.lib +0 -0
  11. package/lib/binding.js +13 -0
  12. package/lib/binding.ts +57 -0
  13. package/lib/index.js +24 -0
  14. package/lib/index.ts +13 -0
  15. package/package.json +65 -0
  16. package/src/addons.cpp +506 -0
  17. package/src/llama.cpp/CMakeLists.txt +1320 -0
  18. package/src/llama.cpp/build.zig +172 -0
  19. package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
  20. package/src/llama.cpp/common/CMakeLists.txt +87 -0
  21. package/src/llama.cpp/common/base64.hpp +392 -0
  22. package/src/llama.cpp/common/common.cpp +2949 -0
  23. package/src/llama.cpp/common/common.h +324 -0
  24. package/src/llama.cpp/common/console.cpp +501 -0
  25. package/src/llama.cpp/common/console.h +19 -0
  26. package/src/llama.cpp/common/grammar-parser.cpp +440 -0
  27. package/src/llama.cpp/common/grammar-parser.h +29 -0
  28. package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
  29. package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
  30. package/src/llama.cpp/common/json.hpp +24766 -0
  31. package/src/llama.cpp/common/log.h +724 -0
  32. package/src/llama.cpp/common/ngram-cache.cpp +282 -0
  33. package/src/llama.cpp/common/ngram-cache.h +94 -0
  34. package/src/llama.cpp/common/sampling.cpp +353 -0
  35. package/src/llama.cpp/common/sampling.h +147 -0
  36. package/src/llama.cpp/common/stb_image.h +8396 -0
  37. package/src/llama.cpp/common/train.cpp +1513 -0
  38. package/src/llama.cpp/common/train.h +233 -0
  39. package/src/llama.cpp/examples/CMakeLists.txt +52 -0
  40. package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
  41. package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
  42. package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
  43. package/src/llama.cpp/examples/batched/batched.cpp +262 -0
  44. package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
  45. package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
  46. package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
  47. package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
  48. package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
  49. package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
  50. package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
  51. package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
  52. package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
  53. package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
  54. package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
  55. package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
  56. package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
  57. package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
  58. package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
  59. package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
  60. package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
  61. package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
  62. package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
  63. package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
  64. package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
  65. package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
  66. package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
  67. package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
  68. package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
  69. package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
  70. package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
  71. package/src/llama.cpp/examples/infill/infill.cpp +767 -0
  72. package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
  73. package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
  74. package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
  75. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
  76. package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
  77. package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
  78. package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
  79. package/src/llama.cpp/examples/llava/clip.h +85 -0
  80. package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
  81. package/src/llama.cpp/examples/llava/llava.cpp +426 -0
  82. package/src/llama.cpp/examples/llava/llava.h +50 -0
  83. package/src/llama.cpp/examples/llava/requirements.txt +3 -0
  84. package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
  85. package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
  86. package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
  87. package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
  88. package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
  89. package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
  90. package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
  91. package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
  92. package/src/llama.cpp/examples/main/main.cpp +957 -0
  93. package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
  94. package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
  95. package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
  96. package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
  97. package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
  98. package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
  99. package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
  100. package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
  101. package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
  102. package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
  103. package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
  104. package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
  105. package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
  106. package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
  107. package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
  108. package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
  109. package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
  110. package/src/llama.cpp/examples/server/httplib.h +9465 -0
  111. package/src/llama.cpp/examples/server/server.cpp +3826 -0
  112. package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
  113. package/src/llama.cpp/examples/server/utils.hpp +653 -0
  114. package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
  115. package/src/llama.cpp/examples/simple/simple.cpp +183 -0
  116. package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
  117. package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
  118. package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
  119. package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
  120. package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
  121. package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
  122. package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
  123. package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
  124. package/src/llama.cpp/ggml-alloc.c +985 -0
  125. package/src/llama.cpp/ggml-alloc.h +76 -0
  126. package/src/llama.cpp/ggml-backend-impl.h +141 -0
  127. package/src/llama.cpp/ggml-backend.c +2099 -0
  128. package/src/llama.cpp/ggml-backend.h +233 -0
  129. package/src/llama.cpp/ggml-common.h +1853 -0
  130. package/src/llama.cpp/ggml-cuda.h +43 -0
  131. package/src/llama.cpp/ggml-impl.h +265 -0
  132. package/src/llama.cpp/ggml-kompute.cpp +2006 -0
  133. package/src/llama.cpp/ggml-kompute.h +46 -0
  134. package/src/llama.cpp/ggml-metal.h +66 -0
  135. package/src/llama.cpp/ggml-mpi.c +216 -0
  136. package/src/llama.cpp/ggml-mpi.h +39 -0
  137. package/src/llama.cpp/ggml-opencl.cpp +2301 -0
  138. package/src/llama.cpp/ggml-opencl.h +36 -0
  139. package/src/llama.cpp/ggml-quants.c +12678 -0
  140. package/src/llama.cpp/ggml-quants.h +133 -0
  141. package/src/llama.cpp/ggml-sycl.cpp +17882 -0
  142. package/src/llama.cpp/ggml-sycl.h +49 -0
  143. package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
  144. package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
  145. package/src/llama.cpp/ggml-vulkan.h +29 -0
  146. package/src/llama.cpp/ggml.c +21819 -0
  147. package/src/llama.cpp/ggml.h +2403 -0
  148. package/src/llama.cpp/llama.cpp +17468 -0
  149. package/src/llama.cpp/llama.h +1117 -0
  150. package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
  151. package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
  152. package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
  153. package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
  154. package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
  155. package/src/llama.cpp/prompts/alpaca.txt +1 -0
  156. package/src/llama.cpp/prompts/assistant.txt +31 -0
  157. package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
  158. package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
  159. package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
  160. package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
  161. package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
  162. package/src/llama.cpp/prompts/chat.txt +28 -0
  163. package/src/llama.cpp/prompts/dan-modified.txt +1 -0
  164. package/src/llama.cpp/prompts/dan.txt +1 -0
  165. package/src/llama.cpp/prompts/mnemonics.txt +93 -0
  166. package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
  167. package/src/llama.cpp/prompts/reason-act.txt +18 -0
  168. package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
  169. package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
  170. package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
  171. package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
  172. package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
  173. package/src/llama.cpp/requirements.txt +12 -0
  174. package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
  175. package/src/llama.cpp/scripts/xxd.cmake +16 -0
  176. package/src/llama.cpp/sgemm.cpp +999 -0
  177. package/src/llama.cpp/sgemm.h +12 -0
  178. package/src/llama.cpp/tests/CMakeLists.txt +78 -0
  179. package/src/llama.cpp/tests/get-model.cpp +21 -0
  180. package/src/llama.cpp/tests/get-model.h +2 -0
  181. package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
  182. package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
  183. package/src/llama.cpp/tests/test-c.c +7 -0
  184. package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
  185. package/src/llama.cpp/tests/test-double-float.cpp +57 -0
  186. package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
  187. package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
  188. package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
  189. package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
  190. package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
  191. package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
  192. package/src/llama.cpp/tests/test-opt.cpp +181 -0
  193. package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
  194. package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
  195. package/src/llama.cpp/tests/test-rope.cpp +221 -0
  196. package/src/llama.cpp/tests/test-sampling.cpp +301 -0
  197. package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
  198. package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
  199. package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
  200. package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
  201. package/src/llama.cpp/unicode-data.cpp +1651 -0
  202. package/src/llama.cpp/unicode-data.h +16 -0
  203. package/src/llama.cpp/unicode.cpp +277 -0
  204. package/src/llama.cpp/unicode.h +28 -0
@@ -0,0 +1,160 @@
1
+ #include "ggml.h"
2
+ #include "common.h"
3
+ #include "llama.h"
4
+ #include "log.h"
5
+ #include "ngram-cache.h"
6
+
7
+ #include <cmath>
8
+ #include <cstdint>
9
+ #include <cstdio>
10
+ #include <fstream>
11
+ #include <string>
12
+ #include <vector>
13
+ #include <unordered_map>
14
+
15
+ int main(int argc, char ** argv){
16
+ gpt_params params;
17
+
18
+ if (!gpt_params_parse(argc, argv, params)) {
19
+ return 1;
20
+ }
21
+
22
+ const int n_draft = params.n_draft;
23
+
24
+ // init llama.cpp
25
+ llama_backend_init();
26
+ llama_numa_init(params.numa);
27
+
28
+ llama_model * model = NULL;
29
+ llama_context * ctx = NULL;
30
+
31
+ // load the model
32
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
33
+ llama_set_rng_seed(ctx, params.seed);
34
+ GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
35
+
36
+ // tokenize the prompt
37
+ std::vector<llama_token> inp;
38
+ inp = ::llama_tokenize(ctx, params.prompt, true, true);
39
+
40
+ llama_ngram_cache ngram_cache_context;
41
+ llama_ngram_cache ngram_cache_dynamic;
42
+ llama_ngram_cache ngram_cache_static;
43
+ int64_t t_draft_flat_us = 0;
44
+ int64_t t_draft_us = 0;
45
+
46
+ {
47
+ const int64_t t_start_draft_us = ggml_time_us();
48
+
49
+ if (!params.lookup_cache_static.empty()) {
50
+ try {
51
+ ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
52
+ } catch (std::ifstream::failure const &) {
53
+ fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
54
+ exit(1);
55
+ }
56
+ }
57
+
58
+ if (!params.lookup_cache_dynamic.empty()) {
59
+ try {
60
+ ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
61
+ } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
62
+ }
63
+
64
+ t_draft_flat_us += ggml_time_us() - t_start_draft_us;
65
+ }
66
+
67
+ const int n_input = inp.size();
68
+ const int n_ctx = params.n_ctx;
69
+
70
+ int n_drafted = 0;
71
+ int n_accept = 0;
72
+
73
+ const int64_t t_start_ms = ggml_time_ms();
74
+
75
+ // Iterate over input tokens in chunks of size n_ctx.
76
+ // Each chunk is treated as if a sequential generation but with pre-determined tokens to ensure reproducibility.
77
+ for (int i_start = 0; i_start + n_ctx < n_input; i_start += n_ctx) {
78
+ const std::vector<llama_token> inp_slice(inp.begin() + i_start, inp.begin() + i_start + n_ctx);
79
+ std::vector<llama_token> pseudo_output;
80
+ pseudo_output.push_back(inp_slice[0]);
81
+
82
+ while ((int) pseudo_output.size() < n_ctx) {
83
+ // Simulate drafting and decoding from draft:
84
+ std::vector<llama_token> draft;
85
+ draft.push_back(pseudo_output.back());
86
+
87
+ {
88
+ const int64_t t_start_draft_us = ggml_time_us();
89
+ llama_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
90
+ t_draft_us += ggml_time_us() - t_start_draft_us;
91
+ }
92
+
93
+ n_drafted += draft.size() - 1;
94
+
95
+ for (size_t j = 1; j < draft.size() && (int) pseudo_output.size() < n_ctx; ++j) {
96
+ const llama_token ground_truth = inp_slice[pseudo_output.size()];
97
+ const llama_token drafted = draft[j];
98
+
99
+ if (ground_truth != drafted) {
100
+ break;
101
+ }
102
+
103
+ ++n_accept;
104
+ pseudo_output.push_back(ground_truth);
105
+
106
+ {
107
+ const int64_t t_start_draft_us = ggml_time_us();
108
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
109
+ t_draft_us += ggml_time_us() - t_start_draft_us;
110
+ }
111
+ }
112
+
113
+ // After each simulated batch decoding simulate the sampling of a single token:
114
+ if ((int) pseudo_output.size() < n_ctx) {
115
+ pseudo_output.push_back(inp_slice[pseudo_output.size()]);
116
+ {
117
+ const int64_t t_start_draft_us = ggml_time_us();
118
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
119
+ t_draft_us += ggml_time_us() - t_start_draft_us;
120
+ }
121
+ }
122
+
123
+ draft.erase(draft.begin());
124
+
125
+ }
126
+ if (i_start > 0 && i_start / 100000 != (i_start - n_ctx) / 100000) {
127
+ const int64_t t_now_ms = ggml_time_ms();
128
+ const int64_t eta_ms = (n_input - i_start) * (t_now_ms - t_start_ms) / i_start;
129
+ const int64_t eta_min = eta_ms / (60*1000);
130
+ const int64_t eta_s = (eta_ms - 60*1000*eta_min) / 1000;
131
+
132
+ LOG_TEE("lookup-stats: %d/%d done, ETA: %02" PRId64 ":%02" PRId64 "\n", i_start, n_input, eta_min, eta_s);
133
+ }
134
+
135
+ // After each chunk, update the dynamic ngram cache with the context ngram cache:
136
+ llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
137
+ ngram_cache_context.clear();
138
+ }
139
+
140
+ LOG_TEE("\n");
141
+
142
+ LOG_TEE("\n");
143
+ LOG_TEE("n_draft = %d\n", n_draft);
144
+ LOG_TEE("n_predict = %d\n", n_input - n_input % n_ctx);
145
+ LOG_TEE("n_drafted = %d\n", n_drafted);
146
+ LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3);
147
+ LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
148
+ t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
149
+ LOG_TEE("n_accept = %d\n", n_accept);
150
+ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
151
+
152
+ llama_free(ctx);
153
+ llama_free_model(model);
154
+
155
+ llama_backend_free();
156
+
157
+ fprintf(stderr, "\n\n");
158
+
159
+ return 0;
160
+ }
@@ -0,0 +1,258 @@
1
+ #include "ggml.h"
2
+ #include "llama.h"
3
+ #include "common.h"
4
+ #include "ngram-cache.h"
5
+
6
+ #include <cmath>
7
+ #include <cstdint>
8
+ #include <cstdio>
9
+ #include <fstream>
10
+ #include <string>
11
+ #include <vector>
12
+ #include <unordered_map>
13
+
14
+ int main(int argc, char ** argv){
15
+ gpt_params params;
16
+
17
+ if (!gpt_params_parse(argc, argv, params)) {
18
+ return 1;
19
+ }
20
+
21
+ // max. number of additional tokens to draft if match is found
22
+ const int n_draft = params.n_draft;
23
+
24
+ const bool dump_kv_cache = params.dump_kv_cache;
25
+
26
+ #ifndef LOG_DISABLE_LOGS
27
+ log_set_target(log_filename_generator("lookup", "log"));
28
+ LOG_TEE("Log start\n");
29
+ log_dump_cmdline(argc, argv);
30
+ #endif // LOG_DISABLE_LOGS
31
+
32
+ // init llama.cpp
33
+ llama_backend_init();
34
+ llama_numa_init(params.numa);
35
+
36
+ llama_model * model = NULL;
37
+ llama_context * ctx = NULL;
38
+
39
+ // load the model
40
+ std::tie(model, ctx) = llama_init_from_gpt_params(params);
41
+ llama_set_rng_seed(ctx, params.seed);
42
+ GGML_ASSERT(llama_n_vocab(model) < (1 << 16));
43
+
44
+ // tokenize the prompt
45
+ std::vector<llama_token> inp;
46
+ inp = ::llama_tokenize(ctx, params.prompt, true, true);
47
+
48
+ llama_ngram_cache ngram_cache_context;
49
+ llama_ngram_cache ngram_cache_dynamic;
50
+ llama_ngram_cache ngram_cache_static;
51
+ int64_t t_draft_flat_us = 0;
52
+ int64_t t_draft_us = 0;
53
+
54
+ {
55
+ // Fill up context ngram cache with tokens from user input:
56
+ const int64_t t_start_draft_us = ggml_time_us();
57
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
58
+
59
+ if (!params.lookup_cache_static.empty()) {
60
+ try {
61
+ ngram_cache_static = llama_ngram_cache_load(params.lookup_cache_static);
62
+ } catch (std::ifstream::failure const &) {
63
+ fprintf(stderr, "error: failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
64
+ exit(1);
65
+ }
66
+ }
67
+
68
+ if (!params.lookup_cache_dynamic.empty()) {
69
+ try {
70
+ ngram_cache_dynamic = llama_ngram_cache_load(params.lookup_cache_dynamic);
71
+ } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
72
+ }
73
+
74
+ t_draft_flat_us += ggml_time_us() - t_start_draft_us;
75
+ }
76
+
77
+ const int max_context_size = llama_n_ctx(ctx);
78
+ const int max_tokens_list_size = max_context_size - 4;
79
+
80
+ if ((int) inp.size() > max_tokens_list_size) {
81
+ fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
82
+ return 1;
83
+ }
84
+
85
+ fprintf(stderr, "\n\n");
86
+
87
+ for (auto id : inp) {
88
+ fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
89
+ }
90
+
91
+ fflush(stderr);
92
+
93
+ const int n_input = inp.size();
94
+
95
+ const auto t_enc_start = ggml_time_us();
96
+
97
+ llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
98
+ llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
99
+
100
+ const auto t_enc_end = ggml_time_us();
101
+
102
+ int n_predict = 0;
103
+ int n_drafted = 0;
104
+ int n_accept = 0;
105
+
106
+ int n_past = inp.size();
107
+
108
+ bool has_eos = false;
109
+
110
+ struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
111
+
112
+ std::vector<llama_token> draft;
113
+
114
+ llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1);
115
+
116
+ // debug
117
+ struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1);
118
+
119
+ const auto t_dec_start = ggml_time_us();
120
+
121
+ while (true) {
122
+ // debug
123
+ if (dump_kv_cache) {
124
+ llama_kv_cache_view_update(ctx, &kvc_view);
125
+ dump_kv_cache_view_seqs(kvc_view, 40);
126
+ }
127
+
128
+ // print current draft sequence
129
+ LOG("drafted %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, draft).c_str());
130
+
131
+ int i_dft = 0;
132
+ while (true) {
133
+ // sample from the target model
134
+ llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
135
+
136
+ llama_sampling_accept(ctx_sampling, ctx, id, true);
137
+
138
+ const std::string token_str = llama_token_to_piece(ctx, id);
139
+
140
+ if (!params.use_color) {
141
+ printf("%s", token_str.c_str());
142
+ }
143
+
144
+ if (llama_token_is_eog(model, id)) {
145
+ has_eos = true;
146
+ }
147
+
148
+ ++n_predict;
149
+
150
+ // check if the target token matches the draft
151
+ if (i_dft < (int) draft.size() && id == draft[i_dft]) {
152
+ LOG("the sampled target token matches the %dth drafted token (%d, '%s') - accepted\n", i_dft, id, token_str.c_str());
153
+ ++n_accept;
154
+ ++n_past;
155
+ ++i_dft;
156
+ inp.push_back(id);
157
+ {
158
+ // Update context ngram cache with the newly accepted token:
159
+ const int64_t t_start_draft_us = ggml_time_us();
160
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
161
+ t_draft_us += ggml_time_us() - t_start_draft_us;
162
+ }
163
+
164
+ if (params.use_color) {
165
+ // color accepted draft token
166
+ printf("\033[34m%s\033[0m", token_str.c_str());
167
+ fflush(stdout);
168
+ }
169
+ continue;
170
+ }
171
+
172
+ if (params.use_color) {
173
+ printf("%s", token_str.c_str());
174
+ }
175
+ fflush(stdout);
176
+
177
+
178
+ LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", id, token_str.c_str());
179
+
180
+ draft.clear();
181
+ draft.push_back(id);
182
+ inp.push_back(id);
183
+ {
184
+ // Update context ngram cache with the newly accepted token:
185
+ const int64_t t_start_draft_us = ggml_time_us();
186
+ llama_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
187
+ t_draft_us += ggml_time_us() - t_start_draft_us;
188
+ }
189
+ break;
190
+ }
191
+
192
+ if ((params.n_predict > 0 && n_predict > params.n_predict) || has_eos) {
193
+ break;
194
+ }
195
+
196
+ // KV cache management
197
+ // clean the cache of draft tokens that weren't accepted
198
+ llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
199
+
200
+ llama_batch_clear(batch_tgt);
201
+ llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
202
+
203
+ // Draft already contains a single token sampled from the model:
204
+ GGML_ASSERT(draft.size() == 1);
205
+ GGML_ASSERT(draft[0] == inp.back());
206
+ const int64_t t_start_draft_us = ggml_time_us();
207
+
208
+ llama_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
209
+
210
+ for (size_t i = 1; i < draft.size(); ++i) {
211
+ llama_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
212
+ }
213
+
214
+ t_draft_us += ggml_time_us() - t_start_draft_us;
215
+ n_drafted += draft.size() - 1;
216
+
217
+ llama_decode(ctx, batch_tgt);
218
+ ++n_past;
219
+
220
+ draft.erase(draft.begin());
221
+ }
222
+
223
+ auto t_dec_end = ggml_time_us();
224
+
225
+ // Update dynamic ngram cache with context ngram cache and save it to disk:
226
+ llama_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
227
+ llama_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
228
+
229
+ LOG_TEE("\n\n");
230
+
231
+ LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
232
+ LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
233
+
234
+ LOG_TEE("\n");
235
+ LOG_TEE("n_draft = %d\n", n_draft);
236
+ LOG_TEE("n_predict = %d\n", n_predict);
237
+ LOG_TEE("n_drafted = %d\n", n_drafted);
238
+ LOG_TEE("t_draft_flat = %.2f ms\n", t_draft_flat_us*1e-3);
239
+ LOG_TEE("t_draft = %.2f ms, %.2f us per token, %.2f tokens per second\n",
240
+ t_draft_us*1e-3, 1.0f*t_draft_us/n_drafted, n_drafted/(1e-6*t_draft_us));
241
+ LOG_TEE("n_accept = %d\n", n_accept);
242
+ LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
243
+
244
+ LOG_TEE("\ntarget:\n");
245
+ llama_print_timings(ctx);
246
+
247
+ llama_sampling_free(ctx_sampling);
248
+ llama_batch_free(batch_tgt);
249
+
250
+ llama_free(ctx);
251
+ llama_free_model(model);
252
+
253
+ llama_backend_free();
254
+
255
+ fprintf(stderr, "\n\n");
256
+
257
+ return 0;
258
+ }
@@ -0,0 +1,5 @@
1
+ set(TARGET main)
2
+ add_executable(${TARGET} main.cpp)
3
+ install(TARGETS ${TARGET} RUNTIME)
4
+ target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
5
+ target_compile_features(${TARGET} PRIVATE cxx_std_11)