@fugood/llama.node 0.3.3 → 0.3.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +5 -0
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/lib/binding.ts +18 -1
- package/package.json +1 -1
- package/src/EmbeddingWorker.cpp +15 -5
- package/src/EmbeddingWorker.h +2 -1
- package/src/LlamaCompletionWorker.cpp +1 -1
- package/src/LlamaContext.cpp +81 -18
- package/src/LlamaContext.h +2 -0
- package/src/llama.cpp/.github/workflows/build.yml +197 -159
- package/src/llama.cpp/.github/workflows/docker.yml +5 -8
- package/src/llama.cpp/.github/workflows/python-lint.yml +8 -1
- package/src/llama.cpp/.github/workflows/server.yml +21 -14
- package/src/llama.cpp/CMakeLists.txt +11 -6
- package/src/llama.cpp/Sources/llama/llama.h +4 -0
- package/src/llama.cpp/cmake/common.cmake +33 -0
- package/src/llama.cpp/cmake/x64-windows-llvm.cmake +11 -0
- package/src/llama.cpp/common/CMakeLists.txt +6 -2
- package/src/llama.cpp/common/arg.cpp +426 -245
- package/src/llama.cpp/common/common.cpp +143 -80
- package/src/llama.cpp/common/common.h +81 -24
- package/src/llama.cpp/common/sampling.cpp +53 -19
- package/src/llama.cpp/common/sampling.h +22 -1
- package/src/llama.cpp/common/speculative.cpp +274 -0
- package/src/llama.cpp/common/speculative.h +28 -0
- package/src/llama.cpp/docs/build.md +101 -148
- package/src/llama.cpp/examples/CMakeLists.txt +32 -13
- package/src/llama.cpp/examples/batched/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/batched/batched.cpp +5 -4
- package/src/llama.cpp/examples/batched-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/cvector-generator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/deprecation-warning/deprecation-warning.cpp +1 -1
- package/src/llama.cpp/examples/embedding/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/eval-callback/CMakeLists.txt +3 -2
- package/src/llama.cpp/examples/export-lora/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +4 -7
- package/src/llama.cpp/examples/gen-docs/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-hash/CMakeLists.txt +8 -1
- package/src/llama.cpp/examples/gguf-split/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +2 -2
- package/src/llama.cpp/examples/gritlm/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +1 -1
- package/src/llama.cpp/examples/imatrix/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +11 -2
- package/src/llama.cpp/examples/infill/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/infill/infill.cpp +1 -1
- package/src/llama.cpp/examples/llama-bench/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +405 -316
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llava/CMakeLists.txt +10 -3
- package/src/llama.cpp/examples/llava/clip.cpp +262 -66
- package/src/llama.cpp/examples/llava/clip.h +8 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/llava.cpp +46 -19
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +1 -1
- package/src/llama.cpp/examples/llava/qwen2vl-cli.cpp +581 -0
- package/src/llama.cpp/examples/lookahead/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +1 -1
- package/src/llama.cpp/examples/lookup/CMakeLists.txt +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +2 -1
- package/src/llama.cpp/examples/lookup/lookup.cpp +2 -2
- package/src/llama.cpp/examples/main/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/main/main.cpp +9 -5
- package/src/llama.cpp/examples/main-cmake-pkg/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/parallel/parallel.cpp +1 -1
- package/src/llama.cpp/examples/passkey/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/perplexity/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +0 -3
- package/src/llama.cpp/examples/quantize-stats/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +4 -4
- package/src/llama.cpp/examples/run/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/run/run.cpp +911 -0
- package/src/llama.cpp/examples/save-load-state/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +4 -4
- package/src/llama.cpp/examples/server/CMakeLists.txt +3 -7
- package/src/llama.cpp/examples/server/server.cpp +1758 -886
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -2
- package/src/llama.cpp/examples/server/utils.hpp +94 -304
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +4 -0
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +3 -0
- package/src/llama.cpp/examples/speculative/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/speculative/speculative.cpp +16 -15
- package/src/llama.cpp/examples/speculative-simple/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/speculative-simple/speculative-simple.cpp +265 -0
- package/src/llama.cpp/examples/tokenize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +1 -1
- package/src/llama.cpp/examples/tts/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/tts/tts.cpp +932 -0
- package/src/llama.cpp/ggml/CMakeLists.txt +46 -34
- package/src/llama.cpp/ggml/include/ggml-backend.h +16 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +7 -49
- package/src/llama.cpp/ggml/include/ggml-opencl.h +26 -0
- package/src/llama.cpp/ggml/include/ggml.h +106 -24
- package/src/llama.cpp/ggml/src/CMakeLists.txt +73 -24
- package/src/llama.cpp/ggml/src/ggml-alloc.c +0 -1
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +51 -11
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +379 -22
- package/src/llama.cpp/ggml/src/ggml-backend.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +3 -7
- package/src/llama.cpp/ggml/src/ggml-blas/ggml-blas.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +33 -3
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +456 -111
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +6 -3
- package/src/llama.cpp/ggml/src/ggml-cann/ggml-cann.cpp +95 -35
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -5
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/dup.cpp +22 -9
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f16.cpp +24 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_f32.cpp +23 -13
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/get_row_q4_0.cpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f16_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_f32_q8_0.cpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +17 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +42 -42
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +288 -213
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.cpp +220 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/amx/amx.h +8 -0
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/common.h +19 -22
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.cpp +93 -92
- package/src/llama.cpp/ggml/src/{ggml-amx → ggml-cpu/amx}/mmq.h +2 -9
- package/src/llama.cpp/ggml/src/ggml-cpu/cpu-feats-x86.cpp +323 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/{ggml-cpu-aarch64.c → ggml-cpu-aarch64.cpp} +892 -190
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +2 -24
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.cpp +55 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-hbm.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +15 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +38 -25
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.cpp +36 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-traits.h +38 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +552 -399
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +101 -136
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +7 -10
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +8 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +4 -6
- package/src/llama.cpp/ggml/src/ggml-impl.h +32 -11
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +13 -9
- package/src/llama.cpp/ggml/src/ggml-kompute/ggml-kompute.cpp +131 -64
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +3 -6
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +39 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +14 -7
- package/src/llama.cpp/ggml/src/ggml-opencl/CMakeLists.txt +147 -0
- package/src/llama.cpp/ggml/src/ggml-opencl/ggml-opencl.cpp +4004 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +67 -80
- package/src/llama.cpp/ggml/src/ggml-quants.c +0 -9
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +3 -5
- package/src/llama.cpp/ggml/src/ggml-rpc/ggml-rpc.cpp +5 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +13 -10
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +2 -11
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +2 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +5 -5
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +32 -13
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +80 -61
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/ggml-sycl.cpp +159 -114
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +3 -2
- package/src/llama.cpp/ggml/src/ggml-sycl/mmq.cpp +6 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +6 -20
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +8 -8
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +4 -3
- package/src/llama.cpp/ggml/src/ggml-sycl/softmax.cpp +7 -7
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +4 -1
- package/src/llama.cpp/ggml/src/ggml-threading.h +4 -2
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +21 -7
- package/src/llama.cpp/ggml/src/ggml-vulkan/ggml-vulkan.cpp +1718 -399
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt +3 -1
- package/src/llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +105 -31
- package/src/llama.cpp/ggml/src/ggml.c +367 -207
- package/src/llama.cpp/include/llama-cpp.h +25 -0
- package/src/llama.cpp/include/llama.h +26 -19
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-roberta-bpe.gguf.out +46 -0
- package/src/llama.cpp/pocs/CMakeLists.txt +3 -1
- package/src/llama.cpp/pocs/vdot/CMakeLists.txt +2 -2
- package/src/llama.cpp/src/CMakeLists.txt +2 -7
- package/src/llama.cpp/src/llama-grammar.cpp +15 -15
- package/src/llama.cpp/src/llama-grammar.h +2 -5
- package/src/llama.cpp/src/llama-sampling.cpp +35 -90
- package/src/llama.cpp/src/llama-vocab.cpp +6 -1
- package/src/llama.cpp/src/llama.cpp +1748 -640
- package/src/llama.cpp/src/unicode.cpp +62 -51
- package/src/llama.cpp/src/unicode.h +9 -10
- package/src/llama.cpp/tests/CMakeLists.txt +48 -37
- package/src/llama.cpp/tests/test-arg-parser.cpp +2 -2
- package/src/llama.cpp/tests/test-backend-ops.cpp +140 -21
- package/src/llama.cpp/tests/test-chat-template.cpp +50 -4
- package/src/llama.cpp/tests/test-gguf.cpp +1303 -0
- package/src/llama.cpp/tests/test-grammar-integration.cpp +3 -6
- package/src/llama.cpp/tests/test-llama-grammar.cpp +2 -4
- package/src/llama.cpp/tests/test-quantize-fns.cpp +3 -3
- package/src/llama.cpp/tests/test-rope.cpp +61 -20
- package/src/llama.cpp/tests/test-sampling.cpp +2 -2
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +0 -72
- package/src/llama.cpp/.github/workflows/nix-ci.yml +0 -79
- package/src/llama.cpp/.github/workflows/nix-flake-update.yml +0 -22
- package/src/llama.cpp/.github/workflows/nix-publish-flake.yml +0 -36
- package/src/llama.cpp/ggml/include/ggml-amx.h +0 -25
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +0 -129
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -19
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +0 -107
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +0 -446
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
#include "arg.h"
|
|
2
|
+
#include "common.h"
|
|
3
|
+
#include "sampling.h"
|
|
4
|
+
#include "speculative.h"
|
|
5
|
+
#include "log.h"
|
|
6
|
+
#include "llama.h"
|
|
7
|
+
|
|
8
|
+
#include <cstdio>
|
|
9
|
+
#include <cstring>
|
|
10
|
+
#include <string>
|
|
11
|
+
#include <vector>
|
|
12
|
+
|
|
13
|
+
int main(int argc, char ** argv) {
|
|
14
|
+
common_params params;
|
|
15
|
+
|
|
16
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
|
17
|
+
return 1;
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
if (params.n_predict < -1) {
|
|
21
|
+
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
|
|
22
|
+
return 1;
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
common_init();
|
|
26
|
+
|
|
27
|
+
if (params.speculative.model.empty()) {
|
|
28
|
+
LOG_ERR("%s: --model-draft is required\n", __func__);
|
|
29
|
+
return 1;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
// init llama.cpp
|
|
33
|
+
llama_backend_init();
|
|
34
|
+
llama_numa_init(params.numa);
|
|
35
|
+
|
|
36
|
+
llama_model * model_tgt = NULL;
|
|
37
|
+
llama_model * model_dft = NULL;
|
|
38
|
+
|
|
39
|
+
llama_context * ctx_tgt = NULL;
|
|
40
|
+
llama_context * ctx_dft = NULL;
|
|
41
|
+
|
|
42
|
+
// load the target model
|
|
43
|
+
common_init_result llama_init_tgt = common_init_from_params(params);
|
|
44
|
+
|
|
45
|
+
model_tgt = llama_init_tgt.model;
|
|
46
|
+
ctx_tgt = llama_init_tgt.context;
|
|
47
|
+
|
|
48
|
+
// load the draft model
|
|
49
|
+
params.devices = params.speculative.devices;
|
|
50
|
+
params.model = params.speculative.model;
|
|
51
|
+
params.n_ctx = params.speculative.n_ctx;
|
|
52
|
+
params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch;
|
|
53
|
+
params.n_gpu_layers = params.speculative.n_gpu_layers;
|
|
54
|
+
|
|
55
|
+
if (params.speculative.cpuparams.n_threads > 0) {
|
|
56
|
+
params.cpuparams.n_threads = params.speculative.cpuparams.n_threads;
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads;
|
|
60
|
+
common_init_result llama_init_dft = common_init_from_params(params);
|
|
61
|
+
|
|
62
|
+
model_dft = llama_init_dft.model;
|
|
63
|
+
ctx_dft = llama_init_dft.context;
|
|
64
|
+
|
|
65
|
+
if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) {
|
|
66
|
+
return 1;
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// Tokenize the prompt
|
|
70
|
+
std::vector<llama_token> inp;
|
|
71
|
+
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
|
|
72
|
+
|
|
73
|
+
if (llama_n_ctx(ctx_tgt) < (uint32_t) inp.size()) {
|
|
74
|
+
LOG_ERR("%s: the prompt exceeds the context size (%d tokens, ctx %d)\n", __func__, (int) inp.size(), llama_n_ctx(ctx_tgt));
|
|
75
|
+
|
|
76
|
+
return 1;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
if (llama_n_batch(ctx_tgt) < (uint32_t) inp.size()) {
|
|
80
|
+
LOG_ERR("%s: the prompt exceeds the batch size (%d tokens, batch %d)\n", __func__, (int) inp.size(), llama_n_batch(ctx_tgt));
|
|
81
|
+
|
|
82
|
+
return 1;
|
|
83
|
+
}
|
|
84
|
+
|
|
85
|
+
LOG("\n\n");
|
|
86
|
+
|
|
87
|
+
for (auto id : inp) {
|
|
88
|
+
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
// how many tokens to draft each time
|
|
92
|
+
int n_draft = params.speculative.n_max;
|
|
93
|
+
int n_draft_min = params.speculative.n_min;
|
|
94
|
+
|
|
95
|
+
float p_min = params.speculative.p_min;
|
|
96
|
+
|
|
97
|
+
int n_predict = 0;
|
|
98
|
+
int n_drafted = 0;
|
|
99
|
+
int n_accept = 0;
|
|
100
|
+
|
|
101
|
+
// used to determine end of generation
|
|
102
|
+
bool has_eos = false;
|
|
103
|
+
|
|
104
|
+
// ================================================
|
|
105
|
+
// everything until here is standard initialization
|
|
106
|
+
// the relevant stuff for speculative decoding starts here
|
|
107
|
+
|
|
108
|
+
const auto t_enc_start = ggml_time_us();
|
|
109
|
+
|
|
110
|
+
// target model sampling context
|
|
111
|
+
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling);
|
|
112
|
+
|
|
113
|
+
// eval the prompt
|
|
114
|
+
llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1));
|
|
115
|
+
|
|
116
|
+
// note: keep the last token separate!
|
|
117
|
+
llama_token id_last = inp.back();
|
|
118
|
+
|
|
119
|
+
// all tokens currently in the target context
|
|
120
|
+
llama_tokens prompt_tgt(inp.begin(), inp.end() - 1);
|
|
121
|
+
prompt_tgt.reserve(llama_n_ctx(ctx_tgt));
|
|
122
|
+
|
|
123
|
+
int n_past = inp.size() - 1;
|
|
124
|
+
|
|
125
|
+
// init the speculator
|
|
126
|
+
struct common_speculative_params params_spec;
|
|
127
|
+
params_spec.n_draft = n_draft;
|
|
128
|
+
params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft;
|
|
129
|
+
params_spec.p_min = p_min;
|
|
130
|
+
|
|
131
|
+
struct common_speculative * spec = common_speculative_init(ctx_dft);
|
|
132
|
+
|
|
133
|
+
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1);
|
|
134
|
+
|
|
135
|
+
const auto t_enc_end = ggml_time_us();
|
|
136
|
+
|
|
137
|
+
const auto t_dec_start = ggml_time_us();
|
|
138
|
+
|
|
139
|
+
while (true) {
|
|
140
|
+
// optionally, generate draft tokens that can be appended to the target batch
|
|
141
|
+
//
|
|
142
|
+
// this is the most important part of the speculation. the more probable tokens that are provided here
|
|
143
|
+
// the better the performance will be. in theory, this computation can be performed asynchronously and even
|
|
144
|
+
// offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens
|
|
145
|
+
// from a cache or lookup tables.
|
|
146
|
+
//
|
|
147
|
+
llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last);
|
|
148
|
+
|
|
149
|
+
//LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str());
|
|
150
|
+
|
|
151
|
+
// always have a token to evaluate from before - id_last
|
|
152
|
+
common_batch_clear(batch_tgt);
|
|
153
|
+
common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true);
|
|
154
|
+
|
|
155
|
+
// evaluate the target model on [id_last, draft0, draft1, ..., draftN-1]
|
|
156
|
+
{
|
|
157
|
+
// do not waste time on small drafts
|
|
158
|
+
if (draft.size() < (size_t) n_draft_min) {
|
|
159
|
+
draft.clear();
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
for (size_t i = 0; i < draft.size(); ++i) {
|
|
163
|
+
common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true);
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
//LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str());
|
|
167
|
+
|
|
168
|
+
llama_decode(ctx_tgt, batch_tgt);
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
// sample from the full target batch and return the accepted tokens based on the target sampler
|
|
172
|
+
//
|
|
173
|
+
// for each token to be accepted, the sampler would have to sample that same token
|
|
174
|
+
// in such cases, instead of decoding the sampled token as we normally do, we simply continue with the
|
|
175
|
+
// available logits from the batch and sample the next token until we run out of logits or the sampler
|
|
176
|
+
// disagrees with the draft
|
|
177
|
+
//
|
|
178
|
+
const auto ids = common_sampler_sample_and_accept_n(smpl, ctx_tgt, draft);
|
|
179
|
+
|
|
180
|
+
//LOG_DBG("ids: %s\n", string_from(ctx_tgt, ids).c_str());
|
|
181
|
+
|
|
182
|
+
GGML_ASSERT(ids.size() > 0); // there will always be at least one accepted token
|
|
183
|
+
|
|
184
|
+
n_past += ids.size() - 1;
|
|
185
|
+
n_drafted += draft.size(); // note: we ignore the discarded small drafts
|
|
186
|
+
n_accept += ids.size() - 1;
|
|
187
|
+
n_predict += ids.size();
|
|
188
|
+
|
|
189
|
+
// process the accepted tokens and update contexts
|
|
190
|
+
//
|
|
191
|
+
// this is the standard token post-processing that we normally do
|
|
192
|
+
// in this case, we do it for a group of accepted tokens at once
|
|
193
|
+
//
|
|
194
|
+
for (size_t i = 0; i < ids.size(); ++i) {
|
|
195
|
+
prompt_tgt.push_back(id_last);
|
|
196
|
+
|
|
197
|
+
id_last = ids[i];
|
|
198
|
+
|
|
199
|
+
if (llama_token_is_eog(model_tgt, id_last)) {
|
|
200
|
+
has_eos = true;
|
|
201
|
+
break;
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
const std::string token_str = common_token_to_piece(ctx_tgt, id_last);
|
|
205
|
+
|
|
206
|
+
if (params.use_color && i + 1 < ids.size()) {
|
|
207
|
+
LOG("\u001b[%dm%s\u001b[37m", (36 - 0 % 6), token_str.c_str());
|
|
208
|
+
} else {
|
|
209
|
+
LOG("%s", token_str.c_str());
|
|
210
|
+
}
|
|
211
|
+
}
|
|
212
|
+
|
|
213
|
+
LOG_DBG("accepted %d/%d draft tokens, the last target token is: (%d)\n", (int) ids.size() - 1, (int) draft.size(), id_last);
|
|
214
|
+
|
|
215
|
+
{
|
|
216
|
+
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
|
|
217
|
+
|
|
218
|
+
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past, -1);
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
|
222
|
+
break;
|
|
223
|
+
}
|
|
224
|
+
}
|
|
225
|
+
|
|
226
|
+
auto t_dec_end = ggml_time_us();
|
|
227
|
+
|
|
228
|
+
const int n_input = inp.size();
|
|
229
|
+
|
|
230
|
+
LOG("\n\n");
|
|
231
|
+
|
|
232
|
+
LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
|
233
|
+
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
|
234
|
+
|
|
235
|
+
LOG_INF("\n");
|
|
236
|
+
LOG_INF("n_draft = %d\n", n_draft);
|
|
237
|
+
LOG_INF("n_predict = %d\n", n_predict);
|
|
238
|
+
LOG_INF("n_drafted = %d\n", n_drafted);
|
|
239
|
+
LOG_INF("n_accept = %d\n", n_accept);
|
|
240
|
+
LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
|
241
|
+
|
|
242
|
+
LOG_INF("\n");
|
|
243
|
+
LOG_INF("draft:\n\n");
|
|
244
|
+
|
|
245
|
+
llama_perf_context_print(ctx_dft);
|
|
246
|
+
|
|
247
|
+
LOG_INF("\n");
|
|
248
|
+
LOG_INF("target:\n\n");
|
|
249
|
+
common_perf_print(ctx_tgt, smpl);
|
|
250
|
+
|
|
251
|
+
common_sampler_free(smpl);
|
|
252
|
+
common_speculative_free(spec);
|
|
253
|
+
|
|
254
|
+
llama_free(ctx_tgt);
|
|
255
|
+
llama_free_model(model_tgt);
|
|
256
|
+
|
|
257
|
+
llama_free(ctx_dft);
|
|
258
|
+
llama_free_model(model_dft);
|
|
259
|
+
|
|
260
|
+
llama_backend_free();
|
|
261
|
+
|
|
262
|
+
LOG("\n\n");
|
|
263
|
+
|
|
264
|
+
return 0;
|
|
265
|
+
}
|
|
@@ -2,4 +2,4 @@ set(TARGET llama-tokenize)
|
|
|
2
2
|
add_executable(${TARGET} tokenize.cpp)
|
|
3
3
|
install(TARGETS ${TARGET} RUNTIME)
|
|
4
4
|
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
|
5
|
-
target_compile_features(${TARGET} PRIVATE
|
|
5
|
+
target_compile_features(${TARGET} PRIVATE cxx_std_17)
|