@fugood/llama.node 0.0.1-alpha.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +85 -0
- package/README.md +56 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/lib/binding.js +13 -0
- package/lib/binding.ts +57 -0
- package/lib/index.js +24 -0
- package/lib/index.ts +13 -0
- package/package.json +65 -0
- package/src/addons.cpp +506 -0
- package/src/llama.cpp/CMakeLists.txt +1320 -0
- package/src/llama.cpp/build.zig +172 -0
- package/src/llama.cpp/cmake/FindSIMD.cmake +100 -0
- package/src/llama.cpp/common/CMakeLists.txt +87 -0
- package/src/llama.cpp/common/base64.hpp +392 -0
- package/src/llama.cpp/common/common.cpp +2949 -0
- package/src/llama.cpp/common/common.h +324 -0
- package/src/llama.cpp/common/console.cpp +501 -0
- package/src/llama.cpp/common/console.h +19 -0
- package/src/llama.cpp/common/grammar-parser.cpp +440 -0
- package/src/llama.cpp/common/grammar-parser.h +29 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +764 -0
- package/src/llama.cpp/common/json-schema-to-grammar.h +4 -0
- package/src/llama.cpp/common/json.hpp +24766 -0
- package/src/llama.cpp/common/log.h +724 -0
- package/src/llama.cpp/common/ngram-cache.cpp +282 -0
- package/src/llama.cpp/common/ngram-cache.h +94 -0
- package/src/llama.cpp/common/sampling.cpp +353 -0
- package/src/llama.cpp/common/sampling.h +147 -0
- package/src/llama.cpp/common/stb_image.h +8396 -0
- package/src/llama.cpp/common/train.cpp +1513 -0
- package/src/llama.cpp/common/train.h +233 -0
- package/src/llama.cpp/examples/CMakeLists.txt +52 -0
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +1640 -0
- package/src/llama.cpp/examples/batched/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/batched/batched.cpp +262 -0
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +261 -0
- package/src/llama.cpp/examples/beam-search/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/beam-search/beam-search.cpp +188 -0
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +6 -0
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +275 -0
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +936 -0
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/embedding/embedding.cpp +211 -0
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +9 -0
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +195 -0
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +462 -0
- package/src/llama.cpp/examples/finetune/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/finetune/finetune.cpp +1861 -0
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +132 -0
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gguf/gguf.cpp +256 -0
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +553 -0
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +215 -0
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +655 -0
- package/src/llama.cpp/examples/infill/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/infill/infill.cpp +767 -0
- package/src/llama.cpp/examples/jeopardy/questions.txt +100 -0
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +1286 -0
- package/src/llama.cpp/examples/llama.android/app/src/main/cpp/CMakeLists.txt +50 -0
- package/src/llama.cpp/examples/llama.android/app/src/main/cpp/llama-android.cpp +443 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +37 -0
- package/src/llama.cpp/examples/llava/clip.cpp +2027 -0
- package/src/llama.cpp/examples/llava/clip.h +85 -0
- package/src/llama.cpp/examples/llava/llava-cli.cpp +309 -0
- package/src/llama.cpp/examples/llava/llava.cpp +426 -0
- package/src/llama.cpp/examples/llava/llava.h +50 -0
- package/src/llama.cpp/examples/llava/requirements.txt +3 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +485 -0
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +23 -0
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +41 -0
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +47 -0
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +160 -0
- package/src/llama.cpp/examples/lookup/lookup.cpp +258 -0
- package/src/llama.cpp/examples/main/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/main/main.cpp +957 -0
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +33 -0
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/parallel/parallel.cpp +427 -0
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/passkey/passkey.cpp +302 -0
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +1943 -0
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +6 -0
- package/src/llama.cpp/examples/quantize/quantize.cpp +423 -0
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +6 -0
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +424 -0
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +350 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +246 -0
- package/src/llama.cpp/examples/server/CMakeLists.txt +40 -0
- package/src/llama.cpp/examples/server/bench/requirements.txt +2 -0
- package/src/llama.cpp/examples/server/httplib.h +9465 -0
- package/src/llama.cpp/examples/server/server.cpp +3826 -0
- package/src/llama.cpp/examples/server/tests/requirements.txt +6 -0
- package/src/llama.cpp/examples/server/utils.hpp +653 -0
- package/src/llama.cpp/examples/simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple/simple.cpp +183 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +614 -0
- package/src/llama.cpp/examples/sycl/CMakeLists.txt +9 -0
- package/src/llama.cpp/examples/sycl/ls-sycl-device.cpp +13 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +42 -0
- package/src/llama.cpp/examples/train-text-from-scratch/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/train-text-from-scratch/train-text-from-scratch.cpp +1252 -0
- package/src/llama.cpp/ggml-alloc.c +985 -0
- package/src/llama.cpp/ggml-alloc.h +76 -0
- package/src/llama.cpp/ggml-backend-impl.h +141 -0
- package/src/llama.cpp/ggml-backend.c +2099 -0
- package/src/llama.cpp/ggml-backend.h +233 -0
- package/src/llama.cpp/ggml-common.h +1853 -0
- package/src/llama.cpp/ggml-cuda.h +43 -0
- package/src/llama.cpp/ggml-impl.h +265 -0
- package/src/llama.cpp/ggml-kompute.cpp +2006 -0
- package/src/llama.cpp/ggml-kompute.h +46 -0
- package/src/llama.cpp/ggml-metal.h +66 -0
- package/src/llama.cpp/ggml-mpi.c +216 -0
- package/src/llama.cpp/ggml-mpi.h +39 -0
- package/src/llama.cpp/ggml-opencl.cpp +2301 -0
- package/src/llama.cpp/ggml-opencl.h +36 -0
- package/src/llama.cpp/ggml-quants.c +12678 -0
- package/src/llama.cpp/ggml-quants.h +133 -0
- package/src/llama.cpp/ggml-sycl.cpp +17882 -0
- package/src/llama.cpp/ggml-sycl.h +49 -0
- package/src/llama.cpp/ggml-vulkan-shaders.hpp +69849 -0
- package/src/llama.cpp/ggml-vulkan.cpp +6442 -0
- package/src/llama.cpp/ggml-vulkan.h +29 -0
- package/src/llama.cpp/ggml.c +21819 -0
- package/src/llama.cpp/ggml.h +2403 -0
- package/src/llama.cpp/llama.cpp +17468 -0
- package/src/llama.cpp/llama.h +1117 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +12 -0
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +9 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +172 -0
- package/src/llama.cpp/pocs/vdot/vdot.cpp +310 -0
- package/src/llama.cpp/prompts/LLM-questions.txt +49 -0
- package/src/llama.cpp/prompts/alpaca.txt +1 -0
- package/src/llama.cpp/prompts/assistant.txt +31 -0
- package/src/llama.cpp/prompts/chat-with-baichuan.txt +4 -0
- package/src/llama.cpp/prompts/chat-with-bob.txt +7 -0
- package/src/llama.cpp/prompts/chat-with-qwen.txt +1 -0
- package/src/llama.cpp/prompts/chat-with-vicuna-v0.txt +7 -0
- package/src/llama.cpp/prompts/chat-with-vicuna-v1.txt +7 -0
- package/src/llama.cpp/prompts/chat.txt +28 -0
- package/src/llama.cpp/prompts/dan-modified.txt +1 -0
- package/src/llama.cpp/prompts/dan.txt +1 -0
- package/src/llama.cpp/prompts/mnemonics.txt +93 -0
- package/src/llama.cpp/prompts/parallel-questions.txt +43 -0
- package/src/llama.cpp/prompts/reason-act.txt +18 -0
- package/src/llama.cpp/requirements/requirements-convert-hf-to-gguf.txt +3 -0
- package/src/llama.cpp/requirements/requirements-convert-llama-ggml-to-gguf.txt +1 -0
- package/src/llama.cpp/requirements/requirements-convert-lora-to-ggml.txt +2 -0
- package/src/llama.cpp/requirements/requirements-convert-persimmon-to-gguf.txt +2 -0
- package/src/llama.cpp/requirements/requirements-convert.txt +5 -0
- package/src/llama.cpp/requirements.txt +12 -0
- package/src/llama.cpp/scripts/gen-build-info-cpp.cmake +24 -0
- package/src/llama.cpp/scripts/xxd.cmake +16 -0
- package/src/llama.cpp/sgemm.cpp +999 -0
- package/src/llama.cpp/sgemm.h +12 -0
- package/src/llama.cpp/tests/CMakeLists.txt +78 -0
- package/src/llama.cpp/tests/get-model.cpp +21 -0
- package/src/llama.cpp/tests/get-model.h +2 -0
- package/src/llama.cpp/tests/test-autorelease.cpp +24 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +2266 -0
- package/src/llama.cpp/tests/test-c.c +7 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +107 -0
- package/src/llama.cpp/tests/test-double-float.cpp +57 -0
- package/src/llama.cpp/tests/test-grad0.cpp +1606 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +243 -0
- package/src/llama.cpp/tests/test-grammar-parser.cpp +250 -0
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +899 -0
- package/src/llama.cpp/tests/test-llama-grammar.cpp +402 -0
- package/src/llama.cpp/tests/test-model-load-cancel.cpp +27 -0
- package/src/llama.cpp/tests/test-opt.cpp +181 -0
- package/src/llama.cpp/tests/test-quantize-fns.cpp +185 -0
- package/src/llama.cpp/tests/test-quantize-perf.cpp +363 -0
- package/src/llama.cpp/tests/test-rope.cpp +221 -0
- package/src/llama.cpp/tests/test-sampling.cpp +301 -0
- package/src/llama.cpp/tests/test-tokenizer-0-falcon.cpp +187 -0
- package/src/llama.cpp/tests/test-tokenizer-0-llama.cpp +190 -0
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +123 -0
- package/src/llama.cpp/tests/test-tokenizer-1-llama.cpp +111 -0
- package/src/llama.cpp/unicode-data.cpp +1651 -0
- package/src/llama.cpp/unicode-data.h +16 -0
- package/src/llama.cpp/unicode.cpp +277 -0
- package/src/llama.cpp/unicode.h +28 -0
|
@@ -0,0 +1,614 @@
|
|
|
1
|
+
#include "common.h"
|
|
2
|
+
#include "llama.h"
|
|
3
|
+
|
|
4
|
+
#include <cmath>
|
|
5
|
+
#include <cstdio>
|
|
6
|
+
#include <string>
|
|
7
|
+
#include <vector>
|
|
8
|
+
#include <set>
|
|
9
|
+
|
|
10
|
+
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
|
11
|
+
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
12
|
+
|
|
13
|
+
struct seq_draft {
|
|
14
|
+
bool active = false;
|
|
15
|
+
bool drafting = false;
|
|
16
|
+
bool skip = false;
|
|
17
|
+
|
|
18
|
+
int i_batch_dft = 0;
|
|
19
|
+
std::vector<int> i_batch_tgt;
|
|
20
|
+
|
|
21
|
+
std::vector<llama_token> tokens;
|
|
22
|
+
std::vector<std::vector<llama_token_data>> dists;
|
|
23
|
+
|
|
24
|
+
struct llama_sampling_context * ctx_sampling;
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
int main(int argc, char ** argv) {
|
|
28
|
+
gpt_params params;
|
|
29
|
+
|
|
30
|
+
if (gpt_params_parse(argc, argv, params) == false) {
|
|
31
|
+
return 1;
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
if (params.model_draft.empty()) {
|
|
35
|
+
fprintf(stderr, "%s: error: --model-draft is required\n", __func__);
|
|
36
|
+
return 1;
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
// max number of parallel drafting sequences (i.e. tree branches)
|
|
40
|
+
const int n_seq_dft = params.n_parallel;
|
|
41
|
+
|
|
42
|
+
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
|
43
|
+
const float p_split = params.p_split;
|
|
44
|
+
|
|
45
|
+
if (params.seed == LLAMA_DEFAULT_SEED) {
|
|
46
|
+
params.seed = time(NULL);
|
|
47
|
+
}
|
|
48
|
+
std::default_random_engine rng(params.seed);
|
|
49
|
+
std::uniform_real_distribution<> u_dist;
|
|
50
|
+
|
|
51
|
+
#ifndef LOG_DISABLE_LOGS
|
|
52
|
+
log_set_target(log_filename_generator("speculative", "log"));
|
|
53
|
+
LOG_TEE("Log start\n");
|
|
54
|
+
log_dump_cmdline(argc, argv);
|
|
55
|
+
#endif // LOG_DISABLE_LOGS
|
|
56
|
+
|
|
57
|
+
// init llama.cpp
|
|
58
|
+
llama_backend_init();
|
|
59
|
+
llama_numa_init(params.numa);
|
|
60
|
+
|
|
61
|
+
llama_model * model_tgt = NULL;
|
|
62
|
+
llama_model * model_dft = NULL;
|
|
63
|
+
|
|
64
|
+
llama_context * ctx_tgt = NULL;
|
|
65
|
+
llama_context * ctx_dft = NULL;
|
|
66
|
+
|
|
67
|
+
// load the target model
|
|
68
|
+
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
|
69
|
+
|
|
70
|
+
// load the draft model
|
|
71
|
+
params.model = params.model_draft;
|
|
72
|
+
params.n_gpu_layers = params.n_gpu_layers_draft;
|
|
73
|
+
if (params.n_threads_draft > 0) {
|
|
74
|
+
params.n_threads = params.n_threads_draft;
|
|
75
|
+
}
|
|
76
|
+
params.n_threads_batch = params.n_threads_batch_draft;
|
|
77
|
+
std::tie(model_dft, ctx_dft) = llama_init_from_gpt_params(params);
|
|
78
|
+
|
|
79
|
+
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
|
|
80
|
+
LOG("vocab_type tgt: %d\n", vocab_type_tgt);
|
|
81
|
+
|
|
82
|
+
const bool vocab_type_dft = llama_vocab_type(model_dft);
|
|
83
|
+
LOG("vocab_type dft: %d\n", vocab_type_dft);
|
|
84
|
+
|
|
85
|
+
if (vocab_type_tgt != vocab_type_dft) {
|
|
86
|
+
fprintf(stderr, "%s: error: draft model vocab type must match target model to use speculation but ", __func__);
|
|
87
|
+
fprintf(stderr, "vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
|
88
|
+
return 1;
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
if (
|
|
92
|
+
llama_add_bos_token(model_tgt) != llama_add_bos_token(model_dft) ||
|
|
93
|
+
llama_add_eos_token(model_tgt) != llama_add_eos_token(model_dft) ||
|
|
94
|
+
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
|
|
95
|
+
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
|
|
96
|
+
) {
|
|
97
|
+
fprintf(stderr, "%s: error: draft model special tokens must match target model to use speculation\n", __func__);
|
|
98
|
+
return 1;
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
{
|
|
102
|
+
const int n_vocab_tgt = llama_n_vocab(model_tgt);
|
|
103
|
+
const int n_vocab_dft = llama_n_vocab(model_dft);
|
|
104
|
+
const int vocab_diff = n_vocab_tgt > n_vocab_dft
|
|
105
|
+
? n_vocab_tgt - n_vocab_dft
|
|
106
|
+
: n_vocab_dft - n_vocab_tgt;
|
|
107
|
+
|
|
108
|
+
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
|
109
|
+
fprintf(stderr, "%s: error: draft model vocab must closely match target model to use speculation but ", __func__);
|
|
110
|
+
fprintf(stderr, "target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
|
111
|
+
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
|
112
|
+
return 1;
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
|
|
116
|
+
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
|
117
|
+
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
|
118
|
+
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
119
|
+
fprintf(stderr, "%s: error: draft model vocab must match target model to use speculation but ", __func__);
|
|
120
|
+
fprintf(stderr, "token %d content differs - target '%s', draft '%s'\n", i,
|
|
121
|
+
llama_token_to_piece(ctx_tgt, i).c_str(),
|
|
122
|
+
llama_token_to_piece(ctx_dft, i).c_str());
|
|
123
|
+
return 1;
|
|
124
|
+
}
|
|
125
|
+
}
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
// Tokenize the prompt
|
|
130
|
+
std::vector<llama_token> inp;
|
|
131
|
+
inp = ::llama_tokenize(ctx_tgt, params.prompt, true, true);
|
|
132
|
+
|
|
133
|
+
const int max_context_size = llama_n_ctx(ctx_tgt);
|
|
134
|
+
const int max_tokens_list_size = max_context_size - 4;
|
|
135
|
+
|
|
136
|
+
if ((int) inp.size() > max_tokens_list_size) {
|
|
137
|
+
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
|
|
138
|
+
return 1;
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
fprintf(stderr, "\n\n");
|
|
142
|
+
|
|
143
|
+
for (auto id : inp) {
|
|
144
|
+
fprintf(stderr, "%s", llama_token_to_piece(ctx_tgt, id).c_str());
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
fflush(stderr);
|
|
148
|
+
|
|
149
|
+
const int n_input = inp.size();
|
|
150
|
+
|
|
151
|
+
const auto t_enc_start = ggml_time_us();
|
|
152
|
+
|
|
153
|
+
// eval the prompt with both models
|
|
154
|
+
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
|
|
155
|
+
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
|
|
156
|
+
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0));
|
|
157
|
+
|
|
158
|
+
const auto t_enc_end = ggml_time_us();
|
|
159
|
+
|
|
160
|
+
// the 2 models should have the same vocab
|
|
161
|
+
//GGML_ASSERT(n_vocab == llama_n_vocab(model_dft));
|
|
162
|
+
|
|
163
|
+
// how many tokens to draft each time
|
|
164
|
+
int n_draft = params.n_draft;
|
|
165
|
+
|
|
166
|
+
int n_predict = 0;
|
|
167
|
+
int n_drafted = 0;
|
|
168
|
+
int n_accept = 0;
|
|
169
|
+
|
|
170
|
+
int n_past_tgt = inp.size();
|
|
171
|
+
int n_past_dft = inp.size();
|
|
172
|
+
|
|
173
|
+
// used to determine end of generation
|
|
174
|
+
bool has_eos = false;
|
|
175
|
+
|
|
176
|
+
// target model sampling context
|
|
177
|
+
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
|
178
|
+
|
|
179
|
+
// draft sequence data
|
|
180
|
+
std::vector<seq_draft> drafts(n_seq_dft);
|
|
181
|
+
|
|
182
|
+
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
|
183
|
+
if (params.sparams.temp == 0) {
|
|
184
|
+
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
for (int s = 0; s < n_seq_dft; ++s) {
|
|
188
|
+
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
|
189
|
+
}
|
|
190
|
+
|
|
191
|
+
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
|
192
|
+
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, n_seq_dft);
|
|
193
|
+
|
|
194
|
+
const auto t_dec_start = ggml_time_us();
|
|
195
|
+
|
|
196
|
+
// sample from the last token of the prompt
|
|
197
|
+
drafts[0].i_batch_tgt.resize(1);
|
|
198
|
+
drafts[0].i_batch_tgt[0] = 0;
|
|
199
|
+
|
|
200
|
+
while (true) {
|
|
201
|
+
std::set<int> active_seqs = {};
|
|
202
|
+
|
|
203
|
+
// print current draft sequences
|
|
204
|
+
for (int s = 0; s < n_seq_dft; ++s) {
|
|
205
|
+
if (!drafts[s].active) {
|
|
206
|
+
continue;
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
active_seqs.insert(s);
|
|
210
|
+
const auto & tokens = drafts[s].tokens;
|
|
211
|
+
|
|
212
|
+
LOG("draft %d: %s\n", s, LOG_TOKENS_TOSTR_PRETTY(ctx_dft, tokens).c_str());
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
int i_dft = 0;
|
|
216
|
+
int s_keep = 0;
|
|
217
|
+
|
|
218
|
+
llama_token token_id;
|
|
219
|
+
std::string token_str;
|
|
220
|
+
|
|
221
|
+
// loop until we fail to accept a drafted token or we run out of drafted tokens
|
|
222
|
+
while (true) {
|
|
223
|
+
|
|
224
|
+
// check if the target token matches any of the drafts
|
|
225
|
+
// for stochastic sampling, attempt to match the token with the drafted tokens
|
|
226
|
+
{
|
|
227
|
+
bool accept = false;
|
|
228
|
+
if (params.sparams.temp > 0) {
|
|
229
|
+
// stochastic verification
|
|
230
|
+
|
|
231
|
+
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
|
|
232
|
+
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
|
233
|
+
float p_tgt = 0, p_dft = 0;
|
|
234
|
+
|
|
235
|
+
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
|
236
|
+
|
|
237
|
+
while (active_seqs.size() > 0) {
|
|
238
|
+
// randomly select a sequence to verify from active sequences
|
|
239
|
+
std::uniform_int_distribution<unsigned int> u_int_dist(0, active_seqs.size() - 1);
|
|
240
|
+
int s = *std::next(active_seqs.begin(), u_int_dist(rng));
|
|
241
|
+
if (i_dft >= (int) drafts[s].tokens.size()) {
|
|
242
|
+
drafts[s].active = false;
|
|
243
|
+
active_seqs.erase(s);
|
|
244
|
+
continue;
|
|
245
|
+
}
|
|
246
|
+
if (accept) {
|
|
247
|
+
// if we already accepted a token, we can skip the rest
|
|
248
|
+
if (drafts[s].tokens[i_dft] != drafts[s_keep].tokens[i_dft]) {
|
|
249
|
+
drafts[s].active = false;
|
|
250
|
+
active_seqs.erase(s);
|
|
251
|
+
}
|
|
252
|
+
continue;
|
|
253
|
+
}
|
|
254
|
+
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
|
|
255
|
+
float r = u_dist(rng);
|
|
256
|
+
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
|
|
257
|
+
// acquire the token probabilities assigned by the draft and target models
|
|
258
|
+
for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
259
|
+
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
|
260
|
+
p_tgt = dist_tgt.data[i].p;
|
|
261
|
+
}
|
|
262
|
+
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
|
|
263
|
+
p_dft = dist_dft.data[i].p;
|
|
264
|
+
}
|
|
265
|
+
if (p_tgt && p_dft) {
|
|
266
|
+
break;
|
|
267
|
+
}
|
|
268
|
+
}
|
|
269
|
+
LOG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
|
|
270
|
+
if (r <= p_tgt / p_dft) {
|
|
271
|
+
s_keep = s;
|
|
272
|
+
accept = true;
|
|
273
|
+
token_id = drafts[s].tokens[i_dft];
|
|
274
|
+
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
|
275
|
+
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
|
276
|
+
|
|
277
|
+
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
|
278
|
+
break;
|
|
279
|
+
} else {
|
|
280
|
+
LOG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], llama_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
|
|
281
|
+
drafts[s].active = false;
|
|
282
|
+
|
|
283
|
+
// calculate residual probability
|
|
284
|
+
GGML_ASSERT(dist_tgt.sorted);
|
|
285
|
+
GGML_ASSERT(dist_dft.sorted);
|
|
286
|
+
float sum_probs = 0.0f;
|
|
287
|
+
|
|
288
|
+
// sort dist by id
|
|
289
|
+
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
290
|
+
return a.id < b.id;
|
|
291
|
+
});
|
|
292
|
+
std::sort(dist_dft.data, dist_dft.data + dist_dft.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
293
|
+
return a.id < b.id;
|
|
294
|
+
});
|
|
295
|
+
|
|
296
|
+
for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
297
|
+
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
|
|
298
|
+
sum_probs += dist_tgt.data[i].p;
|
|
299
|
+
}
|
|
300
|
+
for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
301
|
+
dist_tgt.data[i].p /= sum_probs;
|
|
302
|
+
}
|
|
303
|
+
|
|
304
|
+
// sort dist_tgt by p desc
|
|
305
|
+
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
306
|
+
return a.p > b.p;
|
|
307
|
+
});
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
active_seqs.erase(s);
|
|
311
|
+
for(int i = 0; i < n_seq_dft; i++) {
|
|
312
|
+
if (i == s) {
|
|
313
|
+
continue;
|
|
314
|
+
}
|
|
315
|
+
if (drafts[i].tokens[i_dft] == drafts[s].tokens[i_dft]) {
|
|
316
|
+
// synchronize active status for sequences with the same drafted token
|
|
317
|
+
drafts[i].active = drafts[i].active && accept;
|
|
318
|
+
if (!drafts[i].active) {
|
|
319
|
+
active_seqs.erase(s);
|
|
320
|
+
}
|
|
321
|
+
}
|
|
322
|
+
}
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
if (!accept) {
|
|
326
|
+
// all drafted tokens were rejected
|
|
327
|
+
// sample from the target model
|
|
328
|
+
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
|
|
329
|
+
token_id = llama_sample_token(ctx_tgt, &dist_tgt);
|
|
330
|
+
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
|
331
|
+
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
} else {
|
|
335
|
+
// greedy verification
|
|
336
|
+
|
|
337
|
+
// sample from the target model
|
|
338
|
+
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
339
|
+
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
340
|
+
|
|
341
|
+
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
|
342
|
+
|
|
343
|
+
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
|
344
|
+
|
|
345
|
+
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
|
346
|
+
|
|
347
|
+
for (int s = 0; s < n_seq_dft; ++s) {
|
|
348
|
+
if (!drafts[s].active) {
|
|
349
|
+
continue;
|
|
350
|
+
}
|
|
351
|
+
|
|
352
|
+
if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
|
|
353
|
+
LOG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
|
|
354
|
+
|
|
355
|
+
s_keep = s;
|
|
356
|
+
accept = true;
|
|
357
|
+
} else {
|
|
358
|
+
drafts[s].active = false;
|
|
359
|
+
}
|
|
360
|
+
}
|
|
361
|
+
}
|
|
362
|
+
|
|
363
|
+
if (llama_token_is_eog(model_tgt, token_id)) {
|
|
364
|
+
has_eos = true;
|
|
365
|
+
}
|
|
366
|
+
++n_predict;
|
|
367
|
+
|
|
368
|
+
if (accept) {
|
|
369
|
+
++n_accept;
|
|
370
|
+
++n_past_tgt;
|
|
371
|
+
++n_past_dft;
|
|
372
|
+
++i_dft;
|
|
373
|
+
if (params.use_color) {
|
|
374
|
+
// Color token according to its origin sequence
|
|
375
|
+
printf("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
|
|
376
|
+
} else {
|
|
377
|
+
printf("%s", token_str.c_str());
|
|
378
|
+
}
|
|
379
|
+
fflush(stdout);
|
|
380
|
+
continue;
|
|
381
|
+
} else {
|
|
382
|
+
printf("%s", token_str.c_str());
|
|
383
|
+
fflush(stdout);
|
|
384
|
+
break;
|
|
385
|
+
}
|
|
386
|
+
}
|
|
387
|
+
}
|
|
388
|
+
|
|
389
|
+
{
|
|
390
|
+
LOG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
|
|
391
|
+
|
|
392
|
+
// TODO: simplify
|
|
393
|
+
{
|
|
394
|
+
LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
|
395
|
+
|
|
396
|
+
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
|
397
|
+
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
|
398
|
+
llama_kv_cache_seq_keep(ctx_dft, 0);
|
|
399
|
+
|
|
400
|
+
llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1);
|
|
401
|
+
llama_kv_cache_seq_keep(ctx_tgt, s_keep);
|
|
402
|
+
llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1);
|
|
403
|
+
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
|
404
|
+
}
|
|
405
|
+
|
|
406
|
+
for (int s = 0; s < n_seq_dft; ++s) {
|
|
407
|
+
drafts[s].active = false;
|
|
408
|
+
drafts[s].tokens.clear();
|
|
409
|
+
drafts[s].i_batch_tgt.clear();
|
|
410
|
+
drafts[s].dists.clear();
|
|
411
|
+
}
|
|
412
|
+
// note: will be erased after the speculation phase
|
|
413
|
+
drafts[0].tokens.push_back(token_id);
|
|
414
|
+
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
|
415
|
+
drafts[0].i_batch_tgt.push_back(0);
|
|
416
|
+
|
|
417
|
+
llama_batch_clear(batch_dft);
|
|
418
|
+
llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
|
419
|
+
|
|
420
|
+
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
|
421
|
+
// LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
|
422
|
+
llama_decode(ctx_dft, batch_dft);
|
|
423
|
+
|
|
424
|
+
++n_past_dft;
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
if (n_predict > params.n_predict || has_eos) {
|
|
428
|
+
break;
|
|
429
|
+
}
|
|
430
|
+
|
|
431
|
+
llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
|
|
432
|
+
|
|
433
|
+
int n_seq_cur = 1;
|
|
434
|
+
int n_past_cur = n_past_dft;
|
|
435
|
+
|
|
436
|
+
for (int s = 0; s < n_seq_dft; ++s) {
|
|
437
|
+
drafts[s].active = false;
|
|
438
|
+
drafts[s].drafting = false;
|
|
439
|
+
}
|
|
440
|
+
drafts[0].active = true;
|
|
441
|
+
drafts[0].drafting = true;
|
|
442
|
+
drafts[0].i_batch_dft = 0;
|
|
443
|
+
|
|
444
|
+
llama_batch_clear(batch_tgt);
|
|
445
|
+
llama_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
|
446
|
+
|
|
447
|
+
// sample n_draft tokens from the draft model using tree-based sampling
|
|
448
|
+
for (int i = 0; i < n_draft; ++i) {
|
|
449
|
+
batch_dft.n_tokens = 0;
|
|
450
|
+
|
|
451
|
+
for (int s = 0; s < n_seq_dft; ++s) {
|
|
452
|
+
drafts[s].skip = false;
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
for (int s = 0; s < n_seq_dft; ++s) {
|
|
456
|
+
if (!drafts[s].drafting || drafts[s].skip) {
|
|
457
|
+
continue;
|
|
458
|
+
}
|
|
459
|
+
|
|
460
|
+
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
|
461
|
+
|
|
462
|
+
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
|
463
|
+
|
|
464
|
+
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
|
|
465
|
+
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
|
466
|
+
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
std::vector<int> sa(1, s);
|
|
470
|
+
|
|
471
|
+
// attempt to split the branch if the probability is high enough
|
|
472
|
+
for (int f = 1; f < 8; ++f) {
|
|
473
|
+
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
|
474
|
+
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
|
475
|
+
|
|
476
|
+
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
|
477
|
+
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
|
478
|
+
|
|
479
|
+
// all previous tokens from this branch are now also part of the new branch
|
|
480
|
+
for (int t = 0; t < batch_tgt.n_tokens; ++t) {
|
|
481
|
+
for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) {
|
|
482
|
+
if (batch_tgt.seq_id[t][p] == s) {
|
|
483
|
+
batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur;
|
|
484
|
+
batch_tgt.n_seq_id[t]++;
|
|
485
|
+
break;
|
|
486
|
+
}
|
|
487
|
+
}
|
|
488
|
+
}
|
|
489
|
+
|
|
490
|
+
// copy the draft state
|
|
491
|
+
drafts[n_seq_cur].active = true;
|
|
492
|
+
drafts[n_seq_cur].drafting = true;
|
|
493
|
+
drafts[n_seq_cur].skip = true;
|
|
494
|
+
|
|
495
|
+
drafts[n_seq_cur].tokens = drafts[s].tokens;
|
|
496
|
+
drafts[n_seq_cur].dists = drafts[s].dists;
|
|
497
|
+
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
|
498
|
+
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
|
499
|
+
|
|
500
|
+
llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
|
|
501
|
+
|
|
502
|
+
sa.push_back(n_seq_cur);
|
|
503
|
+
|
|
504
|
+
n_seq_cur++;
|
|
505
|
+
} else {
|
|
506
|
+
break;
|
|
507
|
+
}
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
// add drafted token for each sequence
|
|
511
|
+
for (int is = 0; is < (int) sa.size(); ++is) {
|
|
512
|
+
const llama_token id = cur_p[is].id;
|
|
513
|
+
|
|
514
|
+
const int s = sa[is];
|
|
515
|
+
|
|
516
|
+
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
|
517
|
+
|
|
518
|
+
drafts[s].tokens.push_back(id);
|
|
519
|
+
// save cur_p.data into drafts[s].dists
|
|
520
|
+
drafts[s].dists.push_back(cur_p);
|
|
521
|
+
|
|
522
|
+
// add unique drafted tokens to the target batch
|
|
523
|
+
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
|
524
|
+
|
|
525
|
+
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
|
526
|
+
|
|
527
|
+
// add the token to the batch for batched decoding with the draft model
|
|
528
|
+
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
|
529
|
+
|
|
530
|
+
llama_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
|
531
|
+
|
|
532
|
+
if (batch_tgt.n_tokens > n_draft) {
|
|
533
|
+
drafts[s].drafting = false;
|
|
534
|
+
}
|
|
535
|
+
}
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
// no sequence is drafting anymore
|
|
539
|
+
if (batch_dft.n_tokens == 0) {
|
|
540
|
+
break;
|
|
541
|
+
}
|
|
542
|
+
|
|
543
|
+
// evaluate the drafted tokens on the draft model
|
|
544
|
+
llama_decode(ctx_dft, batch_dft);
|
|
545
|
+
++n_past_cur;
|
|
546
|
+
++n_drafted;
|
|
547
|
+
|
|
548
|
+
if (batch_tgt.n_tokens > n_draft) {
|
|
549
|
+
break;
|
|
550
|
+
}
|
|
551
|
+
}
|
|
552
|
+
|
|
553
|
+
// evaluate the target model on the drafted tokens
|
|
554
|
+
{
|
|
555
|
+
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
|
556
|
+
for (int s = 1; s < n_seq_dft; ++s) {
|
|
557
|
+
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
|
558
|
+
}
|
|
559
|
+
|
|
560
|
+
// LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
|
561
|
+
llama_decode(ctx_tgt, batch_tgt);
|
|
562
|
+
++n_past_tgt;
|
|
563
|
+
}
|
|
564
|
+
|
|
565
|
+
// the first token is always proposed by the target model before the speculation loop so we erase it here
|
|
566
|
+
for (int s = 0; s < n_seq_dft; ++s) {
|
|
567
|
+
if (!drafts[s].active) {
|
|
568
|
+
continue;
|
|
569
|
+
}
|
|
570
|
+
|
|
571
|
+
drafts[s].tokens.erase(drafts[s].tokens.begin());
|
|
572
|
+
drafts[s].dists.erase(drafts[s].dists.begin());
|
|
573
|
+
}
|
|
574
|
+
}
|
|
575
|
+
|
|
576
|
+
auto t_dec_end = ggml_time_us();
|
|
577
|
+
|
|
578
|
+
LOG_TEE("\n\n");
|
|
579
|
+
|
|
580
|
+
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
|
581
|
+
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
|
582
|
+
|
|
583
|
+
LOG_TEE("\n");
|
|
584
|
+
LOG_TEE("n_draft = %d\n", n_draft);
|
|
585
|
+
LOG_TEE("n_predict = %d\n", n_predict);
|
|
586
|
+
LOG_TEE("n_drafted = %d\n", n_drafted);
|
|
587
|
+
LOG_TEE("n_accept = %d\n", n_accept);
|
|
588
|
+
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
|
589
|
+
|
|
590
|
+
LOG_TEE("\ndraft:\n");
|
|
591
|
+
llama_print_timings(ctx_dft);
|
|
592
|
+
|
|
593
|
+
LOG_TEE("\ntarget:\n");
|
|
594
|
+
llama_print_timings(ctx_tgt);
|
|
595
|
+
|
|
596
|
+
llama_sampling_free(ctx_sampling);
|
|
597
|
+
for (int s = 0; s < n_seq_dft; ++s) {
|
|
598
|
+
llama_sampling_free(drafts[s].ctx_sampling);
|
|
599
|
+
}
|
|
600
|
+
|
|
601
|
+
llama_batch_free(batch_dft);
|
|
602
|
+
|
|
603
|
+
llama_free(ctx_tgt);
|
|
604
|
+
llama_free_model(model_tgt);
|
|
605
|
+
|
|
606
|
+
llama_free(ctx_dft);
|
|
607
|
+
llama_free_model(model_dft);
|
|
608
|
+
|
|
609
|
+
llama_backend_free();
|
|
610
|
+
|
|
611
|
+
fprintf(stderr, "\n\n");
|
|
612
|
+
|
|
613
|
+
return 0;
|
|
614
|
+
}
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
# MIT license
|
|
2
|
+
# Copyright (C) 2024 Intel Corporation
|
|
3
|
+
# SPDX-License-Identifier: MIT
|
|
4
|
+
|
|
5
|
+
set(TARGET ls-sycl-device)
|
|
6
|
+
add_executable(${TARGET} ls-sycl-device.cpp)
|
|
7
|
+
install(TARGETS ${TARGET} RUNTIME)
|
|
8
|
+
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
|
9
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
#include "common.h"
|
|
2
|
+
#include "llama.h"
|
|
3
|
+
|
|
4
|
+
#include <cmath>
|
|
5
|
+
#include <cstdio>
|
|
6
|
+
#include <string>
|
|
7
|
+
#include <vector>
|
|
8
|
+
|
|
9
|
+
int main(int argc, char ** argv) {
|
|
10
|
+
if (argc < 3 || argv[1][0] == '-') {
|
|
11
|
+
printf("usage: %s MODEL_PATH PROMPT [--ids]\n" , argv[0]);
|
|
12
|
+
return 1;
|
|
13
|
+
}
|
|
14
|
+
|
|
15
|
+
const char * model_path = argv[1];
|
|
16
|
+
const char * prompt = argv[2];
|
|
17
|
+
|
|
18
|
+
const bool printing_ids = argc > 3 && std::string(argv[3]) == "--ids";
|
|
19
|
+
|
|
20
|
+
llama_backend_init();
|
|
21
|
+
|
|
22
|
+
llama_model_params model_params = llama_model_default_params();
|
|
23
|
+
model_params.vocab_only = true;
|
|
24
|
+
llama_model * model = llama_load_model_from_file(model_path, model_params);
|
|
25
|
+
|
|
26
|
+
llama_context_params ctx_params = llama_context_default_params();
|
|
27
|
+
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
|
28
|
+
|
|
29
|
+
std::vector<llama_token> tokens;
|
|
30
|
+
|
|
31
|
+
tokens = ::llama_tokenize(model, prompt, true, true);
|
|
32
|
+
|
|
33
|
+
for (int i = 0; i < (int) tokens.size(); i++) {
|
|
34
|
+
if (printing_ids) {
|
|
35
|
+
printf("%d\n", tokens[i]);
|
|
36
|
+
} else {
|
|
37
|
+
printf("%6d -> '%s'\n", tokens[i], llama_token_to_piece(ctx, tokens[i]).c_str());
|
|
38
|
+
}
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
return 0;
|
|
42
|
+
}
|