@fugood/llama.node 0.0.1-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +85 -0
- package/README.md +56 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/lib/binding.js +13 -0
- package/lib/binding.ts +57 -0
- package/lib/index.js +24 -0
- package/lib/index.ts +13 -0
- package/package.json +65 -0
- package/src/addons.cpp +506 -0
- package/src/llama.cpp/CMakeLists.txt +1320 -0
- package/src/llama.cpp/build.zig +172 -0
- package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
- package/src/llama.cpp/common/CMakeLists.txt +87 -0
- package/src/llama.cpp/common/base64.hpp +392 -0
- package/src/llama.cpp/common/common.cpp +2949 -0
- package/src/llama.cpp/common/common.h +324 -0
- package/src/llama.cpp/common/console.cpp +501 -0
- package/src/llama.cpp/common/console.h +19 -0
- package/src/llama.cpp/common/grammar-parser.cpp +440 -0
- package/src/llama.cpp/common/grammar-parser.h +29 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
- package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
- package/src/llama.cpp/common/json.hpp +24766 -0
- package/src/llama.cpp/common/log.h +724 -0
- package/src/llama.cpp/common/ngram-cache.cpp +282 -0
- package/src/llama.cpp/common/ngram-cache.h +94 -0
- package/src/llama.cpp/common/sampling.cpp +353 -0
- package/src/llama.cpp/common/sampling.h +147 -0
- package/src/llama.cpp/common/stb_image.h +8396 -0
- package/src/llama.cpp/common/train.cpp +1513 -0
- package/src/llama.cpp/common/train.h +233 -0
- package/src/llama.cpp/examples/CMakeLists.txt +52 -0
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
- package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/batched/batched.cpp +262 -0
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
- package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
- package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
- package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/infill/infill.cpp +767 -0
- package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
- package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
- package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
- package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
- package/src/llama.cpp/examples/llava/clip.h +85 -0
- package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
- package/src/llama.cpp/examples/llava/llava.cpp +426 -0
- package/src/llama.cpp/examples/llava/llava.h +50 -0
- package/src/llama.cpp/examples/llava/requirements.txt +3 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
- package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
- package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/main/main.cpp +957 -0
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
- package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
- package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
- package/src/llama.cpp/examples/server/httplib.h +9465 -0
- package/src/llama.cpp/examples/server/server.cpp +3826 -0
- package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
- package/src/llama.cpp/examples/server/utils.hpp +653 -0
- package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple/simple.cpp +183 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
- package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
- package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
- package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
- package/src/llama.cpp/ggml-alloc.c +985 -0
- package/src/llama.cpp/ggml-alloc.h +76 -0
- package/src/llama.cpp/ggml-backend-impl.h +141 -0
- package/src/llama.cpp/ggml-backend.c +2099 -0
- package/src/llama.cpp/ggml-backend.h +233 -0
- package/src/llama.cpp/ggml-common.h +1853 -0
- package/src/llama.cpp/ggml-cuda.h +43 -0
- package/src/llama.cpp/ggml-impl.h +265 -0
- package/src/llama.cpp/ggml-kompute.cpp +2006 -0
- package/src/llama.cpp/ggml-kompute.h +46 -0
- package/src/llama.cpp/ggml-metal.h +66 -0
- package/src/llama.cpp/ggml-mpi.c +216 -0
- package/src/llama.cpp/ggml-mpi.h +39 -0
- package/src/llama.cpp/ggml-opencl.cpp +2301 -0
- package/src/llama.cpp/ggml-opencl.h +36 -0
- package/src/llama.cpp/ggml-quants.c +12678 -0
- package/src/llama.cpp/ggml-quants.h +133 -0
- package/src/llama.cpp/ggml-sycl.cpp +17882 -0
- package/src/llama.cpp/ggml-sycl.h +49 -0
- package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
- package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
- package/src/llama.cpp/ggml-vulkan.h +29 -0
- package/src/llama.cpp/ggml.c +21819 -0
- package/src/llama.cpp/ggml.h +2403 -0
- package/src/llama.cpp/llama.cpp +17468 -0
- package/src/llama.cpp/llama.h +1117 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
- package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
- package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
- package/src/llama.cpp/prompts/alpaca.txt +1 -0
- package/src/llama.cpp/prompts/assistant.txt +31 -0
- package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
- package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
- package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
- package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
- package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
- package/src/llama.cpp/prompts/chat.txt +28 -0
- package/src/llama.cpp/prompts/dan-modified.txt +1 -0
- package/src/llama.cpp/prompts/dan.txt +1 -0
- package/src/llama.cpp/prompts/mnemonics.txt +93 -0
- package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
- package/src/llama.cpp/prompts/reason-act.txt +18 -0
- package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
- package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
- package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
- package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
- package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
- package/src/llama.cpp/requirements.txt +12 -0
- package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
- package/src/llama.cpp/scripts/xxd.cmake +16 -0
- package/src/llama.cpp/sgemm.cpp +999 -0
- package/src/llama.cpp/sgemm.h +12 -0
- package/src/llama.cpp/tests/CMakeLists.txt +78 -0
- package/src/llama.cpp/tests/get-model.cpp +21 -0
- package/src/llama.cpp/tests/get-model.h +2 -0
- package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
- package/src/llama.cpp/tests/test-c.c +7 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
- package/src/llama.cpp/tests/test-double-float.cpp +57 -0
- package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
- package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
- package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
- package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
- package/src/llama.cpp/tests/test-opt.cpp +181 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
- package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
- package/src/llama.cpp/tests/test-rope.cpp +221 -0
- package/src/llama.cpp/tests/test-sampling.cpp +301 -0
- package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
- package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
- package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
- package/src/llama.cpp/unicode-data.cpp +1651 -0
- package/src/llama.cpp/unicode-data.h +16 -0
- package/src/llama.cpp/unicode.cpp +277 -0
- package/src/llama.cpp/unicode.h +28 -0
|
@@ -0,0 +1,485 @@
|
|
|
1
|
+
#include "common.h"
|
|
2
|
+
#include "llama.h"
|
|
3
|
+
|
|
4
|
+
#include <cmath>
|
|
5
|
+
#include <cstdio>
|
|
6
|
+
#include <string>
|
|
7
|
+
#include <vector>
|
|
8
|
+
|
|
9
|
+
struct ngram_data {
|
|
10
|
+
bool active = false;
|
|
11
|
+
|
|
12
|
+
llama_seq_id seq_id = -1;
|
|
13
|
+
|
|
14
|
+
std::vector<int> i_batch;
|
|
15
|
+
|
|
16
|
+
std::vector<llama_token> tokens;
|
|
17
|
+
};
|
|
18
|
+
|
|
19
|
+
// n-gram container
|
|
20
|
+
struct ngram_container {
|
|
21
|
+
ngram_container(int n_vocab, int N, int G) {
|
|
22
|
+
cnt.resize(n_vocab);
|
|
23
|
+
head.resize(n_vocab);
|
|
24
|
+
tokens.resize(n_vocab * G * (N - 1));
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
int n_total = 0;
|
|
28
|
+
|
|
29
|
+
std::vector<int> cnt;
|
|
30
|
+
std::vector<int> head;
|
|
31
|
+
|
|
32
|
+
// [n_vocab][G][N - 1]
|
|
33
|
+
// for each token of the vocab, keep a ring-buffer of capacity G of n-grams of size N - 1
|
|
34
|
+
std::vector<llama_token> tokens;
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
int main(int argc, char ** argv) {
|
|
38
|
+
gpt_params params;
|
|
39
|
+
|
|
40
|
+
if (gpt_params_parse(argc, argv, params) == false) {
|
|
41
|
+
return 1;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
const int W = 15; // lookahead window
|
|
45
|
+
const int N = 5; // n-gram size
|
|
46
|
+
const int G = 15; // max verification n-grams
|
|
47
|
+
|
|
48
|
+
const bool dump_kv_cache = params.dump_kv_cache;
|
|
49
|
+
|
|
50
|
+
#ifndef LOG_DISABLE_LOGS
|
|
51
|
+
log_set_target(log_filename_generator("lookahead", "log"));
|
|
52
|
+
LOG_TEE("Log start\n");
|
|
53
|
+
log_dump_cmdline(argc, argv);
|
|
54
|
+
#endif // LOG_DISABLE_LOGS
|
|
55
|
+
|
|
56
|
+
// init llama.cpp
|
|
57
|
+
llama_backend_init();
|
|
58
|
+
llama_numa_init(params.numa);
|
|
59
|
+
|
|
60
|
+
llama_model * model = NULL;
|
|
61
|
+
llama_context * ctx = NULL;
|
|
62
|
+
|
|
63
|
+
// load the target model
|
|
64
|
+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
|
65
|
+
|
|
66
|
+
// Tokenize the prompt
|
|
67
|
+
std::vector<llama_token> inp;
|
|
68
|
+
std::vector<llama_token> all;
|
|
69
|
+
|
|
70
|
+
inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
|
71
|
+
all = inp;
|
|
72
|
+
|
|
73
|
+
const int max_context_size = llama_n_ctx(ctx);
|
|
74
|
+
const int max_tokens_list_size = max_context_size - 4;
|
|
75
|
+
|
|
76
|
+
if ((int) inp.size() > max_tokens_list_size) {
|
|
77
|
+
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
|
|
78
|
+
return 1;
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
fprintf(stderr, "\n\n");
|
|
82
|
+
|
|
83
|
+
for (auto id : inp) {
|
|
84
|
+
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
fflush(stderr);
|
|
88
|
+
|
|
89
|
+
const int n_input = inp.size();
|
|
90
|
+
|
|
91
|
+
const auto t_enc_start = ggml_time_us();
|
|
92
|
+
|
|
93
|
+
// eval the prompt
|
|
94
|
+
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
|
95
|
+
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
|
96
|
+
|
|
97
|
+
for (int s = 1; s < W + G + 1; ++s) {
|
|
98
|
+
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
const auto t_enc_end = ggml_time_us();
|
|
102
|
+
|
|
103
|
+
int n_predict = 0;
|
|
104
|
+
int n_accept = 0;
|
|
105
|
+
|
|
106
|
+
int n_past = inp.size();
|
|
107
|
+
|
|
108
|
+
llama_token id = 0;
|
|
109
|
+
|
|
110
|
+
// used to determine end of generation
|
|
111
|
+
bool has_eos = false;
|
|
112
|
+
|
|
113
|
+
// for each decoded batch, we have at most W + G + 1 distinct sequences:
|
|
114
|
+
// seq_id == 0 : the current input token
|
|
115
|
+
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
|
|
116
|
+
// seq_id [W + 1, W + G] : verification n-grams
|
|
117
|
+
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
|
118
|
+
|
|
119
|
+
// target model sampling context
|
|
120
|
+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
|
121
|
+
|
|
122
|
+
// verification n-grams
|
|
123
|
+
std::vector<ngram_data> ngrams_cur(G);
|
|
124
|
+
|
|
125
|
+
// tokens for the past N - 1 Jacobi iterations
|
|
126
|
+
std::vector<llama_token> tokens_j_prev(W);
|
|
127
|
+
std::vector<std::vector<llama_token>> tokens_j(N - 1);
|
|
128
|
+
for (int j = 0; j < N - 1; j++) {
|
|
129
|
+
tokens_j[j].resize(W);
|
|
130
|
+
|
|
131
|
+
for (int i = 0; i < W; i++) {
|
|
132
|
+
// there are different ways to init these tokens
|
|
133
|
+
if (0) {
|
|
134
|
+
// initialize randomly from the prompt tokens
|
|
135
|
+
tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
|
|
136
|
+
} else {
|
|
137
|
+
// initialize with a sequence of increasing numbers
|
|
138
|
+
tokens_j[j][i] = 100 + i;
|
|
139
|
+
}
|
|
140
|
+
}
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
std::vector<llama_seq_id> seq_id_look;
|
|
144
|
+
|
|
145
|
+
// the input token belongs both to all sequences
|
|
146
|
+
std::vector<llama_seq_id> seq_id_all(W + G + 1);
|
|
147
|
+
for (int i = 0; i < W + G + 1; i++) {
|
|
148
|
+
seq_id_all[i] = i;
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
// here we keep adding new n-grams as we go
|
|
152
|
+
ngram_container ngrams_observed(llama_n_vocab(model), N, G);
|
|
153
|
+
|
|
154
|
+
// debug
|
|
155
|
+
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
|
|
156
|
+
|
|
157
|
+
const auto t_dec_start = ggml_time_us();
|
|
158
|
+
|
|
159
|
+
// sample first token
|
|
160
|
+
{
|
|
161
|
+
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
|
|
162
|
+
|
|
163
|
+
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
|
164
|
+
|
|
165
|
+
{
|
|
166
|
+
const std::string token_str = llama_token_to_piece(ctx, id);
|
|
167
|
+
|
|
168
|
+
printf("%s", token_str.c_str());
|
|
169
|
+
fflush(stdout);
|
|
170
|
+
}
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
while (true) {
|
|
174
|
+
// debug
|
|
175
|
+
if (dump_kv_cache) {
|
|
176
|
+
llama_kv_cache_view_update(ctx, &kvc_view);
|
|
177
|
+
dump_kv_cache_view_seqs(kvc_view, 40);
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
// build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/
|
|
181
|
+
//
|
|
182
|
+
// Example for W = 5, N = 4, G = 2:
|
|
183
|
+
// (I = input, L = lookahead, V = verification)
|
|
184
|
+
//
|
|
185
|
+
// Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
|
|
186
|
+
// T: -2 -2 -2 -2 -1 -1 -1 -1 -1 0 0 0 0 0 0
|
|
187
|
+
// Info: I L L L L L L L L L L L L L L V V V V V V
|
|
188
|
+
// Pos: 0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 1 2 3 1 2 3 (+ n_past)
|
|
189
|
+
// Logits: 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
|
|
190
|
+
// ---------------------------------------------------------------------
|
|
191
|
+
// Seq: 0
|
|
192
|
+
// 1 1 1
|
|
193
|
+
// 2 2 2 2
|
|
194
|
+
// 3 3 3 3 3
|
|
195
|
+
// 4 4 4 4 4 4
|
|
196
|
+
// 5 5 5 5 5 5 5
|
|
197
|
+
// 6 6 6 6
|
|
198
|
+
// 7 7 7 7
|
|
199
|
+
// ---------------------------------------------------------------------
|
|
200
|
+
// | | | | | | | | | | |
|
|
201
|
+
// V V V V V | | | | | |
|
|
202
|
+
// j_tokens | | | | | |
|
|
203
|
+
// V V V V V V
|
|
204
|
+
// id
|
|
205
|
+
{
|
|
206
|
+
llama_batch_clear(batch);
|
|
207
|
+
|
|
208
|
+
// current token - first token of the first level
|
|
209
|
+
llama_batch_add(batch, id, n_past, seq_id_all, true);
|
|
210
|
+
|
|
211
|
+
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
|
|
212
|
+
{
|
|
213
|
+
const int g_cur = ngrams_observed.cnt[id];
|
|
214
|
+
|
|
215
|
+
ngrams_cur.resize(g_cur);
|
|
216
|
+
for (int g = 0; g < g_cur; g++) {
|
|
217
|
+
ngrams_cur[g].active = true;
|
|
218
|
+
ngrams_cur[g].tokens.resize(N);
|
|
219
|
+
ngrams_cur[g].i_batch.resize(N);
|
|
220
|
+
ngrams_cur[g].seq_id = W + 1 + g;
|
|
221
|
+
ngrams_cur[g].i_batch[0] = 0;
|
|
222
|
+
ngrams_cur[g].tokens [0] = id;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
for (int j = 0; j < N - 1; j++) {
|
|
226
|
+
for (int g = 0; g < g_cur; g++) {
|
|
227
|
+
const int idx = id*(N - 1)*G + g*(N - 1);
|
|
228
|
+
|
|
229
|
+
const llama_token t = ngrams_observed.tokens[idx + j];
|
|
230
|
+
|
|
231
|
+
ngrams_cur[g].tokens [j + 1] = t;
|
|
232
|
+
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
|
|
233
|
+
|
|
234
|
+
llama_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
|
|
235
|
+
}
|
|
236
|
+
}
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
// fill the remaining W - 1 tokens for the first level
|
|
240
|
+
for (int i = 1; i < W; i++) {
|
|
241
|
+
seq_id_look.resize(W - i);
|
|
242
|
+
for (int j = 0; j < W - i; j++) {
|
|
243
|
+
seq_id_look[j] = i + j + 1;
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
// fill the rest of the levels
|
|
250
|
+
for (int j = 1; j < N - 1; j++) {
|
|
251
|
+
for (int i = 0; i < W; i++) {
|
|
252
|
+
llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
if (llama_decode(ctx, batch) != 0) {
|
|
258
|
+
fprintf(stderr, "\n\n%s: error: llama_decode failed - increase KV cache size\n", __func__);
|
|
259
|
+
return 1;
|
|
260
|
+
}
|
|
261
|
+
|
|
262
|
+
int seq_id_best = 0;
|
|
263
|
+
|
|
264
|
+
for (int v = 0; v < N; ++v) {
|
|
265
|
+
int i_batch = 0;
|
|
266
|
+
|
|
267
|
+
// if no active ngrams are left, it means the sampled token does not pass the verification
|
|
268
|
+
if (v > 0) {
|
|
269
|
+
for (int g = 0; g < (int) ngrams_cur.size(); g++) {
|
|
270
|
+
if (ngrams_cur[g].active) {
|
|
271
|
+
i_batch = ngrams_cur[g].i_batch[v];
|
|
272
|
+
seq_id_best = ngrams_cur[g].seq_id;
|
|
273
|
+
|
|
274
|
+
++n_accept;
|
|
275
|
+
break;
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
|
|
279
|
+
// no more matches -> create a new batch
|
|
280
|
+
if (i_batch == 0) {
|
|
281
|
+
break;
|
|
282
|
+
}
|
|
283
|
+
}
|
|
284
|
+
|
|
285
|
+
// sample the next token
|
|
286
|
+
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
|
|
287
|
+
|
|
288
|
+
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
|
289
|
+
|
|
290
|
+
// print
|
|
291
|
+
{
|
|
292
|
+
const std::string token_str = llama_token_to_piece(ctx, id);
|
|
293
|
+
|
|
294
|
+
if (v == 0) {
|
|
295
|
+
printf("%s", token_str.c_str());
|
|
296
|
+
} else {
|
|
297
|
+
// print light cyan
|
|
298
|
+
printf("\033[0;96m%s\033[0m", token_str.c_str());
|
|
299
|
+
}
|
|
300
|
+
fflush(stdout);
|
|
301
|
+
|
|
302
|
+
if (llama_token_is_eog(model, id)) {
|
|
303
|
+
has_eos = true;
|
|
304
|
+
}
|
|
305
|
+
|
|
306
|
+
all.push_back(id);
|
|
307
|
+
}
|
|
308
|
+
|
|
309
|
+
++n_predict;
|
|
310
|
+
++n_past;
|
|
311
|
+
|
|
312
|
+
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
|
313
|
+
break;
|
|
314
|
+
}
|
|
315
|
+
|
|
316
|
+
// verify across active n-grams
|
|
317
|
+
for (int g = 0; g < (int) ngrams_cur.size(); g++) {
|
|
318
|
+
if (ngrams_cur[g].active) {
|
|
319
|
+
if (v == N - 1) {
|
|
320
|
+
ngrams_cur[g].active = false;
|
|
321
|
+
} else {
|
|
322
|
+
if (id != ngrams_cur[g].tokens[v + 1]) {
|
|
323
|
+
ngrams_cur[g].active = false;
|
|
324
|
+
}
|
|
325
|
+
}
|
|
326
|
+
}
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
// print known n-grams starting with token id (debug)
|
|
330
|
+
if (0 && v == 0) {
|
|
331
|
+
if (ngrams_observed.cnt[id] > 0) {
|
|
332
|
+
printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
|
|
333
|
+
}
|
|
334
|
+
|
|
335
|
+
for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
|
|
336
|
+
printf(" - ngram %2d: ", i);
|
|
337
|
+
|
|
338
|
+
const int idx = id*(N - 1)*G + i*(N - 1);
|
|
339
|
+
|
|
340
|
+
for (int j = 0; j < N - 1; j++) {
|
|
341
|
+
const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
|
|
342
|
+
|
|
343
|
+
printf("%s", token_str.c_str());
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
printf("\n");
|
|
347
|
+
}
|
|
348
|
+
}
|
|
349
|
+
|
|
350
|
+
// update lookahead tokens
|
|
351
|
+
{
|
|
352
|
+
for (int i = 0; i < W; i++) {
|
|
353
|
+
tokens_j_prev[i] = tokens_j[0][i];
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
for (int j = 0; j < N - 2; j++) {
|
|
357
|
+
tokens_j[j] = tokens_j[j + 1];
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
if (v == 0) {
|
|
361
|
+
// sample from the last level
|
|
362
|
+
for (int i = 0; i < W; i++) {
|
|
363
|
+
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
|
364
|
+
}
|
|
365
|
+
} else {
|
|
366
|
+
for (int i = 0; i < W; i++) {
|
|
367
|
+
// there are different ways to init these tokens
|
|
368
|
+
if (0) {
|
|
369
|
+
// random init
|
|
370
|
+
tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
|
|
371
|
+
} else {
|
|
372
|
+
// init from the previous level
|
|
373
|
+
tokens_j[N - 2][i] = tokens_j[0][i];
|
|
374
|
+
}
|
|
375
|
+
}
|
|
376
|
+
}
|
|
377
|
+
}
|
|
378
|
+
|
|
379
|
+
// update observed ngrams
|
|
380
|
+
if (v == 0) {
|
|
381
|
+
// the first token of the n-gram is determined by the index in the container so it is not stored
|
|
382
|
+
std::vector<llama_token> ngram(N - 1);
|
|
383
|
+
|
|
384
|
+
// n-gram generation
|
|
385
|
+
// ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518
|
|
386
|
+
for (int f = 0; f < W; ++f) {
|
|
387
|
+
const int ft = tokens_j_prev[f]; // first token of the n-gram
|
|
388
|
+
|
|
389
|
+
for (int j = 0; j < N - 1; ++j) {
|
|
390
|
+
ngram[j] = tokens_j[j][f];
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
// filter-out repeating n-grams
|
|
394
|
+
{
|
|
395
|
+
bool is_unique = true;
|
|
396
|
+
|
|
397
|
+
for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) {
|
|
398
|
+
const int idx = ft*(N - 1)*G + k*(N - 1);
|
|
399
|
+
|
|
400
|
+
bool is_match = true;
|
|
401
|
+
for (int j = 0; j < N - 1; ++j) {
|
|
402
|
+
if (ngrams_observed.tokens[idx + j] != ngram[j]) {
|
|
403
|
+
is_match = false;
|
|
404
|
+
break;
|
|
405
|
+
}
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
if (is_match) {
|
|
409
|
+
is_unique = false;
|
|
410
|
+
break;
|
|
411
|
+
}
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
if (!is_unique) {
|
|
415
|
+
continue;
|
|
416
|
+
}
|
|
417
|
+
}
|
|
418
|
+
|
|
419
|
+
const int head = ngrams_observed.head[ft];
|
|
420
|
+
const int idx = ft*(N - 1)*G + head*(N - 1);
|
|
421
|
+
|
|
422
|
+
for (int i = 0; i < N - 1; i++) {
|
|
423
|
+
ngrams_observed.tokens[idx + i] = ngram[i];
|
|
424
|
+
}
|
|
425
|
+
|
|
426
|
+
ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
|
|
427
|
+
ngrams_observed.head[ft] = (head + 1) % G;
|
|
428
|
+
|
|
429
|
+
ngrams_observed.n_total++;
|
|
430
|
+
}
|
|
431
|
+
}
|
|
432
|
+
}
|
|
433
|
+
|
|
434
|
+
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
|
435
|
+
break;
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
// KV cache management
|
|
439
|
+
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
|
|
440
|
+
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
|
|
441
|
+
|
|
442
|
+
if (seq_id_best != 0) {
|
|
443
|
+
// if a verification token matched, we keep the best sequence and remove the rest
|
|
444
|
+
// this leads to some KV cache fragmentation
|
|
445
|
+
llama_kv_cache_seq_keep(ctx, seq_id_best);
|
|
446
|
+
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
|
|
447
|
+
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
|
|
448
|
+
|
|
449
|
+
for (int s = 1; s < W + G + 1; ++s) {
|
|
450
|
+
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
|
451
|
+
}
|
|
452
|
+
}
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
auto t_dec_end = ggml_time_us();
|
|
456
|
+
|
|
457
|
+
LOG_TEE("\n\n");
|
|
458
|
+
|
|
459
|
+
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
|
460
|
+
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
|
461
|
+
|
|
462
|
+
LOG_TEE("\n");
|
|
463
|
+
LOG_TEE("W = %2d\n", W);
|
|
464
|
+
LOG_TEE("N = %2d\n", N);
|
|
465
|
+
LOG_TEE("G = %2d\n", G);
|
|
466
|
+
LOG_TEE("\n");
|
|
467
|
+
LOG_TEE("n_predict = %d\n", n_predict);
|
|
468
|
+
LOG_TEE("n_accept = %d\n", n_accept);
|
|
469
|
+
|
|
470
|
+
llama_print_timings(ctx);
|
|
471
|
+
|
|
472
|
+
llama_kv_cache_view_free(&kvc_view);
|
|
473
|
+
llama_sampling_free(ctx_sampling);
|
|
474
|
+
|
|
475
|
+
llama_batch_free(batch);
|
|
476
|
+
|
|
477
|
+
llama_free(ctx);
|
|
478
|
+
llama_free_model(model);
|
|
479
|
+
|
|
480
|
+
llama_backend_free();
|
|
481
|
+
|
|
482
|
+
fprintf(stderr, "\n\n");
|
|
483
|
+
|
|
484
|
+
return 0;
|
|
485
|
+
}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
set(TARGET lookup)
|
|
2
|
+
add_executable(${TARGET} lookup.cpp)
|
|
3
|
+
install(TARGETS ${TARGET} RUNTIME)
|
|
4
|
+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
|
5
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|
6
|
+
|
|
7
|
+
set(TARGET lookup-create)
|
|
8
|
+
add_executable(${TARGET} lookup-create.cpp)
|
|
9
|
+
install(TARGETS ${TARGET} RUNTIME)
|
|
10
|
+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
|
11
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|
12
|
+
|
|
13
|
+
set(TARGET lookup-merge)
|
|
14
|
+
add_executable(${TARGET} lookup-merge.cpp)
|
|
15
|
+
install(TARGETS ${TARGET} RUNTIME)
|
|
16
|
+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
|
17
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|
18
|
+
|
|
19
|
+
set(TARGET lookup-stats)
|
|
20
|
+
add_executable(${TARGET} lookup-stats.cpp)
|
|
21
|
+
install(TARGETS ${TARGET} RUNTIME)
|
|
22
|
+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
|
23
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
#include "ggml.h"
|
|
2
|
+
#include "llama.h"
|
|
3
|
+
#include "common.h"
|
|
4
|
+
#include "ngram-cache.h"
|
|
5
|
+
|
|
6
|
+
#include <cstdint>
|
|
7
|
+
#include <fstream>
|
|
8
|
+
#include <iostream>
|
|
9
|
+
#include <string>
|
|
10
|
+
#include <unordered_map>
|
|
11
|
+
#include <vector>
|
|
12
|
+
|
|
13
|
+
int main(int argc, char ** argv){
|
|
14
|
+
gpt_params params;
|
|
15
|
+
|
|
16
|
+
if (!gpt_params_parse(argc, argv, params)) {
|
|
17
|
+
return 1;
|
|
18
|
+
}
|
|
19
|
+
// init llama.cpp
|
|
20
|
+
llama_backend_init();
|
|
21
|
+
llama_numa_init(params.numa);
|
|
22
|
+
|
|
23
|
+
llama_model * model = NULL;
|
|
24
|
+
llama_context * ctx = NULL;
|
|
25
|
+
|
|
26
|
+
// load the model
|
|
27
|
+
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
|
28
|
+
GGML_ASSERT(model != nullptr);
|
|
29
|
+
|
|
30
|
+
// tokenize the prompt
|
|
31
|
+
std::vector<llama_token> inp;
|
|
32
|
+
inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
|
33
|
+
fprintf(stderr, "%s: tokenization done\n", __func__);
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
llama_ngram_cache ngram_cache;
|
|
37
|
+
llama_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
|
|
38
|
+
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
|
|
39
|
+
|
|
40
|
+
llama_ngram_cache_save(ngram_cache, params.lookup_cache_static);
|
|
41
|
+
}
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
#include "ggml.h"
|
|
2
|
+
#include "llama.h"
|
|
3
|
+
#include "common.h"
|
|
4
|
+
#include "ngram-cache.h"
|
|
5
|
+
|
|
6
|
+
#include <cstdint>
|
|
7
|
+
#include <cstdio>
|
|
8
|
+
#include <fstream>
|
|
9
|
+
#include <iostream>
|
|
10
|
+
#include <string>
|
|
11
|
+
#include <unordered_map>
|
|
12
|
+
#include <vector>
|
|
13
|
+
|
|
14
|
+
static void print_usage() {
|
|
15
|
+
fprintf(stderr, "Merges multiple lookup cache files into a single one.\n");
|
|
16
|
+
fprintf(stderr, "Usage: lookup-merge [--help] lookup_part_1.bin lookup_part_2.bin ... lookup_merged.bin\n");
|
|
17
|
+
}
|
|
18
|
+
|
|
19
|
+
int main(int argc, char ** argv){
|
|
20
|
+
if (argc < 3) {
|
|
21
|
+
print_usage();
|
|
22
|
+
exit(1);
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
std::vector<std::string> args;
|
|
26
|
+
args.resize(argc-1);
|
|
27
|
+
for (int i = 0; i < argc-1; ++i) {
|
|
28
|
+
args[i] = argv[i+1];
|
|
29
|
+
if (args[i] == "-h" || args[i] == "--help") {
|
|
30
|
+
print_usage();
|
|
31
|
+
exit(0);
|
|
32
|
+
}
|
|
33
|
+
}
|
|
34
|
+
|
|
35
|
+
fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
|
|
36
|
+
llama_ngram_cache ngram_cache_merged = llama_ngram_cache_load(args[0]);
|
|
37
|
+
|
|
38
|
+
for (size_t i = 1; i < args.size()-1; ++i) {
|
|
39
|
+
fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
|
|
40
|
+
llama_ngram_cache ngram_cache = llama_ngram_cache_load(args[i]);
|
|
41
|
+
|
|
42
|
+
llama_ngram_cache_merge(ngram_cache_merged, ngram_cache);
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str());
|
|
46
|
+
llama_ngram_cache_save(ngram_cache_merged, args.back());
|
|
47
|
+
}
|