@fugood/llama.node 0.3.1 → 0.3.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/CMakeLists.txt +1 -8
- package/bin/darwin/arm64/llama-node.node +0 -0
- package/bin/darwin/x64/llama-node.node +0 -0
- package/bin/linux/arm64/llama-node.node +0 -0
- package/bin/linux/x64/llama-node.node +0 -0
- package/bin/linux-vulkan/arm64/llama-node.node +0 -0
- package/bin/linux-vulkan/x64/llama-node.node +0 -0
- package/bin/win32/arm64/llama-node.node +0 -0
- package/bin/win32/arm64/node.lib +0 -0
- package/bin/win32/x64/llama-node.node +0 -0
- package/bin/win32/x64/node.lib +0 -0
- package/bin/win32-vulkan/arm64/llama-node.node +0 -0
- package/bin/win32-vulkan/arm64/node.lib +0 -0
- package/bin/win32-vulkan/x64/llama-node.node +0 -0
- package/bin/win32-vulkan/x64/node.lib +0 -0
- package/package.json +4 -2
- package/src/DetokenizeWorker.cpp +1 -1
- package/src/EmbeddingWorker.cpp +2 -2
- package/src/LlamaCompletionWorker.cpp +10 -10
- package/src/LlamaCompletionWorker.h +2 -2
- package/src/LlamaContext.cpp +14 -17
- package/src/TokenizeWorker.cpp +1 -1
- package/src/common.hpp +5 -4
- package/src/llama.cpp/.github/workflows/build.yml +137 -29
- package/src/llama.cpp/.github/workflows/close-issue.yml +5 -0
- package/src/llama.cpp/.github/workflows/docker.yml +46 -34
- package/src/llama.cpp/.github/workflows/nix-ci-aarch64.yml +7 -0
- package/src/llama.cpp/.github/workflows/nix-ci.yml +7 -0
- package/src/llama.cpp/.github/workflows/python-check-requirements.yml +2 -4
- package/src/llama.cpp/.github/workflows/python-type-check.yml +3 -1
- package/src/llama.cpp/.github/workflows/server.yml +7 -0
- package/src/llama.cpp/CMakeLists.txt +26 -11
- package/src/llama.cpp/cmake/arm64-apple-clang.cmake +16 -0
- package/src/llama.cpp/common/CMakeLists.txt +10 -10
- package/src/llama.cpp/common/arg.cpp +2041 -0
- package/src/llama.cpp/common/arg.h +77 -0
- package/src/llama.cpp/common/common.cpp +523 -1861
- package/src/llama.cpp/common/common.h +234 -106
- package/src/llama.cpp/common/console.cpp +3 -0
- package/src/llama.cpp/common/json-schema-to-grammar.cpp +1 -1
- package/src/llama.cpp/common/log.cpp +401 -0
- package/src/llama.cpp/common/log.h +66 -698
- package/src/llama.cpp/common/ngram-cache.cpp +39 -36
- package/src/llama.cpp/common/ngram-cache.h +19 -19
- package/src/llama.cpp/common/sampling.cpp +356 -350
- package/src/llama.cpp/common/sampling.h +62 -139
- package/src/llama.cpp/common/stb_image.h +5990 -6398
- package/src/llama.cpp/docs/build.md +72 -17
- package/src/llama.cpp/examples/CMakeLists.txt +1 -2
- package/src/llama.cpp/examples/batched/batched.cpp +49 -65
- package/src/llama.cpp/examples/batched-bench/batched-bench.cpp +42 -53
- package/src/llama.cpp/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +55 -52
- package/src/llama.cpp/examples/cvector-generator/cvector-generator.cpp +22 -22
- package/src/llama.cpp/examples/cvector-generator/pca.hpp +3 -13
- package/src/llama.cpp/examples/embedding/embedding.cpp +147 -91
- package/src/llama.cpp/examples/eval-callback/eval-callback.cpp +37 -37
- package/src/llama.cpp/examples/export-lora/export-lora.cpp +39 -38
- package/src/llama.cpp/examples/gbnf-validator/gbnf-validator.cpp +14 -39
- package/src/llama.cpp/examples/{baby-llama → gen-docs}/CMakeLists.txt +2 -2
- package/src/llama.cpp/examples/gen-docs/gen-docs.cpp +83 -0
- package/src/llama.cpp/examples/gguf-split/gguf-split.cpp +58 -39
- package/src/llama.cpp/examples/gritlm/gritlm.cpp +46 -39
- package/src/llama.cpp/examples/imatrix/imatrix.cpp +75 -69
- package/src/llama.cpp/examples/infill/infill.cpp +131 -192
- package/src/llama.cpp/examples/llama-bench/llama-bench.cpp +276 -178
- package/src/llama.cpp/examples/llama.android/llama/build.gradle.kts +1 -0
- package/src/llama.cpp/examples/llama.android/llama/src/main/cpp/llama-android.cpp +40 -36
- package/src/llama.cpp/examples/llava/CMakeLists.txt +7 -0
- package/src/llama.cpp/examples/llava/clip.cpp +686 -150
- package/src/llama.cpp/examples/llava/clip.h +11 -2
- package/src/llama.cpp/examples/llava/llava-cli.cpp +60 -71
- package/src/llama.cpp/examples/llava/llava.cpp +146 -26
- package/src/llama.cpp/examples/llava/llava.h +2 -3
- package/src/llama.cpp/examples/llava/minicpmv-cli.cpp +323 -0
- package/src/llama.cpp/examples/llava/requirements.txt +1 -0
- package/src/llama.cpp/examples/lookahead/lookahead.cpp +55 -56
- package/src/llama.cpp/examples/lookup/lookup-create.cpp +15 -13
- package/src/llama.cpp/examples/lookup/lookup-merge.cpp +4 -4
- package/src/llama.cpp/examples/lookup/lookup-stats.cpp +34 -33
- package/src/llama.cpp/examples/lookup/lookup.cpp +60 -63
- package/src/llama.cpp/examples/main/main.cpp +216 -313
- package/src/llama.cpp/examples/parallel/parallel.cpp +58 -59
- package/src/llama.cpp/examples/passkey/passkey.cpp +53 -61
- package/src/llama.cpp/examples/perplexity/perplexity.cpp +277 -311
- package/src/llama.cpp/examples/quantize/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/quantize/quantize.cpp +27 -9
- package/src/llama.cpp/examples/quantize-stats/quantize-stats.cpp +12 -12
- package/src/llama.cpp/examples/retrieval/retrieval.cpp +57 -52
- package/src/llama.cpp/examples/rpc/rpc-server.cpp +27 -2
- package/src/llama.cpp/examples/save-load-state/save-load-state.cpp +60 -46
- package/src/llama.cpp/examples/server/CMakeLists.txt +7 -18
- package/src/llama.cpp/examples/server/server.cpp +1347 -1531
- package/src/llama.cpp/examples/server/tests/requirements.txt +2 -1
- package/src/llama.cpp/examples/server/utils.hpp +396 -107
- package/src/llama.cpp/examples/simple/CMakeLists.txt +1 -1
- package/src/llama.cpp/examples/simple/simple.cpp +132 -106
- package/src/llama.cpp/examples/simple-chat/CMakeLists.txt +5 -0
- package/src/llama.cpp/examples/simple-chat/simple-chat.cpp +197 -0
- package/src/llama.cpp/examples/speculative/speculative.cpp +153 -124
- package/src/llama.cpp/examples/sycl/run-llama2.sh +10 -19
- package/src/llama.cpp/examples/sycl/win-run-llama2.bat +1 -1
- package/src/llama.cpp/examples/tokenize/tokenize.cpp +27 -29
- package/src/llama.cpp/ggml/CMakeLists.txt +29 -12
- package/src/llama.cpp/ggml/include/ggml-alloc.h +3 -3
- package/src/llama.cpp/ggml/include/ggml-amx.h +25 -0
- package/src/llama.cpp/ggml/include/ggml-backend.h +166 -68
- package/src/llama.cpp/ggml/include/ggml-blas.h +5 -3
- package/src/llama.cpp/ggml/include/ggml-cann.h +17 -19
- package/src/llama.cpp/ggml/include/ggml-cpp.h +38 -0
- package/src/llama.cpp/ggml/include/ggml-cpu.h +177 -0
- package/src/llama.cpp/ggml/include/ggml-cuda.h +17 -17
- package/src/llama.cpp/ggml/include/ggml-kompute.h +7 -3
- package/src/llama.cpp/ggml/include/ggml-metal.h +13 -12
- package/src/llama.cpp/ggml/include/ggml-opt.h +216 -0
- package/src/llama.cpp/ggml/include/ggml-rpc.h +9 -5
- package/src/llama.cpp/ggml/include/ggml-sycl.h +18 -11
- package/src/llama.cpp/ggml/include/ggml-vulkan.h +10 -8
- package/src/llama.cpp/ggml/include/ggml.h +272 -505
- package/src/llama.cpp/ggml/src/CMakeLists.txt +69 -1110
- package/src/llama.cpp/ggml/src/ggml-aarch64.c +52 -2116
- package/src/llama.cpp/ggml/src/ggml-aarch64.h +0 -20
- package/src/llama.cpp/ggml/src/ggml-alloc.c +29 -27
- package/src/llama.cpp/ggml/src/ggml-amx/CMakeLists.txt +107 -0
- package/src/llama.cpp/ggml/src/ggml-amx/common.h +94 -0
- package/src/llama.cpp/ggml/src/ggml-amx/ggml-amx.cpp +446 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.cpp +2510 -0
- package/src/llama.cpp/ggml/src/ggml-amx/mmq.h +17 -0
- package/src/llama.cpp/ggml/src/ggml-backend-impl.h +144 -81
- package/src/llama.cpp/ggml/src/ggml-backend-reg.cpp +195 -0
- package/src/llama.cpp/ggml/src/{ggml-backend.c → ggml-backend.cpp} +394 -635
- package/src/llama.cpp/ggml/src/ggml-blas/CMakeLists.txt +91 -0
- package/src/llama.cpp/ggml/src/{ggml-blas.cpp → ggml-blas/ggml-blas.cpp} +217 -70
- package/src/llama.cpp/ggml/src/ggml-cann/CMakeLists.txt +46 -0
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.cpp +4 -27
- package/src/llama.cpp/ggml/src/ggml-cann/acl_tensor.h +32 -4
- package/src/llama.cpp/ggml/src/ggml-cann/aclnn_ops.cpp +179 -41
- package/src/llama.cpp/ggml/src/ggml-cann/common.h +1 -0
- package/src/llama.cpp/ggml/src/{ggml-cann.cpp → ggml-cann/ggml-cann.cpp} +458 -353
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/CMakeLists.txt +2 -1
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/ascendc_kernels.h +2 -0
- package/src/llama.cpp/ggml/src/ggml-cann/kernels/quantize_float_to_q4_0.cpp +278 -0
- package/src/llama.cpp/ggml/src/ggml-common.h +20 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/CMakeLists.txt +261 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.c +3560 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-aarch64.h +30 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-impl.h +371 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.c +10822 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu-quants.h +63 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.c +13970 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/ggml-cpu.cpp +663 -0
- package/src/llama.cpp/ggml/src/ggml-cpu/llamafile/sgemm.cpp +1885 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/CMakeLists.txt +155 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/cuda.h +14 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/hip.h +178 -0
- package/src/llama.cpp/ggml/src/ggml-cuda/vendors/musa.h +134 -0
- package/src/llama.cpp/ggml/src/ggml-hip/CMakeLists.txt +106 -0
- package/src/llama.cpp/ggml/src/ggml-impl.h +380 -584
- package/src/llama.cpp/ggml/src/ggml-kompute/CMakeLists.txt +162 -0
- package/src/llama.cpp/ggml/src/{ggml-kompute.cpp → ggml-kompute/ggml-kompute.cpp} +233 -87
- package/src/llama.cpp/ggml/src/ggml-metal/CMakeLists.txt +108 -0
- package/src/llama.cpp/ggml/src/ggml-metal/ggml-metal-impl.h +249 -0
- package/src/llama.cpp/ggml/src/ggml-musa/CMakeLists.txt +100 -0
- package/src/llama.cpp/ggml/src/ggml-opt.cpp +867 -0
- package/src/llama.cpp/ggml/src/ggml-quants.c +369 -9994
- package/src/llama.cpp/ggml/src/ggml-quants.h +78 -110
- package/src/llama.cpp/ggml/src/ggml-rpc/CMakeLists.txt +11 -0
- package/src/llama.cpp/ggml/src/{ggml-rpc.cpp → ggml-rpc/ggml-rpc.cpp} +560 -335
- package/src/llama.cpp/ggml/src/ggml-sycl/CMakeLists.txt +81 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/backend.hpp +6 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.cpp +51 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/common.hpp +310 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/concat.cpp +1 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.cpp +99 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/conv.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.cpp +57 -57
- package/src/llama.cpp/ggml/src/ggml-sycl/convert.hpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/dequantize.hpp +106 -106
- package/src/llama.cpp/ggml/src/ggml-sycl/dmmv.cpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/dpct/helper.hpp +18 -25
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.cpp +1011 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/element_wise.hpp +76 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/gemm.hpp +101 -0
- package/src/llama.cpp/ggml/src/{ggml-sycl.cpp → ggml-sycl/ggml-sycl.cpp} +3350 -3980
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.cpp +125 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/im2col.hpp +23 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/mmvq.cpp +70 -68
- package/src/llama.cpp/ggml/src/ggml-sycl/norm.cpp +9 -6
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.cpp +56 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/outprod.hpp +11 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/presets.hpp +8 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/rope.cpp +1 -1
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.cpp +71 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/tsembd.hpp +21 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/vecdotq.hpp +4 -4
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.cpp +138 -0
- package/src/llama.cpp/ggml/src/ggml-sycl/wkv6.hpp +10 -0
- package/src/llama.cpp/ggml/src/ggml-threading.cpp +12 -0
- package/src/llama.cpp/ggml/src/ggml-threading.h +12 -0
- package/src/llama.cpp/ggml/src/ggml-vulkan/CMakeLists.txt +78 -0
- package/src/llama.cpp/ggml/src/{ggml-vulkan.cpp → ggml-vulkan/ggml-vulkan.cpp} +2034 -1718
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/CMakeLists.txt +2 -0
- package/src/llama.cpp/ggml/src/{vulkan-shaders → ggml-vulkan/vulkan-shaders}/vulkan-shaders-gen.cpp +152 -185
- package/src/llama.cpp/ggml/src/ggml.c +2075 -16579
- package/src/llama.cpp/include/llama.h +296 -285
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.inp +112 -0
- package/src/llama.cpp/models/ggml-vocab-chameleon.gguf.out +46 -0
- package/src/llama.cpp/pocs/vdot/q8dot.cpp +4 -3
- package/src/llama.cpp/pocs/vdot/vdot.cpp +8 -7
- package/src/llama.cpp/requirements/requirements-convert_legacy_llama.txt +1 -1
- package/src/llama.cpp/src/CMakeLists.txt +2 -1
- package/src/llama.cpp/src/llama-grammar.cpp +721 -122
- package/src/llama.cpp/src/llama-grammar.h +120 -15
- package/src/llama.cpp/src/llama-impl.h +156 -1
- package/src/llama.cpp/src/llama-sampling.cpp +2058 -346
- package/src/llama.cpp/src/llama-sampling.h +39 -47
- package/src/llama.cpp/src/llama-vocab.cpp +390 -127
- package/src/llama.cpp/src/llama-vocab.h +60 -20
- package/src/llama.cpp/src/llama.cpp +6215 -3263
- package/src/llama.cpp/src/unicode-data.cpp +6 -4
- package/src/llama.cpp/src/unicode-data.h +4 -4
- package/src/llama.cpp/src/unicode.cpp +15 -7
- package/src/llama.cpp/tests/CMakeLists.txt +4 -2
- package/src/llama.cpp/tests/test-arg-parser.cpp +131 -0
- package/src/llama.cpp/tests/test-backend-ops.cpp +1725 -297
- package/src/llama.cpp/tests/test-barrier.cpp +94 -0
- package/src/llama.cpp/tests/test-chat-template.cpp +9 -5
- package/src/llama.cpp/tests/test-grammar-integration.cpp +23 -38
- package/src/llama.cpp/tests/test-grammar-parser.cpp +6 -4
- package/src/llama.cpp/tests/test-json-schema-to-grammar.cpp +23 -8
- package/src/llama.cpp/tests/test-llama-grammar.cpp +9 -8
- package/src/llama.cpp/tests/test-log.cpp +39 -0
- package/src/llama.cpp/tests/test-opt.cpp +853 -142
- package/src/llama.cpp/tests/test-quantize-fns.cpp +28 -19
- package/src/llama.cpp/tests/test-quantize-perf.cpp +16 -14
- package/src/llama.cpp/tests/test-rope.cpp +2 -1
- package/src/llama.cpp/tests/test-sampling.cpp +226 -142
- package/src/llama.cpp/tests/test-tokenizer-0.cpp +56 -36
- package/src/llama.cpp/tests/test-tokenizer-1-bpe.cpp +5 -5
- package/src/llama.cpp/tests/test-tokenizer-1-spm.cpp +5 -5
- package/patches/llama.patch +0 -22
- package/src/llama.cpp/.github/workflows/bench.yml +0 -310
- package/src/llama.cpp/common/grammar-parser.cpp +0 -536
- package/src/llama.cpp/common/grammar-parser.h +0 -29
- package/src/llama.cpp/common/train.cpp +0 -1513
- package/src/llama.cpp/common/train.h +0 -233
- package/src/llama.cpp/examples/baby-llama/baby-llama.cpp +0 -1640
- package/src/llama.cpp/examples/benchmark/CMakeLists.txt +0 -6
- package/src/llama.cpp/examples/benchmark/benchmark-matmult.cpp +0 -275
- package/src/llama.cpp/ggml/src/llamafile/sgemm.cpp +0 -1027
- package/src/llama.cpp/tests/test-grad0.cpp +0 -1566
- /package/src/llama.cpp/ggml/{cmake → src/ggml-cpu/cmake}/FindSIMD.cmake +0 -0
- /package/src/llama.cpp/ggml/src/{llamafile → ggml-cpu/llamafile}/sgemm.h +0 -0
|
@@ -1,11 +1,16 @@
|
|
|
1
|
+
#include "arg.h"
|
|
1
2
|
#include "common.h"
|
|
3
|
+
#include "sampling.h"
|
|
4
|
+
#include "log.h"
|
|
2
5
|
#include "llama.h"
|
|
3
6
|
|
|
4
|
-
#include <
|
|
7
|
+
#include <algorithm>
|
|
5
8
|
#include <cstdio>
|
|
9
|
+
#include <cstring>
|
|
10
|
+
#include <random>
|
|
11
|
+
#include <set>
|
|
6
12
|
#include <string>
|
|
7
13
|
#include <vector>
|
|
8
|
-
#include <set>
|
|
9
14
|
|
|
10
15
|
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 100
|
|
11
16
|
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
|
|
@@ -21,19 +26,28 @@ struct seq_draft {
|
|
|
21
26
|
std::vector<llama_token> tokens;
|
|
22
27
|
std::vector<std::vector<llama_token_data>> dists;
|
|
23
28
|
|
|
24
|
-
struct
|
|
29
|
+
struct common_sampler * smpl = nullptr;
|
|
25
30
|
};
|
|
26
31
|
|
|
27
32
|
int main(int argc, char ** argv) {
|
|
28
|
-
|
|
33
|
+
common_params params;
|
|
34
|
+
|
|
35
|
+
// needed to get candidate probs even for temp <= 0.0
|
|
36
|
+
params.sparams.n_probs = 128;
|
|
29
37
|
|
|
30
|
-
if (!
|
|
31
|
-
gpt_params_print_usage(argc, argv, params);
|
|
38
|
+
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) {
|
|
32
39
|
return 1;
|
|
33
40
|
}
|
|
34
41
|
|
|
42
|
+
if (params.n_predict < -1) {
|
|
43
|
+
LOG_ERR("%s: --n-predict must be >= -1\n", __func__);
|
|
44
|
+
return 1;
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
common_init();
|
|
48
|
+
|
|
35
49
|
if (params.model_draft.empty()) {
|
|
36
|
-
|
|
50
|
+
LOG_ERR("%s: --model-draft is required\n", __func__);
|
|
37
51
|
return 1;
|
|
38
52
|
}
|
|
39
53
|
|
|
@@ -43,18 +57,9 @@ int main(int argc, char ** argv) {
|
|
|
43
57
|
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
|
44
58
|
const float p_split = params.p_split;
|
|
45
59
|
|
|
46
|
-
|
|
47
|
-
params.seed = time(NULL);
|
|
48
|
-
}
|
|
49
|
-
std::default_random_engine rng(params.seed);
|
|
60
|
+
std::default_random_engine rng(params.sparams.seed == LLAMA_DEFAULT_SEED ? std::random_device()() : params.sparams.seed);
|
|
50
61
|
std::uniform_real_distribution<> u_dist;
|
|
51
62
|
|
|
52
|
-
#ifndef LOG_DISABLE_LOGS
|
|
53
|
-
log_set_target(log_filename_generator("speculative", "log"));
|
|
54
|
-
LOG_TEE("Log start\n");
|
|
55
|
-
log_dump_cmdline(argc, argv);
|
|
56
|
-
#endif // LOG_DISABLE_LOGS
|
|
57
|
-
|
|
58
63
|
// init llama.cpp
|
|
59
64
|
llama_backend_init();
|
|
60
65
|
llama_numa_init(params.numa);
|
|
@@ -66,26 +71,31 @@ int main(int argc, char ** argv) {
|
|
|
66
71
|
llama_context * ctx_dft = NULL;
|
|
67
72
|
|
|
68
73
|
// load the target model
|
|
69
|
-
|
|
74
|
+
common_init_result llama_init_tgt = common_init_from_params(params);
|
|
75
|
+
model_tgt = llama_init_tgt.model;
|
|
76
|
+
ctx_tgt = llama_init_tgt.context;
|
|
70
77
|
|
|
71
78
|
// load the draft model
|
|
72
79
|
params.model = params.model_draft;
|
|
73
80
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
|
74
|
-
if (params.
|
|
75
|
-
params.n_threads = params.
|
|
81
|
+
if (params.draft_cpuparams.n_threads > 0) {
|
|
82
|
+
params.cpuparams.n_threads = params.draft_cpuparams.n_threads;
|
|
76
83
|
}
|
|
77
|
-
|
|
78
|
-
|
|
84
|
+
|
|
85
|
+
params.cpuparams_batch.n_threads = params.draft_cpuparams_batch.n_threads;
|
|
86
|
+
common_init_result llama_init_dft = common_init_from_params(params);
|
|
87
|
+
model_dft = llama_init_dft.model;
|
|
88
|
+
ctx_dft = llama_init_dft.context;
|
|
79
89
|
|
|
80
90
|
const bool vocab_type_tgt = llama_vocab_type(model_tgt);
|
|
81
|
-
|
|
91
|
+
LOG_DBG("vocab_type tgt: %d\n", vocab_type_tgt);
|
|
82
92
|
|
|
83
93
|
const bool vocab_type_dft = llama_vocab_type(model_dft);
|
|
84
|
-
|
|
94
|
+
LOG_DBG("vocab_type dft: %d\n", vocab_type_dft);
|
|
85
95
|
|
|
86
96
|
if (vocab_type_tgt != vocab_type_dft) {
|
|
87
|
-
|
|
88
|
-
|
|
97
|
+
LOG_ERR("%s: draft model vocab type must match target model to use speculation but ", __func__);
|
|
98
|
+
LOG_ERR("vocab_type_dft = %d while vocab_type_tgt = %d\n", vocab_type_dft, vocab_type_tgt);
|
|
89
99
|
return 1;
|
|
90
100
|
}
|
|
91
101
|
|
|
@@ -95,7 +105,7 @@ int main(int argc, char ** argv) {
|
|
|
95
105
|
llama_token_bos(model_tgt) != llama_token_bos(model_dft) ||
|
|
96
106
|
llama_token_eos(model_tgt) != llama_token_eos(model_dft)
|
|
97
107
|
) {
|
|
98
|
-
|
|
108
|
+
LOG_ERR("%s: draft model special tokens must match target model to use speculation\n", __func__);
|
|
99
109
|
return 1;
|
|
100
110
|
}
|
|
101
111
|
|
|
@@ -107,8 +117,8 @@ int main(int argc, char ** argv) {
|
|
|
107
117
|
: n_vocab_dft - n_vocab_tgt;
|
|
108
118
|
|
|
109
119
|
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
|
|
110
|
-
|
|
111
|
-
|
|
120
|
+
LOG_ERR("%s: draft model vocab must closely match target model to use speculation but ", __func__);
|
|
121
|
+
LOG_ERR("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
|
|
112
122
|
n_vocab_tgt, llama_n_vocab(model_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
|
|
113
123
|
return 1;
|
|
114
124
|
}
|
|
@@ -117,10 +127,10 @@ int main(int argc, char ** argv) {
|
|
|
117
127
|
const char * token_text_tgt = llama_token_get_text(model_tgt, i);
|
|
118
128
|
const char * token_text_dft = llama_token_get_text(model_dft, i);
|
|
119
129
|
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
130
|
+
LOG_ERR("%s: draft model vocab must match target model to use speculation but ", __func__);
|
|
131
|
+
LOG_ERR("token %d content differs - target '%s', draft '%s'\n", i,
|
|
132
|
+
common_token_to_piece(ctx_tgt, i).c_str(),
|
|
133
|
+
common_token_to_piece(ctx_dft, i).c_str());
|
|
124
134
|
return 1;
|
|
125
135
|
}
|
|
126
136
|
}
|
|
@@ -129,32 +139,30 @@ int main(int argc, char ** argv) {
|
|
|
129
139
|
|
|
130
140
|
// Tokenize the prompt
|
|
131
141
|
std::vector<llama_token> inp;
|
|
132
|
-
inp =
|
|
142
|
+
inp = common_tokenize(ctx_tgt, params.prompt, true, true);
|
|
133
143
|
|
|
134
144
|
const int max_context_size = llama_n_ctx(ctx_tgt);
|
|
135
145
|
const int max_tokens_list_size = max_context_size - 4;
|
|
136
146
|
|
|
137
147
|
if ((int) inp.size() > max_tokens_list_size) {
|
|
138
|
-
|
|
148
|
+
LOG_ERR("%s: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
|
|
139
149
|
return 1;
|
|
140
150
|
}
|
|
141
151
|
|
|
142
|
-
|
|
152
|
+
LOG("\n\n");
|
|
143
153
|
|
|
144
154
|
for (auto id : inp) {
|
|
145
|
-
|
|
155
|
+
LOG("%s", common_token_to_piece(ctx_tgt, id).c_str());
|
|
146
156
|
}
|
|
147
157
|
|
|
148
|
-
fflush(stderr);
|
|
149
|
-
|
|
150
158
|
const int n_input = inp.size();
|
|
151
159
|
|
|
152
160
|
const auto t_enc_start = ggml_time_us();
|
|
153
161
|
|
|
154
162
|
// eval the prompt with both models
|
|
155
|
-
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1
|
|
156
|
-
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1
|
|
157
|
-
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input
|
|
163
|
+
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1));
|
|
164
|
+
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1));
|
|
165
|
+
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input));
|
|
158
166
|
|
|
159
167
|
const auto t_enc_end = ggml_time_us();
|
|
160
168
|
|
|
@@ -174,23 +182,19 @@ int main(int argc, char ** argv) {
|
|
|
174
182
|
// used to determine end of generation
|
|
175
183
|
bool has_eos = false;
|
|
176
184
|
|
|
177
|
-
// target model sampling context
|
|
178
|
-
struct
|
|
185
|
+
// target model sampling context (reuse the llama_context's sampling instance)
|
|
186
|
+
struct common_sampler * smpl = common_sampler_init(model_tgt, params.sparams);
|
|
179
187
|
|
|
180
188
|
// draft sequence data
|
|
181
189
|
std::vector<seq_draft> drafts(n_seq_dft);
|
|
182
190
|
|
|
183
|
-
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
|
184
|
-
if (params.sparams.temp == 0) {
|
|
185
|
-
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
|
186
|
-
}
|
|
187
|
-
|
|
188
191
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
189
|
-
|
|
192
|
+
// allocate llama_sampler for each draft sequence
|
|
193
|
+
drafts[s].smpl = common_sampler_init(model_dft, params.sparams);
|
|
190
194
|
}
|
|
191
195
|
|
|
192
|
-
llama_batch batch_dft = llama_batch_init(
|
|
193
|
-
llama_batch batch_tgt = llama_batch_init(
|
|
196
|
+
llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
|
|
197
|
+
llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft);
|
|
194
198
|
|
|
195
199
|
const auto t_dec_start = ggml_time_us();
|
|
196
200
|
|
|
@@ -210,7 +214,7 @@ int main(int argc, char ** argv) {
|
|
|
210
214
|
active_seqs.insert(s);
|
|
211
215
|
const auto & tokens = drafts[s].tokens;
|
|
212
216
|
|
|
213
|
-
|
|
217
|
+
LOG_DBG("draft %d: %s\n", s, string_from(ctx_dft, tokens).c_str());
|
|
214
218
|
}
|
|
215
219
|
|
|
216
220
|
int i_dft = 0;
|
|
@@ -228,12 +232,12 @@ int main(int argc, char ** argv) {
|
|
|
228
232
|
bool accept = false;
|
|
229
233
|
if (params.sparams.temp > 0) {
|
|
230
234
|
// stochastic verification
|
|
235
|
+
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
|
|
231
236
|
|
|
232
|
-
|
|
233
|
-
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
|
234
|
-
float p_tgt = 0, p_dft = 0;
|
|
237
|
+
auto & dist_tgt = *common_sampler_get_candidates(smpl);
|
|
235
238
|
|
|
236
|
-
|
|
239
|
+
float p_tgt = 0.0f;
|
|
240
|
+
float p_dft = 0.0f;
|
|
237
241
|
|
|
238
242
|
while (active_seqs.size() > 0) {
|
|
239
243
|
// randomly select a sequence to verify from active sequences
|
|
@@ -252,39 +256,43 @@ int main(int argc, char ** argv) {
|
|
|
252
256
|
}
|
|
253
257
|
continue;
|
|
254
258
|
}
|
|
255
|
-
|
|
259
|
+
|
|
260
|
+
LOG_DBG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
|
|
256
261
|
float r = u_dist(rng);
|
|
257
|
-
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
|
|
262
|
+
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
|
|
263
|
+
|
|
264
|
+
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
|
|
265
|
+
|
|
258
266
|
// acquire the token probabilities assigned by the draft and target models
|
|
259
267
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
260
268
|
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
|
261
269
|
p_tgt = dist_tgt.data[i].p;
|
|
270
|
+
break;
|
|
262
271
|
}
|
|
272
|
+
}
|
|
273
|
+
for (size_t i = 0; i < dist_dft.size; i++) {
|
|
263
274
|
if (dist_dft.data[i].id == drafts[s].tokens[i_dft]) {
|
|
264
275
|
p_dft = dist_dft.data[i].p;
|
|
265
|
-
}
|
|
266
|
-
if (p_tgt && p_dft) {
|
|
267
276
|
break;
|
|
268
277
|
}
|
|
269
278
|
}
|
|
270
|
-
|
|
279
|
+
LOG_DBG("r = %f, p_dft = %f, p_tgt = %f\n", r, p_dft, p_tgt);
|
|
271
280
|
if (r <= p_tgt / p_dft) {
|
|
272
281
|
s_keep = s;
|
|
273
282
|
accept = true;
|
|
274
283
|
token_id = drafts[s].tokens[i_dft];
|
|
275
|
-
token_str =
|
|
276
|
-
|
|
284
|
+
token_str = common_token_to_piece(ctx_tgt, token_id);
|
|
285
|
+
common_sampler_accept(smpl, token_id, true);
|
|
277
286
|
|
|
278
|
-
|
|
287
|
+
LOG_DBG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
|
279
288
|
break;
|
|
280
289
|
} else {
|
|
281
|
-
|
|
290
|
+
LOG_DBG("draft token %d of sequence %d (%d, '%s') rejected\n", i_dft, s, drafts[s].tokens[i_dft], common_token_to_piece(ctx_tgt, drafts[s].tokens[i_dft]).c_str());
|
|
282
291
|
drafts[s].active = false;
|
|
283
292
|
|
|
284
293
|
// calculate residual probability
|
|
285
294
|
GGML_ASSERT(dist_tgt.sorted);
|
|
286
295
|
GGML_ASSERT(dist_dft.sorted);
|
|
287
|
-
float sum_probs = 0.0f;
|
|
288
296
|
|
|
289
297
|
// sort dist by id
|
|
290
298
|
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
|
@@ -294,10 +302,18 @@ int main(int argc, char ** argv) {
|
|
|
294
302
|
return a.id < b.id;
|
|
295
303
|
});
|
|
296
304
|
|
|
305
|
+
float sum_probs = 0.0f;
|
|
306
|
+
|
|
297
307
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
298
|
-
|
|
308
|
+
if (i < dist_dft.size) {
|
|
309
|
+
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
|
|
310
|
+
} else {
|
|
311
|
+
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p);
|
|
312
|
+
}
|
|
313
|
+
|
|
299
314
|
sum_probs += dist_tgt.data[i].p;
|
|
300
315
|
}
|
|
316
|
+
|
|
301
317
|
for (size_t i = 0; i < dist_tgt.size; i++) {
|
|
302
318
|
dist_tgt.data[i].p /= sum_probs;
|
|
303
319
|
}
|
|
@@ -326,24 +342,30 @@ int main(int argc, char ** argv) {
|
|
|
326
342
|
if (!accept) {
|
|
327
343
|
// all drafted tokens were rejected
|
|
328
344
|
// sample from the target model
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
345
|
+
LOG_DBG("all drafted tokens were rejected, sampling from residual distribution\n");
|
|
346
|
+
std::vector<float> probs(dist_tgt.size);
|
|
347
|
+
for (size_t i = 0; i < dist_tgt.size; ++i) {
|
|
348
|
+
probs[i] = dist_tgt.data[i].p;
|
|
349
|
+
}
|
|
350
|
+
|
|
351
|
+
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
|
352
|
+
|
|
353
|
+
const int idx = dist(rng);
|
|
334
354
|
|
|
355
|
+
token_id = dist_tgt.data[idx].id;
|
|
356
|
+
common_sampler_accept(smpl, token_id, true);
|
|
357
|
+
token_str = common_token_to_piece(ctx_tgt, token_id);
|
|
358
|
+
}
|
|
335
359
|
} else {
|
|
336
360
|
// greedy verification
|
|
337
361
|
|
|
338
362
|
// sample from the target model
|
|
339
|
-
|
|
340
|
-
token_id =
|
|
341
|
-
|
|
342
|
-
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
|
363
|
+
LOG_DBG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
364
|
+
token_id = common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
|
343
365
|
|
|
344
|
-
|
|
366
|
+
common_sampler_accept(smpl, token_id, true);
|
|
345
367
|
|
|
346
|
-
token_str =
|
|
368
|
+
token_str = common_token_to_piece(ctx_tgt, token_id);
|
|
347
369
|
|
|
348
370
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
349
371
|
if (!drafts[s].active) {
|
|
@@ -351,7 +373,7 @@ int main(int argc, char ** argv) {
|
|
|
351
373
|
}
|
|
352
374
|
|
|
353
375
|
if (i_dft < (int) drafts[s].tokens.size() && token_id == drafts[s].tokens[i_dft]) {
|
|
354
|
-
|
|
376
|
+
LOG_DBG("the sampled target token matches the %dth drafted token of sequence %d (%d, '%s') - accepted\n", i_dft, s, token_id, token_str.c_str());
|
|
355
377
|
|
|
356
378
|
s_keep = s;
|
|
357
379
|
accept = true;
|
|
@@ -373,26 +395,24 @@ int main(int argc, char ** argv) {
|
|
|
373
395
|
++i_dft;
|
|
374
396
|
if (params.use_color) {
|
|
375
397
|
// Color token according to its origin sequence
|
|
376
|
-
|
|
398
|
+
LOG("\u001b[%dm%s\u001b[37m", (36 - s_keep % 6), token_str.c_str());
|
|
377
399
|
} else {
|
|
378
|
-
|
|
400
|
+
LOG("%s", token_str.c_str());
|
|
379
401
|
}
|
|
380
|
-
fflush(stdout);
|
|
381
402
|
continue;
|
|
382
403
|
} else {
|
|
383
|
-
|
|
384
|
-
fflush(stdout);
|
|
404
|
+
LOG("%s", token_str.c_str());
|
|
385
405
|
break;
|
|
386
406
|
}
|
|
387
407
|
}
|
|
388
408
|
}
|
|
389
409
|
|
|
390
410
|
{
|
|
391
|
-
|
|
411
|
+
LOG_DBG("the sampled target token (%d, '%s') did not match, or we ran out of drafted tokens\n", token_id, token_str.c_str());
|
|
392
412
|
|
|
393
413
|
// TODO: simplify
|
|
394
414
|
{
|
|
395
|
-
|
|
415
|
+
LOG_DBG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft);
|
|
396
416
|
|
|
397
417
|
llama_kv_cache_seq_keep(ctx_dft, s_keep);
|
|
398
418
|
llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1);
|
|
@@ -415,21 +435,24 @@ int main(int argc, char ** argv) {
|
|
|
415
435
|
drafts[0].dists.push_back(std::vector<llama_token_data>());
|
|
416
436
|
drafts[0].i_batch_tgt.push_back(0);
|
|
417
437
|
|
|
418
|
-
|
|
419
|
-
|
|
438
|
+
common_batch_clear(batch_dft);
|
|
439
|
+
common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true);
|
|
420
440
|
|
|
421
441
|
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
|
422
|
-
//
|
|
442
|
+
// LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str());
|
|
423
443
|
llama_decode(ctx_dft, batch_dft);
|
|
424
444
|
|
|
425
445
|
++n_past_dft;
|
|
426
446
|
}
|
|
427
447
|
|
|
428
|
-
if (n_predict > params.n_predict || has_eos) {
|
|
448
|
+
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
|
|
429
449
|
break;
|
|
430
450
|
}
|
|
431
451
|
|
|
432
|
-
|
|
452
|
+
if (drafts[0].smpl) {
|
|
453
|
+
common_sampler_free(drafts[0].smpl);
|
|
454
|
+
}
|
|
455
|
+
drafts[0].smpl = common_sampler_clone(smpl);
|
|
433
456
|
|
|
434
457
|
int n_seq_cur = 1;
|
|
435
458
|
int n_past_cur = n_past_dft;
|
|
@@ -442,8 +465,8 @@ int main(int argc, char ** argv) {
|
|
|
442
465
|
drafts[0].drafting = true;
|
|
443
466
|
drafts[0].i_batch_dft = 0;
|
|
444
467
|
|
|
445
|
-
|
|
446
|
-
|
|
468
|
+
common_batch_clear(batch_tgt);
|
|
469
|
+
common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true);
|
|
447
470
|
|
|
448
471
|
// sample n_draft tokens from the draft model using tree-based sampling
|
|
449
472
|
for (int i = 0; i < n_draft; ++i) {
|
|
@@ -458,21 +481,21 @@ int main(int argc, char ** argv) {
|
|
|
458
481
|
continue;
|
|
459
482
|
}
|
|
460
483
|
|
|
461
|
-
|
|
484
|
+
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
|
|
462
485
|
|
|
463
|
-
const auto
|
|
486
|
+
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl);
|
|
464
487
|
|
|
465
|
-
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p
|
|
466
|
-
|
|
467
|
-
k, s, i, cur_p[k].id, cur_p[k].p,
|
|
488
|
+
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
|
|
489
|
+
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
|
490
|
+
k, s, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
|
468
491
|
}
|
|
469
492
|
|
|
470
493
|
std::vector<int> sa(1, s);
|
|
471
494
|
|
|
472
495
|
// attempt to split the branch if the probability is high enough
|
|
473
496
|
for (int f = 1; f < 8; ++f) {
|
|
474
|
-
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
|
475
|
-
|
|
497
|
+
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
|
|
498
|
+
LOG_DBG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
|
476
499
|
|
|
477
500
|
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
|
478
501
|
llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1);
|
|
@@ -498,7 +521,10 @@ int main(int argc, char ** argv) {
|
|
|
498
521
|
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
|
499
522
|
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
|
500
523
|
|
|
501
|
-
|
|
524
|
+
if (drafts[n_seq_cur].smpl) {
|
|
525
|
+
common_sampler_free(drafts[n_seq_cur].smpl);
|
|
526
|
+
}
|
|
527
|
+
drafts[n_seq_cur].smpl = common_sampler_clone(drafts[s].smpl);
|
|
502
528
|
|
|
503
529
|
sa.push_back(n_seq_cur);
|
|
504
530
|
|
|
@@ -510,25 +536,25 @@ int main(int argc, char ** argv) {
|
|
|
510
536
|
|
|
511
537
|
// add drafted token for each sequence
|
|
512
538
|
for (int is = 0; is < (int) sa.size(); ++is) {
|
|
513
|
-
const llama_token id = cur_p[is].id;
|
|
539
|
+
const llama_token id = cur_p->data[is].id;
|
|
514
540
|
|
|
515
541
|
const int s = sa[is];
|
|
516
542
|
|
|
517
|
-
|
|
543
|
+
common_sampler_accept(drafts[s].smpl, id, true);
|
|
518
544
|
|
|
519
545
|
drafts[s].tokens.push_back(id);
|
|
520
546
|
// save cur_p.data into drafts[s].dists
|
|
521
|
-
drafts[s].dists.push_back(cur_p);
|
|
547
|
+
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
|
|
522
548
|
|
|
523
549
|
// add unique drafted tokens to the target batch
|
|
524
550
|
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
|
525
551
|
|
|
526
|
-
|
|
552
|
+
common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
|
527
553
|
|
|
528
554
|
// add the token to the batch for batched decoding with the draft model
|
|
529
555
|
drafts[s].i_batch_dft = batch_dft.n_tokens;
|
|
530
556
|
|
|
531
|
-
|
|
557
|
+
common_batch_add(batch_dft, id, n_past_cur, { s }, true);
|
|
532
558
|
|
|
533
559
|
if (batch_tgt.n_tokens > n_draft) {
|
|
534
560
|
drafts[s].drafting = false;
|
|
@@ -558,7 +584,7 @@ int main(int argc, char ** argv) {
|
|
|
558
584
|
llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1);
|
|
559
585
|
}
|
|
560
586
|
|
|
561
|
-
//
|
|
587
|
+
// LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str());
|
|
562
588
|
llama_decode(ctx_tgt, batch_tgt);
|
|
563
589
|
++n_past_tgt;
|
|
564
590
|
}
|
|
@@ -576,27 +602,30 @@ int main(int argc, char ** argv) {
|
|
|
576
602
|
|
|
577
603
|
auto t_dec_end = ggml_time_us();
|
|
578
604
|
|
|
579
|
-
|
|
605
|
+
LOG("\n\n");
|
|
580
606
|
|
|
581
|
-
|
|
582
|
-
|
|
607
|
+
LOG_INF("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
|
|
608
|
+
LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
|
|
583
609
|
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
610
|
+
LOG_INF("\n");
|
|
611
|
+
LOG_INF("n_draft = %d\n", n_draft);
|
|
612
|
+
LOG_INF("n_predict = %d\n", n_predict);
|
|
613
|
+
LOG_INF("n_drafted = %d\n", n_drafted);
|
|
614
|
+
LOG_INF("n_accept = %d\n", n_accept);
|
|
615
|
+
LOG_INF("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
|
590
616
|
|
|
591
|
-
|
|
592
|
-
|
|
617
|
+
LOG_INF("\n");
|
|
618
|
+
LOG_INF("draft:\n\n");
|
|
619
|
+
// TODO: print sampling/grammar timings for all drafts
|
|
620
|
+
llama_perf_context_print(ctx_dft);
|
|
593
621
|
|
|
594
|
-
|
|
595
|
-
|
|
622
|
+
LOG_INF("\n");
|
|
623
|
+
LOG_INF("target:\n\n");
|
|
624
|
+
common_perf_print(ctx_tgt, smpl);
|
|
596
625
|
|
|
597
|
-
|
|
626
|
+
common_sampler_free(smpl);
|
|
598
627
|
for (int s = 0; s < n_seq_dft; ++s) {
|
|
599
|
-
|
|
628
|
+
common_sampler_free(drafts[s].smpl);
|
|
600
629
|
}
|
|
601
630
|
|
|
602
631
|
llama_batch_free(batch_dft);
|
|
@@ -609,7 +638,7 @@ int main(int argc, char ** argv) {
|
|
|
609
638
|
|
|
610
639
|
llama_backend_free();
|
|
611
640
|
|
|
612
|
-
|
|
641
|
+
LOG("\n\n");
|
|
613
642
|
|
|
614
643
|
return 0;
|
|
615
644
|
}
|
|
@@ -4,33 +4,24 @@
|
|
|
4
4
|
# Copyright (C) 2024 Intel Corporation
|
|
5
5
|
# SPDX-License-Identifier: MIT
|
|
6
6
|
|
|
7
|
-
INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
|
8
7
|
source /opt/intel/oneapi/setvars.sh
|
|
9
8
|
|
|
10
|
-
if [ $# -gt 0 ]; then
|
|
11
|
-
GGML_SYCL_DEVICE=$1
|
|
12
|
-
GGML_SYCL_SINGLE_GPU=1
|
|
13
|
-
else
|
|
14
|
-
GGML_SYCL_DEVICE=0
|
|
15
|
-
GGML_SYCL_SINGLE_GPU=0
|
|
16
|
-
fi
|
|
17
|
-
|
|
18
9
|
#export GGML_SYCL_DEBUG=1
|
|
19
10
|
|
|
20
|
-
|
|
21
11
|
#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
|
|
22
12
|
|
|
23
|
-
|
|
13
|
+
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
|
|
14
|
+
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
|
|
15
|
+
NGL=33
|
|
16
|
+
CONEXT=8192
|
|
17
|
+
|
|
18
|
+
if [ $# -gt 0 ]; then
|
|
19
|
+
GGML_SYCL_DEVICE=$1
|
|
24
20
|
echo "use $GGML_SYCL_DEVICE as main GPU"
|
|
25
21
|
#use signle GPU only
|
|
26
|
-
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m
|
|
22
|
+
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT} -mg $GGML_SYCL_DEVICE -sm none
|
|
23
|
+
|
|
27
24
|
else
|
|
28
25
|
#use multiple GPUs with same max compute units
|
|
29
|
-
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m
|
|
26
|
+
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONEXT}
|
|
30
27
|
fi
|
|
31
|
-
|
|
32
|
-
#use main GPU only
|
|
33
|
-
#ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0 -mg $GGML_SYCL_DEVICE -sm none
|
|
34
|
-
|
|
35
|
-
#use multiple GPUs with same max compute units
|
|
36
|
-
#ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "${INPUT2}" -n 400 -e -ngl 33 -s 0
|
|
@@ -6,4 +6,4 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
|
|
|
6
6
|
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
.\build\bin\
|
|
9
|
+
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 33 -s 0
|