@fugood/llama.node 0.3.2 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +2 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +1 -1
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +8 -8
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +8 -9
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +4 -4
- package/src/llama.cpp/.github/workflows/build.yml +43 -9
- package/src/llama.cpp/.github/workflows/docker.yml +3 -0
- package/src/llama.cpp/CMakeLists.txt +7 -4
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +0 -2
- package/src/llama.cpp/common/arg.cpp +642 -607
- package/src/llama.cpp/common/arg.h +22 -22
- package/src/llama.cpp/common/common.cpp +79 -281
- package/src/llama.cpp/common/common.h +130 -100
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +50 -50
- package/src/llama.cpp/common/log.h +18 -18
- package/src/llama.cpp/common/ngram-cache.cpp +36 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +116 -108
- package/src/llama.cpp/common/sampling.h +20 -20
- package/src/llama.cpp/docs/build.md +37 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +14 -14
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +10 -11
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +1 -1
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +9 -9
- package/src/llama.cpp/examples/embedding/embedding.cpp +12 -12
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +8 -8
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +5 -5
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +7 -7
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +18 -18
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +20 -11
- package/src/llama.cpp/examples/infill/infill.cpp +40 -86
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +42 -151
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +11 -14
- package/src/llama.cpp/examples/llava/clip.cpp +1 -0
- package/src/llama.cpp/examples/llava/llava-cli.cpp +23 -23
- package/src/llama.cpp/examples/llava/llava.cpp +37 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +21 -21
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +26 -26
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +7 -7
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +14 -14
- package/src/llama.cpp/examples/lookup/lookup.cpp +29 -29
- package/src/llama.cpp/examples/main/main.cpp +64 -109
- package/src/llama.cpp/examples/parallel/parallel.cpp +18 -19
- package/src/llama.cpp/examples/passkey/passkey.cpp +14 -14
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +99 -120
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +10 -9
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +13 -13
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +3 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +34 -17
- package/src/llama.cpp/examples/server/CMakeLists.txt +4 -13
- package/src/llama.cpp/examples/server/server.cpp +553 -691
- package/src/llama.cpp/examples/server/utils.hpp +312 -25
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +128 -96
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +54 -51
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +2 -2
- package/src/llama.cpp/ggml/CMakeLists.txt +15 -9
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +46 -33
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +9 -7
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +12 -12
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +11 -7
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +53 -393
- package/src/llama.cpp/ggml/src/CMakeLists.txt +66 -1149
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +46 -3126
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +23 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +6 -25
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +303 -864
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +213 -65
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +255 -149
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/{ggml-cpu-impl.h → ggml-cpu/ggml-cpu-impl.h} +0 -243
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.cpp +667 -1
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +366 -16
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +238 -72
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +187 -10692
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -125
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +475 -300
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +3 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +40 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +258 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +2 -22
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3584 -4142
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +69 -67
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +3 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +555 -623
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +125 -206
- package/src/llama.cpp/ggml/src/ggml.c +4032 -19890
- package/src/llama.cpp/include/llama.h +67 -33
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-sampling.cpp +745 -105
- package/src/llama.cpp/src/llama-sampling.h +21 -2
- package/src/llama.cpp/src/llama-vocab.cpp +49 -9
- package/src/llama.cpp/src/llama-vocab.h +35 -11
- package/src/llama.cpp/src/llama.cpp +2636 -2406
- package/src/llama.cpp/src/unicode-data.cpp +2 -2
- package/src/llama.cpp/tests/CMakeLists.txt +1 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +14 -14
- package/src/llama.cpp/tests/test-backend-ops.cpp +185 -60
- package/src/llama.cpp/tests/test-barrier.cpp +1 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +17 -4
- package/src/llama.cpp/tests/test-log.cpp +2 -2
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +22 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +1 -0
- package/src/llama.cpp/tests/test-sampling.cpp +162 -137
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +7 -7
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/src/llama.cpp/common/train.cpp +0 -1515
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/CMakeLists.txt +0 -5
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1639
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1683
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
- /package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +0 -0
|
@@ -25,11 +25,11 @@ static void show_additional_info(int /*argc*/, char ** argv) {
|
|
|
25
25
|
LOG("\nnote: a lower temperature value like 0.1 is recommended for better quality.\n");
|
|
26
26
|
}
|
|
27
27
|
|
|
28
|
-
static struct llama_model * llava_init(
|
|
28
|
+
static struct llama_model * llava_init(common_params * params) {
|
|
29
29
|
llama_backend_init();
|
|
30
30
|
llama_numa_init(params->numa);
|
|
31
31
|
|
|
32
|
-
llama_model_params model_params =
|
|
32
|
+
llama_model_params model_params = common_model_params_to_llama(*params);
|
|
33
33
|
|
|
34
34
|
llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params);
|
|
35
35
|
if (model == NULL) {
|
|
@@ -39,13 +39,13 @@ static struct llama_model * llava_init(gpt_params * params) {
|
|
|
39
39
|
return model;
|
|
40
40
|
}
|
|
41
41
|
|
|
42
|
-
static struct llava_context * llava_init_context(
|
|
42
|
+
static struct llava_context * llava_init_context(common_params * params, llama_model * model) {
|
|
43
43
|
auto prompt = params->prompt;
|
|
44
44
|
if (prompt.empty()) {
|
|
45
45
|
prompt = "describe the image in detail.";
|
|
46
46
|
}
|
|
47
47
|
|
|
48
|
-
llama_context_params ctx_params =
|
|
48
|
+
llama_context_params ctx_params = common_context_params_to_llama(*params);
|
|
49
49
|
if (params->n_ctx < 2048) {
|
|
50
50
|
// warn user here, "Image processing requires at least 2048 context, setting context to 2048"
|
|
51
51
|
LOG_WRN("%s: Image processing requires at least 2048 context, setting context to 2048\n" , __func__);
|
|
@@ -79,7 +79,7 @@ static void llava_free(struct llava_context * ctx_llava) {
|
|
|
79
79
|
llama_backend_free();
|
|
80
80
|
}
|
|
81
81
|
|
|
82
|
-
static struct clip_ctx * clip_init_context(
|
|
82
|
+
static struct clip_ctx * clip_init_context(common_params * params) {
|
|
83
83
|
const char * clip_path = params->mmproj.c_str();
|
|
84
84
|
|
|
85
85
|
auto prompt = params->prompt;
|
|
@@ -97,7 +97,7 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_toke
|
|
|
97
97
|
if (n_eval > n_batch) {
|
|
98
98
|
n_eval = n_batch;
|
|
99
99
|
}
|
|
100
|
-
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval
|
|
100
|
+
if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) {
|
|
101
101
|
LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past);
|
|
102
102
|
return false;
|
|
103
103
|
}
|
|
@@ -114,7 +114,7 @@ static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past) {
|
|
|
114
114
|
|
|
115
115
|
static bool eval_string(struct llama_context * ctx_llama, const char* str, int n_batch, int * n_past, bool add_bos){
|
|
116
116
|
std::string str2 = str;
|
|
117
|
-
std::vector<llama_token> embd_inp =
|
|
117
|
+
std::vector<llama_token> embd_inp = common_tokenize(ctx_llama, str2, add_bos, true);
|
|
118
118
|
return eval_tokens(ctx_llama, embd_inp, n_batch, n_past);
|
|
119
119
|
}
|
|
120
120
|
|
|
@@ -129,7 +129,7 @@ static void process_eval_image_embed(struct llava_context * ctx_llava, const str
|
|
|
129
129
|
llava_image_embed_free(slice_embed);
|
|
130
130
|
}
|
|
131
131
|
|
|
132
|
-
static void process_image(struct llava_context * ctx_llava, struct llava_image_embed * embeds,
|
|
132
|
+
static void process_image(struct llava_context * ctx_llava, struct llava_image_embed * embeds, common_params * params, int &n_past) {
|
|
133
133
|
std::string system_prompt;
|
|
134
134
|
int idx = 0;
|
|
135
135
|
int num_image_embeds = embeds->n_image_pos / clip_n_patches(ctx_llava->ctx_clip);
|
|
@@ -162,22 +162,22 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
|
|
|
162
162
|
LOG_INF("%s: image token past: %d\n", __func__, n_past);
|
|
163
163
|
}
|
|
164
164
|
|
|
165
|
-
static const char * sample(struct
|
|
165
|
+
static const char * sample(struct common_sampler * smpl,
|
|
166
166
|
struct llama_context * ctx_llama,
|
|
167
167
|
int * n_past) {
|
|
168
|
-
const llama_token id =
|
|
169
|
-
|
|
168
|
+
const llama_token id = common_sampler_sample(smpl, ctx_llama, -1);
|
|
169
|
+
common_sampler_accept(smpl, id, true);
|
|
170
170
|
static std::string ret;
|
|
171
171
|
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
|
172
172
|
ret = "</s>";
|
|
173
173
|
} else {
|
|
174
|
-
ret =
|
|
174
|
+
ret = common_token_to_piece(ctx_llama, id);
|
|
175
175
|
}
|
|
176
176
|
eval_id(ctx_llama, id, n_past);
|
|
177
177
|
return ret.c_str();
|
|
178
178
|
}
|
|
179
179
|
|
|
180
|
-
static struct llava_context * minicpmv_init(
|
|
180
|
+
static struct llava_context * minicpmv_init(common_params * params, const std::string & fname, int &n_past){
|
|
181
181
|
auto * ctx_clip = clip_init_context(params);
|
|
182
182
|
auto * embeds = llava_image_embed_make_with_filename(ctx_clip, params->cpuparams.n_threads, fname.c_str());
|
|
183
183
|
if (!embeds) {
|
|
@@ -213,7 +213,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
|
|
|
213
213
|
return ctx_llava;
|
|
214
214
|
}
|
|
215
215
|
|
|
216
|
-
static struct
|
|
216
|
+
static struct common_sampler * llama_init(struct llava_context * ctx_llava, common_params * params, const std::string & prompt, int & n_past, bool is_first = false){
|
|
217
217
|
std::string user_prompt = prompt;
|
|
218
218
|
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
|
|
219
219
|
if (!is_first) {
|
|
@@ -237,11 +237,11 @@ static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_par
|
|
|
237
237
|
|
|
238
238
|
LOG_INF("\n");
|
|
239
239
|
|
|
240
|
-
struct
|
|
240
|
+
struct common_sampler * smpl = common_sampler_init(ctx_llava->model, params->sparams);
|
|
241
241
|
return smpl;
|
|
242
242
|
}
|
|
243
243
|
|
|
244
|
-
static const char * llama_loop(struct llava_context * ctx_llava,struct
|
|
244
|
+
static const char * llama_loop(struct llava_context * ctx_llava,struct common_sampler * smpl, int &n_past){
|
|
245
245
|
|
|
246
246
|
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
|
|
247
247
|
return tmp;
|
|
@@ -250,13 +250,13 @@ static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampl
|
|
|
250
250
|
int main(int argc, char ** argv) {
|
|
251
251
|
ggml_time_init();
|
|
252
252
|
|
|
253
|
-
|
|
253
|
+
common_params params;
|
|
254
254
|
|
|
255
|
-
if (!
|
|
255
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, show_additional_info)) {
|
|
256
256
|
return 1;
|
|
257
257
|
}
|
|
258
258
|
|
|
259
|
-
|
|
259
|
+
common_init();
|
|
260
260
|
|
|
261
261
|
if (params.mmproj.empty() || (params.image.empty())) {
|
|
262
262
|
show_additional_info(argc, argv);
|
|
@@ -290,7 +290,7 @@ int main(int argc, char ** argv) {
|
|
|
290
290
|
|
|
291
291
|
fflush(stdout);
|
|
292
292
|
}
|
|
293
|
-
|
|
293
|
+
common_sampler_free(smpl);
|
|
294
294
|
}else {
|
|
295
295
|
while (true) {
|
|
296
296
|
LOG("<user>");
|
|
@@ -309,7 +309,7 @@ int main(int argc, char ** argv) {
|
|
|
309
309
|
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
|
310
310
|
fflush(stdout);
|
|
311
311
|
}
|
|
312
|
-
|
|
312
|
+
common_sampler_free(smpl);
|
|
313
313
|
}
|
|
314
314
|
}
|
|
315
315
|
printf("\n");
|
|
@@ -37,13 +37,13 @@ struct ngram_container {
|
|
|
37
37
|
};
|
|
38
38
|
|
|
39
39
|
int main(int argc, char ** argv) {
|
|
40
|
-
|
|
40
|
+
common_params params;
|
|
41
41
|
|
|
42
|
-
if (!
|
|
42
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) {
|
|
43
43
|
return 1;
|
|
44
44
|
}
|
|
45
45
|
|
|
46
|
-
|
|
46
|
+
common_init();
|
|
47
47
|
|
|
48
48
|
const int W = 15; // lookahead window
|
|
49
49
|
const int N = 5; // n-gram size
|
|
@@ -56,7 +56,7 @@ int main(int argc, char ** argv) {
|
|
|
56
56
|
llama_numa_init(params.numa);
|
|
57
57
|
|
|
58
58
|
// load the target model
|
|
59
|
-
|
|
59
|
+
common_init_result llama_init = common_init_from_params(params);
|
|
60
60
|
|
|
61
61
|
llama_model * model = llama_init.model;
|
|
62
62
|
llama_context * ctx = llama_init.context;
|
|
@@ -65,7 +65,7 @@ int main(int argc, char ** argv) {
|
|
|
65
65
|
std::vector<llama_token> inp;
|
|
66
66
|
std::vector<llama_token> all;
|
|
67
67
|
|
|
68
|
-
inp =
|
|
68
|
+
inp = common_tokenize(ctx, params.prompt, true, true);
|
|
69
69
|
all = inp;
|
|
70
70
|
|
|
71
71
|
const int max_context_size = llama_n_ctx(ctx);
|
|
@@ -79,7 +79,7 @@ int main(int argc, char ** argv) {
|
|
|
79
79
|
LOG("\n\n");
|
|
80
80
|
|
|
81
81
|
for (auto id : inp) {
|
|
82
|
-
LOG("%s",
|
|
82
|
+
LOG("%s", common_token_to_piece(ctx, id).c_str());
|
|
83
83
|
}
|
|
84
84
|
|
|
85
85
|
fflush(stderr);
|
|
@@ -89,8 +89,8 @@ int main(int argc, char ** argv) {
|
|
|
89
89
|
const auto t_enc_start = ggml_time_us();
|
|
90
90
|
|
|
91
91
|
// eval the prompt
|
|
92
|
-
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1
|
|
93
|
-
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1
|
|
92
|
+
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
|
|
93
|
+
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
|
94
94
|
|
|
95
95
|
for (int s = 1; s < W + G + 1; ++s) {
|
|
96
96
|
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
|
|
@@ -115,7 +115,7 @@ int main(int argc, char ** argv) {
|
|
|
115
115
|
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
|
116
116
|
|
|
117
117
|
// target model sampling context
|
|
118
|
-
struct
|
|
118
|
+
struct common_sampler * smpl = common_sampler_init(model, params.sparams);
|
|
119
119
|
|
|
120
120
|
// verification n-grams
|
|
121
121
|
std::vector<ngram_data> ngrams_cur(G);
|
|
@@ -156,12 +156,12 @@ int main(int argc, char ** argv) {
|
|
|
156
156
|
|
|
157
157
|
// sample first token
|
|
158
158
|
{
|
|
159
|
-
id =
|
|
159
|
+
id = common_sampler_sample(smpl, ctx, 0);
|
|
160
160
|
|
|
161
|
-
|
|
161
|
+
common_sampler_accept(smpl, id, true);
|
|
162
162
|
|
|
163
163
|
{
|
|
164
|
-
const std::string token_str =
|
|
164
|
+
const std::string token_str = common_token_to_piece(ctx, id);
|
|
165
165
|
|
|
166
166
|
LOG("%s", token_str.c_str());
|
|
167
167
|
fflush(stdout);
|
|
@@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
|
|
|
172
172
|
// debug
|
|
173
173
|
if (dump_kv_cache) {
|
|
174
174
|
llama_kv_cache_view_update(ctx, &kvc_view);
|
|
175
|
-
|
|
175
|
+
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
|
176
176
|
}
|
|
177
177
|
|
|
178
178
|
// build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/
|
|
@@ -201,10 +201,10 @@ int main(int argc, char ** argv) {
|
|
|
201
201
|
// V V V V V V
|
|
202
202
|
// id
|
|
203
203
|
{
|
|
204
|
-
|
|
204
|
+
common_batch_clear(batch);
|
|
205
205
|
|
|
206
206
|
// current token - first token of the first level
|
|
207
|
-
|
|
207
|
+
common_batch_add(batch, id, n_past, seq_id_all, true);
|
|
208
208
|
|
|
209
209
|
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
|
|
210
210
|
{
|
|
@@ -229,7 +229,7 @@ int main(int argc, char ** argv) {
|
|
|
229
229
|
ngrams_cur[g].tokens [j + 1] = t;
|
|
230
230
|
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
|
|
231
231
|
|
|
232
|
-
|
|
232
|
+
common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
|
|
233
233
|
}
|
|
234
234
|
}
|
|
235
235
|
}
|
|
@@ -241,13 +241,13 @@ int main(int argc, char ** argv) {
|
|
|
241
241
|
seq_id_look[j] = i + j + 1;
|
|
242
242
|
}
|
|
243
243
|
|
|
244
|
-
|
|
244
|
+
common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
|
|
245
245
|
}
|
|
246
246
|
|
|
247
247
|
// fill the rest of the levels
|
|
248
248
|
for (int j = 1; j < N - 1; j++) {
|
|
249
249
|
for (int i = 0; i < W; i++) {
|
|
250
|
-
|
|
250
|
+
common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
|
|
251
251
|
}
|
|
252
252
|
}
|
|
253
253
|
}
|
|
@@ -281,13 +281,13 @@ int main(int argc, char ** argv) {
|
|
|
281
281
|
}
|
|
282
282
|
|
|
283
283
|
// sample the next token
|
|
284
|
-
id =
|
|
284
|
+
id = common_sampler_sample(smpl, ctx, i_batch);
|
|
285
285
|
|
|
286
|
-
|
|
286
|
+
common_sampler_accept(smpl, id, true);
|
|
287
287
|
|
|
288
288
|
// print
|
|
289
289
|
{
|
|
290
|
-
const std::string token_str =
|
|
290
|
+
const std::string token_str = common_token_to_piece(ctx, id);
|
|
291
291
|
|
|
292
292
|
if (v == 0) {
|
|
293
293
|
LOG("%s", token_str.c_str());
|
|
@@ -327,7 +327,7 @@ int main(int argc, char ** argv) {
|
|
|
327
327
|
// print known n-grams starting with token id (debug)
|
|
328
328
|
if (0 && v == 0) {
|
|
329
329
|
if (ngrams_observed.cnt[id] > 0) {
|
|
330
|
-
LOG("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id],
|
|
330
|
+
LOG("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], common_token_to_piece(ctx, id).c_str());
|
|
331
331
|
}
|
|
332
332
|
|
|
333
333
|
for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
|
|
@@ -336,7 +336,7 @@ int main(int argc, char ** argv) {
|
|
|
336
336
|
const int idx = id*(N - 1)*G + i*(N - 1);
|
|
337
337
|
|
|
338
338
|
for (int j = 0; j < N - 1; j++) {
|
|
339
|
-
const std::string token_str =
|
|
339
|
+
const std::string token_str = common_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
|
|
340
340
|
|
|
341
341
|
LOG("%s", token_str.c_str());
|
|
342
342
|
}
|
|
@@ -358,7 +358,7 @@ int main(int argc, char ** argv) {
|
|
|
358
358
|
if (v == 0) {
|
|
359
359
|
// sample from the last level
|
|
360
360
|
for (int i = 0; i < W; i++) {
|
|
361
|
-
tokens_j[N - 2][i] =
|
|
361
|
+
tokens_j[N - 2][i] = common_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
|
362
362
|
}
|
|
363
363
|
} else {
|
|
364
364
|
for (int i = 0; i < W; i++) {
|
|
@@ -466,9 +466,9 @@ int main(int argc, char ** argv) {
|
|
|
466
466
|
LOG_INF("n_accept = %d\n", n_accept);
|
|
467
467
|
|
|
468
468
|
LOG_INF("\n");
|
|
469
|
-
|
|
469
|
+
common_perf_print(ctx, smpl);
|
|
470
470
|
|
|
471
|
-
|
|
471
|
+
common_sampler_free(smpl);
|
|
472
472
|
|
|
473
473
|
llama_kv_cache_view_free(&kvc_view);
|
|
474
474
|
|
|
@@ -12,9 +12,9 @@
|
|
|
12
12
|
#include <vector>
|
|
13
13
|
|
|
14
14
|
int main(int argc, char ** argv){
|
|
15
|
-
|
|
15
|
+
common_params params;
|
|
16
16
|
|
|
17
|
-
if (!
|
|
17
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
|
|
18
18
|
return 1;
|
|
19
19
|
}
|
|
20
20
|
|
|
@@ -23,7 +23,7 @@ int main(int argc, char ** argv){
|
|
|
23
23
|
llama_numa_init(params.numa);
|
|
24
24
|
|
|
25
25
|
// load the model
|
|
26
|
-
|
|
26
|
+
common_init_result llama_init = common_init_from_params(params);
|
|
27
27
|
|
|
28
28
|
llama_model * model = llama_init.model;
|
|
29
29
|
llama_context * ctx = llama_init.context;
|
|
@@ -31,15 +31,15 @@ int main(int argc, char ** argv){
|
|
|
31
31
|
|
|
32
32
|
// tokenize the prompt
|
|
33
33
|
std::vector<llama_token> inp;
|
|
34
|
-
inp =
|
|
34
|
+
inp = common_tokenize(ctx, params.prompt, true, true);
|
|
35
35
|
fprintf(stderr, "%s: tokenization done\n", __func__);
|
|
36
36
|
|
|
37
37
|
|
|
38
|
-
|
|
39
|
-
|
|
38
|
+
common_ngram_cache ngram_cache;
|
|
39
|
+
common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true);
|
|
40
40
|
fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str());
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
common_ngram_cache_save(ngram_cache, params.lookup_cache_static);
|
|
43
43
|
|
|
44
44
|
return 0;
|
|
45
45
|
}
|
|
@@ -33,15 +33,15 @@ int main(int argc, char ** argv){
|
|
|
33
33
|
}
|
|
34
34
|
|
|
35
35
|
fprintf(stderr, "lookup-merge: loading file %s\n", args[0].c_str());
|
|
36
|
-
|
|
36
|
+
common_ngram_cache ngram_cache_merged = common_ngram_cache_load(args[0]);
|
|
37
37
|
|
|
38
38
|
for (size_t i = 1; i < args.size()-1; ++i) {
|
|
39
39
|
fprintf(stderr, "lookup-merge: loading file %s\n", args[i].c_str());
|
|
40
|
-
|
|
40
|
+
common_ngram_cache ngram_cache = common_ngram_cache_load(args[i]);
|
|
41
41
|
|
|
42
|
-
|
|
42
|
+
common_ngram_cache_merge(ngram_cache_merged, ngram_cache);
|
|
43
43
|
}
|
|
44
44
|
|
|
45
45
|
fprintf(stderr, "lookup-merge: saving file %s\n", args.back().c_str());
|
|
46
|
-
|
|
46
|
+
common_ngram_cache_save(ngram_cache_merged, args.back());
|
|
47
47
|
}
|
|
@@ -13,13 +13,13 @@
|
|
|
13
13
|
#include <vector>
|
|
14
14
|
|
|
15
15
|
int main(int argc, char ** argv){
|
|
16
|
-
|
|
16
|
+
common_params params;
|
|
17
17
|
|
|
18
|
-
if (!
|
|
18
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
|
|
19
19
|
return 1;
|
|
20
20
|
}
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
common_init();
|
|
23
23
|
|
|
24
24
|
const int n_draft = params.n_draft;
|
|
25
25
|
|
|
@@ -28,18 +28,18 @@ int main(int argc, char ** argv){
|
|
|
28
28
|
llama_numa_init(params.numa);
|
|
29
29
|
|
|
30
30
|
// load the model
|
|
31
|
-
|
|
31
|
+
common_init_result llama_init = common_init_from_params(params);
|
|
32
32
|
|
|
33
33
|
llama_model * model = llama_init.model;
|
|
34
34
|
llama_context * ctx = llama_init.context;
|
|
35
35
|
|
|
36
36
|
// tokenize the prompt
|
|
37
37
|
std::vector<llama_token> inp;
|
|
38
|
-
inp =
|
|
38
|
+
inp = common_tokenize(ctx, params.prompt, true, true);
|
|
39
39
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
40
|
+
common_ngram_cache ngram_cache_context;
|
|
41
|
+
common_ngram_cache ngram_cache_dynamic;
|
|
42
|
+
common_ngram_cache ngram_cache_static;
|
|
43
43
|
int64_t t_draft_flat_us = 0;
|
|
44
44
|
int64_t t_draft_us = 0;
|
|
45
45
|
|
|
@@ -48,7 +48,7 @@ int main(int argc, char ** argv){
|
|
|
48
48
|
|
|
49
49
|
if (!params.lookup_cache_static.empty()) {
|
|
50
50
|
try {
|
|
51
|
-
ngram_cache_static =
|
|
51
|
+
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
|
52
52
|
} catch (std::ifstream::failure const &) {
|
|
53
53
|
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
|
54
54
|
exit(1);
|
|
@@ -57,7 +57,7 @@ int main(int argc, char ** argv){
|
|
|
57
57
|
|
|
58
58
|
if (!params.lookup_cache_dynamic.empty()) {
|
|
59
59
|
try {
|
|
60
|
-
ngram_cache_dynamic =
|
|
60
|
+
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
|
61
61
|
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
|
62
62
|
}
|
|
63
63
|
|
|
@@ -86,7 +86,7 @@ int main(int argc, char ** argv){
|
|
|
86
86
|
|
|
87
87
|
{
|
|
88
88
|
const int64_t t_start_draft_us = ggml_time_us();
|
|
89
|
-
|
|
89
|
+
common_ngram_cache_draft(pseudo_output, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
|
|
90
90
|
t_draft_us += ggml_time_us() - t_start_draft_us;
|
|
91
91
|
}
|
|
92
92
|
|
|
@@ -105,7 +105,7 @@ int main(int argc, char ** argv){
|
|
|
105
105
|
|
|
106
106
|
{
|
|
107
107
|
const int64_t t_start_draft_us = ggml_time_us();
|
|
108
|
-
|
|
108
|
+
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
|
|
109
109
|
t_draft_us += ggml_time_us() - t_start_draft_us;
|
|
110
110
|
}
|
|
111
111
|
}
|
|
@@ -115,7 +115,7 @@ int main(int argc, char ** argv){
|
|
|
115
115
|
pseudo_output.push_back(inp_slice[pseudo_output.size()]);
|
|
116
116
|
{
|
|
117
117
|
const int64_t t_start_draft_us = ggml_time_us();
|
|
118
|
-
|
|
118
|
+
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, pseudo_output, 1, false);
|
|
119
119
|
t_draft_us += ggml_time_us() - t_start_draft_us;
|
|
120
120
|
}
|
|
121
121
|
}
|
|
@@ -133,7 +133,7 @@ int main(int argc, char ** argv){
|
|
|
133
133
|
}
|
|
134
134
|
|
|
135
135
|
// After each chunk, update the dynamic ngram cache with the context ngram cache:
|
|
136
|
-
|
|
136
|
+
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
|
|
137
137
|
ngram_cache_context.clear();
|
|
138
138
|
}
|
|
139
139
|
|
|
@@ -13,13 +13,13 @@
|
|
|
13
13
|
#include <vector>
|
|
14
14
|
|
|
15
15
|
int main(int argc, char ** argv){
|
|
16
|
-
|
|
16
|
+
common_params params;
|
|
17
17
|
|
|
18
|
-
if (!
|
|
18
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) {
|
|
19
19
|
return 1;
|
|
20
20
|
}
|
|
21
21
|
|
|
22
|
-
|
|
22
|
+
common_init();
|
|
23
23
|
|
|
24
24
|
// max. number of additional tokens to draft if match is found
|
|
25
25
|
const int n_draft = params.n_draft;
|
|
@@ -31,29 +31,29 @@ int main(int argc, char ** argv){
|
|
|
31
31
|
llama_numa_init(params.numa);
|
|
32
32
|
|
|
33
33
|
// load the model
|
|
34
|
-
|
|
34
|
+
common_init_result llama_init = common_init_from_params(params);
|
|
35
35
|
|
|
36
36
|
llama_model * model = llama_init.model;
|
|
37
37
|
llama_context * ctx = llama_init.context;
|
|
38
38
|
|
|
39
39
|
// tokenize the prompt
|
|
40
40
|
std::vector<llama_token> inp;
|
|
41
|
-
inp =
|
|
41
|
+
inp = common_tokenize(ctx, params.prompt, true, true);
|
|
42
42
|
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
43
|
+
common_ngram_cache ngram_cache_context;
|
|
44
|
+
common_ngram_cache ngram_cache_dynamic;
|
|
45
|
+
common_ngram_cache ngram_cache_static;
|
|
46
46
|
int64_t t_draft_flat_us = 0;
|
|
47
47
|
int64_t t_draft_us = 0;
|
|
48
48
|
|
|
49
49
|
{
|
|
50
50
|
// Fill up context ngram cache with tokens from user input:
|
|
51
51
|
const int64_t t_start_draft_us = ggml_time_us();
|
|
52
|
-
|
|
52
|
+
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false);
|
|
53
53
|
|
|
54
54
|
if (!params.lookup_cache_static.empty()) {
|
|
55
55
|
try {
|
|
56
|
-
ngram_cache_static =
|
|
56
|
+
ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static);
|
|
57
57
|
} catch (std::ifstream::failure const &) {
|
|
58
58
|
LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str());
|
|
59
59
|
exit(1);
|
|
@@ -62,7 +62,7 @@ int main(int argc, char ** argv){
|
|
|
62
62
|
|
|
63
63
|
if (!params.lookup_cache_dynamic.empty()) {
|
|
64
64
|
try {
|
|
65
|
-
ngram_cache_dynamic =
|
|
65
|
+
ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic);
|
|
66
66
|
} catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program
|
|
67
67
|
}
|
|
68
68
|
|
|
@@ -80,7 +80,7 @@ int main(int argc, char ** argv){
|
|
|
80
80
|
LOG("\n\n");
|
|
81
81
|
|
|
82
82
|
for (auto id : inp) {
|
|
83
|
-
LOG("%s",
|
|
83
|
+
LOG("%s", common_token_to_piece(ctx, id).c_str());
|
|
84
84
|
}
|
|
85
85
|
|
|
86
86
|
fflush(stderr);
|
|
@@ -89,8 +89,8 @@ int main(int argc, char ** argv){
|
|
|
89
89
|
|
|
90
90
|
const auto t_enc_start = ggml_time_us();
|
|
91
91
|
|
|
92
|
-
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1
|
|
93
|
-
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1
|
|
92
|
+
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1));
|
|
93
|
+
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
|
|
94
94
|
|
|
95
95
|
const auto t_enc_end = ggml_time_us();
|
|
96
96
|
|
|
@@ -102,7 +102,7 @@ int main(int argc, char ** argv){
|
|
|
102
102
|
|
|
103
103
|
bool has_eos = false;
|
|
104
104
|
|
|
105
|
-
struct
|
|
105
|
+
struct common_sampler * smpl = common_sampler_init(model, params.sparams);
|
|
106
106
|
|
|
107
107
|
std::vector<llama_token> draft;
|
|
108
108
|
|
|
@@ -117,7 +117,7 @@ int main(int argc, char ** argv){
|
|
|
117
117
|
// debug
|
|
118
118
|
if (dump_kv_cache) {
|
|
119
119
|
llama_kv_cache_view_update(ctx, &kvc_view);
|
|
120
|
-
|
|
120
|
+
common_kv_cache_dump_view_seqs(kvc_view, 40);
|
|
121
121
|
}
|
|
122
122
|
|
|
123
123
|
// print current draft sequence
|
|
@@ -126,11 +126,11 @@ int main(int argc, char ** argv){
|
|
|
126
126
|
int i_dft = 0;
|
|
127
127
|
while (true) {
|
|
128
128
|
// sample from the target model
|
|
129
|
-
llama_token id =
|
|
129
|
+
llama_token id = common_sampler_sample(smpl, ctx, i_dft);
|
|
130
130
|
|
|
131
|
-
|
|
131
|
+
common_sampler_accept(smpl, id, true);
|
|
132
132
|
|
|
133
|
-
const std::string token_str =
|
|
133
|
+
const std::string token_str = common_token_to_piece(ctx, id);
|
|
134
134
|
|
|
135
135
|
if (!params.use_color) {
|
|
136
136
|
LOG("%s", token_str.c_str());
|
|
@@ -152,7 +152,7 @@ int main(int argc, char ** argv){
|
|
|
152
152
|
{
|
|
153
153
|
// Update context ngram cache with the newly accepted token:
|
|
154
154
|
const int64_t t_start_draft_us = ggml_time_us();
|
|
155
|
-
|
|
155
|
+
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
|
|
156
156
|
t_draft_us += ggml_time_us() - t_start_draft_us;
|
|
157
157
|
}
|
|
158
158
|
|
|
@@ -178,7 +178,7 @@ int main(int argc, char ** argv){
|
|
|
178
178
|
{
|
|
179
179
|
// Update context ngram cache with the newly accepted token:
|
|
180
180
|
const int64_t t_start_draft_us = ggml_time_us();
|
|
181
|
-
|
|
181
|
+
common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, 1, false);
|
|
182
182
|
t_draft_us += ggml_time_us() - t_start_draft_us;
|
|
183
183
|
}
|
|
184
184
|
break;
|
|
@@ -192,18 +192,18 @@ int main(int argc, char ** argv){
|
|
|
192
192
|
// clean the cache of draft tokens that weren't accepted
|
|
193
193
|
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
|
194
194
|
|
|
195
|
-
|
|
196
|
-
|
|
195
|
+
common_batch_clear(batch_tgt);
|
|
196
|
+
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);
|
|
197
197
|
|
|
198
198
|
// Draft already contains a single token sampled from the model:
|
|
199
199
|
GGML_ASSERT(draft.size() == 1);
|
|
200
200
|
GGML_ASSERT(draft[0] == inp.back());
|
|
201
201
|
const int64_t t_start_draft_us = ggml_time_us();
|
|
202
202
|
|
|
203
|
-
|
|
203
|
+
common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static);
|
|
204
204
|
|
|
205
205
|
for (size_t i = 1; i < draft.size(); ++i) {
|
|
206
|
-
|
|
206
|
+
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
|
207
207
|
}
|
|
208
208
|
|
|
209
209
|
t_draft_us += ggml_time_us() - t_start_draft_us;
|
|
@@ -218,8 +218,8 @@ int main(int argc, char ** argv){
|
|
|
218
218
|
auto t_dec_end = ggml_time_us();
|
|
219
219
|
|
|
220
220
|
// Update dynamic ngram cache with context ngram cache and save it to disk:
|
|
221
|
-
|
|
222
|
-
|
|
221
|
+
common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context);
|
|
222
|
+
common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic);
|
|
223
223
|
|
|
224
224
|
LOG("\n\n");
|
|
225
225
|
|
|
@@ -237,9 +237,9 @@ int main(int argc, char ** argv){
|
|
|
237
237
|
LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
|
238
238
|
|
|
239
239
|
LOG_INF("\ntarget:\n\n");
|
|
240
|
-
|
|
240
|
+
common_perf_print(ctx, smpl);
|
|
241
241
|
|
|
242
|
-
|
|
242
|
+
common_sampler_free(smpl);
|
|
243
243
|
|
|
244
244
|
llama_batch_free(batch_tgt);
|
|
245
245
|
|